From 54f0afc2cd26a19e027298f7575581a377cf5db8 Mon Sep 17 00:00:00 2001 From: Oliver Holworthy <1216955+oliverholworthy@users.noreply.github.com> Date: Tue, 20 Jun 2023 11:44:03 +0100 Subject: [PATCH] Coerce float dtypes to default torch dtype --- merlin/dataloader/torch.py | 26 ++++++++++++++++++- .../unit/dataloader/test_torch_dataloader.py | 21 +++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/merlin/dataloader/torch.py b/merlin/dataloader/torch.py index c32642f1..7c8bd22b 100644 --- a/merlin/dataloader/torch.py +++ b/merlin/dataloader/torch.py @@ -15,6 +15,7 @@ # import contextlib from functools import partial +from typing import Dict, Union from merlin.core.compat.torch import torch as th from merlin.dataloader.loader_base import LoaderBase @@ -100,6 +101,7 @@ def convert_batch(self, batch): inputs_table = TensorTable(inputs, _unsafe=True) for col_name, col in inputs_table.items(): torch_inputs[col_name] = self.convert_col(col, column_type) + torch_inputs = TensorTable(torch_inputs, _unsafe=True).to_dict() torch_targets = None if targets is not None: @@ -113,7 +115,29 @@ def convert_batch(self, batch): targets_col = TensorColumn(targets, _unsafe=True) torch_targets = self.convert_col(targets_col, column_type).values - return (TensorTable(torch_inputs, _unsafe=True).to_dict(), torch_targets) + # Convert float values to match the torch default floating point dtype + self._to_default_dtype(torch_inputs) + if torch_targets: + self._to_default_dtype(torch_targets) + + return (torch_inputs, torch_targets) + + def _to_default_dtype(self, inputs: Union[th.Tensor, Dict[str, th.Tensor]]) -> None: + """Convert tensors to match default floating point dtype + + Parameters + ---------- + inputs : Union[th.Tensor, Dict[str, th.Tensor]] + Input tensor or dictionary of tensors + """ + default_float_dtype = th.get_default_dtype() + if isinstance(inputs, dict): + for name, tensor in inputs.items(): + if th.is_floating_point(tensor) and tensor.dtype != default_float_dtype: + inputs[name] = tensor.to(default_float_dtype) + elif isinstance(inputs, th.Tensor): + if th.is_floating_point(inputs) and inputs.dtype != default_float_dtype: + inputs.to(default_float_dtype) def map(self, fn): """ diff --git a/tests/unit/dataloader/test_torch_dataloader.py b/tests/unit/dataloader/test_torch_dataloader.py index 753dd8ec..a1bfe5f0 100644 --- a/tests/unit/dataloader/test_torch_dataloader.py +++ b/tests/unit/dataloader/test_torch_dataloader.py @@ -33,6 +33,8 @@ # If pytorch isn't installed skip these tests. Note that the # torch_dataloader import needs to happen after this line torch = pytest.importorskip("torch") # noqa + +from torch import nn # noqa from torch.utils.data import DataLoader, IterableDataset # noqa import merlin.dataloader.torch as torch_dataloader # noqa @@ -42,6 +44,25 @@ ) +def test_default_dtype(): + df = pd.DataFrame({"feature": np.array([0.1, 0.2], dtype="float64")}) + dataset = Dataset(df) + dataloader = torch_dataloader.Loader(dataset, batch_size=1) + x, _ = dataloader.peek() + model = nn.Sequential(nn.LazyLinear(1)) + model(x["feature"]) + + +def test_default_dtype_double(): + df = pd.DataFrame({"feature": np.array([0.1, 0.2], dtype="float64")}) + dataset = Dataset(df) + torch.set_default_dtype(torch.float64) + dataloader = torch_dataloader.Loader(dataset, batch_size=1) + x, _ = dataloader.peek() + model = nn.Sequential(nn.LazyLinear(1)).double() + model(x["feature"]) + + def test_iterable_dataset(): df = pd.DataFrame({"feature": [1, 2, 3], "target": [0, 1, 0]}) dataset = Dataset(df)