Skip to content

Commit cf487cd

Browse files
yjc9696pridejcyangmingjihantencent
authored
HunYuan opensource (huggingface#39606)
* merge opensource_hunyuan * add head_dim * fix assertion error * fix seen_tokens * ready_for_upstream (merge request !17) Squash merge branch 'ready_for_upstream' into 'main' * fix configuration type&docstring * fix style * ready_for_upstream (merge request !18) Squash merge branch 'ready_for_upstream' into 'main' * add doc * fix testcode * fix configuration type&docstring * rename base model * remove assert * update * remove tiktoken * update * fix moe and code style (#3) * update * fix format * update * revert makefile * fix moe config * fix numel() * remove prepare_inputs_for_generation * fix kv_seq_len * add docs/toctree * remove unused paramter&add licence * add licence * remove unused paramter * fix code * dense modular update import fix fix use mistralmodel fix qknorm add sliding_window make style fix dense done hunyuan moe fix import fix modular fixup fixup * update model path * fix mlp_bias * fix modular * Fix modeling (huggingface#5) * fix attention * use llamamodel * fix code * Fix qk (huggingface#6) * fix qk_norm * fix * fix modual * Fix moe (huggingface#7) * fix some moe code * fix einsum * try top1 * use top1 * Fix rotary (huggingface#8) * fix rotary * fix modeling * fix modular * fix testcode * remove A13B unit test * Fix moe v1 (huggingface#9) fix moe & gate * Fix gate norm (huggingface#10) * add norm_topk_prob * Fix testcase (huggingface#11) * fix&skip test * Fix testcase (huggingface#12) * skip testcase * Fix norm topk (huggingface#13) * hardcode norm_topk_prob * fix testcase --------- Co-authored-by: pridejcyang <[email protected]> Co-authored-by: Mingji Han <[email protected]>
1 parent 8365f70 commit cf487cd

18 files changed

+2342
-0
lines changed

docs/source/en/_toctree.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,12 @@
529529
title: Helium
530530
- local: model_doc/herbert
531531
title: HerBERT
532+
- local: model_doc/hgnet_v2
533+
title: HGNet-V2
534+
- local: model_doc/hunyuan_v1_dense
535+
title: HunYuanDenseV1
536+
- local: model_doc/hunyuan_v1_moe
537+
title: HunYuanMoEV1
532538
- local: model_doc/ibert
533539
title: I-BERT
534540
- local: model_doc/jamba
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
<!--Copyright (C) 2024 THL A29 Limited, a Tencent company and The HuggingFace Inc. team. All rights reserved..
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
13+
rendered properly in your Markdown viewer.
14+
15+
-->
16+
17+
# HunYuanDenseV1
18+
19+
## Overview
20+
21+
To be released with the official model launch.
22+
23+
### Model Details
24+
25+
To be released with the official model launch.
26+
27+
28+
## Usage tips
29+
30+
To be released with the official model launch.
31+
32+
## HunYuanDenseV1Config
33+
34+
[[autodoc]] HunYuanDenseV1Config
35+
36+
## HunYuanModel
37+
38+
[[autodoc]] HunYuanDenseV1Model
39+
- forward
40+
41+
## HunYuanDenseV1ForCausalLM
42+
43+
[[autodoc]] HunYuanDenseV1ForCausalLM
44+
- forward
45+
46+
## HunYuanDenseV1ForSequenceClassification
47+
48+
[[autodoc]] HunYuanDenseV1ForSequenceClassification
49+
- forward
50+
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
<!--Copyright (C) 2024 THL A29 Limited, a Tencent company and The HuggingFace Inc. team. All rights reserved..
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
13+
rendered properly in your Markdown viewer.
14+
15+
-->
16+
17+
# HunYuanMoEV1
18+
19+
## Overview
20+
21+
To be released with the official model launch.
22+
23+
### Model Details
24+
25+
To be released with the official model launch.
26+
27+
28+
## Usage tips
29+
30+
To be released with the official model launch.
31+
32+
## HunYuanMoEV1Config
33+
34+
[[autodoc]] HunYuanMoEV1Config
35+
36+
## HunYuanMoEV1Model
37+
38+
[[autodoc]] HunYuanMoEV1Model
39+
- forward
40+
41+
## HunYuanMoEV1ForCausalLM
42+
43+
[[autodoc]] HunYuanMoEV1ForCausalLM
44+
- forward
45+
46+
## HunYuanMoEV1ForSequenceClassification
47+
48+
[[autodoc]] HunYuanMoEV1ForSequenceClassification
49+
- forward
50+

src/transformers/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,8 @@
158158
from .hgnet_v2 import *
159159
from .hiera import *
160160
from .hubert import *
161+
from .hunyuan_v1_dense import *
162+
from .hunyuan_v1_moe import *
161163
from .ibert import *
162164
from .idefics import *
163165
from .idefics2 import *

src/transformers/models/auto/configuration_auto.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,8 @@
193193
("hgnet_v2", "HGNetV2Config"),
194194
("hiera", "HieraConfig"),
195195
("hubert", "HubertConfig"),
196+
("hunyuan_v1_dense", "HunYuanDenseV1Config"),
197+
("hunyuan_v1_moe", "HunYuanMoEV1Config"),
196198
("ibert", "IBertConfig"),
197199
("idefics", "IdeficsConfig"),
198200
("idefics2", "Idefics2Config"),
@@ -613,6 +615,8 @@
613615
("hgnet_v2", "HGNet-V2"),
614616
("hiera", "Hiera"),
615617
("hubert", "Hubert"),
618+
("hunyuan_v1_dense", "HunYuanDenseV1"),
619+
("hunyuan_v1_moe", "HunYuanMoeV1"),
616620
("ibert", "I-BERT"),
617621
("idefics", "IDEFICS"),
618622
("idefics2", "Idefics2"),

src/transformers/models/auto/modeling_auto.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,8 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
193193
("hgnet_v2", "HGNetV2Backbone"),
194194
("hiera", "HieraModel"),
195195
("hubert", "HubertModel"),
196+
("hunyuan_v1_dense", "HunYuanDenseV1Model"),
197+
("hunyuan_v1_moe", "HunYuanMoEV1Model"),
196198
("ibert", "IBertModel"),
197199
("idefics", "IdeficsModel"),
198200
("idefics2", "Idefics2Model"),
@@ -664,6 +666,8 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
664666
("granitemoehybrid", "GraniteMoeHybridForCausalLM"),
665667
("granitemoeshared", "GraniteMoeSharedForCausalLM"),
666668
("helium", "HeliumForCausalLM"),
669+
("hunyuan_v1_dense", "HunYuanDenseV1ForCausalLM"),
670+
("hunyuan_v1_moe", "HunYuanMoEV1ForCausalLM"),
667671
("jamba", "JambaForCausalLM"),
668672
("jetmoe", "JetMoeForCausalLM"),
669673
("lfm2", "Lfm2ForCausalLM"),
@@ -1209,6 +1213,8 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
12091213
("gpt_oss", "GptOssForSequenceClassification"),
12101214
("gptj", "GPTJForSequenceClassification"),
12111215
("helium", "HeliumForSequenceClassification"),
1216+
("hunyuan_v1_dense", "HunYuanDenseV1ForSequenceClassification"),
1217+
("hunyuan_v1_moe", "HunYuanMoEV1ForSequenceClassification"),
12121218
("ibert", "IBertForSequenceClassification"),
12131219
("jamba", "JambaForSequenceClassification"),
12141220
("jetmoe", "JetMoeForSequenceClassification"),
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from typing import TYPE_CHECKING
2+
3+
from ...utils import _LazyModule
4+
from ...utils.import_utils import define_import_structure
5+
6+
7+
if TYPE_CHECKING:
8+
from .configuration_hunyuan_v1_dense import *
9+
from .modeling_hunyuan_v1_dense import *
10+
from .tokenization_hy import *
11+
else:
12+
import sys
13+
14+
_file = globals()["__file__"]
15+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
# coding=utf-8
2+
# Copyright (C) 2025 THL A29 Limited, a Tencent company and the HuggingFace Inc. team. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""HunYuanDenseV1 model configuration"""
16+
17+
from transformers.configuration_utils import PretrainedConfig
18+
from transformers.utils import logging
19+
20+
21+
logger = logging.get_logger(__name__)
22+
23+
24+
class HunYuanDenseV1Config(PretrainedConfig):
25+
r"""
26+
This is the configuration class to store the configuration of a [`HunYuanDenseV1Config`]. It is used to instantiate an
27+
HunYuan model according to the specified arguments, defining the model architecture. Instantiating a configuration
28+
with the defaults will yield a similar configuration to that of the HunYuan-7B.
29+
Hunyuan-7B-Instruct [tencent/Hunyuan-7B-Instruct](https://huggingface.co/tencent/Hunyuan-7B-Instruct).
30+
31+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
32+
documentation from [`PretrainedConfig`] for more information.
33+
34+
35+
Args:
36+
vocab_size (`int`, *optional*, defaults to 290943):
37+
Vocabulary size of the HunYuan model. Defines the number of different tokens that can be represented by the
38+
`inputs_ids` passed when calling [`HunYuanDenseV1Config`]
39+
hidden_size (`int`, *optional*, defaults to 4096):
40+
Dimension of the hidden representations.
41+
intermediate_size (`int`, *optional*, defaults to 11008):
42+
Dimension of the MLP representations or shared MLP representations.
43+
num_hidden_layers (`int`, *optional*, defaults to 32):
44+
Number of hidden layers in the Transformer decoder.
45+
num_attention_heads (`int`, *optional*, defaults to 32):
46+
Number of attention heads for each attention layer in the Transformer decoder.
47+
num_key_value_heads (`int`, *optional*):
48+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
49+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
50+
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
51+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
52+
by meanpooling all the original heads within that group. For more details checkout [this
53+
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
54+
`num_attention_heads`.
55+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
56+
The non-linear activation function (function or string) in the decoder.
57+
max_position_embeddings (`int`, *optional*, defaults to 2048):
58+
The maximum sequence length that this model might ever be used with.
59+
initializer_range (`float`, *optional*, defaults to 0.02):
60+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
61+
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
62+
The epsilon used by the rms normalization layers.
63+
use_cache (`bool`, *optional*, defaults to `True`):
64+
Whether or not the model should return the last key/values attentions (not used by all models). Only
65+
relevant if `config.is_decoder=True`.
66+
pad_token_id (`int`, *optional*, defaults to 0):
67+
Padding token id.
68+
bos_token_id (`int`, *optional*, defaults to 1):
69+
Beginning of stream token id.
70+
eos_token_id (`int`, *optional*, defaults to 2):
71+
End of stream token id.
72+
eod_token_id (int, *optional*, defaults to 3):
73+
Token ID representing the end-of-document marker. Used to indicate the termination of a text sequence.
74+
Example: In multi-document processing, this token helps the model distinguish between separate documents.
75+
pretraining_tp (`int`, *optional*, defaults to 1):
76+
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
77+
document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
78+
necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
79+
issue](https://github.com/pytorch/pytorch/issues/76232).
80+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
81+
Whether to tie weight embeddings
82+
rope_theta (`float`, *optional*, defaults to 10000.0):
83+
The base period of the RoPE embeddings.
84+
rope_scaling (`Dict`, *optional*):
85+
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
86+
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
87+
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
88+
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
89+
these scaling strategies behave:
90+
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
91+
experimental feature, subject to breaking API changes in future versions.
92+
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
93+
Whether to use a bias in the query, key, value and output projection layers during self-attention.
94+
attention_dropout (`float`, *optional*, defaults to 0.0):
95+
The dropout ratio for the attention probabilities.
96+
head_dim (`int`, *optional*, defaults to 128):
97+
The attention head dimension.
98+
"""
99+
100+
model_type = "hunyuan_v1_dense"
101+
keys_to_ignore_at_inference = ["past_key_values"]
102+
103+
def __init__(
104+
self,
105+
vocab_size=290943,
106+
hidden_size=4096,
107+
intermediate_size: int = 11008,
108+
num_hidden_layers=32,
109+
num_attention_heads=32,
110+
num_key_value_heads=None,
111+
hidden_act="silu",
112+
max_position_embeddings=2048,
113+
initializer_range=0.02,
114+
rms_norm_eps=1e-5,
115+
use_cache=True,
116+
pad_token_id=0,
117+
bos_token_id=1,
118+
eos_token_id=2,
119+
eod_token_id=3,
120+
pretraining_tp=1,
121+
tie_word_embeddings=False,
122+
rope_theta=10000.0,
123+
rope_scaling=None,
124+
attention_bias=False,
125+
attention_dropout=0.0,
126+
head_dim=None,
127+
**kwargs,
128+
):
129+
self.vocab_size = vocab_size
130+
self.max_position_embeddings = max_position_embeddings
131+
self.hidden_size = hidden_size
132+
self.intermediate_size = intermediate_size
133+
self.num_hidden_layers = num_hidden_layers
134+
self.num_attention_heads = num_attention_heads
135+
self.head_dim = head_dim
136+
# for backward compatibility
137+
if num_key_value_heads is None:
138+
num_key_value_heads = num_attention_heads
139+
140+
self.num_key_value_heads = num_key_value_heads
141+
self.hidden_act = hidden_act
142+
self.initializer_range = initializer_range
143+
self.rms_norm_eps = rms_norm_eps
144+
self.pretraining_tp = pretraining_tp
145+
self.use_cache = use_cache
146+
self.rope_theta = rope_theta
147+
self.rope_scaling = rope_scaling
148+
# self._rope_scaling_validation() # TODO: Need validation?
149+
self.attention_bias = attention_bias
150+
self.attention_dropout = attention_dropout
151+
152+
super().__init__(
153+
pad_token_id=pad_token_id,
154+
bos_token_id=bos_token_id,
155+
eos_token_id=eos_token_id,
156+
tie_word_embeddings=tie_word_embeddings,
157+
**kwargs,
158+
)
159+
160+
def _rope_scaling_validation(self):
161+
"""
162+
Validate the `rope_scaling` configuration.
163+
"""
164+
if self.rope_scaling is None:
165+
return
166+
167+
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
168+
raise ValueError(
169+
"`rope_scaling` must be a dictionary with with two fields, `type` and `factor` or `type` and `alpha`, "
170+
f"got {self.rope_scaling}"
171+
)
172+
rope_scaling_type = self.rope_scaling.get("type", None)
173+
rope_scaling_factor = self.rope_scaling.get("factor", None)
174+
rope_scaling_alpha = self.rope_scaling.get("alpha", None)
175+
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
176+
raise ValueError(
177+
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
178+
)
179+
if rope_scaling_factor is None and rope_scaling_alpha is None:
180+
raise ValueError("`rope_scaling`'s factor or alpha field must be have one, got both of none")
181+
if rope_scaling_factor is not None:
182+
if not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
183+
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1.0, got {rope_scaling_factor}")
184+
if rope_scaling_alpha is not None:
185+
if not isinstance(rope_scaling_alpha, float) or rope_scaling_alpha <= 1.0:
186+
raise ValueError(f"`rope_scaling`'s alpha field must be a float > 1.0, got {rope_scaling_alpha}")
187+
188+
189+
__all__ = ["HunYuanDenseV1Config"]

0 commit comments

Comments
 (0)