Skip to content

Commit

Permalink
Added support for electra tokenizer as a special case of BERT tokeniz…
Browse files Browse the repository at this point in the history
…er (#28)

* Added support for electra tokenizer as a special case of BERT tokenizer

* Updated based on review

* formatted with black

* fix: Added correct tests

* re-enable skips when hf is not installed

These were disabled for testing
  • Loading branch information
KennethEnevoldsen authored Mar 24, 2024
1 parent 246a5b3 commit e8657ff
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,38 @@ def test_wordpiece_encoder_hf_model(sample_docs):
)


@pytest.mark.slow
@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers")
def test_wordpiece_encoder_hf_model_w_electra(sample_docs):
ops = get_current_ops()
encoder = build_wordpiece_encoder_v1()
encoder.init = build_hf_piece_encoder_loader_v1(
name="google/electra-small-discriminator"
)
encoder.initialize()

encoding = encoder.predict(sample_docs)

assert isinstance(encoding, list)
assert len(encoding) == 2

assert isinstance(encoding[0], Ragged)
ops.xp.testing.assert_array_equal(
encoding[0].lengths, [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
)
print(encoding[0].dataXd)
ops.xp.testing.assert_array_equal(
encoding[0].dataXd, [101, 1045, 2387, 1037, 2611, 2007, 1037, 12772, 1012, 102]
)

ops.xp.testing.assert_array_equal(encoding[1].lengths, [1, 1, 1, 1, 1, 1, 1, 1, 1])
ops.xp.testing.assert_array_equal(
encoding[1].dataXd,
[101, 2651, 2057, 2097, 4521, 26202, 4605, 1012, 102],
)
print(encoding[1].dataXd)


@pytest.mark.slow
@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers")
def test_wordpiece_encoder_hf_model_uncased(sample_docs):
Expand Down
12 changes: 9 additions & 3 deletions spacy_curated_transformers/tokenization/hf_loader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Callable, Optional
from typing import Callable, Optional, Union

from .._compat import has_hf_transformers, transformers
from .bbpe_encoder import ByteBPEProcessor
Expand All @@ -14,6 +14,7 @@
transformers.XLMRobertaTokenizerFast,
transformers.CamembertTokenizerFast,
transformers.BertJapaneseTokenizer,
transformers.ElectraTokenizerFast,
)
else:
SUPPORTED_TOKENIZERS = () # type: ignore
Expand Down Expand Up @@ -59,7 +60,9 @@ def build_hf_piece_encoder_loader_v1(
def _convert_encoder(
model: Tok2PiecesModelT, tokenizer: "transformers.PreTrainedTokenizerBase"
) -> Tok2PiecesModelT:
if isinstance(tokenizer, transformers.BertTokenizerFast):
if isinstance(
tokenizer, (transformers.BertTokenizerFast, transformers.ElectraTokenizerFast)
):
return _convert_wordpiece_encoder(model, tokenizer)
elif isinstance(tokenizer, transformers.RobertaTokenizerFast):
return _convert_byte_bpe_encoder(model, tokenizer)
Expand Down Expand Up @@ -109,7 +112,10 @@ def _convert_sentencepiece_encoder(


def _convert_wordpiece_encoder(
model: Tok2PiecesModelT, tokenizer: "transformers.BertTokenizerFast"
model: Tok2PiecesModelT,
tokenizer: Union[
"transformers.BertTokenizerFast", "transformers.ElectraTokenizerFast"
],
) -> Tok2PiecesModelT:
# Seems like we cannot get the vocab file name for a BERT vocabulary? So,
# instead, copy the vocabulary.
Expand Down

0 comments on commit e8657ff

Please sign in to comment.