Skip to content

Commit 1f8d908

Browse files
authored
[model] support Tencent-Hunyuan/Hunyuan-A13B-Instruct (#4745)
1 parent 696fad6 commit 1f8d908

File tree

16 files changed

+142
-23
lines changed

16 files changed

+142
-23
lines changed

docs/source/Instruction/支持的模型和数据集.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,7 @@
559559
|[XiaomiMiMo/MiMo-7B-RL-0530](https://modelscope.cn/models/XiaomiMiMo/MiMo-7B-RL-0530)|mimo_rl|mimo_rl|transformers>=4.37|✔|-|[XiaomiMiMo/MiMo-7B-RL-0530](https://huggingface.co/XiaomiMiMo/MiMo-7B-RL-0530)|
560560
|[rednote-hilab/dots.llm1.base](https://modelscope.cn/models/rednote-hilab/dots.llm1.base)|dots1|dots1|transformers>=4.53.0.dev0|✔|-|[rednote-hilab/dots.llm1.base](https://huggingface.co/rednote-hilab/dots.llm1.base)|
561561
|[rednote-hilab/dots.llm1.inst](https://modelscope.cn/models/rednote-hilab/dots.llm1.inst)|dots1|dots1|transformers>=4.53.0.dev0|✔|-|[rednote-hilab/dots.llm1.inst](https://huggingface.co/rednote-hilab/dots.llm1.inst)|
562+
|[Tencent-Hunyuan/Hunyuan-A13B-Instruct](https://modelscope.cn/models/Tencent-Hunyuan/Hunyuan-A13B-Instruct)|hunyuan|hunyuan|-|✘|-|[tencent/Hunyuan-A13B-Instruct](https://huggingface.co/tencent/Hunyuan-A13B-Instruct)|
562563
|[answerdotai/ModernBERT-base](https://modelscope.cn/models/answerdotai/ModernBERT-base)|modern_bert|dummy|transformers>=4.48|✘|bert|[answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base)|
563564
|[answerdotai/ModernBERT-large](https://modelscope.cn/models/answerdotai/ModernBERT-large)|modern_bert|dummy|transformers>=4.48|✘|bert|[answerdotai/ModernBERT-large](https://huggingface.co/answerdotai/ModernBERT-large)|
564565
|[iic/gte-modernbert-base](https://modelscope.cn/models/iic/gte-modernbert-base)|modern_bert_gte|dummy|transformers>=4.48|✘|bert, embedding|[Alibaba-NLP/gte-modernbert-base](https://huggingface.co/Alibaba-NLP/gte-modernbert-base)|

docs/source_en/Instruction/Supported-models-and-datasets.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,7 @@ The table below introduces the models integrated with ms-swift:
559559
|[XiaomiMiMo/MiMo-7B-RL-0530](https://modelscope.cn/models/XiaomiMiMo/MiMo-7B-RL-0530)|mimo_rl|mimo_rl|transformers>=4.37|✔|-|[XiaomiMiMo/MiMo-7B-RL-0530](https://huggingface.co/XiaomiMiMo/MiMo-7B-RL-0530)|
560560
|[rednote-hilab/dots.llm1.base](https://modelscope.cn/models/rednote-hilab/dots.llm1.base)|dots1|dots1|transformers>=4.53.0.dev0|✔|-|[rednote-hilab/dots.llm1.base](https://huggingface.co/rednote-hilab/dots.llm1.base)|
561561
|[rednote-hilab/dots.llm1.inst](https://modelscope.cn/models/rednote-hilab/dots.llm1.inst)|dots1|dots1|transformers>=4.53.0.dev0|✔|-|[rednote-hilab/dots.llm1.inst](https://huggingface.co/rednote-hilab/dots.llm1.inst)|
562+
|[Tencent-Hunyuan/Hunyuan-A13B-Instruct](https://modelscope.cn/models/Tencent-Hunyuan/Hunyuan-A13B-Instruct)|hunyuan|hunyuan|-|✘|-|[tencent/Hunyuan-A13B-Instruct](https://huggingface.co/tencent/Hunyuan-A13B-Instruct)|
562563
|[answerdotai/ModernBERT-base](https://modelscope.cn/models/answerdotai/ModernBERT-base)|modern_bert|dummy|transformers>=4.48|✘|bert|[answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base)|
563564
|[answerdotai/ModernBERT-large](https://modelscope.cn/models/answerdotai/ModernBERT-large)|modern_bert|dummy|transformers>=4.48|✘|bert|[answerdotai/ModernBERT-large](https://huggingface.co/answerdotai/ModernBERT-large)|
564565
|[iic/gte-modernbert-base](https://modelscope.cn/models/iic/gte-modernbert-base)|modern_bert_gte|dummy|transformers>=4.48|✘|bert, embedding|[Alibaba-NLP/gte-modernbert-base](https://huggingface.co/Alibaba-NLP/gte-modernbert-base)|

swift/llm/model/constant.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ class LLMModelType:
119119
mimo = 'mimo'
120120
mimo_rl = 'mimo_rl'
121121
dots1 = 'dots1'
122+
hunyuan = 'hunyuan'
122123

123124

124125
class BertModelType:

swift/llm/model/model/llm.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,3 +343,14 @@ def forward(self, **kwargs):
343343
architectures=['Dots1ForCausalLM'],
344344
requires=['transformers>=4.53.0.dev0'],
345345
))
346+
347+
register_model(
348+
ModelMeta(
349+
LLMModelType.hunyuan,
350+
[ModelGroup([
351+
Model('Tencent-Hunyuan/Hunyuan-A13B-Instruct', 'tencent/Hunyuan-A13B-Instruct'),
352+
])],
353+
TemplateType.hunyuan,
354+
get_model_tokenizer_with_flash_attn,
355+
architectures=['HunYuanMoEV1ForCausalLM'],
356+
))

swift/llm/template/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1034,7 +1034,10 @@ def _swift_encode(self, inputs: StdTemplateInputs):
10341034
res_context_list.append(bos_token)
10351035
res_context_types.append(ContextType.OTHER)
10361036

1037-
prefix = template_meta.system_prefix if system else template_meta.prefix
1037+
if self.template_meta.is_post_system or not system:
1038+
prefix = template_meta.prefix
1039+
else:
1040+
prefix = template_meta.system_prefix
10381041
self._concat_context_list(prefix, res_context_list, res_context_types, system=system)
10391042

10401043
n_round = len(inputs.messages) // 2

swift/llm/template/constant.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ class LLMTemplateType:
8585
moonlight = 'moonlight'
8686
mimo_rl = 'mimo_rl'
8787
dots1 = 'dots1'
88+
hunyuan = 'hunyuan'
8889

8990
aya = 'aya'
9091
c4ai = 'c4ai'

swift/llm/template/template/llm.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,3 +289,13 @@ class TeleChatTemplateMeta(TemplateMeta):
289289
suffix=['<|endofresponse|>'],
290290
default_system='You are a helpful assistant.',
291291
))
292+
293+
register_template(
294+
TemplateMeta(
295+
LLMTemplateType.hunyuan,
296+
prefix=['<|startoftext|>'],
297+
system_prefix=['<|startoftext|>{{SYSTEM}}<|extra_4|>'],
298+
prompt=['{{QUERY}}<|extra_0|>'],
299+
chat_sep=['<|eos|><|startoftext|>'],
300+
suffix=['<|eos|>'],
301+
))

swift/llm/template/template_meta.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ def __post_init__(self):
8282

8383
self.is_post_system = self._has_system(self.prompt) # mistral_nemo
8484
if self.is_post_system:
85-
self.prompt = [context for context in self.prompt if '{{SYSTEM}}' not in context]
8685
self.system_prompt = self.prompt
86+
self.prompt = [context for context in self.prompt if '{{SYSTEM}}' not in context]
8787

8888
if self.system_prefix is None and not self.is_post_system:
8989
self.support_system = False

swift/megatron/model/config.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
# moe
2525
'moe_ffn_hidden_size': ['moe_intermediate_size'],
2626
'moe_shared_expert_intermediate_size': ['shared_expert_intermediate_size'],
27-
'moe_router_topk': ['num_experts_per_tok', 'n_group'],
27+
'moe_router_topk': ['num_experts_per_tok', 'n_group', 'moe_topk'],
2828
'num_experts': ['num_experts', 'n_routed_experts'],
2929
'moe_router_pre_softmax': ['norm_topk_prob'],
3030
'moe_aux_loss_coeff': ['router_aux_loss_coef'],
@@ -35,11 +35,12 @@
3535
'qk_head_dim': ['qk_nope_head_dim'],
3636
'qk_pos_emb_head_dim': ['qk_rope_head_dim'],
3737
'moe_router_topk_scaling_factor': ['routed_scaling_factor'],
38+
'qk_layernorm': ['use_qk_norm'],
3839
# other
3940
'original_max_position_embeddings': ['original_max_position_embeddings'],
4041
'partial_rotary_factor': ['partial_rotary_factor'],
4142
'first_k_dense_replace': ['first_k_dense_replace'],
42-
'n_shared_experts': ['n_shared_experts']
43+
'n_shared_experts': ['n_shared_experts', 'num_shared_expert'],
4344
}
4445

4546

@@ -49,6 +50,8 @@ def convert_hf_config(config) -> Dict[str, Any]:
4950
for hf_k in hf_keys:
5051
if hasattr(config, hf_k):
5152
hf_v = getattr(config, hf_k)
53+
if hf_v is None:
54+
continue
5255
if k == 'rotary_base':
5356
megatron_config[k] = int(hf_v)
5457
elif k in {'untie_embeddings_and_output_weights', 'disable_bias_linear', 'moe_router_pre_softmax'}:

swift/megatron/model/gpt/config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,14 @@ def convert_gpt_hf_config(config) -> Dict[str, Any]:
3030
if res.get('moe_router_score_function', 'softmax') == 'sigmoid':
3131
res['moe_router_enable_expert_bias'] = True
3232
res['moe_layer_freq'] = f'[0]*{first_k_dense_replace}+[1]*{res["num_layers"] - first_k_dense_replace}'
33+
if architectures == 'HunYuanMoEV1ForCausalLM':
34+
# Since HunYuan’s attention applies RoPE before using q/k_layernorm,
35+
# which is incompatible with megatron-core, support is not provided here.
36+
res['n_shared_experts'] = n_shared_experts
37+
for key in ['moe_ffn_hidden_size', 'n_shared_experts', 'moe_router_topk']:
38+
val = res.get(key)
39+
if isinstance(val, list) and val and min(val) == max(val):
40+
res[key] = val[0]
41+
n_shared_experts = res.pop('n_shared_experts')
42+
res['moe_shared_expert_intermediate_size'] = n_shared_experts * res['moe_ffn_hidden_size']
3343
return res

swift/megatron/model/gpt/hf2mcore.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,10 @@ def set_attn_state(args, mg_attn, hf_attn):
3939
],
4040
dim=1).reshape(-1))
4141
if args.qk_layernorm:
42-
mg_attn.q_layernorm.weight.data.copy_(hf_attn.q_norm.weight)
43-
mg_attn.k_layernorm.weight.data.copy_(hf_attn.k_norm.weight)
42+
q_norm = hf_attn.query_layernorm if hasattr(hf_attn, 'query_layernorm') else hf_attn.q_norm
43+
k_norm = hf_attn.key_layernorm if hasattr(hf_attn, 'key_layernorm') else hf_attn.k_norm
44+
mg_attn.q_layernorm.weight.data.copy_(q_norm.weight)
45+
mg_attn.k_layernorm.weight.data.copy_(k_norm.weight)
4446

4547

4648
def _set_mlp_state(mg_mlp, hf_mlp):
@@ -52,11 +54,19 @@ def _set_mlp_state(mg_mlp, hf_mlp):
5254

5355

5456
def _set_moe_state(args, mg_mlp, hf_mlp):
55-
mg_mlp.router.weight.data.copy_(hf_mlp.gate.weight)
57+
hf_gate = hf_mlp.gate
58+
if hasattr(hf_gate, 'wg'):
59+
hf_gate = hf_gate.wg
60+
mg_mlp.router.weight.data.copy_(hf_gate.weight)
5661
if args.moe_router_enable_expert_bias:
57-
mg_mlp.router.expert_bias.data.copy_(hf_mlp.gate.e_score_correction_bias)
62+
mg_mlp.router.expert_bias.data.copy_(hf_gate.e_score_correction_bias)
5863
if mg_mlp.shared_experts is not None:
59-
hf_shared_expert = hf_mlp.shared_expert if hasattr(hf_mlp, 'shared_expert') else hf_mlp.shared_experts
64+
if hasattr(hf_mlp, 'shared_experts'):
65+
hf_shared_expert = hf_mlp.shared_experts
66+
elif hasattr(hf_mlp, 'shared_mlp'):
67+
hf_shared_expert = hf_mlp.shared_mlp
68+
else:
69+
hf_shared_expert = hf_mlp.shared_expert
6070
_set_mlp_state(mg_mlp.shared_experts, hf_shared_expert)
6171
if mg_mlp.shared_experts.gate_weight is not None:
6272
mg_mlp.shared_experts.gate_weight.data.copy_(hf_mlp.shared_expert_gate.weight)

swift/megatron/model/gpt/mcore2hf.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,26 @@ def set_attn_state(args, mg_attn, hf_attn):
3434
hf_attn.v_proj.bias.data.copy_(mg_attn_bias[:, -kv_dim:].reshape(-1))
3535

3636
if args.qk_layernorm:
37-
hf_attn.q_norm.weight.data.copy_(mg_attn.q_layernorm.weight)
38-
hf_attn.k_norm.weight.data.copy_(mg_attn.k_layernorm.weight)
37+
q_norm = hf_attn.query_layernorm if hasattr(hf_attn, 'query_layernorm') else hf_attn.q_norm
38+
k_norm = hf_attn.key_layernorm if hasattr(hf_attn, 'key_layernorm') else hf_attn.k_norm
39+
q_norm.weight.data.copy_(mg_attn.q_layernorm.weight)
40+
k_norm.weight.data.copy_(mg_attn.k_layernorm.weight)
3941

4042

4143
def _set_moe_state(args, mg_mlp, hf_mlp):
42-
hf_mlp.gate.weight.data.copy_(mg_mlp.router.weight)
44+
hf_gate = hf_mlp.gate
45+
if hasattr(hf_gate, 'wg'):
46+
hf_gate = hf_gate.wg
47+
hf_gate.weight.data.copy_(mg_mlp.router.weight)
4348
if args.moe_router_enable_expert_bias:
44-
hf_mlp.gate.e_score_correction_bias.data.copy_(mg_mlp.router.expert_bias)
49+
hf_gate.e_score_correction_bias.data.copy_(mg_mlp.router.expert_bias)
4550
if mg_mlp.shared_experts is not None:
46-
hf_shared_expert = hf_mlp.shared_expert if hasattr(hf_mlp, 'shared_expert') else hf_mlp.shared_experts
51+
if hasattr(hf_mlp, 'shared_experts'):
52+
hf_shared_expert = hf_mlp.shared_experts
53+
elif hasattr(hf_mlp, 'shared_mlp'):
54+
hf_shared_expert = hf_mlp.shared_mlp
55+
else:
56+
hf_shared_expert = hf_mlp.shared_expert
4757
_set_mlp_state(mg_mlp.shared_experts, hf_shared_expert)
4858
if mg_mlp.shared_experts.gate_weight is not None:
4959
hf_mlp.shared_expert_gate.weight.data.copy_(mg_mlp.shared_experts.gate_weight)

swift/megatron/model/rope.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1+
from typing import Any, Dict, Optional
2+
3+
import torch
14
from megatron.training import get_args
5+
from transformers import PretrainedConfig
26

37
from swift.utils import get_logger
48

@@ -30,11 +34,22 @@ def _get_dummy_config(args):
3034
return dummy_config
3135

3236

37+
EXTENDED_ROPE_INIT_FUNCTIONS = {}
38+
39+
40+
def _get_rope_type(rope_scaling: Dict[str, Any]):
41+
rope_type = rope_scaling['rope_type']
42+
if rope_type == 'dynamic' and rope_scaling.get('alpha') is not None:
43+
rope_type = 'dynamic_alpha'
44+
return rope_type
45+
46+
3347
def get_rope_inv_freq(seq_len=None):
3448
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
3549
args = get_args()
50+
ROPE_INIT_FUNCTIONS.update(EXTENDED_ROPE_INIT_FUNCTIONS)
3651
dummy_config = _get_dummy_config(args)
37-
rope_init_fn = ROPE_INIT_FUNCTIONS[args.rope_scaling['rope_type']]
52+
rope_init_fn = ROPE_INIT_FUNCTIONS[_get_rope_type(args.rope_scaling)]
3853
inv_freq, attention_scaling = rope_init_fn(dummy_config, 'cpu', seq_len=seq_len)
3954
if attention_scaling is None:
4055
attention_scaling = 1.
@@ -49,7 +64,7 @@ def longrope_frequency_update(args, model, inv_freq, seq_len: int):
4964
original_max_position_embeddings = args.max_position_embeddings
5065

5166
if not hasattr(model, 'long_inv_freq'):
52-
model.long_inv_freq, _ = get_rope_inv_freq(inv_freq.device, seq_len=original_max_position_embeddings + 1)
67+
model.long_inv_freq, _ = get_rope_inv_freq(seq_len=original_max_position_embeddings + 1)
5368
model.original_inv_freq = inv_freq.clone()
5469

5570
if seq_len > original_max_position_embeddings:
@@ -66,7 +81,7 @@ def dynamic_frequency_update(args, model, inv_freq, seq_len: int):
6681
model.original_inv_freq = inv_freq.clone()
6782
attention_scaling = None
6883
if seq_len > model.max_seq_len_cached: # growth
69-
new_inv_freq, attention_scaling = get_rope_inv_freq(inv_freq.device, seq_len=seq_len)
84+
new_inv_freq, attention_scaling = get_rope_inv_freq(seq_len=seq_len)
7085
inv_freq.data.copy_(new_inv_freq)
7186
model.max_seq_len_cached = seq_len
7287

@@ -78,10 +93,34 @@ def dynamic_frequency_update(args, model, inv_freq, seq_len: int):
7893

7994
def dynamic_rope_update(model, inv_freq, seq_len: int):
8095
args = get_args()
81-
rope_type = args.rope_scaling['rope_type']
96+
rope_type = _get_rope_type(args.rope_scaling)
8297
attention_scaling = None
8398
if rope_type == 'dynamic':
8499
attention_scaling = dynamic_frequency_update(args, model, inv_freq, seq_len)
85100
elif rope_type == 'longrope':
86101
attention_scaling = longrope_frequency_update(args, model, inv_freq, seq_len)
87102
return attention_scaling
103+
104+
105+
def _compute_dynamic_alpha_ntk_parameters(
106+
config: Optional[PretrainedConfig] = None,
107+
device: Optional['torch.device'] = None,
108+
seq_len: Optional[int] = None,
109+
**rope_kwargs,
110+
) -> tuple['torch.Tensor', float]:
111+
# Code borrowed from Tencent-Hunyuan/Hunyuan-A13B-Instruct
112+
base = config.rope_theta
113+
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, 'partial_rotary_factor') else 1.0
114+
head_dim = getattr(config, 'head_dim', config.hidden_size // config.num_attention_heads)
115+
dim = int(head_dim * partial_rotary_factor)
116+
alpha = config.rope_scaling['alpha']
117+
118+
attention_factor = 1.0 # Unused in this type of RoPE
119+
120+
# Compute the inverse frequencies
121+
base = base * alpha**(dim / (dim - 2))
122+
inv_freq = 1.0 / (base**(torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim))
123+
return inv_freq, attention_factor
124+
125+
126+
EXTENDED_ROPE_INIT_FUNCTIONS['dynamic_alpha'] = _compute_dynamic_alpha_ntk_parameters

swift/megatron/utils/convert.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def _find_modules(model, recurse: bool = True):
5555

5656

5757
@contextmanager
58-
def _model_cpu_forward_context(modules, torch_dtype=None, device=None):
58+
def _model_cpu_forward_context(modules, torch_dtype=None, device=None, share_embedding: bool = False):
5959
origin_torch_dtype = next(modules[0].parameters()).dtype
6060

6161
def _to_cuda_hook(module, args):
@@ -65,6 +65,8 @@ def _to_cuda_hook(module, args):
6565
module.to(torch_dtype)
6666

6767
def _to_cpu_hook(module, args, output):
68+
if share_embedding and module is modules[0]:
69+
return
6870
module.to('cpu')
6971
if torch_dtype is not None:
7072
module.to(origin_torch_dtype)
@@ -89,9 +91,11 @@ def test_convert_precision(hf_model, mg_model, processor, torch_dtype=torch.floa
8991
input_ids = torch.tensor(input_ids)[None].to('cuda')
9092

9193
HfConfigFactory.set_model_config_attr(hf_model, 'use_cache', False)
94+
share_embedding = mg_model.share_embeddings_and_output_weights
9295
hf_modules = _find_modules(hf_model)
93-
with torch.inference_mode(), _model_cpu_forward_context(hf_modules, torch_dtype):
96+
with torch.inference_mode(), _model_cpu_forward_context(hf_modules, torch_dtype, share_embedding=share_embedding):
9497
hf_logits = hf_model(input_ids).logits
98+
hf_model = hf_model.to('cpu')
9599

96100
attention_mask, _, position_ids = get_ltor_masks_and_position_ids(input_ids, -100, True, True, True)
97101
packed_seq_params = None
@@ -102,7 +106,8 @@ def test_convert_precision(hf_model, mg_model, processor, torch_dtype=torch.floa
102106
# packed_seq_params = get_packed_seq_params(position_ids)
103107
# attention_mask = None
104108
mg_modules = _find_modules(mg_model)
105-
with torch.inference_mode(), _model_cpu_forward_context(mg_modules, mg_torch_dtype, 'cuda'):
109+
with torch.inference_mode(), _model_cpu_forward_context(
110+
mg_modules, mg_torch_dtype, 'cuda', share_embedding=share_embedding):
106111
mg_logits = mg_model(
107112
input_ids=input_ids,
108113
attention_mask=attention_mask,

tests/megatron/test_align/test_llm.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,10 @@ def test_kimi_dev():
116116
_test_model('moonshotai/Kimi-Dev-72B')
117117

118118

119+
def test_hunyuan():
120+
_test_model('Tencent-Hunyuan/Hunyuan-A13B-Instruct')
121+
122+
119123
if __name__ == '__main__':
120124
# test_qwen2()
121125
# test_llama2()
@@ -137,4 +141,5 @@ def test_kimi_dev():
137141
# test_deepseek_v2()
138142
# test_deepseek_moe()
139143
# test_dots()
140-
test_kimi_dev()
144+
# test_kimi_dev()
145+
test_hunyuan()

tests/test_align/test_template/test_llm.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,14 @@ def test_kimi_dev():
430430
assert res == res2, f'res: {res}, res2: {res2}'
431431

432432

433+
def test_hunyuan():
434+
pt_engine = PtEngine('Tencent-Hunyuan/Hunyuan-A13B-Instruct')
435+
res = _infer_model(pt_engine)
436+
pt_engine.default_template.template_backend = 'jinja'
437+
res2 = _infer_model(pt_engine)
438+
assert res == res2, f'res: {res}, res2: {res2}'
439+
440+
433441
if __name__ == '__main__':
434442
from swift.llm import PtEngine, RequestConfig
435443
from swift.utils import get_logger, seed_everything
@@ -471,4 +479,5 @@ def test_kimi_dev():
471479
# test_mimo()
472480
# test_minicpm()
473481
# test_minimax()
474-
test_kimi_dev()
482+
# test_kimi_dev()
483+
test_hunyuan()

0 commit comments

Comments
 (0)