Skip to content

Commit 228f430

Browse files
aliafzalfacebook-github-bot
authored andcommitted
PlanLoader addition into planner (#3355)
Summary: Pull Request resolved: #3355 **Summary:** * Added PlanLoader abstract base class to enable loading pre-computed sharding plans from stored locations within planner. * Supports two key scenarios: 1. Reusing previously computed and stored sharding plans to avoid regeneration costs 2. Using sharding plans from previous runs as starting points for iterative improvements * Defines two abstract methods: * `load()`: Returns a dictionary mapping sharding option hashes to ShardingOption objects * `plan_validation_str()`: Provides validation string for plan integrity checks * Part of the broader effort to improve planner UX and reliability by enabling plan persistence and reuse across training runs Reviewed By: mserturk Differential Revision: D81571293 fbshipit-source-id: e71336faeed40d95d53afae0c7cb27762574d396
1 parent b444149 commit 228f430

File tree

1 file changed

+34
-0
lines changed

1 file changed

+34
-0
lines changed

torchrec/distributed/planner/types.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -998,6 +998,40 @@ def log(
998998
...
999999

10001000

1001+
class PlanLoader(abc.ABC):
1002+
"""
1003+
Retrieves a pre-computed sharding plan from its stored location. This is useful in two scenarios:
1004+
1. To utilize a specific sharding plan that was previously computed and stored, saving the cost of re-generating the plan
1005+
2. To use a sharding plan from previous runs as a starting point for the next run, allowing for improvement over time.
1006+
"""
1007+
1008+
@abc.abstractmethod
1009+
def load(
1010+
self,
1011+
) -> Optional[Dict[int, ShardingOption]]:
1012+
"""
1013+
Load sharding plan from its stored location.
1014+
1015+
Returns:
1016+
Dict[int, ShardingOption]: loaded sharding plan. key is hash of sharding option to map to sharding option with enumerated sharding option.
1017+
"""
1018+
...
1019+
1020+
@abc.abstractmethod
1021+
def plan_context_hash(
1022+
self,
1023+
) -> Optional[str]:
1024+
"""
1025+
Input context hash of a sharding plan.
1026+
1027+
Returns:
1028+
str: hash of sharding plan context.
1029+
"""
1030+
...
1031+
1032+
...
1033+
1034+
10011035
@dataclass
10021036
class CriticalPathEstimate:
10031037
comms_estimate: float

0 commit comments

Comments
 (0)