diff --git a/DESCRIPTION b/DESCRIPTION
index 4772f52fd..045182c1f 100644
--- a/DESCRIPTION
+++ b/DESCRIPTION
@@ -91,6 +91,7 @@ Collate:
     'CallbackSetUnfreeze.R'
     'ContextTorch.R'
     'DataBackendLazy.R'
+    'DataBackendLazyTensors.R'
     'utils.R'
     'DataDescriptor.R'
     'LearnerTorch.R'
diff --git a/NAMESPACE b/NAMESPACE
index 122ebafb6..1ed76e181 100644
--- a/NAMESPACE
+++ b/NAMESPACE
@@ -7,11 +7,15 @@ S3method("[[<-",lazy_tensor)
 S3method(as.data.table,DictionaryMlr3torchCallbacks)
 S3method(as.data.table,DictionaryMlr3torchLosses)
 S3method(as.data.table,DictionaryMlr3torchOptimizers)
+S3method(as_data_backend,dataset)
 S3method(as_data_descriptor,dataset)
 S3method(as_lazy_tensor,DataDescriptor)
 S3method(as_lazy_tensor,dataset)
 S3method(as_lazy_tensor,numeric)
 S3method(as_lazy_tensor,torch_tensor)
+S3method(as_lazy_tensors,dataset)
+S3method(as_task_classif,dataset)
+S3method(as_task_regr,dataset)
 S3method(as_torch_callback,R6ClassGenerator)
 S3method(as_torch_callback,TorchCallback)
 S3method(as_torch_callback,character)
@@ -27,6 +31,8 @@ S3method(as_torch_optimizer,character)
 S3method(as_torch_optimizer,torch_optimizer_generator)
 S3method(c,lazy_tensor)
 S3method(col_info,DataBackendLazy)
+S3method(col_info,DataBackendLazyTensors)
+S3method(distinct_values,lazy_tensor)
 S3method(format,lazy_tensor)
 S3method(hash_input,TorchIngressToken)
 S3method(hash_input,lazy_tensor)
@@ -71,6 +77,7 @@ export(CallbackSetTB)
 export(CallbackSetUnfreeze)
 export(ContextTorch)
 export(DataBackendLazy)
+export(DataBackendLazyTensors)
 export(DataDescriptor)
 export(LearnerTorch)
 export(LearnerTorchFeatureless)
@@ -161,6 +168,7 @@ export(TorchLoss)
 export(TorchOptimizer)
 export(as_data_descriptor)
 export(as_lazy_tensor)
+export(as_lazy_tensors)
 export(as_lr_scheduler)
 export(as_torch_callback)
 export(as_torch_callbacks)
diff --git a/NEWS.md b/NEWS.md
index 9f2cbf174..1d581c0b2 100644
--- a/NEWS.md
+++ b/NEWS.md
@@ -11,6 +11,8 @@
   This means that for binary classification tasks, `t_loss("cross_entropy")` now generates
   `nn_bce_with_logits_loss` instead of `nn_cross_entropy_loss`.
   This also came with a reparametrization of the `t_loss("cross_entropy")` loss (thanks to @tdhock, #374).
+* fix: `NA` is now a valid shape for lazy tensors.
+* feat: `lazy_tensor`s of length 0 can now be materialized.
 
 # mlr3torch 0.2.1
 
diff --git a/R/DataBackendLazyTensors.R b/R/DataBackendLazyTensors.R
new file mode 100644
index 000000000..2328a86b4
--- /dev/null
+++ b/R/DataBackendLazyTensors.R
@@ -0,0 +1,249 @@
+
+#' @title Special Backend for Lazy Tensors
+#' @description
+#' This backend essentially allows you to use a [`torch::dataset`] directly with
+#' an [`mlr3::Learner`].
+#'
+#' * The data cannot contain missing values, as [`lazy_tensor`]s do not support them.
+#'   For this reason, calling `$missings()` will always return `0` for all columns.
+#' * The `$distinct()` method will consider two lazy tensors that refer to the same element of a
+#'   [`DataDescriptor`] to be identical.
+#'   This means, that it might be underreporting the number of distinct values of lazy tensor columns.
+#'
+#' @export
+#' @examplesIf torch::torch_is_installed()
+#' # used as feature in all backends
+#' x = torch_randn(100, 10)
+#' # regression
+#' ds_regr = tensor_dataset(x = x, y = torch_randn(100, 1))
+#' be_regr = as_data_backend(ds_regr, converter = list(y = as.numeric))
+#' be_regr$head()
+#'
+#'
+#' # binary classification: underlying target tensor must be float in [0, 1]
+#' ds_binary = tensor_dataset(x = x, y = torch_randint(0, 2, c(100, 1))$float())
+#' be_binary = as_data_backend(ds_binary, converter = list(
+#'   y = function(x) factor(as.integer(x), levels = c(0, 1), labels = c("A", "yes"))
+#' ))
+#' be_binary$head()
+#'
+#' # multi-class classification: underlying target tensor must be integer in [1, K]
+#' ds_multiclass = tensor_dataset(x = x, y = torch_randint(1, 4, size = c(100, 1)))
+#' be_multiclass = as_data_backend(ds_multiclass, converter = list(y = as.numeric))
+#' be_multiclass$head()
+
+DataBackendLazyTensors = R6Class("DataBackendLazyTensors",
+  cloneable = FALSE,
+  inherit = DataBackendDataTable,
+  public = list(
+    chunk_size = NULL,
+    #' @description
+    #' Create a new instance of this [R6][R6::R6Class] class.
+    #' @param data (`data.table`)\cr
+    #'   Data containing (among others) [`lazy_tensor`] columns.
+    #' @param primary_key (`character(1)`)\cr
+    #'   Name of the column used as primary key.
+    #' @param converter (named `list()` of `function`s)\cr
+    #'   A named list of functions that convert the lazy tensor columns to their R representation.
+    #'   The names must be the names of the columns that need conversion.
+    #' @param cache (`character()`)\cr
+    #'   Names of the columns that should be cached.
+    #'   Per default, all columns that are converted are cached.
+    initialize = function(data, primary_key, converter, cache = names(converter), chunk_size = 100) {
+      private$.converter = assert_list(converter, types = "function", any.missing = FALSE)
+      assert_subset(names(converter), colnames(data))
+      assert_subset(cache, names(converter), empty.ok = TRUE)
+      private$.cached_cols = assert_subset(cache, names(converter))
+      self$chunk_size = assert_int(chunk_size, lower = 1L)
+      walk(names(private$.converter), function(nm) {
+        if (!inherits(data[[nm]], "lazy_tensor")) {
+          stopf("Column '%s' is not a lazy tensor.", nm)
+        }
+      })
+      super$initialize(data, primary_key)
+      # select the column whose name is stored in primary_key from private$.data but keep its name
+      private$.data_cache = private$.data[, primary_key, with = FALSE]
+    },
+    data = function(rows, cols) {
+      rows = assert_integerish(rows, coerce = TRUE)
+      assert_names(cols, type = "unique")
+
+      if (getOption("mlr3torch.data_loading", FALSE)) {
+        # no caching, no materialization as this is called in the training loop
+        return(super$data(rows, cols))
+      }
+      if (all(intersect(cols, private$.cached_cols) %in% names(private$.data_cache))) {
+        expensive_cols = intersect(cols, private$.cached_cols)
+        other_cols = setdiff(cols, expensive_cols)
+        cache_hit = private$.data_cache[list(rows), expensive_cols, on = self$primary_key, with = FALSE]
+        complete = complete.cases(cache_hit)
+        cache_hit = cache_hit[complete]
+        if (nrow(cache_hit) == length(rows)) {
+          tbl = cbind(cache_hit, super$data(rows, other_cols))
+          setcolorder(tbl, cols)
+          return(tbl)
+        }
+        combined = rbindlist(list(cache_hit, private$.load_and_cache(rows[!complete], expensive_cols)))
+        reorder = vector("integer", nrow(combined))
+        reorder[complete] = seq_len(nrow(cache_hit))
+        reorder[!complete] = nrow(cache_hit) + seq_len(nrow(combined) - nrow(cache_hit))
+
+        tbl = cbind(combined[reorder], super$data(rows, other_cols))
+        setcolorder(tbl, cols)
+        return(tbl)
+      }
+
+      private$.load_and_cache(rows, cols)
+    },
+    head = function(n = 6L) {
+      if (getOption("mlr3torch.data_loading", FALSE)) {
+        return(super$head(n))
+      }
+
+      self$data(seq_len(n), self$colnames)
+    },
+    missings = function(rows, cols) {
+      set_names(rep(0L, length(cols)), cols)
+    }
+  ),
+  active = list(
+    converter = function(rhs) {
+      assert_ro_binding(rhs)
+      private$.converter
+    }
+  ),
+  private = list(
+    # call this function only with rows that are not in the cache yet
+    .load_and_cache = function(rows, cols) {
+      # Process columns that need conversion
+      tbl = super$data(rows, cols)
+      cols_to_convert = intersect(names(private$.converter), names(tbl))
+      tbl_to_mat = tbl[, cols_to_convert, with = FALSE]
+      # chunk the rows of tbl_to_mat into chunks of size self$chunk_size, apply materialize
+      n = nrow(tbl_to_mat)
+      chunks = split(seq_len(n), rep(seq_len(ceiling(n / self$chunk_size)), each = self$chunk_size, length.out = n))
+
+      tbl_mat = if (n == 0) {
+        set_names(list(torch_empty(0)), names(tbl_to_mat))
+      } else {
+        set_names(lapply(transpose_list(lapply(chunks, function(chunk) {
+          materialize(tbl_to_mat[chunk, ], rbind = TRUE)
+        })), torch_cat, dim = 1L), names(tbl_to_mat))
+      }
+
+      for (nm in cols_to_convert) {
+        converted = private$.converter[[nm]](tbl_mat[[nm]])
+        tbl[[nm]] = converted
+
+        if (nm %in% private$.cached_cols) {
+          set(private$.data_cache, i = rows, j = nm, value = converted)
+        }
+      }
+      return(tbl)
+    },
+    .data_cache = NULL,
+    .converter = NULL,
+    .cached_cols = NULL
+  )
+)
+
+#' @export
+as_data_backend.dataset = function(x, dataset_shapes, ...) {
+  tbl = as_lazy_tensors(x, dataset_shapes, ...)
+  tbl$row_id = seq_len(nrow(tbl))
+  DataBackendLazyTensors$new(tbl, primary_key = "row_id", ...)
+}
+
+#' @export
+as_task_classif.dataset = function(x, target, levels, converter = NULL, dataset_shapes = NULL, chunk_size = 100, cache = names(converter), ...) {
+  if (length(x) < 2) {
+    stopf("Dataset must have at least 2 rows.")
+  }
+  batch = dataloader(x, batch_size = 2)$.iter()$.next()
+  if (is.null(converter)) {
+    if (length(levels) == 2) {
+      if (batch[[target]]$dtype != torch_float()) {
+        stopf("Target must be a float tensor, but has dtype %s", batch[[target]]$dtype)
+      }
+      if (test_equal(batch[[target]]$shape, c(2L, 1L))) {
+        converter = set_names(list(crate(function(x) factor(as.integer(x), levels = 0:1, labels = levels), levels)), target)
+      } else {
+        stopf("Target must be a float tensor of shape (batch_size, 1), but has shape (batch_size, %s)",
+          paste(batch[[target]]$shape[-1L], collapse = ", "))
+      }
+      converter = set_names(list(crate(function(x) factor(as.integer(x), levels = 0:1, labels = levels), levels)), target)
+    } else {
+      if (batch[[target]]$dtype != torch_int()) {
+        stopf("Target must be an integer tensor, but has dtype %s", batch[[target]]$dtype)
+      }
+      if (test_equal(batch[[target]]$shape, 2L)) {
+        converter = set_names(list(crate(function(x) factor(as.integer(x), labels = levels), levels)), target)
+      } else {
+        stopf("Target must be an integer tensor of shape (batch_size), but has shape (batch_size, %s)",
+          paste(batch[[target]]$shape[-1L], collapse = ", "))
+      }
+      converter = set_names(list(crate(function(x) factor(as.integer(x), labels = levels), levels)), target)
+    }
+  }
+  be = as_data_backend(x, dataset_shapes, converter = converter, cache = cache, chunk_size = chunk_size)
+  as_task_classif(be, target = target, ...)
+}
+
+#' @export
+as_task_regr.dataset = function(x, target, converter = NULL, dataset_shapes = NULL, chunk_size = 100, cache = names(converter), ...) {
+  if (length(x) < 2) {
+    stopf("Dataset must have at least 2 rows.")
+  }
+  if (is.null(converter)) {
+    converter = set_names(list(as.numeric), target)
+  }
+  batch = dataloader(x, batch_size = 2)$.iter()$.next()
+
+  if (batch[[target]]$dtype != torch_float()) {
+    stopf("Target must be a float tensor, but has dtype %s", batch[[target]]$dtype)
+  }
+
+  if (!test_equal(batch[[target]]$shape, c(2L, 1L))) {
+    stopf("Target must be a float tensor of shape (batch_size, 1), but has shape (batch_size, %s)",
+      paste(batch[[target]]$shape[-1L], collapse = ", "))
+  }
+
+  dataset_shapes = get_or_check_dataset_shapes(x, dataset_shapes)
+  be = as_data_backend(x, dataset_shapes, converter = converter, cache = cache, chunk_size = chunk_size)
+  as_task_regr(be, target = target, ...)
+}
+
+#' @export
+col_info.DataBackendLazyTensors = function(x, ...) { # nolint
+  first_row = x$head(1L)
+  types = map_chr(first_row, function(x) class(x)[1L])
+  discrete = setdiff(names(types)[types %chin% c("factor", "ordered")], x$primary_key)
+  levels = insert_named(named_list(names(types)), map(first_row[, discrete, with = FALSE], levels))
+  data.table(id = names(types), type = unname(types), levels = levels, key = "id")
+}
+
+
+# conservative check that avoids that a pseudo-lazy-tensor is preprocessed by some pipeop
+# @param be
+#   the backend
+# @param candidates
+#   the feature and target names
+# @param visited
+#  Union of all colnames already visited
+# @return visited
+check_lazy_tensors_backend = function(be, candidates, visited = character()) {
+  if (inherits(be, "DataBackendRbind") || inherits(be, "DataBackendCbind")) {
+    bs = be$.__enclos_env__$private$.data
+    # first we check b2, then b1, because b2 possibly overshadows some b1 rows/cols
+    visited = check_lazy_tensors_backend(bs$b2, candidates, visited)
+    check_lazy_tensors_backend(bs$b1, candidates, visited)
+  } else {
+    if (inherits(be, "DataBackendLazyTensors")) {
+      if (any(names(be$converter) %in% visited)) {
+        converter_cols = names(be$converter)[names(be$converter) %in% visited]
+        stopf("A converter column ('%s') from a DataBackendLazyTensors was presumably preprocessed by some PipeOp. This can cause inefficiencies and is therefore not allowed. If you want to preprocess them, please directly encode them as R types.", paste0(converter_cols, collapse = ", ")) # nolint
+      }
+    }
+    union(visited, intersect(candidates, be$colnames))
+  }
+}
diff --git a/R/DataDescriptor.R b/R/DataDescriptor.R
index 1bf3cd68d..6a1d65740 100644
--- a/R/DataDescriptor.R
+++ b/R/DataDescriptor.R
@@ -60,14 +60,7 @@ DataDescriptor = R6Class("DataDescriptor",
       # For simplicity we here require the first dimension of the shape to be NA so we don't have to deal with it,
       # e.g. during subsetting
 
-      if (is.null(dataset_shapes)) {
-        if (is.null(dataset$.getbatch)) {
-          stopf("dataset_shapes must be provided if dataset does not have a `.getbatch` method.")
-        }
-        dataset_shapes = infer_shapes_from_getbatch(dataset)
-      } else {
-        assert_compatible_shapes(dataset_shapes, dataset)
-      }
+      dataset_shapes = get_or_check_dataset_shapes(dataset, dataset_shapes)
 
       if (is.null(graph)) {
         # avoid name conflicts
@@ -84,8 +77,7 @@ DataDescriptor = R6Class("DataDescriptor",
         assert_true(length(graph$pipeops) >= 1L)
       }
       # no preprocessing, dataset returns only a single element (there we can infer a lot)
-      simple_case = length(graph$pipeops) == 1L && inherits(graph$pipeops[[1L]], "PipeOpNOP") &&
-        length(dataset_shapes) == 1L
+      simple_case = (length(graph$pipeops) == 1L) && inherits(graph$pipeops[[1L]], "PipeOpNOP")
 
       if (is.null(input_map) && nrow(graph$input) == 1L && length(dataset_shapes) == 1L) {
         input_map = names(dataset_shapes)
@@ -100,7 +92,7 @@ DataDescriptor = R6Class("DataDescriptor",
         assert_choice(pointer[[2]], graph$pipeops[[pointer[[1]]]]$output$name)
       }
       if (is.null(pointer_shape) && simple_case) {
-        pointer_shape = dataset_shapes[[1L]]
+        pointer_shape = dataset_shapes[[input_map]]
       } else {
         assert_shape(pointer_shape, null_ok = TRUE)
       }
@@ -225,13 +217,14 @@ infer_shapes_from_getbatch = function(ds) {
 }
 
 assert_compatible_shapes = function(shapes, dataset) {
-  assert_shapes(shapes, null_ok = TRUE, unknown_batch = TRUE, named = TRUE)
+  shapes = assert_shapes(shapes, null_ok = TRUE, unknown_batch = TRUE, named = TRUE, coerce = TRUE)
 
   # prevent user from e.g. forgetting to wrap the return in a list
-  example = if (is.null(dataset$.getbatch)) {
-    dataset$.getitem(1L)
-  } else {
+  has_getbatch = !is.null(dataset$.getbatch)
+  example = if (has_getbatch) {
     dataset$.getbatch(1L)
+  } else {
+    dataset$.getitem(1L)
   }
   if (!test_list(example, names = "unique") || !test_permutation(names(example), names(shapes))) {
     stopf("Dataset must return a list with named elements that are a permutation of the dataset_shapes names.")
@@ -242,17 +235,17 @@ assert_compatible_shapes = function(shapes, dataset) {
     }
   })
 
-  if (is.null(dataset$.getbatch)) {
-    example = map(example, function(x) x$unsqueeze(1))
-  }
-
   iwalk(shapes, function(dataset_shape, name) {
-    if (!is.null(dataset_shape) && !test_equal(shapes[[name]][-1], example[[name]]$shape[-1L])) {
-      expected_shape = example[[name]]$shape
-      expected_shape[1] = NA
+    observed_shape = example[[name]]$shape
+    if (has_getbatch) {
+      observed_shape[1L] = NA_integer_
+    } else {
+      observed_shape = c(NA_integer_, observed_shape)
+    }
+    if (!is.null(dataset_shape) && !test_equal(observed_shape, dataset_shape)) {
       stopf(paste0("First batch from dataset is incompatible with the provided shape of %s:\n",
-        "* Provided shape: %s.\n* Expected shape: %s."), name,
-        shape_to_str(unname(shapes[name])), shape_to_str(list(expected_shape)))
+        "* Provided shape: %s.\n* Observed shape: %s."), name,
+        shape_to_str(unname(shapes[name])), shape_to_str(list(observed_shape)))
     }
   })
 }
diff --git a/R/LearnerTorch.R b/R/LearnerTorch.R
index af1db6e5e..068d54a84 100644
--- a/R/LearnerTorch.R
+++ b/R/LearnerTorch.R
@@ -109,7 +109,7 @@
 #'
 #'   For information on the expected target encoding of `y`, see section *Network Head and Target Encoding*.
 #'   Moreover, one needs to pay attention respect the row ids of the provided task.
-#'   It is recommended to relu on [`task_dataset`] for creating the [`dataset`][torch::dataset].
+#'   It is strongly recommended to use the [`task_dataset`] class to create the dataset.
 #'
 #' It is also possible to overwrite the private `.dataloader()` method.
 #' This must respect the dataloader parameters from the [`ParamSet`][paradox::ParamSet].
diff --git a/R/lazy_tensor.R b/R/lazy_tensor.R
index d050f8545..00b397575 100644
--- a/R/lazy_tensor.R
+++ b/R/lazy_tensor.R
@@ -197,6 +197,19 @@ as_lazy_tensor.torch_tensor = function(x, ...) { # nolint
   as_lazy_tensor(ds, dataset_shapes = list(x = c(NA, dim(x)[-1])))
 }
 
+#' @export
+as_lazy_tensors = function(x, ...) {
+  UseMethod("as_lazy_tensors")
+}
+
+#' @export
+as_lazy_tensors.dataset = function(x, dataset_shapes = NULL, ...) {
+  dataset_shapes = get_or_check_dataset_shapes(x, dataset_shapes)
+  set_names(map_dtc(names(dataset_shapes), function(shape) {
+    as_lazy_tensor(x, dataset_shapes = dataset_shapes, input_map = shape)
+  }), names(dataset_shapes))
+}
+
 #' Assert Lazy Tensor
 #'
 #' Asserts whether something is a lazy tensor.
@@ -339,3 +352,13 @@ rep.lazy_tensor = function(x, ...) {
 rep_len.lazy_tensor = function(x, ...) {
   set_class(NextMethod(), c("lazy_tensor", "list"))
 }
+
+
+#' @export
+distinct_values.lazy_tensor = function(x, drop = TRUE, na_rm = TRUE) {
+  if (!length(x)) {
+    return(x)
+  }
+  ids = distinct_values(map_int(x, 1))
+  lazy_tensor(dd(x), ids)
+}
\ No newline at end of file
diff --git a/R/learner_torch_methods.R b/R/learner_torch_methods.R
index 79cebaa4a..7259bf587 100644
--- a/R/learner_torch_methods.R
+++ b/R/learner_torch_methods.R
@@ -18,8 +18,10 @@ learner_torch_predict = function(self, private, super, task, param_vals) {
   private$.encode_prediction(predict_tensor = predict_tensor, task = task)
 }
 
+
 learner_torch_train = function(self, private, super, task, param_vals) {
   # Here, all param_vals (like seed = "random" or device = "auto") have already been resolved
+  check_lazy_tensors_backend(task$backend, c(task$feature_names, task$target_names))
   dataset_train = private$.dataset(task, param_vals)
   dataset_train = as_multi_tensor_dataset(dataset_train, param_vals)
   loader_train = private$.dataloader(dataset_train, param_vals)
@@ -356,3 +358,5 @@ as_multi_tensor_dataset = function(dataset, param_vals) {
     dataset
   }
 }
+
+
diff --git a/R/materialize.R b/R/materialize.R
index 849024ad4..ee113830d 100644
--- a/R/materialize.R
+++ b/R/materialize.R
@@ -63,6 +63,13 @@ materialize.list = function(x, device = "cpu", rbind = FALSE, cache = "auto", ..
 
   map(x, function(col) {
     if (is_lazy_tensor(col)) {
+      if (length(col) == 0L) {
+        if (rbind) {
+          return(torch_empty(0L))
+        } else {
+          return(list())
+        }
+      }
       materialize_internal(col, device = device, cache = cache, rbind = rbind)
     } else {
       col
@@ -76,16 +83,30 @@ materialize.list = function(x, device = "cpu", rbind = FALSE, cache = "auto", ..
 #' @method materialize data.frame
 #' @export
 materialize.data.frame = function(x, device = "cpu", rbind = FALSE, cache = "auto", ...) { # nolint
+  if (nrow(x) == 0L) {
+    if (rbind) {
+      set_names(replicate(ncol(x), torch_empty(0L)), names(x))
+    } else {
+      set_names(replicate(ncol(x), list()), names(x))
+    }
+  }
   materialize(as.list(x), device = device, rbind = rbind, cache = cache)
 }
 
 
 #' @export
 materialize.lazy_tensor = function(x, device = "cpu", rbind = FALSE, ...) { # nolint
+  if (length(x) == 0L) {
+    if (rbind) {
+      return(torch_empty(0L))
+    } else {
+      return(list())
+    }
+  }
   materialize_internal(x = x, device = device, cache = NULL, rbind = rbind)
 }
 
-get_input = function(ds, ids, varying_shapes, rbind) {
+get_input = function(ds, ids, varying_shapes) {
   if (is.null(ds$.getbatch)) { # .getindex is never NULL but a function that errs if it was not defined
     x = map(ids, function(id) map(ds$.getitem(id), function(x) x$unsqueeze(1)))
     if (varying_shapes) {
@@ -154,9 +175,6 @@ get_output = function(input, graph, varying_shapes, rbind, device) {
 #' @return [`lazy_tensor()`]
 #' @keywords internal
 materialize_internal = function(x, device = "cpu", cache = NULL, rbind) {
-  if (!length(x)) {
-    stopf("Cannot materialize lazy tensor of length 0.")
-  }
   do_caching = !is.null(cache)
   ids = map_int(x, 1)
 
@@ -183,7 +201,7 @@ materialize_internal = function(x, device = "cpu", cache = NULL, rbind) {
   }
 
   if (!do_caching || !input_hit) {
-    input = get_input(ds, ids, varying_shapes, rbind)
+    input = get_input(ds, ids, varying_shapes)
   }
 
   if (do_caching && !input_hit) {
diff --git a/R/shape.R b/R/shape.R
index d1fdda83d..7970c37ec 100644
--- a/R/shape.R
+++ b/R/shape.R
@@ -30,7 +30,7 @@ test_shape = function(shape, null_ok = FALSE, unknown_batch = NULL, len = NULL)
   if (is.null(shape) && null_ok) {
     return(TRUE)
   }
-  ok = test_integerish(shape, min.len = 2L, all.missing = FALSE, any.missing = TRUE, len = len)
+  ok = test_integerish(shape, min.len = 1L, any.missing = TRUE, len = len)
 
   if (!ok) {
     return(FALSE)
diff --git a/R/task_dataset.R b/R/task_dataset.R
index bd088d1bf..10ab6e268 100644
--- a/R/task_dataset.R
+++ b/R/task_dataset.R
@@ -81,13 +81,21 @@ task_dataset = dataset("task_dataset",
   .getbatch = function(index) {
     cache = if (self$cache_lazy_tensors) new.env()
 
-    datapool = self$task$data(rows = self$task$row_ids[index], cols = self$all_features)
+    datapool = withr::with_options(list(mlr3torch.data_loading = TRUE), {
+      self$task$data(rows = self$task$row_ids[index], cols = self$all_features)
+    })
+
     x = lapply(self$feature_ingress_tokens, function(it) {
       it$batchgetter(datapool[, it$features, with = FALSE], cache = cache)
     })
 
     y = if (!is.null(self$target_batchgetter)) {
-      self$target_batchgetter(datapool[, self$task$target_names, with = FALSE])
+      target = datapool[, self$task$target_names, with = FALSE]
+      if (!inherits(target[[1L]], "lazy_tensor")) {
+        self$target_batchgetter(target)
+      } else {
+        materialize(target[[1L]], rbind = TRUE)
+      }
     }
     out = list(x = x, .index = torch_tensor(index, dtype = torch_long()))
     if (!is.null(y)) out$y = y
diff --git a/R/utils.R b/R/utils.R
index 2f9d68302..bcb17af66 100644
--- a/R/utils.R
+++ b/R/utils.R
@@ -190,7 +190,10 @@ list_to_batch = function(tensors) {
 }
 
 auto_cache_lazy_tensors = function(lts) {
-  any(duplicated(map_chr(lts, function(x) dd(x)$dataset_hash)))
+  if (length(lts) <= 1L) {
+    return(FALSE)
+  }
+  anyDuplicated(unlist(map_if(lts, function(x) length(x) > 0, function(x) dd(x)$dataset_hash))) > 0L
 }
 
 #' Replace the head of a network
@@ -300,6 +303,18 @@ infer_shapes = function(shapes_in, param_vals, output_names, fn, rowwise, id) {
   set_names(list(sout), output_names)
 }
 
+get_or_check_dataset_shapes = function(dataset, dataset_shapes) {
+  if (is.null(dataset_shapes)) {
+    if (is.null(dataset$.getbatch)) {
+      stopf("dataset_shapes must be provided if dataset does not have a `.getbatch` method.")
+    }
+    dataset_shapes = infer_shapes_from_getbatch(dataset)
+  } else {
+    assert_compatible_shapes(dataset_shapes, dataset)
+  }
+  dataset_shapes
+}
+
 #' @title Network Output Dimension
 #' @description
 #' Calculates the output dimension of a neural network for a given task that is expected by
diff --git a/TODO.md b/TODO.md
new file mode 100644
index 000000000..62d953809
--- /dev/null
+++ b/TODO.md
@@ -0,0 +1,25 @@
+* Add `as_lazy_tensors()`
+* Make it easier to se
+* Fix the bug that the shapes are reported as unknown below and make the code easier.
+  ```r
+  ds = dataset("test",
+    initialize = function() {
+      self$x = torch_randn(100, 10)
+      self$y = torch_randn(100, 1)
+    },
+    .getitem = function(i) {
+      list(x = self$x[i, ], y = self$y[i])
+    },
+    .length = function() {
+      nrow(self$x)
+    }
+  )()
+  x_lt = as_lazy_tensor(ds, list(x = c(NA, 10), y = c(NA, 1)), input_map = "x")
+  y_lt = as_lazy_tensor(ds, list(x = c(NA, 10), y = c(NA, 1)), input_map = "y")
+
+  tbl = data.table(x = x_lt, y = y_lt)
+  ```
+* Add checks on usage of `DataBackendLazyTensors` in `task_dataset`
+* Add optimization that truths values don't have to be loaded twice during resampling, i.e.
+  once for making the predictions and once for retrieving the truth column.
+* only allow caching converter columns in `DataBackendLazyTensors` (probably just remove the `cache` parameter)
\ No newline at end of file
diff --git a/man/DataBackendLazyTensors.Rd b/man/DataBackendLazyTensors.Rd
new file mode 100644
index 000000000..42880dec7
--- /dev/null
+++ b/man/DataBackendLazyTensors.Rd
@@ -0,0 +1,123 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/DataBackendLazyTensors.R
+\name{DataBackendLazyTensors}
+\alias{DataBackendLazyTensors}
+\title{Special Backend for Lazy Tensors}
+\description{
+This backend essentially allows you to use a \code{\link[torch:dataset]{torch::dataset}} directly with
+an \code{\link[mlr3:Learner]{mlr3::Learner}}.
+\itemize{
+\item The data cannot contain missing values, as \code{\link{lazy_tensor}}s do not support them.
+For this reason, calling \verb{$missings()} will always return \code{0} for all columns.
+\item The \verb{$distinct()} method will consider two lazy tensors that refer to the same element of a
+\code{\link{DataDescriptor}} to be identical.
+This means, that it might be underreporting the number of distinct values of lazy tensor columns.
+}
+}
+\examples{
+\dontshow{if (torch::torch_is_installed()) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf}
+# used as feature in all backends
+x = torch_randn(100, 10)
+# regression
+ds_regr = tensor_dataset(x = x, y = torch_randn(100, 1))
+be_regr = as_data_backend(ds_regr, converter = list(y = as.numeric))
+be_regr$head()
+
+
+# binary classification: underlying target tensor must be float in [0, 1]
+ds_binary = tensor_dataset(x = x, y = torch_randint(0, 2, c(100, 1))$float())
+be_binary = as_data_backend(ds_binary, converter = list(
+  y = function(x) factor(as.integer(x), levels = c(0, 1), labels = c("A", "yes"))
+))
+be_binary$head()
+
+# multi-class classification: underlying target tensor must be integer in [1, K]
+ds_multiclass = tensor_dataset(x = x, y = torch_randint(1, 4, size = c(100, 1)))
+be_multiclass = as_data_backend(ds_multiclass, converter = list(y = as.numeric))
+be_multiclass$head()
+\dontshow{\}) # examplesIf}
+}
+\section{Super classes}{
+\code{\link[mlr3:DataBackend]{mlr3::DataBackend}} -> \code{\link[mlr3:DataBackendDataTable]{mlr3::DataBackendDataTable}} -> \code{DataBackendLazyTensors}
+}
+\section{Methods}{
+\subsection{Public methods}{
+\itemize{
+\item \href{#method-DataBackendLazyTensors-new}{\code{DataBackendLazyTensors$new()}}
+\item \href{#method-DataBackendLazyTensors-data}{\code{DataBackendLazyTensors$data()}}
+\item \href{#method-DataBackendLazyTensors-head}{\code{DataBackendLazyTensors$head()}}
+\item \href{#method-DataBackendLazyTensors-missings}{\code{DataBackendLazyTensors$missings()}}
+}
+}
+\if{html}{\out{
+Inherited methods
+
+