Skip to content

Commit 46e7eef

Browse files
authored
Add XLMR and RoBERTa transforms as factory functions (#2102)
* Add XLMR and RoBERTa transforms as classes * Remove unused import * Add roberta transform and xlmr transform as factory methods
1 parent 9d42632 commit 46e7eef

File tree

1 file changed

+29
-43
lines changed

1 file changed

+29
-43
lines changed

torchtext/models/roberta/bundler.py

+29-43
Original file line numberDiff line numberDiff line change
@@ -160,16 +160,35 @@ def encoderConf(self) -> RobertaEncoderConf:
160160
return self._encoder_conf
161161

162162

163-
XLMR_BASE_ENCODER = RobertaBundle(
164-
_path=urljoin(_TEXT_BUCKET, "xlmr.base.encoder.pt"),
165-
_encoder_conf=RobertaEncoderConf(vocab_size=250002),
166-
transform=lambda: T.Sequential(
163+
def xlmr_transform(truncate_length: int) -> Module:
164+
"""Standard transform for XLMR models."""
165+
return T.Sequential(
167166
T.SentencePieceTokenizer(urljoin(_TEXT_BUCKET, "xlmr.sentencepiece.bpe.model")),
168167
T.VocabTransform(load_state_dict_from_url(urljoin(_TEXT_BUCKET, "xlmr.vocab.pt"))),
169-
T.Truncate(254),
168+
T.Truncate(truncate_length),
170169
T.AddToken(token=0, begin=True),
171170
T.AddToken(token=2, begin=False),
172-
),
171+
)
172+
173+
174+
def roberta_transform(truncate_length: int) -> Module:
175+
"""Standard transform for RoBERTa models."""
176+
return T.Sequential(
177+
T.GPT2BPETokenizer(
178+
encoder_json_path=urljoin(_TEXT_BUCKET, "gpt2_bpe_encoder.json"),
179+
vocab_bpe_path=urljoin(_TEXT_BUCKET, "gpt2_bpe_vocab.bpe"),
180+
),
181+
T.VocabTransform(load_state_dict_from_url(urljoin(_TEXT_BUCKET, "roberta.vocab.pt"))),
182+
T.Truncate(truncate_length),
183+
T.AddToken(token=0, begin=True),
184+
T.AddToken(token=2, begin=False),
185+
)
186+
187+
188+
XLMR_BASE_ENCODER = RobertaBundle(
189+
_path=urljoin(_TEXT_BUCKET, "xlmr.base.encoder.pt"),
190+
_encoder_conf=RobertaEncoderConf(vocab_size=250002),
191+
transform=lambda: xlmr_transform(254),
173192
)
174193

175194
XLMR_BASE_ENCODER.__doc__ = """
@@ -193,13 +212,7 @@ def encoderConf(self) -> RobertaEncoderConf:
193212
_encoder_conf=RobertaEncoderConf(
194213
vocab_size=250002, embedding_dim=1024, ffn_dimension=4096, num_attention_heads=16, num_encoder_layers=24
195214
),
196-
transform=lambda: T.Sequential(
197-
T.SentencePieceTokenizer(urljoin(_TEXT_BUCKET, "xlmr.sentencepiece.bpe.model")),
198-
T.VocabTransform(load_state_dict_from_url(urljoin(_TEXT_BUCKET, "xlmr.vocab.pt"))),
199-
T.Truncate(510),
200-
T.AddToken(token=0, begin=True),
201-
T.AddToken(token=2, begin=False),
202-
),
215+
transform=lambda: xlmr_transform(510),
203216
)
204217

205218
XLMR_LARGE_ENCODER.__doc__ = """
@@ -221,16 +234,7 @@ def encoderConf(self) -> RobertaEncoderConf:
221234
ROBERTA_BASE_ENCODER = RobertaBundle(
222235
_path=urljoin(_TEXT_BUCKET, "roberta.base.encoder.pt"),
223236
_encoder_conf=RobertaEncoderConf(vocab_size=50265),
224-
transform=lambda: T.Sequential(
225-
T.GPT2BPETokenizer(
226-
encoder_json_path=urljoin(_TEXT_BUCKET, "gpt2_bpe_encoder.json"),
227-
vocab_bpe_path=urljoin(_TEXT_BUCKET, "gpt2_bpe_vocab.bpe"),
228-
),
229-
T.VocabTransform(load_state_dict_from_url(urljoin(_TEXT_BUCKET, "roberta.vocab.pt"))),
230-
T.Truncate(254),
231-
T.AddToken(token=0, begin=True),
232-
T.AddToken(token=2, begin=False),
233-
),
237+
transform=lambda: roberta_transform(254),
234238
)
235239

236240
ROBERTA_BASE_ENCODER.__doc__ = """
@@ -263,16 +267,7 @@ def encoderConf(self) -> RobertaEncoderConf:
263267
num_attention_heads=16,
264268
num_encoder_layers=24,
265269
),
266-
transform=lambda: T.Sequential(
267-
T.GPT2BPETokenizer(
268-
encoder_json_path=urljoin(_TEXT_BUCKET, "gpt2_bpe_encoder.json"),
269-
vocab_bpe_path=urljoin(_TEXT_BUCKET, "gpt2_bpe_vocab.bpe"),
270-
),
271-
T.VocabTransform(load_state_dict_from_url(urljoin(_TEXT_BUCKET, "roberta.vocab.pt"))),
272-
T.Truncate(510),
273-
T.AddToken(token=0, begin=True),
274-
T.AddToken(token=2, begin=False),
275-
),
270+
transform=lambda: roberta_transform(510),
276271
)
277272

278273
ROBERTA_LARGE_ENCODER.__doc__ = """
@@ -302,16 +297,7 @@ def encoderConf(self) -> RobertaEncoderConf:
302297
num_encoder_layers=6,
303298
padding_idx=1,
304299
),
305-
transform=lambda: T.Sequential(
306-
T.GPT2BPETokenizer(
307-
encoder_json_path=urljoin(_TEXT_BUCKET, "gpt2_bpe_encoder.json"),
308-
vocab_bpe_path=urljoin(_TEXT_BUCKET, "gpt2_bpe_vocab.bpe"),
309-
),
310-
T.VocabTransform(load_state_dict_from_url(urljoin(_TEXT_BUCKET, "roberta.vocab.pt"))),
311-
T.Truncate(510),
312-
T.AddToken(token=0, begin=True),
313-
T.AddToken(token=2, begin=False),
314-
),
300+
transform=lambda: roberta_transform(510),
315301
)
316302

317303
ROBERTA_DISTILLED_ENCODER.__doc__ = """

0 commit comments

Comments
 (0)