Skip to content

Commit ec9bedf

Browse files
Use Dictionary lookup for supplied IDs to Embedding Operator (#148)
* Use lookup dict for Embedding operator with ids to speed up transform * Add test for embedding operator with unknown value * Update typehint for unknown_value * Remove embedding tag from input schema to embeddings tests The op now adds the embedding tag automatically, and the input schema cannot have both a CATEGORICAL and EMBEDDING tag * Flatten array passed as ids to EmbeddingOperator * Add assertion for ids shape * Set default value of `embedding_index_mapping` * Correct casing of message in test
1 parent dbe0ade commit ec9bedf

File tree

2 files changed

+129
-15
lines changed

2 files changed

+129
-15
lines changed

merlin/dataloader/ops/embeddings.py

Lines changed: 68 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
#
16+
import os
1617
from typing import Optional, Union
1718

1819
import numpy as np
@@ -25,9 +26,9 @@
2526

2627

2728
class EmbeddingOperator(BaseOperator):
28-
"""Create an operator that will apply a torch embedding table to supplied indices.
29-
This operator allows the user to supply an id lookup table if the indices supplied
30-
via the id_lookup_table.
29+
"""Create an operator that will apply an embedding table to supplied indices.
30+
31+
An id lookup table for the embeddings can be supplied with the argument `id_lookup_table`.
3132
3233
Parameters
3334
----------
@@ -39,6 +40,12 @@ class EmbeddingOperator(BaseOperator):
3940
name of new column of embeddings, added to output, by default "embeddings"
4041
id_lookup_table : np.array, optional
4142
numpy array of values that represent embedding indices, by default None
43+
mmap : bool, default False
44+
When loading embeddings from a file, specify whether we should memory map the file
45+
This is useful for accessing a large file without reading the entire file into memory.
46+
unknown_value : Union[float, int, np.ndarray]
47+
If an embedding index is not found.
48+
Specifies the value should we return for the corresponding embedding.
4249
"""
4350

4451
def __init__(
@@ -47,24 +54,75 @@ def __init__(
4754
lookup_key: str = "id",
4855
embedding_name: str = "embeddings",
4956
id_lookup_table: Optional[Union[np.ndarray, str]] = None,
50-
mmap=False,
57+
mmap: bool = False,
58+
unknown_value: Union[float, int, np.ndarray] = 0,
5159
):
52-
if mmap:
53-
embeddings = np.load(embeddings, mmap_mode="r")
54-
id_lookup_table = np.load(id_lookup_table) if id_lookup_table else None
60+
if isinstance(embeddings, (str, os.PathLike)):
61+
mmap_mode = "r" if mmap else None
62+
embeddings = np.load(embeddings, mmap_mode=mmap_mode)
63+
elif isinstance(embeddings, np.ndarray):
64+
pass
65+
else:
66+
raise ValueError(
67+
f"Unsupported type '{type(embeddings)}' passed to argument `embeddings` "
68+
f"of '{type(self).__name__}'. "
69+
"Expected either a numpy.ndarray "
70+
"or a (string or pathlike object corresponding to a numpy file) "
71+
"containing embeddings. "
72+
)
5573
self.embeddings = embeddings
74+
75+
embedding_index_mapping = None
76+
if isinstance(id_lookup_table, (str, os.PathLike)):
77+
_ids = np.load(id_lookup_table)
78+
embedding_index_mapping = self._get_embedding_index_mapping(_ids)
79+
elif isinstance(id_lookup_table, np.ndarray):
80+
_ids = id_lookup_table
81+
embedding_index_mapping = self._get_embedding_index_mapping(_ids)
82+
elif id_lookup_table is None:
83+
pass
84+
else:
85+
raise ValueError(
86+
f"Unsupported type '{type(id_lookup_table)}' passed to argument `id_lookup_table` "
87+
f"of '{type(self).__name__}'. "
88+
"Expected either a numpy.ndarray "
89+
"or a (string or pathlike object corresponding to a numpy file) "
90+
"containing the IDs that correspond to the embeddings. "
91+
)
92+
self.embedding_index_mapping = embedding_index_mapping
93+
5694
self.lookup_key = lookup_key
5795
self.embedding_name = embedding_name
58-
self.id_lookup_table = id_lookup_table
96+
self.unknown_value = unknown_value
97+
98+
def _get_embedding_index_mapping(self, ids):
99+
expected_ids_shape = (self.embeddings.shape[0],)
100+
assert ids.shape == expected_ids_shape, (
101+
"IDs provided must match the number of embeddings. "
102+
f"Expected IDs with shape {expected_ids_shape} "
103+
f"Received IDs with shape: {ids.shape} "
104+
f"Embeddings shape: {self.embeddings.shape} "
105+
)
106+
id_to_index_mapping = dict(zip(ids, range(len(ids))))
107+
return id_to_index_mapping
59108

60109
def transform(
61110
self, col_selector: ColumnSelector, transformable: Transformable
62111
) -> Transformable:
63112
keys = transformable[self.lookup_key]
64113
indices = keys.cpu().values
65-
if self.id_lookup_table is not None:
66-
indices = np.in1d(self.id_lookup_table, indices)
114+
115+
if self.embedding_index_mapping is not None:
116+
indices = np.array([self.embedding_index_mapping.get(_id, -1) for _id in indices])
117+
67118
embeddings = self.embeddings[indices]
119+
120+
# set unknown embedding to zero
121+
for idx in np.ndindex(indices.shape):
122+
embedding_index = indices[idx]
123+
if embedding_index == -1:
124+
embeddings[idx] = self.unknown_value
125+
68126
embeddings_col = TensorColumn(embeddings, offsets=keys.cpu().offsets)
69127
transformable[self.embedding_name] = (
70128
embeddings_col.gpu() if keys.device == Device.GPU else embeddings_col

tests/unit/dataloader/test_embeddings.py

Lines changed: 61 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import pandas as pd
2020
import pytest
2121

22+
from merlin.core.compat import cupy
2223
from merlin.core.dispatch import HAS_GPU
2324
from merlin.dataloader.loader_base import LoaderBase as Loader # noqa
2425
from merlin.dataloader.ops.embeddings import EmbeddingOperator
@@ -29,6 +30,59 @@
2930
from merlin.table import TensorColumn, TensorTable
3031

3132

33+
def test_embeddings_invalid_ids():
34+
ids = np.array(["a", "b"])
35+
embeddings = np.random.rand(3, 10)
36+
with pytest.raises(AssertionError) as exc_info:
37+
EmbeddingOperator(
38+
embeddings,
39+
lookup_key="id",
40+
embedding_name="id_embedding",
41+
id_lookup_table=ids,
42+
)
43+
assert "IDs provided must match the number of embeddings" in str(exc_info.value)
44+
assert "Expected IDs with shape (3,)" in str(exc_info.value)
45+
46+
47+
@pytest.mark.parametrize("unknown_value", [0, 1, np.random.uniform(size=10)])
48+
def test_embedding_lookup_with_unknown_value(unknown_value):
49+
ids = np.array(["a", "b", "c"])
50+
embeddings = np.random.rand(3, 10)
51+
df = pd.DataFrame(
52+
{
53+
"id": ["a", "unknown"],
54+
"feature": [1, 2],
55+
}
56+
)
57+
58+
dataset = Dataset(df, cpu=True)
59+
60+
data_loader = Loader(
61+
dataset,
62+
batch_size=3,
63+
transforms=[
64+
EmbeddingOperator(
65+
embeddings,
66+
lookup_key="id",
67+
embedding_name="id_embedding",
68+
id_lookup_table=ids,
69+
unknown_value=unknown_value,
70+
),
71+
],
72+
shuffle=False,
73+
)
74+
x, y = data_loader.peek()
75+
76+
assert x["id"].values.shape == (2,)
77+
embedding_values = x["id_embedding"].values
78+
if cupy and isinstance(embedding_values, cupy.ndarray):
79+
embedding_values = embedding_values.get()
80+
assert embedding_values.shape == (2, 10)
81+
np.testing.assert_equal(embedding_values[0], embeddings[0])
82+
np.testing.assert_equal(embedding_values[1], unknown_value)
83+
assert data_loader.output_schema.column_names == ["id", "feature", "id_embedding"]
84+
85+
3286
def test_embedding_with_target():
3387
id_embeddings = np.random.rand(1000, 10)
3488
df = pd.DataFrame(
@@ -116,7 +170,7 @@ def test_embedding_np_mmap_dl_with_lookup(tmpdir, rev_embedding_ids, np_embeddin
116170
dataset = dataset.repartition(10)
117171
schema = dataset.schema
118172
for col_name in cat_names:
119-
schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL, Tags.EMBEDDING])
173+
schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL])
120174
dataset.schema = schema
121175

122176
data_loader = Loader(
@@ -148,7 +202,7 @@ def test_embedding_np_dl_no_lookup(tmpdir, embedding_ids, embeddings_from_datafr
148202
dataset = dataset.repartition(10)
149203
schema = dataset.schema
150204
for col_name in cat_names:
151-
schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL, Tags.EMBEDDING])
205+
schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL])
152206
dataset.schema = schema
153207
paths = sorted(glob.glob(f"{embeddings_from_dataframe}/*"))
154208
embeddings_ds = Dataset(paths)
@@ -183,15 +237,17 @@ def test_embedding_np_dl_with_lookup(tmpdir, rev_embedding_ids, embeddings_from_
183237
dataset = dataset.repartition(10)
184238
schema = dataset.schema
185239
for col_name in cat_names:
186-
schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL, Tags.EMBEDDING])
240+
schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL])
187241
dataset.schema = schema
188242
paths = sorted(glob.glob(f"{embeddings_from_dataframe}/*"))
189243
embeddings_ds = Dataset(paths)
190244
embeddings_np = embeddings_ds.to_ddf().compute().to_numpy()[:, 1:]
191245
data_loader = Loader(
192246
dataset,
193247
batch_size=batch_size,
194-
transforms=[EmbeddingOperator(embeddings_np, id_lookup_table=embedding_ids.to_numpy())],
248+
transforms=[
249+
EmbeddingOperator(embeddings_np, id_lookup_table=embedding_ids.to_numpy().ravel())
250+
],
195251
shuffle=False,
196252
device=cpu,
197253
)
@@ -222,7 +278,7 @@ def test_embedding_np_dl_with_lookup_ragged(
222278
dataset = dataset.repartition(10)
223279
schema = dataset.schema
224280
for col_name in cat_names:
225-
schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL, Tags.EMBEDDING])
281+
schema[col_name] = schema[col_name].with_tags([Tags.CATEGORICAL])
226282
dataset.schema = schema
227283
paths = sorted(glob.glob(f"{embeddings_from_dataframe}/*"))
228284
embeddings_ds = Dataset(paths)

0 commit comments

Comments
 (0)