Skip to content

Commit 6e21cd4

Browse files
aliafzalfacebook-github-bot
authored andcommitted
Add plan paste for easily accessing the sharding plan
Summary: X-link: #3389 internal This diff introduces additional logging of the complete sharding plan in a human-readable format, making it easily accessible for any review and analysis through the planner db dataset. Differential Revision: D82945862
1 parent e083ca6 commit 6e21cd4

File tree

1 file changed

+16
-9
lines changed

1 file changed

+16
-9
lines changed

torchrec/distributed/planner/types.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -667,15 +667,22 @@ def __deepcopy__(
667667
return result
668668

669669
def __str__(self) -> str:
670-
str_obj: str = ""
671-
str_obj += f"name: {self.name}"
672-
str_obj += f"\nsharding type: {self.sharding_type}"
673-
str_obj += f"\ncompute kernel: {self.compute_kernel}"
674-
str_obj += f"\nnum shards: {len(self.shards)}"
675-
for shard in self.shards:
676-
str_obj += f"\n\t{str(shard)}"
677-
678-
return str_obj
670+
tensor_metadata = f"{{shape: {tuple(self.tensor.shape)}, dtype: {self.tensor.dtype}, device: {self.tensor.device}}}"
671+
shards_str = f"[{', '.join(str(shard) for shard in self.shards)}]"
672+
return f"""{{
673+
"name": "{self.name}",
674+
"module_fqn": "{self.module[0]}",
675+
"tensor": {tensor_metadata},
676+
"input_lengths": {self.input_lengths},
677+
"batch_size": {self.batch_size},
678+
"sharding_type": "{self.sharding_type}",
679+
"compute_kernel": "{self.compute_kernel}",
680+
"shards": {shards_str},
681+
"is_pooled": {self.is_pooled if self.module[1] else None},
682+
"feature_names": {self.feature_names},
683+
"cache_params": {self.cache_params},
684+
"is_weighted": {self.is_weighted}
685+
}}"""
679686

680687

681688
class PartitionByType(Enum):

0 commit comments

Comments
 (0)