From 04392c4a4725f1dd9afa679a464d32bc3987d832 Mon Sep 17 00:00:00 2001 From: Ali Afzal Date: Wed, 1 Oct 2025 18:38:57 -0700 Subject: [PATCH] Add plan paste for easily accessing the sharding plan (#3389) Summary: X-link: https://github.com/pytorch/torchrec/pull/3389 internal This diff introduces additional logging of the complete sharding plan in a human-readable format, making it easily accessible for any review and analysis through the planner db dataset. Differential Revision: D82945862 --- torchrec/distributed/planner/types.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/torchrec/distributed/planner/types.py b/torchrec/distributed/planner/types.py index fe676cf9f..92606e0f6 100644 --- a/torchrec/distributed/planner/types.py +++ b/torchrec/distributed/planner/types.py @@ -667,15 +667,22 @@ def __deepcopy__( return result def __str__(self) -> str: - str_obj: str = "" - str_obj += f"name: {self.name}" - str_obj += f"\nsharding type: {self.sharding_type}" - str_obj += f"\ncompute kernel: {self.compute_kernel}" - str_obj += f"\nnum shards: {len(self.shards)}" - for shard in self.shards: - str_obj += f"\n\t{str(shard)}" - - return str_obj + tensor_metadata = f"{{shape: {tuple(self.tensor.shape)}, dtype: {self.tensor.dtype}, device: {self.tensor.device}}}" + shards_str = f"[{', '.join(str(shard) for shard in self.shards)}]" + return f"""{{ + "name": "{self.name}", + "module_fqn": "{self.module[0]}", + "tensor": {tensor_metadata}, + "input_lengths": {self.input_lengths}, + "batch_size": {self.batch_size}, + "sharding_type": "{self.sharding_type}", + "compute_kernel": "{self.compute_kernel}", + "shards": {shards_str}, + "is_pooled": {self.is_pooled if self.module[1] else None}, + "feature_names": {self.feature_names}, + "cache_params": {self.cache_params}, + "is_weighted": {self.is_weighted} + }}""" class PartitionByType(Enum):