diff --git a/src/flag_gems/runtime/backend/_mthreads/ops/arange.py b/src/flag_gems/runtime/backend/_mthreads/ops/arange.py index 48c1c58a9d..a7268b5014 100644 --- a/src/flag_gems/runtime/backend/_mthreads/ops/arange.py +++ b/src/flag_gems/runtime/backend/_mthreads/ops/arange.py @@ -13,7 +13,7 @@ from flag_gems.utils import triton_lang_extension as tle logger = logging.getLogger( - f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}' + f"flag_gems.runtime.backend._mthreads.ops.{__name__.split('.')[-1]}" ) device_ = runtime.device @@ -140,6 +140,19 @@ def arange_start( else: device = torch.device(device) + # Handle int64 dtype with float parameters - convert to int + if dtype is torch.int64: + if ( + isinstance(start, float) + or isinstance(end, float) + or isinstance(step, float) + ): + start = int(start) if isinstance(start, float) else start + end = int(end) if isinstance(end, float) else end + step = int(step) if isinstance(step, float) else step + if step == 0: + raise RuntimeError("step must be nonzero") + is_float_dtype = torch.is_floating_point(torch.tensor(0, dtype=dtype)) use_int64 = dtype == torch.int64 size = _compute_size(start, end, step, is_float_dtype) diff --git a/src/flag_gems/runtime/backend/_mthreads/ops/repeat.py b/src/flag_gems/runtime/backend/_mthreads/ops/repeat.py index af05f2c38e..cf349f6960 100644 --- a/src/flag_gems/runtime/backend/_mthreads/ops/repeat.py +++ b/src/flag_gems/runtime/backend/_mthreads/ops/repeat.py @@ -381,7 +381,7 @@ def repeat(inp: torch.Tensor, sizes) -> torch.Tensor: assert ( sizes_shape[i] >= 0 ), f"the number of repetitions per dimension out of range (expected to >= 0) but got {sizes_shape[i]}" - if sizes_shape[i] == 0: + if in0_shape[i] * sizes_shape[i] == 0: is_empty = True out_shape.append(in0_shape[i] * sizes_shape[i]) diff --git a/src/flag_gems/runtime/backend/_mthreads/ops/repeat_interleave.py b/src/flag_gems/runtime/backend/_mthreads/ops/repeat_interleave.py index 3cc65e0abb..690fdc0265 100644 --- a/src/flag_gems/runtime/backend/_mthreads/ops/repeat_interleave.py +++ b/src/flag_gems/runtime/backend/_mthreads/ops/repeat_interleave.py @@ -12,7 +12,7 @@ from flag_gems.utils.tensor_wrapper import StridedBuffer logger = logging.getLogger( - f'flag_gems.runtime.backend._mthreads.ops.{__name__.split(".")[-1]}' + f"flag_gems.runtime.backend._mthreads.ops.{__name__.split('.')[-1]}" ) @@ -437,6 +437,9 @@ def fused_repeat_interleave_dim0(inp, repeats, dim): def repeat_interleave_self_tensor(inp, repeats, dim=None, *, output_size=None): logger.debug("GEMS_MTHREADS REPEAT_INTERLEAVE_SELF_TENSOR") + if repeats.numel() == 0: + return inp.clone() + if dim is None: inp = inp.flatten() dim = 0 diff --git a/src/flag_gems/runtime/backend/_mthreads/tune_configs.yaml b/src/flag_gems/runtime/backend/_mthreads/tune_configs.yaml index 8226b4861c..da2194f575 100644 --- a/src/flag_gems/runtime/backend/_mthreads/tune_configs.yaml +++ b/src/flag_gems/runtime/backend/_mthreads/tune_configs.yaml @@ -301,6 +301,26 @@ cross_entropy_loss_sum_and_scale: - 256 - 1024 +conj_physical: + - META: + BLOCK_SIZE: 64 + num_warps: 8 + - META: + BLOCK_SIZE: 128 + num_warps: 8 + - META: + BLOCK_SIZE: 256 + num_warps: 8 + - META: + BLOCK_SIZE: 512 + num_warps: 8 + - META: + BLOCK_SIZE: 1024 + num_warps: 8 + - META: + BLOCK_SIZE: 2048 + num_warps: 8 + upsample_nearest2d: - gen: true param_map: