22import math
33import os
44import random
5+ from collections .abc import Callable
56from pathlib import Path
67from typing import Any
78
@@ -114,14 +115,57 @@ def split_files(datapath: str, ratio: float = 0.9, seed: int = 0):
114115 return train_files , val_files
115116
116117
118+ # Data standardization functions
119+ StandardizeFnSig = Callable [[dict [str , Any ]], dict [str , Any ]]
120+
121+
122+ def standardize_data_v0 (data : dict [str , Any ]) -> dict [str , Any ]:
123+ # v0 data format:
124+ # {
125+ # "input_ids": [seq_len],
126+ # "loss_mask": [seq_len],
127+ # "hidden_state": [seq_len, 3 * hidden_size],
128+ # "target": [seq_len, hidden_size],
129+ # }
130+
131+ return {
132+ "hidden_states" : data ["hidden_state" ],
133+ "input_ids" : data ["input_ids" ],
134+ "verifier_last_hidden_states" : data ["target" ],
135+ "loss_mask" : data ["loss_mask" ],
136+ }
137+
138+
139+ def standardize_data_v1 (data : dict [str , Any ]) -> dict [str , Any ]:
140+ # v1 data format:
141+ # {
142+ # "input_ids": [seq_len],
143+ # "loss_mask": [seq_len],
144+ # "hidden_states": [
145+ # [seq_len, hidden_size],
146+ # [seq_len, hidden_size],
147+ # [seq_len, hidden_size],
148+ # ...
149+ # ],
150+ # }
151+
152+ return {
153+ "hidden_states" : torch .cat (data ["hidden_states" ][:- 1 ], dim = - 1 ),
154+ "input_ids" : data ["input_ids" ],
155+ "verifier_last_hidden_states" : data ["hidden_states" ][- 1 ],
156+ "loss_mask" : data ["loss_mask" ],
157+ }
158+
159+
117160class Eagle3SampleFileDataset (Dataset ):
118161 def __init__ (
119162 self ,
120163 max_len : int ,
121164 datapath : str | None = None ,
122165 file_list : list [str ] | None = None ,
123- transform = None ,
166+ transform : TransformTensors | None = None ,
124167 hidden_states_dtype = torch .float ,
168+ standardize_fn : StandardizeFnSig = standardize_data_v1 ,
125169 ):
126170 if datapath is not None and file_list is not None :
127171 raise ValueError ("Only one of datapath or file_list may be provided" )
@@ -134,6 +178,7 @@ def __init__(
134178 self .data : list [str ] = file_list
135179 self .max_len = max_len
136180 self .transform = transform
181+ self .standardize_fn = standardize_fn
137182 self .hidden_states_dtype = hidden_states_dtype
138183 self .approx_lengths = self ._compute_approx_lengths ()
139184
@@ -155,24 +200,24 @@ def _compute_approx_lengths(self) -> list[int]:
155200 def __getitem__ (self , index ) -> BatchType :
156201 data = torch .load (self .data [index ])
157202
158- # todo: standardize names during data generation and then remove this
159- data ["hidden_states" ] = data ["hidden_state" ]
160- data ["verifier_last_hidden_states" ] = data ["target" ]
161- del data ["hidden_state" ]
162- del data ["target" ]
203+ data = self .standardize_fn (data )
204+ # data structure: {
205+ # "hidden_states": [seq_len, 3 * hidden_size],
206+ # "input_ids": [seq_len],
207+ # "verifier_last_hidden_states": [seq_len, hidden_size],
208+ # "loss_mask": [seq_len],
209+ # }
163210
164- # todo: standardize dtypes during data generation and then remove this
211+ # Convert hidden states to the correct dtype
165212 data = {
166213 k : v .to (self .hidden_states_dtype ) if "hidden_states" in k else v
167214 for k , v in data .items ()
168215 }
169216
170- seq_len = data ["input_ids" ].shape [0 ]
171217 # Add lengths tensor
218+ seq_len = data ["input_ids" ].shape [0 ]
172219 data ["lengths" ] = torch .tensor ([seq_len ], dtype = torch .long )
173-
174- if self .transform :
175- data = self .transform (data )
220+ # shape: [1]
176221
177222 data ["position_ids" ] = torch .arange (seq_len , dtype = torch .long )
178223 # shape: [seq_len]
@@ -186,6 +231,10 @@ def __getitem__(self, index) -> BatchType:
186231 # "position_ids": [seq_len],
187232 # }
188233
234+ # Apply transform
235+ if self .transform :
236+ data = self .transform (data )
237+
189238 # Note: shift_batch will reduce seq_len by 1
190239 return shift_batch (data )
191240
@@ -194,15 +243,20 @@ def create_collate_fn(max_len: int):
194243 def collate_fn (batch : list [BatchType ]) -> BatchType :
195244 collated_data = {}
196245 for key in batch [0 ]:
246+ # Concatenate the tensors along the seq (0th) dimension
197247 collated_data [key ] = torch .cat ([b [key ] for b in batch ], dim = 0 )
248+ # shape: [total_seq_len, ...]
198249
199250 if key != "lengths" :
251+ # Slice and pad on seq (0th) dimension to max_len
200252 collated_data [key ] = slice_and_pad_to_length (
201253 collated_data [key ], max_len
202254 ).unsqueeze (0 )
203255 # shape: [1, max_len, ...]
204256
205- # Handle lengths update
257+ # Include lengths until while they fit in max_len
258+ # The last included length is (if necessary) truncated
259+ # Any additional lengths are discarded
206260 lengths = collated_data ["lengths" ]
207261 new_lengths = []
208262 cum_length = 0
0 commit comments