-
Notifications
You must be signed in to change notification settings - Fork 312
Add Devstral Small 1.1 #2468
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
omkar-334
wants to merge
13
commits into
keras-team:master
Choose a base branch
from
omkar-334:devstral_1_1
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+496
−14
Draft
Add Devstral Small 1.1 #2468
Changes from 9 commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
61cbf9f
add devstral in preset_map
omkar-334 abf1276
changes for devstral tokenizer
omkar-334 64107d0
edit hf tokenizer for devstral
omkar-334 70a8bc8
add preset
omkar-334 59df4e4
linting fixes
omkar-334 1325743
Update keras_hub/src/utils/transformers/convert_mistral.py
omkar-334 71d1b9f
Update tools/checkpoint_conversion/convert_mistral_checkpoints.py
omkar-334 0ee78e2
fix
omkar-334 5fe3b35
add tiktoken tokenizer for mistral
omkar-334 9b1be1e
fixes (need to test this script)
omkar-334 1e071d1
fixes
omkar-334 e0985f5
fix rope_theta
omkar-334 9792af1
try
omkar-334 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,14 +3,13 @@ | |
| from keras_hub.src.tokenizers.sentence_piece_tokenizer import ( | ||
| SentencePieceTokenizer, | ||
| ) | ||
| from keras_hub.src.tokenizers.tiktoken_tokenizer import TiktokenTokenizer | ||
|
|
||
|
|
||
| @keras_hub_export( | ||
| [ | ||
| "keras_hub.tokenizers.MistralTokenizer", | ||
| "keras_hub.models.MistralTokenizer", | ||
| ] | ||
| ) | ||
| @keras_hub_export([ | ||
| "keras_hub.tokenizers.MistralTokenizer", | ||
| "keras_hub.models.MistralTokenizer", | ||
| ]) | ||
| class MistralTokenizer(SentencePieceTokenizer): | ||
| """Mistral tokenizer layer based on SentencePiece. | ||
|
|
||
|
|
@@ -55,3 +54,32 @@ def __init__(self, proto, **kwargs): | |
| self._add_special_token("</s>", "end_token") | ||
| self.pad_token_id = 0 | ||
| super().__init__(proto=proto, **kwargs) | ||
|
|
||
|
|
||
| @keras_hub_export([ | ||
| "keras_hub.tokenizers.NewMistralTokenizer", | ||
| "keras_hub.models.NewMistralTokenizer", | ||
| ]) | ||
| class NewMistralTokenizer(TiktokenTokenizer): | ||
| """ | ||
| Tekken-based tokenizer for Mistral models. | ||
|
|
||
| Responsibilities: | ||
| • Add required Mistral special tokens (<s>, </s>, pad) | ||
| • Delegate tekken.json parsing to TiktokenTokenizer | ||
| • Use Tiktoken backend via TiktokenTokenizer normalisation | ||
| """ | ||
|
Comment on lines
68
to
70
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The docstring for References
|
||
|
|
||
| backbone_cls = MistralBackbone | ||
|
|
||
| def __init__(self, proto, sequence_length=None, dtype="int32", **kwargs): | ||
| self._add_special_token("<s>", "start_token") | ||
| self._add_special_token("</s>", "end_token") | ||
| self.pad_token_id = 0 | ||
|
|
||
| super().__init__( | ||
| proto=proto, | ||
| sequence_length=sequence_length, | ||
| dtype=dtype, | ||
| **kwargs, | ||
| ) | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.