Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
ShawnXuan committed Sep 18, 2024
2 parents b68f005 + 4c803dc commit bb3b59c
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion projects/mock_transformers/mock_tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
import os

import oneflow as flow
import oneflow.mock_torch as mock

from libai.utils import distributed as dist
import oneflow.mock_torch as mock

with mock.enable(lazy=True):

from transformers import ( # noqa
BertTokenizer,
GPT2Tokenizer,
Expand All @@ -33,6 +33,7 @@
from transformers.utils import generic # noqa
from transformers.utils.generic import TensorType # noqa


# ---------------- mock TensorType ------------------
class TensorType(ExplicitEnum): # noqa
PYTORCH = "pt"
Expand Down Expand Up @@ -77,6 +78,7 @@ def flow_convert_to_tensors(self, tensor_type=None, prepend_batch_axis=False):
try:
import oneflow # noqa
except ImportError as e:

msg = (
"Unable to convert output to OneFlow tensors format, OneFlow is not installed."
)
Expand Down Expand Up @@ -143,4 +145,5 @@ def flow_convert_to_tensors(self, tensor_type=None, prepend_batch_axis=False):
self[k] = v.to_global(sbp=sbp, placement=dist.get_layer_placement(0))
return self


BatchEncoding.convert_to_tensors = flow_convert_to_tensors # noqa

0 comments on commit bb3b59c

Please sign in to comment.