Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
ShawnXuan committed Sep 18, 2024
1 parent e42b8d0 commit ef8ce79
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions projects/mock_transformers/mock_tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,17 @@

with mock.enable(lazy=True):

from transformers import BertTokenizer, GPT2Tokenizer, MT5Tokenizer, T5Tokenizer, Qwen2Tokenizer # noqa
from transformers import (
BertTokenizer,
GPT2Tokenizer,
MT5Tokenizer,
T5Tokenizer,
Qwen2Tokenizer,
) # noqa
from transformers.tokenization_utils_base import * # noqa
from transformers.utils import generic # noqa
from transformers.utils.generic import TensorType # noqa


# ---------------- mock TensorType ------------------
class TensorType(ExplicitEnum): # noqa
PYTORCH = "pt"
Expand All @@ -36,10 +41,8 @@ class TensorType(ExplicitEnum): # noqa
NUMPY = "np"
JAX = "jax"


generic.TensorType = TensorType


# ---------------- mock convert_to_tensors ------------------
def flow_convert_to_tensors(self, tensor_type=None, prepend_batch_axis=False):
if tensor_type is None:
Expand Down Expand Up @@ -74,7 +77,9 @@ def flow_convert_to_tensors(self, tensor_type=None, prepend_batch_axis=False):
try:
import oneflow # noqa
except ImportError as e:
msg = "Unable to convert output to OneFlow tensors format, OneFlow is not installed."
msg = (
"Unable to convert output to OneFlow tensors format, OneFlow is not installed."
)
raise ImportError(msg) from e
as_tensor = flow.tensor
is_tensor = flow.is_tensor
Expand Down Expand Up @@ -137,5 +142,4 @@ def flow_convert_to_tensors(self, tensor_type=None, prepend_batch_axis=False):
self[k] = v.to_global(sbp=sbp, placement=dist.get_layer_placement(0))
return self


BatchEncoding.convert_to_tensors = flow_convert_to_tensors # noqa

0 comments on commit ef8ce79

Please sign in to comment.