@@ -816,6 +816,159 @@ def test_row_wise_set_heterogenous_device(self, data_type: DataType) -> None:
816
816
0 ,
817
817
)
818
818
819
+ # pyre-fixme[56]
820
+ @given (data_type = st .sampled_from ([DataType .FP32 , DataType .FP16 ]))
821
+ @settings (verbosity = Verbosity .verbose , max_examples = 8 , deadline = None )
822
+ def test_row_wise_bucket_level_sharding (self , data_type : DataType ) -> None :
823
+
824
+ embedding_config = [
825
+ EmbeddingBagConfig (
826
+ name = f"table_{ idx } " ,
827
+ feature_names = [f"feature_{ idx } " ],
828
+ embedding_dim = 64 ,
829
+ num_embeddings = 4096 ,
830
+ data_type = data_type ,
831
+ )
832
+ for idx in range (2 )
833
+ ]
834
+ module_sharding_plan = construct_module_sharding_plan (
835
+ EmbeddingCollection (tables = embedding_config ),
836
+ per_param_sharding = {
837
+ "table_0" : row_wise (
838
+ sizes_placement = (
839
+ [2048 , 1024 , 1024 ],
840
+ ["cpu" , "cuda" , "cuda" ],
841
+ ),
842
+ num_buckets_per_rank = [20 , 30 , 40 ],
843
+ ),
844
+ "table_1" : row_wise (
845
+ sizes_placement = ([2048 , 1024 , 1024 ], ["cpu" , "cpu" , "cpu" ])
846
+ ),
847
+ },
848
+ local_size = 1 ,
849
+ world_size = 2 ,
850
+ device_type = "cuda" ,
851
+ )
852
+
853
+ # Make sure per_param_sharding setting override the default device_type
854
+ device_table_0_shard_0 = (
855
+ # pyre-ignore[16]
856
+ module_sharding_plan ["table_0" ]
857
+ .sharding_spec .shards [0 ]
858
+ .placement
859
+ )
860
+ self .assertEqual (
861
+ device_table_0_shard_0 .device ().type ,
862
+ "cpu" ,
863
+ )
864
+ # cpu always has rank 0
865
+ self .assertEqual (
866
+ device_table_0_shard_0 .rank (),
867
+ 0 ,
868
+ )
869
+ for i in range (1 , 3 ):
870
+ device_table_0_shard_i = (
871
+ module_sharding_plan ["table_0" ].sharding_spec .shards [i ].placement
872
+ )
873
+ self .assertEqual (
874
+ device_table_0_shard_i .device ().type ,
875
+ "cuda" ,
876
+ )
877
+ # first rank is assigned to cpu so index = rank - 1
878
+ self .assertEqual (
879
+ device_table_0_shard_i .device ().index ,
880
+ i - 1 ,
881
+ )
882
+ self .assertEqual (
883
+ device_table_0_shard_i .rank (),
884
+ i ,
885
+ )
886
+ for i in range (3 ):
887
+ device_table_1_shard_i = (
888
+ module_sharding_plan ["table_1" ].sharding_spec .shards [i ].placement
889
+ )
890
+ self .assertEqual (
891
+ device_table_1_shard_i .device ().type ,
892
+ "cpu" ,
893
+ )
894
+ # cpu always has rank 0
895
+ self .assertEqual (
896
+ device_table_1_shard_i .rank (),
897
+ 0 ,
898
+ )
899
+
900
+ expected = {
901
+ "table_0" : ParameterSharding (
902
+ sharding_type = "row_wise" ,
903
+ compute_kernel = "key_value" ,
904
+ ranks = [
905
+ 0 ,
906
+ 1 ,
907
+ 2 ,
908
+ ],
909
+ sharding_spec = EnumerableShardingSpec (
910
+ shards = [
911
+ ShardMetadata (
912
+ shard_offsets = [0 , 0 ],
913
+ shard_sizes = [2048 , 64 ],
914
+ placement = "rank:0/cpu" ,
915
+ bucket_id_offset = 0 ,
916
+ num_buckets = 20 ,
917
+ ),
918
+ ShardMetadata (
919
+ shard_offsets = [2048 , 0 ],
920
+ shard_sizes = [1024 , 64 ],
921
+ placement = "rank:1/cuda:0" ,
922
+ bucket_id_offset = 20 ,
923
+ num_buckets = 30 ,
924
+ ),
925
+ ShardMetadata (
926
+ shard_offsets = [3072 , 0 ],
927
+ shard_sizes = [1024 , 64 ],
928
+ placement = "rank:2/cuda:1" ,
929
+ bucket_id_offset = 50 ,
930
+ num_buckets = 40 ,
931
+ ),
932
+ ]
933
+ ),
934
+ ),
935
+ "table_1" : ParameterSharding (
936
+ sharding_type = "row_wise" ,
937
+ compute_kernel = "quant" ,
938
+ ranks = [
939
+ 0 ,
940
+ 1 ,
941
+ 2 ,
942
+ ],
943
+ sharding_spec = EnumerableShardingSpec (
944
+ shards = [
945
+ ShardMetadata (
946
+ shard_offsets = [0 , 0 ],
947
+ shard_sizes = [2048 , 64 ],
948
+ placement = "rank:0/cpu" ,
949
+ bucket_id_offset = None ,
950
+ num_buckets = None ,
951
+ ),
952
+ ShardMetadata (
953
+ shard_offsets = [2048 , 0 ],
954
+ shard_sizes = [1024 , 64 ],
955
+ placement = "rank:0/cpu" ,
956
+ bucket_id_offset = None ,
957
+ num_buckets = None ,
958
+ ),
959
+ ShardMetadata (
960
+ shard_offsets = [3072 , 0 ],
961
+ shard_sizes = [1024 , 64 ],
962
+ placement = "rank:0/cpu" ,
963
+ bucket_id_offset = None ,
964
+ num_buckets = None ,
965
+ ),
966
+ ]
967
+ ),
968
+ ),
969
+ }
970
+ self .assertDictEqual (expected , module_sharding_plan )
971
+
819
972
# pyre-fixme[56]
820
973
@given (data_type = st .sampled_from ([DataType .FP32 , DataType .FP16 ]))
821
974
@settings (verbosity = Verbosity .verbose , max_examples = 8 , deadline = None )
@@ -929,18 +1082,89 @@ def test_str(self) -> None:
929
1082
)
930
1083
expected = """module: ebc
931
1084
932
- param | sharding type | compute kernel | ranks
1085
+ param | sharding type | compute kernel | ranks
933
1086
-------- | ------------- | -------------- | ------
934
1087
user_id | table_wise | dense | [0]
935
1088
movie_id | row_wise | dense | [0, 1]
936
1089
937
- param | shard offsets | shard sizes | placement
1090
+ param | shard offsets | shard sizes | placement
938
1091
-------- | ------------- | ----------- | -------------
939
1092
user_id | [0, 0] | [4096, 32] | rank:0/cuda:0
940
1093
movie_id | [0, 0] | [2048, 32] | rank:0/cuda:0
941
1094
movie_id | [2048, 0] | [2048, 32] | rank:0/cuda:1
942
1095
"""
943
1096
self .maxDiff = None
1097
+ print ("STR PLAN" )
1098
+ print (str (plan ))
1099
+ print ("=======" )
1100
+ for i in range (len (expected .splitlines ())):
1101
+ self .assertEqual (
1102
+ expected .splitlines ()[i ].strip (), str (plan ).splitlines ()[i ].strip ()
1103
+ )
1104
+
1105
+ def test_str_bucket_wise_sharding (self ) -> None :
1106
+ plan = ShardingPlan (
1107
+ {
1108
+ "ebc" : EmbeddingModuleShardingPlan (
1109
+ {
1110
+ "user_id" : ParameterSharding (
1111
+ sharding_type = "table_wise" ,
1112
+ compute_kernel = "dense" ,
1113
+ ranks = [0 ],
1114
+ sharding_spec = EnumerableShardingSpec (
1115
+ [
1116
+ ShardMetadata (
1117
+ shard_offsets = [0 , 0 ],
1118
+ shard_sizes = [4096 , 32 ],
1119
+ placement = "rank:0/cuda:0" ,
1120
+ ),
1121
+ ]
1122
+ ),
1123
+ ),
1124
+ "movie_id" : ParameterSharding (
1125
+ sharding_type = "row_wise" ,
1126
+ compute_kernel = "dense" ,
1127
+ ranks = [0 , 1 ],
1128
+ sharding_spec = EnumerableShardingSpec (
1129
+ [
1130
+ ShardMetadata (
1131
+ shard_offsets = [0 , 0 ],
1132
+ shard_sizes = [2048 , 32 ],
1133
+ placement = "rank:0/cuda:0" ,
1134
+ bucket_id_offset = 0 ,
1135
+ num_buckets = 20 ,
1136
+ ),
1137
+ ShardMetadata (
1138
+ shard_offsets = [2048 , 0 ],
1139
+ shard_sizes = [2048 , 32 ],
1140
+ placement = "rank:0/cuda:1" ,
1141
+ bucket_id_offset = 20 ,
1142
+ num_buckets = 30 ,
1143
+ ),
1144
+ ]
1145
+ ),
1146
+ ),
1147
+ }
1148
+ )
1149
+ }
1150
+ )
1151
+ expected = """module: ebc
1152
+
1153
+ param | sharding type | compute kernel | ranks
1154
+ -------- | ------------- | -------------- | ------
1155
+ user_id | table_wise | dense | [0]
1156
+ movie_id | row_wise | dense | [0, 1]
1157
+
1158
+ param | shard offsets | shard sizes | placement | bucket id offset | num buckets
1159
+ -------- | ------------- | ----------- | ------------- | ---------------- | -----------
1160
+ user_id | [0, 0] | [4096, 32] | rank:0/cuda:0 | None | None
1161
+ movie_id | [0, 0] | [2048, 32] | rank:0/cuda:0 | 0 | 20
1162
+ movie_id | [2048, 0] | [2048, 32] | rank:0/cuda:1 | 20 | 30
1163
+ """
1164
+ self .maxDiff = None
1165
+ print ("STR PLAN BUCKET WISE" )
1166
+ print (str (plan ))
1167
+ print ("=======" )
944
1168
for i in range (len (expected .splitlines ())):
945
1169
self .assertEqual (
946
1170
expected .splitlines ()[i ].strip (), str (plan ).splitlines ()[i ].strip ()
0 commit comments