diff --git a/src/ldp/nn/handlers/chunking.py b/src/ldp/nn/handlers/chunking.py index 9369d6d3..543d099e 100644 --- a/src/ldp/nn/handlers/chunking.py +++ b/src/ldp/nn/handlers/chunking.py @@ -153,19 +153,15 @@ def _split_value(self, value): Right now, only torch.Tensor values are split. Non-tensor values are replicated. """ if isinstance(value, torch.Tensor): - chunks = list(torch.chunk(value, self.num_chunks, dim=0)) - dummy_chunk_flags = [] - for i in range(self.num_chunks): - if i >= len(chunks): - # Chunk 0 will always exist, and we need only a batch of one ([:1]) - # to activate the model. - # We use the first element of the existing chunks as real data to avoid - # errors in the model that may expect a specific token structure. - chunks.append(chunks[0][:1]) - dummy_chunk_flags.append(True) - else: - dummy_chunk_flags.append(False) - - return chunks, dummy_chunk_flags + chunks = torch.chunk(value, self.num_chunks, dim=0) + num_real_chunks = len(chunks) + dummy_chunk_flags = [False] * num_real_chunks + + if num_real_chunks < self.num_chunks: + dummy_chunk_flags.extend([True] * (self.num_chunks - num_real_chunks)) + dummy_chunk = chunks[0][:1] + chunks = chunks + (dummy_chunk,) * (self.num_chunks - num_real_chunks) + + return list(chunks), dummy_chunk_flags # Non-tensor values are replicated return [value] * self.num_chunks, None