Skip to content

Commit dbe0ade

Browse files
authored
Adjust tag propagation in EmbeddingOperator (#146)
This limits the propagation of tags from the input id column to the output embedding column.
1 parent a2cc7ee commit dbe0ade

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

merlin/dataloader/ops/embeddings.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,23 @@ def compute_output_schema(
9898
col_schemas.append(col_schema)
9999
id_schema = input_schema.column_schemas[self.lookup_key]
100100
embedding_dim = self.embeddings.shape[1]
101+
102+
new_tags = [Tags.EMBEDDING]
103+
propagated_tags = [
104+
Tags.LIST,
105+
Tags.SEQUENCE,
106+
Tags.USER,
107+
Tags.ITEM,
108+
Tags.CONTEXT,
109+
Tags.SESSION,
110+
]
111+
for tag in propagated_tags:
112+
if tag in id_schema.tags:
113+
new_tags.append(tag)
101114
col_schemas.append(
102115
ColumnSchema(
103116
name=self.embedding_name,
104-
tags=[Tags.EMBEDDING],
117+
tags=new_tags,
105118
dtype=self.embeddings.dtype,
106119
dims=id_schema.shape.as_tuple + (embedding_dim,),
107120
)

tests/unit/dataloader/test_embeddings.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from merlin.dataloader.ops.padding import Padding
2626
from merlin.io import Dataset
2727
from merlin.schema import Tags
28+
from merlin.schema.tags import TagSet
2829
from merlin.table import TensorColumn, TensorTable
2930

3031

@@ -71,7 +72,7 @@ def test_embedding_np_mmap_dl_no_lookup(tmpdir, embedding_ids, np_embeddings_fro
7172
dataset = dataset.repartition(10)
7273
schema = dataset.schema
7374
for col_name in cat_names:
74-
schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL, Tags.EMBEDDING])
75+
schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL, Tags.ITEM, Tags.ID])
7576
dataset.schema = schema
7677
data_loader = Loader(
7778
dataset,
@@ -85,8 +86,10 @@ def test_embedding_np_mmap_dl_no_lookup(tmpdir, embedding_ids, np_embeddings_fro
8586
assert data_loader.output_schema.column_names == ["id", "embeddings"]
8687

8788
embeddings_dim = 1024
88-
embeddings_value_count = data_loader.output_schema["embeddings"].value_count
89+
embedding_schema = data_loader.output_schema["embeddings"]
90+
embeddings_value_count = embedding_schema.value_count
8991
assert embeddings_value_count.min == embeddings_value_count.max == embeddings_dim
92+
assert embedding_schema.tags == TagSet([Tags.EMBEDDING, Tags.ITEM])
9093

9194
full_len = 0
9295
for idx, batch in enumerate(data_loader):

0 commit comments

Comments
 (0)