Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion merlin/dataloader/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#
import contextlib
from functools import partial
from typing import Dict, Union

from merlin.core.compat.torch import torch as th
from merlin.dataloader.loader_base import LoaderBase
Expand Down Expand Up @@ -100,6 +101,7 @@ def convert_batch(self, batch):
inputs_table = TensorTable(inputs, _unsafe=True)
for col_name, col in inputs_table.items():
torch_inputs[col_name] = self.convert_col(col, column_type)
torch_inputs = TensorTable(torch_inputs, _unsafe=True).to_dict()

torch_targets = None
if targets is not None:
Expand All @@ -113,7 +115,29 @@ def convert_batch(self, batch):
targets_col = TensorColumn(targets, _unsafe=True)
torch_targets = self.convert_col(targets_col, column_type).values

return (TensorTable(torch_inputs, _unsafe=True).to_dict(), torch_targets)
# Convert float values to match the torch default floating point dtype
self._to_default_dtype(torch_inputs)
if torch_targets:
self._to_default_dtype(torch_targets)

return (torch_inputs, torch_targets)

def _to_default_dtype(self, inputs: Union[th.Tensor, Dict[str, th.Tensor]]) -> None:
"""Convert tensors to match default floating point dtype

Parameters
----------
inputs : Union[th.Tensor, Dict[str, th.Tensor]]
Input tensor or dictionary of tensors
"""
default_float_dtype = th.get_default_dtype()
if isinstance(inputs, dict):
for name, tensor in inputs.items():
if th.is_floating_point(tensor) and tensor.dtype != default_float_dtype:
inputs[name] = tensor.to(default_float_dtype)
elif isinstance(inputs, th.Tensor):
if th.is_floating_point(inputs) and inputs.dtype != default_float_dtype:
inputs.to(default_float_dtype)

def map(self, fn):
"""
Expand Down
21 changes: 21 additions & 0 deletions tests/unit/dataloader/test_torch_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
# If pytorch isn't installed skip these tests. Note that the
# torch_dataloader import needs to happen after this line
torch = pytest.importorskip("torch") # noqa

from torch import nn # noqa
from torch.utils.data import DataLoader, IterableDataset # noqa

import merlin.dataloader.torch as torch_dataloader # noqa
Expand All @@ -42,6 +44,25 @@
)


def test_default_dtype():
df = pd.DataFrame({"feature": np.array([0.1, 0.2], dtype="float64")})
dataset = Dataset(df)
dataloader = torch_dataloader.Loader(dataset, batch_size=1)
x, _ = dataloader.peek()
model = nn.Sequential(nn.LazyLinear(1))
model(x["feature"])


def test_default_dtype_double():
df = pd.DataFrame({"feature": np.array([0.1, 0.2], dtype="float64")})
dataset = Dataset(df)
torch.set_default_dtype(torch.float64)
dataloader = torch_dataloader.Loader(dataset, batch_size=1)
x, _ = dataloader.peek()
model = nn.Sequential(nn.LazyLinear(1)).double()
model(x["feature"])


def test_iterable_dataset():
df = pd.DataFrame({"feature": [1, 2, 3], "target": [0, 1, 0]})
dataset = Dataset(df)
Expand Down