@@ -160,16 +160,35 @@ def encoderConf(self) -> RobertaEncoderConf:
160
160
return self ._encoder_conf
161
161
162
162
163
- XLMR_BASE_ENCODER = RobertaBundle (
164
- _path = urljoin (_TEXT_BUCKET , "xlmr.base.encoder.pt" ),
165
- _encoder_conf = RobertaEncoderConf (vocab_size = 250002 ),
166
- transform = lambda : T .Sequential (
163
+ def xlmr_transform (truncate_length : int ) -> Module :
164
+ """Standard transform for XLMR models."""
165
+ return T .Sequential (
167
166
T .SentencePieceTokenizer (urljoin (_TEXT_BUCKET , "xlmr.sentencepiece.bpe.model" )),
168
167
T .VocabTransform (load_state_dict_from_url (urljoin (_TEXT_BUCKET , "xlmr.vocab.pt" ))),
169
- T .Truncate (254 ),
168
+ T .Truncate (truncate_length ),
170
169
T .AddToken (token = 0 , begin = True ),
171
170
T .AddToken (token = 2 , begin = False ),
172
- ),
171
+ )
172
+
173
+
174
+ def roberta_transform (truncate_length : int ) -> Module :
175
+ """Standard transform for RoBERTa models."""
176
+ return T .Sequential (
177
+ T .GPT2BPETokenizer (
178
+ encoder_json_path = urljoin (_TEXT_BUCKET , "gpt2_bpe_encoder.json" ),
179
+ vocab_bpe_path = urljoin (_TEXT_BUCKET , "gpt2_bpe_vocab.bpe" ),
180
+ ),
181
+ T .VocabTransform (load_state_dict_from_url (urljoin (_TEXT_BUCKET , "roberta.vocab.pt" ))),
182
+ T .Truncate (truncate_length ),
183
+ T .AddToken (token = 0 , begin = True ),
184
+ T .AddToken (token = 2 , begin = False ),
185
+ )
186
+
187
+
188
+ XLMR_BASE_ENCODER = RobertaBundle (
189
+ _path = urljoin (_TEXT_BUCKET , "xlmr.base.encoder.pt" ),
190
+ _encoder_conf = RobertaEncoderConf (vocab_size = 250002 ),
191
+ transform = lambda : xlmr_transform (254 ),
173
192
)
174
193
175
194
XLMR_BASE_ENCODER .__doc__ = """
@@ -193,13 +212,7 @@ def encoderConf(self) -> RobertaEncoderConf:
193
212
_encoder_conf = RobertaEncoderConf (
194
213
vocab_size = 250002 , embedding_dim = 1024 , ffn_dimension = 4096 , num_attention_heads = 16 , num_encoder_layers = 24
195
214
),
196
- transform = lambda : T .Sequential (
197
- T .SentencePieceTokenizer (urljoin (_TEXT_BUCKET , "xlmr.sentencepiece.bpe.model" )),
198
- T .VocabTransform (load_state_dict_from_url (urljoin (_TEXT_BUCKET , "xlmr.vocab.pt" ))),
199
- T .Truncate (510 ),
200
- T .AddToken (token = 0 , begin = True ),
201
- T .AddToken (token = 2 , begin = False ),
202
- ),
215
+ transform = lambda : xlmr_transform (510 ),
203
216
)
204
217
205
218
XLMR_LARGE_ENCODER .__doc__ = """
@@ -221,16 +234,7 @@ def encoderConf(self) -> RobertaEncoderConf:
221
234
ROBERTA_BASE_ENCODER = RobertaBundle (
222
235
_path = urljoin (_TEXT_BUCKET , "roberta.base.encoder.pt" ),
223
236
_encoder_conf = RobertaEncoderConf (vocab_size = 50265 ),
224
- transform = lambda : T .Sequential (
225
- T .GPT2BPETokenizer (
226
- encoder_json_path = urljoin (_TEXT_BUCKET , "gpt2_bpe_encoder.json" ),
227
- vocab_bpe_path = urljoin (_TEXT_BUCKET , "gpt2_bpe_vocab.bpe" ),
228
- ),
229
- T .VocabTransform (load_state_dict_from_url (urljoin (_TEXT_BUCKET , "roberta.vocab.pt" ))),
230
- T .Truncate (254 ),
231
- T .AddToken (token = 0 , begin = True ),
232
- T .AddToken (token = 2 , begin = False ),
233
- ),
237
+ transform = lambda : roberta_transform (254 ),
234
238
)
235
239
236
240
ROBERTA_BASE_ENCODER .__doc__ = """
@@ -263,16 +267,7 @@ def encoderConf(self) -> RobertaEncoderConf:
263
267
num_attention_heads = 16 ,
264
268
num_encoder_layers = 24 ,
265
269
),
266
- transform = lambda : T .Sequential (
267
- T .GPT2BPETokenizer (
268
- encoder_json_path = urljoin (_TEXT_BUCKET , "gpt2_bpe_encoder.json" ),
269
- vocab_bpe_path = urljoin (_TEXT_BUCKET , "gpt2_bpe_vocab.bpe" ),
270
- ),
271
- T .VocabTransform (load_state_dict_from_url (urljoin (_TEXT_BUCKET , "roberta.vocab.pt" ))),
272
- T .Truncate (510 ),
273
- T .AddToken (token = 0 , begin = True ),
274
- T .AddToken (token = 2 , begin = False ),
275
- ),
270
+ transform = lambda : roberta_transform (510 ),
276
271
)
277
272
278
273
ROBERTA_LARGE_ENCODER .__doc__ = """
@@ -302,16 +297,7 @@ def encoderConf(self) -> RobertaEncoderConf:
302
297
num_encoder_layers = 6 ,
303
298
padding_idx = 1 ,
304
299
),
305
- transform = lambda : T .Sequential (
306
- T .GPT2BPETokenizer (
307
- encoder_json_path = urljoin (_TEXT_BUCKET , "gpt2_bpe_encoder.json" ),
308
- vocab_bpe_path = urljoin (_TEXT_BUCKET , "gpt2_bpe_vocab.bpe" ),
309
- ),
310
- T .VocabTransform (load_state_dict_from_url (urljoin (_TEXT_BUCKET , "roberta.vocab.pt" ))),
311
- T .Truncate (510 ),
312
- T .AddToken (token = 0 , begin = True ),
313
- T .AddToken (token = 2 , begin = False ),
314
- ),
300
+ transform = lambda : roberta_transform (510 ),
315
301
)
316
302
317
303
ROBERTA_DISTILLED_ENCODER .__doc__ = """
0 commit comments