diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index e2ababc3e5f..327aeb482e0 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -216,6 +216,9 @@ class TransformerConfig(ModelParallelConfig): """Number of SMs to use for HybridEP. In pure NVL scenarios, 16 SMs can generally achieve good bandwidth.""" + untie_embeddings_and_output_weights: bool = False + """The model's input word embedding matrix and the output layer's weight matrix are tied""" + #################### # initialization #################### @@ -787,7 +790,6 @@ class TransformerConfig(ModelParallelConfig): lora_out_init_method: Optional[str] = None """Lora b init method""" - #################### # TE_FL ####################