Skip to content

Commit 485a6f6

Browse files
authored
Merge pull request #84 from rhymes-ai/output_embeddings
add set/get_output_embeddings
2 parents 7cbf499 + d7687b6 commit 485a6f6

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

aria/model/modeling_aria.py

+8
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,14 @@ def get_input_embeddings(self) -> nn.Module:
164164
def set_input_embeddings(self, value):
165165
"""Set the input embeddings for the language model."""
166166
self.language_model.set_input_embeddings(value)
167+
168+
def get_output_embeddings(self):
169+
"""Retrieve the output embeddings from the language model."""
170+
return self.language_model.get_output_embeddings()
171+
172+
def set_output_embeddings(self, value):
173+
"""Set the output embeddings for the language model."""
174+
self.language_model.set_output_embeddings(value)
167175

168176
def set_moe_z_loss_coeff(self, value):
169177
"""

0 commit comments

Comments
 (0)