Skip to content

Commit f738e20

Browse files
zejunhfacebook-github-bot
authored andcommitted
support permute_multi_embedding_function on torch.export (#3897)
Summary: X-link: facebookresearch/FBGEMM#988 Pull Request resolved: #3897 support fbgemm.permute_multi_embedding_function.default for LPV model register with separate entry in kernel.yaml for graph mode lowering added fp16 ref kernel (sitecao) Differential Revision: D71821354
1 parent d1b19e4 commit f738e20

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

Diff for: fbgemm_gpu/fbgemm_gpu/sparse_ops.py

+22
Original file line numberDiff line numberDiff line change
@@ -1121,6 +1121,24 @@ def generic_histogram_binning_calibration_by_feature(
11211121
)
11221122

11231123

1124+
def permute_multi_embedding_function(
1125+
pooled_embs: List[Tensor],
1126+
permutes: Tensor,
1127+
in_shapes: Tensor,
1128+
out_shapes: Tensor,
1129+
out_lengths: List[int],
1130+
reverse: bool = False,
1131+
) -> List[Tensor]:
1132+
out_dtype = pooled_embs[0].dtype
1133+
bs = pooled_embs[0].shape[0]
1134+
torch._check(permutes.shape[1] == 6, lambda: "permutes must have 6 columns")
1135+
1136+
output = []
1137+
for i in range(len(out_lengths)):
1138+
output.append(torch.empty([bs, out_lengths[i]], dtype=out_dtype))
1139+
return output
1140+
1141+
11241142
def _setup() -> None:
11251143
# pyre-ignore[16]
11261144
_setup.done = getattr(_setup, "done", False)
@@ -1258,6 +1276,10 @@ def impl_autograd(op_name, fn, setup_context: Optional[Callable] = None) -> None
12581276
"fbgemm::generic_histogram_binning_calibration_by_feature",
12591277
generic_histogram_binning_calibration_by_feature,
12601278
)
1279+
impl_abstract(
1280+
"fbgemm::permute_multi_embedding_function",
1281+
permute_multi_embedding_function,
1282+
)
12611283
impl_abstract(
12621284
"fbgemm::FloatToHFP8Quantized",
12631285
float_to_hfp8_quantized,

Diff for: fbgemm_gpu/src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,7 @@ std::vector<Tensor> regroup_keyed_tensor_meta(
344344
} // namespace fbgemm_gpu
345345

346346
TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
347+
m.set_python_module("fbgemm_gpu.sparse_ops");
347348
// register the forward function for internal (autograd) usage
348349
m.def(
349350
"permute_multi_embedding_function(Tensor[] pooled_embs, Tensor permutes, Tensor in_shapes, Tensor out_shapes, SymInt[] out_lengths, bool reverse=False) -> Tensor[]");

0 commit comments

Comments
 (0)