diff --git a/src/ldp/nn/handlers/chunking.py b/src/ldp/nn/handlers/chunking.py index 9369d6d3..283c9dc1 100644 --- a/src/ldp/nn/handlers/chunking.py +++ b/src/ldp/nn/handlers/chunking.py @@ -154,18 +154,18 @@ def _split_value(self, value): """ if isinstance(value, torch.Tensor): chunks = list(torch.chunk(value, self.num_chunks, dim=0)) - dummy_chunk_flags = [] - for i in range(self.num_chunks): - if i >= len(chunks): - # Chunk 0 will always exist, and we need only a batch of one ([:1]) - # to activate the model. - # We use the first element of the existing chunks as real data to avoid - # errors in the model that may expect a specific token structure. - chunks.append(chunks[0][:1]) - dummy_chunk_flags.append(True) - else: - dummy_chunk_flags.append(False) + real_chunks_len = len(chunks) + + # Pre-determining if dummy chunks are needed and their count + if real_chunks_len < self.num_chunks: + dummy_count = self.num_chunks - real_chunks_len + dummy_chunks = [chunks[0][:1]] * dummy_count + dummy_chunk_flags = [False] * real_chunks_len + [True] * dummy_count + chunks.extend(dummy_chunks) + else: + dummy_chunk_flags = [False] * self.num_chunks return chunks, dummy_chunk_flags + # Non-tensor values are replicated return [value] * self.num_chunks, None