Skip to content

Commit 8813a5f

Browse files
Yixin Baometa-codesync[bot]
authored andcommitted
Sync D80683711 with torchrec/modules folder (#3407)
Summary: Pull Request resolved: #3407 as titled Reviewed By: zlzhao1104 Differential Revision: D83709897 fbshipit-source-id: 86558cda0d94b397f4c9b56af1a70047f0614933
1 parent 48a8651 commit 8813a5f

File tree

3 files changed

+268
-7
lines changed

3 files changed

+268
-7
lines changed

torchrec/modules/hash_mc_metrics.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def __init__(
4242
zch_size: int,
4343
frequency: int,
4444
start_bucket: int,
45+
disable_fallback: bool,
4546
log_file_path: str = "",
4647
) -> None:
4748
super().__init__()
@@ -56,6 +57,7 @@ def __init__(
5657
self._zch_size: int = zch_size
5758
self._frequency: int = frequency
5859
self._start_bucket: int = start_bucket
60+
self._disable_fallback: bool = disable_fallback
5961

6062
self._dtype_checked: bool = False
6163
self._total_cnt: int = 0
@@ -114,7 +116,11 @@ def update(
114116

115117
self._insert_cnt += insert_cnt
116118
self._total_cnt += values.numel()
117-
hits = torch.eq(remapped_identities_0, values)
119+
if self._disable_fallback:
120+
hits = torch.isin(remapped_identities_0, values)
121+
else:
122+
# Cannot use isin() as it is possible that cache miss falls back to another element in values.
123+
hits = torch.eq(remapped_identities_0, values)
118124
hit_cnt = int(torch.sum(hits).item())
119125
self._hit_cnt += hit_cnt
120126
self._collision_cnt += values.numel() - hit_cnt - insert_cnt

torchrec/modules/hash_mc_modules.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ def __init__(
211211
end_bucket: Optional[int] = None,
212212
opt_in_prob: int = -1,
213213
percent_reserved_slots: float = 0,
214+
disable_fallback: bool = False,
214215
) -> None:
215216
if output_segments is None:
216217
assert (
@@ -273,6 +274,7 @@ def __init__(
273274
self._eviction_policy_name is None
274275
or self._eviction_policy_name != HashZchEvictionPolicyName.LRU_EVICTION
275276
), "LRU eviction is not compatible with opt-in at this time"
277+
self._disable_fallback: bool = disable_fallback
276278

277279
if torch.jit.is_scripting() or self._is_inference or self._name is None:
278280
self._tb_logging_frequency = 0
@@ -284,6 +286,7 @@ def __init__(
284286
zch_size=self._zch_size,
285287
frequency=self._tb_logging_frequency,
286288
start_bucket=self._start_bucket,
289+
disable_fallback=self._disable_fallback,
287290
)
288291
else:
289292
logger.info(
@@ -349,7 +352,7 @@ def __init__(
349352
f"{self._buckets=}, {self._start_bucket=}, {self._end_bucket=}, "
350353
f"{self._output_global_offset_tensor=}, {self._output_segments=}, "
351354
f"{inference_dispatch_div_train_world_size=}, "
352-
f"{self._opt_in_prob=}, {self._percent_reserved_slots=}"
355+
f"{self._opt_in_prob=}, {self._percent_reserved_slots=}, {self._disable_fallback=}"
353356
)
354357

355358
@property
@@ -525,7 +528,7 @@ def remap(self, features: Dict[str, JaggedTensor]) -> Dict[str, JaggedTensor]:
525528
# Use self._is_inference to turn on writing to pinned
526529
# CPU memory directly. But may not have perf benefit.
527530
output_on_uvm=False, # self._is_inference,
528-
disable_fallback=False,
531+
disable_fallback=self._disable_fallback,
529532
_modulo_identity_DPRECATED=False, # deprecated, always False
530533
input_metadata=input_metadata,
531534
eviction_threshold=eviction_threshold,
@@ -537,7 +540,12 @@ def remap(self, features: Dict[str, JaggedTensor]) -> Dict[str, JaggedTensor]:
537540

538541
# record the on-device remapped ids
539542
self.table_name_on_device_remapped_ids_dict[name] = remapped_ids.clone()
540-
543+
lengths: torch.Tensor = feature.lengths()
544+
if self._disable_fallback:
545+
# Only works on GPU when read only is true.
546+
hit_indices = remapped_ids != -1
547+
remapped_ids = remapped_ids[hit_indices]
548+
lengths = torch.masked_fill(lengths, ~hit_indices, 0)
541549
if self._scalar_logger is not None:
542550
assert identities_0 is not None
543551
self._scalar_logger.update(
@@ -560,7 +568,7 @@ def remap(self, features: Dict[str, JaggedTensor]) -> Dict[str, JaggedTensor]:
560568

561569
remapped_features[name] = JaggedTensor(
562570
values=remapped_ids,
563-
lengths=feature.lengths(),
571+
lengths=lengths,
564572
offsets=feature.offsets(),
565573
weights=feature.weights_or_none(),
566574
)
@@ -623,6 +631,7 @@ def rebuild_with_output_id_range(
623631
eviction_config=self._eviction_config,
624632
opt_in_prob=self._opt_in_prob,
625633
percent_reserved_slots=self._percent_reserved_slots,
634+
disable_fallback=self._disable_fallback,
626635
)
627636

628637

torchrec/modules/tests/test_hash_mc_modules.py

Lines changed: 248 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,27 @@
88
# pyre-strict
99

1010
import unittest
11-
from typing import cast
11+
from typing import cast, Dict
1212
from unittest.mock import patch
1313

1414
import torch
1515
from hypothesis import given, settings, strategies as st
16+
1617
from pyre_extensions import none_throws
1718
from torchrec.distributed.embedding_sharding import bucketize_kjt_before_all2all
18-
from torchrec.modules.embedding_configs import EmbeddingConfig
19+
from torchrec.modules.embedding_configs import (
20+
DataType,
21+
EmbeddingBagConfig,
22+
EmbeddingConfig,
23+
PoolingType,
24+
)
25+
from torchrec.modules.embedding_modules import EmbeddingBagCollection
1926
from torchrec.modules.hash_mc_evictions import (
2027
HashZchEvictionConfig,
2128
HashZchEvictionPolicyName,
2229
)
2330
from torchrec.modules.hash_mc_modules import HashZchManagedCollisionModule
31+
from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingBagCollection
2432
from torchrec.modules.mc_modules import (
2533
ManagedCollisionCollection,
2634
ManagedCollisionModule,
@@ -680,3 +688,241 @@ def test_dynamically_switch_inference_training_mode(self) -> None:
680688
self.assertTrue(m._is_inference)
681689
self.assertTrue(m._eviction_policy_name is None)
682690
self.assertTrue(m._eviction_module is None)
691+
692+
# Pyre-ignore [56]: Pyre was not able to infer the type of argument `torch.cuda.device_count() < 1` to decorator factory `unittest.skipIf`
693+
@unittest.skipIf(
694+
torch.cuda.device_count() < 1,
695+
"Not enough GPUs, this test requires at least two GPUs",
696+
)
697+
def test_zch_hash_disable_fallback(self) -> None:
698+
m = HashZchManagedCollisionModule(
699+
zch_size=30,
700+
device=torch.device("cuda"),
701+
total_num_buckets=2,
702+
eviction_policy_name=HashZchEvictionPolicyName.SINGLE_TTL_EVICTION,
703+
eviction_config=HashZchEvictionConfig(
704+
features=[],
705+
single_ttl=10,
706+
),
707+
max_probe=4,
708+
disable_fallback=True,
709+
start_bucket=1,
710+
output_segments=[0, 10, 20],
711+
)
712+
jt = JaggedTensor(
713+
values=torch.arange(0, 4, dtype=torch.int64, device="cuda"),
714+
lengths=torch.tensor([1, 1, 1, 1], dtype=torch.int64, device="cuda"),
715+
)
716+
# Run once to insert ids
717+
output0 = m.remap({"test": jt})
718+
self.assertTrue(
719+
torch.equal(
720+
output0["test"].values(),
721+
torch.tensor([8, 15, 11], dtype=torch.int64, device="cuda:0"),
722+
)
723+
)
724+
self.assertTrue(
725+
torch.equal(
726+
output0["test"].lengths(),
727+
torch.tensor([1, 1, 0, 1], dtype=torch.int64, device="cuda:0"),
728+
)
729+
)
730+
m.reset_inference_mode()
731+
jt = JaggedTensor(
732+
values=torch.tensor([9, 0, 1, 4, 6, 8], dtype=torch.int64, device="cuda"),
733+
lengths=torch.tensor([1, 1, 1, 1, 1, 1], dtype=torch.int64, device="cuda"),
734+
)
735+
# Run again in inference mode and only values 0 and 1 exist.
736+
output1 = m.remap({"test": jt})
737+
self.assertTrue(
738+
torch.equal(
739+
output1["test"].values(),
740+
torch.tensor([8, 15], dtype=torch.int64, device="cuda:0"),
741+
)
742+
)
743+
self.assertTrue(
744+
torch.equal(
745+
output1["test"].lengths(),
746+
torch.tensor([0, 1, 1, 0, 0, 0], dtype=torch.int64, device="cuda:0"),
747+
)
748+
)
749+
750+
m = HashZchManagedCollisionModule(
751+
zch_size=10,
752+
device=torch.device("cuda"),
753+
total_num_buckets=2,
754+
eviction_policy_name=HashZchEvictionPolicyName.SINGLE_TTL_EVICTION,
755+
eviction_config=HashZchEvictionConfig(
756+
features=[],
757+
single_ttl=10,
758+
),
759+
max_probe=4,
760+
start_bucket=0,
761+
output_segments=None,
762+
disable_fallback=True,
763+
)
764+
jt = JaggedTensor(
765+
values=torch.arange(0, 4, dtype=torch.int64, device="cuda"),
766+
lengths=torch.tensor([1, 1, 1, 1], dtype=torch.int64, device="cuda"),
767+
)
768+
# Run once to insert ids
769+
output0 = m.remap({"test": jt})
770+
self.assertTrue(
771+
torch.equal(
772+
output0["test"].values(),
773+
torch.tensor([3, 5, 4, 6], dtype=torch.int64, device="cuda:0"),
774+
)
775+
)
776+
self.assertTrue(
777+
torch.equal(
778+
output0["test"].lengths(),
779+
torch.tensor([1, 1, 1, 1], dtype=torch.int64, device="cuda:0"),
780+
)
781+
)
782+
m.reset_inference_mode()
783+
jt = JaggedTensor(
784+
values=torch.tensor([9, 0, 1, 4, 6, 8], dtype=torch.int64, device="cuda"),
785+
lengths=torch.tensor([1, 1, 1, 1, 1, 1], dtype=torch.int64, device="cuda"),
786+
)
787+
# Run again in inference mode and only values 0 and 1 exist.
788+
output1 = m.remap({"test": jt})
789+
self.assertTrue(
790+
torch.equal(
791+
output1["test"].values(),
792+
torch.tensor([3, 5], dtype=torch.int64, device="cuda:0"),
793+
)
794+
)
795+
self.assertTrue(
796+
torch.equal(
797+
output1["test"].lengths(),
798+
torch.tensor([0, 1, 1, 0, 0, 0], dtype=torch.int64, device="cuda:0"),
799+
)
800+
)
801+
802+
# Pyre-ignore [56]: Pyre was not able to infer the type of argument `torch.cuda.device_count() < 1` to decorator factory `unittest.skipIf`
803+
@unittest.skipIf(
804+
torch.cuda.device_count() < 1,
805+
"Not enough GPUs, this test requires at least two GPUs",
806+
)
807+
def test_zch_hash_zero_rows(self) -> None:
808+
# When disabling fallback, for missed ids we should return zero rows in output embeddings.
809+
mc_emb_configs = [
810+
EmbeddingBagConfig(
811+
num_embeddings=10,
812+
embedding_dim=3,
813+
name="table_0",
814+
data_type=DataType.FP32,
815+
feature_names=["table_0"],
816+
pooling=PoolingType.SUM,
817+
weight_init_max=None,
818+
weight_init_min=None,
819+
init_fn=None,
820+
use_virtual_table=False,
821+
virtual_table_eviction_policy=None,
822+
total_num_buckets=1,
823+
)
824+
]
825+
mc_modules: Dict[str, ManagedCollisionModule] = {
826+
"table_0": HashZchManagedCollisionModule(
827+
zch_size=10,
828+
device=torch.device("cuda"),
829+
max_probe=512,
830+
tb_logging_frequency=100,
831+
name="table_0",
832+
total_num_buckets=1,
833+
eviction_config=None,
834+
eviction_policy_name=None,
835+
opt_in_prob=-1,
836+
percent_reserved_slots=0,
837+
disable_fallback=True,
838+
)
839+
}
840+
mcebc = ManagedCollisionEmbeddingBagCollection(
841+
EmbeddingBagCollection(
842+
device=torch.device("cuda"),
843+
tables=mc_emb_configs,
844+
is_weighted=False,
845+
),
846+
ManagedCollisionCollection(
847+
managed_collision_modules=mc_modules,
848+
embedding_configs=mc_emb_configs,
849+
),
850+
return_remapped_features=True,
851+
)
852+
lengths = torch.tensor(
853+
[1, 1, 1, 1, 1], dtype=torch.int64, device=torch.device("cuda")
854+
)
855+
values = torch.tensor(
856+
[3, 4, 5, 6, 8],
857+
dtype=torch.int64,
858+
device=torch.device("cuda"),
859+
)
860+
features = KeyedJaggedTensor(
861+
keys=["table_0"],
862+
values=values,
863+
lengths=lengths,
864+
)
865+
# Run once to insert ids
866+
res = mcebc.forward(features)
867+
# Pyre-ignore [6]: In call `torch._C._VariableFunctions.abs`, for 1st positional argument, expected `Tensor` but got `Union[JaggedTensor, Tensor]`
868+
mask = torch.abs(res[0]["table_0"]) == 0
869+
# For each row, check if all elements are True (i.e., close to zero)
870+
row_mask = mask.all(dim=1)
871+
# Get indices of zero rows
872+
self.assertEqual(torch.nonzero(row_mask, as_tuple=False).squeeze().numel(), 0)
873+
self.assertIsNotNone(res[1])
874+
self.assertTrue(
875+
torch.equal(
876+
# Pyre-ignore [16]: Optional type has no attribute `__getitem__`.
877+
res[1]["table_0"].values(),
878+
torch.tensor([1, 2, 8, 9, 3], dtype=torch.int64, device="cuda:0"),
879+
)
880+
)
881+
self.assertTrue(
882+
torch.equal(
883+
res[1]["table_0"].lengths(),
884+
torch.tensor([1, 1, 1, 1, 1], dtype=torch.int64, device="cuda:0"),
885+
)
886+
)
887+
# Pyre-ignore [29]: `typing.Union[torch._tensor.Tensor, torch.nn.modules.module.Module]` is not a function
888+
mcebc._managed_collision_collection._managed_collision_modules[
889+
"table_0"
890+
].reset_inference_mode()
891+
lengths = torch.tensor(
892+
[1, 1, 1, 1, 1, 1], dtype=torch.int64, device=torch.device("cuda")
893+
)
894+
values = torch.tensor(
895+
[0, 4, 5, 1, 2, 8],
896+
dtype=torch.int64,
897+
device=torch.device("cuda"),
898+
)
899+
features = KeyedJaggedTensor(
900+
keys=["table_0"],
901+
values=values,
902+
lengths=lengths,
903+
)
904+
# Run once to insert ids.
905+
res = mcebc.forward(features)
906+
self.assertTrue(
907+
torch.equal(
908+
res[1]["table_0"].values(),
909+
torch.tensor([2, 8, 3], dtype=torch.int64, device="cuda:0"),
910+
)
911+
)
912+
self.assertTrue(
913+
torch.equal(
914+
res[1]["table_0"].lengths(),
915+
torch.tensor([0, 1, 1, 0, 0, 1], dtype=torch.int64, device="cuda:0"),
916+
)
917+
)
918+
# Pyre-ignore [6]: In call `torch._C._VariableFunctions.abs`, for 1st positional argument, expected `Tensor` but got `Union[JaggedTensor, Tensor]`
919+
mask = torch.abs(res[0]["table_0"]) == 0
920+
# For each row, check if all elements are True (i.e., close to zero)
921+
row_mask = mask.all(dim=1)
922+
# Get indices of zero rows
923+
self.assertTrue(
924+
torch.equal(
925+
torch.tensor([0, 3, 4], device="cuda:0"),
926+
torch.nonzero(row_mask, as_tuple=False).squeeze(),
927+
)
928+
)

0 commit comments

Comments
 (0)