diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1f0b2254..8e356b82 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,26 +1,22 @@ -# See https://pre-commit.com for more information -# See https://pre-commit.com/hooks.html for more hooks -fail_fast: true +# Usage +# uv run pre-commit install +# uv run pre-commit run --all-files repos: -- repo: https://github.com/psf/black - rev: 25.1.0 - hooks: - - id: black - args: [--config, pyproject.toml] - types: [python] + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v6.0.0 + hooks: + - id: check-added-large-files + - id: check-case-conflict + - id: check-merge-conflict + - id: check-symlinks + - id: mixed-line-ending + - id: trailing-whitespace -- repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.11.0 - hooks: - - id: ruff - args: [ --fix ] - -- repo: https://github.com/pre-commit/pre-commit-hooks - rev: v5.0.0 - hooks: - - id: check-toml - - id: check-yaml - - id: detect-private-key - - id: end-of-file-fixer - - id: trailing-whitespace + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.13.2 + hooks: + - id: ruff-check + args: [ --fix ] + - id: ruff-format + types_or: [ python, pyi ] diff --git a/configs/experiment/graph/am.yaml b/configs/experiment/graph/am.yaml index 302a6d51..424ecaae 100644 --- a/configs/experiment/graph/am.yaml +++ b/configs/experiment/graph/am.yaml @@ -21,7 +21,7 @@ model: test_data_size: 1000 optimizer_kwargs: lr: 1e-4 - + trainer: max_epochs: 100 diff --git a/configs/experiment/routing/mdpomo.yaml b/configs/experiment/routing/mdpomo.yaml index 2de2a532..5d5b4e30 100644 --- a/configs/experiment/routing/mdpomo.yaml +++ b/configs/experiment/routing/mdpomo.yaml @@ -11,7 +11,7 @@ env: generator_params: num_loc: 50 loc_distribution: "mix_distribution" - + logger: wandb: diff --git a/docs/content/api/zoo/improvement.md b/docs/content/api/zoo/improvement.md index 8ca68c75..90e763bc 100644 --- a/docs/content/api/zoo/improvement.md +++ b/docs/content/api/zoo/improvement.md @@ -2,7 +2,7 @@ These methods are trained to improve existing solutions iteratively, akin to local search algorithms. They focus on refining existing solutions rather than generating them from scratch. -### DACT +### DACT :::models.zoo.dact.encoder options: @@ -19,7 +19,7 @@ These methods are trained to improve existing solutions iteratively, akin to loc :::models.zoo.dact.model options: show_root_heading: false - + ### N2S diff --git a/docs/content/general/faq.md b/docs/content/general/faq.md index c708171e..2f67fbff 100644 --- a/docs/content/general/faq.md +++ b/docs/content/general/faq.md @@ -2,7 +2,7 @@ -You can submit your questions via [GitHub Issues](https://github.com/ai4co/rl4co/issues) or [Discussions](https://github.com/ai4co/rl4co/discussions). +You can submit your questions via [GitHub Issues](https://github.com/ai4co/rl4co/issues) or [Discussions](https://github.com/ai4co/rl4co/discussions). You may search for your question in the existing issues or discussions before submitting a new one. If asked more than a few times, we will add it here! diff --git a/docs/content/intro/environments.md b/docs/content/intro/environments.md index 1c6872c9..8a28e0a4 100644 --- a/docs/content/intro/environments.md +++ b/docs/content/intro/environments.md @@ -60,7 +60,7 @@ Click [here](../api/envs/routing.md) for API documentation on routing problems. ## Scheduling Problems -Scheduling problems are a fundamental class of problems in operations research and industrial engineering, where the objective is to optimize the allocation of resources over time. These problems are critical in various industries, such as manufacturing, computer science, and project management. +Scheduling problems are a fundamental class of problems in operations research and industrial engineering, where the objective is to optimize the allocation of resources over time. These problems are critical in various industries, such as manufacturing, computer science, and project management. @@ -68,24 +68,24 @@ Scheduling problems are a fundamental class of problems in operations research a Here we show a general constructive MDP formulation based on the Job Shop Scheduling Problem (JSSP), a well-known scheduling problem, which can be adapted to other scheduling problems. -- **State** $s_t \in \mathcal{S}$: +- **State** $s_t \in \mathcal{S}$: The state is represented by a disjunctive graph, where: - Operations are nodes - Processing orders between operations are shown by directed arcs - This graph encapsulates both the problem instance and the current partial schedule -- **Action** $a_t \in \mathcal{A}$: +- **Action** $a_t \in \mathcal{A}$: An action involves selecting a feasible operation to assign to its designated machine, a process often referred to as dispatching. The action space consists of all operations that can be feasibly scheduled at the current state. -- **Transition** $\mathcal{T}$: +- **Transition** $\mathcal{T}$: The transition function deterministically updates the disjunctive graph based on the dispatched operation. This includes: - Modifying the graph's topology (e.g., adding new connections between operations) - Updating operation attributes (e.g., start times) -- **Reward** $\mathcal{R}$: +- **Reward** $\mathcal{R}$: The reward function is designed to align with the optimization objective. For instance, if minimizing makespan is the goal, the reward could be the negative change in makespan resulting from the latest action. -- **Policy** $\pi$: +- **Policy** $\pi$: The policy, typically stochastic, takes the current disjunctive graph as input and outputs a probability distribution over feasible dispatching actions. This process continues until a complete schedule is constructed. @@ -103,25 +103,25 @@ Electronic Design Automation (EDA) is a sophisticated process that involves the EDA encompasses many problem types; here we'll focus on placement problems, which are fundamental in the physical design of integrated circuits and printed circuit boards. We'll use the Decap Placement Problem (DPP) as an example to illustrate a typical MDP formulation for EDA placement problems. -- **State** $s_t \in \mathcal{S}$: +- **State** $s_t \in \mathcal{S}$: The state typically represents the current configuration of the design space, which may include: - Locations of fixed elements (e.g., ports, keepout regions) - Current placements of movable elements - Remaining resources or components to be placed -- **Action** $a_t \in \mathcal{A}$: +- **Action** $a_t \in \mathcal{A}$: An action usually involves placing a component at a valid location within the design space. The action space consists of all feasible placement locations, considering design rules and constraints. -- **Transition** $\mathcal{T}$: +- **Transition** $\mathcal{T}$: The transition function updates the design state based on the placement action, which may include: - Updating the placement map - Adjusting available resources or remaining components - Recalculating relevant metrics (e.g., wire length, power distribution) -- **Reward** $\mathcal{R}$: +- **Reward** $\mathcal{R}$: The reward is typically based on the improvement in the design objective resulting from the latest placement action. This could involve metrics such as area efficiency, signal integrity, or power consumption. -- **Policy** $\pi$: +- **Policy** $\pi$: The policy takes the current design state as input and outputs a probability distribution over possible placement actions. Note that specific problems may introduce additional complexities or constraints. @@ -142,26 +142,26 @@ In graph problems, we typically work with a graph $G = (V, E)$, where $V$ is a s Graph problems can be effectively modeled using a Markov Decision Process (MDP) framework in a constructive fashion. Here, we outline the key components of the MDP formulation for graph problems: -- **State** $s_t \in \mathcal{S}$: +- **State** $s_t \in \mathcal{S}$: The state encapsulates the current configuration of the graph and the optimization progress. It typically includes: - The graph structure (vertices and edges) - Attributes associated with vertices or edges - The set of elements (vertices, edges, or subgraphs) selected so far - Problem-specific information, such as remaining selections or resources -- **Action** $a_t \in \mathcal{A}$: +- **Action** $a_t \in \mathcal{A}$: An action usually involves selecting a graph element (e.g., a vertex, edge, or subgraph). The action space comprises all valid selections based on the problem constraints and the current state. -- **Transition** $\mathcal{T}$: +- **Transition** $\mathcal{T}$: The transition function $\mathcal{T}(s_t, a_t) \rightarrow s_{t+1}$ updates the graph state based on the selected action. This typically involves: - Updating the set of selected elements - Modifying graph attributes affected by the selection - Updating problem-specific information (e.g., remaining selections or resources) -- **Reward** $\mathcal{R}$: +- **Reward** $\mathcal{R}$: The reward function $\mathcal{R}(s_t, a_t)$ quantifies the quality of the action taken. It is typically based on the improvement in the optimization objective resulting from the latest selection. This could involve metrics such as coverage, distance, connectivity, or any other problem-specific criteria. -- **Policy** $\pi$: +- **Policy** $\pi$: The policy $\pi(a_t|s_t)$ is a probability distribution over possible actions given the current state. It guides the decision-making process, determining which graph elements to select at each step to optimize the objective. Specific problems may introduce additional complexities or constraints, which can often be incorporated through careful design of the state space, action space, and reward function. diff --git a/docs/content/intro/policies.md b/docs/content/intro/policies.md index 9b5ea4ec..0483a868 100644 --- a/docs/content/intro/policies.md +++ b/docs/content/intro/policies.md @@ -11,7 +11,7 @@ A policy $\pi$ is used to construct a solution from scratch for a given problem An AR policy is composed of an encoder $f$ that maps the instance $\mathbf{x}$ into an embedding space $\mathbf{h}=f(\mathbf{x})$ and by a decoder $g$ that iteratively determines a sequence of actions $\mathbf{a}$ as follows: $$ -a_t \sim g(a_t | a_{t-1}, ... ,a_0, s_t, \mathbf{h}), \quad +a_t \sim g(a_t | a_{t-1}, ... ,a_0, s_t, \mathbf{h}), \quad \pi(\mathbf{a}|\mathbf{x}) \triangleq \prod_{t=1}^{T-1} g(a_{t} | a_{t-1}, \ldots ,a_0, s_t, \mathbf{h}). $$ diff --git a/docs/content/intro/rl.md b/docs/content/intro/rl.md index 1e4510cf..3b7a436d 100644 --- a/docs/content/intro/rl.md +++ b/docs/content/intro/rl.md @@ -19,7 +19,7 @@ $$ \nabla_{\theta} \mathcal{L}_a(\theta|\mathbf{x}) = \mathbb{E}_{\pi(\mathbf{a}|\mathbf{x})} \left[(R(\mathbf{a}, \mathbf{x}) - b(\mathbf{x})) \nabla_{\theta}\log \pi(\mathbf{a}|\mathbf{x})\right], $$ -where $b(\cdot)$ is a baseline function used to stabilize training and reduce gradient variance. +where $b(\cdot)$ is a baseline function used to stabilize training and reduce gradient variance. We also distinguish between two types of RL (pre)training: diff --git a/docs/content/start/hydra.md b/docs/content/start/hydra.md index b08b7e52..626ba5a9 100644 --- a/docs/content/start/hydra.md +++ b/docs/content/start/hydra.md @@ -102,7 +102,7 @@ defaults: This section sets the default configuration for the model, environment, callbacks, trainer, and logger. This means that if a key is not specified in the experiment configuration, the default value will be used. Note that these are set in the root [configs/](https://github.com/ai4co/rl4co/tree/main/configs) folder, and are useful for better organization and reusability. ```yaml linenums="11" -env: +env: generator_params: loc_distribution: "uniform" num_loc: 50 @@ -153,7 +153,7 @@ logger: Finally, this section specifies the logger configuration. In this case, we are using Weights & Biases (WandB) to log the results of the experiment. We specify the project name, tags, group, and name of the experiment. -That's it! 🎉 +That's it! 🎉 !!! tip diff --git a/docs/hooks.py b/docs/hooks.py index e4e6950c..7fe10eef 100644 --- a/docs/hooks.py +++ b/docs/hooks.py @@ -26,12 +26,12 @@ def on_startup(*args, **kwargs): def append_tricks_to_readme(file_path): # read the tricks from docs/overrides/fancylogo.txt # and put them at the beginning of the file - with open("docs/overrides/fancylogo.txt", "r") as fancylogo: + with open("docs/overrides/fancylogo.txt") as fancylogo: tricks = fancylogo.read() if not os.path.exists(file_path): print(f"Error: The file {file_path} does not exist.") return - with open(file_path, "r") as original: + with open(file_path) as original: data = original.read() # remove first 33 lines. yeah, it's a hack to remove unneded stuff lol data = "\n".join(data.split("\n")[33:]) diff --git a/docs/js/autolink.js b/docs/js/autolink.js index 30a295a5..3ac4d515 100644 --- a/docs/js/autolink.js +++ b/docs/js/autolink.js @@ -3,16 +3,16 @@ const convertLinks = ( input ) => { let text = input; const linksFound = text.match( /(?:www|https?)[^\s]+/g ); const aLink = []; - + if ( linksFound != null ) { - + for ( let i=0; i' ) } @@ -26,7 +26,7 @@ const convertLinks = ( input ) => { text = text.split( linksFound[i] ).map(item => { return aLink[i].includes('iframe') ? item.trim() : item } ).join( aLink[i] ); } return text; - + } else { return input; diff --git a/docs/js/katex.js b/docs/js/katex.js index 841e35ad..2ab434fc 100644 --- a/docs/js/katex.js +++ b/docs/js/katex.js @@ -1,4 +1,4 @@ -document$.subscribe(({ body }) => { +document$.subscribe(({ body }) => { renderMathInElement(body, { delimiters: [ { left: "$$", right: "$$", display: true }, diff --git a/docs/overrides/fancylogo.txt b/docs/overrides/fancylogo.txt index 67ee9347..fae2b082 100644 --- a/docs/overrides/fancylogo.txt +++ b/docs/overrides/fancylogo.txt @@ -2,16 +2,16 @@ hide: - navigation - toc ---- +--- -
+
-
+ +
@@ -83,11 +83,11 @@ hide: const setContainerDimensions = () => { const container = document.querySelector('.md-main__inner #particles-container'); const mainContent = document.querySelector('.md-main__inner'); - + if (mainContent && container) { const containerWidth = mainContent.offsetWidth; container.style.width = `${containerWidth}px`; - + // Calculate height based on the aspect ratio and 60% width const imageWidth = containerWidth * 0.6; const imageHeight = imageWidth * ASPECT_RATIO; @@ -103,16 +103,16 @@ hide: const backgroundColor = computedStyle.backgroundColor; mask.style.background = ` - linear-gradient(to right, - ${backgroundColor} 0%, - rgba(0,0,0,0) 10%, - rgba(0,0,0,0) 90%, + linear-gradient(to right, + ${backgroundColor} 0%, + rgba(0,0,0,0) 10%, + rgba(0,0,0,0) 90%, ${backgroundColor} 100% ), - linear-gradient(to bottom, - ${backgroundColor} 0%, - rgba(0,0,0,0) 10%, - rgba(0,0,0,0) 90%, + linear-gradient(to bottom, + ${backgroundColor} 0%, + rgba(0,0,0,0) 10%, + rgba(0,0,0,0) 90%, ${backgroundColor} 100% ) `; diff --git a/examples/2d-meta_train.py b/examples/2d-meta_train.py index 1f3fb8d4..cd8438c7 100644 --- a/examples/2d-meta_train.py +++ b/examples/2d-meta_train.py @@ -4,34 +4,38 @@ from rl4co.envs import CVRPEnv from rl4co.models.zoo.am import AttentionModelPolicy from rl4co.models.zoo.pomo import POMO -from rl4co.utils.trainer import RL4COTrainer from rl4co.utils.meta_trainer import ReptileCallback +from rl4co.utils.trainer import RL4COTrainer + def main(): # Set device device_id = 0 # RL4CO env based on TorchRL - env = CVRPEnv(generator_params={'num_loc': 50}) + env = CVRPEnv(generator_params={"num_loc": 50}) # Policy: neural network, in this case with encoder-decoder architecture # Note that this is adapted the same as POMO did in the original paper - policy = AttentionModelPolicy(env_name=env.name, - embed_dim=128, - num_encoder_layers=6, - num_heads=8, - normalization="instance", - use_graph_context=False - ) + policy = AttentionModelPolicy( + env_name=env.name, + embed_dim=128, + num_encoder_layers=6, + num_heads=8, + normalization="instance", + use_graph_context=False, + ) # RL Model (POMO) - model = POMO(env, - policy, - batch_size=64, # meta_batch_size - train_data_size=64 * 50, # equals to (meta_batch_size) * (gradient decent steps in the inner-loop optimization of meta-learning method) - val_data_size=0, - optimizer_kwargs={"lr": 1e-4, "weight_decay": 1e-6}, - ) + model = POMO( + env, + policy, + batch_size=64, # meta_batch_size + train_data_size=64 + * 50, # equals to (meta_batch_size) * (gradient decent steps in the inner-loop optimization of meta-learning method) + val_data_size=0, + optimizer_kwargs={"lr": 1e-4, "weight_decay": 1e-6}, + ) # Example callbacks checkpoint_callback = ModelCheckpoint( @@ -46,14 +50,14 @@ def main(): # Meta callbacks meta_callback = ReptileCallback( - num_tasks = 1, # the number of tasks in a mini-batch, i.e. `B` in the original paper - alpha = 0.9, # initial weight of the task model for the outer-loop optimization of reptile - alpha_decay = 1, # weight decay of the task model for the outer-loop optimization of reptile. No decay performs better. - min_size = 20, # minimum of sampled size in meta tasks (only supported in cross-size generalization) - max_size= 150, # maximum of sampled size in meta tasks (only supported in cross-size generalization) + num_tasks=1, # the number of tasks in a mini-batch, i.e. `B` in the original paper + alpha=0.9, # initial weight of the task model for the outer-loop optimization of reptile + alpha_decay=1, # weight decay of the task model for the outer-loop optimization of reptile. No decay performs better. + min_size=20, # minimum of sampled size in meta tasks (only supported in cross-size generalization) + max_size=150, # maximum of sampled size in meta tasks (only supported in cross-size generalization) data_type="size_distribution", # choose from ["size", "distribution", "size_distribution"] sch_bar=0.9, # for the task scheduler of size setting, where lr_decay_epoch = sch_bar * epochs, i.e. after this epoch, learning rate will decay with a weight 0.1 - print_log=True # whether to print the sampled tasks in each meta iteration + print_log=True, # whether to print the sampled tasks in each meta iteration ) callbacks = [meta_callback, checkpoint_callback, rich_model_summary] @@ -68,7 +72,7 @@ def main(): accelerator="gpu", devices=[device_id], logger=logger, - limit_train_batches=50 # gradient decent steps in the inner-loop optimization of meta-learning method + limit_train_batches=50, # gradient decent steps in the inner-loop optimization of meta-learning method ) # Fit @@ -77,4 +81,3 @@ def main(): if __name__ == "__main__": main() - diff --git a/pyproject.toml b/pyproject.toml index d22a3f2e..69be938b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,7 +74,6 @@ dependencies = [ [project.optional-dependencies] dev = [ - "black", "pre-commit>=3.3.3", "ruff", "pytest", @@ -120,34 +119,15 @@ include = ["rl4co"] requires = ["hatchling"] build-backend = "hatchling.build" -[tool.black] -line-length = 90 -target-version = ["py311"] -include = '\.pyi?$' -exclude = ''' -( - /( - \.direnv - | \.eggs - | \.git - | \.tox - | \.venv - | _build - | build - | dist - | venv - )/ -) -''' - [tool.ruff] -line-length = 90 -target-version = "py311" +line-length = 100 +target-version = "py310" show-fixes = false +extend-exclude = ["*.ipynb"] [tool.ruff.lint] -select = ["F", "E", "W", "I001"] -ignore = ["E501"] # never enforce `E501` (line length violations), handled by Black +select = ["F", "E", "W", "I001", "UP"] +ignore = ["E501"] # Ignore line too long errors task-tags = ["TODO", "FIXME"] [tool.ruff.lint.per-file-ignores] @@ -167,6 +147,7 @@ combine-as-imports = true split-on-trailing-comma = false lines-between-types = 1 + [tool.coverage] include = ["rl4co.*"] diff --git a/rl4co/data/dataset.py b/rl4co/data/dataset.py index 60e0fbe5..c29c3eca 100644 --- a/rl4co/data/dataset.py +++ b/rl4co/data/dataset.py @@ -50,9 +50,7 @@ class TensorDictDataset(Dataset): def __init__(self, td: TensorDict): self.data_len = td.batch_size[0] - self.data = [ - {key: value[i] for key, value in td.items()} for i in range(self.data_len) - ] + self.data = [{key: value[i] for key, value in td.items()} for i in range(self.data_len)] def __len__(self): return self.data_len diff --git a/rl4co/data/generate_data.py b/rl4co/data/generate_data.py index 12313515..58126ffb 100644 --- a/rl4co/data/generate_data.py +++ b/rl4co/data/generate_data.py @@ -29,17 +29,13 @@ def generate_env_data(env_type, *args, **kwargs): # remove all None values from args args = [arg for arg in args if arg is not None] - return getattr(sys.modules[__name__], f"generate_{env_type}_data")( - *args, **kwargs - ) + return getattr(sys.modules[__name__], f"generate_{env_type}_data")(*args, **kwargs) except AttributeError: raise NotImplementedError(f"Environment type {env_type} not implemented") def generate_tsp_data(dataset_size, tsp_size): - return { - "locs": np.random.uniform(size=(dataset_size, tsp_size, 2)).astype(np.float32) - } + return {"locs": np.random.uniform(size=(dataset_size, tsp_size, 2)).astype(np.float32)} def generate_vrp_data(dataset_size, vrp_size, capacities=None): @@ -69,9 +65,7 @@ def generate_vrp_data(dataset_size, vrp_size, capacities=None): CAPACITIES[k] = v return { - "depot": np.random.uniform(size=(dataset_size, 2)).astype( - np.float32 - ), # Depot location + "depot": np.random.uniform(size=(dataset_size, 2)).astype(np.float32), # Depot location "locs": np.random.uniform(size=(dataset_size, vrp_size, 2)).astype( np.float32 ), # Node locations @@ -103,9 +97,7 @@ def generate_op_data(dataset_size, op_size, prize_type="const", max_lengths=None else: # Based on distance to depot assert prize_type == "dist" prize_ = np.linalg.norm(depot[:, None, :] - loc, axis=-1) - prize = ( - 1 + (prize_ / prize_.max(axis=-1, keepdims=True) * 99).astype(int) - ) / 100.0 + prize = (1 + (prize_ / prize_.max(axis=-1, keepdims=True) * 99).astype(int)) / 100.0 # Max length is approximately half of optimal TSP tour, such that half (a bit more) of the nodes can be visited # which is maximally difficult as this has the largest number of possibilities @@ -140,17 +132,13 @@ def generate_pctsp_data(dataset_size, pctsp_size, penalty_factor=3, max_lengths= # Now expectation is 0.5 so expected total prize is n / 2, we want to force to visit approximately half of the nodes # so the constraint will be that total prize >= (n / 2) / 2 = n / 4 # equivalently, we divide all prizes by n / 4 and the total prize should be >= 1 - deterministic_prize = ( - np.random.uniform(size=(dataset_size, pctsp_size)) * 4 / float(pctsp_size) - ) + deterministic_prize = np.random.uniform(size=(dataset_size, pctsp_size)) * 4 / float(pctsp_size) # In the deterministic setting, the stochastic_prize is not used and the deterministic prize is known # In the stochastic setting, the deterministic prize is the expected prize and is known up front but the # stochastic prize is only revealed once the node is visited # Stochastic prize is between (0, 2 * expected_prize) such that E(stochastic prize) = E(deterministic_prize) - stochastic_prize = ( - np.random.uniform(size=(dataset_size, pctsp_size)) * deterministic_prize * 2 - ) + stochastic_prize = np.random.uniform(size=(dataset_size, pctsp_size)) * deterministic_prize * 2 return { "locs": loc.astype(np.float32), @@ -287,11 +275,7 @@ def generate_dataset( datadir, "{}{}{}_{}_seed{}.npz".format( problem, - ( - "_{}".format(distribution) - if distribution is not None - else "" - ), + (f"_{distribution}" if distribution is not None else ""), graph_size, name, seed, @@ -304,19 +288,13 @@ def generate_dataset( os.makedirs(os.path.dirname(fname), exist_ok=True) iter += 1 except Exception: - raise ValueError( - "Number of filenames does not match number of problems" - ) + raise ValueError("Number of filenames does not match number of problems") fname = check_extension(filename, extension=".npz") - if not overwrite and os.path.isfile( - check_extension(fname, extension=".npz") - ): + if not overwrite and os.path.isfile(check_extension(fname, extension=".npz")): if not disable_warning: log.info( - "File {} already exists! Run with -f option to overwrite. Skipping...".format( - fname - ) + f"File {fname} already exists! Run with -f option to overwrite. Skipping..." ) continue @@ -324,14 +302,12 @@ def generate_dataset( np.random.seed(seed) # Automatically generate dataset - dataset = generate_env_data( - problem, dataset_size, graph_size, distribution - ) + dataset = generate_env_data(problem, dataset_size, graph_size, distribution) # A function can return None in case of an error or a skip if dataset is not None: # Save to disk as dict - log.info("Saving {} dataset to {}".format(problem, fname)) + log.info(f"Saving {problem} dataset to {fname}") np.savez(fname, **dataset) @@ -354,23 +330,18 @@ def generate_default_datasets(data_dir, generate_eda=False): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument( - "--filename", help="Filename of the dataset to create (ignores datadir)" - ) + parser.add_argument("--filename", help="Filename of the dataset to create (ignores datadir)") parser.add_argument( "--data_dir", default="data", help="Create datasets in data_dir/problem (default 'data')", ) - parser.add_argument( - "--name", type=str, required=True, help="Name to identify dataset" - ) + parser.add_argument("--name", type=str, required=True, help="Name to identify dataset") parser.add_argument( "--problem", type=str, default="all", - help="Problem, 'tsp', 'vrp', 'pctsp' or 'op_const', 'op_unif' or 'op_dist'" - " or 'all' to generate all", + help="Problem, 'tsp', 'vrp', 'pctsp' or 'op_const', 'op_unif' or 'op_dist' or 'all' to generate all", ) parser.add_argument( "--data_distribution", @@ -378,9 +349,7 @@ def generate_default_datasets(data_dir, generate_eda=False): default="all", help="Distributions to generate for problem, default 'all'.", ) - parser.add_argument( - "--dataset_size", type=int, default=10000, help="Size of the dataset" - ) + parser.add_argument("--dataset_size", type=int, default=10000, help="Size of the dataset") parser.add_argument( "--graph_sizes", type=int, diff --git a/rl4co/data/transforms.py b/rl4co/data/transforms.py index bafa6b2f..e29d3925 100644 --- a/rl4co/data/transforms.py +++ b/rl4co/data/transforms.py @@ -1,6 +1,6 @@ import math -from typing import Callable +from collections.abc import Callable import torch @@ -38,9 +38,7 @@ def dihedral_8_augmentation(xy: Tensor) -> Tensor: return aug_xy -def dihedral_8_augmentation_wrapper( - xy: Tensor, reduce: bool = True, *args, **kw -) -> Tensor: +def dihedral_8_augmentation_wrapper(xy: Tensor, reduce: bool = True, *args, **kw) -> Tensor: """Wrapper for dihedral_8_augmentation. If reduce, only return the first 1/8 of the augmented data since the augmentation augments the data 8 times. """ @@ -104,7 +102,7 @@ def get_augment_function(augment_fn: str | Callable): ) -class StateAugmentation(object): +class StateAugmentation: """Augment state by N times via symmetric rotation/reflection transform Args: @@ -125,9 +123,9 @@ def __init__( feats: list = None, ): self.augmentation = get_augment_function(augment_fn) - assert not ( - self.augmentation == dihedral_8_augmentation_wrapper and num_augment != 8 - ), "When using the `dihedral8` augmentation function, then num_augment must be 8" + assert not (self.augmentation == dihedral_8_augmentation_wrapper and num_augment != 8), ( + "When using the `dihedral8` augmentation function, then num_augment must be 8" + ) if feats is None: log.info("Features not passed, defaulting to 'locs'") diff --git a/rl4co/envs/common/base.py b/rl4co/envs/common/base.py index 7b04867b..cc6fbea9 100644 --- a/rl4co/envs/common/base.py +++ b/rl4co/envs/common/base.py @@ -1,7 +1,7 @@ import abc +from collections.abc import Iterable from os.path import join as pjoin -from typing import Iterable, Optional import torch @@ -95,9 +95,9 @@ def get_multiple_dataloader_names(f, names): if names is None: names = [f"{i}" for i in range(len(f))] else: - assert len(names) == len( - f - ), "Number of dataloader names must match number of files" + assert len(names) == len(f), ( + "Number of dataloader names must match number of files" + ) else: if names is not None: log.warning( @@ -132,7 +132,7 @@ def step(self, td: TensorDict) -> TensorDict: # Since we simplify the syntax return self._torchrl_step(td) - def reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict: + def reset(self, td: TensorDict | None = None, batch_size=None) -> TensorDict: """Reset function to call at the beginning of each episode""" if batch_size is None: batch_size = self.batch_size if td is None else td.batch_size @@ -154,9 +154,7 @@ def _torchrl_step(self, td: TensorDict) -> TensorDict: self._assert_tensordict_shape(td) next_preset = td.get("next", None) - next_tensordict = self._step( - td.clone() - ) # NOTE: we clone to avoid recursion error + next_tensordict = self._step(td.clone()) # NOTE: we clone to avoid recursion error next_tensordict = self._step_proc_data(next_tensordict) if next_preset is not None: next_tensordict.update(next_preset.exclude(*next_tensordict.keys(True, True))) @@ -171,7 +169,7 @@ def _step(self, td: TensorDict) -> TensorDict: raise NotImplementedError @abc.abstractmethod - def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict: + def _reset(self, td: TensorDict | None = None, batch_size=None) -> TensorDict: """Reset function to call at the beginning of each episode""" raise NotImplementedError @@ -180,7 +178,7 @@ def _make_spec(self, td_params: TensorDict = None): raise NotImplementedError def get_reward( - self, td: TensorDict, actions: torch.Tensor, check_solution: Optional[bool] = None + self, td: TensorDict, actions: torch.Tensor, check_solution: bool | None = None ) -> torch.Tensor: """Function to compute the reward. Can be called by the agent to compute the reward of the current state This is faster than calling step() and getting the reward from the returned TensorDict at each time for CO tasks @@ -227,15 +225,11 @@ def replace_selected_actions( """ raise NotImplementedError - def local_search( - self, td: TensorDict, actions: torch.Tensor, **kwargs - ) -> torch.Tensor: + def local_search(self, td: TensorDict, actions: torch.Tensor, **kwargs) -> torch.Tensor: """Function to improve the solution. Can be called by the agent to improve the current state This is called with the full solution (i.e. all actions) at the end of the episode """ - raise NotImplementedError( - f"Local is not implemented yet for {self.name} environment" - ) + raise NotImplementedError(f"Local is not implemented yet for {self.name} environment") def dataset(self, batch_size=[], phase="train", filename=None): """Return a dataset of observations @@ -291,7 +285,7 @@ def load_data(fpath, batch_size=[]): """Dataset loading from file""" return load_npz_to_tensordict(fpath) - def _set_seed(self, seed: Optional[int]): + def _set_seed(self, seed: int | None): """Set the seed for the environment""" rng = torch.manual_seed(seed) self.rng = rng @@ -383,9 +377,7 @@ def _get_real_solution(rec): visited_time = torch.zeros((batch_size, seq_length)).to(rec.device) pre = torch.zeros((batch_size), device=rec.device).long() for i in range(seq_length): - visited_time[torch.arange(batch_size), rec[torch.arange(batch_size), pre]] = ( - i + 1 - ) + visited_time[torch.arange(batch_size), rec[torch.arange(batch_size), pre]] = i + 1 pre = rec[torch.arange(batch_size), pre] visited_time = visited_time % seq_length diff --git a/rl4co/envs/common/distribution_utils.py b/rl4co/envs/common/distribution_utils.py index ab8e1449..277ea57b 100644 --- a/rl4co/envs/common/distribution_utils.py +++ b/rl4co/envs/common/distribution_utils.py @@ -19,13 +19,10 @@ def __init__(self, n_cluster: int = 3): self.n_cluster = n_cluster def sample(self, size): - batch_size, num_loc, _ = size # Generate the centers of the clusters - center = self.lower + (self.upper - self.lower) * torch.rand( - batch_size, self.n_cluster * 2 - ) + center = self.lower + (self.upper - self.lower) * torch.rand(batch_size, self.n_cluster * 2) # Pre-define the coordinates coords = torch.zeros(batch_size, num_loc, 2) @@ -69,7 +66,6 @@ def __init__(self, n_cluster_mix=1): self.n_cluster_mix = n_cluster_mix def sample(self, size): - batch_size, num_loc, _ = size # Generate the centers of the clusters @@ -126,7 +122,6 @@ def __init__(self, num_modes: int = 0, cdist: int = 0): self.cdist = cdist def sample(self, size): - batch_size, num_loc, _ = size if self.num_modes == 0: # (0, 0) - uniform @@ -175,9 +170,7 @@ def generate_gaussian(self, batch_size, num_loc): for i in range(batch_size): # Construct covariance matrix for each sample cov_matrix = torch.tensor([[1.0, covs[i]], [covs[i], 1.0]]) - m = torch.distributions.MultivariateNormal( - mean[i], covariance_matrix=cov_matrix - ) + m = torch.distributions.MultivariateNormal(mean[i], covariance_matrix=cov_matrix) coords[i] = m.sample() # Shuffle the coordinates @@ -187,7 +180,6 @@ def generate_gaussian(self, batch_size, num_loc): return self._batch_normalize_and_center(coords) def _global_min_max_scaling(self, coords): - # Scale the points to [0, 1] using min-max scaling coords_min = coords.min(0, keepdim=True).values coords_max = coords.max(0, keepdim=True).values @@ -201,18 +193,14 @@ def _batch_normalize_and_center(self, coords): coords_max = coords.max(dim=1, keepdim=True).values # Step 2: Normalize coordinates to range [0, 1] - coords = ( - coords - coords_min - ) # Broadcasting subtracts min value on each coordinate + coords = coords - coords_min # Broadcasting subtracts min value on each coordinate range_max = ( (coords_max - coords_min).max(dim=-1, keepdim=True).values ) # The maximum range among both coordinates coords = coords / range_max # Divide by the max range to normalize # Step 3: Center the batch in the middle of the [0, 1] range - coords = ( - coords + (1 - coords.max(dim=1, keepdim=True).values) / 2 - ) # Centering the batch + coords = coords + (1 - coords.max(dim=1, keepdim=True).values) / 2 # Centering the batch return coords @@ -235,7 +223,6 @@ def __init__(self, n_cluster=3, n_cluster_mix=1): self.Cluster = Cluster(n_cluster=n_cluster) def sample(self, size): - batch_size, num_loc, _ = size # Pre-define the coordinates sampled under uniform distribution @@ -268,9 +255,7 @@ class Mix_Multi_Distributions: def __init__(self): super().__init__() - self.dist_set = [(0, 0), (1, 1)] + [ - (m, c) for m in [3, 5, 7] for c in [10, 30, 50] - ] + self.dist_set = [(0, 0), (1, 1)] + [(m, c) for m in [3, 5, 7] for c in [10, 30, 50]] def sample(self, size): batch_size, num_loc, _ = size @@ -278,9 +263,7 @@ def sample(self, size): # Pre-select distributions for the entire batch dists = [random.choice(self.dist_set) for _ in range(batch_size)] - unique_dists = list( - set(dists) - ) # Unique distributions to minimize re-instantiation + unique_dists = list(set(dists)) # Unique distributions to minimize re-instantiation # Instantiate Gaussian_Mixture only once per unique distribution gm_instances = {dist: Gaussian_Mixture(*dist) for dist in unique_dists} diff --git a/rl4co/envs/common/utils.py b/rl4co/envs/common/utils.py index 9b6bb8f7..bb792ff6 100644 --- a/rl4co/envs/common/utils.py +++ b/rl4co/envs/common/utils.py @@ -1,6 +1,6 @@ import abc -from typing import Callable +from collections.abc import Callable import torch @@ -61,29 +61,27 @@ def get_sampler( elif distribution == Uniform or distribution == "uniform": return Uniform(low=low, high=high) elif distribution == Normal or distribution == "normal" or distribution == "gaussian": - assert ( - kwargs.get(val_name + "_mean", None) is not None - ), "mean is required for Normal distribution" - assert ( - kwargs.get(val_name + "_std", None) is not None - ), "std is required for Normal distribution" + assert kwargs.get(val_name + "_mean", None) is not None, ( + "mean is required for Normal distribution" + ) + assert kwargs.get(val_name + "_std", None) is not None, ( + "std is required for Normal distribution" + ) return Normal(loc=kwargs[val_name + "_mean"], scale=kwargs[val_name + "_std"]) elif distribution == Exponential or distribution == "exponential": - assert ( - kwargs.get(val_name + "_rate", None) is not None - ), "rate is required for Exponential/Poisson distribution" + assert kwargs.get(val_name + "_rate", None) is not None, ( + "rate is required for Exponential/Poisson distribution" + ) return Exponential(rate=kwargs[val_name + "_rate"]) elif distribution == Poisson or distribution == "poisson": - assert ( - kwargs.get(val_name + "_rate", None) is not None - ), "rate is required for Exponential/Poisson distribution" + assert kwargs.get(val_name + "_rate", None) is not None, ( + "rate is required for Exponential/Poisson distribution" + ) return Poisson(rate=kwargs[val_name + "_rate"]) elif distribution == "center": return Uniform(low=(high - low) / 2, high=(high - low) / 2) elif distribution == "corner": - return Uniform( - low=low, high=low - ) # todo: should be also `low, high` and any other corner + return Uniform(low=low, high=low) # todo: should be also `low, high` and any other corner elif isinstance(distribution, Callable): return distribution(**kwargs) elif distribution == "gaussian_mixture": diff --git a/rl4co/envs/eda/dpp/env.py b/rl4co/envs/eda/dpp/env.py index 10def833..94421075 100644 --- a/rl4co/envs/eda/dpp/env.py +++ b/rl4co/envs/eda/dpp/env.py @@ -1,7 +1,5 @@ import os -from typing import Optional - import numpy as np import torch @@ -90,7 +88,7 @@ def _step(self, td: TensorDict) -> TensorDict: ) return td - def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict: + def _reset(self, td: TensorDict | None = None, batch_size=None) -> TensorDict: device = td.device # Other variables @@ -152,9 +150,7 @@ def _get_reward(self, td, actions): td = td.unsqueeze(0) actions = actions.unsqueeze(0) probes = td["probe"] - reward = torch.stack( - [self._decap_simulator(p, a) for p, a in zip(probes, actions)] - ) + reward = torch.stack([self._decap_simulator(p, a) for p, a in zip(probes, actions)]) return reward @staticmethod @@ -169,9 +165,7 @@ def _decap_placement(self, pi, probe): z1 = self.raw_pdn.to(device) decap = self.decap.reshape(-1).to(device) - z2 = torch.zeros( - (self.num_freq, num_decap, num_decap), dtype=torch.float32, device=device - ) + z2 = torch.zeros((self.num_freq, num_decap, num_decap), dtype=torch.float32, device=device) qIndx = torch.arange(num_decap, device=device) @@ -181,9 +175,7 @@ def _decap_placement(self, pi, probe): pIndx = pi.long() aIndx = torch.arange(len(z1[0]), device=device) - aIndx = torch.tensor( - list(set(aIndx.tolist()) - set(pIndx.tolist())), device=device - ) + aIndx = torch.tensor(list(set(aIndx.tolist()) - set(pIndx.tolist())), device=device) z1aa = z1[:, aIndx, :][:, :, aIndx] z1ap = z1[:, aIndx, :][:, :, pIndx] @@ -220,9 +212,9 @@ def _decap_simulator(self, probe, solution, keepout=None): probe = probe.item() - assert len(solution) == len( - torch.unique(solution) - ), "An Element of Decap Sequence must be Unique" + assert len(solution) == len(torch.unique(solution)), ( + "An Element of Decap Sequence must be Unique" + ) if keepout is not None: keepout = torch.tensor(keepout) diff --git a/rl4co/envs/eda/dpp/generator.py b/rl4co/envs/eda/dpp/generator.py index d34b8e7c..dc3ef78a 100644 --- a/rl4co/envs/eda/dpp/generator.py +++ b/rl4co/envs/eda/dpp/generator.py @@ -1,22 +1,19 @@ import os import zipfile -from typing import Union, Callable -import torch import numpy as np +import torch from robust_downloader import download -from torch.distributions import Uniform from tensordict.tensordict import TensorDict from rl4co.data.utils import load_npz_to_tensordict +from rl4co.envs.common.utils import Generator from rl4co.utils.pylogger import get_pylogger -from rl4co.envs.common.utils import get_sampler, Generator log = get_pylogger(__name__) - class DPPGenerator(Generator): """Data generator for the Decap Placement Problem (DPP). @@ -32,7 +29,7 @@ class DPPGenerator(Generator): decap_file: Name of the decap file. Defaults to "01nF_decap.npy". freq_file: Name of the frequency file. Defaults to "freq_201.npy". url: URL to download data from. Defaults to None. - + Returns: A TensorDict with the following keys: locs [batch_size, num_loc, 2]: locations of each customer @@ -40,6 +37,7 @@ class DPPGenerator(Generator): demand [batch_size, num_loc]: demand of each customer capacity [batch_size]: capacity of the vehicle """ + def __init__( self, min_loc: float = 0.0, @@ -52,7 +50,7 @@ def __init__( decap_file: str = "01nF_decap.npy", freq_file: str = "freq_201.npy", url: str = None, - **unused_kwargs + **unused_kwargs, ): self.min_loc = min_loc self.max_loc = max_loc @@ -65,25 +63,20 @@ def __init__( if len(unused_kwargs) > 0: log.error(f"Found {len(unused_kwargs)} unused kwargs: {unused_kwargs}") - # Download and load the data from online dataset self.url = ( "https://github.com/kaist-silab/devformer/raw/main/data/data.zip" if url is None else url ) - self.backup_url = ( - "https://drive.google.com/uc?id=1IEuR2v8Le-mtHWHxwTAbTOPIkkQszI95" - ) + self.backup_url = "https://drive.google.com/uc?id=1IEuR2v8Le-mtHWHxwTAbTOPIkkQszI95" self._load_dpp_data(chip_file, decap_file, freq_file) # Check the validity of the keepout parameters - assert ( - num_keepout_min <= num_keepout_max - ), "num_keepout_min must be <= num_keepout_max" - assert ( - num_keepout_max <= self.size**2 - ), "num_keepout_max must be <= size * size (total number of locations)" + assert num_keepout_min <= num_keepout_max, "num_keepout_min must be <= num_keepout_max" + assert num_keepout_max <= self.size**2, ( + "num_keepout_max must be <= size * size (total number of locations)" + ) def _generate(self, batch_size) -> TensorDict: """ @@ -97,9 +90,7 @@ def _generate(self, batch_size) -> TensorDict: bs = [1] if not batched else batch_size # Create a list of locs on a grid - locs = torch.meshgrid( - torch.arange(m), torch.arange(n) - ) + locs = torch.meshgrid(torch.arange(m), torch.arange(n)) locs = torch.stack(locs, dim=-1).reshape(-1, 2) # normalize the locations by the number of rows and columns locs = locs / torch.tensor([m, n], dtype=torch.float) @@ -155,9 +146,7 @@ def _download_data(self): ) download(self.backup_url, self.data_dir, "data.zip") log.info("Download complete. Unzipping...") - zipfile.ZipFile(os.path.join(self.data_dir, "data.zip"), "r").extractall( - self.data_dir - ) + zipfile.ZipFile(os.path.join(self.data_dir, "data.zip"), "r").extractall(self.data_dir) log.info("Unzip complete. Removing zip file") os.remove(os.path.join(self.data_dir, "data.zip")) diff --git a/rl4co/envs/eda/dpp/render.py b/rl4co/envs/eda/dpp/render.py index 7e8a3db9..ebe8ce39 100644 --- a/rl4co/envs/eda/dpp/render.py +++ b/rl4co/envs/eda/dpp/render.py @@ -66,8 +66,7 @@ def render(self, decaps, probe, action_mask, ax=None, legend=True): if legend: num_unique = 4 handles = [ - plt.Rectangle((0, 0), 1, 1, color=settings[i]["color"]) - for i in range(num_unique) + plt.Rectangle((0, 0), 1, 1, color=settings[i]["color"]) for i in range(num_unique) ] ax.legend( handles, diff --git a/rl4co/envs/eda/mdpp/env.py b/rl4co/envs/eda/mdpp/env.py index 5c4400d8..30f7ac5b 100644 --- a/rl4co/envs/eda/mdpp/env.py +++ b/rl4co/envs/eda/mdpp/env.py @@ -1,5 +1,3 @@ -from typing import Optional - import torch from tensordict.tensordict import TensorDict @@ -70,7 +68,7 @@ def _step(self, td: TensorDict) -> TensorDict: # Step function is the same as DPPEnv, only masking changes return super()._step(td) - def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict: + def _reset(self, td: TensorDict | None = None, batch_size=None) -> TensorDict: # Reset function is the same as DPPEnv, only masking changes due to probes td_reset = super()._reset(td, batch_size=batch_size) @@ -131,10 +129,7 @@ def _get_reward(self, td, actions): # Reward calculation is expensive since we need to run decap simulation (not vectorizable) reward = torch.stack( - [ - self._single_env_reward(td_single, action) - for td_single, action in zip(td, actions) - ] + [self._single_env_reward(td_single, action) for td_single, action in zip(td, actions)] ) return reward diff --git a/rl4co/envs/eda/mdpp/generator.py b/rl4co/envs/eda/mdpp/generator.py index 75767150..92e33d3e 100644 --- a/rl4co/envs/eda/mdpp/generator.py +++ b/rl4co/envs/eda/mdpp/generator.py @@ -1,17 +1,15 @@ import os import zipfile -from typing import Union, Callable -import torch import numpy as np +import torch from robust_downloader import download -from torch.distributions import Uniform from tensordict.tensordict import TensorDict from rl4co.data.utils import load_npz_to_tensordict +from rl4co.envs.common.utils import Generator from rl4co.utils.pylogger import get_pylogger -from rl4co.envs.common.utils import get_sampler, Generator log = get_pylogger(__name__) @@ -31,7 +29,7 @@ class MDPPGenerator(Generator): decap_file: Name of the decap file. Defaults to "01nF_decap.npy". freq_file: Name of the frequency file. Defaults to "freq_201.npy". url: URL to download data from. Defaults to None. - + Returns: A TensorDict with the following keys: locs [batch_size, num_loc, 2]: locations of each customer @@ -39,6 +37,7 @@ class MDPPGenerator(Generator): demand [batch_size, num_loc]: demand of each customer capacity [batch_size]: capacity of the vehicle """ + def __init__( self, min_loc: float = 0.0, @@ -53,7 +52,7 @@ def __init__( decap_file: str = "01nF_decap.npy", freq_file: str = "freq_201.npy", url: str = None, - **unused_kwargs + **unused_kwargs, ): self.min_loc = min_loc self.max_loc = max_loc @@ -68,25 +67,20 @@ def __init__( if len(unused_kwargs) > 0: log.error(f"Found {len(unused_kwargs)} unused kwargs: {unused_kwargs}") - # Download and load the data from online dataset self.url = ( "https://github.com/kaist-silab/devformer/raw/main/data/data.zip" if url is None else url ) - self.backup_url = ( - "https://drive.google.com/uc?id=1IEuR2v8Le-mtHWHxwTAbTOPIkkQszI95" - ) + self.backup_url = "https://drive.google.com/uc?id=1IEuR2v8Le-mtHWHxwTAbTOPIkkQszI95" self._load_dpp_data(chip_file, decap_file, freq_file) # Check the validity of the keepout parameters - assert ( - num_keepout_min <= num_keepout_max - ), "num_keepout_min must be <= num_keepout_max" - assert ( - num_keepout_max <= self.size**2 - ), "num_keepout_max must be <= size * size (total number of locations)" + assert num_keepout_min <= num_keepout_max, "num_keepout_min must be <= num_keepout_max" + assert num_keepout_max <= self.size**2, ( + "num_keepout_max must be <= size * size (total number of locations)" + ) def _generate(self, batch_size) -> TensorDict: m = n = self.size @@ -164,9 +158,7 @@ def _download_data(self): ) download(self.backup_url, self.data_dir, "data.zip") log.info("Download complete. Unzipping...") - zipfile.ZipFile(os.path.join(self.data_dir, "data.zip"), "r").extractall( - self.data_dir - ) + zipfile.ZipFile(os.path.join(self.data_dir, "data.zip"), "r").extractall(self.data_dir) log.info("Unzip complete. Removing zip file") os.remove(os.path.join(self.data_dir, "data.zip")) diff --git a/rl4co/envs/eda/mdpp/render.py b/rl4co/envs/eda/mdpp/render.py index 61194900..b3d89de7 100644 --- a/rl4co/envs/eda/mdpp/render.py +++ b/rl4co/envs/eda/mdpp/render.py @@ -81,9 +81,7 @@ def draw_probe(ax, x, y, color="black"): def draw_keepout(ax, x, y, color="black"): # Backgrund rectangle: same as color but with alpha=0.5 ax.add_patch(Rectangle((x, y), 1, 1, color=color, alpha=0.5)) - ax.add_patch( - RegularPolygon((x + 0.5, y + 0.5), numVertices=6, radius=0.45, color=color) - ) + ax.add_patch(RegularPolygon((x + 0.5, y + 0.5), numVertices=6, radius=0.45, color=color)) size = self.size td = td.detach().cpu() @@ -138,9 +136,7 @@ def draw_keepout(ax, x, y, color="black"): colors = [settings[k]["color"] for k in settings.keys()] labels = [settings[k]["label"] for k in settings.keys()] handles = [ - plt.Rectangle( - (0, 0), 1, 1, color=c, edgecolor="k", linestyle="-", linewidth=1 - ) + plt.Rectangle((0, 0), 1, 1, color=c, edgecolor="k", linestyle="-", linewidth=1) for c in colors ] ax.legend( diff --git a/rl4co/envs/graph/flp/env.py b/rl4co/envs/graph/flp/env.py index aa73b3f9..89331b9f 100644 --- a/rl4co/envs/graph/flp/env.py +++ b/rl4co/envs/graph/flp/env.py @@ -1,5 +1,3 @@ -from typing import Optional - import torch from tensordict.tensordict import TensorDict @@ -73,9 +71,7 @@ def _step(self, td: TensorDict) -> TensorDict: orig_distances = td["orig_distances"] # (batch_size, n_points, n_points) cur_min_dist = ( - gather_by_index( - orig_distances, chosen.nonzero(as_tuple=True)[1].view(batch_size, -1) - ) + gather_by_index(orig_distances, chosen.nonzero(as_tuple=True)[1].view(batch_size, -1)) .view(batch_size, -1, n_points_) .min(dim=1) .values @@ -97,16 +93,14 @@ def _step(self, td: TensorDict) -> TensorDict: ) return td - def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict: + def _reset(self, td: TensorDict | None = None, batch_size=None) -> TensorDict: self.to(td.device) return TensorDict( { # given information "locs": td["locs"], # (batch_size, n_points, dim_loc) - "orig_distances": td[ - "orig_distances" - ], # (batch_size, n_points, n_points) + "orig_distances": td["orig_distances"], # (batch_size, n_points, n_points) "distances": td["distances"], # (batch_size, n_points, n_points) # states changed by actions "chosen": torch.zeros( @@ -137,9 +131,7 @@ def _get_reward(self, td: TensorDict, actions: torch.Tensor) -> torch.Tensor: n_points_ = td["chosen"].shape[-1] orig_distances = td["orig_distances"] cur_min_dist = ( - gather_by_index( - orig_distances, chosen.nonzero(as_tuple=True)[1].view(batch_size_, -1) - ) + gather_by_index(orig_distances, chosen.nonzero(as_tuple=True)[1].view(batch_size_, -1)) .view(batch_size_, -1, n_points_) .min(1) .values.sum(-1) @@ -163,7 +155,4 @@ def get_num_starts(td): @staticmethod def select_start_nodes(td, num_starts): num_loc = td["action_mask"].shape[-1] - return ( - torch.arange(num_starts, device=td.device).repeat_interleave(td.shape[0]) - % num_loc - ) + return torch.arange(num_starts, device=td.device).repeat_interleave(td.shape[0]) % num_loc diff --git a/rl4co/envs/graph/flp/generator.py b/rl4co/envs/graph/flp/generator.py index 6b25ec38..32c835bc 100644 --- a/rl4co/envs/graph/flp/generator.py +++ b/rl4co/envs/graph/flp/generator.py @@ -1,6 +1,6 @@ import math -from typing import Callable +from collections.abc import Callable import torch @@ -50,9 +50,7 @@ def __init__( if kwargs.get("loc_sampler", None) is not None: self.loc_sampler = kwargs["loc_sampler"] else: - self.loc_sampler = get_sampler( - "loc", loc_distribution, min_loc, max_loc, **kwargs - ) + self.loc_sampler = get_sampler("loc", loc_distribution, min_loc, max_loc, **kwargs) def _generate(self, batch_size) -> TensorDict: # Sample locations @@ -64,9 +62,7 @@ def _generate(self, batch_size) -> TensorDict: { "locs": locs, "orig_distances": distances, - "distances": torch.full( - (*batch_size, self.num_loc), max_dist, dtype=torch.float - ), + "distances": torch.full((*batch_size, self.num_loc), max_dist, dtype=torch.float), "chosen": torch.zeros(*batch_size, self.num_loc, dtype=torch.bool), "to_choose": torch.ones(*batch_size, dtype=torch.long) * self.to_choose, }, diff --git a/rl4co/envs/graph/mcp/env.py b/rl4co/envs/graph/mcp/env.py index 3f0275e0..6bfc8fd5 100644 --- a/rl4co/envs/graph/mcp/env.py +++ b/rl4co/envs/graph/mcp/env.py @@ -1,5 +1,3 @@ -from typing import Optional - import torch from tensordict.tensordict import TensorDict @@ -74,9 +72,7 @@ def _step(self, td: TensorDict) -> TensorDict: remaining_membership = remaining_sets.unsqueeze(-1) * td["membership"] batch_indices, set_indices, item_indices = chosen_membership_nonzero.T - chosen_items_indices = chosen_membership[ - batch_indices, set_indices, item_indices - ].long() + chosen_items_indices = chosen_membership[batch_indices, set_indices, item_indices].long() batch_size, n_items = td["weights"].shape @@ -107,7 +103,7 @@ def _step(self, td: TensorDict) -> TensorDict: ) return td - def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict: + def _reset(self, td: TensorDict | None = None, batch_size=None) -> TensorDict: self.to(td.device) return TensorDict( @@ -150,9 +146,7 @@ def _get_reward(self, td: TensorDict, actions: torch.Tensor) -> torch.Tensor: chosen_membership_nonzero = chosen_membership.nonzero() batch_indices, set_indices, item_indices = chosen_membership_nonzero.T - chosen_items_indices = chosen_membership[ - batch_indices, set_indices, item_indices - ].long() + chosen_items_indices = chosen_membership[batch_indices, set_indices, item_indices].long() batch_size, n_items = weights.shape @@ -187,7 +181,4 @@ def get_num_starts(td): @staticmethod def select_start_nodes(td, num_starts): num_sets = td["action_mask"].shape[-1] - return ( - torch.arange(num_starts, device=td.device).repeat_interleave(td.shape[0]) - % num_sets - ) + return torch.arange(num_starts, device=td.device).repeat_interleave(td.shape[0]) % num_sets diff --git a/rl4co/envs/graph/mcp/generator.py b/rl4co/envs/graph/mcp/generator.py index 29ee56d1..fba568d8 100644 --- a/rl4co/envs/graph/mcp/generator.py +++ b/rl4co/envs/graph/mcp/generator.py @@ -1,4 +1,4 @@ -from typing import Callable +from collections.abc import Callable import torch @@ -116,9 +116,7 @@ def _generate(self, batch_size) -> TensorDict: 1, self.num_items + 1, (batch_size, self.num_sets, max_size) ) - cutoffs_masks = torch.arange(self.max_size).view(1, 1, -1) < set_sizes.unsqueeze( - -1 - ) + cutoffs_masks = torch.arange(self.max_size).view(1, 1, -1) < set_sizes.unsqueeze(-1) # Take the masked elements, 0 means the item is invalid membership_tensor = ( membership_tensor_max_size * cutoffs_masks diff --git a/rl4co/envs/routing/atsp/env.py b/rl4co/envs/routing/atsp/env.py index ed944e50..86098664 100644 --- a/rl4co/envs/routing/atsp/env.py +++ b/rl4co/envs/routing/atsp/env.py @@ -1,5 +1,3 @@ -from typing import Optional - import torch from tensordict.tensordict import TensorDict @@ -84,7 +82,7 @@ def _step(td: TensorDict) -> TensorDict: ) return td - def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict: + def _reset(self, td: TensorDict | None = None, batch_size=None) -> TensorDict: # Initialize distance matrix cost_matrix = td["cost_matrix"] device = td.device @@ -148,18 +146,16 @@ def _get_reward(self, td: TensorDict, actions: torch.Tensor) -> torch.Tensor: # Get indexes of tour edges nodes_src = actions nodes_tgt = torch.roll(actions, -1, dims=1) - batch_idx = torch.arange( - distance_matrix.shape[0], device=distance_matrix.device - ).unsqueeze(1) + batch_idx = torch.arange(distance_matrix.shape[0], device=distance_matrix.device).unsqueeze( + 1 + ) # return negative tour length return -distance_matrix[batch_idx, nodes_src, nodes_tgt].sum(-1) @staticmethod def check_solution_validity(td: TensorDict, actions: torch.Tensor): assert ( - torch.arange(actions.size(1), out=actions.data.new()) - .view(1, -1) - .expand_as(actions) + torch.arange(actions.size(1), out=actions.data.new()).view(1, -1).expand_as(actions) == actions.data.sort(1)[0] ).all(), "Invalid tour" diff --git a/rl4co/envs/routing/atsp/generator.py b/rl4co/envs/routing/atsp/generator.py index 5e23a83b..22dc84d6 100644 --- a/rl4co/envs/routing/atsp/generator.py +++ b/rl4co/envs/routing/atsp/generator.py @@ -1,4 +1,4 @@ -from typing import Callable +from collections.abc import Callable import torch @@ -53,12 +53,12 @@ def _generate(self, batch_size) -> TensorDict: # We satifsy the triangle inequality (TMAT class) in a batch batch_size = [batch_size] if isinstance(batch_size, int) else batch_size dms = ( - self.dist_sampler.sample((batch_size + [self.num_loc, self.num_loc])) + self.dist_sampler.sample(batch_size + [self.num_loc, self.num_loc]) * (self.max_dist - self.min_dist) + self.min_dist ) dms[..., torch.arange(self.num_loc), torch.arange(self.num_loc)] = 0 - log.info("Using TMAT class (triangle inequality): {}".format(self.tmat_class)) + log.info(f"Using TMAT class (triangle inequality): {self.tmat_class}") if self.tmat_class: for i in range(self.num_loc): dms = torch.minimum(dms, dms[..., :, [i]] + dms[..., [i], :]) diff --git a/rl4co/envs/routing/cvrp/env.py b/rl4co/envs/routing/cvrp/env.py index 413d8a7c..b1548b0e 100644 --- a/rl4co/envs/routing/cvrp/env.py +++ b/rl4co/envs/routing/cvrp/env.py @@ -1,5 +1,3 @@ -from typing import Optional - import torch from tensordict.tensordict import TensorDict @@ -14,7 +12,7 @@ try: from .local_search import local_search -except: # In case when we fail to build HGS +except Exception: # In case when we fail to build HGS local_search = None from .render import render @@ -75,9 +73,7 @@ def _step(self, td: TensorDict) -> TensorDict: ) # Increase capacity if depot is not visited, otherwise set to 0 - used_capacity = (td["used_capacity"] + selected_demand) * ( - current_node != 0 - ).float() + used_capacity = (td["used_capacity"] + selected_demand) * (current_node != 0).float() # Note: here we do not subtract one as we have to scatter so the first column allows scattering depot # Add one dimension since we write a single value @@ -101,8 +97,8 @@ def _step(self, td: TensorDict) -> TensorDict: def _reset( self, - td: Optional[TensorDict] = None, - batch_size: Optional[list] = None, + td: TensorDict | None = None, + batch_size: list | None = None, ) -> TensorDict: device = td.device @@ -111,9 +107,7 @@ def _reset( { "locs": torch.cat((td["depot"][:, None, :], td["locs"]), -2), "demand": td["demand"], - "current_node": torch.zeros( - *batch_size, 1, dtype=torch.long, device=device - ), + "current_node": torch.zeros(*batch_size, 1, dtype=torch.long, device=device), "used_capacity": torch.zeros((*batch_size, 1), device=device), "vehicle_capacity": torch.full( (*batch_size, 1), self.generator.vehicle_capacity, device=device @@ -138,9 +132,7 @@ def get_action_mask(td: TensorDict) -> torch.Tensor: mask_loc = td["visited"][..., 1:].to(exceeds_cap.dtype) | exceeds_cap # Cannot visit the depot if just visited and still unserved nodes - mask_depot = (td["current_node"] == 0) & ((mask_loc == 0).int().sum(-1) > 0)[ - :, None - ] + mask_depot = (td["current_node"] == 0) & ((mask_loc == 0).int().sum(-1) > 0)[:, None] return ~torch.cat((mask_depot, mask_loc), -1) def _get_reward(self, td: TensorDict, actions: TensorDict) -> TensorDict: @@ -180,9 +172,9 @@ def check_solution_validity(td: TensorDict, actions: torch.Tensor): ] # This will reset/make capacity negative if i == 0, e.g. depot visited # Cannot use less than 0 used_cap[used_cap < 0] = 0 - assert ( - used_cap <= td["vehicle_capacity"][:, 0] + 1e-5 - ).all(), "Used more than capacity" + assert (used_cap <= td["vehicle_capacity"][:, 0] + 1e-5).all(), ( + "Used more than capacity" + ) @staticmethod def load_data(fpath, batch_size=[]): @@ -254,9 +246,9 @@ def replace_selected_actions( @staticmethod def local_search(td: TensorDict, actions: torch.Tensor, **kwargs) -> torch.Tensor: - assert ( - local_search is not None - ), "Cannot import local_search module. Check `rl4co/envs/routing/cvrp/README.md` for instructions to build HGS." + assert local_search is not None, ( + "Cannot import local_search module. Check `rl4co/envs/routing/cvrp/README.md` for instructions to build HGS." + ) return local_search(td, actions, **kwargs) @staticmethod diff --git a/rl4co/envs/routing/cvrp/generator.py b/rl4co/envs/routing/cvrp/generator.py index b6620ee1..c158c3c0 100644 --- a/rl4co/envs/routing/cvrp/generator.py +++ b/rl4co/envs/routing/cvrp/generator.py @@ -1,4 +1,4 @@ -from typing import Callable +from collections.abc import Callable import torch @@ -77,9 +77,7 @@ def __init__( if kwargs.get("loc_sampler", None) is not None: self.loc_sampler = kwargs["loc_sampler"] else: - self.loc_sampler = get_sampler( - "loc", loc_distribution, min_loc, max_loc, **kwargs - ) + self.loc_sampler = get_sampler("loc", loc_distribution, min_loc, max_loc, **kwargs) # Depot distribution if kwargs.get("depot_sampler", None) is not None: @@ -100,9 +98,7 @@ def __init__( ) # Capacity - if ( - capacity is None - ): # If not provided, use the default capacity from Kool et al. 2019 + if capacity is None: # If not provided, use the default capacity from Kool et al. 2019 capacity = CAPACITIES.get(num_loc, None) if ( capacity is None @@ -116,7 +112,6 @@ def __init__( self.capacity = capacity def _generate(self, batch_size) -> TensorDict: - # Sample locations: depot and customers if self.depot_sampler is not None: depot = self.depot_sampler.sample((*batch_size, 2)) diff --git a/rl4co/envs/routing/cvrp/local_search.py b/rl4co/envs/routing/cvrp/local_search.py index 7b6580de..8217f408 100644 --- a/rl4co/envs/routing/cvrp/local_search.py +++ b/rl4co/envs/routing/cvrp/local_search.py @@ -1,31 +1,21 @@ +import concurrent.futures import os import platform -from ctypes import ( - Structure, - CDLL, - POINTER, - c_int, - c_double, - c_char, - sizeof, - cast, - byref, -) +import random +import sys +import time + +from ctypes import CDLL, POINTER, Structure, byref, c_char, c_double, c_int, cast, sizeof from dataclasses import dataclass -from typing import List -import concurrent.futures import numpy as np -import sys -import random -import time import torch + from tensordict.tensordict import TensorDict from rl4co.utils.ops import get_distance_matrix from rl4co.utils.pylogger import get_pylogger - log = get_pylogger(__name__) @@ -64,7 +54,7 @@ def local_search( else: distances_np = distances.detach().cpu().numpy() - subroutes_all: List[List[List[int]]] = [get_subroutes(path) for path in actions_np] + subroutes_all: list[list[list[int]]] = [get_subroutes(path) for path in actions_np] with concurrent.futures.ThreadPoolExecutor() as executor: futures = [] for i in range(len(subroutes_all)): @@ -86,32 +76,35 @@ def local_search( # Remove heading and tailing zeros max_pos = np.max(np.where(new_actions != 0)[1]) - new_actions = new_actions[:, 1: max_pos + 1] + new_actions = new_actions[:, 1 : max_pos + 1] new_actions = torch.from_numpy(new_actions).to(td.device) # Check the validity of the solution and use the original solution if the new solution is invalid isvalid = check_validity(td, new_actions) - if not isvalid.all(): + if not isvalid.all(): new_actions[~isvalid] = 0 orig_valid_actions = actions[~isvalid] # pad if needed orig_max_pos = torch.max(torch.where(orig_valid_actions != 0)[1]) + 1 if orig_max_pos > max_pos: new_actions = torch.nn.functional.pad( - new_actions, (0, orig_max_pos - max_pos, 0, 0), mode="constant", value=0 # type: ignore + new_actions, + (0, orig_max_pos - max_pos, 0, 0), + mode="constant", + value=0, # type: ignore ) new_actions[~isvalid, :orig_max_pos] = orig_valid_actions[:, :orig_max_pos] return new_actions -def get_subroutes(path, end_with_zero = True) -> List[List[int]]: +def get_subroutes(path, end_with_zero=True) -> list[list[int]]: x = np.where(path == 0)[0] subroutes = [] for i, j in zip(x, x[1:]): if j - i > 1: if end_with_zero: j = j + 1 - subroutes.append(path[i: j]) + subroutes.append(path[i:j]) return subroutes @@ -121,7 +114,7 @@ def merge_subroutes(subroutes, length): for r in subroutes: if len(r) > 2: r = r[:-1] # remove the last zero - route[i: i + len(r)] = r + route[i : i + len(r)] = r i += len(r) return route @@ -150,14 +143,10 @@ def check_validity(td: TensorDict, actions: torch.Tensor) -> torch.Tensor: used_cap = torch.zeros_like(td["demand"][:, 0]) valid = torch.ones(batch_size, dtype=torch.bool) for i in range(actions.size(1)): - used_cap += d[ - :, i - ] # This will reset/make capacity negative if i == 0, e.g. depot visited + used_cap += d[:, i] # This will reset/make capacity negative if i == 0, e.g. depot visited # Cannot use less than 0 used_cap[used_cap < 0] = 0 - valid &= ( - used_cap <= td["vehicle_capacity"][:, 0] + 1e-5 - ) + valid &= used_cap <= td["vehicle_capacity"][:, 0] + 1e-5 return valid @@ -171,16 +160,16 @@ def check_validity(td: TensorDict, actions: torch.Tensor) -> torch.Tensor: C_DBL_MAX = sys.float_info.max -def write_routes(routes: List[np.ndarray], filepath: str): +def write_routes(routes: list[np.ndarray], filepath: str): with open(filepath, "w") as f: for i, r in enumerate(routes): - f.write(f"Route #{i + 1}: "+' '.join([str(x) for x in r if x > 0])+"\n") + f.write(f"Route #{i + 1}: " + " ".join([str(x) for x in r if x > 0]) + "\n") return -def read_routes(filepath) -> List[np.ndarray]: +def read_routes(filepath) -> list[np.ndarray]: routes = [] - with open(filepath, "r") as f: + with open(filepath) as f: while 1: line = f.readline().strip() if line.startswith("Route"): @@ -315,8 +304,8 @@ def __init__(self, parameters=AlgorithmParameters(), verbose=False): self._c_api_delete_sol = hgs_library.delete_solution self._c_api_delete_sol.restype = None self._c_api_delete_sol.argtypes = [POINTER(_Solution)] - - def local_search(self, data, routes: List[np.ndarray], count: int = 1) -> List[np.ndarray]: + + def local_search(self, data, routes: list[np.ndarray], count: int = 1) -> list[np.ndarray]: # required data demand = np.asarray(data["demands"]) vehicle_capacity = data["vehicle_capacity"] @@ -367,10 +356,10 @@ def local_search(self, data, routes: List[np.ndarray], count: int = 1) -> List[n assert dist_mtx.shape[0] == dist_mtx.shape[1] assert (dist_mtx >= 0.0).all() - callid = (time.time_ns()*100000+random.randint(0,100000))%C_INT_MAX + callid = (time.time_ns() * 100000 + random.randint(0, 100000)) % C_INT_MAX - tmppath = "/tmp/route-{}".format(callid) - resultpath = "/tmp/swapstar-result-{}".format(callid) + tmppath = f"/tmp/route-{callid}" + resultpath = f"/tmp/swapstar-result-{callid}" write_routes(routes, tmppath) try: self._local_search( @@ -396,7 +385,7 @@ def local_search(self, data, routes: List[np.ndarray], count: int = 1) -> List[n os.remove(resultpath) finally: os.remove(tmppath) - + return result def _local_search( @@ -413,7 +402,7 @@ def _local_search( algorithm_parameters: AlgorithmParameters, verbose: bool, callid: int, - count:int, + count: int, ): n_nodes = x_coords.size @@ -425,7 +414,6 @@ def _local_search( m_ct = dist_mtx.reshape(n_nodes * n_nodes).astype(c_double).ctypes ap_ct = algorithm_parameters.ctypes - # struct Solution *solve_cvrp_dist_mtx( # int n, double* x, double* y, double *dist_mtx, double *serv_time, double *dem, # double vehicleCapacity, double durationLimit, char isDurationConstraint, @@ -451,22 +439,22 @@ def _local_search( return result -def swapstar(demands, matrix, positions, routes: List[np.ndarray], count=1): +def swapstar(demands, matrix, positions, routes: list[np.ndarray], count=1): ap = AlgorithmParameters() hgs_solver = Solver(parameters=ap, verbose=False) data = dict() x = positions[:, 0] y = positions[:, 1] - data['x_coordinates'] = x - data['y_coordinates'] = y + data["x_coordinates"] = x + data["y_coordinates"] = y - data['depot'] = 0 - data['demands'] = demands * 1000 + data["depot"] = 0 + data["demands"] = demands * 1000 data["num_vehicles"] = len(routes) - data['vehicle_capacity'] = 1000.001 # to avoid floating-point error + data["vehicle_capacity"] = 1000.001 # to avoid floating-point error # Solve with calculated distances - data['distance_matrix'] = matrix + data["distance_matrix"] = matrix result = hgs_solver.local_search(data, routes, count) return result diff --git a/rl4co/envs/routing/cvrp/render.py b/rl4co/envs/routing/cvrp/render.py index dd99aa75..9ce27c21 100644 --- a/rl4co/envs/routing/cvrp/render.py +++ b/rl4co/envs/routing/cvrp/render.py @@ -86,9 +86,9 @@ def render(td, actions=None, ax=None, skip_depot=True, integer_demands=True): # text demand for node_idx in range(1, len(locs)): demand_text = ( - f"{demands[node_idx-1].int().item()}" + f"{demands[node_idx - 1].int().item()}" if integer_demands - else f"{demands[node_idx-1].item():.2f}" + else f"{demands[node_idx - 1].item():.2f}" ) ax.text( locs[node_idx, 0], diff --git a/rl4co/envs/routing/cvrpmvc/env.py b/rl4co/envs/routing/cvrpmvc/env.py index 5a8780ba..c3ee24dd 100644 --- a/rl4co/envs/routing/cvrpmvc/env.py +++ b/rl4co/envs/routing/cvrpmvc/env.py @@ -28,9 +28,7 @@ def _step(self, td: TensorDict) -> TensorDict: ) # Increase capacity if depot is not visited, otherwise set to 0 - used_capacity = (td["used_capacity"] + selected_demand) * ( - current_node != 0 - ).float() + used_capacity = (td["used_capacity"] + selected_demand) * (current_node != 0).float() demand_remaining = td["demand_remaining"] - selected_demand @@ -56,9 +54,7 @@ def _step(self, td: TensorDict) -> TensorDict: td.set("action_mask", self.get_action_mask(td)) return td - def _reset( - self, td: TensorDict | None = None, batch_size: list | None = None - ) -> TensorDict: + def _reset(self, td: TensorDict | None = None, batch_size: list | None = None) -> TensorDict: td = super()._reset(td, batch_size) batch_size = batch_size or list(td.batch_size) td.set( @@ -66,9 +62,7 @@ def _reset( torch.ones((*batch_size, 1), dtype=torch.int, device=td.device), ) td.set("demand_remaining", td["demand"].sum(-1, keepdim=True)) - td.set( - "max_vehicle", torch.ceil(td["demand_remaining"] / td["vehicle_capacity"]) + 1 - ) + td.set("max_vehicle", torch.ceil(td["demand_remaining"] / td["vehicle_capacity"]) + 1) return td @staticmethod @@ -82,20 +76,14 @@ def get_action_mask(td: TensorDict) -> torch.Tensor: if "vehicles_used" in td.keys(): max_vehicle = td["max_vehicle"] demand_remaining = td["demand_remaining"] - capacity_remaining = (max_vehicle - td["vehicles_used"]) * td[ - "vehicle_capacity" - ] + capacity_remaining = (max_vehicle - td["vehicles_used"]) * td["vehicle_capacity"] mask_depot = ( # mask the depot (td["current_node"] == 0) # if the depot is just visited | ( demand_remaining > capacity_remaining ) # or the unassigned vehicles' capacity can't sastify remaining demands - ) & ~torch.all( - mask_loc, dim=-1, keepdim=True - ) # unless there's no other choices + ) & ~torch.all(mask_loc, dim=-1, keepdim=True) # unless there's no other choices else: # Cannot visit the depot if just visited and still unserved nodes - mask_depot = (td["current_node"] == 0) & ~torch.all( - mask_loc, dim=-1, keepdim=True - ) + mask_depot = (td["current_node"] == 0) & ~torch.all(mask_loc, dim=-1, keepdim=True) return ~torch.cat((mask_depot, mask_loc), -1) diff --git a/rl4co/envs/routing/cvrptw/env.py b/rl4co/envs/routing/cvrptw/env.py index 5a26b8cc..6c5a19d7 100644 --- a/rl4co/envs/routing/cvrptw/env.py +++ b/rl4co/envs/routing/cvrptw/env.py @@ -1,15 +1,9 @@ -from typing import Optional - import torch from tensordict.tensordict import TensorDict from torchrl.data import Bounded, Composite, Unbounded -from rl4co.data.utils import ( - load_npz_to_tensordict, - load_solomon_instance, - load_solomon_solution, -) +from rl4co.data.utils import load_npz_to_tensordict, load_solomon_instance, load_solomon_solution from rl4co.envs.routing.cvrp.env import CVRPEnv from rl4co.utils.ops import gather_by_index, get_distance @@ -129,20 +123,14 @@ def _step(self, td: TensorDict) -> TensorDict: td = super()._step(td) return td - def _reset( - self, td: Optional[TensorDict] = None, batch_size: Optional[list] = None - ) -> TensorDict: + def _reset(self, td: TensorDict | None = None, batch_size: list | None = None) -> TensorDict: device = td.device td_reset = TensorDict( { "locs": torch.cat((td["depot"][..., None, :], td["locs"]), -2), "demand": td["demand"], - "current_node": torch.zeros( - *batch_size, 1, dtype=torch.long, device=device - ), - "current_time": torch.zeros( - *batch_size, 1, dtype=torch.float32, device=device - ), + "current_node": torch.zeros(*batch_size, 1, dtype=torch.long, device=device), + "current_time": torch.zeros(*batch_size, 1, dtype=torch.float32, device=device), "used_capacity": torch.zeros((*batch_size, 1), device=device), "vehicle_capacity": torch.full( (*batch_size, 1), self.generator.vehicle_capacity, device=device @@ -170,9 +158,7 @@ def check_solution_validity(td: TensorDict, actions: torch.Tensor) -> None: CVRPEnv.check_solution_validity(td, actions) batch_size = td["locs"].shape[0] # distances to depot - distances = get_distance( - td["locs"][..., 0, :], td["locs"].transpose(0, 1) - ).transpose(0, 1) + distances = get_distance(td["locs"][..., 0, :], td["locs"].transpose(0, 1)).transpose(0, 1) # basic checks on time windows assert torch.all(distances >= 0.0), "Distances must be non-negative." assert torch.all(td["time_windows"] >= 0.0), "Time windows must be non-negative." @@ -180,12 +166,10 @@ def check_solution_validity(td: TensorDict, actions: torch.Tensor) -> None: td["time_windows"][..., :, 0] + distances + td["durations"] <= td["time_windows"][..., 0, 1][0] # max_time is the same for all batches ), "vehicle cannot perform service and get back to depot in time." - assert torch.all( - td["durations"] >= 0.0 - ), "Service durations must be non-negative." - assert torch.all( - td["time_windows"][..., 0] < td["time_windows"][..., 1] - ), "there are unfeasible time windows" + assert torch.all(td["durations"] >= 0.0), "Service durations must be non-negative." + assert torch.all(td["time_windows"][..., 0] < td["time_windows"][..., 1]), ( + "there are unfeasible time windows" + ) # check vehicles can meet deadlines curr_time = torch.zeros(batch_size, 1, dtype=torch.float32, device=td.device) curr_node = torch.zeros_like(curr_time, dtype=torch.int64, device=td.device) @@ -197,15 +181,11 @@ def check_solution_validity(td: TensorDict, actions: torch.Tensor) -> None: ).reshape([batch_size, 1]) curr_time = torch.max( (curr_time + dist).int(), - gather_by_index(td["time_windows"], next_node)[..., 0].reshape( - [batch_size, 1] - ), + gather_by_index(td["time_windows"], next_node)[..., 0].reshape([batch_size, 1]), ) assert torch.all( curr_time - <= gather_by_index(td["time_windows"], next_node)[..., 1].reshape( - [batch_size, 1] - ) + <= gather_by_index(td["time_windows"], next_node)[..., 1].reshape([batch_size, 1]) ), "vehicle cannot start service before deadline" curr_time = curr_time + gather_by_index(td["durations"], next_node).reshape( [batch_size, 1] @@ -250,9 +230,7 @@ def extract_from_solomon(self, instance: dict, batch_size: int = 1): self.max_time = instance["time_window"][:, 1].max() # assert the time window of the depot starts at 0 and ends at max_time assert self.min_time == 0, "Time window of depot must start at 0." - assert ( - self.max_time == instance["time_window"][0, 1] - ), "Depot must have latest end time." + assert self.max_time == instance["time_window"][0, 1], "Depot must have latest end time." # convert to format used in CVRPTWEnv td = TensorDict( { diff --git a/rl4co/envs/routing/cvrptw/generator.py b/rl4co/envs/routing/cvrptw/generator.py index 8938c5de..899ba46c 100644 --- a/rl4co/envs/routing/cvrptw/generator.py +++ b/rl4co/envs/routing/cvrptw/generator.py @@ -1,4 +1,4 @@ -from typing import Callable +from collections.abc import Callable import torch @@ -140,9 +140,9 @@ def _generate(self, batch_size) -> TensorDict: # 8. stack to tensor time_windows time_windows = torch.stack((min_times, max_times), dim=-1) - assert torch.all( - min_times < max_times - ), "Please make sure the relation between max_loc and max_time allows for feasible solutions." + assert torch.all(min_times < max_times), ( + "Please make sure the relation between max_loc and max_time allows for feasible solutions." + ) # Reset duration at depot to 0 durations[:, 0] = 0.0 diff --git a/rl4co/envs/routing/cvrptw/render.py b/rl4co/envs/routing/cvrptw/render.py index 88b7d35c..f0a9bf6b 100644 --- a/rl4co/envs/routing/cvrptw/render.py +++ b/rl4co/envs/routing/cvrptw/render.py @@ -88,7 +88,7 @@ def render(td, actions=None, ax=None): ax.text( locs[node_idx, 0], locs[node_idx, 1] - 0.025, - f"{demands[node_idx-1].item():.2f}", + f"{demands[node_idx - 1].item():.2f}", horizontalalignment="center", verticalalignment="top", fontsize=10, diff --git a/rl4co/envs/routing/mdcpdp/env.py b/rl4co/envs/routing/mdcpdp/env.py index e9597122..a8a76390 100644 --- a/rl4co/envs/routing/mdcpdp/env.py +++ b/rl4co/envs/routing/mdcpdp/env.py @@ -1,5 +1,3 @@ -from typing import Optional - import torch from tensordict.tensordict import TensorDict @@ -117,14 +115,10 @@ def _step(self, td: TensorDict) -> TensorDict: # TODO: better way? available = td["available"] if td["i"][0] > 0: - available = available.scatter( - -1, current_node.expand_as(td["action_mask"]), 0 - ) + available = available.scatter(-1, current_node.expand_as(td["action_mask"]), 0) # Record the to be delivered node - to_deliver = td["to_deliver"].scatter( - -1, new_to_deliver.expand_as(td["to_deliver"]), 1 - ) + to_deliver = td["to_deliver"].scatter(-1, new_to_deliver.expand_as(td["to_deliver"]), 1) # Update number of current carry orders current_carry = td["current_carry"] @@ -136,9 +130,7 @@ def _step(self, td: TensorDict) -> TensorDict: # Update the current depot # current_depot = td["current_depot"] # current_depot = torch.where(back_flag, current_node, current_depot) - current_depot = torch.where( - current_node < num_agents, current_node, td["current_depot"] - ) + current_depot = torch.where(current_node < num_agents, current_node, td["current_depot"]) # Update the length of current tour current_length = td["current_length"] @@ -168,9 +160,7 @@ def _step(self, td: TensorDict) -> TensorDict: # Update the arrive time for each city arrivetime_record = td["arrivetime_record"] - arrivetime_record.scatter_( - -1, current_node, current_length.gather(-1, current_depot) - ) + arrivetime_record.scatter_(-1, current_node, current_length.gather(-1, current_depot)) # Action is feasible if the node is not visited and is to deliver action_mask = available & to_deliver @@ -184,39 +174,27 @@ def _step(self, td: TensorDict) -> TensorDict: # If back to the current depot, this tour is done, set other depots to availbe to start # a new tour. Must start from a depot. - action_mask[..., num_agents:] &= ~back_flag.expand_as( - action_mask[..., num_agents:] - ) + action_mask[..., num_agents:] &= ~back_flag.expand_as(action_mask[..., num_agents:]) # If back to the depot, other unvisited depots are available # if not back to the depot, depots are not available except the current depot - action_mask[..., :num_agents] &= back_flag.expand_as( - action_mask[..., :num_agents] - ) + action_mask[..., :num_agents] &= back_flag.expand_as(action_mask[..., :num_agents]) action_mask[..., :num_agents].scatter_(-1, current_depot, ~back_flag) # If this is the last agent, it has to finish all the left taks - last_depot_flag = ( - torch.sum(available[..., :num_agents].long(), dim=-1, keepdim=True) == 0 - ) - action_mask[..., :num_agents] &= ~last_depot_flag.expand_as( - action_mask[..., :num_agents] - ) + last_depot_flag = torch.sum(available[..., :num_agents].long(), dim=-1, keepdim=True) == 0 + action_mask[..., :num_agents] &= ~last_depot_flag.expand_as(action_mask[..., :num_agents]) # Update depot mask carry_flag = current_carry > 0 # If agent is carrying orders - action_mask[ - ..., :num_agents - ] &= ~carry_flag # If carrying orders, depot is not available + action_mask[..., :num_agents] &= ~carry_flag # If carrying orders, depot is not available # 1) current node is a depot # 2) we did not just come back # 3) it is not the first step # cannot go to other depots prev_depot_flag = (current_node < num_agents) & (td["i"] > 0) & ~back_flag - action_mask[..., :num_agents] &= ~prev_depot_flag.expand_as( - action_mask[..., :num_agents] - ) + action_mask[..., :num_agents] &= ~prev_depot_flag.expand_as(action_mask[..., :num_agents]) # We are done there are no unvisited locations # done = torch.count_nonzero(available, dim=-1) == 0 @@ -249,7 +227,7 @@ def _step(self, td: TensorDict) -> TensorDict: ) return td - def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict: + def _reset(self, td: TensorDict | None = None, batch_size=None) -> TensorDict: device = td.device if "depots" in td: @@ -290,17 +268,13 @@ def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict low=0, high=num_agents, size=(*batch_size, 1), device=device ) elif self.start_mode == "order": - current_depot = torch.zeros( - (*batch_size, 1), dtype=torch.int64, device=device - ) + current_depot = torch.zeros((*batch_size, 1), dtype=torch.int64, device=device) # Current carry order number current_carry = torch.zeros((*batch_size, 1), dtype=torch.int64, device=device) # Current length of each depot - current_length = torch.zeros( - (*batch_size, num_agents), dtype=torch.float32, device=device - ) + current_length = torch.zeros((*batch_size, num_agents), dtype=torch.float32, device=device) # Arrive time for each city arrivetime_record = torch.zeros( diff --git a/rl4co/envs/routing/mdcpdp/generator.py b/rl4co/envs/routing/mdcpdp/generator.py index 9839d3e4..79e492d0 100644 --- a/rl4co/envs/routing/mdcpdp/generator.py +++ b/rl4co/envs/routing/mdcpdp/generator.py @@ -1,4 +1,4 @@ -from typing import Callable, Union +from collections.abc import Callable import torch @@ -41,15 +41,15 @@ def __init__( num_loc: int = 20, min_loc: float = 0.0, max_loc: float = 1.0, - loc_distribution: Union[int, float, str, type, Callable] = Uniform, + loc_distribution: int | float | str | type | Callable = Uniform, num_agents: int = 5, depot_mode: str = "multiple", - depot_distribution: Union[int, float, str, type, Callable] = Uniform, + depot_distribution: int | float | str | type | Callable = Uniform, min_capacity: int = 3, max_capacity: int = 3, min_lateness_weight: float = 1.0, max_lateness_weight: float = 1.0, - lateness_weight_distribution: Union[int, float, str, type, Callable] = Uniform, + lateness_weight_distribution: int | float | str | type | Callable = Uniform, **kwargs, ): self.num_loc = num_loc @@ -64,9 +64,7 @@ def __init__( # Number of locations must be even if num_loc % 2 != 0: - log.warning( - "Number of locations must be even. Adding 1 to the number of locations." - ) + log.warning("Number of locations must be even. Adding 1 to the number of locations.") self.num_loc += 1 # Check depot mode validity @@ -76,9 +74,7 @@ def __init__( if kwargs.get("loc_sampler", None) is not None: self.loc_sampler = kwargs["loc_sampler"] else: - self.loc_sampler = get_sampler( - "loc", loc_distribution, min_loc, max_loc, **kwargs - ) + self.loc_sampler = get_sampler("loc", loc_distribution, min_loc, max_loc, **kwargs) # Depot distribution if kwargs.get("depot_sampler", None) is not None: diff --git a/rl4co/envs/routing/mpdp/env.py b/rl4co/envs/routing/mpdp/env.py index 21143abe..9d35ab05 100644 --- a/rl4co/envs/routing/mpdp/env.py +++ b/rl4co/envs/routing/mpdp/env.py @@ -1,5 +1,3 @@ -from typing import Optional - import torch from tensordict.tensordict import TensorDict @@ -88,9 +86,9 @@ def _step(self, td: TensorDict) -> TensorDict: remain_pickup_max_distance = depot_distance[:, : agent_num + 1 + n_loc // 2].max( dim=-1, keepdim=True )[0] - remain_delivery_max_distance = depot_distance[ - :, agent_num + 1 + n_loc // 2 : - ].max(dim=-1, keepdim=True)[0] + remain_delivery_max_distance = depot_distance[:, agent_num + 1 + n_loc // 2 :].max( + dim=-1, keepdim=True + )[0] # Calculate makespan cur_coord = gather_by_index(td["locs"], selected) @@ -99,9 +97,9 @@ def _step(self, td: TensorDict) -> TensorDict: td["lengths"].scatter_add_(-1, td["count_depot"], path_lengths.unsqueeze(-1)) # If visit depot then plus one to count_depot\ - td["count_depot"][ - (selected == td["agent_idx"]) & (td["agent_idx"] < agent_num) - ] += 1 # torch.ones(td["count_depot"][(selected == 0) & (td["agent_idx"] < agent_num)].shape, dtype=torch.int64, device=td["count_depot"].device) + td["count_depot"][(selected == td["agent_idx"]) & (td["agent_idx"] < agent_num)] += ( + 1 # torch.ones(td["count_depot"][(selected == 0) & (td["agent_idx"] < agent_num)].shape, dtype=torch.int64, device=td["count_depot"].device) + ) # `agent_idx` is added by 1 if the current agent comes back to depot agent_idx = (td["count_depot"] + 1) * torch.ones( @@ -134,9 +132,9 @@ def _step(self, td: TensorDict) -> TensorDict: def _reset( self, - td: Optional[TensorDict] = None, - batch_size: Optional[list] = None, - agent_num: Optional[int] = None, # NOTE hardcoded from ET + td: TensorDict | None = None, + batch_size: list | None = None, + agent_num: int | None = None, # NOTE hardcoded from ET ) -> TensorDict: device = td.device @@ -215,9 +213,7 @@ def _reset( batch_size, dtype=torch.int64, device=device ), # Vector with length num_steps "to_delivery": to_delivery, - "count_depot": torch.zeros( - batch_size, 1, dtype=torch.int64, device=device - ), + "count_depot": torch.zeros(batch_size, 1, dtype=torch.int64, device=device), "agent_idx": torch.ones(batch_size, 1, dtype=torch.long, device=device), "left_request": left_request * torch.ones(batch_size, 1, dtype=torch.long, device=device), @@ -267,9 +263,7 @@ def get_action_mask(td: TensorDict) -> torch.Tensor: return ( torch.cat( [ - torch.zeros( - batch_size, 1, 1, dtype=torch.uint8, device=mask_loc.device - ), + torch.zeros(batch_size, 1, 1, dtype=torch.uint8, device=mask_loc.device), torch.ones( batch_size, 1, diff --git a/rl4co/envs/routing/mpdp/generator.py b/rl4co/envs/routing/mpdp/generator.py index 70b92b24..fc8f135b 100644 --- a/rl4co/envs/routing/mpdp/generator.py +++ b/rl4co/envs/routing/mpdp/generator.py @@ -1,4 +1,4 @@ -from typing import Callable +from collections.abc import Callable import torch @@ -48,18 +48,14 @@ def __init__( # Number of locations must be even if num_loc % 2 != 0: - log.warning( - "Number of locations must be even. Adding 1 to the number of locations." - ) + log.warning("Number of locations must be even. Adding 1 to the number of locations.") self.num_loc += 1 # Location distribution if kwargs.get("loc_sampler", None) is not None: self.loc_sampler = kwargs["loc_sampler"] else: - self.loc_sampler = get_sampler( - "loc", loc_distribution, min_loc, max_loc, **kwargs - ) + self.loc_sampler = get_sampler("loc", loc_distribution, min_loc, max_loc, **kwargs) # Depot distribution if kwargs.get("depot_sampler", None) is not None: diff --git a/rl4co/envs/routing/mtsp/env.py b/rl4co/envs/routing/mtsp/env.py index bf39cf74..04ab2059 100644 --- a/rl4co/envs/routing/mtsp/env.py +++ b/rl4co/envs/routing/mtsp/env.py @@ -1,5 +1,3 @@ -from typing import Optional - import torch from tensordict.tensordict import TensorDict @@ -132,7 +130,7 @@ def _step(td: TensorDict) -> TensorDict: return td - def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict: + def _reset(self, td: TensorDict | None = None, batch_size=None) -> TensorDict: device = td.device # Keep track of the agent number to know when to stop diff --git a/rl4co/envs/routing/mtsp/generator.py b/rl4co/envs/routing/mtsp/generator.py index d2402580..22f44854 100644 --- a/rl4co/envs/routing/mtsp/generator.py +++ b/rl4co/envs/routing/mtsp/generator.py @@ -1,4 +1,4 @@ -from typing import Callable +from collections.abc import Callable import torch @@ -48,9 +48,7 @@ def __init__( if kwargs.get("loc_sampler", None) is not None: self.loc_sampler = kwargs["loc_sampler"] else: - self.loc_sampler = get_sampler( - "loc", loc_distribution, min_loc, max_loc, **kwargs - ) + self.loc_sampler = get_sampler("loc", loc_distribution, min_loc, max_loc, **kwargs) def _generate(self, batch_size) -> TensorDict: # Sample locations diff --git a/rl4co/envs/routing/mtsp/render.py b/rl4co/envs/routing/mtsp/render.py index 7b76ce3b..9b05a797 100644 --- a/rl4co/envs/routing/mtsp/render.py +++ b/rl4co/envs/routing/mtsp/render.py @@ -73,9 +73,7 @@ def discrete_cmap(num, base_cmap="nipy_spectral"): color = cmap(num_agents - agent_idx) from_node = actions[i] - to_node = ( - actions[i + 1] if i < len(actions) - 1 else actions[0] - ) # last goes back to depot + to_node = actions[i + 1] if i < len(actions) - 1 else actions[0] # last goes back to depot from_loc = td["locs"][from_node] to_loc = td["locs"][to_node] ax.plot([from_loc[0], to_loc[0]], [from_loc[1], to_loc[1]], color=color) diff --git a/rl4co/envs/routing/mtvrp/baselines/ortools.py b/rl4co/envs/routing/mtvrp/baselines/ortools.py index 67b31c1a..4743fac0 100644 --- a/rl4co/envs/routing/mtvrp/baselines/ortools.py +++ b/rl4co/envs/routing/mtvrp/baselines/ortools.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import Optional import numpy as np import routefinder.baselines.pyvrp as pyvrp @@ -59,8 +58,8 @@ class ORToolsData: vehicle_capacities: list[int] max_distance: int demands: list[int] - time_windows: Optional[list[list[int]]] - backhauls: Optional[list[int]] + time_windows: list[list[int]] | None + backhauls: list[int] | None @property def num_locations(self) -> int: @@ -132,9 +131,7 @@ def _solve(data: ORToolsData, max_runtime: float, log: bool = False): """ # Manager for converting between nodes (location indices) and index # (internal CP variable indices). - manager = pywrapcp.RoutingIndexManager( - data.num_locations, data.num_vehicles, data.depot - ) + manager = pywrapcp.RoutingIndexManager(data.num_locations, data.num_vehicles, data.depot) routing = pywrapcp.RoutingModel(manager) # Set arc costs equal to distances. diff --git a/rl4co/envs/routing/mtvrp/env.py b/rl4co/envs/routing/mtvrp/env.py index 6beb8eb8..a0b60eaa 100644 --- a/rl4co/envs/routing/mtvrp/env.py +++ b/rl4co/envs/routing/mtvrp/env.py @@ -1,5 +1,3 @@ -from typing import Optional - import torch from tensordict.tensordict import TensorDict @@ -100,22 +98,17 @@ def _step(self, td: TensorDict) -> TensorDict: distance = get_distance(prev_loc, curr_loc)[..., None] # Update current time - service_time = gather_by_index( - src=td["service_time"], idx=curr_node, dim=1, squeeze=False - ) - start_times = gather_by_index( - src=td["time_windows"], idx=curr_node, dim=1, squeeze=False - )[..., 0] + service_time = gather_by_index(src=td["service_time"], idx=curr_node, dim=1, squeeze=False) + start_times = gather_by_index(src=td["time_windows"], idx=curr_node, dim=1, squeeze=False)[ + ..., 0 + ] # we cannot start before we arrive and we should start at least at start times curr_time = (curr_node[:, None] != 0) * ( - torch.max(td["current_time"] + distance / td["speed"], start_times) - + service_time + torch.max(td["current_time"] + distance / td["speed"], start_times) + service_time ) # Update current route length (reset at depot) - curr_route_length = (curr_node[:, None] != 0) * ( - td["current_route_length"] + distance - ) + curr_route_length = (curr_node[:, None] != 0) * (td["current_route_length"] + distance) # Linehaul (delivery) demands selected_demand_linehaul = gather_by_index( @@ -158,8 +151,8 @@ def _step(self, td: TensorDict) -> TensorDict: def _reset( self, - td: Optional[TensorDict] = None, - batch_size: Optional[list] = None, + td: TensorDict | None = None, + batch_size: list | None = None, ) -> TensorDict: device = td.device @@ -176,9 +169,7 @@ def _reset( "vehicle_capacity": td["vehicle_capacity"], "capacity_original": td["capacity_original"], "speed": td["speed"], - "current_node": torch.zeros( - (*batch_size,), dtype=torch.long, device=device - ), + "current_node": torch.zeros((*batch_size,), dtype=torch.long, device=device), "current_route_length": torch.zeros( (*batch_size, 1), dtype=torch.float32, device=device ), # for distance limits @@ -228,15 +219,12 @@ def get_action_mask(td: TensorDict) -> torch.Tensor: # Distance limit (L): do not add distance to depot if open route (O) exceeds_dist_limit = ( - td["current_route_length"] + d_ij + (d_j0 * ~td["open_route"]) - > td["distance_limit"] + td["current_route_length"] + d_ij + (d_j0 * ~td["open_route"]) > td["distance_limit"] ) # Linehaul demand / delivery (C) and backhaul demand / pickup (B) # All linehauls are visited before backhauls - linehauls_missing = ((td["demand_linehaul"] * ~td["visited"]).sum(-1) > 0)[ - ..., None - ] + linehauls_missing = ((td["demand_linehaul"] * ~td["visited"]).sum(-1) > 0)[..., None] is_carrying_backhaul = ( gather_by_index( src=td["demand_backhaul"], @@ -307,9 +295,9 @@ def check_solution_validity(td: TensorDict, actions: torch.Tensor): d_j0 = get_distance(locs, locs[..., 0:1, :]) # j (next) -> 0 (depot) assert torch.all(td["time_windows"] >= 0.0), "Time windows must be non-negative." assert torch.all(td["service_time"] >= 0.0), "Service time must be non-negative." - assert torch.all( - td["time_windows"][..., 0] < td["time_windows"][..., 1] - ), "there are unfeasible time windows" + assert torch.all(td["time_windows"][..., 0] < td["time_windows"][..., 1]), ( + "there are unfeasible time windows" + ) assert torch.all( td["time_windows"][..., :, 0] + d_j0 + td["service_time"] <= td["time_windows"][..., 0, 1, None] @@ -328,17 +316,17 @@ def check_solution_validity(td: TensorDict, actions: torch.Tensor): curr_length = curr_length + dist * ~( td["open_route"].squeeze(-1) & (next_node == 0) ) # do not count back to depot for open route - assert torch.all( - curr_length <= td["distance_limit"].squeeze(-1) - ), "Route exceeds distance limit" + assert torch.all(curr_length <= td["distance_limit"].squeeze(-1)), ( + "Route exceeds distance limit" + ) curr_length[next_node == 0] = 0.0 # reset length for depot curr_time = torch.max( curr_time + dist, gather_by_index(td["time_windows"], next_node)[..., 0] ) - assert torch.all( - curr_time <= gather_by_index(td["time_windows"], next_node)[..., 1] - ), "vehicle cannot start service before deadline" + assert torch.all(curr_time <= gather_by_index(td["time_windows"], next_node)[..., 1]), ( + "vehicle cannot start service before deadline" + ) curr_time = curr_time + gather_by_index(td["service_time"], next_node) curr_node = next_node curr_time[curr_node == 0] = 0.0 # reset time for depot @@ -352,9 +340,9 @@ def _check_c1(feature="demand_linehaul"): # reset at depot used_cap = used_cap * (actions[:, ii] != 0) used_cap += demand[:, ii] - assert ( - used_cap <= td["vehicle_capacity"] - ).all(), "Used more than capacity for {}: {}".format(feature, used_cap) + assert (used_cap <= td["vehicle_capacity"]).all(), ( + f"Used more than capacity for {feature}: {used_cap}" + ) _check_c1("demand_linehaul") _check_c1("demand_backhaul") @@ -386,9 +374,7 @@ def select_start_nodes(self, td, num_starts): """Select available start nodes for the environment (e.g. for POMO-based training)""" num_loc = td["locs"].shape[-2] - 1 selected = ( - torch.arange(num_starts, device=td.device).repeat_interleave(td.shape[0]) - % num_loc - + 1 + torch.arange(num_starts, device=td.device).repeat_interleave(td.shape[0]) % num_loc + 1 ) return selected @@ -472,9 +458,7 @@ def get_variant_names(td): has_backhaul, ) = MTVRPEnv.check_variants(td) instance_names = [] - for o, b, l_, tw in zip( - has_open, has_backhaul, has_duration_limit, has_time_window - ): + for o, b, l_, tw in zip(has_open, has_backhaul, has_duration_limit, has_time_window): if not o and not b and not l_ and not tw: instance_name = "CVRP" else: diff --git a/rl4co/envs/routing/mtvrp/generator.py b/rl4co/envs/routing/mtvrp/generator.py index c4258eee..82b4118c 100644 --- a/rl4co/envs/routing/mtvrp/generator.py +++ b/rl4co/envs/routing/mtvrp/generator.py @@ -1,4 +1,4 @@ -from typing import Callable, Tuple +from collections.abc import Callable import torch @@ -121,9 +121,7 @@ def __init__( if kwargs.get("loc_sampler", None) is not None: self.loc_sampler = kwargs["loc_sampler"] else: - self.loc_sampler = get_sampler( - "loc", loc_distribution, min_loc, max_loc, **kwargs - ) + self.loc_sampler = get_sampler("loc", loc_distribution, min_loc, max_loc, **kwargs) if capacity is None: capacity = get_vehicle_capacity(num_loc) @@ -139,16 +137,16 @@ def __init__( self.distance_limit = distance_limit self.speed = speed - assert not ( - subsample and (variant_preset is None) - ), "Cannot use subsample if variant_preset is not specified. " + assert not (subsample and (variant_preset is None)), ( + "Cannot use subsample if variant_preset is not specified. " + ) if variant_preset is not None: log.info(f"Using variant generation preset {variant_preset}") variant_probs = VARIANT_GENERATION_PRESETS.get(variant_preset) - assert ( - variant_probs is not None - ), f"Variant generation preset {variant_preset} not found. \ + assert variant_probs is not None, ( + f"Variant generation preset {variant_preset} not found. \ Available presets are {VARIANT_GENERATION_PRESETS.keys()} with probabilities {VARIANT_GENERATION_PRESETS.values()}" + ) else: variant_probs = { "O": prob_open, @@ -172,9 +170,7 @@ def _generate(self, batch_size) -> TensorDict: locs = self.generate_locations(batch_size=batch_size, num_loc=self.num_loc) # Vehicle capacity (C, B) - applies to both linehaul and backhaul - vehicle_capacity = torch.full( - (*batch_size, 1), self.capacity, dtype=torch.float32 - ) + vehicle_capacity = torch.full((*batch_size, 1), self.capacity, dtype=torch.float32) capacity_original = vehicle_capacity.clone() # linehaul demand / delivery (C) and backhaul / pickup demand (B) @@ -182,12 +178,8 @@ def _generate(self, batch_size) -> TensorDict: batch_size=batch_size, num_loc=self.num_loc ) # add empty depot demands - demand_linehaul = torch.cat( - [torch.zeros(size=(*batch_size, 1)), demand_linehaul], dim=1 - ) - demand_backhaul = torch.cat( - [torch.zeros(size=(*batch_size, 1)), demand_backhaul], dim=1 - ) + demand_linehaul = torch.cat([torch.zeros(size=(*batch_size, 1)), demand_linehaul], dim=1) + demand_backhaul = torch.cat([torch.zeros(size=(*batch_size, 1)), demand_backhaul], dim=1) # Open (O) open_route = self.generate_open_route(shape=(*batch_size, 1)) @@ -260,9 +252,9 @@ def subsample_problems(self, td): cvrp_prob = 0.5 if self.variant_preset in ("all", "cvrp", "single_feat", "single_feat_otw"): indices = torch.distributions.Categorical( - torch.Tensor(list(self.variant_probs.values()) + [cvrp_prob])[ - None - ].repeat(batch_size, 1) + torch.Tensor(list(self.variant_probs.values()) + [cvrp_prob])[None].repeat( + batch_size, 1 + ) ).sample() if self.variant_preset == "single_feat_otw": keep_mask = torch.zeros((batch_size, 6), dtype=torch.bool) @@ -320,9 +312,7 @@ def generate_locations(self, batch_size, num_loc) -> torch.Tensor: Returns: locs: [B, N+1, 2] where the first location is the depot. """ - locs = torch.FloatTensor(*batch_size, num_loc + 1, 2).uniform_( - self.min_loc, self.max_loc - ) + locs = torch.FloatTensor(*batch_size, num_loc + 1, 2).uniform_(self.min_loc, self.max_loc) return locs def generate_demands(self, batch_size: int, num_loc: int) -> torch.Tensor: @@ -399,9 +389,7 @@ def generate_time_windows( service_time = torch.cat((torch.zeros(batch_size, 1), service_time), dim=-1) return time_windows, service_time # [B, N+1, 2], [B, N+1] - def generate_distance_limit( - self, shape: Tuple[int, int], locs: torch.Tensor - ) -> torch.Tensor: + def generate_distance_limit(self, shape: tuple[int, int], locs: torch.Tensor) -> torch.Tensor: """Generates distance limits (L) and checks their feasibilities. Returns: @@ -414,13 +402,13 @@ def generate_distance_limit( ).all(), "Distance limit too low, not all nodes can be reached from the depot." return torch.full(shape, self.distance_limit, dtype=torch.float32) - def generate_open_route(self, shape: Tuple[int, int]): + def generate_open_route(self, shape: tuple[int, int]): """Generate open route flags (O). Here we could have a sampler but we simply return True here so all routes are open. Afterwards, we subsample the problems. """ return torch.ones(shape, dtype=torch.bool) - def generate_speed(self, shape: Tuple[int, int]): + def generate_speed(self, shape: tuple[int, int]): """We simply generate the speed as constant here""" # in this version, the speed is constant but this class may be overridden return torch.full(shape, self.speed, dtype=torch.float32) diff --git a/rl4co/envs/routing/mtvrp/render.py b/rl4co/envs/routing/mtvrp/render.py index e5a3a9b8..465fc154 100644 --- a/rl4co/envs/routing/mtvrp/render.py +++ b/rl4co/envs/routing/mtvrp/render.py @@ -7,9 +7,7 @@ log = get_pylogger(__name__) -def render( - td: TensorDict, actions=None, ax=None, scale_xy: bool = False, vehicle_capacity=None -): +def render(td: TensorDict, actions=None, ax=None, scale_xy: bool = False, vehicle_capacity=None): import matplotlib.pyplot as plt import numpy as np diff --git a/rl4co/envs/routing/op/env.py b/rl4co/envs/routing/op/env.py index d15617d1..8a29ff82 100644 --- a/rl4co/envs/routing/op/env.py +++ b/rl4co/envs/routing/op/env.py @@ -1,5 +1,3 @@ -from typing import Optional - import torch import torch.nn.functional as F @@ -104,8 +102,8 @@ def _step(self, td: TensorDict) -> TensorDict: def _reset( self, - td: Optional[TensorDict] = None, - batch_size: Optional[list] = None, + td: TensorDict | None = None, + batch_size: list | None = None, ) -> TensorDict: device = td.device @@ -116,29 +114,21 @@ def _reset( td_reset = TensorDict( { "locs": locs_with_depot, - "prize": F.pad( - td["prize"], (1, 0), mode="constant", value=0 - ), # add 0 for depot + "prize": F.pad(td["prize"], (1, 0), mode="constant", value=0), # add 0 for depot "tour_length": torch.zeros(*batch_size, device=device), # max_length is max length allowed when arriving at node, so subtract distance to return to depot # Additionally, substract epsilon margin for numeric stability "max_length": td["max_length"][..., None] - (td["depot"][..., None, :] - locs_with_depot).norm(p=2, dim=-1) - 1e-6, - "current_node": torch.zeros( - *batch_size, 1, dtype=torch.long, device=device - ), + "current_node": torch.zeros(*batch_size, 1, dtype=torch.long, device=device), "visited": torch.zeros( (*batch_size, locs_with_depot.shape[-2]), dtype=torch.bool, device=device, ), - "current_total_prize": torch.zeros( - *batch_size, dtype=torch.float, device=device - ), - "i": torch.zeros( - (*batch_size,), dtype=torch.int64, device=device - ), # counter + "current_total_prize": torch.zeros(*batch_size, dtype=torch.float, device=device), + "i": torch.zeros((*batch_size,), dtype=torch.int64, device=device), # counter }, batch_size=batch_size, ) @@ -187,8 +177,7 @@ def check_solution_validity( sorted_actions = actions.data.sort(1)[0] # Make sure each node visited once at most (except for depot) assert ( - (sorted_actions[:, 1:] == 0) - | (sorted_actions[:, 1:] > sorted_actions[:, :-1]) + (sorted_actions[:, 1:] == 0) | (sorted_actions[:, 1:] > sorted_actions[:, :-1]) ).all(), "Duplicates" # Gather locations in order of tour and get the length of tours @@ -198,14 +187,10 @@ def check_solution_validity( max_length = td["max_length"] if add_distance_to_depot: max_length = ( - max_length - + (td["locs"][..., 0:1, :] - td["locs"]).norm(p=2, dim=-1) - + 1e-6 + max_length + (td["locs"][..., 0:1, :] - td["locs"]).norm(p=2, dim=-1) + 1e-6 ) - assert ( - length[..., None] <= max_length + 1e-5 - ).all(), "Max length exceeded by {}".format( - (length[..., None] - max_length).max() + assert (length[..., None] <= max_length + 1e-5).all(), ( + f"Max length exceeded by {(length[..., None] - max_length).max()}" ) def _make_spec(self, generator: OPGenerator): diff --git a/rl4co/envs/routing/op/generator.py b/rl4co/envs/routing/op/generator.py index fbfa0f65..df102685 100644 --- a/rl4co/envs/routing/op/generator.py +++ b/rl4co/envs/routing/op/generator.py @@ -1,4 +1,4 @@ -from typing import Callable +from collections.abc import Callable import torch @@ -62,9 +62,7 @@ def __init__( if kwargs.get("loc_sampler", None) is not None: self.loc_sampler = kwargs["loc_sampler"] else: - self.loc_sampler = get_sampler( - "loc", loc_distribution, min_loc, max_loc, **kwargs - ) + self.loc_sampler = get_sampler("loc", loc_distribution, min_loc, max_loc, **kwargs) # Depot distribution if kwargs.get("depot_sampler", None) is not None: @@ -119,18 +117,11 @@ def _generate(self, batch_size) -> TensorDict: prize = torch.ones(*batch_size, self.num_loc, device=self.device) elif self.prize_type == "unif": prize = ( - 1 - + torch.randint( - 0, 100, (*batch_size, self.num_loc), device=self.device - ).float() + 1 + torch.randint(0, 100, (*batch_size, self.num_loc), device=self.device).float() ) / 100 elif self.prize_type == "dist": # based on the distance to the depot - prize = (locs_with_depot[..., 0:1, :] - locs_with_depot[..., 1:, :]).norm( - p=2, dim=-1 - ) - prize = ( - 1 + (prize / prize.max(dim=-1, keepdim=True)[0] * 99).int() - ).float() / 100 + prize = (locs_with_depot[..., 0:1, :] - locs_with_depot[..., 1:, :]).norm(p=2, dim=-1) + prize = (1 + (prize / prize.max(dim=-1, keepdim=True)[0] * 99).int()).float() / 100 else: raise ValueError(f"Invalid prize_type: {self.prize_type}") diff --git a/rl4co/envs/routing/pctsp/env.py b/rl4co/envs/routing/pctsp/env.py index 222b3cad..7746649e 100644 --- a/rl4co/envs/routing/pctsp/env.py +++ b/rl4co/envs/routing/pctsp/env.py @@ -1,5 +1,3 @@ -from typing import Optional - import torch import torch.nn.functional as F @@ -65,12 +63,8 @@ def _step(self, td: TensorDict) -> TensorDict: current_node = td["action"] # Get current coordinates, prize, and penalty - cur_total_prize = td["cur_total_prize"] + gather_by_index( - td["real_prize"], current_node - ) - cur_total_penalty = td["cur_total_penalty"] + gather_by_index( - td["penalty"], current_node - ) + cur_total_prize = td["cur_total_prize"] + gather_by_index(td["real_prize"], current_node) + cur_total_penalty = td["cur_total_penalty"] + gather_by_index(td["penalty"], current_node) # Update visited visited = td["visited"].scatter(-1, current_node[..., None], 1) @@ -96,16 +90,12 @@ def _step(self, td: TensorDict) -> TensorDict: td.set("action_mask", self.get_action_mask(td)) return td - def _reset( - self, td: Optional[TensorDict] = None, batch_size: Optional[list] = None - ) -> TensorDict: + def _reset(self, td: TensorDict | None = None, batch_size: list | None = None) -> TensorDict: device = td.device locs = torch.cat([td["depot"][..., None, :], td["locs"]], dim=-2) expected_prize = td["deterministic_prize"] - real_prize = ( - td["stochastic_prize"] if self.stochastic else td["deterministic_prize"] - ) + real_prize = td["stochastic_prize"] if self.stochastic else td["deterministic_prize"] penalty = td["penalty"] # Concatenate depots @@ -124,9 +114,7 @@ def _reset( (*batch_size, self.generator.num_loc + 1), dtype=torch.bool, device=device ) i = torch.zeros((*batch_size,), dtype=torch.int64, device=device) - prize_required = torch.full( - (*batch_size,), self.generator.prize_required, device=device - ) + prize_required = torch.full((*batch_size,), self.generator.prize_required, device=device) td_reset = TensorDict( { @@ -186,8 +174,7 @@ def check_solution_validity(td: TensorDict, actions: torch.Tensor) -> None: # Make sure each node visited once at most (except for depot) assert ( - (sorted_actions[..., 1:] == 0) - | (sorted_actions[..., 1:] > sorted_actions[..., :-1]) + (sorted_actions[..., 1:] == 0) | (sorted_actions[..., 1:] > sorted_actions[..., :-1]) ).all(), "Duplicates" prize = td["real_prize"][..., 1:] # Remove depot @@ -210,9 +197,7 @@ def stochastic(self): @stochastic.setter def stochastic(self, state: bool): if state is True: - log.warning( - "Stochastic mode should not be used for PCTSP. Use SPCTSP instead." - ) + log.warning("Stochastic mode should not be used for PCTSP. Use SPCTSP instead.") def _make_spec(self, generator): """Make the locs and action specs from the parameters.""" diff --git a/rl4co/envs/routing/pctsp/generator.py b/rl4co/envs/routing/pctsp/generator.py index 246318e3..0144dfa9 100644 --- a/rl4co/envs/routing/pctsp/generator.py +++ b/rl4co/envs/routing/pctsp/generator.py @@ -1,4 +1,4 @@ -from typing import Callable +from collections.abc import Callable from tensordict.tensordict import TensorDict from torch.distributions import Uniform @@ -55,9 +55,7 @@ def __init__( if kwargs.get("loc_sampler", None) is not None: self.loc_sampler = kwargs["loc_sampler"] else: - self.loc_sampler = get_sampler( - "loc", loc_distribution, min_loc, max_loc, **kwargs - ) + self.loc_sampler = get_sampler("loc", loc_distribution, min_loc, max_loc, **kwargs) # Depot distribution if kwargs.get("depot_sampler", None) is not None: @@ -99,9 +97,7 @@ def __init__( # Adjust as in Kool et al. (2019) self.max_penalty *= penalty_factor / self.num_loc - self.penalty_sampler = get_sampler( - "penalty", "uniform", 0.0, self.max_penalty, **kwargs - ) + self.penalty_sampler = get_sampler("penalty", "uniform", 0.0, self.max_penalty, **kwargs) def _generate(self, batch_size) -> TensorDict: # Sample locations: depot and customers @@ -121,17 +117,14 @@ def _generate(self, batch_size) -> TensorDict: # Now expectation is 0.5 so expected total prize is n / 2, we want to force to visit approximately half of the nodes # so the constraint will be that total prize >= (n / 2) / 2 = n / 4 # equivalently, we divide all prizes by n / 4 and the total prize should be >= 1 - deterministic_prize = self.deterministic_prize_sampler.sample( - (*batch_size, self.num_loc) - ) + deterministic_prize = self.deterministic_prize_sampler.sample((*batch_size, self.num_loc)) # In the deterministic setting, the stochastic_prize is not used and the deterministic prize is known # In the stochastic setting, the deterministic prize is the expected prize and is known up front but the # stochastic prize is only revealed once the node is visited # Stochastic prize is between (0, 2 * expected_prize) such that E(stochastic prize) = E(deterministic_prize) stochastic_prize = ( - self.stochastic_prize_sampler.sample((*batch_size, self.num_loc)) - * deterministic_prize + self.stochastic_prize_sampler.sample((*batch_size, self.num_loc)) * deterministic_prize ) return TensorDict( diff --git a/rl4co/envs/routing/pctsp/render.py b/rl4co/envs/routing/pctsp/render.py index 71242c3e..0fe607a5 100644 --- a/rl4co/envs/routing/pctsp/render.py +++ b/rl4co/envs/routing/pctsp/render.py @@ -32,9 +32,7 @@ def render(td, actions=None, ax=None): 200 * (prizes - torch.min(prizes)) / (torch.max(prizes) - torch.min(prizes)) + 10 ) normalized_penalties = ( - 3 - * (penalties - torch.min(penalties)) - / (torch.max(penalties) - torch.min(penalties)) + 3 * (penalties - torch.min(penalties)) / (torch.max(penalties) - torch.min(penalties)) ) # Represent penalty with colormap and size of edges diff --git a/rl4co/envs/routing/pdp/env.py b/rl4co/envs/routing/pdp/env.py index 5996415f..86c9bbb1 100644 --- a/rl4co/envs/routing/pdp/env.py +++ b/rl4co/envs/routing/pdp/env.py @@ -1,5 +1,3 @@ -from typing import Optional - import torch from tensordict.tensordict import TensorDict @@ -73,13 +71,9 @@ def _step(td: TensorDict) -> TensorDict: new_to_deliver = (current_node + num_loc // 2) % (num_loc + 1) # Set available to 0 (i.e., we visited the node) - available = td["available"].scatter( - -1, current_node.expand_as(td["action_mask"]), 0 - ) + available = td["available"].scatter(-1, current_node.expand_as(td["action_mask"]), 0) - to_deliver = td["to_deliver"].scatter( - -1, new_to_deliver.expand_as(td["to_deliver"]), 1 - ) + to_deliver = td["to_deliver"].scatter(-1, new_to_deliver.expand_as(td["to_deliver"]), 1) # Action is feasible if the node is not visited and is to deliver # action_mask = torch.logical_and(available, to_deliver) @@ -90,7 +84,7 @@ def _step(td: TensorDict) -> TensorDict: # The reward is calculated outside via get_reward for efficiency, so we set it to 0 here reward = torch.zeros_like(done) - + # Update step td.update( { @@ -105,7 +99,7 @@ def _step(td: TensorDict) -> TensorDict: ) return td - def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict: + def _reset(self, td: TensorDict | None = None, batch_size=None) -> TensorDict: device = td.device locs = torch.cat((td["depot"][:, None, :], td["locs"]), -2) @@ -128,16 +122,16 @@ def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict ) # Masking variables - available = torch.ones( - (*batch_size, self.generator.num_loc + 1), dtype=torch.bool - ).to(device) - action_mask = torch.ones_like(available) # [batch_size, graph_size+1] + available = torch.ones((*batch_size, self.generator.num_loc + 1), dtype=torch.bool).to( + device + ) + action_mask = torch.ones_like(available) # [batch_size, graph_size+1] if self.force_start_at_depot: - action_mask[..., 1:] = False # can only visit the depot at the first step + action_mask[..., 1:] = False # can only visit the depot at the first step else: action_mask = action_mask & to_deliver - available[..., 0] = False # depot is already visited (during reward calculation) - action_mask[..., 0] = False # depot is not available to visit + available[..., 0] = False # depot is already visited (during reward calculation) + action_mask[..., 0] = False # depot is not available to visit # Other variables current_node = torch.zeros((*batch_size, 1), dtype=torch.int64).to(device) @@ -208,18 +202,16 @@ def check_solution_validity(self, td, actions): actions = torch.cat((torch.zeros_like(actions[:, 0:1]), actions), dim=-1) assert ( - (torch.arange(actions.size(1), out=actions.data.new())) - .view(1, -1) - .expand_as(actions) + (torch.arange(actions.size(1), out=actions.data.new())).view(1, -1).expand_as(actions) == actions.data.sort(1)[0] ).all(), "Not visiting all nodes" - + # make sure we don't go back to the depot in the middle of the tour - assert (actions[:, 1:-1] != 0).all(), "Going back to depot in the middle of the tour (not allowed)" + assert (actions[:, 1:-1] != 0).all(), ( + "Going back to depot in the middle of the tour (not allowed)" + ) - visited_time = torch.argsort( - actions, 1 - ) # index of pickup less than index of delivery + visited_time = torch.argsort(actions, 1) # index of pickup less than index of delivery assert ( visited_time[:, 1 : actions.size(1) // 2 + 1] < visited_time[:, actions.size(1) // 2 + 1 :] @@ -352,7 +344,7 @@ def _step(self, td: TensorDict, solution_to=None) -> TensorDict: return td - def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict: + def _reset(self, td: TensorDict | None = None, batch_size=None) -> TensorDict: device = td.device locs = torch.cat((td["depot"][:, None, :], td["locs"]), -2) @@ -363,7 +355,7 @@ def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict bs = batch_size[0] seq_length = self.generator.num_loc + 1 visited_time = torch.zeros((bs, seq_length)).to(device) - pre = torch.zeros((bs)).to(device).long() + pre = torch.zeros(bs).to(device).long() arange = torch.arange(bs) for i in range(seq_length): current_nodes = current_rec[arange, pre] @@ -483,9 +475,7 @@ def check_solution_validity(self, td, actions=None): batch_size, graph_size = solution.size() assert ( - torch.arange(graph_size, out=solution.data.new()) - .view(1, -1) - .expand_as(solution) + torch.arange(graph_size, out=solution.data.new()).view(1, -1).expand_as(solution) == solution.data.sort(1)[0] ).all(), "Not visiting all nodes" @@ -497,8 +487,7 @@ def check_solution_validity(self, td, actions=None): pre = solution[arange, pre] assert ( - visited_time[:, 1 : graph_size // 2 + 1] - < visited_time[:, graph_size // 2 + 1 :] + visited_time[:, 1 : graph_size // 2 + 1] < visited_time[:, graph_size // 2 + 1 :] ).all(), "Deliverying without pick-up" @staticmethod @@ -524,17 +513,13 @@ def get_mask(selected_node, td): @classmethod def _random_action(cls, td): batch_size, graph_size = td["rec_best"].size() - selected_node = ( - (torch.rand(batch_size, 1) * graph_size // 2) % (graph_size // 2) - ).long() + selected_node = ((torch.rand(batch_size, 1) * graph_size // 2) % (graph_size // 2)).long() mask = cls.get_mask(selected_node + 1, td) logits = torch.rand(batch_size, graph_size, graph_size) logits[~mask] = -1e20 prob = torch.softmax(logits.view(batch_size, -1), -1) sample = prob.multinomial(1) - action = torch.cat( - (selected_node, sample // (graph_size), sample % (graph_size)), -1 - ) + action = torch.cat((selected_node, sample // (graph_size), sample % (graph_size)), -1) td["action"] = action return action diff --git a/rl4co/envs/routing/pdp/generator.py b/rl4co/envs/routing/pdp/generator.py index f1cbf79c..fec320d8 100644 --- a/rl4co/envs/routing/pdp/generator.py +++ b/rl4co/envs/routing/pdp/generator.py @@ -1,4 +1,4 @@ -from typing import Callable +from collections.abc import Callable import torch @@ -47,18 +47,14 @@ def __init__( # Number of locations must be even if num_loc % 2 != 0: - log.warning( - "Number of locations must be even. Adding 1 to the number of locations." - ) + log.warning("Number of locations must be even. Adding 1 to the number of locations.") self.num_loc += 1 # Location distribution if kwargs.get("loc_sampler", None) is not None: self.loc_sampler = kwargs["loc_sampler"] else: - self.loc_sampler = get_sampler( - "loc", loc_distribution, min_loc, max_loc, **kwargs - ) + self.loc_sampler = get_sampler("loc", loc_distribution, min_loc, max_loc, **kwargs) # Depot distribution if kwargs.get("depot_sampler", None) is not None: @@ -110,8 +106,7 @@ def _get_initial_solutions(self, coordinates): add_index = (next_selected_node <= order_size).view(-1) pairing = ( - next_selected_node[next_selected_node <= order_size].view(-1, 1) - + order_size + next_selected_node[next_selected_node <= order_size].view(-1, 1) + order_size ) candidates[add_index] = candidates[add_index].scatter_(1, pairing, 1) @@ -139,8 +134,7 @@ def _get_initial_solutions(self, coordinates): add_index = (next_selected_node <= order_size).view(-1) pairing = ( - next_selected_node[next_selected_node <= order_size].view(-1, 1) - + order_size + next_selected_node[next_selected_node <= order_size].view(-1, 1) + order_size ) candidates[add_index] = candidates[add_index].scatter_(1, pairing, 1) diff --git a/rl4co/envs/routing/pdp/render.py b/rl4co/envs/routing/pdp/render.py index 8217a964..74a8761a 100644 --- a/rl4co/envs/routing/pdp/render.py +++ b/rl4co/envs/routing/pdp/render.py @@ -29,9 +29,7 @@ def render(td, actions=None, ax=None): # Plot the actions in order for i in range(len(actions)): from_node = actions[i] - to_node = ( - actions[i + 1] if i < len(actions) - 1 else actions[0] - ) # last goes back to depot + to_node = actions[i + 1] if i < len(actions) - 1 else actions[0] # last goes back to depot from_loc = td["locs"][from_node] to_loc = td["locs"][to_node] ax.plot([from_loc[0], to_loc[0]], [from_loc[1], to_loc[1]], "k-") @@ -88,16 +86,10 @@ def render_improvement(td, current_soltuion, best_soltuion): if ax == ax1: ax.axis([-0.05, 1.05] * 2) # plot the nodes - ax.scatter( - coordinates[:, 0], coordinates[:, 1], marker="H", s=55, c="blue", zorder=2 - ) + ax.scatter(coordinates[:, 0], coordinates[:, 1], marker="H", s=55, c="blue", zorder=2) # plot the tour - real_seq_coordinates = coordinates.gather( - 0, real_seq[0].unsqueeze(1).repeat(1, 2) - ) - real_seq_coordinates = torch.cat( - (real_seq_coordinates, real_seq_coordinates[:1]), 0 - ) + real_seq_coordinates = coordinates.gather(0, real_seq[0].unsqueeze(1).repeat(1, 2)) + real_seq_coordinates = torch.cat((real_seq_coordinates, real_seq_coordinates[:1]), 0) ax.plot( real_seq_coordinates[:, 0], real_seq_coordinates[:, 1], @@ -114,16 +106,10 @@ def render_improvement(td, current_soltuion, best_soltuion): else: ax.axis([-0.05, 1.05] * 2) # plot the nodes - ax.scatter( - coordinates[:, 0], coordinates[:, 1], marker="H", s=55, c="blue", zorder=2 - ) + ax.scatter(coordinates[:, 0], coordinates[:, 1], marker="H", s=55, c="blue", zorder=2) # plot the tour - real_best_coordinates = coordinates.gather( - 0, real_best[0].unsqueeze(1).repeat(1, 2) - ) - real_best_coordinates = torch.cat( - (real_best_coordinates, real_best_coordinates[:1]), 0 - ) + real_best_coordinates = coordinates.gather(0, real_best[0].unsqueeze(1).repeat(1, 2)) + real_best_coordinates = torch.cat((real_best_coordinates, real_best_coordinates[:1]), 0) ax.plot( real_best_coordinates[:, 0], real_best_coordinates[:, 1], diff --git a/rl4co/envs/routing/sdvrp/env.py b/rl4co/envs/routing/sdvrp/env.py index 916a7a9f..374181f9 100644 --- a/rl4co/envs/routing/sdvrp/env.py +++ b/rl4co/envs/routing/sdvrp/env.py @@ -1,5 +1,3 @@ -from typing import Optional - import torch from tensordict.tensordict import TensorDict @@ -63,19 +61,13 @@ def _step(self, td: TensorDict) -> TensorDict: selected_demand = gather_by_index( td["demand_with_depot"], current_node, dim=-1, squeeze=False )[..., :1] - delivered_demand = torch.min( - selected_demand, td["vehicle_capacity"] - td["used_capacity"] - ) + delivered_demand = torch.min(selected_demand, td["vehicle_capacity"] - td["used_capacity"]) # Increase capacity if depot is not visited, otherwise set to 0 - used_capacity = (td["used_capacity"] + delivered_demand) * ( - current_node != 0 - ).float() + used_capacity = (td["used_capacity"] + delivered_demand) * (current_node != 0).float() # Update demand - demand_with_depot = td["demand_with_depot"].scatter_add( - -1, current_node, -delivered_demand - ) + demand_with_depot = td["demand_with_depot"].scatter_add(-1, current_node, -delivered_demand) # Get done done = ~(demand_with_depot > 0).any(-1) @@ -98,8 +90,8 @@ def _step(self, td: TensorDict) -> TensorDict: def _reset( self, - td: Optional[TensorDict] = None, - batch_size: Optional[list] = None, + td: TensorDict | None = None, + batch_size: list | None = None, ) -> TensorDict: device = td.device @@ -111,9 +103,7 @@ def _reset( "demand_with_depot": torch.cat( (torch.zeros_like(td["demand"][..., 0:1]), td["demand"]), -1 ), - "current_node": torch.zeros( - *batch_size, 1, dtype=torch.long, device=device - ), + "current_node": torch.zeros(*batch_size, 1, dtype=torch.long, device=device), "used_capacity": torch.zeros((*batch_size, 1), device=device), "vehicle_capacity": torch.full( (*batch_size, 1), self.generator.vehicle_capacity, device=device @@ -129,9 +119,7 @@ def get_action_mask(td: TensorDict) -> torch.Tensor: mask_loc = (td["demand_with_depot"][..., 1:] == 0) | ( td["used_capacity"] >= td["vehicle_capacity"] ) - mask_depot = (td["current_node"] == 0).squeeze(-1) & ( - (mask_loc == 0).int().sum(-1) > 0 - ) + mask_depot = (td["current_node"] == 0).squeeze(-1) & ((mask_loc == 0).int().sum(-1) > 0) return ~torch.cat((mask_depot[..., None], mask_loc), -1) @staticmethod @@ -148,9 +136,9 @@ def check_solution_validity(td: TensorDict, actions: torch.Tensor) -> None: used_cap = torch.zeros_like(td["demand"][..., 0]) a_prev = None for a in actions.transpose(0, 1): - assert ( - a_prev is None or (demands[((a_prev == 0) & (a == 0)), :] == 0).all() - ), "Cannot visit depot twice if any nonzero demand" + assert a_prev is None or (demands[((a_prev == 0) & (a == 0)), :] == 0).all(), ( + "Cannot visit depot twice if any nonzero demand" + ) d = torch.min(demands[rng, a], td["vehicle_capacity"].squeeze(-1) - used_cap) demands[rng, a] -= d used_cap += d diff --git a/rl4co/envs/routing/shpp/env.py b/rl4co/envs/routing/shpp/env.py index 45363469..ddb20938 100644 --- a/rl4co/envs/routing/shpp/env.py +++ b/rl4co/envs/routing/shpp/env.py @@ -1,5 +1,3 @@ -from typing import Optional - import torch from tensordict.tensordict import TensorDict @@ -93,7 +91,7 @@ def _step(td: TensorDict) -> TensorDict: ) return td - def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict: + def _reset(self, td: TensorDict | None = None, batch_size=None) -> TensorDict: """Note: the first node is the starting node; the last node is the terminating node""" device = td.device locs = td["locs"] @@ -103,9 +101,7 @@ def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict # Other variables current_node = torch.zeros((batch_size), dtype=torch.int64, device=device) - last_node = torch.full( - (batch_size), num_loc - 1, dtype=torch.int64, device=device - ) + last_node = torch.full((batch_size), num_loc - 1, dtype=torch.int64, device=device) available = torch.ones( (*batch_size, num_loc), dtype=torch.bool, device=device ) # 1 means not visited, i.e. action is allowed @@ -136,9 +132,7 @@ def _get_reward(self, td, actions) -> TensorDict: def check_solution_validity(td: TensorDict, actions: torch.Tensor): """Check that solution is valid: nodes are visited exactly once""" assert ( - torch.arange(actions.size(1), out=actions.data.new()) - .view(1, -1) - .expand_as(actions) + torch.arange(actions.size(1), out=actions.data.new()).view(1, -1).expand_as(actions) == actions.data.sort(1)[0] ).all(), "Invalid tour" diff --git a/rl4co/envs/routing/shpp/generator.py b/rl4co/envs/routing/shpp/generator.py index aaa6ecb0..d23ebfe6 100644 --- a/rl4co/envs/routing/shpp/generator.py +++ b/rl4co/envs/routing/shpp/generator.py @@ -1,4 +1,4 @@ -from typing import Callable, Union +from collections.abc import Callable from tensordict.tensordict import TensorDict from torch.distributions import Uniform @@ -27,7 +27,7 @@ def __init__( num_loc: int = 20, min_loc: float = 0.0, max_loc: float = 1.0, - loc_distribution: Union[int, float, str, type, Callable] = Uniform, + loc_distribution: int | float | str | type | Callable = Uniform, **kwargs, ): self.num_loc = num_loc @@ -38,9 +38,7 @@ def __init__( if kwargs.get("loc_sampler", None) is not None: self.loc_sampler = kwargs["loc_sampler"] else: - self.loc_sampler = get_sampler( - "loc", loc_distribution, min_loc, max_loc, **kwargs - ) + self.loc_sampler = get_sampler("loc", loc_distribution, min_loc, max_loc, **kwargs) def _generate(self, batch_size) -> TensorDict: # Sample locations diff --git a/rl4co/envs/routing/shpp/render.py b/rl4co/envs/routing/shpp/render.py index bc1edb54..aca77d00 100644 --- a/rl4co/envs/routing/shpp/render.py +++ b/rl4co/envs/routing/shpp/render.py @@ -9,7 +9,6 @@ def render(td, actions=None, ax=None): - if ax is None: # Create a plot of the nodes _, ax = plt.subplots(figsize=(3, 3)) diff --git a/rl4co/envs/routing/spctsp/env.py b/rl4co/envs/routing/spctsp/env.py index 4f99c070..186d0e94 100644 --- a/rl4co/envs/routing/spctsp/env.py +++ b/rl4co/envs/routing/spctsp/env.py @@ -26,6 +26,4 @@ def stochastic(self): @stochastic.setter def stochastic(self, state: bool): if state is False: - log.warning( - "Deterministic mode should not be used for SPCTSP. Use PCTSP instead." - ) + log.warning("Deterministic mode should not be used for SPCTSP. Use PCTSP instead.") diff --git a/rl4co/envs/routing/svrp/env.py b/rl4co/envs/routing/svrp/env.py index afcbae6a..084eb2de 100644 --- a/rl4co/envs/routing/svrp/env.py +++ b/rl4co/envs/routing/svrp/env.py @@ -1,5 +1,3 @@ -from typing import Optional - import torch from tensordict.tensordict import TensorDict @@ -107,9 +105,7 @@ def get_action_mask(td: TensorDict) -> torch.Tensor: current_tech_skill = gather_by_index(td["techs"], td["current_tech"]).reshape( [batch_size, 1] ) - can_service = td["skills"] <= current_tech_skill.unsqueeze(1).expand_as( - td["skills"] - ) + can_service = td["skills"] <= current_tech_skill.unsqueeze(1).expand_as(td["skills"]) mask_loc = td["visited"][..., 1:, :].to(can_service.dtype) | ~can_service # Cannot visit the depot if there are still unserved nodes and I either just visited the depot or am the last technician mask_depot = ( @@ -144,9 +140,7 @@ def _step(self, td: TensorDict) -> torch.Tensor: td.set("action_mask", self.get_action_mask(td)) return td - def _reset( - self, td: Optional[TensorDict] = None, batch_size: Optional[list] = None - ) -> TensorDict: + def _reset(self, td: TensorDict | None = None, batch_size: list | None = None) -> TensorDict: device = td.device # Create reset TensorDict @@ -155,12 +149,8 @@ def _reset( "locs": torch.cat((td["depot"][:, None, :], td["locs"]), -2), "techs": td["techs"], "skills": td["skills"], - "current_node": torch.zeros( - *batch_size, 1, dtype=torch.long, device=device - ), - "current_tech": torch.zeros( - *batch_size, 1, dtype=torch.long, device=device - ), + "current_node": torch.zeros(*batch_size, 1, dtype=torch.long, device=device), + "current_tech": torch.zeros(*batch_size, 1, dtype=torch.long, device=device), "visited": torch.zeros( (*batch_size, td["locs"].shape[-2] + 1, 1), dtype=torch.uint8, @@ -185,9 +175,7 @@ def _get_reward(self, td: TensorDict, actions: torch.Tensor) -> torch.Tensor: locs_ordered = torch.cat( [ depot, - gather_by_index(td["locs"], actions).reshape( - [batch_size, actions.size(-1), 2] - ), + gather_by_index(td["locs"], actions).reshape([batch_size, actions.size(-1), 2]), ], dim=1, ) @@ -230,19 +218,15 @@ def check_solution_validity(td: TensorDict, actions: torch.Tensor) -> None: # make sure all required skill levels are met indices = torch.nonzero(actions == 0) - skills = torch.cat( - [torch.zeros(batch_size, 1, 1, device=td.device), td["skills"]], 1 - ) - skills_ordered = gather_by_index(skills, actions).reshape( - [batch_size, actions.size(-1), 1] - ) + skills = torch.cat([torch.zeros(batch_size, 1, 1, device=td.device), td["skills"]], 1) + skills_ordered = gather_by_index(skills, actions).reshape([batch_size, actions.size(-1), 1]) batch = start = tech = 0 for each in indices: if each[0] > batch: start = tech = 0 batch = each[0] - assert ( - skills_ordered[batch, start : each[1]] <= td["techs"][batch, tech] - ).all(), "Skill level not met" + assert (skills_ordered[batch, start : each[1]] <= td["techs"][batch, tech]).all(), ( + "Skill level not met" + ) start = each[1] + 1 # skip the depot tech += 1 diff --git a/rl4co/envs/routing/svrp/generator.py b/rl4co/envs/routing/svrp/generator.py index efd480fd..1c4fdb6d 100644 --- a/rl4co/envs/routing/svrp/generator.py +++ b/rl4co/envs/routing/svrp/generator.py @@ -1,4 +1,4 @@ -from typing import Callable +from collections.abc import Callable import torch @@ -55,9 +55,7 @@ def __init__( if kwargs.get("loc_sampler", None) is not None: self.loc_sampler = kwargs["loc_sampler"] else: - self.loc_sampler = get_sampler( - "loc", loc_distribution, min_loc, max_loc, **kwargs - ) + self.loc_sampler = get_sampler("loc", loc_distribution, min_loc, max_loc, **kwargs) # Depot distribution if kwargs.get("depot_sampler", None) is not None: diff --git a/rl4co/envs/routing/tsp/env.py b/rl4co/envs/routing/tsp/env.py index 79913ecb..ecea37ee 100644 --- a/rl4co/envs/routing/tsp/env.py +++ b/rl4co/envs/routing/tsp/env.py @@ -1,5 +1,3 @@ -from typing import Optional - import torch from tensordict.tensordict import TensorDict @@ -87,7 +85,7 @@ def _step(td: TensorDict) -> TensorDict: ) return td - def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict: + def _reset(self, td: TensorDict | None = None, batch_size=None) -> TensorDict: # Initialize locations device = td.device init_locs = td["locs"] @@ -161,9 +159,7 @@ def _get_reward(self, td: TensorDict, actions: torch.Tensor) -> torch.Tensor: def check_solution_validity(td: TensorDict, actions: torch.Tensor) -> None: """Check that solution is valid: nodes are visited exactly once""" assert ( - torch.arange(actions.size(1), out=actions.data.new()) - .view(1, -1) - .expand_as(actions) + torch.arange(actions.size(1), out=actions.data.new()).view(1, -1).expand_as(actions) == actions.data.sort(1)[0] ).all(), "Invalid tour" @@ -186,9 +182,9 @@ def replace_selected_actions( @staticmethod def local_search(td: TensorDict, actions: torch.Tensor, **kwargs) -> torch.Tensor: - assert ( - local_search is not None - ), "Cannot import local_search module. Check if `numba` is installed." + assert local_search is not None, ( + "Cannot import local_search module. Check if `numba` is installed." + ) return local_search(td, actions, **kwargs) @staticmethod @@ -288,7 +284,7 @@ def _step(self, td: TensorDict, solution_to=None) -> TensorDict: return td - def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict: + def _reset(self, td: TensorDict | None = None, batch_size=None) -> TensorDict: device = td.device locs = td["locs"] @@ -300,7 +296,7 @@ def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict bs = batch_size[0] seq_length = self.generator.num_loc visited_time = torch.zeros((bs, seq_length)).to(device) - pre = torch.zeros((bs)).to(device).long() + pre = torch.zeros(bs).to(device).long() arange = torch.arange(bs) for i in range(seq_length): current_nodes = current_rec[arange, pre] @@ -346,9 +342,7 @@ def _local_operator(self, solution, action): cur = first for i in range(self.generator.num_loc): cur_next = solution.gather(1, cur) - rec.scatter_( - 1, cur_next, torch.where(cur != second, cur, rec.gather(1, cur_next)) - ) + rec.scatter_(1, cur_next, torch.where(cur != second, cur, rec.gather(1, cur_next))) cur = torch.where(cur != second, cur_next, cur) rec_next = rec @@ -438,9 +432,7 @@ def check_solution_validity(self, td, actions=None): batch_size, graph_size = solution.size() assert ( - torch.arange(graph_size, out=solution.data.new()) - .view(1, -1) - .expand_as(solution) + torch.arange(graph_size, out=solution.data.new()).view(1, -1).expand_as(solution) == solution.data.sort(1)[0] ).all(), "Not visiting all nodes" @@ -487,9 +479,7 @@ def _random_action(self, td): 1 - value_max.view(-1, 1) < 1e-5, action_max.view(-1, 1), action ) ### fix bug of pytorch if i > 0: - action = torch.where( - stopped.unsqueeze(-1), action_index[:, :1], action - ) + action = torch.where(stopped.unsqueeze(-1), action_index[:, :1], action) # Store and Process actions next_of_new_action = rec.gather(1, action) @@ -508,9 +498,7 @@ def _random_action(self, td): # Calc next basic masks if i == 0: - visited_time_tag = ( - visited_time - visited_time.gather(1, action) - ) % gs + visited_time_tag = (visited_time - visited_time.gather(1, action)) % gs mask &= False mask[(visited_time_tag <= visited_time_tag.gather(1, action))] = True if i == 0: @@ -522,9 +510,7 @@ def _random_action(self, td): index_allow_first_node = (~stopped) & ( next_of_new_action.squeeze() == action_index[:, 0] ) - mask[index_allow_first_node, action_index[index_allow_first_node, 0]] = ( - False - ) + mask[index_allow_first_node, action_index[index_allow_first_node, 0]] = False # Move to next next_of_last_action = next_of_new_action @@ -551,9 +537,7 @@ class DenseRewardTSPEnv(TSPEnv): to the current tour by the given action. """ - def __init__( - self, generator: TSPGenerator = None, generator_params: dict = {}, **kwargs - ): + def __init__(self, generator: TSPGenerator = None, generator_params: dict = {}, **kwargs): super().__init__( generator, generator_params, diff --git a/rl4co/envs/routing/tsp/generator.py b/rl4co/envs/routing/tsp/generator.py index cb34fb7c..8d5a56c5 100644 --- a/rl4co/envs/routing/tsp/generator.py +++ b/rl4co/envs/routing/tsp/generator.py @@ -1,4 +1,4 @@ -from typing import Callable +from collections.abc import Callable import torch @@ -44,9 +44,7 @@ def __init__( if kwargs.get("loc_sampler", None) is not None: self.loc_sampler = kwargs["loc_sampler"] else: - self.loc_sampler = get_sampler( - "loc", loc_distribution, min_loc, max_loc, **kwargs - ) + self.loc_sampler = get_sampler("loc", loc_distribution, min_loc, max_loc, **kwargs) def _generate(self, batch_size) -> TensorDict: # Sample locations diff --git a/rl4co/envs/routing/tsp/local_search.py b/rl4co/envs/routing/tsp/local_search.py index 78bcf43a..42ea3e3d 100644 --- a/rl4co/envs/routing/tsp/local_search.py +++ b/rl4co/envs/routing/tsp/local_search.py @@ -1,9 +1,10 @@ import os -import numpy as np import numba as nb -from numba import set_num_threads +import numpy as np import torch + +from numba import set_num_threads from tensordict.tensordict import TensorDict from rl4co.utils.ops import get_distance_matrix @@ -42,31 +43,34 @@ def local_search( return torch.from_numpy(numba_results.astype(np.int64)).to(actions.device) -@nb.njit(nb.float32(nb.float32[:,:], nb.uint16[:], nb.uint16), nogil=True) -def two_opt_once(distmat, tour, fixed_i = 0): - '''in-place operation''' +@nb.njit(nb.float32(nb.float32[:, :], nb.uint16[:], nb.uint16), nogil=True) +def two_opt_once(distmat, tour, fixed_i=0): + """in-place operation""" n = tour.shape[0] p = q = 0 delta = 0 - for i in range(1, n - 1) if fixed_i==0 else range(fixed_i, fixed_i + 1): + for i in range(1, n - 1) if fixed_i == 0 else range(fixed_i, fixed_i + 1): for j in range(i + 1, n): node_i, node_j = tour[i], tour[j] node_prev, node_next = tour[i - 1], tour[(j + 1) % n] if node_prev == node_j or node_next == node_i: continue change = ( - distmat[node_prev, node_j] + distmat[node_i, node_next] - - distmat[node_prev, node_i] - distmat[node_j, node_next] + distmat[node_prev, node_j] + + distmat[node_i, node_next] + - distmat[node_prev, node_i] + - distmat[node_j, node_next] ) if change < delta: p, q, delta = i, j, change if delta < -1e-6: - tour[p: q + 1] = np.flip(tour[p: q + 1]) + tour[p : q + 1] = np.flip(tour[p : q + 1]) return delta else: return 0.0 -@nb.njit(nb.uint16[:,:](nb.float32[:,:,:], nb.uint16[:,:], nb.int64), nogil=True, parallel=True) + +@nb.njit(nb.uint16[:, :](nb.float32[:, :, :], nb.uint16[:, :], nb.int64), nogil=True, parallel=True) def _two_opt_python(distmat, tour, max_iterations=1000): for i in nb.prange(tour.shape[0]): iterations = 0 diff --git a/rl4co/envs/routing/tsp/render.py b/rl4co/envs/routing/tsp/render.py index dde4c86f..624ce968 100644 --- a/rl4co/envs/routing/tsp/render.py +++ b/rl4co/envs/routing/tsp/render.py @@ -56,16 +56,10 @@ def render_improvement(td, current_soltuion, best_soltuion): if ax == ax1: ax.axis([-0.05, 1.05] * 2) # plot the nodes - ax.scatter( - coordinates[:, 0], coordinates[:, 1], marker="H", s=55, c="blue", zorder=2 - ) + ax.scatter(coordinates[:, 0], coordinates[:, 1], marker="H", s=55, c="blue", zorder=2) # plot the tour - real_seq_coordinates = coordinates.gather( - 0, real_seq[0].unsqueeze(1).repeat(1, 2) - ) - real_seq_coordinates = torch.cat( - (real_seq_coordinates, real_seq_coordinates[:1]), 0 - ) + real_seq_coordinates = coordinates.gather(0, real_seq[0].unsqueeze(1).repeat(1, 2)) + real_seq_coordinates = torch.cat((real_seq_coordinates, real_seq_coordinates[:1]), 0) ax.plot( real_seq_coordinates[:, 0], real_seq_coordinates[:, 1], @@ -82,16 +76,10 @@ def render_improvement(td, current_soltuion, best_soltuion): else: ax.axis([-0.05, 1.05] * 2) # plot the nodes - ax.scatter( - coordinates[:, 0], coordinates[:, 1], marker="H", s=55, c="blue", zorder=2 - ) + ax.scatter(coordinates[:, 0], coordinates[:, 1], marker="H", s=55, c="blue", zorder=2) # plot the tour - real_best_coordinates = coordinates.gather( - 0, real_best[0].unsqueeze(1).repeat(1, 2) - ) - real_best_coordinates = torch.cat( - (real_best_coordinates, real_best_coordinates[:1]), 0 - ) + real_best_coordinates = coordinates.gather(0, real_best[0].unsqueeze(1).repeat(1, 2)) + real_best_coordinates = torch.cat((real_best_coordinates, real_best_coordinates[:1]), 0) ax.plot( real_best_coordinates[:, 0], real_best_coordinates[:, 1], diff --git a/rl4co/envs/scheduling/ffsp/env.py b/rl4co/envs/scheduling/ffsp/env.py index 26191053..ddb293a5 100644 --- a/rl4co/envs/scheduling/ffsp/env.py +++ b/rl4co/envs/scheduling/ffsp/env.py @@ -1,7 +1,6 @@ import itertools from math import factorial -from typing import Optional import torch @@ -242,9 +241,7 @@ def _step(self, td: TensorDict) -> TensorDict: return td - def _reset( - self, td: Optional[TensorDict] = None, batch_size: Optional[list] = None - ) -> TensorDict: + def _reset(self, td: TensorDict | None = None, batch_size: list | None = None) -> TensorDict: """ Args: @@ -318,9 +315,7 @@ def _reset( fill_value=False, ) - action_mask = torch.ones( - size=(*batch_size, self.num_job + 1), dtype=bool, device=device - ) + action_mask = torch.ones(size=(*batch_size, self.num_job + 1), dtype=bool, device=device) action_mask[..., -1] = 0 batch_idx = torch.arange(*batch_size, dtype=torch.long, device=td.device) diff --git a/rl4co/envs/scheduling/fjsp/env.py b/rl4co/envs/scheduling/fjsp/env.py index 060991b7..3b026d62 100644 --- a/rl4co/envs/scheduling/fjsp/env.py +++ b/rl4co/envs/scheduling/fjsp/env.py @@ -118,9 +118,7 @@ def _decode_graph_structure(self, td: TensorDict): # generate for each batch a sequence specifying the position of all operations in their respective jobs, # e.g. [0,1,0,0,1,2,0,1,2,3,0,0] for jops with n_ops=[2,1,3,4,1,1] # (bs, max_ops) - ops_seq_order = torch.sum( - ops_job_bin_map * (ops_job_bin_map.cumsum(2) - 1), dim=1 - ) + ops_seq_order = torch.sum(ops_job_bin_map * (ops_job_bin_map.cumsum(2) - 1), dim=1) # predecessor and successor adjacency matrices pred = torch.diag_embed(torch.ones(n_ops_max - 1), offset=-1)[None].expand( @@ -203,9 +201,7 @@ def _get_job_machine_availability(self, td: TensorDict): batch_size = td.size(0) # (bs, jobs, machines) - action_mask = torch.full((batch_size, self.num_jobs, self.num_mas), False).to( - td.device - ) + action_mask = torch.full((batch_size, self.num_jobs, self.num_mas), False).to(td.device) # mask jobs that are done already action_mask.add_(td["job_done"].unsqueeze(2)) @@ -230,9 +226,7 @@ def get_action_mask(self, td: TensorDict) -> torch.Tensor: no_op_mask = td["done"] else: # if no job is currently processed and instance is not finished yet, waiting is not allowed - no_op_mask = ( - td["job_in_process"].any(1, keepdims=True) & (~td["done"]) - ) | td["done"] + no_op_mask = (td["job_in_process"].any(1, keepdims=True) & (~td["done"])) | td["done"] # flatten action mask to correspond with logit shape action_mask = rearrange(action_mask, "bs j m -> bs (j m)") # NOTE: 1 means feasible action, 0 means infeasible action @@ -370,9 +364,7 @@ def _transit_to_next_time(self, step_complete, td: TensorDict) -> TensorDict: # we want to transition to the next time step where a machine becomes idle again. This time step must be # in the future, therefore we mask all machine idle times lying in the past / present available_time = ( - torch.where( - available_time_ma > td["time"][:, None], available_time_ma, torch.inf - ) + torch.where(available_time_ma > td["time"][:, None], available_time_ma, torch.inf) .min(1) .values ) @@ -403,12 +395,10 @@ def _get_reward(self, td, actions=None) -> TensorDict: if self.stepwise_reward and actions is None: return td["reward"] else: - assert td[ - "done" - ].all(), "Set stepwise_reward to True if you want reward prior to completion" - return ( - -td["finish_times"].masked_fill(td["pad_mask"], -torch.inf).max(1).values + assert td["done"].all(), ( + "Set stepwise_reward to True if you want reward prior to completion" ) + return -td["finish_times"].masked_fill(td["pad_mask"], -torch.inf).max(1).values def _make_spec(self, generator: FJSPGenerator): self.observation_spec = Composite( diff --git a/rl4co/envs/scheduling/fjsp/generator.py b/rl4co/envs/scheduling/fjsp/generator.py index ca3a47a9..04428c66 100644 --- a/rl4co/envs/scheduling/fjsp/generator.py +++ b/rl4co/envs/scheduling/fjsp/generator.py @@ -62,9 +62,7 @@ def __init__( if len(unused_kwargs) > 0: log.error(f"Found {len(unused_kwargs)} unused kwargs: {unused_kwargs}") - def _simulate_processing_times( - self, n_eligible_per_ops: torch.Tensor - ) -> torch.Tensor: + def _simulate_processing_times(self, n_eligible_per_ops: torch.Tensor) -> torch.Tensor: bs, n_ops_max = n_eligible_per_ops.shape # (bs, max_ops, machines) @@ -102,8 +100,7 @@ def _simulate_processing_times( + 1 ) proc_times = ( - torch.randint(2**63 - 1, size=proc_times.shape) - % (high_bounds - low_bounds) + torch.randint(2**63 - 1, size=proc_times.shape) % (high_bounds - low_bounds) + low_bounds ) else: @@ -229,9 +226,7 @@ def list_files(path): import os files = [ - os.path.join(path, f) - for f in os.listdir(path) - if os.path.isfile(os.path.join(path, f)) + os.path.join(path, f) for f in os.listdir(path) if os.path.isfile(os.path.join(path, f)) ] assert len(files) > 0 return files diff --git a/rl4co/envs/scheduling/fjsp/parser.py b/rl4co/envs/scheduling/fjsp/parser.py index 21f55738..130acda5 100644 --- a/rl4co/envs/scheduling/fjsp/parser.py +++ b/rl4co/envs/scheduling/fjsp/parser.py @@ -2,27 +2,24 @@ from functools import partial from pathlib import Path -from typing import Tuple import torch from tensordict import TensorDict -ProcessingData = list[Tuple[int, int]] +ProcessingData = list[tuple[int, int]] def list_files(path): import os files = [ - os.path.join(path, f) - for f in os.listdir(path) - if os.path.isfile(os.path.join(path, f)) + os.path.join(path, f) for f in os.listdir(path) if os.path.isfile(os.path.join(path, f)) ] return files -def parse_job_line(line: Tuple[int]) -> Tuple[ProcessingData]: +def parse_job_line(line: tuple[int]) -> tuple[ProcessingData]: """ Parses a FJSPLIB job data line of the following form: @@ -119,7 +116,7 @@ def read(loc: Path, max_ops=None): def file2lines(loc: Path | str) -> list[list[int]]: - with open(loc, "r") as fh: + with open(loc) as fh: lines = [line for line in fh.readlines() if line.strip()] def parse_num(word: str): @@ -130,9 +127,7 @@ def parse_num(word: str): def write_one(args, where=None): id, instance = args - assert ( - len(instance["proc_times"].shape) == 2 - ), "no batch dimension allowed in write operation" + assert len(instance["proc_times"].shape) == 2, "no batch dimension allowed in write operation" lines = [] # The flexibility is the average number of eligible machines per operation. @@ -164,7 +159,7 @@ def write_one(args, where=None): formatted = "\n".join(lines) - file_name = f"{str(id+1).rjust(4, '0')}_{num_jobs}j_{num_machines}m.txt" + file_name = f"{str(id + 1).rjust(4, '0')}_{num_jobs}j_{num_machines}m.txt" full_path = os.path.join(where, file_name) with open(full_path, "w") as fh: diff --git a/rl4co/envs/scheduling/fjsp/render.py b/rl4co/envs/scheduling/fjsp/render.py index 6f5e27c8..21d9f4aa 100644 --- a/rl4co/envs/scheduling/fjsp/render.py +++ b/rl4co/envs/scheduling/fjsp/render.py @@ -48,9 +48,7 @@ def render(td: TensorDict, idx: int): linewidth=1, ) - ax.text( - start + (end - start) / 2, ma, op, ha="center", va="center", color="white" - ) + ax.text(start + (end - start) / 2, ma, op, ha="center", va="center", color="white") # Set labels and title ax.set_yticks(range(len(schedule))) diff --git a/rl4co/envs/scheduling/fjsp/utils.py b/rl4co/envs/scheduling/fjsp/utils.py index 0865eb3e..7eaf97ac 100644 --- a/rl4co/envs/scheduling/fjsp/utils.py +++ b/rl4co/envs/scheduling/fjsp/utils.py @@ -1,7 +1,5 @@ import logging -from typing import Tuple - import torch from tensordict import TensorDict @@ -16,9 +14,7 @@ def get_op_features(td: TensorDict): return torch.stack((td["lbs"], td["is_ready"], td["num_eligible"]), dim=-1) -def cat_and_norm_features( - td: TensorDict, feats: list[str], time_feats: list[str], norm_const: int -): +def cat_and_norm_features(td: TensorDict, feats: list[str], time_feats: list[str], norm_const: int): # logger.info(f"will scale the features {','.join(time_feats)} with a constant ({norm_const})") feature_list = [] for feat in feats: @@ -32,7 +28,7 @@ def cat_and_norm_features( def view( tensor: Tensor, - idx: Tuple[Tensor], + idx: tuple[Tensor], pad_mask: Tensor, new_shape: Size | list[int], pad_value: float | int, @@ -93,9 +89,7 @@ def get_job_op_view(td: TensorDict, keys: list[str] = [], pad_value: float | int if "pad_mask" not in keys: keys.append("pad_mask") - new_views = dict( - map(lambda key: (key, view(td[key], idx, pad_mask, new_shape)), keys) - ) + new_views = dict(map(lambda key: (key, view(td[key], idx, pad_mask, new_shape)), keys)) # update tensordict clone with reshaped tensors return {"proc_times": new_proc_times_view, **new_views} @@ -119,13 +113,13 @@ def blockify(td, tensor: Tensor, pad_value: float | int = 0): return new_view_tensor -def unblockify( - td: TensorDict, tensor: Tensor, mask: Tensor = None, pad_value: float | int = 0 -): +def unblockify(td: TensorDict, tensor: Tensor, mask: Tensor = None, pad_value: float | int = 0): assert len(tensor.shape) in [ 3, 4, - ], "blockify only supports tensors of shape (bs, nb, s, (d)), where the feature dim d is optional" + ], ( + "blockify only supports tensors of shape (bs, nb, s, (d)), where the feature dim d is optional" + ) # get the size of the blockified tensor bs, _, _, *d = tensor.shape n_ops_per_batch = td["job_ops_adj"].sum((1, 2)).unsqueeze(1) # (bs) @@ -176,9 +170,9 @@ def spatial_encoding(td: TensorDict): same_job[pad_mask.unsqueeze(2).expand_as(same_job)] = 0 same_job[pad_mask.unsqueeze(1).expand_as(same_job)] = 0 # take upper triangular of same_job and set diagonal to zero for counting purposes - upper_tri = torch.triu(same_job) - torch.diag( - torch.ones(n_total_ops, device=td.device) - )[None].expand_as(same_job) + upper_tri = torch.triu(same_job) - torch.diag(torch.ones(n_total_ops, device=td.device))[ + None + ].expand_as(same_job) # cumsum and masking of operations that do not belong to the same job num_jumps = upper_tri.cumsum(2) * upper_tri # mirror the matrix @@ -228,9 +222,7 @@ def calc_lower_bound(td: TensorDict): # using the start_time, we can determine if and how long an op needs to wait for a machine to finish wait_for_ma_offset = torch.clip(busy_until[..., None] - maybe_start_at[:, None], 0) # we add this required waiting time to the respective processing time - proc_time_plus_wait = torch.where( - proc_times == 0, proc_times, proc_times + wait_for_ma_offset - ) + proc_time_plus_wait = torch.where(proc_times == 0, proc_times, proc_times + wait_for_ma_offset) # NOTE get the mean processing time over all eligible machines for lb calulation # ops_proc_times = torch.where(proc_times == 0, torch.inf, proc_time_plus_wait).min(1).values) ops_proc_times = proc_time_plus_wait.sum(1) / (proc_times.gt(0).sum(1) + 1e-9) @@ -242,17 +234,13 @@ def calc_lower_bound(td: TensorDict): # sum over the processing time to determine the lower bound of unscheduled operations... proc_matrix = job_ops_adj ops_assigned = proc_matrix * op_scheduled[:, None] - proc_matrix_not_scheduled = proc_matrix * ( - torch.ones_like(proc_matrix) - op_scheduled[:, None] - ) + proc_matrix_not_scheduled = proc_matrix * (torch.ones_like(proc_matrix) - op_scheduled[:, None]) # ...and add the finish_time of the last scheduled operation of the respective job to that. To make this work, using the cumsum logic, # we calc the first differences of the finish times and seperate by job. # We use the first differences, so that the finish times do not add up during cumulative sum below # (bs, num_jobs, num_ops) - finish_times_1st_diff = ops_assigned * first_diff( - ops_assigned * finish_times[:, None], 2 - ) + finish_times_1st_diff = ops_assigned * first_diff(ops_assigned * finish_times[:, None], 2) # masking the processing time of scheduled operations and add their finish times instead (first diff thereof) lb_end_expand = ( @@ -265,9 +253,7 @@ def calc_lower_bound(td: TensorDict): LBs = torch.nan_to_num(LBs, nan=0.0) # test - assert torch.where( - finish_times != INIT_FINISH, torch.isclose(LBs, finish_times), True - ).all() + assert torch.where(finish_times != INIT_FINISH, torch.isclose(LBs, finish_times), True).all() return LBs @@ -286,7 +272,7 @@ def op_is_ready(td: TensorDict): def get_job_ops_mapping( start_op_per_job: torch.Tensor, end_op_per_job: torch.Tensor, n_ops_max: int -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: """Implements a mapping function from operations to jobs :param torch.Tensor start_op_per_job: index of first operation of each job @@ -306,9 +292,7 @@ def get_job_ops_mapping( # here we will generate the operations-job mapping: # Therefore we first generate a sequence of operation ids and expand it the the size of the mapping matrix: # (bs, jobs, max_ops) - ops_seq_exp = torch.arange(n_ops_max, device=device)[None, None].expand( - bs, num_jobs, -1 - ) + ops_seq_exp = torch.arange(n_ops_max, device=device)[None, None].expand(bs, num_jobs, -1) # (bs, jobs, max_ops) # expanding start and end operation ids end_op_per_job_exp = end_op_per_job[..., None].expand_as(ops_seq_exp) start_op_per_job_exp = start_op_per_job[..., None].expand_as(ops_seq_exp) diff --git a/rl4co/envs/scheduling/jssp/env.py b/rl4co/envs/scheduling/jssp/env.py index 2381836a..3bc989db 100644 --- a/rl4co/envs/scheduling/jssp/env.py +++ b/rl4co/envs/scheduling/jssp/env.py @@ -100,9 +100,7 @@ def get_action_mask(self, td: TensorDict) -> Tensor: no_op_mask = td["done"] else: # if no job is currently processed and instance is not finished yet, waiting is not allowed - no_op_mask = ( - td["job_in_process"].any(1, keepdims=True) & (~td["done"]) - ) | td["done"] + no_op_mask = (td["job_in_process"].any(1, keepdims=True) & (~td["done"])) | td["done"] # reduce action mask to correspond with logit shape action_mask = reduce(action_mask, "bs j m -> bs j", reduction="all") # NOTE: 1 means feasible action, 0 means infeasible action diff --git a/rl4co/envs/scheduling/jssp/generator.py b/rl4co/envs/scheduling/jssp/generator.py index 3e3437b6..c5581a17 100644 --- a/rl4co/envs/scheduling/jssp/generator.py +++ b/rl4co/envs/scheduling/jssp/generator.py @@ -70,9 +70,7 @@ def __init__( def _simulate_processing_times(self, bs, n_ops_max) -> torch.Tensor: if self.one2one_ma_map: ops_machine_ids = ( - torch.rand((*bs, self.num_jobs, self.num_mas)) - .argsort(dim=-1) - .flatten(1, 2) + torch.rand((*bs, self.num_jobs, self.num_mas)).argsort(dim=-1).flatten(1, 2) ) else: ops_machine_ids = torch.randint( @@ -159,9 +157,7 @@ class JSSPFileGenerator(Generator): """ def __init__(self, file_path: str, n_ops_max: int = None, **unused_kwargs): - self.files = ( - [file_path] if os.path.isfile(file_path) else self.list_files(file_path) - ) + self.files = [file_path] if os.path.isfile(file_path) else self.list_files(file_path) self.num_samples = len(self.files) if len(unused_kwargs) > 0: @@ -198,9 +194,7 @@ def _generate(self, batch_size: list[int]) -> TensorDict: @staticmethod def list_files(path): files = [ - os.path.join(path, f) - for f in os.listdir(path) - if os.path.isfile(os.path.join(path, f)) + os.path.join(path, f) for f in os.listdir(path) if os.path.isfile(os.path.join(path, f)) ] assert len(files) > 0, "No files found in the specified path" return files diff --git a/rl4co/envs/scheduling/jssp/parser.py b/rl4co/envs/scheduling/jssp/parser.py index 129838e7..88dca5e1 100644 --- a/rl4co/envs/scheduling/jssp/parser.py +++ b/rl4co/envs/scheduling/jssp/parser.py @@ -1,14 +1,13 @@ from pathlib import Path -from typing import Tuple import torch from tensordict import TensorDict -ProcessingData = list[Tuple[int, int]] +ProcessingData = list[tuple[int, int]] -def parse_job_line(line: Tuple[int]) -> Tuple[ProcessingData]: +def parse_job_line(line: tuple[int]) -> tuple[ProcessingData]: """ Parses a JSSP job data line of the following form: @@ -101,7 +100,7 @@ def read(loc: Path, max_ops=None): def file2lines(loc: Path | str) -> list[list[int]]: - with open(loc, "r") as fh: + with open(loc) as fh: lines = [line for line in fh.readlines() if line.strip()] def parse_num(word: str): diff --git a/rl4co/envs/scheduling/smtwtp/env.py b/rl4co/envs/scheduling/smtwtp/env.py index c7ee54f2..83c8a9f9 100644 --- a/rl4co/envs/scheduling/smtwtp/env.py +++ b/rl4co/envs/scheduling/smtwtp/env.py @@ -1,5 +1,3 @@ -from typing import Optional - import torch from tensordict.tensordict import TensorDict @@ -98,7 +96,7 @@ def _step(td: TensorDict) -> TensorDict: ) return td - def _reset(self, td: Optional[TensorDict] = None, batch_size=None) -> TensorDict: + def _reset(self, td: TensorDict | None = None, batch_size=None) -> TensorDict: device = td.device init_job_due_time = td["job_due_time"] @@ -180,9 +178,7 @@ def _get_reward(self, td, actions) -> TensorDict: ordered_process_time = job_process_time[batch_idx, actions] ordered_due_time = job_due_time[batch_idx, actions] ordered_job_weight = job_weight[batch_idx, actions] - presum_process_time = torch.cumsum( - ordered_process_time, dim=1 - ) # ending time of each job + presum_process_time = torch.cumsum(ordered_process_time, dim=1) # ending time of each job job_tardiness = presum_process_time - ordered_due_time job_tardiness[job_tardiness < 0] = 0 job_weighted_tardiness = ordered_job_weight * job_tardiness diff --git a/rl4co/envs/scheduling/smtwtp/generator.py b/rl4co/envs/scheduling/smtwtp/generator.py index 39701478..a6d08338 100644 --- a/rl4co/envs/scheduling/smtwtp/generator.py +++ b/rl4co/envs/scheduling/smtwtp/generator.py @@ -1,24 +1,16 @@ -import os -import zipfile -from typing import Union, Callable - import torch -import numpy as np -from robust_downloader import download -from torch.distributions import Uniform from tensordict.tensordict import TensorDict -from rl4co.data.utils import load_npz_to_tensordict +from rl4co.envs.common.utils import Generator from rl4co.utils.pylogger import get_pylogger -from rl4co.envs.common.utils import get_sampler, Generator log = get_pylogger(__name__) class SMTWTPGenerator(Generator): """Data generator for the Single Machine Total Weighted Tardiness Problem (SMTWTP) environment - + Args: num_job: number of jobs min_time_span: lower bound of jobs' due time. By default, jobs' due time is uniformly sampled from (min_time_span, max_time_span) @@ -27,23 +19,24 @@ class SMTWTPGenerator(Generator): max_job_weight: upper bound of jobs' weights min_process_time: lower bound of jobs' process time. By default, jobs' process time is uniformly sampled from (min_process_time, max_process_time) max_process_time: upper bound of jobs' process time - + Returns: A TensorDict with the following key: job_due_time [batch_size, num_job + 1]: the due time of each job job_weight [batch_size, num_job + 1]: the weight of each job job_process_time [batch_size, num_job + 1]: the process time of each job """ + def __init__( self, num_job: int = 10, min_time_span: float = 0, - max_time_span: float = None, # will be set to num_job / 2 by default. In DeepACO, it is set to num_job, which would be too simple + max_time_span: float = None, # will be set to num_job / 2 by default. In DeepACO, it is set to num_job, which would be too simple min_job_weight: float = 0, max_job_weight: float = 1, min_process_time: float = 0, max_process_time: float = 1, - **unused_kwargs + **unused_kwargs, ): self.num_job = num_job self.min_time_span = min_time_span @@ -60,17 +53,14 @@ def __init__( def _generate(self, batch_size) -> TensorDict: batch_size = [batch_size] if isinstance(batch_size, int) else batch_size # Sampling according to Ye et al. (2023) - job_due_time = ( - torch.FloatTensor(*batch_size, self.num_job + 1) - .uniform_(self.min_time_span, self.max_time_span) + job_due_time = torch.FloatTensor(*batch_size, self.num_job + 1).uniform_( + self.min_time_span, self.max_time_span ) - job_weight = ( - torch.FloatTensor(*batch_size, self.num_job + 1) - .uniform_(self.min_job_weight, self.max_job_weight) + job_weight = torch.FloatTensor(*batch_size, self.num_job + 1).uniform_( + self.min_job_weight, self.max_job_weight ) - job_process_time = ( - torch.FloatTensor(*batch_size, self.num_job + 1) - .uniform_(self.min_process_time, self.max_process_time) + job_process_time = torch.FloatTensor(*batch_size, self.num_job + 1).uniform_( + self.min_process_time, self.max_process_time ) # Rollouts begin at dummy node 0, whose features are set to 0 diff --git a/rl4co/envs/scheduling/smtwtp/render.py b/rl4co/envs/scheduling/smtwtp/render.py index 9f8eedf0..07211e8a 100644 --- a/rl4co/envs/scheduling/smtwtp/render.py +++ b/rl4co/envs/scheduling/smtwtp/render.py @@ -1,11 +1,5 @@ -import torch -import numpy as np -import matplotlib.pyplot as plt - -from matplotlib import cm, colormaps from tensordict.tensordict import TensorDict -from rl4co.utils.ops import gather_by_index from rl4co.utils.pylogger import get_pylogger log = get_pylogger(__name__) diff --git a/rl4co/models/__init__.py b/rl4co/models/__init__.py index 85e27115..41a94423 100644 --- a/rl4co/models/__init__.py +++ b/rl4co/models/__init__.py @@ -26,19 +26,10 @@ from rl4co.models.zoo.dact import DACT, DACTPolicy from rl4co.models.zoo.deepaco import DeepACO, DeepACOPolicy from rl4co.models.zoo.eas import EAS, EASEmb, EASLay -from rl4co.models.zoo.glop import GLOP, GLOPPolicy from rl4co.models.zoo.gfacs import GFACS, GFACSPolicy -from rl4co.models.zoo.ham import ( - HeterogeneousAttentionModel, - HeterogeneousAttentionModelPolicy, -) -from rl4co.models.zoo.l2d import ( - L2DAttnPolicy, - L2DModel, - L2DPolicy, - L2DPolicy4PPO, - L2DPPOModel, -) +from rl4co.models.zoo.glop import GLOP, GLOPPolicy +from rl4co.models.zoo.ham import HeterogeneousAttentionModel, HeterogeneousAttentionModelPolicy +from rl4co.models.zoo.l2d import L2DAttnPolicy, L2DModel, L2DPolicy, L2DPolicy4PPO, L2DPPOModel from rl4co.models.zoo.matnet import MatNet, MatNetPolicy from rl4co.models.zoo.mdam import MDAM, MDAMPolicy from rl4co.models.zoo.mvmoe import MVMoE_AM, MVMoE_POMO diff --git a/rl4co/models/common/constructive/autoregressive/policy.py b/rl4co/models/common/constructive/autoregressive/policy.py index 3cad5961..627b7ad1 100644 --- a/rl4co/models/common/constructive/autoregressive/policy.py +++ b/rl4co/models/common/constructive/autoregressive/policy.py @@ -32,7 +32,7 @@ def __init__( if decoder is None: raise ValueError("AutoregressivePolicy requires a decoder to be provided.") - super(AutoregressivePolicy, self).__init__( + super().__init__( encoder=encoder, decoder=decoder, env_name=env_name, diff --git a/rl4co/models/common/constructive/base.py b/rl4co/models/common/constructive/base.py index 804bb867..c6e5be52 100644 --- a/rl4co/models/common/constructive/base.py +++ b/rl4co/models/common/constructive/base.py @@ -1,6 +1,7 @@ import abc -from typing import Any, Callable, Optional, Tuple +from collections.abc import Callable +from typing import Any import torch.nn as nn @@ -8,11 +9,7 @@ from torch import Tensor from rl4co.envs import RL4COEnvBase, get_env -from rl4co.utils.decoding import ( - DecodingStrategy, - get_decoding_strategy, - get_log_likelihood, -) +from rl4co.utils.decoding import DecodingStrategy, get_decoding_strategy, get_log_likelihood from rl4co.utils.ops import calculate_entropy from rl4co.utils.pylogger import get_pylogger @@ -23,7 +20,7 @@ class ConstructiveEncoder(nn.Module, metaclass=abc.ABCMeta): """Base class for the encoder of constructive models""" @abc.abstractmethod - def forward(self, td: TensorDict) -> Tuple[Any, Tensor]: + def forward(self, td: TensorDict) -> tuple[Any, Tensor]: """Forward pass for the encoder Args: @@ -43,7 +40,7 @@ class ConstructiveDecoder(nn.Module, metaclass=abc.ABCMeta): @abc.abstractmethod def forward( self, td: TensorDict, hidden: Any = None, num_starts: int = 0 - ) -> Tuple[Tensor, Tensor]: + ) -> tuple[Tensor, Tensor]: """Obtain logits for current action to the next ones Args: @@ -58,7 +55,7 @@ def forward( def pre_decoder_hook( self, td: TensorDict, env: RL4COEnvBase, hidden: Any = None, num_starts: int = 0 - ) -> Tuple[TensorDict, RL4COEnvBase, Any]: + ) -> tuple[TensorDict, RL4COEnvBase, Any]: """By default, we don't need to do anything here. Args: @@ -76,7 +73,7 @@ def pre_decoder_hook( class NoEncoder(ConstructiveEncoder): """Default encoder decoder-only models, i.e. autoregressive models that re-encode all the state at each decoding step.""" - def forward(self, td: TensorDict) -> Tuple[Tensor, Tensor]: + def forward(self, td: TensorDict) -> tuple[Tensor, Tensor]: """Return Nones for the hidden state and initial embeddings""" return None, None @@ -132,7 +129,7 @@ def __init__( test_decode_type: str = "greedy", **unused_kw, ): - super(ConstructivePolicy, self).__init__() + super().__init__() if len(unused_kw) > 0: log.error(f"Found {len(unused_kw)} unused kwargs: {unused_kw}") @@ -157,7 +154,7 @@ def __init__( def forward( self, td: TensorDict, - env: Optional[str | RL4COEnvBase] = None, + env: str | RL4COEnvBase | None = None, phase: str = "train", calc_reward: bool = True, return_actions: bool = True, @@ -237,9 +234,7 @@ def forward( td = env.step(td)["next"] step += 1 if step > max_steps: - log.error( - f"Exceeded maximum number of steps ({max_steps}) duing decoding" - ) + log.error(f"Exceeded maximum number of steps ({max_steps}) duing decoding") break # Post-decoding hook: used for the final step(s) of the decoding strategy diff --git a/rl4co/models/common/constructive/nonautoregressive/__init__.py b/rl4co/models/common/constructive/nonautoregressive/__init__.py index f170d079..25a71815 100644 --- a/rl4co/models/common/constructive/nonautoregressive/__init__.py +++ b/rl4co/models/common/constructive/nonautoregressive/__init__.py @@ -1,9 +1,3 @@ -from rl4co.models.common.constructive.nonautoregressive.decoder import ( - NonAutoregressiveDecoder, -) -from rl4co.models.common.constructive.nonautoregressive.encoder import ( - NonAutoregressiveEncoder, -) -from rl4co.models.common.constructive.nonautoregressive.policy import ( - NonAutoregressivePolicy, -) +from rl4co.models.common.constructive.nonautoregressive.decoder import NonAutoregressiveDecoder +from rl4co.models.common.constructive.nonautoregressive.encoder import NonAutoregressiveEncoder +from rl4co.models.common.constructive.nonautoregressive.policy import NonAutoregressivePolicy diff --git a/rl4co/models/common/constructive/nonautoregressive/policy.py b/rl4co/models/common/constructive/nonautoregressive/policy.py index 655b9cc7..f6b3dd4f 100644 --- a/rl4co/models/common/constructive/nonautoregressive/policy.py +++ b/rl4co/models/common/constructive/nonautoregressive/policy.py @@ -1,4 +1,3 @@ -from typing import Optional from rl4co.models.common.constructive.base import ConstructivePolicy from .decoder import NonAutoregressiveDecoder @@ -13,7 +12,7 @@ class NonAutoregressivePolicy(ConstructivePolicy): def __init__( self, encoder: NonAutoregressiveEncoder, - decoder: Optional[NonAutoregressiveDecoder] = None, + decoder: NonAutoregressiveDecoder | None = None, env_name: str = "tsp", temperature: float = 1.0, tanh_clipping: float = 0, @@ -27,7 +26,7 @@ def __init__( if decoder is None: decoder = NonAutoregressiveDecoder() - super(NonAutoregressivePolicy, self).__init__( + super().__init__( encoder=encoder, decoder=decoder, env_name=env_name, diff --git a/rl4co/models/common/improvement/__init__.py b/rl4co/models/common/improvement/__init__.py index 3eb4a954..d13a2b52 100644 --- a/rl4co/models/common/improvement/__init__.py +++ b/rl4co/models/common/improvement/__init__.py @@ -1 +1,5 @@ -from rl4co.models.common.improvement.base import ImprovementDecoder, ImprovementEncoder, ImprovementPolicy \ No newline at end of file +from rl4co.models.common.improvement.base import ( + ImprovementDecoder, + ImprovementEncoder, + ImprovementPolicy, +) diff --git a/rl4co/models/common/improvement/base.py b/rl4co/models/common/improvement/base.py index e4f5cb15..256206f3 100644 --- a/rl4co/models/common/improvement/base.py +++ b/rl4co/models/common/improvement/base.py @@ -1,7 +1,5 @@ import abc -from typing import Tuple - import torch.nn as nn from tensordict import TensorDict @@ -31,15 +29,13 @@ def __init__( feedforward_hidden: int = 128, linear_bias: bool = False, ): - super(ImprovementEncoder, self).__init__() + super().__init__() if isinstance(env_name, RL4COEnvBase): env_name = env_name.name self.env_name = env_name self.init_embedding = ( - env_init_embedding( - self.env_name, {"embed_dim": embed_dim, "linear_bias": linear_bias} - ) + env_init_embedding(self.env_name, {"embed_dim": embed_dim, "linear_bias": linear_bias}) if init_embedding is None else init_embedding ) @@ -52,7 +48,7 @@ def __init__( ) @abc.abstractmethod - def _encoder_forward(self, init_h: Tensor, init_p: Tensor) -> Tuple[Tensor, Tensor]: + def _encoder_forward(self, init_h: Tensor, init_p: Tensor) -> tuple[Tensor, Tensor]: """Process the node embeddings and positional embeddings to the final embeddings Args: @@ -64,7 +60,7 @@ def _encoder_forward(self, init_h: Tensor, init_p: Tensor) -> Tuple[Tensor, Tens """ raise NotImplementedError("Implement me in subclass!") - def forward(self, td: TensorDict) -> Tuple[Tensor, Tensor]: + def forward(self, td: TensorDict) -> tuple[Tensor, Tensor]: """Forward pass of the encoder. Transform the input TensorDict into a latent representation. diff --git a/rl4co/models/common/transductive/base.py b/rl4co/models/common/transductive/base.py index 5e8c0b12..f0c2c317 100644 --- a/rl4co/models/common/transductive/base.py +++ b/rl4co/models/common/transductive/base.py @@ -1,6 +1,6 @@ import abc -from typing import Any, Optional +from typing import Any from lightning.pytorch.utilities.types import STEP_OUTPUT from torch.utils.data import Dataset @@ -36,8 +36,8 @@ def __init__( dataset: Dataset | str, batch_size: int = 1, max_iters: int = 100, - max_runtime: Optional[int] = 86_400, - save_path: Optional[str] = None, + max_runtime: int | None = 86_400, + save_path: str | None = None, **kwargs, ): self.save_hyperparameters(logger=False, ignore=["env", "policy", "dataset"]) @@ -72,9 +72,7 @@ def training_step(self, batch, batch_idx): """Main search loop. We use the training step to effectively adapt to a `batch` of instances.""" raise NotImplementedError("Implement in subclass") - def on_train_batch_end( - self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int - ) -> None: + def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int) -> None: """Called when the train batch ends. This can be used for instance for logging or clearing cache. """ diff --git a/rl4co/models/nn/attention.py b/rl4co/models/nn/attention.py index d091a591..fd506512 100644 --- a/rl4co/models/nn/attention.py +++ b/rl4co/models/nn/attention.py @@ -2,7 +2,7 @@ import math import warnings -from typing import Callable, Optional +from collections.abc import Callable import torch import torch.nn as nn @@ -16,9 +16,7 @@ log = get_pylogger(__name__) -def scaled_dot_product_attention_simple( - q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False -): +def scaled_dot_product_attention_simple(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False): """Simple (exact) Scaled Dot-Product Attention in RL4CO without customized kernels (i.e. no Flash Attention).""" # Check for causal and attn_mask conflict @@ -90,7 +88,7 @@ def __init__( causal: bool = False, device: str = None, dtype: torch.dtype = None, - sdpa_fn: Optional[Callable] = None, + sdpa_fn: Callable | None = None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__() @@ -102,9 +100,9 @@ def __init__( self.num_heads = num_heads assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads" self.head_dim = self.embed_dim // num_heads - assert ( - self.head_dim % 8 == 0 and self.head_dim <= 128 - ), "Only support head_dim <= 128 and divisible by 8" + assert self.head_dim % 8 == 0 and self.head_dim <= 128, ( + "Only support head_dim <= 128 and divisible by 8" + ) self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) @@ -171,7 +169,7 @@ def __init__( attention_dropout: float = 0.0, device: str = None, dtype: torch.dtype = None, - sdpa_fn: Optional[Callable | nn.Module] = None, + sdpa_fn: Callable | nn.Module | None = None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__() @@ -186,9 +184,9 @@ def __init__( self.num_heads = num_heads assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads" self.head_dim = self.embed_dim // num_heads - assert ( - self.head_dim % 8 == 0 and self.head_dim <= 128 - ), "Only support head_dim <= 128 and divisible by 8" + assert self.head_dim % 8 == 0 and self.head_dim <= 128, ( + "Only support head_dim <= 128 and divisible by 8" + ) self.Wq = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) self.Wkv = nn.Linear(embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs) @@ -196,14 +194,10 @@ def __init__( def forward(self, q_input, kv_input, cross_attn_mask=None, dmat=None): # Project query, key, value - q = rearrange( - self.Wq(q_input), "b m (h d) -> b h m d", h=self.num_heads - ) # [b, h, m, d] + q = rearrange(self.Wq(q_input), "b m (h d) -> b h m d", h=self.num_heads) # [b, h, m, d] k, v = rearrange( self.Wkv(kv_input), "b n (two h d) -> two b h n d", two=2, h=self.num_heads - ).unbind( - dim=0 - ) # [b, h, n, d] + ).unbind(dim=0) # [b, h, n, d] if cross_attn_mask is not None: # add head dim @@ -252,7 +246,7 @@ def __init__( sdpa_fn: Callable | str = "default", **kwargs, ): - super(PointerAttention, self).__init__() + super().__init__() self.num_heads = num_heads self.mask_inner = mask_inner @@ -274,9 +268,7 @@ def __init__( else: if sdpa_fn is None: sdpa_fn = scaled_dot_product_attention - log.info( - "Using default scaled_dot_product_attention for PointerAttention" - ) + log.info("Using default scaled_dot_product_attention for PointerAttention") self.sdpa_fn = sdpa_fn def forward(self, query, key, value, logit_key, attn_mask=None): @@ -358,12 +350,10 @@ def __init__( mask_inner: bool = True, out_bias: bool = False, check_nan: bool = True, - sdpa_fn: Optional[Callable] = None, - moe_kwargs: Optional[dict] = None, + sdpa_fn: Callable | None = None, + moe_kwargs: dict | None = None, ): - super(PointerAttnMoE, self).__init__( - embed_dim, num_heads, mask_inner, out_bias, check_nan, sdpa_fn - ) + super().__init__(embed_dim, num_heads, mask_inner, out_bias, check_nan, sdpa_fn) self.moe_kwargs = moe_kwargs self.project_out = None @@ -381,17 +371,11 @@ def _project_out(self, out, attn_mask): # only do this at the "second" step, which is depot -> pomo -> first select if (num_available_nodes >= num_nodes - 1).any(): self.probs = F.softmax( - self.dense_or_moe( - out.view(-1, out.size(-1)).mean(dim=0, keepdim=True) - ), + self.dense_or_moe(out.view(-1, out.size(-1)).mean(dim=0, keepdim=True)), dim=-1, ) selected = self.probs.multinomial(1).squeeze(0) - out = ( - self.project_out_moe(out) - if selected.item() == 1 - else self.project_out(out) - ) + out = self.project_out_moe(out) if selected.item() == 1 else self.project_out(out) glimpse = out * self.probs.squeeze(0)[selected] else: glimpse = self.project_out_moe(out) @@ -408,13 +392,13 @@ def __init__(self, *args, **kwargs): "Note that several components of the previous LogitAttention have moved to `rl4co.models.nn.dec_strategies`.", category=DeprecationWarning, ) - super(LogitAttention, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) # MultiHeadCompat class MultiHeadCompat(nn.Module): def __init__(self, n_heads, input_dim, embed_dim=None, val_dim=None, key_dim=None): - super(MultiHeadCompat, self).__init__() + super().__init__() if val_dim is None: # assert embed_dim is not None, "Provide either embed_dim or val_dim" @@ -499,17 +483,13 @@ class PolyNetAttention(PointerAttention): sdpa_fn: scaled dot product attention function (SDPA) implementation """ - def __init__( - self, k: int, embed_dim: int, poly_layer_dim: int, num_heads: int, **kwargs - ): - super(PolyNetAttention, self).__init__(embed_dim, num_heads, **kwargs) + def __init__(self, k: int, embed_dim: int, poly_layer_dim: int, num_heads: int, **kwargs): + super().__init__(embed_dim, num_heads, **kwargs) self.k = k self.binary_vector_dim = math.ceil(math.log2(k)) self.binary_vectors = torch.nn.Parameter( - torch.Tensor( - list(itertools.product([0, 1], repeat=self.binary_vector_dim))[:k] - ), + torch.Tensor(list(itertools.product([0, 1], repeat=self.binary_vector_dim))[:k]), requires_grad=False, ) @@ -532,9 +512,7 @@ def forward(self, query, key, value, logit_key, attn_mask=None): glimpse = self.project_out(heads) num_solutions = glimpse.shape[1] - z = self.binary_vectors.repeat(math.ceil(num_solutions / self.k), 1)[ - :num_solutions - ] + z = self.binary_vectors.repeat(math.ceil(num_solutions / self.k), 1)[:num_solutions] z = z[None].expand(glimpse.shape[0], num_solutions, self.binary_vector_dim) # PolyNet layers diff --git a/rl4co/models/nn/env_embeddings/context.py b/rl4co/models/nn/env_embeddings/context.py index 56e286ff..ae0904dd 100644 --- a/rl4co/models/nn/env_embeddings/context.py +++ b/rl4co/models/nn/env_embeddings/context.py @@ -53,7 +53,7 @@ class EnvContext(nn.Module): Consists of a linear layer that projects the node features to the embedding space.""" def __init__(self, embed_dim, step_context_dim=None, linear_bias=False): - super(EnvContext, self).__init__() + super().__init__() self.embed_dim = embed_dim step_context_dim = step_context_dim if step_context_dim is not None else embed_dim self.project_context = nn.Linear(step_context_dim, embed_dim, bias=linear_bias) @@ -110,17 +110,13 @@ class TSPContext(EnvContext): """ def __init__(self, embed_dim): - super(TSPContext, self).__init__(embed_dim, 2 * embed_dim) - self.W_placeholder = nn.Parameter( - torch.Tensor(2 * self.embed_dim).uniform_(-1, 1) - ) + super().__init__(embed_dim, 2 * embed_dim) + self.W_placeholder = nn.Parameter(torch.Tensor(2 * self.embed_dim).uniform_(-1, 1)) def forward(self, embeddings, td): batch_size = embeddings.size(0) # By default, node_dim = -1 (we only have one node embedding per node) - node_dim = ( - (-1,) if td["first_node"].dim() == 1 else (td["first_node"].size(-1), -1) - ) + node_dim = (-1,) if td["first_node"].dim() == 1 else (td["first_node"].size(-1), -1) if td["i"][(0,) * td["i"].dim()].item() < 1: # get first item fast if len(td.batch_size) < 2: context_embedding = self.W_placeholder[None, :].expand( @@ -133,9 +129,7 @@ def forward(self, embeddings, td): else: context_embedding = gather_by_index( embeddings, - torch.stack([td["first_node"], td["current_node"]], -1).view( - batch_size, -1 - ), + torch.stack([td["first_node"], td["current_node"]], -1).view(batch_size, -1), ).view(batch_size, *node_dim) return self.project_context(context_embedding) @@ -148,9 +142,7 @@ class VRPContext(EnvContext): """ def __init__(self, embed_dim): - super(VRPContext, self).__init__( - embed_dim=embed_dim, step_context_dim=embed_dim + 1 - ) + super().__init__(embed_dim=embed_dim, step_context_dim=embed_dim + 1) def _state_embedding(self, embeddings, td): state_embedding = td["vehicle_capacity"] - td["used_capacity"] @@ -166,9 +158,7 @@ class VRPTWContext(VRPContext): """ def __init__(self, embed_dim): - super(VRPContext, self).__init__( - embed_dim=embed_dim, step_context_dim=embed_dim + 2 - ) + super(VRPContext, self).__init__(embed_dim=embed_dim, step_context_dim=embed_dim + 2) def _state_embedding(self, embeddings, td): capacity = super()._state_embedding(embeddings, td) @@ -184,7 +174,7 @@ class SVRPContext(EnvContext): """ def __init__(self, embed_dim): - super(SVRPContext, self).__init__(embed_dim=embed_dim, step_context_dim=embed_dim) + super().__init__(embed_dim=embed_dim, step_context_dim=embed_dim) def forward(self, embeddings, td): cur_node_embedding = self._cur_node_embedding(embeddings, td).squeeze() @@ -199,12 +189,12 @@ class PCTSPContext(EnvContext): """ def __init__(self, embed_dim): - super(PCTSPContext, self).__init__(embed_dim, embed_dim + 1) + super().__init__(embed_dim, embed_dim + 1) def _state_embedding(self, embeddings, td): - state_embedding = torch.clamp( - td["prize_required"] - td["cur_total_prize"], min=0 - )[..., None] + state_embedding = torch.clamp(td["prize_required"] - td["cur_total_prize"], min=0)[ + ..., None + ] return state_embedding @@ -216,7 +206,7 @@ class OPContext(EnvContext): """ def __init__(self, embed_dim): - super(OPContext, self).__init__(embed_dim, embed_dim + 1) + super().__init__(embed_dim, embed_dim + 1) def _state_embedding(self, embeddings, td): state_embedding = td["max_length"][..., 0] - td["tour_length"] @@ -230,7 +220,7 @@ class DPPContext(EnvContext): """ def __init__(self, embed_dim): - super(DPPContext, self).__init__(embed_dim) + super().__init__(embed_dim) def forward(self, embeddings, td): """Context cannot be defined by a single node embedding for DPP, hence 0. @@ -246,7 +236,7 @@ class PDPContext(EnvContext): """ def __init__(self, embed_dim): - super(PDPContext, self).__init__(embed_dim, embed_dim) + super().__init__(embed_dim, embed_dim) def forward(self, embeddings, td): cur_node_embedding = self._cur_node_embedding(embeddings, td).squeeze() @@ -264,10 +254,8 @@ class MTSPContext(EnvContext): """ def __init__(self, embed_dim, linear_bias=False): - super(MTSPContext, self).__init__(embed_dim, 2 * embed_dim) - proj_in_dim = ( - 4 # remaining_agents, current_length, max_subtour_length, distance_from_depot - ) + super().__init__(embed_dim, 2 * embed_dim) + proj_in_dim = 4 # remaining_agents, current_length, max_subtour_length, distance_from_depot self.proj_dynamic_feats = nn.Linear(proj_in_dim, embed_dim, bias=linear_bias) def _cur_node_embedding(self, embeddings, td): @@ -300,7 +288,7 @@ class SMTWTPContext(EnvContext): """ def __init__(self, embed_dim): - super(SMTWTPContext, self).__init__(embed_dim, embed_dim + 1) + super().__init__(embed_dim, embed_dim + 1) def _cur_node_embedding(self, embeddings, td): cur_node_embedding = gather_by_index(embeddings, td["current_job"]) @@ -318,7 +306,7 @@ class MDCPDPContext(EnvContext): """ def __init__(self, embed_dim): - super(MDCPDPContext, self).__init__(embed_dim, embed_dim * 2 + 5) + super().__init__(embed_dim, embed_dim * 2 + 5) def _state_embedding(self, embeddings, td): # get number of visited cities over total @@ -380,17 +368,11 @@ class MTVRPContext(VRPContext): """ def __init__(self, embed_dim): - super(VRPContext, self).__init__( - embed_dim=embed_dim, step_context_dim=embed_dim + 5 - ) + super(VRPContext, self).__init__(embed_dim=embed_dim, step_context_dim=embed_dim + 5) def _state_embedding(self, embeddings, td): - remaining_linehaul_capacity = ( - td["vehicle_capacity"] - td["used_capacity_linehaul"] - ) - remaining_backhaul_capacity = ( - td["vehicle_capacity"] - td["used_capacity_backhaul"] - ) + remaining_linehaul_capacity = td["vehicle_capacity"] - td["used_capacity_linehaul"] + remaining_backhaul_capacity = td["vehicle_capacity"] - td["used_capacity_backhaul"] current_time = td["current_time"] current_route_length = td["current_route_length"] open_route = td["open_route"] @@ -410,7 +392,7 @@ class FLPContext(EnvContext): """Context embedding for the Facility Location Problem (FLP).""" def __init__(self, embed_dim: int): - super(FLPContext, self).__init__(embed_dim=embed_dim) + super().__init__(embed_dim=embed_dim) self.embed_dim = embed_dim self.project_context = nn.Linear(embed_dim, embed_dim, bias=True) @@ -429,7 +411,7 @@ class MCPContext(EnvContext): """Context embedding for the Maximum Coverage Problem (MCP).""" def __init__(self, embed_dim: int): - super(MCPContext, self).__init__(embed_dim=embed_dim) + super().__init__(embed_dim=embed_dim) self.embed_dim = embed_dim self.project_context = nn.Linear(embed_dim, embed_dim, bias=True) @@ -441,8 +423,6 @@ def forward(self, embeddings, td): # membership_weighted: [batch_size, n_sets] # softmax; higher weights for better sets - membership_weighted = torch.softmax( - membership_weighted, dim=-1 - ) # (batch_size, n_sets) + membership_weighted = torch.softmax(membership_weighted, dim=-1) # (batch_size, n_sets) context_embedding = (membership_weighted.unsqueeze(-1) * embeddings).sum(1) return self.project_context(context_embedding) diff --git a/rl4co/models/nn/env_embeddings/dynamic.py b/rl4co/models/nn/env_embeddings/dynamic.py index 470af835..84533df9 100644 --- a/rl4co/models/nn/env_embeddings/dynamic.py +++ b/rl4co/models/nn/env_embeddings/dynamic.py @@ -51,7 +51,7 @@ class StaticEmbedding(nn.Module): """ def __init__(self, *args, **kwargs): - super(StaticEmbedding, self).__init__() + super().__init__() def forward(self, td): return 0, 0, 0 @@ -66,7 +66,7 @@ class SDVRPDynamicEmbedding(nn.Module): """ def __init__(self, embed_dim, linear_bias=False): - super(SDVRPDynamicEmbedding, self).__init__() + super().__init__() self.projection = nn.Linear(1, 3 * embed_dim, bias=linear_bias) def forward(self, td): @@ -104,9 +104,7 @@ def forward(self, td, cache): # bs, ma, ops masked_proc_times[ma_busy] = 0.0 # bs, ops, ma, 3 - edge_feat = self.project_edge_step(masked_proc_times.unsqueeze(-1)).transpose( - 1, 2 - ) + edge_feat = self.project_edge_step(masked_proc_times.unsqueeze(-1)).transpose(1, 2) job_edge_feat = gather_by_index(edge_feat, td["next_op"], dim=1) # bs, nodes, 3*emb edge_upd = torch.einsum("ijkl,ikm->ijlm", job_edge_feat, ma_emb).view( @@ -115,7 +113,5 @@ def forward(self, td, cache): updates = updates + edge_upd # (bs, nodes, emb) - glimpse_key_dynamic, glimpse_val_dynamic, logit_key_dynamic = updates.chunk( - 3, dim=-1 - ) + glimpse_key_dynamic, glimpse_val_dynamic, logit_key_dynamic = updates.chunk(3, dim=-1) return glimpse_key_dynamic, glimpse_val_dynamic, logit_key_dynamic diff --git a/rl4co/models/nn/env_embeddings/edge.py b/rl4co/models/nn/env_embeddings/edge.py index be50f9a6..fcbe0858 100644 --- a/rl4co/models/nn/env_embeddings/edge.py +++ b/rl4co/models/nn/env_embeddings/edge.py @@ -1,4 +1,4 @@ -from typing import Callable, Union +from collections.abc import Callable import torch import torch.nn as nn @@ -64,14 +64,14 @@ def __init__( embed_dim, linear_bias=True, sparsify=True, - k_sparse: Union[int, Callable[[int], int], None] = None, + k_sparse: int | Callable[[int], int] | None = None, ): assert Batch is not None, ( "torch_geometric not found. Please install torch_geometric using instructions from " "https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html." ) - super(TSPEdgeEmbedding, self).__init__() + super().__init__() if k_sparse is None: self._get_k_sparse = lambda n: max(n // 5, 10) @@ -101,18 +101,14 @@ def _cost_matrix_to_graph(self, batch_cost_matrix: Tensor, init_embeddings: Tens graph_data = [] for index, cost_matrix in enumerate(batch_cost_matrix): if self.sparsify: - edge_index, edge_attr = sparsify_graph( - cost_matrix, k_sparse, self_loop=False - ) + edge_index, edge_attr = sparsify_graph(cost_matrix, k_sparse, self_loop=False) else: - edge_index = get_full_graph_edge_index( - cost_matrix.shape[0], self_loop=False - ).to(cost_matrix.device) + edge_index = get_full_graph_edge_index(cost_matrix.shape[0], self_loop=False).to( + cost_matrix.device + ) edge_attr = cost_matrix[edge_index[0], edge_index[1]].unsqueeze(-1) - graph = Data( - x=init_embeddings[index], edge_index=edge_index, edge_attr=edge_attr - ) + graph = Data(x=init_embeddings[index], edge_index=edge_index, edge_attr=edge_attr) graph_data.append(graph) batch = Batch.from_data_list(graph_data) @@ -167,14 +163,12 @@ def _cost_matrix_to_graph(self, batch_cost_matrix: Tensor, init_embeddings: Tens ) else: - edge_index = get_full_graph_edge_index( - cost_matrix.shape[0], self_loop=False - ).to(cost_matrix.device) + edge_index = get_full_graph_edge_index(cost_matrix.shape[0], self_loop=False).to( + cost_matrix.device + ) edge_attr = cost_matrix[edge_index[0], edge_index[1]].unsqueeze(-1) - graph = Data( - x=init_embeddings[index], edge_index=edge_index, edge_attr=edge_attr - ) + graph = Data(x=init_embeddings[index], edge_index=edge_index, edge_attr=edge_attr) graph_data.append(graph) batch = Batch.from_data_list(graph_data) # type: ignore @@ -269,7 +263,7 @@ def __init__(self, embed_dim, self_loop=False, **kwargs): "https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html." ) - super(NoEdgeEmbedding, self).__init__() + super().__init__() self.embed_dim = embed_dim self.self_loop = self_loop diff --git a/rl4co/models/nn/env_embeddings/init.py b/rl4co/models/nn/env_embeddings/init.py index 1d0202c3..c1d0e6bf 100644 --- a/rl4co/models/nn/env_embeddings/init.py +++ b/rl4co/models/nn/env_embeddings/init.py @@ -59,7 +59,7 @@ class TSPInitEmbedding(nn.Module): """ def __init__(self, embed_dim, linear_bias=True): - super(TSPInitEmbedding, self).__init__() + super().__init__() node_dim = 2 # x, y self.init_embed = nn.Linear(node_dim, embed_dim, linear_bias) @@ -120,7 +120,7 @@ class VRPInitEmbedding(nn.Module): """ def __init__(self, embed_dim, linear_bias=True, node_dim: int = 3): - super(VRPInitEmbedding, self).__init__() + super().__init__() node_dim = node_dim # 3: x, y, demand self.init_embed = nn.Linear(node_dim, embed_dim, linear_bias) self.init_embed_depot = nn.Linear(2, embed_dim, linear_bias) # depot embedding @@ -130,9 +130,7 @@ def forward(self, td): depot, cities = td["locs"][:, :1, :], td["locs"][:, 1:, :] depot_embedding = self.init_embed_depot(depot) # [batch, n_city, 2, batch, n_city, 1] -> [batch, n_city, embed_dim] - node_embeddings = self.init_embed( - torch.cat((cities, td["demand"][..., None]), -1) - ) + node_embeddings = self.init_embed(torch.cat((cities, td["demand"][..., None]), -1)) # [batch, n_city+1, embed_dim] out = torch.cat((depot_embedding, node_embeddings), -2) return out @@ -141,7 +139,7 @@ def forward(self, td): class VRPTWInitEmbedding(VRPInitEmbedding): def __init__(self, embed_dim, linear_bias=True, node_dim: int = 6): # node_dim = 6: x, y, demand, tw start, tw end, service time - super(VRPTWInitEmbedding, self).__init__(embed_dim, linear_bias, node_dim) + super().__init__(embed_dim, linear_bias, node_dim) def forward(self, td): depot, cities = td["locs"][:, :1, :], td["locs"][:, 1:, :] @@ -150,9 +148,7 @@ def forward(self, td): # embeddings depot_embedding = self.init_embed_depot(depot) node_embeddings = self.init_embed( - torch.cat( - (cities, td["demand"][..., None], time_windows, durations[..., None]), -1 - ) + torch.cat((cities, td["demand"][..., None], time_windows, durations[..., None]), -1) ) return torch.cat((depot_embedding, node_embeddings), -2) @@ -171,15 +167,13 @@ def __init__( node_dim: int = 3, attach_cartesian_coords=False, ): - super(VRPPolarInitEmbedding, self).__init__() + super().__init__() self.node_dim = node_dim + ( 2 if attach_cartesian_coords else 0 ) # 3: r, theta, demand; 5: r, theta, demand, x, y; self.attach_cartesian_coords = attach_cartesian_coords self.init_embed = nn.Linear(self.node_dim, embed_dim, linear_bias) - self.init_embed_depot = nn.Linear( - self.node_dim, embed_dim, linear_bias - ) # depot embedding + self.init_embed_depot = nn.Linear(self.node_dim, embed_dim, linear_bias) # depot embedding def forward(self, td): with torch.no_grad(): @@ -208,7 +202,7 @@ def forward(self, td): class SVRPInitEmbedding(nn.Module): def __init__(self, embed_dim, linear_bias=True, node_dim: int = 3): - super(SVRPInitEmbedding, self).__init__() + super().__init__() node_dim = node_dim # 3: x, y, skill self.init_embed = nn.Linear(node_dim, embed_dim, linear_bias) self.init_embed_depot = nn.Linear(2, embed_dim, linear_bias) # depot embedding @@ -234,7 +228,7 @@ class PCTSPInitEmbedding(nn.Module): """ def __init__(self, embed_dim, linear_bias=True): - super(PCTSPInitEmbedding, self).__init__() + super().__init__() node_dim = 4 # x, y, prize, penalty self.init_embed = nn.Linear(node_dim, embed_dim, linear_bias) self.init_embed_depot = nn.Linear(2, embed_dim, linear_bias) @@ -265,7 +259,7 @@ class OPInitEmbedding(nn.Module): """ def __init__(self, embed_dim, linear_bias=True): - super(OPInitEmbedding, self).__init__() + super().__init__() node_dim = 3 # x, y, prize self.init_embed = nn.Linear(node_dim, embed_dim, linear_bias) self.init_embed_depot = nn.Linear(2, embed_dim, linear_bias) # depot embedding @@ -294,16 +288,14 @@ class DPPInitEmbedding(nn.Module): """ def __init__(self, embed_dim, linear_bias=True): - super(DPPInitEmbedding, self).__init__() + super().__init__() node_dim = 2 # x, y self.init_embed = nn.Linear(node_dim, embed_dim // 2, linear_bias) # locs self.init_embed_probe = nn.Linear(1, embed_dim // 2, linear_bias) # probe def forward(self, td): node_embeddings = self.init_embed(td["locs"]) - probe_embedding = self.init_embed_probe( - self._distance_probe(td["locs"], td["probe"]) - ) + probe_embedding = self.init_embed_probe(self._distance_probe(td["locs"], td["probe"])) return torch.cat([node_embeddings, probe_embedding], -1) def _distance_probe(self, locs, probe): @@ -320,12 +312,10 @@ class MDPPInitEmbedding(nn.Module): """ def __init__(self, embed_dim, linear_bias=True): - super(MDPPInitEmbedding, self).__init__() + super().__init__() node_dim = 2 # x, y self.init_embed = nn.Linear(node_dim, embed_dim, linear_bias) # locs - self.init_embed_probe_distance = nn.Linear( - 1, embed_dim, linear_bias - ) # probe_distance + self.init_embed_probe_distance = nn.Linear(1, embed_dim, linear_bias) # probe_distance self.project_out = nn.Linear(embed_dim * 2, embed_dim, linear_bias) def forward(self, td): @@ -339,9 +329,7 @@ def forward(self, td): min_dist, _ = torch.min(dist, dim=1) min_probe_dist_embedding = self.init_embed_probe_distance(min_dist[..., None]) - return self.project_out( - torch.cat([node_embeddings, min_probe_dist_embedding], -1) - ) + return self.project_out(torch.cat([node_embeddings, min_probe_dist_embedding], -1)) class PDPInitEmbedding(nn.Module): @@ -352,7 +340,7 @@ class PDPInitEmbedding(nn.Module): """ def __init__(self, embed_dim, linear_bias=True): - super(PDPInitEmbedding, self).__init__() + super().__init__() node_dim = 2 # x, y self.init_embed_depot = nn.Linear(2, embed_dim, linear_bias) self.init_embed_pick = nn.Linear(node_dim * 2, embed_dim, linear_bias) @@ -380,7 +368,7 @@ class MTSPInitEmbedding(nn.Module): def __init__(self, embed_dim, linear_bias=True): """NOTE: new made by Fede. May need to be checked""" - super(MTSPInitEmbedding, self).__init__() + super().__init__() node_dim = 2 # x, y self.init_embed = nn.Linear(node_dim, embed_dim, linear_bias) self.init_embed_depot = nn.Linear(2, embed_dim, linear_bias) # depot embedding @@ -400,7 +388,7 @@ class SMTWTPInitEmbedding(nn.Module): """ def __init__(self, embed_dim, linear_bias=True): - super(SMTWTPInitEmbedding, self).__init__() + super().__init__() node_dim = 3 # job_due_time, job_weight, job_process_time self.init_embed = nn.Linear(node_dim, embed_dim, linear_bias) @@ -421,7 +409,7 @@ class MDCPDPInitEmbedding(nn.Module): """ def __init__(self, embed_dim, linear_bias=True): - super(MDCPDPInitEmbedding, self).__init__() + super().__init__() node_dim = 2 # x, y self.init_embed_depot = nn.Linear(2, embed_dim, linear_bias) self.init_embed_pick = nn.Linear(node_dim * 2, embed_dim, linear_bias) @@ -450,7 +438,7 @@ def __init__( scaling_factor: int = 1000, num_op_feats=5, ): - super(JSSPInitEmbedding, self).__init__() + super().__init__() self.embed_dim = embed_dim self.scaling_factor = scaling_factor self.init_ops_embed = nn.Linear(num_op_feats, embed_dim, linear_bias) @@ -537,7 +525,7 @@ def forward(self, td: TensorDict): class MTVRPInitEmbedding(VRPInitEmbedding): def __init__(self, embed_dim, linear_bias=True, node_dim: int = 7): # node_dim = 7: x, y, demand_linehaul, demand_backhaul, tw start, tw end, service time - super(MTVRPInitEmbedding, self).__init__(embed_dim, linear_bias, node_dim) + super().__init__(embed_dim, linear_bias, node_dim) def forward(self, td): depot, cities = td["locs"][:, :1, :], td["locs"][:, 1:, :] diff --git a/rl4co/models/nn/flash_attention.py b/rl4co/models/nn/flash_attention.py index 28dff562..da215201 100644 --- a/rl4co/models/nn/flash_attention.py +++ b/rl4co/models/nn/flash_attention.py @@ -22,15 +22,13 @@ def fused_chunk_linear_attn_wrapper( normalize: bool = True, **kwargs, ): - assert ( - fused_chunk_linear_attn is not None - ), "fused_chunk_linear_attn not found. Install Flash Linear Attention using instructions from https://github.com/sustcsonglin/flash-linear-attention" - assert ( - kwargs.get("attn_mask", None) is None - ), "attn_mask is not supported in Flash Linear Attention" - return fused_chunk_linear_attn( - q, k, v, scale, initial_state, output_final_state, normalize - )[0] + assert fused_chunk_linear_attn is not None, ( + "fused_chunk_linear_attn not found. Install Flash Linear Attention using instructions from https://github.com/sustcsonglin/flash-linear-attention" + ) + assert kwargs.get("attn_mask", None) is None, ( + "attn_mask is not supported in Flash Linear Attention" + ) + return fused_chunk_linear_attn(q, k, v, scale, initial_state, output_final_state, normalize)[0] def scaled_dot_product_attention_flash_attn( diff --git a/rl4co/models/nn/graph/attnnet.py b/rl4co/models/nn/graph/attnnet.py index 9bfc29c6..bd5c44c9 100644 --- a/rl4co/models/nn/graph/attnnet.py +++ b/rl4co/models/nn/graph/attnnet.py @@ -1,12 +1,12 @@ -from typing import Callable, Optional +from collections.abc import Callable import torch.nn as nn from torch import Tensor +from rl4co.models.nn.attention import MultiHeadAttention from rl4co.models.nn.mlp import MLP from rl4co.models.nn.moe import MoE -from rl4co.models.nn.attention import MultiHeadAttention from rl4co.models.nn.ops import Normalization, SkipConnection from rl4co.utils.pylogger import get_pylogger @@ -30,21 +30,24 @@ def __init__( embed_dim: int, num_heads: int = 8, feedforward_hidden: int = 512, - normalization: Optional[str] = "batch", + normalization: str | None = "batch", bias: bool = True, - sdpa_fn: Optional[Callable] = None, - moe_kwargs: Optional[dict] = None, + sdpa_fn: Callable | None = None, + moe_kwargs: dict | None = None, ): num_neurons = [feedforward_hidden] if feedforward_hidden > 0 else [] if moe_kwargs is not None: ffn = MoE(embed_dim, embed_dim, num_neurons=num_neurons, **moe_kwargs) else: - ffn = MLP(input_dim=embed_dim, output_dim=embed_dim, num_neurons=num_neurons, hidden_act="ReLU") + ffn = MLP( + input_dim=embed_dim, + output_dim=embed_dim, + num_neurons=num_neurons, + hidden_act="ReLU", + ) - super(MultiHeadAttentionLayer, self).__init__( - SkipConnection( - MultiHeadAttention(embed_dim, num_heads, bias=bias, sdpa_fn=sdpa_fn) - ), + super().__init__( + SkipConnection(MultiHeadAttention(embed_dim, num_heads, bias=bias, sdpa_fn=sdpa_fn)), Normalization(embed_dim, normalization), SkipConnection(ffn), Normalization(embed_dim, normalization), @@ -72,10 +75,10 @@ def __init__( num_layers: int, normalization: str = "batch", feedforward_hidden: int = 512, - sdpa_fn: Optional[Callable] = None, - moe_kwargs: Optional[dict] = None, + sdpa_fn: Callable | None = None, + moe_kwargs: dict | None = None, ): - super(GraphAttentionNetwork, self).__init__() + super().__init__() self.layers = nn.Sequential( *( @@ -91,7 +94,7 @@ def __init__( ) ) - def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: + def forward(self, x: Tensor, mask: Tensor | None = None) -> Tensor: """Forward pass of the encoder Args: diff --git a/rl4co/models/nn/graph/gcn.py b/rl4co/models/nn/graph/gcn.py index 18740d41..d4f8a77a 100644 --- a/rl4co/models/nn/graph/gcn.py +++ b/rl4co/models/nn/graph/gcn.py @@ -1,4 +1,4 @@ -from typing import Callable, Tuple +from collections.abc import Callable import torch.nn as nn import torch.nn.functional as F @@ -72,9 +72,7 @@ def __init__( [GCNConv(embed_dim, embed_dim, bias=bias) for _ in range(num_layers)] ) - def forward( - self, td: TensorDict, mask: Tensor | None = None - ) -> Tuple[Tensor, Tensor]: + def forward(self, td: TensorDict, mask: Tensor | None = None) -> tuple[Tensor, Tensor]: """Forward pass of the encoder. Transform the input TensorDict into a latent representation. diff --git a/rl4co/models/nn/graph/gnn.py b/rl4co/models/nn/graph/gnn.py index 91e84ffe..8c5cb940 100644 --- a/rl4co/models/nn/graph/gnn.py +++ b/rl4co/models/nn/graph/gnn.py @@ -26,7 +26,7 @@ def __init__(self, units: int, act_fn: str = "silu", agg_fn: str = "mean"): "https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html." ) - super(GNNLayer, self).__init__() + super().__init__() self.units = units self.act_fn = getattr(nn.functional, act_fn) self.agg_fn = getattr(gnn, f"global_{agg_fn}_pool") @@ -52,9 +52,7 @@ def forward(self, x, edge_index, edge_attr): x3 = self.v_lin3(x0) x4 = self.v_lin4(x0) x = x0 + self.act_fn( - self.v_bn( - x1 + self.agg_fn(torch.sigmoid(w0) * x2[edge_index[1]], edge_index[0]) - ) + self.v_bn(x1 + self.agg_fn(torch.sigmoid(w0) * x2[edge_index[1]], edge_index[0])) ) # Edge updates @@ -74,7 +72,7 @@ class GNNEncoder(nn.Module): """ def __init__(self, num_layers: int, embed_dim: int, act_fn="silu", agg_fn="mean"): - super(GNNEncoder, self).__init__() + super().__init__() self.act_fn = getattr(nn.functional, act_fn) self.agg_fn = agg_fn diff --git a/rl4co/models/nn/graph/hgnn.py b/rl4co/models/nn/graph/hgnn.py index bd4ce0d2..8963dbd6 100644 --- a/rl4co/models/nn/graph/hgnn.py +++ b/rl4co/models/nn/graph/hgnn.py @@ -24,9 +24,7 @@ def __init__( self.activation = nn.ReLU() self.scale = 1 / math.sqrt(embed_dim) - def forward( - self, self_emb: Tensor, other_emb: Tensor, edge_emb: Tensor, edges: Tensor - ): + def forward(self, self_emb: Tensor, other_emb: Tensor, edge_emb: Tensor, edges: Tensor): bs, n_rows, _ = self_emb.shape # concat operation embeddings and o-m edge features (proc times) diff --git a/rl4co/models/nn/graph/mpnn.py b/rl4co/models/nn/graph/mpnn.py index f6ef519a..b8af255c 100644 --- a/rl4co/models/nn/graph/mpnn.py +++ b/rl4co/models/nn/graph/mpnn.py @@ -1,5 +1,3 @@ -from typing import Tuple - import torch import torch.nn as nn @@ -26,7 +24,7 @@ def __init__( residual=False, **mlp_params, ): - super(MessagePassingLayer, self).__init__(aggr=aggregation) + super().__init__(aggr=aggregation) # Init message passing models self.edge_model = MLP( input_dim=edge_indim + 2 * node_indim, output_dim=edge_outdim, **mlp_params @@ -79,7 +77,7 @@ def __init__( Note: - Support fully connected graph for now. """ - super(MessagePassingEncoder, self).__init__() + super().__init__() self.env_name = env_name @@ -114,9 +112,7 @@ def __init__( self.self_loop = self_loop # def forward(self, x, mask=None): - def forward( - self, td: TensorDict, mask: Tensor | None = None - ) -> Tuple[Tensor, Tensor]: + def forward(self, td: TensorDict, mask: Tensor | None = None) -> tuple[Tensor, Tensor]: init_h = self.init_embedding(td) num_node = init_h.size(-2) diff --git a/rl4co/models/nn/mlp.py b/rl4co/models/nn/mlp.py index 295c15a5..92249e4a 100644 --- a/rl4co/models/nn/mlp.py +++ b/rl4co/models/nn/mlp.py @@ -17,7 +17,7 @@ def __init__( input_norm: str = "None", output_norm: str = "None", ): - super(MLP, self).__init__() + super().__init__() assert input_norm in ["Batch", "Layer", "None"] assert output_norm in ["Batch", "Layer", "None"] @@ -69,9 +69,7 @@ def _get_norm_layer(norm_method, dim): elif norm_method == "None": in_norm = nn.Identity() # kinda placeholder else: - raise RuntimeError( - "Not implemented normalization layer type {}".format(norm_method) - ) + raise RuntimeError(f"Not implemented normalization layer type {norm_method}") return in_norm def _get_act(self, is_last): diff --git a/rl4co/models/nn/moe.py b/rl4co/models/nn/moe.py index a2b04584..64cd2f25 100644 --- a/rl4co/models/nn/moe.py +++ b/rl4co/models/nn/moe.py @@ -1,5 +1,6 @@ import torch import torch.nn as nn + from torch.distributions.normal import Normal from rl4co.models.nn.mlp import MLP @@ -11,7 +12,7 @@ """ -class SparseDispatcher(object): +class SparseDispatcher: """ Helper for implementing a mixture of experts. The purpose of this class is to create input minibatches for the experts @@ -96,7 +97,9 @@ def combine(self, expert_out, multiply_by_gates=True): if multiply_by_gates: stitched = stitched.mul(self._nonzero_gates) - zeros = torch.zeros(self._gates.size(0), expert_out[-1].size(-1), requires_grad=True, device=stitched.device) + zeros = torch.zeros( + self._gates.size(0), expert_out[-1].size(-1), requires_grad=True, device=stitched.device + ) # combine samples that have been processed by the same k experts combined = zeros.index_add(0, self._batch_index, stitched.float()) return combined @@ -122,8 +125,19 @@ class MoE(nn.Module): k: an integer - how many experts to use for each batch element """ - def __init__(self, input_size, output_size, num_neurons=[], hidden_act="ReLU", out_bias=True, num_experts=4, k=2, noisy_gating=True, **kwargs): - super(MoE, self).__init__() + def __init__( + self, + input_size, + output_size, + num_neurons=[], + hidden_act="ReLU", + out_bias=True, + num_experts=4, + k=2, + noisy_gating=True, + **kwargs, + ): + super().__init__() self.noisy_gating = noisy_gating self.num_experts = num_experts self.output_size = output_size @@ -132,10 +146,24 @@ def __init__(self, input_size, output_size, num_neurons=[], hidden_act="ReLU", o # instantiate experts if num_neurons != []: - self.experts = nn.ModuleList([MLP(input_dim=input_size, output_dim=output_size, num_neurons=num_neurons, - hidden_act=hidden_act) for _ in range(self.num_experts)]) + self.experts = nn.ModuleList( + [ + MLP( + input_dim=input_size, + output_dim=output_size, + num_neurons=num_neurons, + hidden_act=hidden_act, + ) + for _ in range(self.num_experts) + ] + ) else: - self.experts = nn.ModuleList([nn.Linear(self.input_size, self.output_size, bias=out_bias) for _ in range(self.num_experts)]) + self.experts = nn.ModuleList( + [ + nn.Linear(self.input_size, self.output_size, bias=out_bias) + for _ in range(self.num_experts) + ] + ) self.w_gate = nn.Parameter(torch.zeros(input_size, num_experts), requires_grad=True) self.w_noise = nn.Parameter(torch.zeros(input_size, num_experts), requires_grad=True) @@ -143,7 +171,7 @@ def __init__(self, input_size, output_size, num_neurons=[], hidden_act="ReLU", o self.softmax = nn.Softmax(-1) self.register_buffer("mean", torch.tensor([0.0])) self.register_buffer("std", torch.tensor([1.0])) - assert(self.k <= self.num_experts) + assert self.k <= self.num_experts def cv_squared(self, x): """The squared coefficient of variation of a sample. @@ -160,7 +188,7 @@ def cv_squared(self, x): if x.shape[0] == 1: return torch.tensor([0], device=x.device, dtype=x.dtype) - return x.float().var() / (x.float().mean()**2 + eps) + return x.float().var() / (x.float().mean() ** 2 + eps) def _gates_to_load(self, gates): """Compute the true load per expert, given the gates. @@ -194,14 +222,18 @@ def _prob_in_top_k(self, clean_values, noisy_values, noise_stddev, noisy_top_val top_values_flat = noisy_top_values.flatten() threshold_positions_if_in = torch.arange(batch, device=clean_values.device) * m + self.k - threshold_if_in = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_in), 1) + threshold_if_in = torch.unsqueeze( + torch.gather(top_values_flat, 0, threshold_positions_if_in), 1 + ) is_in = torch.gt(noisy_values, threshold_if_in) threshold_positions_if_out = threshold_positions_if_in - 1 - threshold_if_out = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_out), 1) + threshold_if_out = torch.unsqueeze( + torch.gather(top_values_flat, 0, threshold_positions_if_out), 1 + ) # is each value currently in the top k. normal = Normal(self.mean, self.std) - prob_if_in = normal.cdf((clean_values - threshold_if_in)/noise_stddev) - prob_if_out = normal.cdf((clean_values - threshold_if_out)/noise_stddev) + prob_if_in = normal.cdf((clean_values - threshold_if_in) / noise_stddev) + prob_if_out = normal.cdf((clean_values - threshold_if_out) / noise_stddev) prob = torch.where(is_in, prob_if_in, prob_if_out) return prob @@ -228,20 +260,22 @@ def noisy_top_k_gating(self, x, train, noise_epsilon=1e-2): # calculate topk + 1 that will be needed for the noisy gates logits = self.softmax(logits) top_logits, top_indices = logits.topk(min(self.k + 1, self.num_experts), dim=-1) - top_k_logits = top_logits[:, :self.k] - top_k_indices = top_indices[:, :self.k] + top_k_logits = top_logits[:, : self.k] + top_k_indices = top_indices[:, : self.k] top_k_gates = top_k_logits / (top_k_logits.sum(1, keepdim=True) + 1e-6) # normalization zeros = torch.zeros_like(logits, requires_grad=True) gates = zeros.scatter(-1, top_k_indices, top_k_gates) # non-topk elements will be 0 if self.noisy_gating and self.k < self.num_experts and train: - load = (self._prob_in_top_k(clean_logits, noisy_logits, noise_stddev, top_logits)).sum(0) + load = (self._prob_in_top_k(clean_logits, noisy_logits, noise_stddev, top_logits)).sum( + 0 + ) else: load = self._gates_to_load(gates) return gates, load - def forward(self, x, loss_coef=0.): + def forward(self, x, loss_coef=0.0): """ Token/Node-level Gating with the default gating algorithm in . In specific, each token/node chooses TopK experts, auxiliary losses required for load balancing. diff --git a/rl4co/models/nn/ops.py b/rl4co/models/nn/ops.py index 42ebb559..0dce793c 100644 --- a/rl4co/models/nn/ops.py +++ b/rl4co/models/nn/ops.py @@ -1,7 +1,5 @@ import math -from typing import Tuple - import torch import torch.nn as nn @@ -10,7 +8,7 @@ class SkipConnection(nn.Module): def __init__(self, module): - super(SkipConnection, self).__init__() + super().__init__() self.module = module def forward(self, x): @@ -19,10 +17,10 @@ def forward(self, x): class AdaptiveSequential(nn.Sequential): def forward( - self, *inputs: Tuple[torch.Tensor] | torch.Tensor - ) -> Tuple[torch.Tensor] | torch.Tensor: + self, *inputs: tuple[torch.Tensor] | torch.Tensor + ) -> tuple[torch.Tensor] | torch.Tensor: for module in self._modules.values(): - if type(inputs) == tuple: + if isinstance(inputs, tuple): inputs = module(*inputs) else: inputs = module(inputs) @@ -31,7 +29,7 @@ def forward( class Normalization(nn.Module): def __init__(self, embed_dim, normalization="batch"): - super(Normalization, self).__init__() + super().__init__() if normalization != "layer": normalizer_class = { "batch": nn.BatchNorm1d, @@ -63,9 +61,7 @@ def __init__(self, embed_dim: int, dropout: float = 0.1, max_len: int = 1000): self.d_model = embed_dim max_len = max_len position = torch.arange(max_len).unsqueeze(1) - div_term = torch.exp( - torch.arange(0, self.d_model, 2) * (-math.log(10000.0) / self.d_model) - ) + div_term = torch.exp(torch.arange(0, self.d_model, 2) * (-math.log(10000.0) / self.d_model)) pe = torch.zeros(max_len, 1, self.d_model) pe[:, 0, 0::2] = torch.sin(position * div_term) pe[:, 0, 1::2] = torch.cos(position * div_term) @@ -126,9 +122,7 @@ def forward(self, hidden: torch.Tensor, classes=None) -> torch.Tensor: b, s, _ = hidden.shape if classes is None: classes = torch.eye(s).unsqueeze(0).expand(b, s) - assert ( - classes.max() < self.max_classes - ), "number of classes larger than embedding table" + assert classes.max() < self.max_classes, "number of classes larger than embedding table" classes = classes.unsqueeze(-1).expand(-1, -1, self.embed_dim) rand_idx = torch.rand(b, self.max_classes).argsort(dim=1) embs_permuted = self.emb[rand_idx] diff --git a/rl4co/models/nn/pos_embeddings.py b/rl4co/models/nn/pos_embeddings.py index 9d217e63..ca50a373 100644 --- a/rl4co/models/nn/pos_embeddings.py +++ b/rl4co/models/nn/pos_embeddings.py @@ -27,7 +27,7 @@ class AbsolutePositionalEmbedding(nn.Module): """Absolute Positional Embedding in the original Transformer.""" def __init__(self, embed_dim): - super(AbsolutePositionalEmbedding, self).__init__() + super().__init__() self.embed_dim = embed_dim self.pattern = None @@ -75,7 +75,7 @@ class CyclicPositionalEmbedding(nn.Module): """ def __init__(self, embed_dim, mean_pooling=True): - super(CyclicPositionalEmbedding, self).__init__() + super().__init__() self.embed_dim = embed_dim self.mean_pooling = mean_pooling self.pattern = None @@ -96,28 +96,16 @@ def _init(self, n_position, emb_dim, mean_pooling): x = np.zeros((n_position, emb_dim)) for i in range(emb_dim): - Td = ( - Td_set[i // 3 * 3 + 1] - if (i // 3 * 3 + 1) < (emb_dim // 2) - else Td_set[-1] - ) - fai = ( - 0 - if i <= (emb_dim // 2) - else 2 * np.pi * ((-i + (emb_dim // 2)) / (emb_dim // 2)) - ) + Td = Td_set[i // 3 * 3 + 1] if (i // 3 * 3 + 1) < (emb_dim // 2) else Td_set[-1] + fai = 0 if i <= (emb_dim // 2) else 2 * np.pi * ((-i + (emb_dim // 2)) / (emb_dim // 2)) longer_pattern = np.arange(0, np.ceil((n_position) / Td) * Td, 0.01) if i % 2 == 1: x[:, i] = self._basecos(longer_pattern, Td, fai)[ - np.linspace( - 0, len(longer_pattern), n_position, dtype="int", endpoint=False - ) + np.linspace(0, len(longer_pattern), n_position, dtype="int", endpoint=False) ] else: x[:, i] = self._basesin(longer_pattern, Td, fai)[ - np.linspace( - 0, len(longer_pattern), n_position, dtype="int", endpoint=False - ) + np.linspace(0, len(longer_pattern), n_position, dtype="int", endpoint=False) ] pattern = torch.from_numpy(x).type(torch.FloatTensor) diff --git a/rl4co/models/rl/a2c/a2c.py b/rl4co/models/rl/a2c/a2c.py index 09b3f980..2ca2733f 100644 --- a/rl4co/models/rl/a2c/a2c.py +++ b/rl4co/models/rl/a2c/a2c.py @@ -35,7 +35,7 @@ def __init__( **kwargs, ): if critic is None: - log.info("Creating critic network for {}".format(env.name)) + log.info(f"Creating critic network for {env.name}") critic = create_critic_from_actor(policy, **critic_kwargs) # The baseline is directly created here, so we eliminate the baseline argument diff --git a/rl4co/models/rl/common/base.py b/rl4co/models/rl/common/base.py index 7315798d..e0a830c5 100644 --- a/rl4co/models/rl/common/base.py +++ b/rl4co/models/rl/common/base.py @@ -1,7 +1,8 @@ import abc +from collections.abc import Iterable from functools import partial -from typing import Any, Iterable +from typing import Any import torch import torch.nn as nn @@ -98,9 +99,7 @@ def __init__( self._optimizer_name_or_cls: str | torch.optim.Optimizer = optimizer self.optimizer_kwargs: dict = optimizer_kwargs - self._lr_scheduler_name_or_cls: str | torch.optim.lr_scheduler.LRScheduler = ( - lr_scheduler - ) + self._lr_scheduler_name_or_cls: str | torch.optim.lr_scheduler.LRScheduler = lr_scheduler self.lr_scheduler_kwargs: dict = lr_scheduler_kwargs self.lr_scheduler_interval: str = lr_scheduler_interval self.lr_scheduler_monitor: str = lr_scheduler_monitor @@ -137,9 +136,7 @@ def setup(self, stage="fit"): self.test_batch_size = self.val_batch_size if test_bs is None else test_bs if self.data_cfg["generate_default_data"]: - log.info( - "Generating default datasets. If found, they will not be overwritten" - ) + log.info("Generating default datasets. If found, they will not be overwritten") generate_default_datasets(data_dir=self.data_cfg["data_dir"]) log.info("Setting up datasets") @@ -147,9 +144,7 @@ def setup(self, stage="fit"): self.env.dataset(self.data_cfg["train_data_size"], phase="train") ) self.val_dataset = self.env.dataset(self.data_cfg["val_data_size"], phase="val") - self.test_dataset = self.env.dataset( - self.data_cfg["test_data_size"], phase="test" - ) + self.test_dataset = self.env.dataset(self.data_cfg["test_data_size"], phase="test") self.dataloader_names = None self.setup_loggers() self.post_setup_hook() @@ -157,9 +152,7 @@ def setup(self, stage="fit"): def setup_loggers(self): """Log all hyperparameters except those in `nn.Module`""" if self.loggers is not None: - hparams_save = { - k: v for k, v in self.hparams.items() if not isinstance(v, nn.Module) - } + hparams_save = {k: v for k, v in self.hparams.items() if not isinstance(v, nn.Module)} for logger in self.loggers: logger.log_hyperparams(hparams_save) logger.log_graph(self) @@ -200,9 +193,7 @@ def configure_optimizers(self, parameters=None): optimizer, self._lr_scheduler_name_or_cls, **self.lr_scheduler_kwargs ) elif isinstance(self._lr_scheduler_name_or_cls, partial): - scheduler = self._lr_scheduler_name_or_cls( - optimizer, **self.lr_scheduler_kwargs - ) + scheduler = self._lr_scheduler_name_or_cls(optimizer, **self.lr_scheduler_kwargs) else: # User-defined scheduler scheduler_cls = self._lr_scheduler_name_or_cls scheduler = scheduler_cls(optimizer, **self.lr_scheduler_kwargs) @@ -213,18 +204,14 @@ def configure_optimizers(self, parameters=None): "monitor": self.lr_scheduler_monitor, } - def log_metrics( - self, metric_dict: dict, phase: str, dataloader_idx: int | None = None - ): + def log_metrics(self, metric_dict: dict, phase: str, dataloader_idx: int | None = None): """Log metrics to logger and progress bar""" metrics = getattr(self, f"{phase}_metrics") dataloader_name = "" if dataloader_idx is not None and self.dataloader_names is not None: dataloader_name = "/" + self.dataloader_names[dataloader_idx] metrics = { - f"{phase}/{k}{dataloader_name}": ( - v.mean() if isinstance(v, torch.Tensor) else v - ) + f"{phase}/{k}{dataloader_name}": (v.mean() if isinstance(v, torch.Tensor) else v) for k, v in metric_dict.items() if k in metrics } @@ -258,14 +245,10 @@ def training_step(self, batch: Any, batch_idx: int): return self.shared_step(batch, batch_idx, phase="train") def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = None): - return self.shared_step( - batch, batch_idx, phase="val", dataloader_idx=dataloader_idx - ) + return self.shared_step(batch, batch_idx, phase="val", dataloader_idx=dataloader_idx) def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = None): - return self.shared_step( - batch, batch_idx, phase="test", dataloader_idx=dataloader_idx - ) + return self.shared_step(batch, batch_idx, phase="test", dataloader_idx=dataloader_idx) def train_dataloader(self): return self._dataloader( @@ -306,18 +289,18 @@ def _dataloader(self, dataset, batch_size, shuffle=False): # if batch size is int, make it into list if isinstance(batch_size, int): batch_size = [batch_size] * len(self.dataloader_names) - assert len(batch_size) == len( - self.dataloader_names - ), f"Batch size must match number of datasets. \ + assert len(batch_size) == len(self.dataloader_names), ( + f"Batch size must match number of datasets. \ Found: {len(batch_size)} and {len(self.dataloader_names)}" + ) return [ self._dataloader_single(dset, bsize, shuffle) for dset, bsize in zip(dataset.values(), batch_size) ] else: - assert isinstance( - batch_size, int - ), f"Batch size must be an integer for a single dataset, found {batch_size}" + assert isinstance(batch_size, int), ( + f"Batch size must be an integer for a single dataset, found {batch_size}" + ) return self._dataloader_single(dataset, batch_size, shuffle) def _dataloader_single(self, dataset, batch_size, shuffle=False): diff --git a/rl4co/models/rl/common/critic.py b/rl4co/models/rl/common/critic.py index 64ec6534..dc062314 100644 --- a/rl4co/models/rl/common/critic.py +++ b/rl4co/models/rl/common/critic.py @@ -1,7 +1,5 @@ import copy -from typing import Optional - from tensordict import TensorDict from torch import Tensor, nn @@ -24,19 +22,20 @@ class CriticNetwork(nn.Module): def __init__( self, encoder: nn.Module, - value_head: Optional[nn.Module] = None, + value_head: nn.Module | None = None, embed_dim: int = 128, hidden_dim: int = 512, customized: bool = False, ): - super(CriticNetwork, self).__init__() + super().__init__() self.encoder = encoder if value_head is None: # check if embed dim of encoder is different, if so, use it if getattr(encoder, "embed_dim", embed_dim) != embed_dim: log.warning( - f"Found encoder with different embed_dim {encoder.embed_dim} than the value head {embed_dim}. Using encoder embed_dim for value head." + f"Found encoder with different embed_dim {encoder.embed_dim} than the value head {embed_dim}. \ + Using encoder embed_dim for value head." ) embed_dim = getattr(encoder, "embed_dim", embed_dim) value_head = nn.Sequential( @@ -62,15 +61,11 @@ def forward(self, x: Tensor | TensorDict, hidden=None) -> Tensor: return self.value_head(h, hidden) -def create_critic_from_actor( - policy: nn.Module, backbone: str = "encoder", **critic_kwargs -): +def create_critic_from_actor(policy: nn.Module, backbone: str = "encoder", **critic_kwargs): # we reuse the network of the policy's backbone, such as an encoder encoder = getattr(policy, backbone, None) if encoder is None: - raise ValueError( - f"CriticBaseline requires a backbone in the policy network: {backbone}" - ) + raise ValueError(f"CriticBaseline requires a backbone in the policy network: {backbone}") critic = CriticNetwork(copy.deepcopy(encoder), **critic_kwargs).to( next(policy.parameters()).device ) diff --git a/rl4co/models/rl/common/utils.py b/rl4co/models/rl/common/utils.py index 6c16976a..4e6dd2bd 100644 --- a/rl4co/models/rl/common/utils.py +++ b/rl4co/models/rl/common/utils.py @@ -32,7 +32,7 @@ def __call__(self, scores: torch.Tensor): elif self.scale == "scale": scores /= score_scaling_factor else: - raise ValueError("unknown scaling operation requested: %s" % self.scale) + raise ValueError(f"unknown scaling operation requested: {self.scale}") return scores @torch.no_grad() diff --git a/rl4co/models/rl/ppo/n_step_ppo.py b/rl4co/models/rl/ppo/n_step_ppo.py index 57d96815..711a45c2 100644 --- a/rl4co/models/rl/ppo/n_step_ppo.py +++ b/rl4co/models/rl/ppo/n_step_ppo.py @@ -96,9 +96,9 @@ def __init__( } def configure_optimizers(self): - parameters = [ - {"params": self.policy.parameters(), "lr": self.ppo_cfg["lr_policy"]} - ] + [{"params": self.critic.parameters(), "lr": self.ppo_cfg["lr_critic"]}] + parameters = [{"params": self.policy.parameters(), "lr": self.ppo_cfg["lr_policy"]}] + [ + {"params": self.critic.parameters(), "lr": self.ppo_cfg["lr_critic"]} + ] return super().configure_optimizers(parameters) @@ -113,9 +113,7 @@ def on_train_epoch_end(self): # CL scheduler self.CL_num += 1 / self.CL_scalar - def shared_step( - self, batch: Any, batch_idx: int, phase: str, dataloader_idx: int = None - ): + def shared_step(self, batch: Any, batch_idx: int, phase: str, dataloader_idx: int = None): if phase != "train": with torch.no_grad(): td = self.env.reset(batch) @@ -140,9 +138,9 @@ def shared_step( cost_init = td["cost_current"] # perform gradiant updates every n_step untill reaching T_max - assert ( - self.ppo_cfg["T_train"] % self.ppo_cfg["n_step"] == 0 - ), "T_max should be divided by n_step with no remainder" + assert self.ppo_cfg["T_train"] % self.ppo_cfg["n_step"] == 0, ( + "T_max should be divided by n_step with no remainder" + ) t = 0 while t < self.ppo_cfg["T_train"]: memory.clear_memory() @@ -153,9 +151,7 @@ def shared_step( memory.tds.append(td.clone()) out = self.policy(td, self.env, phase=phase, return_embeds=True) - value_pred = self.critic( - out["embeds"].detach(), td["cost_bsf"].unsqueeze(-1) - ) + value_pred = self.critic(out["embeds"].detach(), td["cost_bsf"].unsqueeze(-1)) memory.actions.append(out["actions"].clone()) memory.logprobs.append(out["log_likelihood"].clone()) diff --git a/rl4co/models/rl/ppo/ppo.py b/rl4co/models/rl/ppo/ppo.py index eea2ed20..8b9397ec 100644 --- a/rl4co/models/rl/ppo/ppo.py +++ b/rl4co/models/rl/ppo/ppo.py @@ -25,8 +25,8 @@ class PPO(RL4COLitModule): choice for tractable CO solution generation. This choice aligns with the Attention Model (AM) (https://openreview.net/forum?id=ByxBFsRqYm), which treats decoding steps as a single-step MDP in Equation 9. - Modeling autoregressive decoding steps as a single-step MDP introduces significant changes to the PPO implementation, - including: + Modeling autoregressive decoding steps as a single-step MDP introduces significant changes to the PPO + implementation, including: - Generalized Advantage Estimation (GAE) (https://arxiv.org/abs/1506.02438) is not applicable since we are dealing with a single-step MDP. - The definition of policy entropy can differ from the commonly implemented manner. @@ -80,13 +80,11 @@ def __init__( self.automatic_optimization = False # PPO uses custom optimization routine if critic is None: - log.info("Creating critic network for {}".format(env.name)) + log.info(f"Creating critic network for {env.name}") critic = create_critic_from_actor(policy, **critic_kwargs) self.critic = critic - if isinstance(mini_batch_size, float) and ( - mini_batch_size <= 0 or mini_batch_size > 1 - ): + if isinstance(mini_batch_size, float) and (mini_batch_size <= 0 or mini_batch_size > 1): default_mini_batch_fraction = 0.25 log.warning( f"mini_batch_size must be an integer or a float in the range (0, 1], got {mini_batch_size}. Setting mini_batch_size to {default_mini_batch_fraction}." @@ -125,9 +123,7 @@ def on_train_epoch_end(self): if isinstance(sch, torch.optim.lr_scheduler.MultiStepLR): sch.step() - def shared_step( - self, batch: Any, batch_idx: int, phase: str, dataloader_idx: int = None - ): + def shared_step(self, batch: Any, batch_idx: int, phase: str, dataloader_idx: int = None): # Evaluate old actions, log probabilities, and rewards with torch.no_grad(): td = self.env.reset(batch) # note: clone needed for dataloader @@ -175,9 +171,7 @@ def shared_step( ll, entropy = out["log_likelihood"], out["entropy"] # Compute the ratio of probabilities of new and old actions - ratio = torch.exp(ll.sum(dim=-1) - sub_td["logprobs"]).view( - -1, 1 - ) # [batch, 1] + ratio = torch.exp(ll.sum(dim=-1) - sub_td["logprobs"]).view(-1, 1) # [batch, 1] # Compute the advantage value_pred = self.critic(sub_td) # [batch, 1] diff --git a/rl4co/models/rl/ppo/stepwise_ppo.py b/rl4co/models/rl/ppo/stepwise_ppo.py index 69387547..af23827f 100644 --- a/rl4co/models/rl/ppo/stepwise_ppo.py +++ b/rl4co/models/rl/ppo/stepwise_ppo.py @@ -138,9 +138,7 @@ def update(self, device): outs = {k: torch.stack([dic[k] for dic in outs], dim=0) for k in outs[0]} return outs - def shared_step( - self, batch: Any, batch_idx: int, phase: str, dataloader_idx: int = None - ): + def shared_step(self, batch: Any, batch_idx: int, phase: str, dataloader_idx: int = None): next_td = self.env.reset(batch) device = next_td.device if phase == "train": @@ -163,9 +161,7 @@ def shared_step( self.rb.empty() else: - out = self.policy.generate( - next_td, self.env, phase=phase, select_best=phase != "train" - ) + out = self.policy.generate(next_td, self.env, phase=phase, select_best=phase != "train") metrics = self.log_metrics(out, phase, dataloader_idx=dataloader_idx) return {"loss": out.get("loss", None), **metrics} diff --git a/rl4co/models/rl/reinforce/baselines.py b/rl4co/models/rl/reinforce/baselines.py index e59b1969..1e9aa2a7 100644 --- a/rl4co/models/rl/reinforce/baselines.py +++ b/rl4co/models/rl/reinforce/baselines.py @@ -28,9 +28,7 @@ def wrap_dataset(self, dataset: Dataset, *args, **kw): return dataset @abc.abstractmethod - def eval( - self, td: TensorDict, reward: torch.Tensor, env: RL4COEnvBase = None, **kwargs - ): + def eval(self, td: TensorDict, reward: torch.Tensor, env: RL4COEnvBase = None, **kwargs): """Evaluate baseline""" raise NotImplementedError @@ -134,7 +132,7 @@ def epoch_callback(self, *args, **kw): self.baseline.epoch_callback(*args, **kw) if kw["epoch"] < self.n_epochs: self.alpha = (kw["epoch"] + 1) / float(self.n_epochs) - log.info("Set warmup alpha = {}".format(self.alpha)) + log.info(f"Set warmup alpha = {self.alpha}") class CriticBaseline(REINFORCEBaseline): @@ -145,12 +143,12 @@ class CriticBaseline(REINFORCEBaseline): """ def __init__(self, critic: CriticNetwork = None, **unused_kw): - super(CriticBaseline, self).__init__() + super().__init__() self.critic = critic def setup(self, policy, env, **kwargs): if self.critic is None: - log.info("Critic not found. Creating critic network for {}".format(env.name)) + log.info(f"Critic not found. Creating critic network for {env.name}") self.critic = create_critic_from_actor(policy) def eval(self, x, c, env=None): @@ -167,7 +165,7 @@ class RolloutBaseline(REINFORCEBaseline): """ def __init__(self, bl_alpha=0.05, **kw): - super(RolloutBaseline, self).__init__() + super().__init__() self.bl_alpha = bl_alpha def setup(self, *args, **kw): @@ -207,18 +205,14 @@ def epoch_callback( candidate_vals = self.rollout(policy, env, batch_size, device).cpu().numpy() candidate_mean = candidate_vals.mean() - log.info( - "Candidate mean: {:.3f}, Baseline mean: {:.3f}".format( - candidate_mean, self.mean - ) - ) + log.info(f"Candidate mean: {candidate_mean:.3f}, Baseline mean: {self.mean:.3f}") if candidate_mean - self.mean > 0: # Calc p value with inverse logic (costs) t, p = ttest_rel(-candidate_vals, -self.bl_vals) p_val = p / 2 # one-sided assert t < 0, "T-statistic should be negative" - log.info("p-value: {:.3f}".format(p_val)) + log.info(f"p-value: {p_val:.3f}") if p_val < self.bl_alpha: log.info("Updating baseline") self._update_policy(policy, env, batch_size, device, dataset_size) @@ -250,11 +244,7 @@ def wrap_dataset(self, dataset, env, batch_size=64, device="cpu", **kw): at every call but just once. Values are added to the dataset. This also allows for larger batch sizes since we evauate the policy without gradients. """ - rewards = ( - self.rollout(self.policy, env, batch_size, device, dataset=dataset) - .detach() - .cpu() - ) + rewards = self.rollout(self.policy, env, batch_size, device, dataset=dataset).detach().cpu() return dataset.add_key("extra", rewards) def __getstate__(self): @@ -297,9 +287,7 @@ def get_reinforce_baseline(name, **kw): warmup_epochs = kw.get("n_epochs", 1) warmup_exp_beta = kw.get("exp_beta", 0.8) bl_alpha = kw.get("bl_alpha", 0.05) - return WarmupBaseline( - RolloutBaseline(bl_alpha=bl_alpha), warmup_epochs, warmup_exp_beta - ) + return WarmupBaseline(RolloutBaseline(bl_alpha=bl_alpha), warmup_epochs, warmup_exp_beta) if name is None: name = "no" # default to no baseline diff --git a/rl4co/models/rl/reinforce/reinforce.py b/rl4co/models/rl/reinforce/reinforce.py index 81f96121..db49fd5e 100644 --- a/rl4co/models/rl/reinforce/reinforce.py +++ b/rl4co/models/rl/reinforce/reinforce.py @@ -1,4 +1,4 @@ -from typing import IO, Any, Optional, cast +from typing import IO, Any, cast import torch import torch.nn as nn @@ -56,9 +56,7 @@ def __init__( self.baseline = baseline self.advantage_scaler = RewardScaler(reward_scale) - def shared_step( - self, batch: Any, batch_idx: int, phase: str, dataloader_idx: int = None - ): + def shared_step(self, batch: Any, batch_idx: int, phase: str, dataloader_idx: int = None): td = self.env.reset(batch) # Perform forward pass (i.e., constructing solution and computing log-likelihoods) out = self.policy(td, self.env, phase=phase, select_best=phase != "train") @@ -75,8 +73,8 @@ def calculate_loss( td: TensorDict, batch: TensorDict, policy_out: dict, - reward: Optional[torch.Tensor] = None, - log_likelihood: Optional[torch.Tensor] = None, + reward: torch.Tensor | None = None, + log_likelihood: torch.Tensor | None = None, ): """Calculate loss for REINFORCE algorithm. @@ -95,9 +93,7 @@ def calculate_loss( ) # REINFORCE baseline - bl_val, bl_loss = ( - self.baseline.eval(td, reward, self.env) if extra is None else (extra, 0) - ) + bl_val, bl_loss = self.baseline.eval(td, reward, self.env) if extra is None else (extra, 0) # Main loss function advantage = reward - bl_val # advantage = reward - baseline @@ -169,7 +165,7 @@ def load_from_checkpoint( cls, checkpoint_path: _PATH | IO, map_location: _MAP_LOCATION_TYPE = None, - hparams_file: Optional[_PATH] = None, + hparams_file: _PATH | None = None, strict: bool = False, load_baseline: bool = True, **kwargs: Any, @@ -201,9 +197,9 @@ def load_from_checkpoint( loaded.setup() loaded.post_setup_hook() # load baseline state dict - state_dict = torch.load( - checkpoint_path, map_location=map_location, weights_only=False - )["state_dict"] + state_dict = torch.load(checkpoint_path, map_location=map_location, weights_only=False)[ + "state_dict" + ] # get only baseline parameters state_dict = {k: v for k, v in state_dict.items() if "baseline" in k} state_dict = {k.replace("baseline.", "", 1): v for k, v in state_dict.items()} diff --git a/rl4co/models/zoo/__init__.py b/rl4co/models/zoo/__init__.py index 8cfc2da1..81b9cac2 100644 --- a/rl4co/models/zoo/__init__.py +++ b/rl4co/models/zoo/__init__.py @@ -8,17 +8,8 @@ from rl4co.models.zoo.deepaco import DeepACO, DeepACOPolicy from rl4co.models.zoo.eas import EAS, EASEmb, EASLay from rl4co.models.zoo.glop import GLOP, GLOPPolicy -from rl4co.models.zoo.ham import ( - HeterogeneousAttentionModel, - HeterogeneousAttentionModelPolicy, -) -from rl4co.models.zoo.l2d import ( - L2DAttnPolicy, - L2DModel, - L2DPolicy, - L2DPolicy4PPO, - L2DPPOModel, -) +from rl4co.models.zoo.ham import HeterogeneousAttentionModel, HeterogeneousAttentionModelPolicy +from rl4co.models.zoo.l2d import L2DAttnPolicy, L2DModel, L2DPolicy, L2DPolicy4PPO, L2DPPOModel from rl4co.models.zoo.matnet import MatNet, MatNetPolicy from rl4co.models.zoo.mdam import MDAM, MDAMPolicy from rl4co.models.zoo.mvmoe import MVMoE_AM, MVMoE_POMO diff --git a/rl4co/models/zoo/active_search/search.py b/rl4co/models/zoo/active_search/search.py index d92b82b2..332bd270 100644 --- a/rl4co/models/zoo/active_search/search.py +++ b/rl4co/models/zoo/active_search/search.py @@ -57,7 +57,7 @@ def __init__( assert batch_size == 1, "Batch size must be 1 for active search" - super(ActiveSearch, self).__init__( + super().__init__( env, policy=policy, dataset=dataset, @@ -77,7 +77,7 @@ def setup(self, stage="fit"): - original policy state dict """ log.info("Setting up active search...") - super(ActiveSearch, self).setup(stage) + super().setup(stage) # Instantiate augmentation self.augmentation = StateAugmentation( @@ -92,9 +92,7 @@ def setup(self, stage="fit"): dataset_size = len(self.dataset) _batch = next(iter(self.train_dataloader())) self.problem_size = self.env.reset(_batch)["action_mask"].shape[-1] - self.instance_solutions = torch.zeros( - dataset_size, self.problem_size * 2, dtype=int - ) + self.instance_solutions = torch.zeros(dataset_size, self.problem_size * 2, dtype=int) self.instance_rewards = torch.zeros(dataset_size) def on_train_batch_start(self, batch: Any, batch_idx: int): @@ -174,15 +172,11 @@ def training_step(self, batch, batch_idx): return {"max_reward": max_reward, "best_solutions": best_solutions} - def on_train_batch_end( - self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int - ) -> None: + def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int) -> None: """We store the best solution and reward found.""" max_rewards, best_solutions = outputs["max_reward"], outputs["best_solutions"] self.instance_rewards[batch_idx] = max_rewards - self.instance_solutions[batch_idx, :] = best_solutions.squeeze( - 0 - ) # only one instance + self.instance_solutions[batch_idx, :] = best_solutions.squeeze(0) # only one instance log.info(f"Best reward: {max_rewards.mean():.2f}") def on_train_epoch_end(self) -> None: diff --git a/rl4co/models/zoo/am/decoder.py b/rl4co/models/zoo/am/decoder.py index f9809610..2d26389f 100644 --- a/rl4co/models/zoo/am/decoder.py +++ b/rl4co/models/zoo/am/decoder.py @@ -1,5 +1,4 @@ from dataclasses import dataclass, fields -from typing import Tuple import torch import torch.nn as nn @@ -108,9 +107,7 @@ def __init__( if pointer is None: # MHA with Pointer mechanism (https://arxiv.org/abs/1506.03134) - pointer_attn_class = ( - PointerAttention if moe_kwargs is None else PointerAttnMoE - ) + pointer_attn_class = PointerAttention if moe_kwargs is None else PointerAttnMoE pointer = pointer_attn_class( embed_dim, num_heads, @@ -124,9 +121,7 @@ def __init__( self.pointer = pointer # For each node we compute (glimpse key, glimpse value, logit key) so 3 * embed_dim - self.project_node_embeddings = nn.Linear( - embed_dim, 3 * embed_dim, bias=linear_bias - ) + self.project_node_embeddings = nn.Linear(embed_dim, 3 * embed_dim, bias=linear_bias) self.project_fixed_context = nn.Linear(embed_dim, embed_dim, bias=linear_bias) self.use_graph_context = use_graph_context @@ -163,7 +158,7 @@ def forward( td: TensorDict, cached: PrecomputedCache, num_starts: int = 0, - ) -> Tuple[Tensor, Tensor]: + ) -> tuple[Tensor, Tensor]: """Compute the logits of the next actions given the current state Args: @@ -199,13 +194,11 @@ def forward( def pre_decoder_hook( self, td, env, embeddings, num_starts: int = 0 - ) -> Tuple[TensorDict, RL4COEnvBase, PrecomputedCache]: + ) -> tuple[TensorDict, RL4COEnvBase, PrecomputedCache]: """Precompute the embeddings cache before the decoder is called""" return td, env, self._precompute_cache(embeddings, num_starts=num_starts) - def _precompute_cache( - self, embeddings: torch.Tensor, num_starts: int = 0 - ) -> PrecomputedCache: + def _precompute_cache(self, embeddings: torch.Tensor, num_starts: int = 0) -> PrecomputedCache: """Compute the cached embeddings for the pointer attention. Args: diff --git a/rl4co/models/zoo/am/encoder.py b/rl4co/models/zoo/am/encoder.py index c5303bba..b2e48711 100644 --- a/rl4co/models/zoo/am/encoder.py +++ b/rl4co/models/zoo/am/encoder.py @@ -1,5 +1,3 @@ -from typing import Tuple - import torch.nn as nn from tensordict import TensorDict @@ -41,7 +39,7 @@ def __init__( sdpa_fn=None, moe_kwargs: dict = None, ): - super(AttentionModelEncoder, self).__init__() + super().__init__() if isinstance(env_name, RL4COEnvBase): env_name = env_name.name @@ -67,9 +65,7 @@ def __init__( else net ) - def forward( - self, td: TensorDict, mask: Tensor | None = None - ) -> Tuple[Tensor, Tensor]: + def forward(self, td: TensorDict, mask: Tensor | None = None) -> tuple[Tensor, Tensor]: """Forward pass of the encoder. Transform the input TensorDict into a latent representation. diff --git a/rl4co/models/zoo/am/policy.py b/rl4co/models/zoo/am/policy.py index d650b72a..804d1a9b 100644 --- a/rl4co/models/zoo/am/policy.py +++ b/rl4co/models/zoo/am/policy.py @@ -1,4 +1,4 @@ -from typing import Callable +from collections.abc import Callable import torch.nn as nn @@ -108,7 +108,7 @@ def __init__( moe_kwargs=moe_kwargs["decoder"], ) - super(AttentionModelPolicy, self).__init__( + super().__init__( encoder=encoder, decoder=decoder, env_name=env_name, diff --git a/rl4co/models/zoo/amppo/model.py b/rl4co/models/zoo/amppo/model.py index 17a55257..ba600413 100644 --- a/rl4co/models/zoo/amppo/model.py +++ b/rl4co/models/zoo/amppo/model.py @@ -36,7 +36,7 @@ def __init__( policy = AttentionModelPolicy(env_name=env.name, **policy_kwargs) if critic is None: - log.info("Creating critic network for {}".format(env.name)) + log.info(f"Creating critic network for {env.name}") # we reuse the parameters of the model encoder = getattr(policy, "encoder", None) if encoder is None: diff --git a/rl4co/models/zoo/dact/decoder.py b/rl4co/models/zoo/dact/decoder.py index 81a684ad..ae650583 100644 --- a/rl4co/models/zoo/dact/decoder.py +++ b/rl4co/models/zoo/dact/decoder.py @@ -37,14 +37,10 @@ def __init__( self.hidden_dim = embed_dim # for MHC sublayer (NFE aspect) - self.compater_node = MultiHeadCompat( - num_heads, embed_dim, embed_dim, embed_dim, embed_dim - ) + self.compater_node = MultiHeadCompat(num_heads, embed_dim, embed_dim, embed_dim, embed_dim) # for MHC sublayer (PFE aspect) - self.compater_pos = MultiHeadCompat( - num_heads, embed_dim, embed_dim, embed_dim, embed_dim - ) + self.compater_pos = MultiHeadCompat(num_heads, embed_dim, embed_dim, embed_dim, embed_dim) self.norm_factor = 1 / math.sqrt(1 * self.hidden_dim) @@ -77,9 +73,9 @@ def forward(self, td: TensorDict, final_h: Tensor, final_p: Tensor) -> Tensor: h_node_refined = self.project_node_node(final_h) + self.project_graph_node( final_h.max(1)[0] )[:, None, :].expand(batch_size, graph_size, dim) - h_pos_refined = self.project_node_pos(final_p) + self.project_graph_pos( - final_p.max(1)[0] - )[:, None, :].expand(batch_size, graph_size, dim) + h_pos_refined = self.project_node_pos(final_p) + self.project_graph_pos(final_p.max(1)[0])[ + :, None, : + ].expand(batch_size, graph_size, dim) # MHC sublayer compatibility = torch.zeros( @@ -89,9 +85,9 @@ def forward(self, td: TensorDict, final_h: Tensor, final_p: Tensor) -> Tensor: compatibility[:, :, :, : self.n_heads] = self.compater_pos(h_pos_refined).permute( 1, 2, 3, 0 ) - compatibility[:, :, :, self.n_heads :] = self.compater_node( - h_node_refined - ).permute(1, 2, 3, 0) + compatibility[:, :, :, self.n_heads :] = self.compater_node(h_node_refined).permute( + 1, 2, 3, 0 + ) # FFA sublater return self.value_head(self.norm_factor * compatibility).squeeze(-1) @@ -118,9 +114,7 @@ def forward(self, x: torch.Tensor, hidden=None) -> torch.Tensor: graph_feature: torch.Tensor = self.project_graph(mean_pooling)[ :, None, : ] # (batch_size, 1, input_dim/2) - node_feature: torch.Tensor = self.project_node( - x - ) # (batch_size, graph_size+1, input_dim/2) + node_feature: torch.Tensor = self.project_node(x) # (batch_size, graph_size+1, input_dim/2) # pass through value_head, get estimated value fusion = node_feature + graph_feature.expand_as( diff --git a/rl4co/models/zoo/dact/encoder.py b/rl4co/models/zoo/dact/encoder.py index 0e263de0..61e8b6c4 100644 --- a/rl4co/models/zoo/dact/encoder.py +++ b/rl4co/models/zoo/dact/encoder.py @@ -1,7 +1,5 @@ import math -from typing import Tuple - import torch import torch.nn as nn import torch.nn.functional as F @@ -18,7 +16,7 @@ # implements the Multi-head DAC-Att module class DAC_ATT(nn.Module): def __init__(self, n_heads, input_dim, embed_dim=None, val_dim=None, key_dim=None): - super(DAC_ATT, self).__init__() + super().__init__() self.n_heads = n_heads @@ -29,37 +27,23 @@ def __init__(self, n_heads, input_dim, embed_dim=None, val_dim=None, key_dim=Non self.norm_factor = 1 / math.sqrt(1 * self.key_dim) # W_h^Q in the paper - self.W_query_node = nn.Parameter( - torch.Tensor(n_heads, self.input_dim, self.key_dim) - ) + self.W_query_node = nn.Parameter(torch.Tensor(n_heads, self.input_dim, self.key_dim)) # W_g^Q in the paper - self.W_query_pos = nn.Parameter( - torch.Tensor(n_heads, self.input_dim, self.key_dim) - ) + self.W_query_pos = nn.Parameter(torch.Tensor(n_heads, self.input_dim, self.key_dim)) # W_h^K in the paper - self.W_key_node = nn.Parameter( - torch.Tensor(n_heads, self.input_dim, self.key_dim) - ) + self.W_key_node = nn.Parameter(torch.Tensor(n_heads, self.input_dim, self.key_dim)) # W_g^K in the paper self.W_key_pos = nn.Parameter(torch.Tensor(n_heads, self.input_dim, self.key_dim)) # W_h^V and W_h^Vref in the paper - self.W_val_node = nn.Parameter( - torch.Tensor(2 * n_heads, self.input_dim, self.val_dim) - ) + self.W_val_node = nn.Parameter(torch.Tensor(2 * n_heads, self.input_dim, self.val_dim)) # W_g^V and W_g^Vref in the paper - self.W_val_pos = nn.Parameter( - torch.Tensor(2 * n_heads, self.input_dim, self.val_dim) - ) + self.W_val_pos = nn.Parameter(torch.Tensor(2 * n_heads, self.input_dim, self.val_dim)) # W_h^O and W_g^O in the paper if embed_dim is not None: - self.W_out_node = nn.Parameter( - torch.Tensor(n_heads, 2 * self.key_dim, embed_dim) - ) - self.W_out_pos = nn.Parameter( - torch.Tensor(n_heads, 2 * self.key_dim, embed_dim) - ) + self.W_out_node = nn.Parameter(torch.Tensor(n_heads, 2 * self.key_dim, embed_dim)) + self.W_out_pos = nn.Parameter(torch.Tensor(n_heads, 2 * self.key_dim, embed_dim)) self.init_parameters() @@ -88,9 +72,7 @@ def forward(self, h_node_in, h_pos_in): # input (NFEs, PFEs) V_pos = torch.matmul(h_pos, self.W_val_pos).view(shp_v) # Get attention correlations and norm by softmax - node_correlations = self.norm_factor * torch.matmul( - Q_node, K_node.transpose(2, 3) - ) + node_correlations = self.norm_factor * torch.matmul(Q_node, K_node.transpose(2, 3)) pos_correlations = self.norm_factor * torch.matmul(Q_pos, K_pos.transpose(2, 3)) attn1 = F.softmax(node_correlations, dim=-1) # head, bs, n, n attn2 = F.softmax(pos_correlations, dim=-1) # head, bs, n, n @@ -106,16 +88,12 @@ def forward(self, h_node_in, h_pos_in): # input (NFEs, PFEs) # get output out_node = torch.mm( - heads_node.permute(1, 2, 0, 3) - .contiguous() - .view(-1, self.n_heads * 2 * self.val_dim), + heads_node.permute(1, 2, 0, 3).contiguous().view(-1, self.n_heads * 2 * self.val_dim), self.W_out_node.view(-1, self.embed_dim), ).view(batch_size, graph_size, self.embed_dim) out_pos = torch.mm( - heads_pos.permute(1, 2, 0, 3) - .contiguous() - .view(-1, self.n_heads * 2 * self.val_dim), + heads_pos.permute(1, 2, 0, 3).contiguous().view(-1, self.n_heads * 2 * self.val_dim), self.W_out_pos.view(-1, self.embed_dim), ).view(batch_size, graph_size, self.embed_dim) @@ -131,7 +109,7 @@ def __init__( feed_forward_hidden, normalization="layer", ): - super(DACTEncoderLayer, self).__init__() + super().__init__() self.MHA_sublayer = DACsubLayer( n_heads, @@ -161,7 +139,7 @@ def __init__( feed_forward_hidden, normalization="layer", ): - super(DACsubLayer, self).__init__() + super().__init__() self.MHA = DAC_ATT(n_heads, input_dim=embed_dim, embed_dim=embed_dim) @@ -184,7 +162,7 @@ def __init__( feed_forward_hidden, normalization="layer", ): - super(FFNsubLayer, self).__init__() + super().__init__() self.FF1 = ( nn.Sequential( @@ -244,7 +222,7 @@ def __init__( normalization: str = "layer", feedforward_hidden: int = 64, ): - super(DACTEncoder, self).__init__( + super().__init__( embed_dim=embed_dim, env_name=env_name, pos_type=pos_type, @@ -268,7 +246,7 @@ def __init__( ) ) - def _encoder_forward(self, init_h: Tensor, init_p: Tensor) -> Tuple[Tensor, Tensor]: + def _encoder_forward(self, init_h: Tensor, init_p: Tensor) -> tuple[Tensor, Tensor]: NFE, PFE = self.net(init_h, init_p) return NFE, PFE diff --git a/rl4co/models/zoo/dact/model.py b/rl4co/models/zoo/dact/model.py index 34bf9c5e..4aa80210 100644 --- a/rl4co/models/zoo/dact/model.py +++ b/rl4co/models/zoo/dact/model.py @@ -46,9 +46,7 @@ def __init__( critic_kwargs["feedforward_hidden"] * 2 if "feedforward_hidden" in critic_kwargs else 128, - critic_kwargs["normalization"] - if "normalization" in critic_kwargs - else "layer", + critic_kwargs["normalization"] if "normalization" in critic_kwargs else "layer", bias=False, ) value_head = CriticDecoder(embed_dim) diff --git a/rl4co/models/zoo/dact/policy.py b/rl4co/models/zoo/dact/policy.py index 34489c6f..aea923c6 100644 --- a/rl4co/models/zoo/dact/policy.py +++ b/rl4co/models/zoo/dact/policy.py @@ -53,7 +53,7 @@ def __init__( val_decode_type: str = "sampling", test_decode_type: str = "sampling", ): - super(DACTPolicy, self).__init__() + super().__init__() self.env_name = env_name @@ -153,11 +153,7 @@ def forward( logprob, action_sampled = decode_strategy.step( logits, mask, - action=( - actions[:, 0] * seq_length + actions[:, 1] - if actions is not None - else None - ), + action=(actions[:, 0] * seq_length + actions[:, 1] if actions is not None else None), ) action_sampled = action_sampled.unsqueeze(-1) if phase == "train": diff --git a/rl4co/models/zoo/deepaco/antsystem.py b/rl4co/models/zoo/deepaco/antsystem.py index f31d6e1a..102189b2 100644 --- a/rl4co/models/zoo/deepaco/antsystem.py +++ b/rl4co/models/zoo/deepaco/antsystem.py @@ -1,5 +1,4 @@ -from functools import lru_cache, cached_property -from typing import Optional, Tuple +from functools import cached_property, lru_cache import torch @@ -8,9 +7,7 @@ from tqdm import trange from rl4co.envs import RL4COEnvBase -from rl4co.models.common.constructive.nonautoregressive.decoder import ( - NonAutoregressiveDecoder, -) +from rl4co.models.common.constructive.nonautoregressive.decoder import NonAutoregressiveDecoder from rl4co.utils.decoding import Sampling from rl4co.utils.ops import batchify, get_distance_matrix, unbatchify @@ -40,8 +37,8 @@ def __init__( alpha: float = 1.0, beta: float = 1.0, decay: float = 0.95, - Q: Optional[float] = None, - pheromone: Optional[Tensor | int] = None, + Q: float | None = None, + pheromone: Tensor | int | None = None, use_local_search: bool = False, use_nls: bool = False, n_perturbations: int = 1, @@ -71,7 +68,9 @@ def __init__( assert not (use_nls and not use_local_search), "use_nls requires use_local_search" self.use_nls = use_nls self.n_perturbations = n_perturbations - self.local_search_params = local_search_params.copy() # needs to be copied to avoid side effects + self.local_search_params = ( + local_search_params.copy() + ) # needs to be copied to avoid side effects self.perturbation_params = perturbation_params.copy() self._batchindex = torch.arange(self.batch_size, device=log_heuristic.device) @@ -92,7 +91,7 @@ def run( n_iterations: int, decoding_kwargs: dict, disable_tqdm: bool = True, - ) -> Tuple[Tensor, dict[int, Tensor]]: + ) -> tuple[Tensor, dict[int, Tensor]]: """Run the Ant System algorithm for a specified number of iterations. Args: @@ -107,9 +106,7 @@ def run( actions: The final actions chosen by the algorithm. reward: The final reward achieved by the algorithm. """ - pbar = trange( - n_iterations, dynamic_ncols=True, desc="Running ACO", disable=disable_tqdm - ) + pbar = trange(n_iterations, dynamic_ncols=True, desc="Running ACO", disable=disable_tqdm) for i in pbar: # reset environment td = td_initial.clone() @@ -157,9 +154,7 @@ def _sampling( ): # Sample from heatmaps # p = phe**alpha * heu**beta <==> log(p) = alpha*log(phe) + beta*log(heu) - heatmaps_logits = ( - self.alpha * torch.log(self.pheromone) + self.beta * self.log_heuristic - ) + heatmaps_logits = self.alpha * torch.log(self.pheromone) + self.beta * self.log_heuristic decode_strategy = Sampling(**decoding_kwargs) td, env, num_starts = decode_strategy.pre_decoder_hook(td, env) @@ -177,7 +172,7 @@ def _sampling( def local_search( self, td: TensorDict, env: RL4COEnvBase, actions: Tensor, decoding_kwargs: dict - ) -> Tuple[Tensor, Tensor]: + ) -> tuple[Tensor, Tensor]: """Perform local search on the actions and reward obtained. Args: diff --git a/rl4co/models/zoo/deepaco/model.py b/rl4co/models/zoo/deepaco/model.py index af2bd49e..a43fc23c 100644 --- a/rl4co/models/zoo/deepaco/model.py +++ b/rl4co/models/zoo/deepaco/model.py @@ -1,8 +1,9 @@ -from typing import Any, Optional, Union +from typing import Any -from tensordict import TensorDict import torch +from tensordict import TensorDict + from rl4co.envs.common.base import RL4COEnvBase from rl4co.models.rl import REINFORCE from rl4co.models.rl.reinforce.baselines import REINFORCEBaseline @@ -26,8 +27,8 @@ class DeepACO(REINFORCE): def __init__( self, env: RL4COEnvBase, - policy: Optional[DeepACOPolicy] = None, - baseline: Union[REINFORCEBaseline, str] = "no", # Shared baseline is manually implemented + policy: DeepACOPolicy | None = None, + baseline: REINFORCEBaseline | str = "no", # Shared baseline is manually implemented train_with_local_search: bool = True, ls_reward_aug_W: float = 0.95, policy_kwargs: dict = {}, @@ -45,7 +46,7 @@ def __init__( self.ls_reward_aug_W = ls_reward_aug_W def shared_step( - self, batch: Any, batch_idx: int, phase: str, dataloader_idx: Optional[int] = None + self, batch: Any, batch_idx: int, phase: str, dataloader_idx: int | None = None ): td = self.env.reset(batch) # Perform forward pass (i.e., constructing solution and computing log-likelihoods) @@ -63,8 +64,8 @@ def calculate_loss( td: TensorDict, batch: TensorDict, policy_out: dict, - reward: Optional[torch.Tensor] = None, - log_likelihood: Optional[torch.Tensor] = None, + reward: torch.Tensor | None = None, + log_likelihood: torch.Tensor | None = None, ): """Calculate loss for REINFORCE algorithm. @@ -81,7 +82,9 @@ def calculate_loss( if self.train_with_local_search: ls_reward = policy_out["ls_reward"] ls_advantage = ls_reward - ls_reward.mean(dim=1, keepdim=True) # Shared baseline - weighted_advantage = advantage * (1 - self.ls_reward_aug_W) + ls_advantage * self.ls_reward_aug_W + weighted_advantage = ( + advantage * (1 - self.ls_reward_aug_W) + ls_advantage * self.ls_reward_aug_W + ) else: weighted_advantage = advantage diff --git a/rl4co/models/zoo/deepaco/policy.py b/rl4co/models/zoo/deepaco/policy.py index 9f43504a..d57c2da2 100644 --- a/rl4co/models/zoo/deepaco/policy.py +++ b/rl4co/models/zoo/deepaco/policy.py @@ -1,9 +1,9 @@ from functools import partial -from typing import Optional, Type, Union -from tensordict import TensorDict import torch +from tensordict import TensorDict + from rl4co.envs import RL4COEnvBase, get_env from rl4co.models.common.constructive.nonautoregressive import ( NonAutoregressiveEncoder, @@ -11,7 +11,10 @@ ) from rl4co.models.zoo.deepaco.antsystem import AntSystem from rl4co.models.zoo.nargnn.encoder import NARGNNEncoder -from rl4co.utils.decoding import modify_logits_for_top_k_filtering, modify_logits_for_top_p_filtering +from rl4co.utils.decoding import ( + modify_logits_for_top_k_filtering, + modify_logits_for_top_p_filtering, +) from rl4co.utils.ops import batchify, unbatchify from rl4co.utils.utils import merge_with_defaults @@ -35,19 +38,19 @@ class DeepACOPolicy(NonAutoregressivePolicy): def __init__( self, - encoder: Optional[NonAutoregressiveEncoder] = None, + encoder: NonAutoregressiveEncoder | None = None, env_name: str = "tsp", temperature: float = 1.0, top_p: float = 0.0, top_k: int = 0, - aco_class: Optional[Type[AntSystem]] = None, + aco_class: type[AntSystem] | None = None, aco_kwargs: dict = {}, train_with_local_search: bool = False, - n_ants: Optional[Union[int, dict]] = None, - n_iterations: Optional[Union[int, dict]] = None, - start_node: Optional[int] = None, + n_ants: int | dict | None = None, + n_iterations: int | dict | None = None, + start_node: int | None = None, multistart: bool = False, - k_sparse: Optional[int] = None, + k_sparse: int | None = None, **encoder_kwargs, ): if encoder is None: @@ -68,16 +71,16 @@ def __init__( self.default_decoding_kwargs = {} self.default_decoding_kwargs["select_best"] = False if k_sparse is not None: - self.default_decoding_kwargs["top_k"] = k_sparse + (0 if env_name == "tsp" else 1) # 1 for depot + self.default_decoding_kwargs["top_k"] = k_sparse + ( + 0 if env_name == "tsp" else 1 + ) # 1 for depot if "multistart" in self.decode_type: select_start_nodes_fn = partial(self.select_start_node_fn, start_node=start_node) self.default_decoding_kwargs.update( {"multistart": True, "select_start_nodes_fn": select_start_nodes_fn} ) else: - self.default_decoding_kwargs.update( - {"multisample": True} - ) + self.default_decoding_kwargs.update({"multisample": True}) # For now, top_p and top_k are only used to filter logits (not passed to decoder) self.top_p = top_p @@ -93,7 +96,7 @@ def __init__( @staticmethod def select_start_node_fn( - td: TensorDict, env: RL4COEnvBase, num_starts: int, start_node: Optional[int] = None + td: TensorDict, env: RL4COEnvBase, num_starts: int, start_node: int | None = None ): if env.name == "tsp" and start_node is not None: # For now, only TSP supports explicitly setting the start node @@ -105,7 +108,7 @@ def select_start_node_fn( def forward( self, td_initial: TensorDict, - env: Optional[Union[str, RL4COEnvBase]] = None, + env: str | RL4COEnvBase | None = None, phase: str = "train", return_actions: bool = True, return_hidden: bool = True, @@ -124,7 +127,9 @@ def forward( ) # Instantiate environment if needed - if (phase != "train" or self.train_with_local_search) and (env is None or isinstance(env, str)): + if (phase != "train" or self.train_with_local_search) and ( + env is None or isinstance(env, str) + ): env_name = self.env_name if env is None else env env = get_env(env_name) else: diff --git a/rl4co/models/zoo/eas/decoder.py b/rl4co/models/zoo/eas/decoder.py index fee3c6fa..f29817fe 100644 --- a/rl4co/models/zoo/eas/decoder.py +++ b/rl4co/models/zoo/eas/decoder.py @@ -25,8 +25,7 @@ def forward_pointer_attn_eas_lay(self, query, key, value, logit_key, mask): # Batch matrix multiplication to compute logits (batch_size, num_steps, graph_size) # bmm is slightly faster than einsum and matmul logits = ( - torch.bmm(glimpse, logit_key.squeeze(1).transpose(-2, -1)) - / math.sqrt(glimpse.size(-1)) + torch.bmm(glimpse, logit_key.squeeze(1).transpose(-2, -1)) / math.sqrt(glimpse.size(-1)) ).squeeze(1) return logits @@ -99,9 +98,7 @@ def forward_eas( logits, mask, temperature=self.temperature if self.temperature is not None else temperature, - tanh_clipping=( - self.tanh_clipping if self.tanh_clipping is not None else tanh_clipping - ), + tanh_clipping=(self.tanh_clipping if self.tanh_clipping is not None else tanh_clipping), mask_logits=self.mask_logits if self.mask_logits is not None else mask_logits, ) diff --git a/rl4co/models/zoo/eas/search.py b/rl4co/models/zoo/eas/search.py index f8ae4a00..a198988c 100644 --- a/rl4co/models/zoo/eas/search.py +++ b/rl4co/models/zoo/eas/search.py @@ -71,11 +71,11 @@ def __init__( ): self.save_hyperparameters(logger=False, ignore=["env", "policy", "dataset"]) - assert ( - self.hparams.use_eas_embedding or self.hparams.use_eas_layer - ), "At least one of `use_eas_embedding` or `use_eas_layer` must be True." + assert self.hparams.use_eas_embedding or self.hparams.use_eas_layer, ( + "At least one of `use_eas_embedding` or `use_eas_layer` must be True." + ) - super(EAS, self).__init__( + super().__init__( env, policy=policy, dataset=dataset, @@ -105,7 +105,7 @@ def setup(self, stage="fit"): f"- EAS Embedding: {self.hparams.use_eas_embedding} \n" f"- EAS Layer: {self.hparams.use_eas_layer} \n" ) - super(EAS, self).setup(stage) + super().setup(stage) # Instantiate augmentation self.augmentation = StateAugmentation( @@ -166,18 +166,14 @@ def training_step(self, batch, batch_idx): # EASLay: replace forward of logit attention computation. EASLayer eas_layer = EASLayerNet(num_instances, decoder.embed_dim).to(batch.device) decoder.pointer.eas_layer = partial(eas_layer, decoder.pointer) - decoder.pointer.forward = partial( - forward_pointer_attn_eas_lay, decoder.pointer - ) + decoder.pointer.forward = partial(forward_pointer_attn_eas_lay, decoder.pointer) for param in eas_layer.parameters(): opt_params.append(param) if self.hparams.use_eas_embedding: # EASEmb: set gradient of emb_key to True # for all the keys, wrap the embedding in a nn.Parameter for key in self.hparams.eas_emb_cache_keys: - setattr( - cached_embeds, key, torch.nn.Parameter(getattr(cached_embeds, key)) - ) + setattr(cached_embeds, key, torch.nn.Parameter(getattr(cached_embeds, key))) opt_params.append(getattr(cached_embeds, key)) decoder.forward_eas = partial(forward_eas, decoder) @@ -226,9 +222,7 @@ def set_attr_if_exists(attr): elif self.hparams.baseline == "symmetric": bl_val = group_reward.mean(dim=-2, keepdim=True) elif self.hparams.baseline == "full": - bl_val = group_reward.mean(dim=-1, keepdim=True).mean( - dim=-2, keepdim=True - ) + bl_val = group_reward.mean(dim=-1, keepdim=True).mean(dim=-2, keepdim=True) else: raise ValueError(f"Baseline {self.hparams.baseline} not supported.") @@ -269,8 +263,7 @@ def set_attr_if_exists(attr): ) log.info( - f"{iter_count}/{self.hparams.max_iters} | " - f" Reward: {max_reward.mean().item():.2f} " + f"{iter_count}/{self.hparams.max_iters} | Reward: {max_reward.mean().item():.2f} " ) # Stop if max runtime is exceeded @@ -280,9 +273,7 @@ def set_attr_if_exists(attr): return {"max_reward": max_reward, "best_solutions": best_solutions} - def on_train_batch_end( - self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int - ) -> None: + def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int) -> None: """We store the best solution and reward found.""" max_rewards, best_solutions = outputs["max_reward"], outputs["best_solutions"] self.instance_solutions.append(best_solutions) @@ -316,15 +307,13 @@ def __init__( *args, **kwargs, ): - if not kwargs.get("use_eas_embedding", False) or kwargs.get( - "use_eas_layer", True - ): + if not kwargs.get("use_eas_embedding", False) or kwargs.get("use_eas_layer", True): log.warning( "Setting `use_eas_embedding` to True and `use_eas_layer` to False. Use EAS base class to override." ) kwargs["use_eas_embedding"] = True kwargs["use_eas_layer"] = False - super(EASEmb, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) class EASLay(EAS): @@ -335,12 +324,10 @@ def __init__( *args, **kwargs, ): - if kwargs.get("use_eas_embedding", False) or not kwargs.get( - "use_eas_layer", True - ): + if kwargs.get("use_eas_embedding", False) or not kwargs.get("use_eas_layer", True): log.warning( "Setting `use_eas_embedding` to True and `use_eas_layer` to False. Use EAS base class to override." ) kwargs["use_eas_embedding"] = False kwargs["use_eas_layer"] = True - super(EASLay, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) diff --git a/rl4co/models/zoo/gfacs/encoder.py b/rl4co/models/zoo/gfacs/encoder.py index dbeaa81a..c9ed4914 100644 --- a/rl4co/models/zoo/gfacs/encoder.py +++ b/rl4co/models/zoo/gfacs/encoder.py @@ -1,7 +1,7 @@ -from typing import Optional -from tensordict import TensorDict import torch.nn as nn +from tensordict import TensorDict + from rl4co.models.zoo.nargnn.encoder import NARGNNEncoder @@ -10,21 +10,22 @@ class GFACSEncoder(NARGNNEncoder): NARGNNEncoder with log-partition function estimation for training with Trajectory Balance (TB) loss (Malkin et al., https://arxiv.org/abs/2201.13259) """ + def __init__( self, embed_dim: int = 64, env_name: str = "tsp", # TODO: pass network - init_embedding: Optional[nn.Module] = None, - edge_embedding: Optional[nn.Module] = None, - graph_network: Optional[nn.Module] = None, - heatmap_generator: Optional[nn.Module] = None, + init_embedding: nn.Module | None = None, + edge_embedding: nn.Module | None = None, + graph_network: nn.Module | None = None, + heatmap_generator: nn.Module | None = None, num_layers_heatmap_generator: int = 5, num_layers_graph_encoder: int = 15, act_fn="silu", agg_fn="mean", linear_bias: bool = True, - k_sparse: Optional[int] = None, + k_sparse: int | None = None, z_out_dim: int = 1, ): super().__init__( @@ -56,9 +57,7 @@ def forward(self, td: TensorDict): # Process embedding into graph # TODO: standardize? - graph.x, graph.edge_attr = self.graph_network( - graph.x, graph.edge_index, graph.edge_attr - ) + graph.x, graph.edge_attr = self.graph_network(graph.x, graph.edge_index, graph.edge_attr) logZ = self.Z_net(graph.edge_attr).reshape(-1, len(td), self.z_out_dim).mean(0) diff --git a/rl4co/models/zoo/gfacs/model.py b/rl4co/models/zoo/gfacs/model.py index c21c6841..fe0f486e 100644 --- a/rl4co/models/zoo/gfacs/model.py +++ b/rl4co/models/zoo/gfacs/model.py @@ -1,7 +1,5 @@ import math -from typing import Optional, Union - import numpy as np import scipy import torch @@ -37,8 +35,8 @@ class GFACS(DeepACO): def __init__( self, env: RL4COEnvBase, - policy: Optional[GFACSPolicy] = None, - baseline: Union[REINFORCEBaseline, str] = "no", + policy: GFACSPolicy | None = None, + baseline: REINFORCEBaseline | str = "no", train_with_local_search: bool = True, policy_kwargs: dict = {}, baseline_kwargs: dict = {}, @@ -77,8 +75,7 @@ def __init__( @property def alpha(self) -> float: return self.alpha_min + (self.alpha_max - self.alpha_min) * min( - self.current_epoch - / (self.trainer.max_epochs - self.alpha_flat_epochs), # type: ignore + self.current_epoch / (self.trainer.max_epochs - self.alpha_flat_epochs), # type: ignore 1.0, ) @@ -95,8 +92,8 @@ def calculate_loss( td: TensorDict, batch: TensorDict, policy_out: dict, - reward: Optional[torch.Tensor] = None, - log_likelihood: Optional[torch.Tensor] = None, + reward: torch.Tensor | None = None, + log_likelihood: torch.Tensor | None = None, ): """Calculate loss for REINFORCE algorithm. @@ -128,9 +125,9 @@ def calculate_loss( # Off-policy loss if self.train_with_local_search: - ls_forward_flow = policy_out["ls_log_likelihood"] + policy_out[ - "ls_logZ" - ].repeat(1, n_ants) + ls_forward_flow = policy_out["ls_log_likelihood"] + policy_out["ls_logZ"].repeat( + 1, n_ants + ) ls_backward_flow = ( self.calculate_log_pb_uniform(policy_out["ls_actions"], n_ants) + ls_advantage.detach() * self.beta @@ -152,9 +149,7 @@ def calculate_log_pb_uniform(self, actions: torch.Tensor, n_ants: int): n_routes = np.count_nonzero(_a2, axis=1) - n_nodes _a3 = _a1[:, 2:] - _a1[:, :-2] n_multinode_routes = np.count_nonzero(_a3, axis=1) - n_nodes - log_b_p = -scipy.special.gammaln( - n_routes + 1 - ) - n_multinode_routes * math.log(2) + log_b_p = -scipy.special.gammaln(n_routes + 1) - n_multinode_routes * math.log(2) return unbatchify(torch.from_numpy(log_b_p).to(actions.device), n_ants) case "op" | "pctsp": return math.log(1 / 2) diff --git a/rl4co/models/zoo/gfacs/policy.py b/rl4co/models/zoo/gfacs/policy.py index 19d49640..bc1a31f9 100644 --- a/rl4co/models/zoo/gfacs/policy.py +++ b/rl4co/models/zoo/gfacs/policy.py @@ -1,7 +1,6 @@ -from typing import Optional, Type, Union +import torch from tensordict import TensorDict -import torch from rl4co.envs import RL4COEnvBase, get_env from rl4co.models.zoo.deepaco import DeepACOPolicy @@ -12,12 +11,11 @@ get_decoding_strategy, get_log_likelihood, modify_logits_for_top_k_filtering, - modify_logits_for_top_p_filtering + modify_logits_for_top_p_filtering, ) from rl4co.utils.ops import batchify, unbatchify from rl4co.utils.pylogger import get_pylogger - log = get_pylogger(__name__) @@ -40,18 +38,18 @@ class GFACSPolicy(DeepACOPolicy): def __init__( self, - encoder: Optional[GFACSEncoder] = None, + encoder: GFACSEncoder | None = None, env_name: str = "tsp", temperature: float = 1.0, top_p: float = 0.0, top_k: int = 0, - aco_class: Optional[Type[AntSystem]] = None, + aco_class: type[AntSystem] | None = None, aco_kwargs: dict = {}, train_with_local_search: bool = True, - n_ants: Optional[Union[int, dict]] = None, - n_iterations: Optional[Union[int, dict]] = None, + n_ants: int | dict | None = None, + n_iterations: int | dict | None = None, multistart: bool = False, - k_sparse: Optional[int] = None, + k_sparse: int | None = None, **encoder_kwargs, ): if encoder is None: @@ -77,7 +75,7 @@ def __init__( def forward( self, td_initial: TensorDict, - env: Optional[Union[str, RL4COEnvBase]] = None, + env: str | RL4COEnvBase | None = None, phase: str = "train", return_actions: bool = True, return_hidden: bool = False, @@ -98,7 +96,9 @@ def forward( ) # Instantiate environment if needed - if (phase != "train" or self.train_with_local_search) and (env is None or isinstance(env, str)): + if (phase != "train" or self.train_with_local_search) and ( + env is None or isinstance(env, str) + ): env_name = self.env_name if env is None else env env = get_env(env_name) else: @@ -121,7 +121,7 @@ def forward( "reward": unbatchify(env.get_reward(td, actions), n_ants), "log_likelihood": unbatchify( get_log_likelihood(logprobs, actions, td.get("mask", None), True), n_ants - ) + ), } if return_actions: @@ -147,7 +147,7 @@ def forward( "ls_log_likelihood": unbatchify( get_log_likelihood(ls_logprobs, ls_actions, td.get("mask", None), True), n_ants, - ) + ), } ) if return_actions: @@ -185,7 +185,7 @@ def common_decoding( td: TensorDict, env: RL4COEnvBase, hidden: TensorDict, - actions: Optional[torch.Tensor] = None, + actions: torch.Tensor | None = None, max_steps: int = 1_000_000, **decoding_kwargs, ): @@ -199,11 +199,15 @@ def common_decoding( **decoding_kwargs, ) if actions is not None: - assert decoding_strategy.name == "evaluate", "decoding strategy must be 'evaluate' when actions are provided" + assert decoding_strategy.name == "evaluate", ( + "decoding strategy must be 'evaluate' when actions are provided" + ) # Pre-decoding hook: used for the initial step(s) of the decoding strategy td, env, num_starts = decoding_strategy.pre_decoder_hook( - td, env, actions[:, 0] if actions is not None and "multistart" in self.decode_type else None + td, + env, + actions[:, 0] if actions is not None and "multistart" in self.decode_type else None, ) # Additionally call a decoder hook if needed before main decoding @@ -222,11 +226,9 @@ def common_decoding( td = env.step(td)["next"] step += 1 if step > max_steps: - log.error( - f"Exceeded maximum number of steps ({max_steps}) duing decoding" - ) + log.error(f"Exceeded maximum number of steps ({max_steps}) duing decoding") break # Post-decoding hook: used for the final step(s) of the decoding strategy logprobs, actions, td, env = decoding_strategy.post_decoder_hook(td, env) - return logprobs, actions, td, env \ No newline at end of file + return logprobs, actions, td, env diff --git a/rl4co/models/zoo/glop/__init__.py b/rl4co/models/zoo/glop/__init__.py index 246acff2..c40aef59 100644 --- a/rl4co/models/zoo/glop/__init__.py +++ b/rl4co/models/zoo/glop/__init__.py @@ -1,2 +1,2 @@ from rl4co.models.zoo.glop.model import GLOP -from rl4co.models.zoo.glop.policy import GLOPPolicy \ No newline at end of file +from rl4co.models.zoo.glop.policy import GLOPPolicy diff --git a/rl4co/models/zoo/glop/adapter/tsp_adapter.py b/rl4co/models/zoo/glop/adapter/tsp_adapter.py index 43b1106d..4c024325 100644 --- a/rl4co/models/zoo/glop/adapter/tsp_adapter.py +++ b/rl4co/models/zoo/glop/adapter/tsp_adapter.py @@ -1,4 +1,5 @@ -from typing import Any, Generator, NamedTuple, Optional, Union +from collections.abc import Generator +from typing import Any, NamedTuple import torch @@ -22,8 +23,8 @@ def __init__( self, parent_td: TensorDict, actions: torch.Tensor, - subprob_batch_size: Optional[int] = None, - partition_node_count: Union[int, list[int]] = 20, + subprob_batch_size: int | None = None, + partition_node_count: int | list[int] = 20, shift: int = 0, ) -> None: batch_size = parent_td.batch_size[0] @@ -40,9 +41,7 @@ def __init__( self.subprob_batch_size = subprob_batch_size self.shift = shift - def _get_batched_subprobs_one_iter( - self, node_count: int - ) -> Generator[SHPPMapping, Any, None]: + def _get_batched_subprobs_one_iter(self, node_count: int) -> Generator[SHPPMapping, Any, None]: self.shpp_actions, shpp_coordinates, self.share_memory = self.action_partitioner( self._actions, self.coordinates, node_count ) @@ -52,9 +51,7 @@ def _get_batched_subprobs_one_iter( return batch_size = self.subprob_batch_size is None or shpp_count for start_index in range(0, shpp_count, batch_size): - map_action_index = torch.arange( - start_index, min(start_index + batch_size, shpp_count) - ) + map_action_index = torch.arange(start_index, min(start_index + batch_size, shpp_count)) map_node_index = self.shpp_actions[map_action_index] # shpp_index this_shpp_coordinates = shpp_coordinates[map_action_index] yield SHPPMapping(map_action_index, map_node_index, this_shpp_coordinates) @@ -84,18 +81,14 @@ def get_actions(self): return self._actions @staticmethod - def action_partitioner( - actions: torch.Tensor, coordinates: torch.Tensor, shpp_nodes: int - ): + def action_partitioner(actions: torch.Tensor, coordinates: torch.Tensor, shpp_nodes: int): batch_size, tsp_nodes, _ = coordinates.shape share_memory = tsp_nodes % shpp_nodes == 0 if share_memory: shpp_actions = actions.view(-1, shpp_nodes) else: tsp_nodes -= tsp_nodes % shpp_nodes - shpp_actions = actions[:, :tsp_nodes].reshape( - -1, shpp_nodes - ) # trim tail nodes + shpp_actions = actions[:, :tsp_nodes].reshape(-1, shpp_nodes) # trim tail nodes repeated_coordinates = ( coordinates.unsqueeze(1) diff --git a/rl4co/models/zoo/glop/adapter/vrp_adapter.py b/rl4co/models/zoo/glop/adapter/vrp_adapter.py index 173b333a..1115d9c7 100644 --- a/rl4co/models/zoo/glop/adapter/vrp_adapter.py +++ b/rl4co/models/zoo/glop/adapter/vrp_adapter.py @@ -1,4 +1,5 @@ -from typing import Any, Generator, NamedTuple, Optional +from collections.abc import Generator +from typing import Any, NamedTuple import numba as nb import numpy as np @@ -24,7 +25,7 @@ def __init__( self, parent_td: TensorDict, actions: torch.Tensor, - subprob_batch_size: Optional[int] = None, + subprob_batch_size: int | None = None, min_node_count: int = 4, ) -> None: batch_size = parent_td.batch_size[0] @@ -118,9 +119,7 @@ def _compose_subtsp_coordinates( n_samples = actions.shape[0] // batch_size max_subtsp_length = (map_action_index[:, 2] - map_action_index[:, 1]).max() subtsp_index = np.zeros((n_subtsp, max_subtsp_length + 1), dtype=np.int32) - subtsp_coordinates = np.zeros( - (n_subtsp, max_subtsp_length + 1, 2), dtype=coordinates.dtype - ) + subtsp_coordinates = np.zeros((n_subtsp, max_subtsp_length + 1, 2), dtype=coordinates.dtype) for idx in nb.prange(n_subtsp): route_idx, start, end = map_action_index[idx] inst_idx = route_idx // n_samples @@ -137,9 +136,7 @@ def _update_cvrp_actions( map_node_index: np.ndarray, ): subtsp_length = subtsp_actions.shape[1] - subtsp_underlying_actions = np.take_along_axis( - map_node_index, subtsp_actions, axis=-1 - ) + subtsp_underlying_actions = np.take_along_axis(map_node_index, subtsp_actions, axis=-1) for idx in nb.prange(subtsp_actions.shape[0]): route_idx, start, end = map_action_index[idx] real_nodes = subtsp_underlying_actions[idx] diff --git a/rl4co/models/zoo/glop/model.py b/rl4co/models/zoo/glop/model.py index 9fa4ca6c..50c42c28 100644 --- a/rl4co/models/zoo/glop/model.py +++ b/rl4co/models/zoo/glop/model.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Union +from typing import Any from rl4co.envs.common.base import RL4COEnvBase from rl4co.models.rl import REINFORCE @@ -22,8 +22,8 @@ class GLOP(REINFORCE): def __init__( self, env: RL4COEnvBase, - policy: Optional[GLOPPolicy] = None, - baseline: Union[REINFORCEBaseline, str] = "mean", + policy: GLOPPolicy | None = None, + baseline: REINFORCEBaseline | str = "mean", policy_kwargs={}, baseline_kwargs={}, **kwargs, @@ -34,7 +34,7 @@ def __init__( super().__init__(env, policy, baseline, baseline_kwargs, **kwargs) def shared_step( - self, batch: Any, batch_idx: int, phase: str, dataloader_idx: Optional[int] = None + self, batch: Any, batch_idx: int, phase: str, dataloader_idx: int | None = None ): td = self.env.reset(batch) n_samples = self.policy.n_samples diff --git a/rl4co/models/zoo/glop/policy.py b/rl4co/models/zoo/glop/policy.py index e321b0e0..d1684f39 100644 --- a/rl4co/models/zoo/glop/policy.py +++ b/rl4co/models/zoo/glop/policy.py @@ -1,4 +1,5 @@ -from typing import Callable, Literal, Optional, Union +from collections.abc import Callable +from typing import Literal import numpy as np import torch @@ -27,12 +28,12 @@ log = get_pylogger(__name__) -SubProblemSolverType = Union[ - Literal["insertion"], - RL4COLitModule, - tuple[RL4COLitModule, dict], - Callable[[torch.Tensor], torch.Tensor], -] +SubProblemSolverType = ( + Literal["insertion"] + | RL4COLitModule + | tuple[RL4COLitModule, dict] + | Callable[[torch.Tensor], torch.Tensor] +) class GLOPPolicy(NonAutoregressivePolicy): @@ -53,8 +54,8 @@ class GLOPPolicy(NonAutoregressivePolicy): def __init__( self, - encoder: Optional[NonAutoregressiveEncoder] = None, - decoder: Optional[NonAutoregressiveDecoder] = None, + encoder: NonAutoregressiveEncoder | None = None, + decoder: NonAutoregressiveDecoder | None = None, env_name: str = "cvrp", n_samples: int = 10, temperature: float = 1.0, @@ -63,15 +64,14 @@ def __init__( subprob_solver: SubProblemSolverType = "insertion", **encoder_kwargs, ): - if subprob_adapter_class is None: - assert ( - env_name in adapter_map - ), f"{env_name} is not supported by {self.__class__.__name__} yet" + assert env_name in adapter_map, ( + f"{env_name} is not supported by {self.__class__.__name__} yet" + ) subprob_adapter_class = adapter_map.get(env_name) - assert ( - subprob_adapter_class is not None - ), "Can not import adapter module. Please check if `numba` is installed." + assert subprob_adapter_class is not None, ( + "Can not import adapter module. Please check if `numba` is installed." + ) if encoder is None: encoder_kwargs.setdefault("embed_dim", 64) @@ -107,14 +107,14 @@ def __init__( def forward( self, td: TensorDict, - env: Optional[Union[RL4COEnvBase, str]] = None, + env: RL4COEnvBase | str | None = None, phase: Literal["train", "val", "test"] = "test", calc_reward: bool = True, return_actions: bool = False, return_entropy: bool = False, return_init_embeds: bool = False, return_sum_log_likelihood: bool = False, - subprob_solver: Optional[SubProblemSolverType] = None, + subprob_solver: SubProblemSolverType | None = None, **decoding_kwargs, ) -> dict: """Forward pass of GLOP. @@ -140,9 +140,7 @@ def forward( or isinstance(env, RL4COEnvBase) and env.name.startswith("cvrp") ): - decoding_kwargs.setdefault( - "select_start_nodes_fn", select_start_nodes_by_distance - ) + decoding_kwargs.setdefault("select_start_nodes_fn", select_start_nodes_by_distance) par_out = super().forward( td=td, @@ -177,9 +175,7 @@ def forward( if calc_reward: if isinstance(env, str) or env is None: env_name = self.env_name if env is None else env - log.info( - f"Instantiated environment not provided; instantiating {env_name}" - ) + log.info(f"Instantiated environment not provided; instantiating {env_name}") env = get_env(env_name) td_repeated = batchify(td, self.n_samples) reward = env.get_reward(td_repeated, actions) @@ -210,13 +206,11 @@ def local_policy( adapter.update_actions(mapping, subprob_actions) actions_revised = adapter.get_actions().to(td.device) - actions_revised = rearrange( - actions_revised, "(b n) ... -> (n b) ...", n=self.n_samples - ) + actions_revised = rearrange(actions_revised, "(b n) ... -> (n b) ...", n=self.n_samples) return dict(actions=actions_revised) def _get_subprob_solver( - self, solver: Optional[SubProblemSolverType] + self, solver: SubProblemSolverType | None ) -> Callable[[torch.Tensor], torch.Tensor]: solver = self.subprob_solver if solver is None else solver env_name = self.subprob_adapter_class.subproblem_env_name diff --git a/rl4co/models/zoo/ham/attention.py b/rl4co/models/zoo/ham/attention.py index 0c4d593e..641790a8 100644 --- a/rl4co/models/zoo/ham/attention.py +++ b/rl4co/models/zoo/ham/attention.py @@ -10,7 +10,7 @@ def __init__(self, num_heads, input_dim, embed_dim=None, val_dim=None, key_dim=N Heterogenous Multi-Head Attention for Pickup and Delivery problems https://arxiv.org/abs/2110.02634 """ - super(HeterogenousMHA, self).__init__() + super().__init__() if val_dim is None: assert embed_dim is not None, "Provide either embed_dim or val_dim" @@ -56,7 +56,7 @@ def forward(self, q, h=None, mask=None): q: queries (batch_size, n_query, input_dim) h: data (batch_size, graph_size, input_dim) mask: mask (batch_size, n_query, graph_size) or viewable as that (i.e. can be 2 dim if n_query == 1) - + Mask should contain 1 if attention is not possible (i.e. mask is negative adjacency) """ if h is None: @@ -66,10 +66,10 @@ def forward(self, q, h=None, mask=None): batch_size, graph_size, input_dim = h.size() # Check if graph size is odd number - assert ( - graph_size % 2 == 1 - ), "Graph size should have odd number of nodes due to pickup-delivery problem \ + assert graph_size % 2 == 1, ( + "Graph size should have odd number of nodes due to pickup-delivery problem \ (n/2 pickup, n/2 delivery, 1 depot)" + ) n_query = q.size(1) assert q.size(0) == batch_size @@ -479,9 +479,7 @@ def forward(self, q, h=None, mask=None): ) out = torch.mm( - heads.permute(1, 2, 0, 3) - .contiguous() - .view(-1, self.num_heads * self.val_dim), + heads.permute(1, 2, 0, 3).contiguous().view(-1, self.num_heads * self.val_dim), self.W_out.view(-1, self.embed_dim), ).view(batch_size, n_query, self.embed_dim) diff --git a/rl4co/models/zoo/ham/encoder.py b/rl4co/models/zoo/ham/encoder.py index 8a116336..da579858 100644 --- a/rl4co/models/zoo/ham/encoder.py +++ b/rl4co/models/zoo/ham/encoder.py @@ -13,7 +13,7 @@ def __init__( feedforward_hidden=512, normalization="batch", ): - super(HeterogeneuousMHALayer, self).__init__( + super().__init__( SkipConnection(HeterogenousMHA(num_heads, embed_dim, embed_dim)), Normalization(embed_dim, normalization), SkipConnection( @@ -41,7 +41,7 @@ def __init__( feedforward_hidden=512, sdpa_fn=None, ): - super(GraphHeterogeneousAttentionEncoder, self).__init__() + super().__init__() # substitute env_name with pdp if none if env_name is None: diff --git a/rl4co/models/zoo/ham/model.py b/rl4co/models/zoo/ham/model.py index 9d022567..2b17191a 100644 --- a/rl4co/models/zoo/ham/model.py +++ b/rl4co/models/zoo/ham/model.py @@ -26,9 +26,9 @@ def __init__( baseline_kwargs={}, **kwargs, ): - assert ( - env.name == "pdp" - ), "HeterogeneousAttentionModel only works for PDP (Pickup and Delivery Problem)" + assert env.name == "pdp", ( + "HeterogeneousAttentionModel only works for PDP (Pickup and Delivery Problem)" + ) if policy is None: policy = HeterogeneousAttentionModelPolicy(env_name=env.name, **policy_kwargs) diff --git a/rl4co/models/zoo/ham/policy.py b/rl4co/models/zoo/ham/policy.py index 3dc8ddbc..7032f2db 100644 --- a/rl4co/models/zoo/ham/policy.py +++ b/rl4co/models/zoo/ham/policy.py @@ -1,4 +1,4 @@ -from typing import Callable, Optional +from collections.abc import Callable import torch.nn as nn @@ -34,7 +34,7 @@ def __init__( num_heads: int = 8, normalization: str = "batch", feedforward_hidden: int = 512, - sdpa_fn: Optional[Callable] = None, + sdpa_fn: Callable | None = None, **kwargs, ): if encoder is None: @@ -51,7 +51,7 @@ def __init__( else: encoder = encoder - super(HeterogeneousAttentionModelPolicy, self).__init__( + super().__init__( env_name=env_name, encoder=encoder, embed_dim=embed_dim, diff --git a/rl4co/models/zoo/l2d/decoder.py b/rl4co/models/zoo/l2d/decoder.py index 833e9c6e..672fb7a2 100644 --- a/rl4co/models/zoo/l2d/decoder.py +++ b/rl4co/models/zoo/l2d/decoder.py @@ -1,6 +1,6 @@ import abc -from typing import Any, Tuple +from typing import Any import torch import torch.nn as nn @@ -31,7 +31,7 @@ class L2DActor(nn.Module, metaclass=abc.ABCMeta): @abc.abstractmethod def forward( self, td: TensorDict, hidden: Any = None, num_starts: int = 0 - ) -> Tuple[Tensor, Tensor]: + ) -> tuple[Tensor, Tensor]: """Obtain logits for current action to the next ones Args: @@ -46,7 +46,7 @@ def forward( def pre_decoder_hook( self, td: TensorDict, env=None, hidden: Any = None, num_starts: int = 0 - ) -> Tuple[TensorDict, Any]: + ) -> tuple[TensorDict, Any]: """By default, we only require the input for the actor to be a tuple (in JSSP we only have operation embeddings but in FJSP we have operation and machine embeddings. By expecting a tuple we can generalize things.) @@ -183,7 +183,7 @@ def __init__( stepwise: bool = False, scaling_factor: int = 1000, ): - super(L2DDecoder, self).__init__() + super().__init__() if feature_extractor is None and stepwise: if env_name == "fjsp" or (het_emb and env_name == "jssp"): @@ -307,7 +307,7 @@ def __init__( def pre_decoder_hook( self, td: TensorDict, env=None, hidden: Any = None, num_starts: int = 0 - ) -> Tuple[TensorDict, Any]: + ) -> tuple[TensorDict, Any]: cache = self._precompute_cache(hidden, num_starts=num_starts) return td, env, (cache,) @@ -327,9 +327,7 @@ def __init__( dynamic_embedding = None else: # otherwise we might want to update the static embeddings using dynamic updates - dynamic_embedding = JSSPDynamicEmbedding( - embed_dim, scaling_factor=scaling_factor - ) + dynamic_embedding = JSSPDynamicEmbedding(embed_dim, scaling_factor=scaling_factor) pointer = L2DAttnPointer(env_name, embed_dim, num_heads, check_nan=False) super().__init__( @@ -364,16 +362,14 @@ def _compute_kvl(self, cached: PrecomputedCache, td: TensorDict): return glimpse_k, glimpse_v, logit_k - def _precompute_cache(self, embeddings: Tuple[torch.Tensor, torch.Tensor], **kwargs): + def _precompute_cache(self, embeddings: tuple[torch.Tensor, torch.Tensor], **kwargs): ops_emb, ma_emb = embeddings ( glimpse_key_fixed, glimpse_val_fixed, logit_key, - ) = self.project_node_embeddings( - ops_emb - ).chunk(3, dim=-1) + ) = self.project_node_embeddings(ops_emb).chunk(3, dim=-1) embeddings = TensorDict( {"op_embeddings": ops_emb, "machine_embeddings": ma_emb}, diff --git a/rl4co/models/zoo/l2d/policy.py b/rl4co/models/zoo/l2d/policy.py index 8446bbf4..576c2405 100644 --- a/rl4co/models/zoo/l2d/policy.py +++ b/rl4co/models/zoo/l2d/policy.py @@ -1,5 +1,3 @@ -from typing import Optional - import torch import torch.nn as nn @@ -28,15 +26,15 @@ class L2DPolicy(AutoregressivePolicy): def __init__( self, - encoder: Optional[AutoregressiveEncoder] = None, - decoder: Optional[AutoregressiveDecoder] = None, + encoder: AutoregressiveEncoder | None = None, + decoder: AutoregressiveDecoder | None = None, embed_dim: int = 64, num_encoder_layers: int = 2, env_name: str = "fjsp", het_emb: bool = True, scaling_factor: int = 1000, normalization: str = "batch", - init_embedding: Optional[nn.Module] = None, + init_embedding: nn.Module | None = None, stepwise_encoding: bool = False, tanh_clipping: float = 10, train_decode_type: str = "sampling", @@ -82,7 +80,7 @@ def __init__( ) # Pass to constructive policy - super(L2DPolicy, self).__init__( + super().__init__( encoder=encoder, decoder=decoder, env_name=env_name, @@ -97,15 +95,15 @@ def __init__( class L2DAttnPolicy(AutoregressivePolicy): def __init__( self, - encoder: Optional[AutoregressiveEncoder] = None, - decoder: Optional[AutoregressiveDecoder] = None, + encoder: AutoregressiveEncoder | None = None, + decoder: AutoregressiveDecoder | None = None, embed_dim: int = 256, num_heads: int = 8, num_encoder_layers: int = 4, scaling_factor: int = 1000, normalization: str = "batch", env_name: str = "fjsp", - init_embedding: Optional[nn.Module] = None, + init_embedding: nn.Module | None = None, tanh_clipping: float = 10, train_decode_type: str = "sampling", val_decode_type: str = "greedy", @@ -117,9 +115,7 @@ def __init__( if encoder is None: if init_embedding is None: - init_embedding = FJSPMatNetInitEmbedding( - embed_dim, scaling_factor=scaling_factor - ) + init_embedding = FJSPMatNetInitEmbedding(embed_dim, scaling_factor=scaling_factor) encoder = Encoder( embed_dim=embed_dim, @@ -141,7 +137,7 @@ def __init__( ) # Pass to constructive policy - super(L2DAttnPolicy, self).__init__( + super().__init__( encoder=encoder, decoder=decoder, env_name=env_name, @@ -199,9 +195,9 @@ def __init__( critic = MLP(input_dim, 1, num_neurons=[embed_dim] * 2) self.critic = critic - assert isinstance( - self.encoder, NoEncoder - ), "Define a feature extractor for decoder rather than an encoder in stepwise PPO" + assert isinstance(self.encoder, NoEncoder), ( + "Define a feature extractor for decoder rather than an encoder in stepwise PPO" + ) def evaluate(self, td): # Encoder: get encoder output and initial embeddings from initial state @@ -214,9 +210,7 @@ def evaluate(self, td): # pred value via the value head value_pred = self.critic(h_pooled) # pre decoder / actor hook - td, _, hidden = self.decoder.actor.pre_decoder_hook( - td, None, hidden, num_starts=0 - ) + td, _, hidden = self.decoder.actor.pre_decoder_hook(td, None, hidden, num_starts=0) logits, mask = self.decoder.actor(td, *hidden) # get logprobs and entropy over logp distribution logprobs = process_logits(logits, mask, tanh_clipping=self.tanh_clipping) diff --git a/rl4co/models/zoo/matnet/decoder.py b/rl4co/models/zoo/matnet/decoder.py index df0e2160..1a3ab91f 100644 --- a/rl4co/models/zoo/matnet/decoder.py +++ b/rl4co/models/zoo/matnet/decoder.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import Tuple import torch import torch.nn as nn @@ -23,15 +22,13 @@ class PrecomputedCache: class MatNetDecoder(AttentionModelDecoder): - def _precompute_cache(self, embeddings: Tuple[Tensor, Tensor], *args, **kwargs): + def _precompute_cache(self, embeddings: tuple[Tensor, Tensor], *args, **kwargs): row_emb, col_emb = embeddings ( glimpse_key_fixed, glimpse_val_fixed, logit_key, - ) = self.project_node_embeddings( - col_emb - ).chunk(3, dim=-1) + ) = self.project_node_embeddings(col_emb).chunk(3, dim=-1) # Optionally disable the graph context from the initial embedding as done in POMO if self.use_graph_context: @@ -74,22 +71,18 @@ def __init__( self.no_job_emb = nn.Parameter(torch.rand(1, 1, embed_dim), requires_grad=True) - def _precompute_cache(self, embeddings: Tuple[Tensor, Tensor], **kwargs): + def _precompute_cache(self, embeddings: tuple[Tensor, Tensor], **kwargs): job_emb, ma_emb = embeddings bs, _, emb_dim = job_emb.shape - job_emb_plus_one = torch.cat( - (job_emb, self.no_job_emb.expand((bs, 1, emb_dim))), dim=1 - ) + job_emb_plus_one = torch.cat((job_emb, self.no_job_emb.expand((bs, 1, emb_dim))), dim=1) ( glimpse_key_fixed, glimpse_val_fixed, logit_key, - ) = self.project_node_embeddings( - job_emb_plus_one - ).chunk(3, dim=-1) + ) = self.project_node_embeddings(job_emb_plus_one).chunk(3, dim=-1) # Optionally disable the graph context from the initial embedding as done in POMO if self.use_graph_context: @@ -133,7 +126,7 @@ def __init__( self.cached_embs: PrecomputedCache = None self.tanh_clipping = tanh_clipping - def _precompute_cache(self, embeddings: Tuple[Tensor], **kwargs): + def _precompute_cache(self, embeddings: tuple[Tensor], **kwargs): self.cached_embs = super()._precompute_cache(embeddings, **kwargs) def forward( @@ -142,8 +135,7 @@ def forward( decode_type="sampling", num_starts: int = 1, **decoding_kwargs, - ) -> Tuple[Tensor, Tensor, TensorDict]: - + ) -> tuple[Tensor, Tensor, TensorDict]: logits, mask = super().forward(td, self.cached_embs, num_starts) logprobs = process_logits( logits, diff --git a/rl4co/models/zoo/matnet/encoder.py b/rl4co/models/zoo/matnet/encoder.py index 0af88e23..bed8a319 100644 --- a/rl4co/models/zoo/matnet/encoder.py +++ b/rl4co/models/zoo/matnet/encoder.py @@ -1,5 +1,3 @@ -from typing import Optional - import torch import torch.nn as nn import torch.nn.functional as F @@ -107,9 +105,7 @@ def __init__( mix2_init=mix2_init, ) - super().__init__( - embed_dim=embed_dim, num_heads=num_heads, bias=bias, sdpa_fn=attn_fn - ) + super().__init__(embed_dim=embed_dim, num_heads=num_heads, bias=bias, sdpa_fn=attn_fn) class MatNetMHA(nn.Module): @@ -149,7 +145,7 @@ def __init__( num_heads: int, bias: bool = False, feedforward_hidden: int = 512, - normalization: Optional[str] = "instance", + normalization: str | None = "instance", ): super().__init__() self.MHA = MatNetMHA(embed_dim, num_heads, bias) diff --git a/rl4co/models/zoo/matnet/matnet_w_sa.py b/rl4co/models/zoo/matnet/matnet_w_sa.py index cf06056f..4b944434 100644 --- a/rl4co/models/zoo/matnet/matnet_w_sa.py +++ b/rl4co/models/zoo/matnet/matnet_w_sa.py @@ -116,9 +116,7 @@ def __init__( self.op_attn = MultiHeadAttention(embed_dim, num_heads, bias=bias) self.ma_attn = MultiHeadAttention(embed_dim, num_heads, bias=bias) - self.cross_attn = EfficientMixedScoreMultiHeadAttention( - embed_dim, num_heads, bias=bias - ) + self.cross_attn = EfficientMixedScoreMultiHeadAttention(embed_dim, num_heads, bias=bias) self.op_ffn = TransformerFFN(embed_dim, feedforward_hidden, normalization) self.ma_ffn = TransformerFFN(embed_dim, feedforward_hidden, normalization) @@ -126,9 +124,7 @@ def __init__( self.op_norm = Normalization(embed_dim, normalization) self.ma_norm = Normalization(embed_dim, normalization) - def forward( - self, op_in, ma_in, cost_mat, op_mask=None, ma_mask=None, cross_mask=None - ): + def forward(self, op_in, ma_in, cost_mat, op_mask=None, ma_mask=None, cross_mask=None): op_cross_out, ma_cross_out = self.cross_attn( op_in, ma_in, attn_mask=cross_mask, cost_mat=cost_mat ) diff --git a/rl4co/models/zoo/matnet/model.py b/rl4co/models/zoo/matnet/model.py index af26870a..f05ab8e4 100644 --- a/rl4co/models/zoo/matnet/model.py +++ b/rl4co/models/zoo/matnet/model.py @@ -44,7 +44,7 @@ def __init__( else: kwargs["num_augment"] = 0 - super(MatNet, self).__init__( + super().__init__( env=env, policy=policy, num_starts=num_starts, diff --git a/rl4co/models/zoo/matnet/policy.py b/rl4co/models/zoo/matnet/policy.py index 26e50a7d..bf2e9227 100644 --- a/rl4co/models/zoo/matnet/policy.py +++ b/rl4co/models/zoo/matnet/policy.py @@ -7,11 +7,7 @@ from rl4co.envs.scheduling.ffsp.env import FFSPEnv from rl4co.models.common.constructive.autoregressive import AutoregressivePolicy -from rl4co.models.zoo.matnet.decoder import ( - MatNetDecoder, - MatNetFFSPDecoder, - MultiStageFFSPDecoder, -) +from rl4co.models.zoo.matnet.decoder import MatNetDecoder, MatNetFFSPDecoder, MultiStageFFSPDecoder from rl4co.models.zoo.matnet.encoder import MatNetEncoder from rl4co.utils.ops import batchify from rl4co.utils.pylogger import get_pylogger @@ -68,7 +64,7 @@ def __init__( use_graph_context=use_graph_context, ) - super(MatNetPolicy, self).__init__( + super().__init__( env_name=env_name, encoder=MatNetEncoder( embed_dim=embed_dim, diff --git a/rl4co/models/zoo/mdam/decoder.py b/rl4co/models/zoo/mdam/decoder.py index 87ad9d7b..3bdf014d 100644 --- a/rl4co/models/zoo/mdam/decoder.py +++ b/rl4co/models/zoo/mdam/decoder.py @@ -39,7 +39,7 @@ def __init__( val_decode_type: str = "greedy", test_decode_type: str = "greedy", ): - super(MDAMDecoder, self).__init__() + super().__init__() self.dynamic_embedding = env_dynamic_embedding(env_name, {"embed_dim": embed_dim}) self.train_decode_type = train_decode_type @@ -47,15 +47,10 @@ def __init__( self.test_decode_type = test_decode_type self.W_placeholder = nn.Parameter(torch.Tensor(2 * embed_dim)) - self.W_placeholder.data.uniform_( - -1, 1 - ) # Placeholder should be in range of activations + self.W_placeholder.data.uniform_(-1, 1) # Placeholder should be in range of activations self.context = nn.ModuleList( - [ - env_context_embedding(env_name, {"embed_dim": embed_dim}) - for _ in range(num_paths) - ] + [env_context_embedding(env_name, {"embed_dim": embed_dim}) for _ in range(num_paths)] ) self.project_node_embeddings = [ @@ -73,9 +68,7 @@ def __init__( ] self.project_step_context = nn.ModuleList(self.project_step_context) - self.project_out = [ - nn.Linear(embed_dim, embed_dim, bias=False) for _ in range(num_paths) - ] + self.project_out = [nn.Linear(embed_dim, embed_dim, bias=False) for _ in range(num_paths)] self.project_out = nn.ModuleList(self.project_out) self.dynamic_embedding = env_dynamic_embedding(env_name, {"embed_dim": embed_dim}) @@ -224,9 +217,7 @@ def _precompute(self, embeddings, num_steps=1, path_index=None): glimpse_key_fixed, glimpse_val_fixed, logit_key_fixed, - ) = self.project_node_embeddings[path_index](embeddings[:, None, :, :]).chunk( - 3, dim=-1 - ) + ) = self.project_node_embeddings[path_index](embeddings[:, None, :, :]).chunk(3, dim=-1) fixed = PrecomputedCache( node_embeddings=embeddings, @@ -249,18 +240,12 @@ def _make_heads(self, v, num_steps=None): self.num_heads, -1, ) - .permute( - 3, 0, 1, 2, 4 - ) # (n_heads, batch_size, num_steps, graph_size, head_dim) + .permute(3, 0, 1, 2, 4) # (n_heads, batch_size, num_steps, graph_size, head_dim) ) def _get_logprobs(self, fixed, td, path_index, normalize=True): - step_context = self.context[path_index]( - fixed.node_embeddings, td - ) # [batch, embed_dim] - glimpse_q = fixed.graph_context + step_context.unsqueeze(1).to( - fixed.graph_context.device - ) + step_context = self.context[path_index](fixed.node_embeddings, td) # [batch, embed_dim] + glimpse_q = fixed.graph_context + step_context.unsqueeze(1).to(fixed.graph_context.device) # Compute keys and values for the nodes ( @@ -287,9 +272,9 @@ def _one_to_many_logits(self, query, glimpse_K, glimpse_V, logit_K, mask, path_i key_size = val_size = embed_dim // self.num_heads # Compute the glimpse, rearrange dimensions so the dimensions are (n_heads, batch_size, num_steps, 1, key_size) - glimpse_Q = query.view( - batch_size, num_steps, self.num_heads, 1, key_size - ).permute(2, 0, 1, 3, 4) + glimpse_Q = query.view(batch_size, num_steps, self.num_heads, 1, key_size).permute( + 2, 0, 1, 3, 4 + ) # Batch matrix multiplication to compute compatibilities (n_heads, batch_size, num_steps, graph_size) compatibility = torch.matmul(glimpse_Q, glimpse_K.transpose(-2, -1)) / math.sqrt( @@ -297,9 +282,7 @@ def _one_to_many_logits(self, query, glimpse_K, glimpse_V, logit_K, mask, path_i ) if self.mask_inner: assert self.mask_logits, "Cannot mask inner without masking logits" - compatibility[~mask[None, :, None, None, :].expand_as(compatibility)] = ( - -math.inf - ) + compatibility[~mask[None, :, None, None, :].expand_as(compatibility)] = -math.inf # Batch matrix multiplication to compute heads (n_heads, batch_size, num_steps, val_size) heads = torch.matmul(F.softmax(compatibility, dim=-1), glimpse_V) diff --git a/rl4co/models/zoo/mdam/encoder.py b/rl4co/models/zoo/mdam/encoder.py index bab7546f..3334eb66 100644 --- a/rl4co/models/zoo/mdam/encoder.py +++ b/rl4co/models/zoo/mdam/encoder.py @@ -1,13 +1,9 @@ -from typing import Callable, Optional +from collections.abc import Callable import torch import torch.nn as nn -from rl4co.models.nn.graph.attnnet import ( - MultiHeadAttentionLayer, - Normalization, - SkipConnection, -) +from rl4co.models.nn.graph.attnnet import MultiHeadAttentionLayer, Normalization, SkipConnection from rl4co.models.zoo.mdam.mha import MultiHeadAttentionMDAM @@ -20,9 +16,9 @@ def __init__( node_dim=None, normalization="batch", feedforward_hidden=512, - sdpa_fn: Optional[Callable] = None, + sdpa_fn: Callable | None = None, ): - super(MDAMGraphAttentionEncoder, self).__init__() + super().__init__() # To map input to embedding space self.init_embed = nn.Linear(node_dim, embed_dim) if node_dim is not None else None @@ -77,14 +73,10 @@ def forward(self, x, mask=None, return_transform_loss=False): def change(self, attn, V, h_old, mask): num_heads, batch_size, graph_size, feat_size = V.size() attn = ( - mask.float() - .view(1, batch_size, 1, graph_size) - .repeat(num_heads, 1, graph_size, 1) + mask.float().view(1, batch_size, 1, graph_size).repeat(num_heads, 1, graph_size, 1) * attn ) - attn = attn / ( - torch.sum(attn, dim=-1).view(num_heads, batch_size, graph_size, 1) + 1e-9 - ) + attn = attn / (torch.sum(attn, dim=-1).view(num_heads, batch_size, graph_size, 1) + 1e-9) heads = torch.matmul(attn, V) h_new = torch.mm( diff --git a/rl4co/models/zoo/mdam/mha.py b/rl4co/models/zoo/mdam/mha.py index 4499faa0..3be65fb4 100644 --- a/rl4co/models/zoo/mdam/mha.py +++ b/rl4co/models/zoo/mdam/mha.py @@ -11,7 +11,7 @@ class MultiHeadAttentionMDAM(nn.Module): def __init__(self, embed_dim, n_heads, last_one=False, sdpa_fn=None): - super(MultiHeadAttentionMDAM, self).__init__() + super().__init__() if sdpa_fn is not None: log.warning("sdpa_fn is not used in this implementation") @@ -77,9 +77,7 @@ def forward(self, q, h=None, mask=None): heads = torch.matmul(attn, V) out = torch.mm( - heads.permute(1, 2, 0, 3) - .contiguous() - .view(-1, self.n_heads * self.embed_dim), + heads.permute(1, 2, 0, 3).contiguous().view(-1, self.n_heads * self.embed_dim), self.W_out.view(-1, self.embed_dim), ).view(batch_size, n_query, self.embed_dim) if self.last_one: diff --git a/rl4co/models/zoo/mdam/model.py b/rl4co/models/zoo/mdam/model.py index 9485a3c1..0a54126d 100644 --- a/rl4co/models/zoo/mdam/model.py +++ b/rl4co/models/zoo/mdam/model.py @@ -6,11 +6,7 @@ from rl4co.envs.common.base import RL4COEnvBase from rl4co.models.rl import REINFORCE -from rl4co.models.rl.reinforce.baselines import ( - REINFORCEBaseline, - RolloutBaseline, - WarmupBaseline, -) +from rl4co.models.rl.reinforce.baselines import REINFORCEBaseline, RolloutBaseline, WarmupBaseline from rl4co.models.zoo.mdam.policy import MDAMPolicy @@ -100,9 +96,7 @@ def calculate_loss( ) # REINFORCE baseline - bl_val, bl_loss = ( - self.baseline.eval(td, reward, self.env) if extra is None else (extra, 0) - ) + bl_val, bl_loss = self.baseline.eval(td, reward, self.env) if extra is None else (extra, 0) # Main loss function # reward: [batch, num_paths]. Note that the baseline value is the max reward diff --git a/rl4co/models/zoo/mdam/policy.py b/rl4co/models/zoo/mdam/policy.py index 5064de56..207fafd8 100644 --- a/rl4co/models/zoo/mdam/policy.py +++ b/rl4co/models/zoo/mdam/policy.py @@ -49,9 +49,7 @@ def __init__( else decoder ) - super(MDAMPolicy, self).__init__( - env_name=env_name, encoder=encoder, decoder=decoder - ) + super().__init__(env_name=env_name, encoder=encoder, decoder=decoder) self.init_embedding = env_init_embedding(env_name, {"embed_dim": embed_dim}) diff --git a/rl4co/models/zoo/mvmoe/__init__.py b/rl4co/models/zoo/mvmoe/__init__.py index a26e33c4..0f2d2ba1 100644 --- a/rl4co/models/zoo/mvmoe/__init__.py +++ b/rl4co/models/zoo/mvmoe/__init__.py @@ -1,2 +1 @@ -from .model import MVMoE_POMO -from .model import MVMoE_AM +from .model import MVMoE_AM, MVMoE_POMO diff --git a/rl4co/models/zoo/mvmoe/model.py b/rl4co/models/zoo/mvmoe/model.py index aa557751..43e3a6c9 100644 --- a/rl4co/models/zoo/mvmoe/model.py +++ b/rl4co/models/zoo/mvmoe/model.py @@ -1,4 +1,4 @@ -from typing import Callable +from collections.abc import Callable import torch.nn as nn @@ -57,7 +57,7 @@ def __init__( policy = AttentionModelPolicy(env_name=env.name, **policy_kwargs) # Initialize with the shared baseline - super(MVMoE_POMO, self).__init__( + super().__init__( env, policy, policy_kwargs, @@ -111,6 +111,4 @@ def __init__( policy = AttentionModelPolicy(env_name=env.name, **policy_kwargs) # Initialize with the shared baseline - super(MVMoE_AM, self).__init__( - env, policy, baseline, policy_kwargs, baseline_kwargs, **kwargs - ) + super().__init__(env, policy, baseline, policy_kwargs, baseline_kwargs, **kwargs) diff --git a/rl4co/models/zoo/n2s/decoder.py b/rl4co/models/zoo/n2s/decoder.py index 2c843a8b..8f5097fd 100644 --- a/rl4co/models/zoo/n2s/decoder.py +++ b/rl4co/models/zoo/n2s/decoder.py @@ -38,12 +38,8 @@ def __init__( assert embed_dim % num_heads == 0 - self.W_Q = nn.Parameter( - torch.Tensor(self.n_heads, self.input_dim, self.hidden_dim) - ) - self.W_K = nn.Parameter( - torch.Tensor(self.n_heads, self.input_dim, self.hidden_dim) - ) + self.W_Q = nn.Parameter(torch.Tensor(self.n_heads, self.input_dim, self.hidden_dim)) + self.W_K = nn.Parameter(torch.Tensor(self.n_heads, self.input_dim, self.hidden_dim)) self.agg = MLP(input_dim=2 * self.n_heads + 4, output_dim=1, num_neurons=[32, 32]) @@ -69,9 +65,7 @@ def forward(self, td: TensorDict, final_h: Tensor, final_p: Tensor) -> Tensor: solution = td["rec_current"] pre = solution.argsort() # pre=[1,2,0] - post = solution.gather( - 1, solution - ) # post=[1,2,0] # the second neighbour works better + post = solution.gather(1, solution) # post=[1,2,0] # the second neighbour works better batch_size, graph_size_plus1, input_dim = final_h.size() hflat = final_h.contiguous().view(-1, input_dim) ################# reshape @@ -82,20 +76,14 @@ def forward(self, td: TensorDict, final_h: Tensor, final_p: Tensor) -> Tensor: hidden_Q = torch.matmul(hflat, self.W_Q).view(shp) hidden_K = torch.matmul(hflat, self.W_K).view(shp) - Q_pre = hidden_Q.gather( - 2, pre.view(1, batch_size, graph_size_plus1, 1).expand_as(hidden_Q) - ) + Q_pre = hidden_Q.gather(2, pre.view(1, batch_size, graph_size_plus1, 1).expand_as(hidden_Q)) K_post = hidden_K.gather( 2, post.view(1, batch_size, graph_size_plus1, 1).expand_as(hidden_Q) ) compatibility = ( - (Q_pre * hidden_K).sum(-1) - + (hidden_Q * K_post).sum(-1) - - (Q_pre * K_post).sum(-1) - )[ - :, :, 1: - ] # (n_heads, batch_size, graph_size) (12) + (Q_pre * hidden_K).sum(-1) + (hidden_Q * K_post).sum(-1) - (Q_pre * K_post).sum(-1) + )[:, :, 1:] # (n_heads, batch_size, graph_size) (12) compatibility_pairing = torch.cat( ( @@ -166,31 +154,22 @@ def forward(self, td: TensorDict, final_h: Tensor, final_p: Tensor) -> torch.Ten arange = torch.arange(batch_size, device=final_h.device) h_pickup = final_h[arange, pos_pickup].unsqueeze(1) # (batch_size, 1, input_dim) - h_delivery = final_h[arange, pos_delivery].unsqueeze( - 1 - ) # (batch_size, 1, input_dim) + h_delivery = final_h[arange, pos_delivery].unsqueeze(1) # (batch_size, 1, input_dim) h_K_neibour = final_h.gather( 1, solution.view(batch_size, graph_size_plus1, 1).expand_as(final_h) ) # (batch_size, graph_size+1, input_dim) compatibility_pickup_pre = ( - self.compater_insert1( - h_pickup, final_h - ) # (n_heads, batch_size, 1, graph_size+1) + self.compater_insert1(h_pickup, final_h) # (n_heads, batch_size, 1, graph_size+1) .permute(1, 2, 3, 0) # (batch_size, 1, graph_size+1, n_heads) .view(shp_p) # (batch_size, graph_size+1, 1, n_heads) .expand(shp) # (batch_size, graph_size+1, graph_size+1, n_heads) ) compatibility_pickup_post = ( - self.compater_insert2(h_pickup, h_K_neibour) - .permute(1, 2, 3, 0) - .view(shp_p) - .expand(shp) + self.compater_insert2(h_pickup, h_K_neibour).permute(1, 2, 3, 0).view(shp_p).expand(shp) ) compatibility_delivery_pre = ( - self.compater_insert1( - h_delivery, final_h - ) # (n_heads, batch_size, 1, graph_size+1) + self.compater_insert1(h_delivery, final_h) # (n_heads, batch_size, 1, graph_size+1) .permute(1, 2, 3, 0) # (batch_size, 1, graph_size+1, n_heads) .view(shp_d) # (batch_size, 1, graph_size+1, n_heads) .expand(shp) # (batch_size, graph_size+1, graph_size+1, n_heads) @@ -238,9 +217,7 @@ def forward(self, x: torch.Tensor, best_cost: torch.Tensor) -> torch.Tensor: graph_feature: torch.Tensor = self.project_graph(mean_pooling)[ :, None, : ] # (batch_size, 1, input_dim/2) - node_feature: torch.Tensor = self.project_node( - x - ) # (batch_size, graph_size+1, input_dim/2) + node_feature: torch.Tensor = self.project_node(x) # (batch_size, graph_size+1, input_dim/2) # pass through value_head, get estimated value fusion = node_feature + graph_feature.expand_as( diff --git a/rl4co/models/zoo/n2s/encoder.py b/rl4co/models/zoo/n2s/encoder.py index c219c3c3..a51050ee 100644 --- a/rl4co/models/zoo/n2s/encoder.py +++ b/rl4co/models/zoo/n2s/encoder.py @@ -1,6 +1,6 @@ import math -from typing import Callable, Tuple +from collections.abc import Callable import torch import torch.nn as nn @@ -48,7 +48,7 @@ def init_parameters(self): def forward( self, h_fea: torch.Tensor, aux_att_score: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: # h should be (batch_size, n_query, input_dim) batch_size, n_query, input_dim = h_fea.size() @@ -64,9 +64,7 @@ def forward( # Calculate compatibility (n_heads, batch_size, n_query, n_key) compatibility = torch.cat((torch.matmul(Q, K.transpose(2, 3)), aux_att_score), 0) - attn_raw = compatibility.permute( - 1, 2, 3, 0 - ) # (batch_size, n_query, n_key, n_heads) + attn_raw = compatibility.permute(1, 2, 3, 0) # (batch_size, n_query, n_key, n_heads) attn = self.score_aggr(attn_raw).permute( 3, 0, 1, 2 ) # (n_heads, batch_size, n_query, n_key) @@ -94,11 +92,11 @@ def __init__(self, n_heads: int, input_dim: int, normalization: str) -> None: self.Norm = Normalization(input_dim, normalization) - __call__: Callable[..., Tuple[torch.Tensor, torch.Tensor]] + __call__: Callable[..., tuple[torch.Tensor, torch.Tensor]] def forward( self, h_fea: torch.Tensor, aux_att_score: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: # Attention and Residual connection h_wave, aux_att_score = self.SynthAtt(h_fea, aux_att_score) @@ -107,9 +105,7 @@ def forward( class FFNormSubLayer(nn.Module): - def __init__( - self, input_dim: int, feed_forward_hidden: int, normalization: str - ) -> None: + def __init__(self, input_dim: int, feed_forward_hidden: int, normalization: str) -> None: super().__init__() self.FF = ( @@ -139,19 +135,15 @@ def __init__( ) -> None: super().__init__() - self.SynthAttNorm_sublayer = SynthAttNormSubLayer( - n_heads, input_dim, normalization - ) + self.SynthAttNorm_sublayer = SynthAttNormSubLayer(n_heads, input_dim, normalization) - self.FFNorm_sublayer = FFNormSubLayer( - input_dim, feed_forward_hidden, normalization - ) + self.FFNorm_sublayer = FFNormSubLayer(input_dim, feed_forward_hidden, normalization) - __call__: Callable[..., Tuple[torch.Tensor, torch.Tensor]] + __call__: Callable[..., tuple[torch.Tensor, torch.Tensor]] def forward( self, h_fea: torch.Tensor, aux_att_score: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: h_wave, aux_att_score = self.SynthAttNorm_sublayer(h_fea, aux_att_score) return self.FFNorm_sublayer(h_wave), aux_att_score @@ -184,7 +176,7 @@ def __init__( normalization: str = "layer", feedforward_hidden: int = 128, ): - super(N2SEncoder, self).__init__( + super().__init__( embed_dim=embed_dim, init_embedding=init_embedding, pos_embedding=pos_embedding, @@ -210,7 +202,7 @@ def __init__( ) ) - def _encoder_forward(self, init_h: Tensor, init_p: Tensor) -> Tuple[Tensor, Tensor]: + def _encoder_forward(self, init_h: Tensor, init_p: Tensor) -> tuple[Tensor, Tensor]: embed_p = self.pos_net(init_p) final_h, final_p = self.net(init_h, embed_p) diff --git a/rl4co/models/zoo/n2s/model.py b/rl4co/models/zoo/n2s/model.py index d3080e4a..0483a496 100644 --- a/rl4co/models/zoo/n2s/model.py +++ b/rl4co/models/zoo/n2s/model.py @@ -46,9 +46,7 @@ def __init__( critic_kwargs["feedforward_hidden"] if "feedforward_hidden" in critic_kwargs else 128, - critic_kwargs["normalization"] - if "normalization" in critic_kwargs - else "layer", + critic_kwargs["normalization"] if "normalization" in critic_kwargs else "layer", bias=False, ) value_head = CriticDecoder(embed_dim) diff --git a/rl4co/models/zoo/n2s/policy.py b/rl4co/models/zoo/n2s/policy.py index 59eb7229..2693a145 100644 --- a/rl4co/models/zoo/n2s/policy.py +++ b/rl4co/models/zoo/n2s/policy.py @@ -5,10 +5,7 @@ from rl4co.envs import RL4COEnvBase, get_env from rl4co.models.common.improvement.base import ImprovementPolicy -from rl4co.models.zoo.n2s.decoder import ( - NodePairReinsertionDecoder, - NodePairRemovalDecoder, -) +from rl4co.models.zoo.n2s.decoder import NodePairReinsertionDecoder, NodePairRemovalDecoder from rl4co.models.zoo.n2s.encoder import N2SEncoder from rl4co.utils.decoding import DecodingStrategy, get_decoding_strategy from rl4co.utils.pylogger import get_pylogger @@ -57,7 +54,7 @@ def __init__( val_decode_type: str = "sampling", test_decode_type: str = "sampling", ): - super(N2SPolicy, self).__init__() + super().__init__() self.env_name = env_name @@ -74,9 +71,7 @@ def __init__( feedforward_hidden=feedforward_hidden, ) - self.removal_decoder = NodePairRemovalDecoder( - embed_dim=embed_dim, num_heads=num_heads - ) + self.removal_decoder = NodePairRemovalDecoder(embed_dim=embed_dim, num_heads=num_heads) self.reinsertion_decoder = NodePairReinsertionDecoder( embed_dim=embed_dim, num_heads=num_heads @@ -123,9 +118,7 @@ def forward( h_wave, final_p = self.encoder(td) if only_return_embed: return {"embeds": h_wave.detach()} - final_h = ( - self.project_node(h_wave) + self.project_graph(h_wave.max(1)[0])[:, None, :] - ) + final_h = self.project_node(h_wave) + self.project_graph(h_wave.max(1)[0])[:, None, :] # Instantiate environment if needed if isinstance(env, str) or env is None: @@ -184,11 +177,7 @@ def forward( logprob_reinsertion, action_reinsertion = decode_strategy.step( logits, mask, - action=( - actions[:, 1] * seq_length + actions[:, 2] - if actions is not None - else None - ), + action=(actions[:, 1] * seq_length + actions[:, 2] if actions is not None else None), ) action_reinsertion = action_reinsertion.unsqueeze(-1) if phase == "train": diff --git a/rl4co/models/zoo/nargnn/encoder.py b/rl4co/models/zoo/nargnn/encoder.py index e63dd84b..9d680758 100644 --- a/rl4co/models/zoo/nargnn/encoder.py +++ b/rl4co/models/zoo/nargnn/encoder.py @@ -1,4 +1,4 @@ -from typing import Callable, Optional +from collections.abc import Callable import torch import torch.nn as nn @@ -36,13 +36,10 @@ def __init__( linear_bias: bool = True, undirected_graph: bool = True, ) -> None: - super(EdgeHeatmapGenerator, self).__init__() + super().__init__() self.linears = nn.ModuleList( - [ - nn.Linear(embed_dim, embed_dim, bias=linear_bias) - for _ in range(num_layers - 1) - ] + [nn.Linear(embed_dim, embed_dim, bias=linear_bias) for _ in range(num_layers - 1)] ) self.output = nn.Linear(embed_dim, 1, bias=linear_bias) @@ -84,7 +81,9 @@ def _make_heatmap_logits(self, batch_graph: Batch) -> Tensor: # type: ignore if heatmap.dtype == torch.float32 or heatmap.dtype == torch.bfloat16: small_value = 1e-12 elif heatmap.dtype == torch.float16: - small_value = 3e-8 # the smallest positive number such that log(small_value) is not -inf + small_value = ( + 3e-8 # the smallest positive number such that log(small_value) is not -inf + ) else: raise ValueError(f"Unsupported dtype: {heatmap.dtype}") @@ -128,16 +127,16 @@ def __init__( embed_dim: int = 64, env_name: str = "tsp", # TODO: pass network - init_embedding: Optional[nn.Module] = None, - edge_embedding: Optional[nn.Module] = None, - graph_network: Optional[nn.Module] = None, - heatmap_generator: Optional[nn.Module] = None, + init_embedding: nn.Module | None = None, + edge_embedding: nn.Module | None = None, + graph_network: nn.Module | None = None, + heatmap_generator: nn.Module | None = None, num_layers_heatmap_generator: int = 5, num_layers_graph_encoder: int = 15, act_fn="silu", agg_fn="mean", linear_bias: bool = True, - k_sparse: Optional[int] = None, + k_sparse: int | None = None, ): super(NonAutoregressiveEncoder, self).__init__() self.env_name = env_name @@ -149,9 +148,7 @@ def __init__( ) self.edge_embedding = ( - env_edge_embedding( - self.env_name, {"embed_dim": embed_dim, "k_sparse": k_sparse} - ) + env_edge_embedding(self.env_name, {"embed_dim": embed_dim, "k_sparse": k_sparse}) if edge_embedding is None else edge_embedding ) @@ -187,9 +184,7 @@ def forward(self, td: TensorDict): # Process embedding into graph # TODO: standardize? - graph.x, graph.edge_attr = self.graph_network( - graph.x, graph.edge_index, graph.edge_attr - ) + graph.x, graph.edge_attr = self.graph_network(graph.x, graph.edge_index, graph.edge_attr) # Generate heatmap logits heatmap_logits = self.heatmap_generator(graph) @@ -210,9 +205,7 @@ def forward(self, td: TensorDict): # Process embedding into graph # TODO: standardize? - graph.x, graph.edge_attr = self.graph_network( - graph.x, graph.edge_index, graph.edge_attr - ) + graph.x, graph.edge_attr = self.graph_network(graph.x, graph.edge_index, graph.edge_attr) proc_embeds = graph.x batch_size = node_embed.shape[0] diff --git a/rl4co/models/zoo/nargnn/policy.py b/rl4co/models/zoo/nargnn/policy.py index 1953628b..2fdbcbb8 100644 --- a/rl4co/models/zoo/nargnn/policy.py +++ b/rl4co/models/zoo/nargnn/policy.py @@ -1,5 +1,3 @@ -from typing import Optional - import torch.nn as nn from rl4co.models.common.constructive.nonautoregressive import ( @@ -49,14 +47,14 @@ class NARGNNPolicy(NonAutoregressivePolicy): def __init__( self, - encoder: Optional[NonAutoregressiveEncoder] = None, - decoder: Optional[NonAutoregressiveDecoder] = None, + encoder: NonAutoregressiveEncoder | None = None, + decoder: NonAutoregressiveDecoder | None = None, embed_dim: int = 64, env_name: str = "tsp", - init_embedding: Optional[nn.Module] = None, - edge_embedding: Optional[nn.Module] = None, - graph_network: Optional[nn.Module] = None, - heatmap_generator: Optional[nn.Module] = None, + init_embedding: nn.Module | None = None, + edge_embedding: nn.Module | None = None, + graph_network: nn.Module | None = None, + heatmap_generator: nn.Module | None = None, num_layers_heatmap_generator: int = 5, num_layers_graph_encoder: int = 15, act_fn="silu", @@ -96,7 +94,7 @@ def __init__( ) # Pass to constructive policy - super(NARGNNPolicy, self).__init__( + super().__init__( encoder=encoder, decoder=decoder, env_name=env_name, diff --git a/rl4co/models/zoo/neuopt/decoder.py b/rl4co/models/zoo/neuopt/decoder.py index f9cca584..77ba778a 100644 --- a/rl4co/models/zoo/neuopt/decoder.py +++ b/rl4co/models/zoo/neuopt/decoder.py @@ -60,9 +60,7 @@ def forward(self, h, q1, q2, input_q1, input_q2) -> Tensor: + self.linear_Q1(q1).unsqueeze(1) + self.linear_K3(h) * self.linear_Q3(q1).unsqueeze(1) ) - ).sum( - -1 - ) # \mu stream + ).sum(-1) # \mu stream result += ( linear_V2.unsqueeze(1) * torch.tanh( @@ -70,8 +68,6 @@ def forward(self, h, q1, q2, input_q1, input_q2) -> Tensor: + self.linear_Q2(q2).unsqueeze(1) + self.linear_K4(h) * self.linear_Q4(q2).unsqueeze(1) ) - ).sum( - -1 - ) # \lambda stream + ).sum(-1) # \lambda stream return result, q1, q2 diff --git a/rl4co/models/zoo/neuopt/model.py b/rl4co/models/zoo/neuopt/model.py index 5bd7050d..1e0fa0f3 100644 --- a/rl4co/models/zoo/neuopt/model.py +++ b/rl4co/models/zoo/neuopt/model.py @@ -46,9 +46,7 @@ def __init__( critic_kwargs["feedforward_hidden"] if "feedforward_hidden" in critic_kwargs else 128, - critic_kwargs["normalization"] - if "normalization" in critic_kwargs - else "layer", + critic_kwargs["normalization"] if "normalization" in critic_kwargs else "layer", bias=False, ) value_head = CriticDecoder(embed_dim, dropout_rate=0.001) diff --git a/rl4co/models/zoo/neuopt/policy.py b/rl4co/models/zoo/neuopt/policy.py index 717ad164..0478468e 100644 --- a/rl4co/models/zoo/neuopt/policy.py +++ b/rl4co/models/zoo/neuopt/policy.py @@ -22,7 +22,7 @@ class CustomizeTSPInitEmbedding(nn.Module): """ def __init__(self, embed_dim, linear_bias=True): - super(CustomizeTSPInitEmbedding, self).__init__() + super().__init__() node_dim = 2 # x, y self.init_embed = nn.Sequential( nn.Linear(node_dim, embed_dim // 2, linear_bias), @@ -75,7 +75,7 @@ def __init__( val_decode_type: str = "sampling", test_decode_type: str = "sampling", ): - super(NeuOptPolicy, self).__init__() + super().__init__() self.env_name = env_name self.embed_dim = embed_dim @@ -183,9 +183,7 @@ def forward( action_index = torch.zeros(bs, env.k_max, dtype=torch.long).to(rec.device) k_action_left = torch.zeros(bs, env.k_max + 1, dtype=torch.long).to(rec.device) k_action_right = torch.zeros(bs, env.k_max, dtype=torch.long).to(rec.device) - next_of_last_action = ( - torch.zeros_like(rec[:, :1], dtype=torch.long).to(rec.device) - 1 - ) + next_of_last_action = torch.zeros_like(rec[:, :1], dtype=torch.long).to(rec.device) - 1 mask = torch.zeros_like(rec, dtype=torch.bool).to(rec.device) stopped = torch.ones(bs, dtype=torch.bool).to(rec.device) zeros = torch.zeros((bs, 1), device=td.device) @@ -242,9 +240,7 @@ def forward( input_q1.clone(), nfe.gather( 1, - (next_of_last_action % gs) - .view(bs, 1, 1) - .expand(bs, 1, self.embed_dim), + (next_of_last_action % gs).view(bs, 1, 1).expand(bs, 1, self.embed_dim), ).squeeze(1), ) @@ -261,9 +257,7 @@ def forward( # Calc next basic masks if i == 0: - visited_time_tag = ( - visited_time - visited_time.gather(1, action_sampled) - ) % gs + visited_time_tag = (visited_time - visited_time.gather(1, action_sampled)) % gs mask &= False mask[(visited_time_tag <= visited_time_tag.gather(1, action_sampled))] = True if i == 0: diff --git a/rl4co/models/zoo/polynet/decoder.py b/rl4co/models/zoo/polynet/decoder.py index 09580bc7..cdb7f0fe 100644 --- a/rl4co/models/zoo/polynet/decoder.py +++ b/rl4co/models/zoo/polynet/decoder.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import Tuple import torch.nn as nn @@ -105,23 +104,17 @@ def __init__( ) # For each node we compute (glimpse key, glimpse value, logit key) so 3 * embed_dim - self.project_node_embeddings = nn.Linear( - embed_dim, 3 * embed_dim, bias=linear_bias - ) + self.project_node_embeddings = nn.Linear(embed_dim, 3 * embed_dim, bias=linear_bias) self.project_fixed_context = nn.Linear(embed_dim, embed_dim, bias=linear_bias) self.use_graph_context = use_graph_context - def _precompute_cache_matnet( - self, embeddings: Tuple[Tensor, Tensor], *args, **kwargs - ): + def _precompute_cache_matnet(self, embeddings: tuple[Tensor, Tensor], *args, **kwargs): col_emb, row_emb = embeddings ( glimpse_key_fixed, glimpse_val_fixed, logit_key, - ) = self.project_node_embeddings( - col_emb - ).chunk(3, dim=-1) + ) = self.project_node_embeddings(col_emb).chunk(3, dim=-1) # Optionally disable the graph context from the initial embedding as done in POMO if self.use_graph_context: @@ -138,7 +131,7 @@ def _precompute_cache_matnet( logit_key=logit_key, ) - def _precompute_cache(self, embeddings: Tuple[Tensor, Tensor], *args, **kwargs): + def _precompute_cache(self, embeddings: tuple[Tensor, Tensor], *args, **kwargs): if self.encoder_type == "AM": return super()._precompute_cache(embeddings, *args, **kwargs) elif self.encoder_type == "MatNet": diff --git a/rl4co/models/zoo/polynet/model.py b/rl4co/models/zoo/polynet/model.py index fa3f1ae3..63f8a544 100644 --- a/rl4co/models/zoo/polynet/model.py +++ b/rl4co/models/zoo/polynet/model.py @@ -1,6 +1,7 @@ import logging -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any import torch @@ -73,14 +74,12 @@ def __init__( policy_kwargs.get("val_decode_type") == "greedy" or policy_kwargs.get("test_decode_type") == "greedy" ): - assert ( - val_num_solutions <= k - ), "If greedy decoding is used val_num_solutions must be <= k" + assert val_num_solutions <= k, ( + "If greedy decoding is used val_num_solutions must be <= k" + ) if encoder_type == "MatNet": - assert ( - num_augment == 1 - ), "MatNet does not use symmetric or dihedral augmentation" + assert num_augment == 1, "MatNet does not use symmetric or dihedral augmentation" if policy is None: policy = PolyNetPolicy( @@ -88,9 +87,7 @@ def __init__( ) if base_model_checkpoint_path is not None: - logging.info( - f"Trying to load weights from baseline model {base_model_checkpoint_path}" - ) + logging.info(f"Trying to load weights from baseline model {base_model_checkpoint_path}") checkpoint = torch.load(base_model_checkpoint_path, weights_only=False) state_dict = checkpoint["state_dict"] state_dict = {k.replace("policy.", "", 1): v for k, v in state_dict.items()} @@ -104,7 +101,7 @@ def __init__( kwargs_with_defaults.update(kwargs) # Initialize with the shared baseline - super(PolyNet, self).__init__(env, policy, baseline, **kwargs_with_defaults) + super().__init__(env, policy, baseline, **kwargs_with_defaults) self.num_augment = num_augment if self.num_augment > 1: @@ -121,9 +118,7 @@ def __init__( # for phase in ["train", "val", "test"]: # self.set_decode_type_multistart(phase) - def shared_step( - self, batch: Any, batch_idx: int, phase: str, dataloader_idx: int = None - ): + def shared_step(self, batch: Any, batch_idx: int, phase: str, dataloader_idx: int = None): td = self.env.reset(batch) n_aug = self.num_augment @@ -184,9 +179,7 @@ def shared_step( out.update({"max_aug_reward": max_aug_reward}) if out.get("actions", None) is not None: - actions_ = ( - out["best_multistart_actions"] if n_start > 1 else out["actions"] - ) + actions_ = out["best_multistart_actions"] if n_start > 1 else out["actions"] out.update({"best_aug_actions": gather_by_index(actions_, max_idxs)}) metrics = self.log_metrics(out, phase, dataloader_idx=dataloader_idx) @@ -197,8 +190,8 @@ def calculate_loss( td: TensorDict, batch: TensorDict, policy_out: dict, - reward: Optional[torch.Tensor] = None, - log_likelihood: Optional[torch.Tensor] = None, + reward: torch.Tensor | None = None, + log_likelihood: torch.Tensor | None = None, ): """Calculate loss following Poppy (https://arxiv.org/abs/2210.03475). @@ -217,9 +210,7 @@ def calculate_loss( ) # REINFORCE baseline - bl_val, bl_loss = ( - self.baseline.eval(td, reward, self.env) if extra is None else (extra, 0) - ) + bl_val, bl_loss = self.baseline.eval(td, reward, self.env) if extra is None else (extra, 0) # Log-likelihood mask. Mask everything but the best rollout per instance best_idx = (-reward).argsort(1).argsort(1) diff --git a/rl4co/models/zoo/polynet/policy.py b/rl4co/models/zoo/polynet/policy.py index 628858be..52a42461 100644 --- a/rl4co/models/zoo/polynet/policy.py +++ b/rl4co/models/zoo/polynet/policy.py @@ -85,7 +85,7 @@ def __init__( **kwargs, ) - super(PolyNetPolicy, self).__init__( + super().__init__( encoder=encoder, decoder=decoder, env_name=env_name, diff --git a/rl4co/models/zoo/pomo/model.py b/rl4co/models/zoo/pomo/model.py index b4057e30..bc9c9599 100644 --- a/rl4co/models/zoo/pomo/model.py +++ b/rl4co/models/zoo/pomo/model.py @@ -1,4 +1,5 @@ -from typing import Any, Callable +from collections.abc import Callable +from typing import Any import torch.nn as nn @@ -61,14 +62,12 @@ def __init__( "use_graph_context": False, } policy_kwargs_with_defaults.update(policy_kwargs) - policy = AttentionModelPolicy( - env_name=env.name, **policy_kwargs_with_defaults - ) + policy = AttentionModelPolicy(env_name=env.name, **policy_kwargs_with_defaults) assert baseline == "shared", "POMO only supports shared baseline" # Initialize with the shared baseline - super(POMO, self).__init__(env, policy, baseline, **kwargs) + super().__init__(env, policy, baseline, **kwargs) self.num_starts = num_starts self.num_augment = num_augment @@ -86,9 +85,7 @@ def __init__( for phase in ["train", "val", "test"]: self.set_decode_type_multistart(phase) - def shared_step( - self, batch: Any, batch_idx: int, phase: str, dataloader_idx: int = None - ): + def shared_step(self, batch: Any, batch_idx: int, phase: str, dataloader_idx: int = None): td = self.env.reset(batch) n_aug, n_start = self.num_augment, self.num_starts n_start = self.env.get_num_starts(td) if n_start is None else n_start @@ -139,9 +136,7 @@ def shared_step( out.update({"max_aug_reward": max_aug_reward}) if out.get("actions", None) is not None: - actions_ = ( - out["best_multistart_actions"] if n_start > 1 else out["actions"] - ) + actions_ = out["best_multistart_actions"] if n_start > 1 else out["actions"] out.update({"best_aug_actions": gather_by_index(actions_, max_idxs)}) metrics = self.log_metrics(out, phase, dataloader_idx=dataloader_idx) diff --git a/rl4co/models/zoo/ptrnet/critic.py b/rl4co/models/zoo/ptrnet/critic.py index efbda9ed..da02c39f 100644 --- a/rl4co/models/zoo/ptrnet/critic.py +++ b/rl4co/models/zoo/ptrnet/critic.py @@ -16,16 +16,14 @@ def __init__( tanh_exploration, use_tanh, ): - super(CriticNetworkLSTM, self).__init__() + super().__init__() self.hidden_dim = hidden_dim self.n_process_block_iters = n_process_block_iters self.encoder = Encoder(embed_dim, hidden_dim) - self.process_block = SimpleAttention( - hidden_dim, use_tanh=use_tanh, C=tanh_exploration - ) + self.process_block = SimpleAttention(hidden_dim, use_tanh=use_tanh, C=tanh_exploration) self.sm = nn.Softmax(dim=1) self.decoder = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 1) @@ -38,12 +36,8 @@ def forward(self, inputs): """ inputs = inputs.transpose(0, 1).contiguous() - encoder_hx = ( - self.encoder.init_hx.unsqueeze(0).repeat(inputs.size(1), 1).unsqueeze(0) - ) - encoder_cx = ( - self.encoder.init_cx.unsqueeze(0).repeat(inputs.size(1), 1).unsqueeze(0) - ) + encoder_hx = self.encoder.init_hx.unsqueeze(0).repeat(inputs.size(1), 1).unsqueeze(0) + encoder_cx = self.encoder.init_cx.unsqueeze(0).repeat(inputs.size(1), 1).unsqueeze(0) # encoder forward pass enc_outputs, (enc_h_t, enc_c_t) = self.encoder(inputs, (encoder_hx, encoder_cx)) diff --git a/rl4co/models/zoo/ptrnet/decoder.py b/rl4co/models/zoo/ptrnet/decoder.py index 710f03c0..631e20de 100644 --- a/rl4co/models/zoo/ptrnet/decoder.py +++ b/rl4co/models/zoo/ptrnet/decoder.py @@ -12,7 +12,7 @@ class SimpleAttention(nn.Module): """A generic attention module for a decoder in seq2seq""" def __init__(self, dim, use_tanh=False, C=10): - super(SimpleAttention, self).__init__() + super().__init__() self.use_tanh = use_tanh self.project_query = nn.Linear(dim, dim) self.project_ref = nn.Conv1d(dim, dim, 1, 1) @@ -58,7 +58,7 @@ def __init__( mask_glimpses=True, mask_logits=True, ): - super(Decoder, self).__init__() + super().__init__() self.embed_dim = embed_dim self.hidden_dim = hidden_dim @@ -76,9 +76,7 @@ def update_mask(self, mask, selected): return mask.clone().scatter_(1, selected.unsqueeze(-1), False) def recurrence(self, x, h_in, prev_mask, prev_idxs, step, context): - logit_mask = ( - self.update_mask(prev_mask, prev_idxs) if prev_idxs is not None else prev_mask - ) + logit_mask = self.update_mask(prev_mask, prev_idxs) if prev_idxs is not None else prev_mask logits, h_out = self.calc_logits( x, h_in, logit_mask, context, self.mask_glimpses, self.mask_logits @@ -92,9 +90,7 @@ def recurrence(self, x, h_in, prev_mask, prev_idxs, step, context): return h_out, log_p, logit_mask - def calc_logits( - self, x, h_in, logit_mask, context, mask_glimpses=None, mask_logits=None - ): + def calc_logits(self, x, h_in, logit_mask, context, mask_glimpses=None, mask_logits=None): if mask_glimpses is None: mask_glimpses = self.mask_glimpses @@ -152,9 +148,7 @@ def forward( ) for i in steps: - hidden, log_p, mask = self.recurrence( - decoder_input, hidden, mask, idxs, i, context - ) + hidden, log_p, mask = self.recurrence(decoder_input, hidden, mask, idxs, i, context) # select the next inputs for the decoder [batch_size x hidden_dim] idxs = ( decode_logprobs(log_p, mask, decode_type=decode_type) diff --git a/rl4co/models/zoo/ptrnet/encoder.py b/rl4co/models/zoo/ptrnet/encoder.py index 575c7430..46c6f10d 100644 --- a/rl4co/models/zoo/ptrnet/encoder.py +++ b/rl4co/models/zoo/ptrnet/encoder.py @@ -9,7 +9,7 @@ class Encoder(nn.Module): to a hidden vector""" def __init__(self, input_dim, hidden_dim): - super(Encoder, self).__init__() + super().__init__() self.hidden_dim = hidden_dim self.lstm = nn.LSTM(input_dim, hidden_dim) self.init_hx, self.init_cx = self.init_hidden(hidden_dim) diff --git a/rl4co/models/zoo/ptrnet/model.py b/rl4co/models/zoo/ptrnet/model.py index 0beea31a..c258ce29 100644 --- a/rl4co/models/zoo/ptrnet/model.py +++ b/rl4co/models/zoo/ptrnet/model.py @@ -27,7 +27,5 @@ def __init__( baseline_kwargs={}, **kwargs, ): - policy = ( - PointerNetworkPolicy(env=env, **policy_kwargs) if policy is None else policy - ) + policy = PointerNetworkPolicy(env=env, **policy_kwargs) if policy is None else policy super().__init__(env, policy, baseline, baseline_kwargs, **kwargs) diff --git a/rl4co/models/zoo/ptrnet/policy.py b/rl4co/models/zoo/ptrnet/policy.py index dc0373a3..be2d0f3b 100644 --- a/rl4co/models/zoo/ptrnet/policy.py +++ b/rl4co/models/zoo/ptrnet/policy.py @@ -22,7 +22,7 @@ def __init__( mask_logits=True, **kwargs, ): - super(PointerNetworkPolicy, self).__init__() + super().__init__() assert env_name == "tsp", "Only the Euclidean TSP env is implemented" self.env_name = env_name diff --git a/rl4co/models/zoo/symnco/losses.py b/rl4co/models/zoo/symnco/losses.py index 38f9265e..545abc9d 100644 --- a/rl4co/models/zoo/symnco/losses.py +++ b/rl4co/models/zoo/symnco/losses.py @@ -33,7 +33,5 @@ def invariance_loss(proj_embed, num_augment): Corresponds to `L_inv` in the SymNCO paper """ pe = rearrange(proj_embed, "(b a) ... -> b a ...", a=num_augment) - similarity = sum( - [cosine_similarity(pe[:, 0], pe[:, i], dim=-1) for i in range(1, num_augment)] - ) + similarity = sum([cosine_similarity(pe[:, 0], pe[:, i], dim=-1) for i in range(1, num_augment)]) return similarity.mean() diff --git a/rl4co/models/zoo/symnco/model.py b/rl4co/models/zoo/symnco/model.py index db4c26ad..0d702eca 100644 --- a/rl4co/models/zoo/symnco/model.py +++ b/rl4co/models/zoo/symnco/model.py @@ -1,4 +1,5 @@ -from typing import Any, Callable +from collections.abc import Callable +from typing import Any import torch.nn as nn @@ -72,9 +73,7 @@ def __init__( for phase in ["train", "val", "test"]: self.set_decode_type_multistart(phase) - def shared_step( - self, batch: Any, batch_idx: int, phase: str, dataloader_idx: int = None - ): + def shared_step(self, batch: Any, batch_idx: int, phase: str, dataloader_idx: int = None): td = self.env.reset(batch) n_aug, n_start = self.num_augment, self.num_starts n_start = get_num_starts(td, self.env.name) if n_start is None else n_start @@ -119,9 +118,7 @@ def shared_step( # Reshape batch to [batch, n_start, n_aug] if out.get("actions", None) is not None: actions = unbatchify(out["actions"], unbatch_dims) - out.update( - {"best_multistart_actions": gather_by_index(actions, max_idxs)} - ) + out.update({"best_multistart_actions": gather_by_index(actions, max_idxs)}) out["actions"] = actions # Get augmentation score only during inference @@ -153,9 +150,7 @@ def load_from_checkpoint( **kwargs, ): if kwargs.pop("baseline", "symnco") != "symnco": - log.warning( - "SymNCO only supports custom-symnco baseline. Setting to 'symnco'." - ) + log.warning("SymNCO only supports custom-symnco baseline. Setting to 'symnco'.") kwargs["baseline"] = "symnco" return super().load_from_checkpoint( checkpoint_path, diff --git a/rl4co/models/zoo/symnco/policy.py b/rl4co/models/zoo/symnco/policy.py index d76d660f..b86c3c37 100644 --- a/rl4co/models/zoo/symnco/policy.py +++ b/rl4co/models/zoo/symnco/policy.py @@ -39,7 +39,7 @@ def __init__( use_projection_head: bool = True, **kwargs, ): - super(SymNCOPolicy, self).__init__( + super().__init__( env_name=env_name, embed_dim=embed_dim, num_encoder_layers=num_encoder_layers, @@ -69,9 +69,9 @@ def forward( super().forward.__doc__ # trick to get docs from parent class # Ensure that if use_projection_head is True, then return_init_embeds is True - assert not ( - self.use_projection_head and not return_init_embeds - ), "If `use_projection_head` is True, then we must `return_init_embeds`" + assert not (self.use_projection_head and not return_init_embeds), ( + "If `use_projection_head` is True, then we must `return_init_embeds`" + ) out = super().forward( td, diff --git a/rl4co/tasks/README.md b/rl4co/tasks/README.md index 19c4dda7..bfd5df73 100644 --- a/rl4co/tasks/README.md +++ b/rl4co/tasks/README.md @@ -37,7 +37,7 @@ Arguments guideline: - `--method`: the evaluation method, e.g., `greedy`, `sampling`, `multistart_greedy`, `augment_dihedral_8`, `augment`, `multistart_greedy_augment_dihedral_8`, and `multistart_greedy_augment`. Default is `greedy`. - `--save-results`: whether to save the evaluation results as a `.pkl` file. Deafult is `True`. The results include `actions`, `rewards`, `inference_time`, and `avg_reward`. - `--save-path`: the path to save the evaluation results. Default is `results/`. -- `--num-instances`: the number of test instances to evaluate. Default is `1000`. +- `--num-instances`: the number of test instances to evaluate. Default is `1000`. If you use the `sampling` method, you may need to specify the following parameters: - `--samples`: the number of samples for the sampling method. Default is `1280`. diff --git a/rl4co/tasks/eval.py b/rl4co/tasks/eval.py index 3f32eacb..26ef5552 100644 --- a/rl4co/tasks/eval.py +++ b/rl4co/tasks/eval.py @@ -41,9 +41,7 @@ def __call__(self, policy, dataloader, **kwargs): rewards_list = [] actions_list = [] - for batch in tqdm( - dataloader, disable=not self.progress, desc=f"Running {self.name}" - ): + for batch in tqdm(dataloader, disable=not self.progress, desc=f"Running {self.name}"): td = batch.to(next(policy.parameters()).device) td = self.env.reset(td) actions, rewards = self._inner(policy, td, **kwargs) @@ -263,9 +261,9 @@ def __init__( assert num_starts is not None, "Must specify num_starts" self.num_starts = num_starts - assert not ( - num_augment != 8 and force_dihedral_8 - ), "Cannot force dihedral 8 when num_augment != 8" + assert not (num_augment != 8 and force_dihedral_8), ( + "Cannot force dihedral 8 when num_augment != 8" + ) self.augmentation = StateAugmentation( num_augment=num_augment, augment_fn="dihedral8" if force_dihedral_8 else "symmetric", @@ -378,7 +376,7 @@ def evaluate_policy( }, } - assert method in methods_mapping, "Method {} not found".format(method) + assert method in methods_mapping, f"Method {method} not found" # Set up the evaluation function eval_settings = methods_mapping[method] @@ -389,13 +387,11 @@ def evaluate_policy( eval_fn = func(env, **kwargs) if auto_batch_size: - assert ( - batch_size is None - ), "Cannot specify batch_size when auto_batch_size is True" + assert batch_size is None, "Cannot specify batch_size when auto_batch_size is True" batch_size = get_automatic_batch_size( eval_fn, max_batch_size=max_batch_size, start_batch_size=start_batch_size ) - print("Using automatic batch size: {}".format(batch_size)) + print(f"Using automatic batch size: {batch_size}") # Set up the dataloader dataloader = DataLoader( @@ -452,9 +448,7 @@ def evaluate_policy( default="checkpoints/am-tsp50.ckpt", help="The path of the checkpoint file", ) - parser.add_argument( - "--device", type=str, default="cuda:1", help="Device to run the evaluation" - ) + parser.add_argument("--device", type=str, default="cuda:1", help="Device to run the evaluation") # Evaluation parser.add_argument( @@ -465,9 +459,7 @@ def evaluate_policy( 'multistart_greedy', 'augment_dihedral_8', 'augment', 'multistart_greedy_augment_dihedral_8',\ 'multistart_greedy_augment'", ) - parser.add_argument( - "--temperature", type=float, default=1.0, help="Temperature for sampling" - ) + parser.add_argument("--temperature", type=float, default=1.0, help="Temperature for sampling") parser.add_argument( "--top-p", type=float, diff --git a/rl4co/tasks/train.py b/rl4co/tasks/train.py index 96a57311..526976fa 100644 --- a/rl4co/tasks/train.py +++ b/rl4co/tasks/train.py @@ -1,5 +1,3 @@ -from typing import Optional, Tuple - import hydra import lightning as L import pyrootutils @@ -19,7 +17,7 @@ @utils.task_wrapper -def run(cfg: DictConfig) -> Tuple[dict, dict]: +def run(cfg: DictConfig) -> tuple[dict, dict]: """Trains the model. Can additionally evaluate on a testset, using best weights obtained during training. This method is wrapped in optional @task_wrapper decorator, that controls the behavior during @@ -96,7 +94,7 @@ def run(cfg: DictConfig) -> Tuple[dict, dict]: @hydra.main(version_base="1.3", config_path="../../configs", config_name="main.yaml") -def train(cfg: DictConfig) -> Optional[float]: +def train(cfg: DictConfig) -> float | None: # apply extra utilities # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) utils.extras(cfg) diff --git a/rl4co/utils/callbacks/speed_monitor.py b/rl4co/utils/callbacks/speed_monitor.py index 3f1ab6ae..700ed481 100644 --- a/rl4co/utils/callbacks/speed_monitor.py +++ b/rl4co/utils/callbacks/speed_monitor.py @@ -32,9 +32,7 @@ def __init__( def on_train_start(self, trainer: "L.Trainer", L_module: "L.LightningModule") -> None: self._snap_epoch_time = None - def on_train_epoch_start( - self, trainer: "L.Trainer", L_module: "L.LightningModule" - ) -> None: + def on_train_epoch_start(self, trainer: "L.Trainer", L_module: "L.LightningModule") -> None: self._snap_intra_step_time = None self._snap_inter_step_time = None self._snap_epoch_time = time.time() @@ -44,9 +42,7 @@ def on_validation_epoch_start( ) -> None: self._snap_inter_step_time = None - def on_test_epoch_start( - self, trainer: "L.Trainer", L_module: "L.LightningModule" - ) -> None: + def on_test_epoch_start(self, trainer: "L.Trainer", L_module: "L.LightningModule") -> None: self._snap_inter_step_time = None @rank_zero_only @@ -65,9 +61,7 @@ def on_train_batch_start( logs = {} if self._log_stats.inter_step_time and self._snap_inter_step_time: # First log at beginning of second step - logs["time/inter_step (ms)"] = ( - time.time() - self._snap_inter_step_time - ) * 1000 + logs["time/inter_step (ms)"] = (time.time() - self._snap_inter_step_time) * 1000 if trainer.logger is not None: trainer.logger.log_metrics(logs, step=trainer.global_step) @@ -83,11 +77,7 @@ def on_train_batch_end( if self._log_stats.inter_step_time: self._snap_inter_step_time = time.time() - if ( - self.verbose - and self._log_stats.intra_step_time - and self._snap_intra_step_time - ): + if self.verbose and self._log_stats.intra_step_time and self._snap_intra_step_time: L_module.print( f"time/intra_step (ms): {(time.time() - self._snap_intra_step_time) * 1000}" ) @@ -97,9 +87,7 @@ def on_train_batch_end( logs = {} if self._log_stats.intra_step_time and self._snap_intra_step_time: - logs["time/intra_step (ms)"] = ( - time.time() - self._snap_intra_step_time - ) * 1000 + logs["time/intra_step (ms)"] = (time.time() - self._snap_intra_step_time) * 1000 if trainer.logger is not None: trainer.logger.log_metrics(logs, step=trainer.global_step) @@ -118,6 +106,4 @@ def on_train_epoch_end( @staticmethod def _should_log(trainer) -> bool: - return ( - trainer.global_step + 1 - ) % trainer.log_every_n_steps == 0 or trainer.should_stop + return (trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop diff --git a/rl4co/utils/decoding.py b/rl4co/utils/decoding.py index 4790b237..5f0ef5a2 100644 --- a/rl4co/utils/decoding.py +++ b/rl4co/utils/decoding.py @@ -1,6 +1,6 @@ import abc -from typing import Optional, Tuple +from collections.abc import Callable import torch import torch.nn.functional as F @@ -53,9 +53,7 @@ def get_log_likelihood(logprobs, actions=None, mask=None, return_sum: bool = Tru if mask is not None: logprobs[~mask] = 0 - assert ( - logprobs > -1000 - ).data.all(), "Logprobs should not be -inf, check sampling procedure!" + assert (logprobs > -1000).data.all(), "Logprobs should not be -inf, check sampling procedure!" # Calculate log_likelihood if return_sum: @@ -73,7 +71,7 @@ def decode_logprobs(logprobs, mask, decode_type="sampling"): elif "sampling" in decode_type: selected = DecodingStrategy.sampling(logprobs, mask) else: - assert False, "Unknown decode type: {}".format(decode_type) + assert False, f"Unknown decode type: {decode_type}" return selected @@ -221,11 +219,11 @@ def __init__( top_k: int = 0, mask_logits: bool = True, tanh_clipping: float = 0, - num_samples: Optional[int] = None, + num_samples: int | None = None, multisample: bool = False, - num_starts: Optional[int] = None, + num_starts: int | None = None, multistart: bool = False, - select_start_nodes_fn: Optional[callable] = None, + select_start_nodes_fn: Callable | None = None, improvement_method_mode: bool = False, select_best: bool = False, store_all_logp: bool = False, @@ -237,13 +235,13 @@ def __init__( self.mask_logits = mask_logits self.tanh_clipping = tanh_clipping # check if multistart (POMO) and multisample flags - assert not ( - multistart and multisample - ), "Using both multistart and multisample is not supported" + assert not (multistart and multisample), ( + "Using both multistart and multisample is not supported" + ) if num_samples and num_starts: - assert not ( - num_samples > 1 and num_starts > 1 - ), f"num_samples={num_samples} and num_starts={num_starts} are both > 1" + assert not (num_samples > 1 and num_starts > 1), ( + f"num_samples={num_samples} and num_starts={num_starts} are both > 1" + ) if num_samples is not None: multisample = True if num_samples > 1 else False if num_starts is not None: @@ -267,10 +265,10 @@ def _step( self, logprobs: torch.Tensor, mask: torch.Tensor, - td: Optional[TensorDict] = None, - action: Optional[torch.Tensor] = None, + td: TensorDict | None = None, + action: torch.Tensor | None = None, **kwargs, - ) -> Tuple[torch.Tensor, torch.Tensor, TensorDict]: + ) -> tuple[torch.Tensor, torch.Tensor, TensorDict]: """Main decoding operation. This method should be called in a loop until all sequences are done. Args: @@ -282,7 +280,7 @@ def _step( raise NotImplementedError("Must be implemented by subclass") def pre_decoder_hook( - self, td: TensorDict, env: RL4COEnvBase, action: Optional[torch.Tensor] = None + self, td: TensorDict, env: RL4COEnvBase, action: torch.Tensor | None = None ): """Pre decoding hook. This method is called before the main decoding operation.""" @@ -333,10 +331,10 @@ def pre_decoder_hook( def post_decoder_hook( self, td: TensorDict, env: RL4COEnvBase - ) -> Tuple[torch.Tensor, torch.Tensor, TensorDict, RL4COEnvBase]: - assert ( - len(self.logprobs) > 0 - ), "No logprobs were collected because all environments were done. Check your initial state" + ) -> tuple[torch.Tensor, torch.Tensor, TensorDict, RL4COEnvBase]: + assert len(self.logprobs) > 0, ( + "No logprobs were collected because all environments were done. Check your initial state" + ) logprobs = torch.stack(self.logprobs, 1) actions = torch.stack(self.actions, 1) if self.num_starts > 0 and self.select_best: @@ -347,8 +345,8 @@ def step( self, logits: torch.Tensor, mask: torch.Tensor, - td: Optional[TensorDict] = None, - action: Optional[torch.Tensor] = None, + td: TensorDict | None = None, + action: torch.Tensor | None = None, **kwargs, ) -> TensorDict: """Main decoding operation. This method should be called in a loop until all sequences are done. @@ -371,9 +369,7 @@ def step( tanh_clipping=self.tanh_clipping, mask_logits=self.mask_logits, ) - logprobs, selected_action, td = self._step( - logprobs, mask, td, action=action, **kwargs - ) + logprobs, selected_action, td = self._step(logprobs, mask, td, action=action, **kwargs) # directly return for improvement methods, since the action for improvement methods is finalized in its own policy if self.improvement_method_mode: @@ -394,9 +390,9 @@ def greedy(logprobs, mask=None): # [BS], [BS] selected = logprobs.argmax(dim=-1) if mask is not None: - assert ( - not (~mask).gather(1, selected.unsqueeze(-1)).data.any() - ), "infeasible action selected" + assert not (~mask).gather(1, selected.unsqueeze(-1)).data.any(), ( + "infeasible action selected" + ) return selected @@ -410,9 +406,9 @@ def sampling(logprobs, mask=None): while (~mask).gather(1, selected.unsqueeze(-1)).data.any(): log.info("Sampled bad values, resampling!") selected = probs.multinomial(1).squeeze(1) - assert ( - not (~mask).gather(1, selected.unsqueeze(-1)).data.any() - ), "infeasible action selected" + assert not (~mask).gather(1, selected.unsqueeze(-1)).data.any(), ( + "infeasible action selected" + ) return selected @@ -432,7 +428,7 @@ class Greedy(DecodingStrategy): def _step( self, logprobs: torch.Tensor, mask: torch.Tensor, td: TensorDict, **kwargs - ) -> Tuple[torch.Tensor, torch.Tensor, TensorDict]: + ) -> tuple[torch.Tensor, torch.Tensor, TensorDict]: """Select the action with the highest log probability""" selected = self.greedy(logprobs, mask) return logprobs, selected, td @@ -443,7 +439,7 @@ class Sampling(DecodingStrategy): def _step( self, logprobs: torch.Tensor, mask: torch.Tensor, td: TensorDict, **kwargs - ) -> Tuple[torch.Tensor, torch.Tensor, TensorDict]: + ) -> tuple[torch.Tensor, torch.Tensor, TensorDict]: """Sample an action with a multinomial distribution given by the log probabilities.""" selected = self.sampling(logprobs, mask) return logprobs, selected, td @@ -459,7 +455,7 @@ def _step( td: TensorDict, action: torch.Tensor, **kwargs, - ) -> Tuple[torch.Tensor, torch.Tensor, TensorDict]: + ) -> tuple[torch.Tensor, torch.Tensor, TensorDict]: """The action is provided externally, so we just return the action""" selected = action return logprobs, selected, td @@ -479,16 +475,16 @@ def __init__(self, beam_width=None, select_best=True, **kwargs) -> None: def _step( self, logprobs: torch.Tensor, mask: torch.Tensor, td: TensorDict, **kwargs - ) -> Tuple[torch.Tensor, torch.Tensor, TensorDict]: + ) -> tuple[torch.Tensor, torch.Tensor, TensorDict]: selected, batch_beam_idx = self._make_beam_step(logprobs) # select the correct state representation, logprobs and mask according to beam parent td = td[batch_beam_idx] logprobs = logprobs[batch_beam_idx] mask = mask[batch_beam_idx] - assert ( - not (~mask).gather(1, selected.unsqueeze(-1)).data.any() - ), "infeasible action selected" + assert not (~mask).gather(1, selected.unsqueeze(-1)).data.any(), ( + "infeasible action selected" + ) return logprobs, selected, td @@ -533,9 +529,9 @@ def _backtrack(self): actions = torch.stack(self.actions, 1) # [BS*BW, seq_len] logprobs = torch.stack(self.logprobs, 1) - assert actions.size(1) == len( - self.beam_path - ), "action idx shape and beam path shape dont match" + assert actions.size(1) == len(self.beam_path), ( + "action idx shape and beam path shape dont match" + ) # [BS*BW] cur_parent = self.beam_path[-1] @@ -545,9 +541,7 @@ def _backtrack(self): aug_batch_size = actions.size(0) batch_size = aug_batch_size // self.beam_width - batch_beam_sequence = ( - torch.arange(0, batch_size).repeat(self.beam_width).to(actions.device) - ) + batch_beam_sequence = torch.arange(0, batch_size).repeat(self.beam_width).to(actions.device) for k in reversed(range(len(self.beam_path) - 1)): batch_beam_idx = batch_beam_sequence + cur_parent * batch_size @@ -583,9 +577,7 @@ def _make_beam_step(self, logprobs: torch.Tensor): # [BS, num_nodes * BW] log_beam_prob_hstacked = torch.cat(log_beam_prob.split(batch_size), dim=1) # [BS, BW] - topk_logprobs, topk_ind = torch.topk( - log_beam_prob_hstacked, self.beam_width, dim=1 - ) + topk_logprobs, topk_ind = torch.topk(log_beam_prob_hstacked, self.beam_width, dim=1) # [BS*BW, 1] logprobs_selected = torch.hstack(torch.unbind(topk_logprobs, 1)).unsqueeze(1) diff --git a/rl4co/utils/lightning.py b/rl4co/utils/lightning.py index a3f29cb7..562b3c31 100644 --- a/rl4co/utils/lightning.py +++ b/rl4co/utils/lightning.py @@ -34,9 +34,7 @@ def remove_key(config, key="wandb"): return new_config -def clean_hydra_config( - config, keep_value_only=True, remove_keys="wandb", clean_cfg_path=True -): +def clean_hydra_config(config, keep_value_only=True, remove_keys="wandb", clean_cfg_path=True): """Clean hydra config by nesting dictionary and cleaning values""" # Remove keys containing `remove_keys` if not isinstance(remove_keys, list): diff --git a/rl4co/utils/meta_trainer.py b/rl4co/utils/meta_trainer.py index ccd64352..3f0fc2ba 100644 --- a/rl4co/utils/meta_trainer.py +++ b/rl4co/utils/meta_trainer.py @@ -1,18 +1,20 @@ +import copy +import math +import random + import lightning.pytorch as pl import torch -import math -import copy -from torch.optim import Adam from lightning import Callback +from torch.optim import Adam + from rl4co import utils -import random + log = utils.get_pylogger(__name__) class ReptileCallback(Callback): - - """ Meta training framework for addressing the generalization issue (implement the Reptile algorithm only) + """Meta training framework for addressing the generalization issue (implement the Reptile algorithm only) Based on Manchanda et al. 2022 (https://arxiv.org/abs/2206.00787) and Zhou et al. 2023 (https://arxiv.org/abs/2305.19587) Args: @@ -25,16 +27,18 @@ class ReptileCallback(Callback): - data_type: type of the tasks, chosen from ["size", "distribution", "size_distribution"] - print_log: whether to print the specific task sampled in each inner-loop optimization """ - def __init__(self, - num_tasks: int, - alpha: float, - alpha_decay: float, - min_size: int, - max_size: int, - sch_bar: float = 0.9, - data_type: str = "size", - print_log: bool =True): + def __init__( + self, + num_tasks: int, + alpha: float, + alpha_decay: float, + min_size: int, + max_size: int, + sch_bar: float = 0.9, + data_type: str = "size", + print_log: bool = True, + ): super().__init__() self.num_tasks = num_tasks @@ -46,7 +50,6 @@ def __init__(self, self.task_set = self._generate_task_set(data_type, min_size, max_size) def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: - # Sample a batch of tasks self._sample_task() @@ -56,51 +59,66 @@ def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> No self.selected_tasks[0] = (pl_module.env.generator.num_loc, 0, 0) elif self.data_type == "size": pl_module.env.generator.loc_distribution = "uniform" - self.selected_tasks[0] = (pl_module.env.generator.num_loc, ) + self.selected_tasks[0] = (pl_module.env.generator.num_loc,) elif self.data_type == "distribution": pl_module.env.generator.loc_distribution = "gaussian_mixture" self.selected_tasks[0] = (0, 0) self.task_params = self.selected_tasks[0] def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: - # Alpha scheduler (decay for the update of meta model) self._alpha_scheduler() # Reinitialize the task model with the parameters of the meta model - if trainer.current_epoch % self.num_tasks == 0: # Save the meta model + if trainer.current_epoch % self.num_tasks == 0: # Save the meta model self.meta_model_state_dict = copy.deepcopy(pl_module.state_dict()) self.task_models = [] # Print sampled tasks if self.print_log: - print('\n>> Meta epoch: {} (Exact epoch: {}), Training task: {}'.format(trainer.current_epoch//self.num_tasks, trainer.current_epoch, self.selected_tasks)) + print( + f"\n>> Meta epoch: {trainer.current_epoch // self.num_tasks} (Exact epoch: {trainer.current_epoch}), Training task: {self.selected_tasks}" + ) else: pl_module.load_state_dict(self.meta_model_state_dict) # Reinitialize the optimizer every epoch - lr_decay = 0.1 if trainer.current_epoch+1 == int(self.sch_bar * trainer.max_epochs) else 1 - old_lr = trainer.optimizers[0].param_groups[0]['lr'] + lr_decay = 0.1 if trainer.current_epoch + 1 == int(self.sch_bar * trainer.max_epochs) else 1 + old_lr = trainer.optimizers[0].param_groups[0]["lr"] new_optimizer = Adam(pl_module.parameters(), lr=old_lr * lr_decay) trainer.optimizers = [new_optimizer] # Print if self.print_log: - if hasattr(pl_module.env.generator, 'capacity'): - print('>> Training task: {}, capacity: {}'.format(self.task_params, pl_module.env.generator.capacity)) + if hasattr(pl_module.env.generator, "capacity"): + print( + f">> Training task: {self.task_params}, capacity: {pl_module.env.generator.capacity}" + ) else: - print('>> Training task: {}'.format(self.task_params)) - - def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): + print(f">> Training task: {self.task_params}") + def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): # Save the task model self.task_models.append(copy.deepcopy(pl_module.state_dict())) - if (trainer.current_epoch+1) % self.num_tasks == 0: + if (trainer.current_epoch + 1) % self.num_tasks == 0: # Outer-loop optimization (update the meta model with the parameters of the task model) with torch.no_grad(): - state_dict = {params_key: (self.meta_model_state_dict[params_key] + - self.alpha * torch.mean(torch.stack([fast_weight[params_key] - self.meta_model_state_dict[params_key] - for fast_weight in self.task_models], dim=0).float(), dim=0)) - for params_key in self.meta_model_state_dict} + state_dict = { + params_key: ( + self.meta_model_state_dict[params_key] + + self.alpha + * torch.mean( + torch.stack( + [ + fast_weight[params_key] - self.meta_model_state_dict[params_key] + for fast_weight in self.task_models + ], + dim=0, + ).float(), + dim=0, + ) + ) + for params_key in self.meta_model_state_dict + } pl_module.load_state_dict(state_dict) # Get ready for the next meta-training iteration @@ -109,10 +127,9 @@ def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule self._sample_task() # Load new training task (Update the environment) for the next meta-training iteration - self._load_task(pl_module, task_idx = (trainer.current_epoch+1) % self.num_tasks) + self._load_task(pl_module, task_idx=(trainer.current_epoch + 1) % self.num_tasks) def _sample_task(self): - # Sample a batch of tasks self.selected_tasks = [] for b in range(self.num_tasks): @@ -120,7 +137,6 @@ def _sample_task(self): self.selected_tasks.append(task_params) def _load_task(self, pl_module: pl.LightningModule, task_idx=0): - # Load new training task (Update the environment) self.task_params = self.selected_tasks[task_idx] @@ -129,23 +145,28 @@ def _load_task(self, pl_module: pl.LightningModule, task_idx=0): pl_module.env.generator.num_loc = self.task_params[0] pl_module.env.generator.num_modes = self.task_params[1] pl_module.env.generator.cdist = self.task_params[2] - elif self.data_type == "distribution": # fixed size + elif self.data_type == "distribution": # fixed size assert len(self.task_params) == 2 pl_module.env.generator.num_modes = self.task_params[0] pl_module.env.generator.cdist = self.task_params[1] - elif self.data_type == "size": # fixed distribution + elif self.data_type == "size": # fixed distribution assert len(self.task_params) == 1 pl_module.env.generator.num_loc = self.task_params[0] - if hasattr(pl_module.env.generator, 'capacity') and self.data_type in ["size_distribution", "size"]: - task_capacity = math.ceil(30 + self.task_params[0] / 5) if self.task_params[0] >= 20 else 20 + if hasattr(pl_module.env.generator, "capacity") and self.data_type in [ + "size_distribution", + "size", + ]: + task_capacity = ( + math.ceil(30 + self.task_params[0] / 5) if self.task_params[0] >= 20 else 20 + ) pl_module.env.generator.capacity = task_capacity def _alpha_scheduler(self): self.alpha = max(self.alpha * self.alpha_decay, 0.0001) def _generate_task_set(self, data_type, min_size, max_size): - """ + r""" Following the setting in Zhou et al. 2023 (https://arxiv.org/abs/2305.19587) Current setting: size: (n,) \in [20, 150] @@ -163,8 +184,7 @@ def _generate_task_set(self, data_type, min_size, max_size): else: raise NotImplementedError - print(">> Generating training task set: {} tasks with type {}".format(len(task_set), data_type)) - print(">> Training task set: {}".format(task_set)) + print(f">> Generating training task set: {len(task_set)} tasks with type {data_type}") + print(f">> Training task set: {task_set}") return task_set - diff --git a/rl4co/utils/ops.py b/rl4co/utils/ops.py index 8249ec6f..c2823eea 100644 --- a/rl4co/utils/ops.py +++ b/rl4co/utils/ops.py @@ -1,5 +1,4 @@ from functools import lru_cache -from typing import Optional import torch @@ -47,9 +46,7 @@ def unbatchify(x: Tensor | TensorDict, shape: tuple | int) -> Tensor | TensorDic >>> out.shape: [a, b, c, ...] """ shape = [shape] if isinstance(shape, int) else shape - for s in reversed( - shape - ): # we need to reverse the shape to unbatchify in the right order + for s in reversed(shape): # we need to reverse the shape to unbatchify in the right order x = _unbatchify_single(x, s) if s > 0 else x return x @@ -141,17 +138,14 @@ def select_start_nodes(td, env, num_starts): num_loc = env.generator.num_loc if hasattr(env.generator, "num_loc") else 0xFFFFFFFF if env.name in ["tsp", "atsp", "flp", "mcp"]: selected = ( - torch.arange(num_starts, device=td.device).repeat_interleave(td.shape[0]) - % num_loc + torch.arange(num_starts, device=td.device).repeat_interleave(td.shape[0]) % num_loc ) elif env.name in ["jssp", "fjsp"]: raise NotImplementedError("Multistart not yet supported for FJSP/JSSP") else: # Environments with depot: we do not select the depot as a start node selected = ( - torch.arange(num_starts, device=td.device).repeat_interleave(td.shape[0]) - % num_loc - + 1 + torch.arange(num_starts, device=td.device).repeat_interleave(td.shape[0]) % num_loc + 1 ) if env.name == "op": if (td["action_mask"][..., 1:].float().sum(-1) < num_starts).any(): @@ -172,7 +166,7 @@ def get_best_actions(actions, max_idxs): return actions.gather(0, max_idxs[..., None, None]) -def sparsify_graph(cost_matrix: Tensor, k_sparse: Optional[int] = None, self_loop=False): +def sparsify_graph(cost_matrix: Tensor, k_sparse: int | None = None, self_loop=False): """Generate a sparsified graph for the cost_matrix by selecting k edges with the lowest cost for each node. Args: @@ -246,7 +240,7 @@ def sample_n_random_actions(td: TensorDict, n: int): replace = True else: replace = False - ps = torch.rand((action_mask.shape)) + ps = torch.rand(action_mask.shape) ps[~action_mask] = -torch.inf ps = torch.softmax(ps, dim=1) selected = torch.multinomial(ps, n, replacement=replace).squeeze(1) @@ -254,7 +248,7 @@ def sample_n_random_actions(td: TensorDict, n: int): return selected.to(td.device) -def cartesian_to_polar(cartesian: torch.Tensor, origin: Optional[torch.Tensor] = None): +def cartesian_to_polar(cartesian: torch.Tensor, origin: torch.Tensor | None = None): """Convert Cartesian coordinates to polar coordinates. Args: @@ -278,9 +272,7 @@ def select_start_nodes_by_distance(td, env, num_starts, exclude_depot=True): radius = torch.norm(td["locs"], dim=-1) else: radius = polar_locs[..., 0] - _, node_index = torch.topk( - radius, k=num_starts + 1, dim=-1, sorted=True, largest=False - ) + _, node_index = torch.topk(radius, k=num_starts + 1, dim=-1, sorted=True, largest=False) selected_nodes = node_index[:, 1:] if exclude_depot else node_index[:, :-1] return rearrange(selected_nodes, "b n -> (n b)") @@ -294,7 +286,7 @@ def batched_scatter_sum(src, idx): idx (Tensor): A tensor of shape [batch_size, M, K] with zero-padding. Each non-zero element in idx represents an index (offset by 1) into src. A zero value indicates a padded (invalid) index. - + Returns: Tensor: A tensor of shape [batch_size, M, h] where for each batch and each index j, the output is computed as: @@ -305,11 +297,11 @@ def batched_scatter_sum(src, idx): - A temporary target tensor (tgt) of shape [batch_size, N+1, h] is created, where tgt[:, 1:] is populated with src. - The function reshapes idx to gather the corresponding values and then reshapes - the result back to [batch_size, M, K, h] before summing over the scattering dimension. + the result back to [batch_size, M, K, h] before summing over the scattering dimension. """ bs, N, h = src.shape bs, M, K = idx.shape tgt = torch.zeros(bs, N + 1, h, device=src.device) tgt[:, 1:] = src tgt = gather_by_index(tgt, idx.long().reshape(bs, -1), squeeze=False) - return tgt.reshape(bs, M, K, h).sum(-2) \ No newline at end of file + return tgt.reshape(bs, M, K, h).sum(-2) diff --git a/rl4co/utils/optim_helpers.py b/rl4co/utils/optim_helpers.py index 46367a37..20a15968 100644 --- a/rl4co/utils/optim_helpers.py +++ b/rl4co/utils/optim_helpers.py @@ -1,6 +1,7 @@ import inspect import torch + from torch.optim import Optimizer diff --git a/rl4co/utils/rich_utils.py b/rl4co/utils/rich_utils.py index 652ba568..bbe049e3 100644 --- a/rl4co/utils/rich_utils.py +++ b/rl4co/utils/rich_utils.py @@ -1,5 +1,5 @@ +from collections.abc import Sequence from pathlib import Path -from typing import Sequence import rich import rich.syntax diff --git a/rl4co/utils/trainer.py b/rl4co/utils/trainer.py index 497e6e8b..6c79d641 100644 --- a/rl4co/utils/trainer.py +++ b/rl4co/utils/trainer.py @@ -1,4 +1,4 @@ -from typing import Iterable, Optional +from collections.abc import Iterable import lightning.pytorch as pl import torch @@ -39,17 +39,18 @@ class RL4COTrainer(Trainer): disable_profiling_executor: Disable JIT profiling executor. This reduces memory and increases speed. auto_configure_ddp: Automatically configure DDP strategy if multiple GPUs are available. reload_dataloaders_every_n_epochs: Set to a value different than 1 to reload dataloaders every n epochs. - matmul_precision: Set matmul precision for faster inference https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision + matmul_precision: Set matmul precision for faster inference + See: https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision **kwargs: Additional keyword arguments passed to the Lightning Trainer. See :class:`lightning.pytorch.trainer.Trainer` for details. """ def __init__( self, accelerator: str | Accelerator = "auto", - callbacks: Optional[list[Callback]] = None, - logger: Optional[Logger | Iterable[Logger]] = None, - min_epochs: Optional[int] = None, - max_epochs: Optional[int] = None, + callbacks: list[Callback] | None = None, + logger: Logger | Iterable[Logger] | None = None, + min_epochs: int | None = None, + max_epochs: int | None = None, strategy: str | Strategy = "auto", devices: list[int] | str | int = "auto", gradient_clip_val: int | float = 1.0, @@ -78,11 +79,7 @@ def __init__( else: n_devices = devices if n_devices > 1: - log.info( - "Configuring DDP strategy automatically with {} GPUs".format( - n_devices - ) - ) + log.info(f"Configuring DDP strategy automatically with {n_devices} GPUs") strategy = DDPStrategy( find_unused_parameters=True, # We set to True due to RL envs gradient_as_bucket_view=True, # https://pytorch-lightning.readthedocs.io/en/stable/advanced/advanced_gpu.html#ddp-optimizations @@ -94,15 +91,13 @@ def __init__( # Check if gradient_clip_val is set to None if gradient_clip_val is None: - log.warning( - "gradient_clip_val is set to None. This may lead to unstable training." - ) + log.warning("gradient_clip_val is set to None. This may lead to unstable training.") # We should reload dataloaders every epoch for RL training if reload_dataloaders_every_n_epochs != 1: log.warning( - "We reload dataloaders every epoch for RL training. Setting reload_dataloaders_every_n_epochs to a value different than 1 " - + "may lead to unexpected behavior since the initial conditions will be the same for `n_epochs` epochs." + "We reload dataloaders every epoch for RL training. Setting reload_dataloaders_every_n_epochs \ + different than 1 may lead to unexpected behavior since the conditions will be the same for `n_epochs`." ) # Main call to `Trainer` superclass @@ -123,10 +118,10 @@ def __init__( def fit( self, model: "pl.LightningModule", - train_dataloaders: Optional[TRAIN_DATALOADERS | LightningDataModule] = None, - val_dataloaders: Optional[EVAL_DATALOADERS] = None, - datamodule: Optional[LightningDataModule] = None, - ckpt_path: Optional[str] = None, + train_dataloaders: TRAIN_DATALOADERS | LightningDataModule | None = None, + val_dataloaders: EVAL_DATALOADERS | None = None, + datamodule: LightningDataModule | None = None, + ckpt_path: str | None = None, ) -> None: """ We override the `fit` method to automatically apply and handle RL4CO magic diff --git a/rl4co/utils/utils.py b/rl4co/utils/utils.py index a41e374e..3709b737 100644 --- a/rl4co/utils/utils.py +++ b/rl4co/utils/utils.py @@ -3,8 +3,8 @@ import sys import warnings +from collections.abc import Callable from importlib.util import find_spec -from typing import Callable import hydra @@ -12,9 +12,7 @@ from lightning.pytorch.loggers.logger import Logger # Import the necessary PyTorch Lightning component -from lightning.pytorch.trainer.connectors.accelerator_connector import ( - _AcceleratorConnector, -) +from lightning.pytorch.trainer.connectors.accelerator_connector import _AcceleratorConnector from lightning.pytorch.utilities.rank_zero import rank_zero_only from omegaconf import DictConfig, OmegaConf @@ -275,11 +273,11 @@ def show_versions(): version = "Not installed" print(f"{name.rjust(longest_name)} : {version}") # platform information - print(f'{"Python".rjust(longest_name)} : {sys.version.split()[0]}') - print(f'{"Platform".rjust(longest_name)} : {platform.platform()}') + print(f"{'Python'.rjust(longest_name)} : {sys.version.split()[0]}") + print(f"{'Platform'.rjust(longest_name)} : {platform.platform()}") try: lightning_auto_device = _AcceleratorConnector()._choose_auto_accelerator(None) except Exception: lightning_auto_device = _AcceleratorConnector()._choose_auto_accelerator() # lightning hardware accelerators - print(f'{"Lightning device".rjust(longest_name)} : {lightning_auto_device}') + print(f"{'Lightning device'.rjust(longest_name)} : {lightning_auto_device}") diff --git a/tests/test_envs.py b/tests/test_envs.py index 25b132fa..4b568b09 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -117,9 +117,7 @@ def test_jssp_lb(env_cls): env = env_cls(generator_params={"num_jobs": 2, "num_machines": 2}) td = TensorDict( { - "proc_times": torch.tensor( - [[[1, 0, 0, 4], [0, 2, 3, 0]]], dtype=torch.float32 - ), + "proc_times": torch.tensor([[[1, 0, 0, 4], [0, 2, 3, 0]]], dtype=torch.float32), "start_op_per_job": torch.tensor([[0, 2]], dtype=torch.long), "end_op_per_job": torch.tensor([[1, 3]], dtype=torch.long), "pad_mask": torch.tensor([[0, 0, 0, 0]], dtype=torch.bool), diff --git a/tests/test_policy.py b/tests/test_policy.py index da813180..5ea7ab86 100644 --- a/tests/test_policy.py +++ b/tests/test_policy.py @@ -63,9 +63,7 @@ def test_beam_search(env_name, select_best, size=20, batch_size=2): td = env.reset(x) policy = AttentionModelPolicy(env_name=env.name) beam_width = size // 2 if env.name in ["pdp"] else size - out = policy( - td, env, decode_type="beam_search", beam_width=beam_width, select_best=select_best - ) + out = policy(td, env, decode_type="beam_search", beam_width=beam_width, select_best=select_best) if select_best: expected_shape = (batch_size,) diff --git a/tests/test_training.py b/tests/test_training.py index b4f8b2eb..ca0ba2ff 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -74,9 +74,7 @@ def test_ppo(): env = TSPEnv(generator_params=dict(num_loc=20)) policy = AttentionModelPolicy(env_name=env.name) model = PPO(env, policy, train_data_size=10, val_data_size=10, test_data_size=10) - trainer = RL4COTrainer( - max_epochs=1, gradient_clip_val=None, devices=1, accelerator=accelerator - ) + trainer = RL4COTrainer(max_epochs=1, gradient_clip_val=None, devices=1, accelerator=accelerator) trainer.fit(model) trainer.test(model) @@ -183,25 +181,17 @@ def test_search_methods(SearchMethod): trainer.test(model) -@pytest.mark.skipif( - "torch_geometric" not in sys.modules, reason="PyTorch Geometric not installed" -) +@pytest.mark.skipif("torch_geometric" not in sys.modules, reason="PyTorch Geometric not installed") def test_nargnn(): env = TSPEnv(generator_params=dict(num_loc=20)) policy = NARGNNPolicy(env_name=env.name) - model = REINFORCE( - env, policy=policy, train_data_size=10, val_data_size=10, test_data_size=10 - ) - trainer = RL4COTrainer( - max_epochs=1, gradient_clip_val=None, devices=1, accelerator=accelerator - ) + model = REINFORCE(env, policy=policy, train_data_size=10, val_data_size=10, test_data_size=10) + trainer = RL4COTrainer(max_epochs=1, gradient_clip_val=None, devices=1, accelerator=accelerator) trainer.fit(model) trainer.test(model) -@pytest.mark.skipif( - "torch_geometric" not in sys.modules, reason="PyTorch Geometric not installed" -) +@pytest.mark.skipif("torch_geometric" not in sys.modules, reason="PyTorch Geometric not installed") @pytest.mark.skipif("numba" not in sys.modules, reason="Numba not installed") @pytest.mark.parametrize("use_local_search", [False]) def test_deepaco(use_local_search): @@ -214,16 +204,12 @@ def test_deepaco(use_local_search): train_with_local_search=use_local_search, policy_kwargs={"n_ants": 5, "aco_kwargs": {"use_local_search": use_local_search}}, ) - trainer = RL4COTrainer( - max_epochs=1, gradient_clip_val=1, devices=1, accelerator=accelerator - ) + trainer = RL4COTrainer(max_epochs=1, gradient_clip_val=1, devices=1, accelerator=accelerator) trainer.fit(model) trainer.test(model) -@pytest.mark.skipif( - "torch_geometric" not in sys.modules, reason="PyTorch Geometric not installed" -) +@pytest.mark.skipif("torch_geometric" not in sys.modules, reason="PyTorch Geometric not installed") @pytest.mark.parametrize( "Environment", [TSPEnv] if "numba" not in sys.modules else [TSPEnv, CVRPMVCEnv] ) @@ -243,9 +229,7 @@ def dummy_solver(c): "subprob_solver": dummy_solver, }, ) - trainer = RL4COTrainer( - max_epochs=1, gradient_clip_val=1, devices=1, accelerator=accelerator - ) + trainer = RL4COTrainer(max_epochs=1, gradient_clip_val=1, devices=1, accelerator=accelerator) trainer.fit(model) trainer.test(model) diff --git a/tests/test_utils.py b/tests/test_utils.py index c0f6041a..594d734a 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -13,9 +13,7 @@ "a", [ torch.randn(10, 20, 2), - TensorDict( - {"a": torch.randn(10, 20, 2), "b": torch.randn(10, 20, 2)}, batch_size=10 - ), + TensorDict({"a": torch.randn(10, 20, 2), "b": torch.randn(10, 20, 2)}, batch_size=10), ], ) @pytest.mark.parametrize("shape", [(2,), (2, 2), (2, 2, 2)])