-
-
Notifications
You must be signed in to change notification settings - Fork 28
Feature/complete sentencepiece tokenizer #65
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
c096dbb
e4f567d
0b54e19
1bdeb40
323983c
c73d046
9406c0b
becd3c1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -333,4 +333,5 @@ __pycache__/ | |
| *.pyd | ||
| *.bz2 | ||
|
|
||
| *.venv | ||
| *.venv | ||
| venv/ | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,25 +4,84 @@ | |
|
|
||
| from .base import BaseTokenizer | ||
|
|
||
| SPM_MODEL_FILE = "spm.model" | ||
| SPM_VOCAB_FILE = "spm.vocab" | ||
|
|
||
|
|
||
| class SentencePieceTokenizer(BaseTokenizer): | ||
| """ | ||
| SentencePiece tokenizer implementation. | ||
| """ | ||
| def __init__(self, vocab_size: int, min_frequency: int): | ||
| super().__init__(vocab_size, min_frequency) | ||
| self._model = None | ||
|
|
||
| def train(self, text_file: Path, save_path: Path): | ||
| text_file = Path(text_file) | ||
| save_path = Path(save_path) | ||
|
|
||
| if not text_file.is_file(): | ||
| raise FileNotFoundError( | ||
| f"Training file not found at {text_file}. Please provide a valid text corpus file." | ||
| ) | ||
|
|
||
| if self.min_frequency != 1: | ||
| raise NotImplementedError( | ||
| f"min_frequency={self.min_frequency} is not supported. " | ||
| "SentencePiece does not expose a confirmed min_count option via the Python wrapper. " | ||
| "Set min_frequency=1 to use the default behaviour, or confirm the upstream option before enabling filtering." | ||
| ) | ||
|
|
||
| save_path.mkdir(parents=True, exist_ok=True) | ||
|
|
||
| model_prefix = save_path / "spm" | ||
|
|
||
| spm.SentencePieceTrainer.train( | ||
| input=str(text_file), | ||
| model_prefix=str(model_prefix), | ||
| vocab_size=self.vocab_size, | ||
| pad_id=0, | ||
| unk_id=1, | ||
| bos_id=2, | ||
| eos_id=3, | ||
| pad_piece="<pad>", | ||
| unk_piece="<unk>", | ||
| bos_piece="<s>", | ||
| eos_piece="</s>", | ||
| character_coverage=1.0, | ||
| model_type="bpe", | ||
| ) | ||
|
|
||
| def get_vocab_path(self, tokenizer_dir: Path): | ||
| return tokenizer_dir / "spm.vocab" | ||
| self._load_model(save_path) | ||
|
|
||
| def encode(self, text: str) -> list: | ||
| self._check_loaded() | ||
| return self._model.encode(text, out_type=int) | ||
|
|
||
| def decode(self, ids: list) -> str: | ||
| self._check_loaded() | ||
| return self._model.decode(ids) | ||
|
|
||
| def load(self, tokenizer_dir: Path): | ||
| tokenizer_dir = Path(tokenizer_dir) | ||
| self._load_model(tokenizer_dir) | ||
|
|
||
|
Comment on lines
+54
to
+65
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧹 Nitpick | 🔵 Trivial Promote These methods only exist on 🤖 Prompt for AI Agents |
||
| def get_vocab_path(self, tokenizer_dir: Path) -> Path: | ||
| return Path(tokenizer_dir) / SPM_VOCAB_FILE | ||
|
|
||
| def get_merges_path(self, tokenizer_dir: Path) -> Path: | ||
| return Path(tokenizer_dir) / SPM_MODEL_FILE | ||
|
|
||
| def _load_model(self, tokenizer_dir: Path): | ||
| model_path = Path(tokenizer_dir) / SPM_MODEL_FILE | ||
|
|
||
| if not model_path.is_file(): | ||
| raise FileNotFoundError( | ||
| f"SentencePiece model not found at {model_path}. Please train the tokenizer first." | ||
| ) | ||
|
|
||
| self._model = spm.SentencePieceProcessor() | ||
| self._model.load(str(model_path)) | ||
|
|
||
| def get_merges_path(self, tokenizer_dir: Path): | ||
| # SentencePiece does not use merges | ||
| return None | ||
| def _check_loaded(self): | ||
| if self._model is None: | ||
| raise RuntimeError( | ||
| "SentencePiece model is not loaded. Call train() or load() before encode/decode." | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,174 @@ | ||
| import pytest | ||
|
|
||
| from openverifiablellm.tokenizer.sentencepiece_tokenizer import SentencePieceTokenizer | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def sample_text_file(tmp_path): | ||
| text = ( | ||
| "Wikipedia is a free online encyclopedia.\n" | ||
| "It is written collaboratively by volunteers.\n" | ||
| "Anyone can edit Wikipedia articles.\n" | ||
| "Wikipedia was launched on January 15 2001.\n" | ||
| "It is one of the most popular websites in the world.\n" | ||
| ) * 500 | ||
|
|
||
| text_file = tmp_path / "sample.txt" | ||
| text_file.write_text(text) | ||
| return text_file | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def trained_tokenizer(tmp_path, sample_text_file): | ||
| tokenizer = SentencePieceTokenizer(vocab_size=200, min_frequency=1) | ||
| tokenizer.train(sample_text_file, tmp_path / "tokenizer") | ||
| return tmp_path / "tokenizer" | ||
|
|
||
|
|
||
| def test_spm_train_creates_artifacts(tmp_path, sample_text_file): | ||
| tokenizer = SentencePieceTokenizer(vocab_size=200, min_frequency=1) | ||
| save_path = tmp_path / "tokenizer" | ||
|
|
||
| tokenizer.train(sample_text_file, save_path) | ||
|
|
||
| assert (save_path / "spm.model").is_file() | ||
| assert (save_path / "spm.vocab").is_file() | ||
|
|
||
|
|
||
| def test_spm_train_creates_save_directory(tmp_path, sample_text_file): | ||
| tokenizer = SentencePieceTokenizer(vocab_size=200, min_frequency=1) | ||
| save_path = tmp_path / "nested" / "tokenizer" / "dir" | ||
|
|
||
| assert not save_path.exists() | ||
|
|
||
| tokenizer.train(sample_text_file, save_path) | ||
|
|
||
| assert save_path.exists() | ||
|
|
||
|
|
||
| def test_spm_train_raises_file_not_found(tmp_path): | ||
| tokenizer = SentencePieceTokenizer(vocab_size=200, min_frequency=1) | ||
|
|
||
| with pytest.raises(FileNotFoundError, match="Training file not found"): | ||
| tokenizer.train(tmp_path / "nonexistent.txt", tmp_path / "tokenizer") | ||
|
|
||
|
|
||
| def test_spm_train_raises_if_directory_passed(tmp_path, sample_text_file): | ||
| tokenizer = SentencePieceTokenizer(vocab_size=200, min_frequency=1) | ||
|
|
||
| with pytest.raises(FileNotFoundError, match="Training file not found"): | ||
| tokenizer.train(tmp_path, tmp_path / "tokenizer") | ||
|
|
||
|
|
||
| def test_spm_train_raises_if_min_frequency_not_one(tmp_path, sample_text_file): | ||
| tokenizer = SentencePieceTokenizer(vocab_size=200, min_frequency=2) | ||
|
|
||
| with pytest.raises(NotImplementedError, match="min_frequency=2 is not supported"): | ||
| tokenizer.train(sample_text_file, tmp_path / "tokenizer") | ||
|
|
||
|
|
||
| def test_spm_encode_returns_list_of_ints(trained_tokenizer): | ||
| tokenizer = SentencePieceTokenizer(vocab_size=200, min_frequency=1) | ||
| tokenizer.load(trained_tokenizer) | ||
|
|
||
| ids = tokenizer.encode("hello world") | ||
|
|
||
| assert isinstance(ids, list) | ||
| assert all(isinstance(i, int) for i in ids) | ||
|
|
||
|
|
||
| def test_spm_encode_decode_roundtrip(trained_tokenizer): | ||
| tokenizer = SentencePieceTokenizer(vocab_size=200, min_frequency=1) | ||
| tokenizer.load(trained_tokenizer) | ||
|
|
||
| text = "Wikipedia is a free online encyclopedia" | ||
| ids = tokenizer.encode(text) | ||
| decoded = tokenizer.decode(ids) | ||
|
|
||
| assert decoded.strip() == text.strip() | ||
|
|
||
|
|
||
| def test_spm_encode_raises_if_not_loaded(): | ||
| tokenizer = SentencePieceTokenizer(vocab_size=200, min_frequency=1) | ||
|
|
||
| with pytest.raises(RuntimeError, match="not loaded"): | ||
| tokenizer.encode("hello world") | ||
|
|
||
|
|
||
| def test_spm_decode_raises_if_not_loaded(): | ||
| tokenizer = SentencePieceTokenizer(vocab_size=200, min_frequency=1) | ||
|
|
||
| with pytest.raises(RuntimeError, match="not loaded"): | ||
| tokenizer.decode([1, 2, 3]) | ||
|
|
||
|
|
||
| def test_spm_load_from_disk(trained_tokenizer): | ||
| tokenizer = SentencePieceTokenizer(vocab_size=200, min_frequency=1) | ||
| tokenizer.load(trained_tokenizer) | ||
|
|
||
| assert tokenizer._model is not None | ||
|
|
||
|
|
||
| def test_spm_encode_works_after_load(trained_tokenizer): | ||
| tokenizer = SentencePieceTokenizer(vocab_size=200, min_frequency=1) | ||
| tokenizer.load(trained_tokenizer) | ||
|
|
||
| ids = tokenizer.encode("hello world") | ||
|
|
||
| assert isinstance(ids, list) | ||
| assert len(ids) > 0 | ||
|
|
||
|
|
||
| def test_spm_load_raises_if_model_missing(tmp_path): | ||
| tokenizer = SentencePieceTokenizer(vocab_size=200, min_frequency=1) | ||
|
|
||
| with pytest.raises(FileNotFoundError, match="SentencePiece model not found"): | ||
| tokenizer.load(tmp_path) | ||
|
|
||
|
|
||
| def test_spm_get_vocab_path(tmp_path): | ||
| tokenizer = SentencePieceTokenizer(vocab_size=200, min_frequency=1) | ||
| vocab_path = tokenizer.get_vocab_path(tmp_path) | ||
|
|
||
| assert vocab_path == tmp_path / "spm.vocab" | ||
|
|
||
|
|
||
| def test_spm_get_merges_path_returns_model_path(tmp_path): | ||
| tokenizer = SentencePieceTokenizer(vocab_size=200, min_frequency=1) | ||
| merges_path = tokenizer.get_merges_path(tmp_path) | ||
|
|
||
| assert merges_path == tmp_path / "spm.model" | ||
|
|
||
|
|
||
| def test_spm_special_tokens_in_vocabulary(trained_tokenizer): | ||
| tokenizer = SentencePieceTokenizer(vocab_size=200, min_frequency=1) | ||
| tokenizer.load(trained_tokenizer) | ||
|
|
||
| vocab_path = trained_tokenizer / "spm.vocab" | ||
| vocab_content = vocab_path.read_text(encoding="utf-8") | ||
|
|
||
| assert "<pad>" in vocab_content | ||
| assert "<unk>" in vocab_content | ||
| assert "<s>" in vocab_content | ||
| assert "</s>" in vocab_content | ||
|
|
||
|
|
||
| def test_spm_training_is_deterministic(tmp_path, sample_text_file): | ||
| save_path_1 = tmp_path / "tokenizer_1" | ||
| save_path_2 = tmp_path / "tokenizer_2" | ||
|
|
||
| tokenizer_1 = SentencePieceTokenizer(vocab_size=200, min_frequency=1) | ||
| tokenizer_1.train(sample_text_file, save_path_1) | ||
|
|
||
| tokenizer_2 = SentencePieceTokenizer(vocab_size=200, min_frequency=1) | ||
| tokenizer_2.train(sample_text_file, save_path_2) | ||
|
|
||
| vocab_1 = (save_path_1 / "spm.vocab").read_text(encoding="utf-8") | ||
| vocab_2 = (save_path_2 / "spm.vocab").read_text(encoding="utf-8") | ||
|
|
||
| assert vocab_1 == vocab_2 | ||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| model_1 = (save_path_1 / "spm.model").read_bytes() | ||
| model_2 = (save_path_2 / "spm.model").read_bytes() | ||
|
|
||
| assert model_1 == model_2 | ||
Uh oh!
There was an error while loading. Please reload this page.