Skip to content

Commit

Permalink
delete _reshape
Browse files Browse the repository at this point in the history
  • Loading branch information
mariogeiger committed Dec 17, 2024
1 parent 0c0d7e9 commit 97ca27a
Showing 1 changed file with 0 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -440,17 +440,6 @@ def _tensor_product_cuda(
return FusedTensorProductOp4(descriptor, perm[:3], device, math_dtype)


@torch.jit.script
def _reshape(x: torch.Tensor, leading_shape: List[int]) -> torch.Tensor:
# Make x have shape (Z, x.shape[-1]) or (x.shape[-1],)
if prod(leading_shape) > 1 and prod(x.shape[:-1]) == 1:
return x.reshape((x.shape[-1],))
else:
return x.expand(leading_shape + (x.shape[-1],)).reshape(
(prod(leading_shape), x.shape[-1])
)


class FusedTensorProductOp3(torch.nn.Module):
def __init__(
self,
Expand Down

0 comments on commit 97ca27a

Please sign in to comment.