Skip to content

Commit 17cb159

Browse files
authored
Merge pull request #85 from rhymes-ai/fix_output
fix: disable tie_embedding
2 parents 485a6f6 + 5b61ee8 commit 17cb159

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

aria/model/configuration_aria.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,13 @@ def __init__(
6666
},
6767
ignore_index=-100,
6868
image_token_index=32000,
69+
tie_word_embeddings=False,
6970
**kwargs,
7071
):
7172
super().__init__(**kwargs)
7273
self.ignore_index = ignore_index
7374
self.image_token_index = image_token_index
74-
75+
self.tie_word_embeddings = tie_word_embeddings
7576
attn_implementation = kwargs.pop("attn_implementation", None)
7677

7778
# Set the default attention implementation to flash_attention_2 if not specified

aria/model/modeling_aria.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -164,11 +164,11 @@ def get_input_embeddings(self) -> nn.Module:
164164
def set_input_embeddings(self, value):
165165
"""Set the input embeddings for the language model."""
166166
self.language_model.set_input_embeddings(value)
167-
167+
168168
def get_output_embeddings(self):
169169
"""Retrieve the output embeddings from the language model."""
170170
return self.language_model.get_output_embeddings()
171-
171+
172172
def set_output_embeddings(self, value):
173173
"""Set the output embeddings for the language model."""
174174
self.language_model.set_output_embeddings(value)

0 commit comments

Comments
 (0)