1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ ElasticBERT model configuration """
17
+
18
+
19
+ from fastNLP .core .log import logger
20
+ from fastNLP .transformers .torch .configuration_utils import PretrainedConfig
21
+
22
+
23
+ __all__ = [
24
+ "ELASTICBERT_PRETRAINED_CONFIG_ARCHIVE_MAP" ,
25
+ "ElasticBertConfig" ,
26
+ ]
27
+
28
+ ELASTICBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
29
+ "elasticbert-base" : "https://huggingface.co/fnlp/elasticbert-base/resolve/main/config.json" ,
30
+ "elasticbert-large" : "https://huggingface.co/fnlp/elasticbert-large/resolve/main/config.json" ,
31
+ "elasticbert-base-chinese" : "https://huggingface.co/fnlp/elasticbert-chinese-base/resolve/main/config.json"
32
+ }
33
+
34
+
35
+ class ElasticBertConfig (PretrainedConfig ):
36
+ r"""
37
+ This is the configuration class to store the configuration of a :class:`ElasticBertModel`
38
+
39
+ Args:
40
+ max_output_layers (:obj: `int`, default to 12):
41
+ The maximum number of classification layers.
42
+ num_output_layers (:obj: `int`, default to 1):
43
+ The number of classification layers. Used to specify how many classification layers there are.
44
+ It is 1 in static usage, and equal to num_hidden_layers in dynamic usage.
45
+ """
46
+
47
+ model_type = "elasticbert"
48
+
49
+ def __init__ (
50
+ self ,
51
+ vocab_size = 30522 ,
52
+ hidden_size = 768 ,
53
+ num_hidden_layers = 12 ,
54
+ num_attention_heads = 12 ,
55
+ max_output_layers = 12 ,
56
+ num_output_layers = 12 ,
57
+ intermediate_size = 3072 ,
58
+ hidden_act = "gelu" ,
59
+ hidden_dropout_prob = 0.1 ,
60
+ attention_probs_dropout_prob = 0.1 ,
61
+ max_position_embeddings = 512 ,
62
+ type_vocab_size = 2 ,
63
+ initializer_range = 0.02 ,
64
+ layer_norm_eps = 1e-12 ,
65
+ pad_token_id = 0 ,
66
+ gradient_checkpointing = False ,
67
+ position_embedding_type = "absolute" ,
68
+ use_cache = True ,
69
+ ** kwargs
70
+ ):
71
+ super ().__init__ (pad_token_id = pad_token_id , ** kwargs )
72
+
73
+ self .vocab_size = vocab_size
74
+ self .hidden_size = hidden_size
75
+ self .num_hidden_layers = num_hidden_layers
76
+ self .num_attention_heads = num_attention_heads
77
+ self .max_output_layers = max_output_layers
78
+ self .num_output_layers = num_output_layers
79
+ self .hidden_act = hidden_act
80
+ self .intermediate_size = intermediate_size
81
+ self .hidden_dropout_prob = hidden_dropout_prob
82
+ self .attention_probs_dropout_prob = attention_probs_dropout_prob
83
+ self .max_position_embeddings = max_position_embeddings
84
+ self .type_vocab_size = type_vocab_size
85
+ self .initializer_range = initializer_range
86
+ self .layer_norm_eps = layer_norm_eps
87
+ self .gradient_checkpointing = gradient_checkpointing
88
+ self .position_embedding_type = position_embedding_type
89
+ self .use_cache = use_cache
0 commit comments