Skip to content

Commit 6cce5db

Browse files
committed
Make full_like(a) transpose the created tensor if needed
1 parent 1b03288 commit 6cce5db

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

thunder/clang/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,17 @@ def full_like(
265265
device = devices.to_device(device) if device is not None else a.device
266266
dtype = dtype if dtype is not None else a.true_dtype
267267

268-
return full(a.shape, fill_value, device=device, dtype=dtype)
268+
is_stride_decreasing = all(x > y for x, y in zip(a.stride(), a.stride()[1:]))
269+
if is_stride_decreasing:
270+
return full(a.shape, fill_value, device=device, dtype=dtype)
271+
272+
permutation = [None] * len(a.stride())
273+
permuted_shape = [None] * len(a.stride())
274+
for i, s in enumerate(sorted(a.stride(), reverse=True)):
275+
permutation[a.stride().index(s)] = i
276+
permuted_shape[i] = a.shape[a.stride().index(s)]
277+
278+
return transpose(full(permuted_shape, fill_value, device=device, dtype=dtype), permutation)
269279

270280

271281
@clangop()

0 commit comments

Comments
 (0)