diff --git a/src/kernl/debugger/tl_lang.py b/src/kernl/debugger/tl_lang.py index 3787b5fd..d7774d5d 100644 --- a/src/kernl/debugger/tl_lang.py +++ b/src/kernl/debugger/tl_lang.py @@ -403,13 +403,13 @@ def reshape(self, input, shape): raise NotImplementedError() @_tensor_operation - def dot(self, input, other, trans_a=False, trans_b=False, allow_tf32=True): + def dot(self, input, other): assert input.dtype == other.dtype - if trans_a: - input = input.T - if trans_b: - other = other.T return torch.matmul(input=input, other=other) + + @_tensor_operation + def trans(self, input): + return input.T @_tensor_operation def atomic_cas(self, pointer, cmp, val):