Skip to content

Commit 38a739c

Browse files
authored
Replace torch.aten.outer with corresponding math (#1084)
Outer is missing lowerings in `torch-mlir` right now. To unblock we just explicitly do the unsqueeze and mul operators.
1 parent f35bb06 commit 38a739c

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

sharktank/sharktank/layers/rotary_embedding.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -276,15 +276,15 @@ def _compute_rotary_embed_table(self, t):
276276
freqs = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
277277

278278
freqs = torch.cat((freqs, freqs), dim=-1)
279-
emb = torch.outer(t.float(), freqs.float())
279+
emb = t.unsqueeze(1).float() * freqs.unsqueeze(0).float()
280280
cos = torch.cos(emb).to(self.dtype)
281281
sin = torch.sin(emb).to(self.dtype)
282282
return (cos, sin)
283283

284284
freqs = 1.0 / (
285285
self.rope_freq_base ** ((torch.arange(0, dim) // 2).float() / dim * 2.0)
286286
)
287-
freqs = torch.outer(t, freqs).float()
287+
freqs = (t.unsqueeze(1) * freqs.unsqueeze(0)).float()
288288
return freqs
289289

290290
def _create_rotary_embed_table(self):

0 commit comments

Comments
 (0)