@@ -31,8 +31,10 @@ def __init__(self, cache: Cache, seed: int, drop_last: bool):
3131 self .drop_last = drop_last
3232
3333 @lru_cache (maxsize = 10 )
34- def get_len (self , distributed_env : _DistributedEnv , current_epoch : int ) -> int :
35- _ , intervals_per_ranks = self .get_chunks_and_intervals_per_ranks (distributed_env , current_epoch )
34+ def get_len (self , distributed_env : _DistributedEnv , num_workers : int , batch_size : int , current_epoch : int ) -> int :
35+ _ , intervals_per_ranks = self .get_chunks_and_intervals_per_ranks (
36+ distributed_env , num_workers , batch_size , current_epoch
37+ )
3638
3739 if self .drop_last :
3840 items_per_process = [
@@ -46,7 +48,9 @@ def get_len(self, distributed_env: _DistributedEnv, current_epoch: int) -> int:
4648 return sum ((interval [- 1 ] - interval [0 ]) for interval in intervals_per_ranks [distributed_env .global_rank ])
4749
4850 @abstractmethod
49- def get_chunks_and_intervals_per_ranks (self , distributed_env : _DistributedEnv , current_epoch : int ) -> Any :
51+ def get_chunks_and_intervals_per_ranks (
52+ self , distributed_env : _DistributedEnv , num_workers : int , batch_size : int , current_epoch : int
53+ ) -> Any :
5054 pass
5155
5256 @abstractmethod
@@ -59,14 +63,16 @@ class NoShuffle(Shuffle):
5963 is True."""
6064
6165 @lru_cache (maxsize = 10 )
62- def get_chunks_and_intervals_per_ranks (self , distributed_env : _DistributedEnv , current_epoch : int ) -> Any :
66+ def get_chunks_and_intervals_per_ranks (
67+ self , distributed_env : _DistributedEnv , num_workers : int , batch_size : int , current_epoch : int
68+ ) -> Any :
6369 # 1. Get the intervals
6470 chunk_intervals = self .cache .get_chunk_intervals ()
6571 indexes = range (len (chunk_intervals ))
6672
6773 # 2. Compute the items budget of each rank
6874 chunks_per_ranks , intervals_per_ranks = _associate_chunks_and_internals_to_ranks (
69- distributed_env , indexes , chunk_intervals , self .drop_last
75+ distributed_env , indexes , chunk_intervals , self .drop_last , num_workers , batch_size
7076 )
7177
7278 return chunks_per_ranks , intervals_per_ranks
@@ -94,7 +100,9 @@ class FullShuffle(Shuffle):
94100 """
95101
96102 @lru_cache (maxsize = 10 )
97- def get_chunks_and_intervals_per_ranks (self , distributed_env : _DistributedEnv , current_epoch : int ) -> Any :
103+ def get_chunks_and_intervals_per_ranks (
104+ self , distributed_env : _DistributedEnv , num_workers : int , batch_size : int , current_epoch : int
105+ ) -> Any :
98106 # 1. Get the intervals
99107 chunk_intervals = self .cache .get_chunk_intervals ()
100108
@@ -113,7 +121,7 @@ def get_chunks_and_intervals_per_ranks(self, distributed_env: _DistributedEnv, c
113121
114122 # 3. Compute the items budget of each rank
115123 chunks_per_ranks , intervals_per_ranks = _associate_chunks_and_internals_to_ranks (
116- distributed_env , shuffled_indexes , shuffled_chunk_intervals , self .drop_last
124+ distributed_env , shuffled_indexes , shuffled_chunk_intervals , self .drop_last , num_workers , batch_size
117125 )
118126
119127 # For the first epoch, no need of further shuffling
@@ -126,7 +134,7 @@ def get_chunks_and_intervals_per_ranks(self, distributed_env: _DistributedEnv, c
126134 shuffled_chunk_intervals = np .asarray (chunk_intervals )[shuffled_indexes ].tolist ()
127135
128136 chunks_per_ranks , intervals_per_ranks = _associate_chunks_and_internals_to_ranks (
129- distributed_env , shuffled_indexes , shuffled_chunk_intervals , self .drop_last
137+ distributed_env , shuffled_indexes , shuffled_chunk_intervals , self .drop_last , num_workers , batch_size
130138 )
131139
132140 return chunks_per_ranks , intervals_per_ranks
0 commit comments