-
Notifications
You must be signed in to change notification settings - Fork 160
Home
Bo仔很忙 edited this page Apr 26, 2024
·
1 revision
- 预训练模型支持多种代码加载方式
from bert4torch.models import build_transformer_model
# 1. 仅指定config_path: 从头初始化模型结构, 不加载预训练模型
model = build_transformer_model('./model/bert4torch_config.json')
# 2. 仅指定checkpoint_path:
## 2.1 文件夹路径: 自动寻找路径下的*.bin/*.safetensors权重文件 + bert4torch_config.json/config.json文件
model = build_transformer_model(checkpoint_path='./model')
## 2.2 文件路径/列表: 文件路径即权重路径/列表, config会从同级目录下寻找
model = build_transformer_model(checkpoint_path='./pytorch_model.bin')
## 2.3 model_name: hf上预训练权重名称, 会自动下载hf权重以及bert4torch_config.json文件
model = build_transformer_model(checkpoint_path='bert-base-chinese')
# 3. 同时指定config_path和checkpoint_path(本地路径名或model_name排列组合):
config_path = './model/bert4torch_config.json' # 或'bert-base-chinese'
checkpoint_path = './model/pytorch_model.bin' # 或'bert-base-chinese'
model = build_transformer_model(config_path, checkpoint_path)
*注:
-
高亮格式
(如bert-base-chinese
)的表示可直接build_transformer_model()
联网下载 - 国内镜像网站加速下载
HF_ENDPOINT=https://hf-mirror.com python your_script.py
-
export HF_ENDPOINT=https://hf-mirror.com
后再执行python代码 - 在python代码开头如下设置
import os os.environ['HF_ENDPOINT'] = "https://hf-mirror.com"