-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix tests for mps support #1
base: master
Are you sure you want to change the base?
Changes from 43 commits
ace0516
9ac6225
2dcbef9
b00ca7f
06a2124
6d868c0
8d79e96
3276cb0
f4f6073
64327c7
0344c3c
fa196ab
efd086e
7f11843
c60f681
92e8d11
b235c8e
d4d0536
0311b62
086f79a
fe606fc
34f4819
ef39571
d26324c
1e5dc90
40ed03c
e83924b
b707480
81e3c63
f0e54a7
d47c586
b85a2a5
1c25053
f822ef5
1ac4a60
9970f51
955382e
56c153f
3d59b5c
dd3d0ac
5e7372d
263e657
8f0b488
e4f4f12
7c71688
4c03a25
020ee42
daaebd0
9489b1a
0ec37d8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` | |
=================== =========== ============ ================= =============== ================ | ||
ARS [#f1]_ ✔️ ✔️ ❌ ❌ ✔️ | ||
A2C ✔️ ✔️ ✔️ ✔️ ✔️ | ||
CrossQ [#f1]_ ✔️ ❌ ❌ ❌ ✔️ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Inconsistency in CrossQ implementation location. There appears to be an inconsistency in the documentation. The table indicates CrossQ is implemented in SB3 Contrib (via footnote [#f1]), but the note at the bottom of the file mentions it's in the SBX repo. Please clarify which repository actually contains the CrossQ implementation. |
||
DDPG ✔️ ❌ ❌ ❌ ✔️ | ||
DQN ❌ ✔️ ❌ ❌ ✔️ | ||
HER ✔️ ✔️ ❌ ❌ ✔️ | ||
|
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -135,6 +135,8 @@ def to_torch(self, array: np.ndarray, copy: bool = True) -> th.Tensor: | |||||||||
:return: | ||||||||||
""" | ||||||||||
if copy: | ||||||||||
if hasattr(th, "backends") and th.backends.mps.is_built(): | ||||||||||
return th.tensor(array, dtype=th.float32, device=self.device) | ||||||||||
Comment on lines
+139
to
+140
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Modify the MPS check in the Consider changing the conditional check to determine if the device is MPS by checking Apply this diff to modify the condition: def to_torch(self, array: np.ndarray, copy: bool = True) -> th.Tensor:
if copy:
- if hasattr(th, "backends") and th.backends.mps.is_built():
+ if self.device.type == "mps":
return th.tensor(array, dtype=th.float32, device=self.device)
return th.tensor(array, device=self.device)
return th.as_tensor(array, device=self.device) 📝 Committable suggestion
Suggested change
|
||||||||||
return th.tensor(array, device=self.device) | ||||||||||
return th.as_tensor(array, device=self.device) | ||||||||||
|
||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -78,11 +78,11 @@ def convert_if_needed(self, state: np.ndarray) -> Union[int, np.ndarray]: | |||||||
if self.discrete_obs_space: | ||||||||
# The internal state is the binary representation of the | ||||||||
# observed one | ||||||||
return int(sum(state[i] * 2**i for i in range(len(state)))) | ||||||||
return int(sum(int(state[i]) * 2**i for i in range(len(state)))) | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Simplify state conversion using NumPy vectorization Currently, the state conversion uses a generator expression with explicit loops and integer casting: return int(sum(int(state[i]) * 2**i for i in range(len(state)))) This can be simplified and optimized by leveraging NumPy's vectorized operations, which are more efficient and concise. Consider rewriting the code using Apply this diff to simplify the code: - return int(sum(int(state[i]) * 2**i for i in range(len(state))))
+ return int(state.dot(2 ** np.arange(len(state)))) This approach eliminates the explicit loop and casting, improving performance and readability. |
||||||||
|
||||||||
if self.image_obs_space: | ||||||||
size = np.prod(self.image_shape) | ||||||||
image = np.concatenate((state * 255, np.zeros(size - len(state), dtype=np.uint8))) | ||||||||
image = np.concatenate((state.astype(np.uint8) * 255, np.zeros(size - len(state), dtype=np.uint8))) | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Optimize image creation by preallocating the array Instead of concatenating arrays to create the image, you can preallocate the array and assign values directly. This avoids unnecessary memory allocation and improves performance: Apply this diff to optimize the code: - image = np.concatenate((state.astype(np.uint8) * 255, np.zeros(size - len(state), dtype=np.uint8)))
+ image = np.zeros(size, dtype=np.uint8)
+ image[:len(state)] = state.astype(np.uint8) * 255 This refactored code preallocates a zero-filled array of the required size and assigns the scaled state values directly to the beginning of the array. Committable suggestion
Suggested change
|
||||||||
return image.reshape(self.image_shape).astype(np.uint8) | ||||||||
return state | ||||||||
|
||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -1,5 +1,6 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import sys | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import time | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import warnings | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import numpy as np | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -135,6 +136,28 @@ def _setup_model(self) -> None: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.observation_space, self.action_space, self.lr_schedule, use_sde=self.use_sde, **self.policy_kwargs | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.policy = self.policy.to(self.device) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Warn when not using CPU with MlpPolicy | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self._maybe_recommend_cpu() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def _maybe_recommend_cpu(self, mlp_class_name: str = "ActorCriticPolicy") -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Recommend to use CPU only when using A2C/PPO with MlpPolicy. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
:param: The name of the class for the default MlpPolicy. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
policy_class_name = self.policy_class.__name__ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if self.device != th.device("cpu") and policy_class_name == mlp_class_name: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
warnings.warn( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
f"You are trying to run {self.__class__.__name__} on the GPU, " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"but it is primarily intended to run on the CPU when not using a CNN policy " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
f"(you are using {policy_class_name} which should be a MlpPolicy). " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"See https://github.com/DLR-RM/stable-baselines3/issues/1245 " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"for more info. " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"You can pass `device='cpu'` or `export CUDA_VISIBLE_DEVICES=` to force using the CPU." | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"Note: The model will train, but the GPU utilization will be poor and " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"the training might take longer than on CPU.", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
UserWarning, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+142
to
+160
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Adjust default policy class name and enhance MPS-specific guidance The implementation is good but needs two adjustments:
Here's the suggested implementation: - def _maybe_recommend_cpu(self, mlp_class_name: str = "ActorCriticPolicy") -> None:
+ def _maybe_recommend_cpu(self, mlp_class_name: str = "MlpPolicy") -> None:
"""
Recommend to use CPU only when using A2C/PPO with MlpPolicy.
:param: The name of the class for the default MlpPolicy.
"""
policy_class_name = self.policy_class.__name__
if self.device != th.device("cpu") and policy_class_name == mlp_class_name:
warnings.warn(
f"You are trying to run {self.__class__.__name__} on the GPU, "
"but it is primarily intended to run on the CPU when not using a CNN policy "
f"(you are using {policy_class_name} which should be a MlpPolicy). "
+ "This is especially important for MPS (Apple Silicon GPU) users. "
"See https://github.com/DLR-RM/stable-baselines3/issues/1245 "
"for more info. "
"You can pass `device='cpu'` or `export CUDA_VISIBLE_DEVICES=` to force using the CPU."
"Note: The model will train, but the GPU utilization will be poor and "
"the training might take longer than on CPU.",
UserWarning,
) 📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def collect_rollouts( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codebase verification
Ubuntu 24.04 is not yet officially supported by Read the Docs
Ubuntu 24.04 (Jammy) was just released in April 2024 and Read the Docs typically takes some time to officially support new Ubuntu LTS versions. I recommend keeping ubuntu-22.04 for now to ensure stable documentation builds. You can upgrade to 24.04 once it's officially supported by Read the Docs.
.readthedocs.yml
back toubuntu-22.04
mambaforge-23.11
as it's independent of the OS version🔗 Analysis chain
Verify Ubuntu 24.04 support on Read the Docs.
While upgrading to Ubuntu 24.04 LTS is forward-thinking, we should verify that Read the Docs officially supports this version to ensure stable documentation builds.
Let me gather more information about Read the Docs' Ubuntu support.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
Length of output: 474
Script:
Length of output: 944