Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug #538

Merged
merged 5 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion libai/inference/generator/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def _get_decoder_start_token_id(
elif bos_token_id is not None:
return bos_token_id
else:
return self.cfg.bos_token_idx
return self.cfg.bos_token_id

@staticmethod
def _expand_inputs_for_generation(
Expand Down
38 changes: 33 additions & 5 deletions libai/models/utils/model_loader/base_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import omegaconf
import oneflow as flow
from safetensors import safe_open
from termcolor import colored

import libai.utils.distributed as dist
Expand Down Expand Up @@ -457,17 +458,35 @@ def _load_config_from_json(self, config_file):

raise NotImplementedError("_load_config_from_json not implemented")

def _load_torch_state_dict(self, state_dict_file):
def _load_torch_state_dict(self, state_dict_file, use_safetensors=False):
try:
import torch
except ImportError:
raise ImportError("Load torch state dict need torch.")

if use_safetensors:
if isinstance(state_dict_file, str):
state_dict = {}
with safe_open(state_dict_file, framework="pt", device=0) as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
return state_dict

elif isinstance(state_dict_file, list):
merged_state_dict = {}
for file in state_dict_file:
state_dict = {}
with safe_open(file, framework="pt") as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k).to(torch.float)
merged_state_dict.update(state_dict)
return merged_state_dict

# load pytorch_model.bin
if isinstance(state_dict_file, str):
return torch.load(state_dict_file, map_location="cpu")

if isinstance(state_dict_file, list):
elif isinstance(state_dict_file, list):
merged_state_dict = {}
for file in state_dict_file:
state_dict = torch.load(file, map_location="cpu")
Expand Down Expand Up @@ -532,6 +551,7 @@ def load(self):
>>> bert = loader.load()

"""
use_safetensors = False
if dist.is_main_process():
if os.path.isdir(self.pretrained_model_path):
# state_dict file pytorch
Expand All @@ -541,10 +561,18 @@ def load(self):
if file.endswith(".bin")
]

if len(model_files) == 0:
use_safetensors = True
model_files = [
os.path.join(self.pretrained_model_path, file)
for file in os.listdir(self.pretrained_model_path)
if file.endswith(".safetensors")
]

if len(model_files) == 0:
raise EnvironmentError(
f"Error: no file named endswith '.bin' found"
f"in directory {self.pretrained_model_path}."
f"Error: no file named endswith '.bin' or '.safetensors' "
f"found in directory {self.pretrained_model_path}."
)

# config file
Expand All @@ -565,7 +593,7 @@ def load(self):
raise EnvironmentError(f"{self.pretrained_model_path} is not a directory.")

logger.info("loading torch model...")
torch_state_dict = self._load_torch_state_dict(model_files)
torch_state_dict = self._load_torch_state_dict(model_files, use_safetensors)
torch_state_dict = self._fix_key(torch_state_dict)
logger.info("transfering torch model into oneflow model...")
flow_state_dict = self._convert_tensors(torch_state_dict)
Expand Down
10 changes: 6 additions & 4 deletions libai/tokenizer/tokenization_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,17 +807,19 @@ def encode(self, text, return_tensors=None, is_global=False, **kwargs):
if isinstance(text, str):
tokens = self.tokenize(text)
token_ids = self.convert_tokens_to_ids(tokens)
token_ids = self.build_inputs_with_special_tokens(token_ids)
if hasattr(self, "build_inputs_with_special_tokens"):
token_ids = self.build_inputs_with_special_tokens(token_ids)
token_ids = self.convert_to_tensors(
token_ids, return_tensors=return_tensors, is_global=is_global, **kwargs
)
return token_ids
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str):
tokens = [self.tokenize(t) for t in text]
token_ids_list = self.convert_tokens_to_ids(tokens)
token_ids_list = [
self.build_inputs_with_special_tokens(token_ids) for token_ids in token_ids_list
]
if hasattr(self, "build_inputs_with_special_tokens"):
token_ids_list = [
self.build_inputs_with_special_tokens(token_ids) for token_ids in token_ids_list
]
token_ids_list = self.convert_to_tensors(
token_ids_list, return_tensors=return_tensors, is_global=is_global, **kwargs
)
Expand Down
238 changes: 122 additions & 116 deletions projects/mock_transformers/mock_tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,125 +16,131 @@
import os

import oneflow as flow
import oneflow.mock_torch as mock

from libai.utils import distributed as dist

flow.mock_torch.enable()

from transformers import BertTokenizer, GPT2Tokenizer, MT5Tokenizer, T5Tokenizer # noqa
from transformers.tokenization_utils_base import * # noqa
from transformers.utils import generic # noqa
from transformers.utils.generic import TensorType # noqa


# ---------------- mock TensorType ------------------
class TensorType(ExplicitEnum): # noqa
PYTORCH = "pt"
TENSORFLOW = "tf"
ONEFLOW = "of"
NUMPY = "np"
JAX = "jax"


generic.TensorType = TensorType


# ---------------- mock convert_to_tensors ------------------
def flow_convert_to_tensors(self, tensor_type=None, prepend_batch_axis=False):
if tensor_type is None:
return self

# Convert to TensorType
if not isinstance(tensor_type, TensorType):
tensor_type = TensorType(tensor_type)
as_tensor = None
is_tensor = None
# Get a function reference for the correct framework
if tensor_type == TensorType.TENSORFLOW:
if not is_tf_available(): # noqa
raise ImportError(
"Unable to convert output to TensorFlow tensors format, TensorFlow is not "
"installed."
)
import tensorflow as tf

as_tensor = tf.constant
is_tensor = tf.is_tensor
elif tensor_type == TensorType.PYTORCH:
if not is_torch_available(): # noqa
raise ImportError(
"Unable to convert output to PyTorch tensors format, PyTorch is not installed."
)
import torch

as_tensor = torch.tensor
is_tensor = torch.is_tensor
elif tensor_type == TensorType.ONEFLOW:
try:
import oneflow # noqa
except ImportError as e:
msg = "Unable to convert output to OneFlow tensors format, OneFlow is not installed."
raise ImportError(msg) from e
as_tensor = flow.tensor
is_tensor = flow.is_tensor
elif tensor_type == TensorType.JAX:
if not is_flax_available(): # noqa
raise ImportError(
"Unable to convert output to JAX tensors format, JAX is not installed."
)
import jax.numpy as jnp # noqa: F811

as_tensor = jnp.array
is_tensor = is_jax_tensor # noqa
else:
as_tensor = np.asarray # noqa
is_tensor = is_numpy_array # noqa

# Do the tensor conversion in batch
for key, value in self.items():
try:
if prepend_batch_axis:
value = [value]

if not is_tensor(value):
tensor = as_tensor(value)

# Removing this for now in favor of controlling the shape with `prepend_batch_axis`
# # at-least2d
# if tensor.ndim > 2:
# tensor = tensor.squeeze(0)
# elif tensor.ndim < 2:
# tensor = tensor[None, :]

self[key] = tensor
except Exception as e:
if key == "overflowing_tokens":
raise ValueError(
"Unable to create tensor returning overflowing tokens of different lengths. "
"Please see if a fast version of this tokenizer is available to have this "
"feature available."
) from e
raise ValueError(
"Unable to create tensor, you should probably activate truncation and/or "
"padding with 'padding=True' 'truncation=True' to have batched tensors with "
f"the same length. Perhaps your features (`{key}` in this case) have "
"excessive nesting (inputs type `list` where type `int` is expected)."
) from e
if os.getenv("IS_GLOBAL", True) is True:
size = self["input_ids"].size()
sbp = dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast])

for k, v in self.items():
if is_tensor != flow.is_tensor:
raise ValueError(
"Unable to create tensor, you should probably set `return_tensors='of'` "
with mock.enable(lazy=True):

from transformers import ( # noqa
BertTokenizer,
GPT2Tokenizer,
MT5Tokenizer,
Qwen2Tokenizer,
T5Tokenizer,
)
from transformers.tokenization_utils_base import * # noqa
from transformers.utils import generic # noqa
from transformers.utils.generic import TensorType # noqa

# ---------------- mock TensorType ------------------
class TensorType(ExplicitEnum): # noqa
PYTORCH = "pt"
TENSORFLOW = "tf"
ONEFLOW = "of"
NUMPY = "np"
JAX = "jax"

generic.TensorType = TensorType

# ---------------- mock convert_to_tensors ------------------
def flow_convert_to_tensors(self, tensor_type=None, prepend_batch_axis=False):
if tensor_type is None:
return self

# Convert to TensorType
if not isinstance(tensor_type, TensorType):
tensor_type = TensorType(tensor_type)
as_tensor = None
is_tensor = None
# Get a function reference for the correct framework
if tensor_type == TensorType.TENSORFLOW:
if not is_tf_available(): # noqa
raise ImportError(
"Unable to convert output to TensorFlow tensors format, TensorFlow is not "
"installed."
)
if v.size() != size:
raise ValueError(
"Unable to create tensor, you should probably padding with `padding=True` "
import tensorflow as tf

as_tensor = tf.constant
is_tensor = tf.is_tensor
elif tensor_type == TensorType.PYTORCH:
if not is_torch_available(): # noqa
raise ImportError(
"Unable to convert output to PyTorch tensors format, PyTorch is not installed."
)
self[k] = v.to_global(sbp=sbp, placement=dist.get_layer_placement(0))
return self

import torch

as_tensor = torch.tensor
is_tensor = torch.is_tensor
elif tensor_type == TensorType.ONEFLOW:
try:
import oneflow # noqa
except ImportError as e:
msg = (
"Unable to convert output to OneFlow tensors format, OneFlow is not installed."
)
raise ImportError(msg) from e
as_tensor = flow.tensor
is_tensor = flow.is_tensor
elif tensor_type == TensorType.JAX:
if not is_flax_available(): # noqa
raise ImportError(
"Unable to convert output to JAX tensors format, JAX is not installed."
)
import jax.numpy as jnp # noqa: F811

as_tensor = jnp.array
is_tensor = is_jax_tensor # noqa
else:
as_tensor = np.asarray # noqa
is_tensor = is_numpy_array # noqa

# Do the tensor conversion in batch
for key, value in self.items():
try:
if prepend_batch_axis:
value = [value]

if not is_tensor(value):
tensor = as_tensor(value)

# Removing this for now in favor of controlling the shape
# with `prepend_batch_axis`
# # at-least2d
# if tensor.ndim > 2:
# tensor = tensor.squeeze(0)
# elif tensor.ndim < 2:
# tensor = tensor[None, :]

self[key] = tensor
except Exception as e:
if key == "overflowing_tokens":
raise ValueError(
"Unable to create tensor returning overflowing tokens of different "
"lengths. Please see if a fast version of this tokenizer is "
"available to have this feature available."
) from e
raise ValueError(
"Unable to create tensor, you should probably activate truncation and/or "
"padding with 'padding=True' 'truncation=True' to have batched tensors with "
f"the same length. Perhaps your features (`{key}` in this case) have "
"excessive nesting (inputs type `list` where type `int` is expected)."
) from e
if os.getenv("IS_GLOBAL", True) is True:
size = self["input_ids"].size()
sbp = dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast])

for k, v in self.items():
if is_tensor != flow.is_tensor:
raise ValueError(
"Unable to create tensor, you should probably set `return_tensors='of'` "
)
if v.size() != size:
raise ValueError(
"Unable to create tensor, you should probably padding with `padding=True` "
)
self[k] = v.to_global(sbp=sbp, placement=dist.get_layer_placement(0))
return self

BatchEncoding.convert_to_tensors = flow_convert_to_tensors # noqa
BatchEncoding.convert_to_tensors = flow_convert_to_tensors # noqa
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ black==21.4b2
autoflake
tensorboardX<=2.5.1
pytest
safetensors
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def get_libai_configs() -> List[str]:
"autoflake",
"tensorboardX<=2.5.1",
"pytest",
"safetensors",
],
packages=find_packages(),
package_data={"libai.config": get_libai_configs()},
Expand Down
Loading