Skip to content

Commit cfffac6

Browse files
faran928facebook-github-bot
authored andcommitted
Bucket offsets and sizes in torchrec shard metadata for bucket wise sharding (#2885)
Summary: Bucket offsets and sizes in torchrec shard metadata for bucket wise sharding for ZCH v.Next Differential Revision: D72921209
1 parent eb5cb59 commit cfffac6

File tree

3 files changed

+304
-9
lines changed

3 files changed

+304
-9
lines changed

Diff for: torchrec/distributed/sharding_plan.py

+41-5
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,7 @@ def _get_parameter_sharding(
361361
sharder: ModuleSharder[nn.Module],
362362
placements: Optional[List[str]] = None,
363363
compute_kernel: Optional[str] = None,
364+
bucket_offset_sizes: Optional[List[Tuple[int, int]]] = None,
364365
) -> ParameterSharding:
365366
return ParameterSharding(
366367
sharding_spec=(
@@ -371,6 +372,8 @@ def _get_parameter_sharding(
371372
ShardMetadata(
372373
shard_sizes=size,
373374
shard_offsets=offset,
375+
bucket_id_offset=bucket_id_offset,
376+
num_buckets=num_buckets,
374377
placement=(
375378
placement(
376379
device_type,
@@ -381,9 +384,17 @@ def _get_parameter_sharding(
381384
else device_placement
382385
),
383386
)
384-
for (size, offset, rank), device_placement in zip(
387+
for (size, offset, rank), device_placement, (
388+
num_buckets,
389+
bucket_id_offset,
390+
) in zip(
385391
size_offset_ranks,
386392
placements if placements else [None] * len(size_offset_ranks),
393+
(
394+
bucket_offset_sizes
395+
if bucket_offset_sizes
396+
else [(None, None)] * len(size_offset_ranks)
397+
),
387398
)
388399
]
389400
)
@@ -512,7 +523,8 @@ def _parameter_sharding_generator(
512523

513524

514525
def row_wise(
515-
sizes_placement: Optional[Tuple[List[int], Union[str, List[str]]]] = None
526+
sizes_placement: Optional[Tuple[List[int], Union[str, List[str]]]] = None,
527+
num_buckets_per_rank: Optional[List[int]] = None, # propagate num buckets per rank
516528
) -> ParameterShardingGenerator:
517529
"""
518530
Returns a generator of ParameterShardingPlan for `ShardingType::ROW_WISE` for construct_module_sharding_plan.
@@ -545,6 +557,7 @@ def _parameter_sharding_generator(
545557
device_type: str,
546558
sharder: ModuleSharder[nn.Module],
547559
) -> ParameterSharding:
560+
bucket_offset_sizes = None
548561
if sizes_placement is None:
549562
size_and_offsets = _get_parameter_size_offsets(
550563
param,
@@ -558,17 +571,34 @@ def _parameter_sharding_generator(
558571
size_offset_ranks.append((size, offset, rank))
559572
else:
560573
size_offset_ranks = []
574+
bucket_offset_sizes = None if num_buckets_per_rank is None else []
561575
sizes = sizes_placement[0]
576+
if num_buckets_per_rank is not None:
577+
assert len(sizes) == len(
578+
num_buckets_per_rank
579+
), f"sizes and num_buckets_per_rank must have the same length during row_wise sharding, got {len(sizes)} and {len(num_buckets_per_rank)} respectively"
562580
(rows, cols) = param.shape
563581
cur_offset = 0
564582
prev_offset = 0
583+
prev_bucket_offset = 0
584+
cur_bucket_offset = 0
565585
for rank, size in enumerate(sizes):
566586
per_rank_row = size
587+
per_rank_bucket_size = None
588+
if num_buckets_per_rank is not None:
589+
per_rank_bucket_size = num_buckets_per_rank[rank]
590+
cur_bucket_offset += per_rank_bucket_size
567591
cur_offset += per_rank_row
568592
cur_offset = min(cur_offset, rows)
569593
per_rank_row = cur_offset - prev_offset
570594
size_offset_ranks.append(([per_rank_row, cols], [prev_offset, 0], rank))
571595
prev_offset = cur_offset
596+
if num_buckets_per_rank is not None:
597+
# bucket has only one col for now
598+
none_throws(bucket_offset_sizes).append(
599+
(per_rank_bucket_size, prev_bucket_offset)
600+
)
601+
prev_bucket_offset = cur_bucket_offset
572602

573603
if cur_offset < rows:
574604
raise ValueError(
@@ -590,6 +620,13 @@ def _parameter_sharding_generator(
590620
if device_type == "cuda":
591621
index += 1
592622

623+
compute_kernel = None
624+
if sizes_placement is not None:
625+
if num_buckets_per_rank is not None:
626+
compute_kernel = EmbeddingComputeKernel.KEY_VALUE.value
627+
else:
628+
compute_kernel = EmbeddingComputeKernel.QUANT.value
629+
593630
return _get_parameter_sharding(
594631
param,
595632
ShardingType.ROW_WISE.value,
@@ -598,9 +635,8 @@ def _parameter_sharding_generator(
598635
device_type,
599636
sharder,
600637
placements=placements if sizes_placement else None,
601-
compute_kernel=(
602-
EmbeddingComputeKernel.QUANT.value if sizes_placement else None
603-
),
638+
compute_kernel=compute_kernel,
639+
bucket_offset_sizes=bucket_offset_sizes,
604640
)
605641

606642
return _parameter_sharding_generator

Diff for: torchrec/distributed/tests/test_sharding_plan.py

+226-2
Original file line numberDiff line numberDiff line change
@@ -816,6 +816,159 @@ def test_row_wise_set_heterogenous_device(self, data_type: DataType) -> None:
816816
0,
817817
)
818818

819+
# pyre-fixme[56]
820+
@given(data_type=st.sampled_from([DataType.FP32, DataType.FP16]))
821+
@settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None)
822+
def test_row_wise_bucket_level_sharding(self, data_type: DataType) -> None:
823+
824+
embedding_config = [
825+
EmbeddingBagConfig(
826+
name=f"table_{idx}",
827+
feature_names=[f"feature_{idx}"],
828+
embedding_dim=64,
829+
num_embeddings=4096,
830+
data_type=data_type,
831+
)
832+
for idx in range(2)
833+
]
834+
module_sharding_plan = construct_module_sharding_plan(
835+
EmbeddingCollection(tables=embedding_config),
836+
per_param_sharding={
837+
"table_0": row_wise(
838+
sizes_placement=(
839+
[2048, 1024, 1024],
840+
["cpu", "cuda", "cuda"],
841+
),
842+
num_buckets_per_rank=[20, 30, 40],
843+
),
844+
"table_1": row_wise(
845+
sizes_placement=([2048, 1024, 1024], ["cpu", "cpu", "cpu"])
846+
),
847+
},
848+
local_size=1,
849+
world_size=2,
850+
device_type="cuda",
851+
)
852+
853+
# Make sure per_param_sharding setting override the default device_type
854+
device_table_0_shard_0 = (
855+
# pyre-ignore[16]
856+
module_sharding_plan["table_0"]
857+
.sharding_spec.shards[0]
858+
.placement
859+
)
860+
self.assertEqual(
861+
device_table_0_shard_0.device().type,
862+
"cpu",
863+
)
864+
# cpu always has rank 0
865+
self.assertEqual(
866+
device_table_0_shard_0.rank(),
867+
0,
868+
)
869+
for i in range(1, 3):
870+
device_table_0_shard_i = (
871+
module_sharding_plan["table_0"].sharding_spec.shards[i].placement
872+
)
873+
self.assertEqual(
874+
device_table_0_shard_i.device().type,
875+
"cuda",
876+
)
877+
# first rank is assigned to cpu so index = rank - 1
878+
self.assertEqual(
879+
device_table_0_shard_i.device().index,
880+
i - 1,
881+
)
882+
self.assertEqual(
883+
device_table_0_shard_i.rank(),
884+
i,
885+
)
886+
for i in range(3):
887+
device_table_1_shard_i = (
888+
module_sharding_plan["table_1"].sharding_spec.shards[i].placement
889+
)
890+
self.assertEqual(
891+
device_table_1_shard_i.device().type,
892+
"cpu",
893+
)
894+
# cpu always has rank 0
895+
self.assertEqual(
896+
device_table_1_shard_i.rank(),
897+
0,
898+
)
899+
900+
expected = {
901+
"table_0": ParameterSharding(
902+
sharding_type="row_wise",
903+
compute_kernel="key_value",
904+
ranks=[
905+
0,
906+
1,
907+
2,
908+
],
909+
sharding_spec=EnumerableShardingSpec(
910+
shards=[
911+
ShardMetadata(
912+
shard_offsets=[0, 0],
913+
shard_sizes=[2048, 64],
914+
placement="rank:0/cpu",
915+
bucket_id_offset=0,
916+
num_buckets=20,
917+
),
918+
ShardMetadata(
919+
shard_offsets=[2048, 0],
920+
shard_sizes=[1024, 64],
921+
placement="rank:1/cuda:0",
922+
bucket_id_offset=20,
923+
num_buckets=30,
924+
),
925+
ShardMetadata(
926+
shard_offsets=[3072, 0],
927+
shard_sizes=[1024, 64],
928+
placement="rank:2/cuda:1",
929+
bucket_id_offset=50,
930+
num_buckets=40,
931+
),
932+
]
933+
),
934+
),
935+
"table_1": ParameterSharding(
936+
sharding_type="row_wise",
937+
compute_kernel="quant",
938+
ranks=[
939+
0,
940+
1,
941+
2,
942+
],
943+
sharding_spec=EnumerableShardingSpec(
944+
shards=[
945+
ShardMetadata(
946+
shard_offsets=[0, 0],
947+
shard_sizes=[2048, 64],
948+
placement="rank:0/cpu",
949+
bucket_id_offset=None,
950+
num_buckets=None,
951+
),
952+
ShardMetadata(
953+
shard_offsets=[2048, 0],
954+
shard_sizes=[1024, 64],
955+
placement="rank:0/cpu",
956+
bucket_id_offset=None,
957+
num_buckets=None,
958+
),
959+
ShardMetadata(
960+
shard_offsets=[3072, 0],
961+
shard_sizes=[1024, 64],
962+
placement="rank:0/cpu",
963+
bucket_id_offset=None,
964+
num_buckets=None,
965+
),
966+
]
967+
),
968+
),
969+
}
970+
self.assertDictEqual(expected, module_sharding_plan)
971+
819972
# pyre-fixme[56]
820973
@given(data_type=st.sampled_from([DataType.FP32, DataType.FP16]))
821974
@settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None)
@@ -929,18 +1082,89 @@ def test_str(self) -> None:
9291082
)
9301083
expected = """module: ebc
9311084
932-
param | sharding type | compute kernel | ranks
1085+
param | sharding type | compute kernel | ranks
9331086
-------- | ------------- | -------------- | ------
9341087
user_id | table_wise | dense | [0]
9351088
movie_id | row_wise | dense | [0, 1]
9361089
937-
param | shard offsets | shard sizes | placement
1090+
param | shard offsets | shard sizes | placement
9381091
-------- | ------------- | ----------- | -------------
9391092
user_id | [0, 0] | [4096, 32] | rank:0/cuda:0
9401093
movie_id | [0, 0] | [2048, 32] | rank:0/cuda:0
9411094
movie_id | [2048, 0] | [2048, 32] | rank:0/cuda:1
9421095
"""
9431096
self.maxDiff = None
1097+
print("STR PLAN")
1098+
print(str(plan))
1099+
print("=======")
1100+
for i in range(len(expected.splitlines())):
1101+
self.assertEqual(
1102+
expected.splitlines()[i].strip(), str(plan).splitlines()[i].strip()
1103+
)
1104+
1105+
def test_str_bucket_wise_sharding(self) -> None:
1106+
plan = ShardingPlan(
1107+
{
1108+
"ebc": EmbeddingModuleShardingPlan(
1109+
{
1110+
"user_id": ParameterSharding(
1111+
sharding_type="table_wise",
1112+
compute_kernel="dense",
1113+
ranks=[0],
1114+
sharding_spec=EnumerableShardingSpec(
1115+
[
1116+
ShardMetadata(
1117+
shard_offsets=[0, 0],
1118+
shard_sizes=[4096, 32],
1119+
placement="rank:0/cuda:0",
1120+
),
1121+
]
1122+
),
1123+
),
1124+
"movie_id": ParameterSharding(
1125+
sharding_type="row_wise",
1126+
compute_kernel="dense",
1127+
ranks=[0, 1],
1128+
sharding_spec=EnumerableShardingSpec(
1129+
[
1130+
ShardMetadata(
1131+
shard_offsets=[0, 0],
1132+
shard_sizes=[2048, 32],
1133+
placement="rank:0/cuda:0",
1134+
bucket_id_offset=0,
1135+
num_buckets=20,
1136+
),
1137+
ShardMetadata(
1138+
shard_offsets=[2048, 0],
1139+
shard_sizes=[2048, 32],
1140+
placement="rank:0/cuda:1",
1141+
bucket_id_offset=20,
1142+
num_buckets=30,
1143+
),
1144+
]
1145+
),
1146+
),
1147+
}
1148+
)
1149+
}
1150+
)
1151+
expected = """module: ebc
1152+
1153+
param | sharding type | compute kernel | ranks
1154+
-------- | ------------- | -------------- | ------
1155+
user_id | table_wise | dense | [0]
1156+
movie_id | row_wise | dense | [0, 1]
1157+
1158+
param | shard offsets | shard sizes | placement | bucket id offset | num buckets
1159+
-------- | ------------- | ----------- | ------------- | ---------------- | -----------
1160+
user_id | [0, 0] | [4096, 32] | rank:0/cuda:0 | None | None
1161+
movie_id | [0, 0] | [2048, 32] | rank:0/cuda:0 | 0 | 20
1162+
movie_id | [2048, 0] | [2048, 32] | rank:0/cuda:1 | 20 | 30
1163+
"""
1164+
self.maxDiff = None
1165+
print("STR PLAN BUCKET WISE")
1166+
print(str(plan))
1167+
print("=======")
9441168
for i in range(len(expected.splitlines())):
9451169
self.assertEqual(
9461170
expected.splitlines()[i].strip(), str(plan).splitlines()[i].strip()

0 commit comments

Comments
 (0)