Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

push rope back to encoder/decoder #208

Merged
merged 4 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions eole/config/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,11 +259,11 @@ class TransformerConfig(Config):

@model_validator(mode="after")
def _validate_transformer_config(self):
"""

if self.position_encoding_type == PositionEncodingType.Rotary:
if self.rope_config is None:
self.rope_config = RotaryPositionConfig()
"""

if self.add_qkvbias and "add_final_linear_bias" not in self.model_fields_set:
self.update(add_final_linear_bias=True)
return self
Expand Down Expand Up @@ -472,6 +472,7 @@ def default_architecture(cls, data: Any) -> Any:

def update_model_opts(self):
update_dict = {}
"""
if self.embeddings.position_encoding_type == PositionEncodingType.Rotary:
if not self.rope_config:
update_dict["rope_config"] = RotaryPositionConfig()
Expand All @@ -480,7 +481,7 @@ def update_model_opts(self):
rope_config = self.rope_config
else:
rope_config = None

"""
if self.embeddings is not None and self.embeddings.word_vec_size > 0:
update_dict["embeddings"] = {
"src_word_vec_size": self.embeddings.word_vec_size,
Expand All @@ -499,7 +500,7 @@ def update_model_opts(self):
{
"position_encoding_type": self.embeddings.position_encoding_type,
"n_positions": self.embeddings.n_positions,
"rope_config": rope_config,
# "rope_config": rope_config,
}
)
update_dict["position_encoding_type"] = self.embeddings.position_encoding_type
Expand All @@ -513,7 +514,7 @@ def update_model_opts(self):
{
"position_encoding_type": self.embeddings.position_encoding_type,
"n_positions": self.embeddings.n_positions,
"rope_config": rope_config,
# "rope_config": rope_config,
}
)
update_dict["position_encoding_type"] = self.embeddings.position_encoding_type
Expand Down
20 changes: 10 additions & 10 deletions eole/decoders/cnn_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,39 +21,39 @@ class CNNDecoder(DecoderBase):

def __init__(
self,
model_config,
decoder_config,
running_config=None,
with_cross_attn=False,
):
super(CNNDecoder, self).__init__()

self.cnn_kernel_width = model_config.cnn_kernel_width
self.cnn_kernel_width = decoder_config.cnn_kernel_width

# Decoder State
self.state = {}

input_size = model_config.hidden_size # we need embeddings.src_vec_size
self.linear = nn.Linear(input_size, model_config.hidden_size)
input_size = decoder_config.hidden_size # we need embeddings.src_vec_size
self.linear = nn.Linear(input_size, decoder_config.hidden_size)
self.conv_layers = nn.ModuleList(
[
GatedConv(
model_config.hidden_size,
model_config.cnn_kernel_width,
decoder_config.hidden_size,
decoder_config.cnn_kernel_width,
getattr(running_config, "dropout", [0.0])[0],
True,
)
for i in range(model_config.layers)
for i in range(decoder_config.layers)
]
)
self.attn_layers = nn.ModuleList(
[ConvMultiStepAttention(model_config.hidden_size) for i in range(model_config.layers)]
[ConvMultiStepAttention(decoder_config.hidden_size) for i in range(decoder_config.layers)]
)

@classmethod
def from_config(cls, model_config, running_config=None, with_cross_attn=False):
def from_config(cls, decoder_config, running_config=None, with_cross_attn=False):
"""Alternate constructor."""
return cls(
model_config,
decoder_config,
running_config,
with_cross_attn=False,
)
Expand Down
2 changes: 1 addition & 1 deletion eole/decoders/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(self, attentional=True):
self.state = {}

@classmethod
def from_config(cls, model_config, running_config=None, with_cross_attn=False):
def from_config(cls, decoder_config, running_config=None, with_cross_attn=False):
"""Alternate constructor.

Subclasses should override this method.
Expand Down
9 changes: 6 additions & 3 deletions eole/decoders/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class EnsembleEncoder(EncoderBase):
def __init__(self, model_encoders):
super(EnsembleEncoder, self).__init__()
self.model_encoders = nn.ModuleList(model_encoders)
if hasattr(model_encoders[0], "rope"):
self.rope = model_encoders[0].rope

def forward(self, emb, pad_mask=None, **kwargs):
enc_out, enc_final_hs = zip(
Expand Down Expand Up @@ -77,6 +79,8 @@ def __init__(self, model_decoders):
attentional = any([dec.attentional for dec in model_decoders])
super(EnsembleDecoder, self).__init__(attentional)
self.model_decoders = model_decoders
if hasattr(model_decoders[0], "rope"):
self.rope = model_decoders[0].rope

def forward(self, emb, enc_out=None, src_len=None, step=None, **kwargs):
"""See :func:`eole.decoders.decoder.DecoderBase.forward()`."""
Expand Down Expand Up @@ -157,9 +161,9 @@ class EnsembleModel(EncoderDecoderModel):

def __init__(self, models, raw_probs=False):
src_emb = EnsembleSrcEmb([model.src_emb for model in models])
encoder = EnsembleEncoder(model.encoder for model in models)
encoder = EnsembleEncoder([model.encoder for model in models])
tgt_emb = EnsembleTgtEmb([model.tgt_emb for model in models])
decoder = EnsembleDecoder(model.decoder for model in models)
decoder = EnsembleDecoder([model.decoder for model in models])
hidden_size = models[0].hidden_size
super(EnsembleModel, self).__init__(
encoder=encoder,
Expand All @@ -170,7 +174,6 @@ def __init__(self, models, raw_probs=False):
)
self.generator = EnsembleGenerator([model.generator for model in models], raw_probs)
self.models = nn.ModuleList(models)
self.rope = models[0].rope


def load_test_model(config, device_id=0):
Expand Down
44 changes: 22 additions & 22 deletions eole/decoders/rnn_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,27 @@ class RNNDecoderBase(DecoderBase):
and required by :class:`~eole.models.BaseModel`.

Args:
model_config (eole.config.DecoderConfig): full decoder config
decoder_config (eole.config.DecoderConfig): full decoder config
running_config (TrainingConfig / InferenceConfig)
"""

def __init__(
self,
model_config,
decoder_config,
running_config=None,
with_cross_attn=False,
):
super(RNNDecoderBase, self).__init__(
attentional=model_config.global_attention != "none" and model_config.global_attention is not None
attentional=decoder_config.global_attention != "none" and decoder_config.global_attention is not None
)

self.bidirectional_encoder = model_config.bidirectional_encoder
self.num_layers = model_config.layers
self.bidirectional_encoder = decoder_config.bidirectional_encoder
self.num_layers = decoder_config.layers
self.dropout = nn.Dropout(getattr(running_config, "dropout", [0.0])[0])

# Build the RNN.
self.rnn = self._build_rnn(
model_config.rnn_type,
decoder_config.rnn_type,
input_size=self._input_size,
hidden_size=self.hidden_size,
num_layers=self.num_layers,
Expand All @@ -43,35 +43,35 @@ def __init__(

# Set up the context gate.
self.context_gate = None
if model_config.context_gate is not None:
if decoder_config.context_gate is not None:
self.context_gate = context_gate_factory(
model_config.context_gate,
decoder_config.context_gate,
self._input_size,
self.hidden_size,
self.hidden_size,
self.hidden_size,
)

# Set up the standard attention.
self._coverage = model_config.coverage_attn
self._coverage = decoder_config.coverage_attn
if not self.attentional:
if self._coverage:
raise ValueError("Cannot use coverage term with no attention.")
self.attn = None
else:
self.attn = GlobalAttention(
self.hidden_size,
coverage=model_config.coverage_attn,
attn_type=model_config.global_attention,
attn_func=model_config.global_attention_function,
coverage=decoder_config.coverage_attn,
attn_type=decoder_config.global_attention,
attn_func=decoder_config.global_attention_function,
)

@classmethod
def from_config(cls, model_config, running_config=None, with_cross_attn=False):
def from_config(cls, decoder_config, running_config=None, with_cross_attn=False):
"""Alternate constructor."""
# config = opt.model.decoder # RnnDecoderConfig
return cls(
model_config,
decoder_config,
running_config=running_config,
with_cross_attn=False,
)
Expand Down Expand Up @@ -176,13 +176,13 @@ class StdRNNDecoder(RNNDecoderBase):

def __init__(
self,
model_config,
decoder_config,
running_config=None,
with_cross_attn=False,
):
self.hidden_size = model_config.hidden_size
self._input_size = model_config.tgt_word_vec_size
super(StdRNNDecoder, self).__init__(model_config, running_config)
self.hidden_size = decoder_config.hidden_size
self._input_size = decoder_config.tgt_word_vec_size
super(StdRNNDecoder, self).__init__(decoder_config, running_config)

def _run_forward_pass(self, emb, enc_out, src_len=None):
"""
Expand Down Expand Up @@ -256,13 +256,13 @@ class InputFeedRNNDecoder(RNNDecoderBase):

def __init__(
self,
model_config,
decoder_config,
running_config=None,
with_cross_attn=False,
):
self.hidden_size = model_config.hidden_size
self._input_size = model_config.tgt_word_vec_size + self.hidden_size
super(InputFeedRNNDecoder, self).__init__(model_config, running_config)
self.hidden_size = decoder_config.hidden_size
self._input_size = decoder_config.tgt_word_vec_size + self.hidden_size
super(InputFeedRNNDecoder, self).__init__(decoder_config, running_config)

def _run_forward_pass(self, emb, enc_out, src_len=None):
"""
Expand Down
Loading