diff --git a/.circleci/config.yml b/.circleci/config.yml index 3b57902c..68e22ac5 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -26,8 +26,7 @@ jobs: command: | pip install pre-commit brew install swift-format - pre-commit run --all - if ! git diff --quiet; then echo 'Style checks failed, please install pre-commit and run pre-commit run --all and push the change'; exit 1; fi + pre-commit run --all || (echo "Style checks failed, please install pre-commit and run pre-commit run --all and push the change"; echo ""; git --no-pager diff; exit 1) - run: name: Run Tests (Xcode, macOS) command: | diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 63dd1998..70b05a86 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,9 +1,14 @@ repos: -- repo: https://github.com/slessans/pre-commit-swift-format - rev: "" + +- repo: local hooks: - id: swift-format - args: ["--configuration", ".swift-format"] + name: swift-format + language: system + entry: swift-format format --in-place --configuration .swift-format --recursive . + require_serial: true + types: [swift] + - repo: https://github.com/cheshirekow/cmake-format-precommit rev: v0.6.10 hooks: diff --git a/Package.swift b/Package.swift index 2ad79e1a..49c39b2b 100644 --- a/Package.swift +++ b/Package.swift @@ -74,8 +74,38 @@ let package = Package( "mlx/tests", // opt-out of these backends (using metal) - "mlx/mlx/backend/no_metal", + "mlx/mlx/backend/no_gpu", "mlx/mlx/backend/no_cpu", + "mlx/mlx/backend/metal/no_metal.cpp", + + // special handling for cuda -- we need to keep one file: + // mlx/mlx/backend/cuda/no_cuda.cpp + + "mlx/mlx/backend/cuda/allocator.cpp", + "mlx/mlx/backend/cuda/compiled.cpp", + "mlx/mlx/backend/cuda/conv.cpp", + "mlx/mlx/backend/cuda/cuda.cpp", + "mlx/mlx/backend/cuda/cudnn_utils.cpp", + "mlx/mlx/backend/cuda/custom_kernel.cpp", + "mlx/mlx/backend/cuda/device.cpp", + "mlx/mlx/backend/cuda/eval.cpp", + "mlx/mlx/backend/cuda/fence.cpp", + "mlx/mlx/backend/cuda/indexing.cpp", + "mlx/mlx/backend/cuda/jit_module.cpp", + "mlx/mlx/backend/cuda/matmul.cpp", + "mlx/mlx/backend/cuda/primitives.cpp", + "mlx/mlx/backend/cuda/slicing.cpp", + "mlx/mlx/backend/cuda/utils.cpp", + "mlx/mlx/backend/cuda/worker.cpp", + "mlx/mlx/backend/cuda/unary", + "mlx/mlx/backend/cuda/gemms", + "mlx/mlx/backend/cuda/steel", + "mlx/mlx/backend/cuda/reduce", + "mlx/mlx/backend/cuda/quantized", + "mlx/mlx/backend/cuda/conv", + "mlx/mlx/backend/cuda/copy", + "mlx/mlx/backend/cuda/device", + "mlx/mlx/backend/cuda/binary", // build variants (we are opting _out_ of these) "mlx/mlx/io/no_safetensors.cpp", @@ -89,6 +119,8 @@ let package = Package( // do not build distributed support (yet) "mlx/mlx/distributed/mpi/mpi.cpp", "mlx/mlx/distributed/ring/ring.cpp", + "mlx/mlx/distributed/nccl/nccl.cpp", + "mlx/mlx/distributed/nccl/nccl_stub", // bnns instead of simd (accelerate) "mlx/mlx/backend/cpu/gemms/simd_fp16.cpp", diff --git a/Source/Cmlx/include/mlx/c/fast.h b/Source/Cmlx/include/mlx/c/fast.h index 048ff6bb..02dd51dc 100644 --- a/Source/Cmlx/include/mlx/c/fast.h +++ b/Source/Cmlx/include/mlx/c/fast.h @@ -27,22 +27,68 @@ extern "C" { * \defgroup fast Fast custom operations */ /**@{*/ -int mlx_fast_affine_dequantize( - mlx_array* res, - const mlx_array w, - const mlx_array scales, - const mlx_array biases, - int group_size, - int bits, - const mlx_stream s); -int mlx_fast_affine_quantize( - mlx_array* res_0, - mlx_array* res_1, - mlx_array* res_2, - const mlx_array w, - int group_size, - int bits, - const mlx_stream s); + +typedef struct mlx_fast_cuda_kernel_config_ { + void* ctx; +} mlx_fast_cuda_kernel_config; +mlx_fast_cuda_kernel_config mlx_fast_cuda_kernel_config_new(); +void mlx_fast_cuda_kernel_config_free(mlx_fast_cuda_kernel_config cls); + +int mlx_fast_cuda_kernel_config_add_output_arg( + mlx_fast_cuda_kernel_config cls, + const int* shape, + size_t size, + mlx_dtype dtype); +int mlx_fast_cuda_kernel_config_set_grid( + mlx_fast_cuda_kernel_config cls, + int grid1, + int grid2, + int grid3); +int mlx_fast_cuda_kernel_config_set_thread_group( + mlx_fast_cuda_kernel_config cls, + int thread1, + int thread2, + int thread3); +int mlx_fast_cuda_kernel_config_set_init_value( + mlx_fast_cuda_kernel_config cls, + float value); +int mlx_fast_cuda_kernel_config_set_verbose( + mlx_fast_cuda_kernel_config cls, + bool verbose); +int mlx_fast_cuda_kernel_config_add_template_arg_dtype( + mlx_fast_cuda_kernel_config cls, + const char* name, + mlx_dtype dtype); +int mlx_fast_cuda_kernel_config_add_template_arg_int( + mlx_fast_cuda_kernel_config cls, + const char* name, + int value); +int mlx_fast_cuda_kernel_config_add_template_arg_bool( + mlx_fast_cuda_kernel_config cls, + const char* name, + bool value); + +typedef struct mlx_fast_cuda_kernel_ { + void* ctx; +} mlx_fast_cuda_kernel; + +mlx_fast_cuda_kernel mlx_fast_cuda_kernel_new( + const char* name, + const mlx_vector_string input_names, + const mlx_vector_string output_names, + const char* source, + const char* header, + bool ensure_row_contiguous, + bool atomic_outputs); + +void mlx_fast_cuda_kernel_free(mlx_fast_cuda_kernel cls); +int mlx_fast_cuda_kernel_apply( + mlx_vector_array* outputs, + mlx_fast_cuda_kernel cls, + const mlx_vector_array inputs, + const mlx_fast_cuda_kernel_config config, + const mlx_stream stream); + int mlx_fast_layer_norm( mlx_array* res, const mlx_array x, @@ -103,6 +149,7 @@ mlx_fast_metal_kernel mlx_fast_metal_kernel_new( const char* header, bool ensure_row_contiguous, bool atomic_outputs); + void mlx_fast_metal_kernel_free(mlx_fast_metal_kernel cls); int mlx_fast_metal_kernel_apply( mlx_vector_array* outputs, @@ -135,6 +182,7 @@ int mlx_fast_scaled_dot_product_attention( float scale, const char* mask_mode, const mlx_vector_array mask_arrs, + const mlx_array sinks /* may be null */, const mlx_stream s); /**@}*/ diff --git a/Source/Cmlx/include/mlx/c/fft.h b/Source/Cmlx/include/mlx/c/fft.h index 55f218a7..b7ef5e03 100644 --- a/Source/Cmlx/include/mlx/c/fft.h +++ b/Source/Cmlx/include/mlx/c/fft.h @@ -49,6 +49,12 @@ int mlx_fft_fftn( const int* axes, size_t axes_num, const mlx_stream s); +int mlx_fft_fftshift( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + const mlx_stream s); int mlx_fft_ifft( mlx_array* res, const mlx_array a, @@ -71,6 +77,12 @@ int mlx_fft_ifftn( const int* axes, size_t axes_num, const mlx_stream s); +int mlx_fft_ifftshift( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + const mlx_stream s); int mlx_fft_irfft( mlx_array* res, const mlx_array a, diff --git a/Source/Cmlx/include/mlx/c/linalg.h b/Source/Cmlx/include/mlx/c/linalg.h index 9142ca57..ac0b3237 100644 --- a/Source/Cmlx/include/mlx/c/linalg.h +++ b/Source/Cmlx/include/mlx/c/linalg.h @@ -43,12 +43,18 @@ int mlx_linalg_cross( const mlx_array b, int axis, const mlx_stream s); +int mlx_linalg_eig( + mlx_array* res_0, + mlx_array* res_1, + const mlx_array a, + const mlx_stream s); int mlx_linalg_eigh( mlx_array* res_0, mlx_array* res_1, const mlx_array a, const char* UPLO, const mlx_stream s); +int mlx_linalg_eigvals(mlx_array* res, const mlx_array a, const mlx_stream s); int mlx_linalg_eigvalsh( mlx_array* res, const mlx_array a, diff --git a/Source/Cmlx/include/mlx/c/ops.h b/Source/Cmlx/include/mlx/c/ops.h index 4f470823..9835f92a 100644 --- a/Source/Cmlx/include/mlx/c/ops.h +++ b/Source/Cmlx/include/mlx/c/ops.h @@ -357,9 +357,10 @@ int mlx_dequantize( mlx_array* res, const mlx_array w, const mlx_array scales, - const mlx_array biases, + const mlx_array biases /* may be null */, int group_size, int bits, + const char* mode, const mlx_stream s); int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s); int mlx_diagonal( @@ -452,12 +453,13 @@ int mlx_gather_qmm( const mlx_array x, const mlx_array w, const mlx_array scales, - const mlx_array biases, + const mlx_array biases /* may be null */, const mlx_array lhs_indices /* may be null */, const mlx_array rhs_indices /* may be null */, bool transpose, int group_size, int bits, + const char* mode, bool sorted_indices, const mlx_stream s); int mlx_greater( @@ -747,22 +749,22 @@ int mlx_put_along_axis( int axis, const mlx_stream s); int mlx_quantize( - mlx_array* res_0, - mlx_array* res_1, - mlx_array* res_2, + mlx_vector_array* res, const mlx_array w, int group_size, int bits, + const char* mode, const mlx_stream s); int mlx_quantized_matmul( mlx_array* res, const mlx_array x, const mlx_array w, const mlx_array scales, - const mlx_array biases, + const mlx_array biases /* may be null */, bool transpose, int group_size, int bits, + const char* mode, const mlx_stream s); int mlx_radians(mlx_array* res, const mlx_array a, const mlx_stream s); int mlx_real(mlx_array* res, const mlx_array a, const mlx_stream s); @@ -868,6 +870,12 @@ int mlx_scatter_prod( const int* axes, size_t axes_num, const mlx_stream s); +int mlx_segmented_mm( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_array segments, + const mlx_stream s); int mlx_sigmoid(mlx_array* res, const mlx_array a, const mlx_stream s); int mlx_sign(mlx_array* res, const mlx_array a, const mlx_stream s); int mlx_sin(mlx_array* res, const mlx_array a, const mlx_stream s); diff --git a/Source/Cmlx/include/mlx/c/random.h b/Source/Cmlx/include/mlx/c/random.h index 04a735aa..5e9d216a 100644 --- a/Source/Cmlx/include/mlx/c/random.h +++ b/Source/Cmlx/include/mlx/c/random.h @@ -88,6 +88,15 @@ int mlx_random_multivariate_normal( mlx_dtype dtype, const mlx_array key /* may be null */, const mlx_stream s); +int mlx_random_normal_broadcast( + mlx_array* res, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + const mlx_array loc /* may be null */, + const mlx_array scale /* may be null */, + const mlx_array key /* may be null */, + const mlx_stream s); int mlx_random_normal( mlx_array* res, const int* shape, diff --git a/Source/Cmlx/mlx b/Source/Cmlx/mlx index eaf709b8..ee18e1cb 160000 --- a/Source/Cmlx/mlx +++ b/Source/Cmlx/mlx @@ -1 +1 @@ -Subproject commit eaf709b83e559079e212699bfc9dd2f939d25c9a +Subproject commit ee18e1cbf0ab7937578d716cc1b62b3fb1725e27 diff --git a/Source/Cmlx/mlx-c b/Source/Cmlx/mlx-c index 9ebe1558..9e43e355 160000 --- a/Source/Cmlx/mlx-c +++ b/Source/Cmlx/mlx-c @@ -1 +1 @@ -Subproject commit 9ebe155864eab06d94ba18e01f9cb2666b2975a7 +Subproject commit 9e43e355ea8b7bdfbd7f8f82978b3cbbb42366f8 diff --git a/Source/Cmlx/mlx-generated/binary.cpp b/Source/Cmlx/mlx-generated/binary.cpp index fa430492..43b99fcc 100644 --- a/Source/Cmlx/mlx-generated/binary.cpp +++ b/Source/Cmlx/mlx-generated/binary.cpp @@ -10,59 +10,116 @@ template uint index [[thread_position_in_grid]]) { c[index] = Op()(a[0], b[0]); } -template +template ::n> [[kernel]] void binary_sv( device const T* a, device const T* b, device U* c, + constant uint& size, uint index [[thread_position_in_grid]]) { - c[index] = Op()(a[0], b[index]); + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + c[index + i] = Op()(a[0], b[index + i]); + } + } else { + for (int i = 0; i < N; ++i) { + c[index + i] = Op()(a[0], b[index + i]); + } + } } -template +template ::n> [[kernel]] void binary_vs( device const T* a, device const T* b, device U* c, + constant uint& size, uint index [[thread_position_in_grid]]) { - c[index] = Op()(a[index], b[0]); + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + c[index + i] = Op()(a[index + i], b[0]); + } + } else { + for (int i = 0; i < N; ++i) { + c[index + i] = Op()(a[index + i], b[0]); + } + } } -template +template ::n> [[kernel]] void binary_vv( device const T* a, device const T* b, device U* c, + constant uint& size, uint index [[thread_position_in_grid]]) { - c[index] = Op()(a[index], b[index]); + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + c[index + i] = Op()(a[index + i], b[index + i]); + } + } else { + for (int i = 0; i < N; ++i) { + c[index + i] = Op()(a[index + i], b[index + i]); + } + } } -template +template ::n> [[kernel]] void binary_sv2( device const T* a, device const T* b, device U* c, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - int64_t offset = index.x + grid_dim.x * int64_t(index.y); - c[offset] = Op()(a[0], b[offset]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + c[offset + i] = Op()(a[0], b[offset + i]); + } + } else { + for (int i = 0; i < N; ++i) { + c[offset + i] = Op()(a[0], b[offset + i]); + } + } } -template +template ::n> [[kernel]] void binary_vs2( device const T* a, device const T* b, device U* c, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - int64_t offset = index.x + grid_dim.x * int64_t(index.y); - c[offset] = Op()(a[offset], b[0]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + c[offset + i] = Op()(a[offset + i], b[0]); + } + } else { + for (int i = 0; i < N; ++i) { + c[offset + i] = Op()(a[offset + i], b[0]); + } + } } -template +template ::n> [[kernel]] void binary_vv2( device const T* a, device const T* b, device U* c, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - int64_t offset = index.x + grid_dim.x * int64_t(index.y); - c[offset] = Op()(a[offset], b[offset]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + c[offset + i] = Op()(a[offset + i], b[offset + i]); + } + } else { + for (int i = 0; i < N; ++i) { + c[offset + i] = Op()(a[offset + i], b[offset + i]); + } + } } template [[kernel]] void binary_g_nd1( diff --git a/Source/Cmlx/mlx-generated/binary_ops.cpp b/Source/Cmlx/mlx-generated/binary_ops.cpp index d23ce5f3..6754554f 100644 --- a/Source/Cmlx/mlx-generated/binary_ops.cpp +++ b/Source/Cmlx/mlx-generated/binary_ops.cpp @@ -199,6 +199,9 @@ struct Power { template metal::enable_if_t, T> operator()(T base, T exp) { T res = 1; + if (exp < 0) { + return 0; + } while (exp) { if (exp & 1) { res *= base; @@ -210,6 +213,13 @@ struct Power { } template <> complex64_t operator()(complex64_t x, complex64_t y) { + if (x.real == 0 && x.imag == 0) { + if (metal::isnan(y.real) || metal::isnan(y.imag)) { + auto nan = metal::numeric_limits::quiet_NaN(); + return {nan, nan}; + } + return {0.0, 0.0}; + } auto x_theta = metal::atan2(x.imag, x.real); auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag); auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta); diff --git a/Source/Cmlx/mlx-generated/binary_two.cpp b/Source/Cmlx/mlx-generated/binary_two.cpp index 07a8138f..57778976 100644 --- a/Source/Cmlx/mlx-generated/binary_two.cpp +++ b/Source/Cmlx/mlx-generated/binary_two.cpp @@ -13,77 +13,146 @@ template c[index] = out[0]; d[index] = out[1]; } -template +template ::n> [[kernel]] void binary_sv( device const T* a, device const T* b, device U* c, device U* d, + constant uint& size, uint index [[thread_position_in_grid]]) { - auto out = Op()(a[0], b[index]); - c[index] = out[0]; - d[index] = out[1]; + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + auto out = Op()(a[0], b[index + i]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[0], b[index + i]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } + } } -template +template ::n> [[kernel]] void binary_vs( device const T* a, device const T* b, device U* c, device U* d, + constant uint& size, uint index [[thread_position_in_grid]]) { - auto out = Op()(a[index], b[0]); - c[index] = out[0]; - d[index] = out[1]; + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + auto out = Op()(a[index + i], b[0]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[index + i], b[0]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } + } } -template +template ::n> [[kernel]] void binary_vv( device const T* a, device const T* b, device U* c, device U* d, + constant uint& size, uint index [[thread_position_in_grid]]) { - auto out = Op()(a[index], b[index]); - c[index] = out[0]; - d[index] = out[1]; + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + auto out = Op()(a[index + i], b[index + i]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[index + i], b[index + i]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } + } } -template +template ::n> [[kernel]] void binary_sv2( device const T* a, device const T* b, device U* c, device U* d, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - auto out = Op()(a[0], b[offset]); - c[offset] = out[0]; - d[offset] = out[1]; + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + auto out = Op()(a[0], b[offset + i]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[0], b[offset + i]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } + } } -template +template ::n> [[kernel]] void binary_vs2( device const T* a, device const T* b, device U* c, device U* d, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - auto out = Op()(a[offset], b[0]); - c[offset] = out[0]; - d[offset] = out[1]; + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + auto out = Op()(a[offset + i], b[0]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[offset + i], b[0]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } + } } -template +template ::n> [[kernel]] void binary_vv2( device const T* a, device const T* b, device U* c, device U* d, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - auto out = Op()(a[offset], b[offset]); - c[offset] = out[0]; - d[offset] = out[1]; + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + auto out = Op()(a[offset + i], b[offset + i]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[offset + i], b[offset + i]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } + } } template [[kernel]] void binary_g_nd1( diff --git a/Source/Cmlx/mlx-generated/conv.cpp b/Source/Cmlx/mlx-generated/conv.cpp index 3e1f1d60..732b2ecd 100644 --- a/Source/Cmlx/mlx-generated/conv.cpp +++ b/Source/Cmlx/mlx-generated/conv.cpp @@ -353,6 +353,7 @@ struct Conv2DWeightBlockLoader { const device T* src; const constant MLXConvParams<2>* params; int weight_hw; + int weight_step; const int read_n; const bool do_read; METAL_FUNC Conv2DWeightBlockLoader( @@ -371,6 +372,7 @@ struct Conv2DWeightBlockLoader { src(src_ + bi * src_ld + bj), params(params_), weight_hw(0), + weight_step(params->C / params->groups), read_n(offsets.y + bi), do_read(read_n + n_rows * TROWS <= gemm_params_->N) {} METAL_FUNC void load_unsafe() const { @@ -400,11 +402,11 @@ struct Conv2DWeightBlockLoader { } METAL_FUNC void next() { if (++weight_hw < (params->wS[1] * params->wS[0])) { - src += params->wt_strides[2]; + src += weight_step; return; } weight_hw = 0; - src += BK - (params->wS[1] * params->wS[0] - 1) * params->wt_strides[2]; + src += BK - (params->wS[1] * params->wS[0] - 1) * weight_step; } }; } @@ -463,7 +465,7 @@ struct Conv2DInputBlockLoaderSmallChannels { threadgroup T* dst; const constant MLXConvParams<2>* params; const constant ImplicitGemmConv2DParams* gemm_params; - short weight_hw; + int weight_hw; const device T* src[n_rows]; int read_n[n_rows]; int read_ih[n_rows]; @@ -604,7 +606,7 @@ struct Conv2DWeightBlockLoaderSmallChannels { } return; } - const device T* curr_src = src + weight_hw * params->wt_strides[2]; + const device T* curr_src = src + weight_hw * (params->C / params->groups); if (BN != 8 || do_read) { #pragma clang loop unroll(full) for (short i = 0; i < BROWS; i += TROWS) { diff --git a/Source/Cmlx/mlx-generated/copy.cpp b/Source/Cmlx/mlx-generated/copy.cpp index 9ac729f1..260f6789 100644 --- a/Source/Cmlx/mlx-generated/copy.cpp +++ b/Source/Cmlx/mlx-generated/copy.cpp @@ -2,37 +2,75 @@ namespace mlx::core::metal { const char* copy() { return R"preamble( -template +template ::n> [[kernel]] void copy_s( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], + constant uint& size, uint index [[thread_position_in_grid]]) { - dst[index] = static_cast(src[0]); + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + dst[index + i] = static_cast(src[0]); + } + } else { + for (int i = 0; i < N; ++i) { + dst[index + i] = static_cast(src[0]); + } + } } -template +template ::n> [[kernel]] void copy_v( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], + constant uint& size, uint index [[thread_position_in_grid]]) { - dst[index] = static_cast(src[index]); + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + dst[index + i] = static_cast(src[index + i]); + } + } else { + for (int i = 0; i < N; ++i) { + dst[index + i] = static_cast(src[index + i]); + } + } } -template +template ::n> [[kernel]] void copy_s2( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - dst[offset] = static_cast(src[0]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + dst[offset + i] = static_cast(src[0]); + } + } else { + for (int i = 0; i < N; ++i) { + dst[offset + i] = static_cast(src[0]); + } + } } -template +template ::n> [[kernel]] void copy_v2( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - dst[offset] = static_cast(src[offset]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + dst[offset + i] = static_cast(src[offset + i]); + } + } else { + for (int i = 0; i < N; ++i) { + dst[offset + i] = static_cast(src[offset + i]); + } + } } template [[kernel]] void copy_g_nd1( diff --git a/Source/Cmlx/mlx-generated/fft.cpp b/Source/Cmlx/mlx-generated/fft.cpp index aaac34cb..065a4f41 100644 --- a/Source/Cmlx/mlx-generated/fft.cpp +++ b/Source/Cmlx/mlx-generated/fft.cpp @@ -314,7 +314,7 @@ struct ReadWriter { return grid_index >= batch_size; } METAL_FUNC void load() const { - int batch_idx = elem.x * grid.y * n; + size_t batch_idx = size_t(elem.x * grid.y) * n; short tg_idx = elem.y * grid.z + elem.z; short max_index = grid.y * n - 2; constexpr int read_width = 2; @@ -333,7 +333,7 @@ struct ReadWriter { } } METAL_FUNC void write() const { - int batch_idx = elem.x * grid.y * n; + size_t batch_idx = size_t(elem.x * grid.y) * n; short tg_idx = elem.y * grid.z + elem.z; short max_index = grid.y * n - 2; constexpr int read_width = 2; @@ -352,7 +352,7 @@ struct ReadWriter { } } METAL_FUNC void load_padded(int length, const device float2* w_k) const { - int batch_idx = elem.x * grid.y * length + elem.y * length; + size_t batch_idx = size_t(elem.x * grid.y) * length + elem.y * length; int fft_idx = elem.z; int m = grid.z; threadgroup float2* seq_buf = buf + elem.y * n; @@ -367,7 +367,7 @@ struct ReadWriter { } } METAL_FUNC void write_padded(int length, const device float2* w_k) const { - int batch_idx = elem.x * grid.y * length + elem.y * length; + size_t batch_idx = size_t(elem.x * grid.y) * length + elem.y * length; int fft_idx = elem.z; int m = grid.z; float2 inv_factor = {1.0f / n, -1.0f / n}; @@ -437,7 +437,7 @@ METAL_FUNC bool ReadWriter::out_of_bounds() const { } template <> METAL_FUNC void ReadWriter::load() const { - int batch_idx = elem.x * grid.y * n * 2 + elem.y * n * 2; + size_t batch_idx = size_t(elem.x * grid.y) * n * 2 + elem.y * n * 2; threadgroup float2* seq_buf = buf + elem.y * n; int grid_index = elem.x * grid.y + elem.y; short next_in = @@ -453,7 +453,8 @@ METAL_FUNC void ReadWriter::load() const { template <> METAL_FUNC void ReadWriter::write() const { short n_over_2 = (n / 2) + 1; - int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2; + size_t batch_idx = + size_t(elem.x * grid.y) * n_over_2 * 2 + elem.y * n_over_2 * 2; threadgroup float2* seq_buf = buf + elem.y * n; int grid_index = elem.x * grid.y + elem.y; short next_out = @@ -480,7 +481,7 @@ template <> METAL_FUNC void ReadWriter::load_padded( int length, const device float2* w_k) const { - int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2; + size_t batch_idx = size_t(elem.x * grid.y) * length * 2 + elem.y * length * 2; threadgroup float2* seq_buf = buf + elem.y * n; int grid_index = elem.x * grid.y + elem.y; short next_in = @@ -503,8 +504,8 @@ METAL_FUNC void ReadWriter::write_padded( int length, const device float2* w_k) const { int length_over_2 = (length / 2) + 1; - int batch_idx = - elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2; + size_t batch_idx = + size_t(elem.x * grid.y) * length_over_2 * 2 + elem.y * length_over_2 * 2; threadgroup float2* seq_buf = buf + elem.y * n + length - 1; int grid_index = elem.x * grid.y + elem.y; short next_out = batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 @@ -540,7 +541,8 @@ METAL_FUNC bool ReadWriter::out_of_bounds() const { template <> METAL_FUNC void ReadWriter::load() const { short n_over_2 = (n / 2) + 1; - int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2; + size_t batch_idx = + size_t(elem.x * grid.y) * n_over_2 * 2 + elem.y * n_over_2 * 2; threadgroup float2* seq_buf = buf + elem.y * n; int grid_index = elem.x * grid.y + elem.y; short next_in = @@ -588,8 +590,8 @@ METAL_FUNC void ReadWriter::load_padded( const device float2* w_k) const { int n_over_2 = (n / 2) + 1; int length_over_2 = (length / 2) + 1; - int batch_idx = - elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2; + size_t batch_idx = + size_t(elem.x * grid.y) * length_over_2 * 2 + elem.y * length_over_2 * 2; threadgroup float2* seq_buf = buf + elem.y * n; int grid_index = elem.x * grid.y + elem.y; short next_in = batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 @@ -627,7 +629,7 @@ template <> METAL_FUNC void ReadWriter::write_padded( int length, const device float2* w_k) const { - int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2; + size_t batch_idx = size_t(elem.x * grid.y) * length * 2 + elem.y * length * 2; threadgroup float2* seq_buf = buf + elem.y * n + length - 1; int grid_index = elem.x * grid.y + elem.y; short next_out = diff --git a/Source/Cmlx/mlx-generated/fp4_quantized.cpp b/Source/Cmlx/mlx-generated/fp4_quantized.cpp new file mode 100644 index 00000000..56beca3b --- /dev/null +++ b/Source/Cmlx/mlx-generated/fp4_quantized.cpp @@ -0,0 +1,1607 @@ +namespace mlx::core::metal { + +const char* fp4_quantized() { + return R"preamble( +constant bool align_M [[function_constant(200)]]; +constant bool align_N [[function_constant(201)]]; +constant bool align_K [[function_constant(202)]]; +using namespace metal; +static constant constexpr const int SIMD_SIZE = 32; +static constant constexpr const int QUAD_SIZE = 4; +template +inline constexpr short get_pack_factor() { + return wsize / 4; +} +template +inline constexpr short get_bytes_per_pack() { + return wsize / 8; +} +template +static inline T dequantize_scale(uint8_t s) { + using FOrI = union { + bfloat16_t f; + uint16_t i; + }; + FOrI out; + out.i = (s == 0 ? 0x40 : (static_cast(s) << 7)); + return static_cast(out.f); +} +template +inline void load_vector(const device T* x, thread U* x_thread) { + for (int i = 0; i < values_per_thread; i += 4) { + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1]; + x_thread[i + 2] = x[i + 2]; + x_thread[i + 3] = x[i + 3]; + } +} +template +inline void load_vector_safe(const device T* x, thread U* x_thread, int N) { + for (int i = 0; i < N; i += 4) { + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1]; + x_thread[i + 2] = x[i + 2]; + x_thread[i + 3] = x[i + 3]; + } + for (int i = N; i < values_per_thread; i++) { + x_thread[i] = 0; + } +} +constexpr constant static float MXFP4_LUT[16] = { + +0.0f, + +0.5f, + +1.0f, + +1.5f, + +2.0f, + +3.0f, + +4.0f, + +6.0f, + -0.0f, + -0.5f, + -1.0f, + -1.5f, + -2.0f, + -3.0f, + -4.0f, + -6.0f}; +template +void load_mxfp4_lut(threadgroup T* lut, uint simd_gid, uint simd_lid) { + if (simd_gid == 0 && simd_lid < 16) { + lut[simd_lid] = static_cast(MXFP4_LUT[simd_lid]); + } + threadgroup_barrier(mem_flags::mem_threadgroup); +} +template +inline U qdot( + const device uint8_t* w, + const thread U* x_thread, + U scale, + const threadgroup U* lut) { + U accum = 0; + const device uint16_t* ws = (const device uint16_t*)w; + for (int i = 0; i < (values_per_thread / 4); i++) { + accum += + (x_thread[4 * i] * lut[ws[i] & 0xf] + + x_thread[4 * i + 1] * lut[(ws[i] >> 4) & 0xf] + + x_thread[4 * i + 2] * lut[(ws[i] >> 8) & 0xf] + + x_thread[4 * i + 3] * lut[(ws[i] >> 12) & 0xf]); + } + return scale * accum; +} +template +inline U qdot_safe( + const device uint8_t* w, + const thread U* x_thread, + U scale, + const threadgroup U* lut, + int N) { + U accum = 0; + const device uint16_t* ws = (const device uint16_t*)w; + for (int i = 0; i < (N / 4); i++) { + accum += + (x_thread[4 * i] * lut[ws[i] & 0xf] + + x_thread[4 * i + 1] * lut[(ws[i] >> 4) & 0xf] + + x_thread[4 * i + 2] * lut[(ws[i] >> 8) & 0xf] + + x_thread[4 * i + 3] * lut[(ws[i] >> 12) & 0xf]); + } + return scale * accum; +} +template +inline void qouter( + const thread uint8_t* w, + U x, + U scale, + thread U* result, + const threadgroup U* lut) { + for (int i = 0; i < (values_per_thread / 2); i++) { + result[2 * i] += x * scale * lut[w[i] & 0xf]; + result[2 * i + 1] += x * scale * lut[(w[i] >> 4) & 0xf]; + } +} +template +inline void dequantize( + const device uint8_t* w, + U scale, + threadgroup U* w_local, + const threadgroup U* lut) { + for (int i = 0; i < (N / 2); i++) { + w_local[2 * i] = scale * lut[w[i] & 0xf]; + w_local[2 * i + 1] = scale * lut[(w[i] >> 4) & 0xf]; + } +} +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size, + short group_size, + typename S> +struct QuantizedBlockLoader { + static_assert( + BCOLS <= group_size, + "The group size should be larger than the columns"); + static_assert( + group_size % BCOLS == 0, + "The group size should be divisible by the columns"); + static constant constexpr const short pack_factor = get_pack_factor<8>(); + static constant constexpr const short bytes_per_pack = get_bytes_per_pack(); + static constant constexpr const short BCOLS_PACKED = BCOLS / pack_factor; + static constant constexpr const short n_reads = + (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size; + static constant constexpr const short group_steps = group_size / BCOLS; + const int src_ld; + const int tile_stride; + short group_step_cnt; + const int group_stride; + const short thread_idx; + const short bi; + const short bj; + threadgroup T* dst; + const device uint8_t* src; + const device S* scales; + threadgroup T* lut; + QuantizedBlockLoader( + const device uint8_t* src_, + const device S* scales_, + const int src_ld_, + threadgroup T* dst_, + threadgroup T* lut_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(src_ld_), + tile_stride( + reduction_dim ? BCOLS_PACKED * bytes_per_pack + : BROWS * src_ld * bytes_per_pack / pack_factor), + group_step_cnt(0), + group_stride(BROWS * src_ld / group_size), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(n_reads * thread_idx / BCOLS_PACKED), + bj((n_reads * thread_idx) % BCOLS_PACKED), + dst(dst_ + bi * dst_ld + bj * pack_factor), + src(src_ + bi * src_ld * bytes_per_pack / pack_factor + + bj * bytes_per_pack), + scales(scales_ + bi * src_ld / group_size), + lut(lut_) { + load_mxfp4_lut(lut, simd_group_id, simd_lane_id); + } + void load_unsafe() const { + if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { + return; + } + T scale = dequantize_scale(*scales); + for (int i = 0; i < n_reads; i++) { + dequantize( + src + i * bytes_per_pack, scale, dst + i * pack_factor, lut); + } + } + void load_safe(short2 src_tile_dim) const { + if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { + return; + } + if (reduction_dim == 1 && bi >= src_tile_dim.x) { + for (int i = 0; i < n_reads * pack_factor; i++) { + dst[i] = T(0); + } + return; + } + if (reduction_dim == 0 && bi >= src_tile_dim.y) { + for (int i = 0; i < n_reads * pack_factor; i++) { + dst[i] = T(0); + } + return; + } + T scale = dequantize_scale(*scales); + for (int i = 0; i < n_reads; i++) { + dequantize( + (device uint8_t*)(src + i * bytes_per_pack), + scale, + dst + i * pack_factor, + lut); + } + } + void next() { + src += tile_stride; + if (reduction_dim == 1) { + if (group_steps > 1) { + group_step_cnt++; + if (group_step_cnt == group_steps) { + group_step_cnt = 0; + scales++; + } + } else { + scales++; + } + } else { + scales += group_stride; + } + } +}; +template +METAL_FUNC void mxfp4_qmv_quad_impl( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid [[threadgroup_position_in_grid]], + uint quad_gid [[quadgroup_index_in_threadgroup]], + uint quad_lid [[thread_index_in_quadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]], + threadgroup float* lut) { + constexpr int quads_per_simd = SIMD_SIZE / QUAD_SIZE; + constexpr int pack_factor = 8; + constexpr int values_per_thread = D / QUAD_SIZE; + constexpr int packs_per_thread = values_per_thread / pack_factor; + constexpr int scale_step_per_thread = group_size / values_per_thread; + constexpr int results_per_quadgroup = 8; + typedef float U; + thread U x_thread[values_per_thread]; + thread U result[results_per_quadgroup] = {0}; + load_mxfp4_lut(lut, simd_gid, simd_lid); + const int in_vec_size_w = in_vec_size / pack_factor; + const int in_vec_size_g = in_vec_size / group_size; + const int out_row = tid.y * quads_per_simd * results_per_quadgroup + quad_gid; + w += out_row * in_vec_size_w + quad_lid * packs_per_thread; + scales += out_row * in_vec_size_g + quad_lid / scale_step_per_thread; + x += tid.x * in_vec_size + quad_lid * values_per_thread; + y += tid.x * out_vec_size + out_row; + load_vector(x, x_thread); + for (int row = 0; row < results_per_quadgroup; row++) { + auto wl = (const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd); + const device S* sl = scales + row * in_vec_size_g * quads_per_simd; + U s = dequantize_scale(sl[0]); + if (row * quads_per_simd + out_row < out_vec_size) { + result[row] += qdot(wl, x_thread, s, lut); + } + } + for (int row = 0; row < results_per_quadgroup; row++) { + result[row] = quad_sum(result[row]); + if (quad_lid == 0 && row * quads_per_simd + out_row < out_vec_size) { + y[row * quads_per_simd] = static_cast(result[row]); + } + } +} +template +METAL_FUNC void mxfp4_qmv_fast_impl( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]], + threadgroup float* lut) { + constexpr int packs_per_thread = 2; + constexpr int num_simdgroups = 2; + constexpr int results_per_simdgroup = 4; + constexpr int pack_factor = get_pack_factor<32>(); + constexpr int bytes_per_pack = get_bytes_per_pack<32>(); + constexpr int values_per_thread = pack_factor * packs_per_thread; + constexpr int block_size = values_per_thread * SIMD_SIZE; + constexpr int scale_step_per_thread = group_size / values_per_thread; + const device uint8_t* ws = (const device uint8_t*)w; + typedef float U; + thread U x_thread[values_per_thread]; + thread U result[results_per_simdgroup] = {0}; + load_mxfp4_lut(lut, simd_gid, simd_lid); + const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; + const int in_vec_size_g = in_vec_size / group_size; + const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + + simd_gid * results_per_simdgroup; + ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; + scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + x += tid.x * in_vec_size + simd_lid * values_per_thread; + y += tid.x * out_vec_size + out_row; + for (int k = 0; k < in_vec_size; k += block_size) { + load_vector(x, x_thread); + for (int row = 0; row < results_per_simdgroup; row++) { + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); + const device auto* sl = scales + row * in_vec_size_g; + U s = dequantize_scale(sl[0]); + result[row] += qdot(wl, x_thread, s, lut); + } + ws += block_size * bytes_per_pack / pack_factor; + scales += block_size / group_size; + x += block_size; + } + for (int row = 0; row < results_per_simdgroup; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0) { + y[row] = static_cast(result[row]); + } + } +} +template +METAL_FUNC void mxfp4_qmv_impl( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]], + threadgroup float* lut) { + constexpr int num_simdgroups = 2; + constexpr int results_per_simdgroup = 4; + constexpr int packs_per_thread = 1; + constexpr int pack_factor = get_pack_factor<32>(); + constexpr int bytes_per_pack = get_bytes_per_pack<32>(); + constexpr int values_per_thread = pack_factor * packs_per_thread; + constexpr int block_size = values_per_thread * SIMD_SIZE; + constexpr int scale_step_per_thread = group_size / values_per_thread; + const device uint8_t* ws = (const device uint8_t*)w; + typedef float U; + thread U x_thread[values_per_thread]; + thread U result[results_per_simdgroup] = {0}; + load_mxfp4_lut(lut, simd_gid, simd_lid); + const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; + const int in_vec_size_g = in_vec_size / group_size; + const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + + simd_gid * results_per_simdgroup; + const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row); + if (out_row >= out_vec_size) { + return; + } + if (out_vec_size < (num_simdgroups * results_per_simdgroup)) { + ws += + out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; + scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + x += tid.x * in_vec_size + simd_lid * values_per_thread; + y += tid.x * out_vec_size + out_row; + int k = 0; + for (; k < in_vec_size - block_size; k += block_size) { + load_vector(x, x_thread); + for (int row = 0; out_row + row < out_vec_size; row++) { + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); + const device auto* sl = scales + row * in_vec_size_g; + S s = sl[0]; + result[row] += qdot(wl, x_thread, s, lut); + } + ws += block_size * bytes_per_pack / pack_factor; + scales += block_size / group_size; + x += block_size; + } + const int remaining = clamp( + static_cast(in_vec_size - k - simd_lid * values_per_thread), + 0, + values_per_thread); + if (remaining > 0) { + load_vector_safe(x, x_thread, remaining); + for (int row = 0; out_row + row < out_vec_size; row++) { + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); + const device auto* sl = scales + row * in_vec_size_g; + U s = dequantize_scale(sl[0]); + result[row] += qdot(wl, x_thread, s, lut); + } + } + for (int row = 0; out_row + row < out_vec_size; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0) { + y[row] = static_cast(result[row]); + } + } + } + else { + ws += used_out_row * in_vec_size_w + + simd_lid * packs_per_thread * bytes_per_pack; + scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + x += tid.x * in_vec_size + simd_lid * values_per_thread; + y += tid.x * out_vec_size + used_out_row; + int k = 0; + for (; k < in_vec_size - block_size; k += block_size) { + load_vector(x, x_thread); + for (int row = 0; row < results_per_simdgroup; row++) { + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); + const device auto* sl = scales + row * in_vec_size_g; + U s = dequantize_scale(sl[0]); + result[row] += qdot(wl, x_thread, s, lut); + } + ws += block_size * bytes_per_pack / pack_factor; + scales += block_size / group_size; + x += block_size; + } + const int remaining = clamp( + static_cast(in_vec_size - k - simd_lid * values_per_thread), + 0, + values_per_thread); + if (remaining > 0) { + load_vector_safe(x, x_thread, remaining); + for (int row = 0; row < results_per_simdgroup; row++) { + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); + const device auto* sl = scales + row * in_vec_size_g; + U s = dequantize_scale(sl[0]); + result[row] += + qdot_safe(wl, x_thread, s, lut, remaining); + } + } + for (int row = 0; row < results_per_simdgroup; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0) { + y[row] = static_cast(result[row]); + } + } + } +} +template +METAL_FUNC void mxfp4_qvm_impl( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + const int in_vec_size, + const int out_vec_size, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]], + threadgroup float* lut) { + constexpr int num_simdgroups = 2; + constexpr int pack_factor = get_pack_factor<32>(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + constexpr int tn = 32 / pack_factor; + constexpr int block_size = SIMD_SIZE; + using W_T = uint32_t; + const device W_T* ws = (const device W_T*)w; + typedef float U; + typedef struct { + W_T wi[tn * bytes_per_pack]; + } vec_w; + thread vec_w w_local; + thread U result[tn * pack_factor] = {0}; + thread U scale = 0; + thread U x_local = 0; + load_mxfp4_lut(lut, simd_gid, simd_lid); + const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor; + const int out_vec_size_g = out_vec_size / group_size; + int out_col = pack_factor * tn * (tid.y * num_simdgroups + simd_gid); + ws += out_col * bytes_per_pack / pack_factor + simd_lid * out_vec_size_w; + scales += out_col / group_size + simd_lid * out_vec_size_g; + x += tid.x * in_vec_size + simd_lid; + y += tid.x * out_vec_size + out_col; + if (out_col >= out_vec_size) { + return; + } + int remaining = in_vec_size % block_size; + if (remaining == 0) { + for (int i = 0; i < in_vec_size; i += block_size) { + x_local = *x; + scale = dequantize_scale(*scales); + w_local = *((device vec_w*)ws); + qouter( + (thread uint8_t*)&w_local, x_local, scale, result, lut); + x += block_size; + scales += block_size * out_vec_size_g; + ws += block_size * out_vec_size_w; + } + } else { + for (int i = block_size; i < in_vec_size; i += block_size) { + x_local = *x; + scale = dequantize_scale(*scales); + w_local = *((device vec_w*)ws); + qouter( + (thread uint8_t*)&w_local, x_local, scale, result, lut); + x += block_size; + scales += block_size * out_vec_size_g; + ws += block_size * out_vec_size_w; + } + if (static_cast(simd_lid) < remaining) { + x_local = *x; + scale = dequantize_scale(*scales); + w_local = *((device vec_w*)ws); + } else { + x_local = 0; + scale = 0; + } + qouter( + (thread uint8_t*)&w_local, x_local, scale, result, lut); + } +#pragma clang loop unroll(full) + for (int k = 0; k < tn * pack_factor; k++) { + result[k] = simd_sum(result[k]); + } + if (simd_lid == 0) { +#pragma clang loop unroll(full) + for (int k = 0; k < tn * pack_factor; k++) { + y[k] = static_cast(result[k]); + } + } +} +template < + typename T, + const int group_size, + typename S, + const bool aligned_N, + const int BM = 32, + const int BK = 32, + const int BN = 32> +METAL_FUNC void mxfp4_qmm_t_impl( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + threadgroup T* Xs, + threadgroup T* Ws, + const constant int& K, + const constant int& N, + const constant int& M, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]], + threadgroup T* lut) { + static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); + static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); + (void)lid; + constexpr int WM = 2; + constexpr int WN = 2; + constexpr int pack_factor = get_pack_factor<8>(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + constexpr int BK_padded = (BK + 16 / sizeof(T)); + using mma_t = mlx::steel:: + BlockMMA; + using loader_x_t = + mlx::steel::BlockLoader; + using loader_w_t = QuantizedBlockLoader< + T, + BN, + BK, + BK_padded, + 1, + WM * WN * SIMD_SIZE, + group_size, + S>; + const int K_w = K * bytes_per_pack / pack_factor; + const int K_g = K / group_size; + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + auto wl = (const device uint8_t*)w; + x += y_row * static_cast(K); + wl += y_col * K_w; + scales += y_col * K_g; + y += y_row * static_cast(N) + y_col; + const short num_els = min(BM, M - y_row); + const short num_outs = min(BN, N - y_col); + loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); + loader_w_t loader_w(wl, scales, K, Ws, lut, simd_gid, simd_lid); + mma_t mma_op(simd_gid, simd_lid); + if (num_els < BM) { + if (!aligned_N && num_outs < BN) { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(BK, num_els)); + loader_w.load_safe(short2(BK, num_outs)); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } else { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(BK, num_els)); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } + } else { + if (!aligned_N && num_outs < BN) { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_unsafe(); + loader_w.load_safe(short2(BK, num_outs)); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } else { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_unsafe(); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (num_els < BM || num_outs < BN) { + mma_op.store_result_safe(y, N, short2(num_outs, num_els)); + } else { + mma_op.store_result(y, N); + } +} +template < + typename T, + const int group_size, + typename S, + const int BM = 32, + const int BK = 32, + const int BN = 32> +METAL_FUNC void mxfp4_qmm_n_impl( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + threadgroup T* Xs, + threadgroup T* Ws, + const constant int& K, + const constant int& N, + const constant int& M, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]], + threadgroup T* lut) { + static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); + static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); + (void)lid; + constexpr int WM = 2; + constexpr int WN = 2; + constexpr int pack_factor = get_pack_factor<8>(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + using mma_t = mlx::steel:: + BlockMMA; + using loader_x_t = mlx::steel:: + BlockLoader; + using loader_w_t = QuantizedBlockLoader< + T, + BK, + BN, + BN_padded, + 0, + WM * WN * SIMD_SIZE, + group_size, + S>; + auto wl = (const device uint8_t*)w; + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + x += y_row * static_cast(K); + wl += y_col * bytes_per_pack / pack_factor; + scales += y_col / group_size; + y += y_row * static_cast(N) + y_col; + const short num_els = min(BM, M - y_row); + loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); + loader_w_t loader_w(wl, scales, N, Ws, lut, simd_gid, simd_lid); + mma_t mma_op(simd_gid, simd_lid); + if (num_els < BM) { + if ((K % BK) != 0) { + const int k_blocks = K / BK; + for (int k = 0; k < k_blocks; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(BK, num_els)); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + const short num_k = K - k_blocks * BK; + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(num_k, num_els)); + loader_w.load_safe(short2(BN, num_k)); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + } else { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(BK, num_els)); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } + } else { + if ((K % BK) != 0) { + const int k_blocks = K / BK; + for (int k = 0; k < k_blocks; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_unsafe(); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + const short num_k = K - k_blocks * BK; + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(num_k, BM)); + loader_w.load_safe(short2(BN, num_k)); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + } else { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_unsafe(); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (num_els < BM) { + mma_op.store_result_safe(y, N, short2(BN, num_els)); + } else { + mma_op.store_result(y, N); + } +} +template +METAL_FUNC void adjust_matrix_offsets( + const device T*& x, + const device uint32_t*& w, + const device S*& scales, + device T*& y, + int output_stride, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]]) { + uint32_t x_idx = tid.z; + uint32_t w_idx = tid.z; + if (x_batch_ndims == 1) { + x += x_idx * x_strides[0]; + } else { + x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); + } + if (w_batch_ndims == 1) { + w += w_idx * w_strides[0]; + scales += w_idx * s_strides[0]; + } else { + ulong2 idx = elem_to_loc_broadcast( + w_idx, w_shape, w_strides, s_strides, w_batch_ndims); + w += idx.x; + scales += idx.y; + } + y += tid.z * output_stride; +} +template +METAL_FUNC void adjust_matrix_offsets( + const device T*& x, + const device uint32_t*& w, + const device S*& scales, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T*& y, + int output_stride, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]]) { + uint32_t x_idx; + uint32_t w_idx; + if (batch_ndims == 1) { + x_idx = lhs_indices[tid.z * lhs_strides[0]]; + w_idx = rhs_indices[tid.z * rhs_strides[0]]; + } else { + ulong2 idx = elem_to_loc_broadcast( + tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims); + x_idx = lhs_indices[idx.x]; + w_idx = rhs_indices[idx.y]; + } + if (x_batch_ndims == 1) { + x += x_idx * x_strides[0]; + } else { + x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); + } + if (w_batch_ndims == 1) { + w += w_idx * w_strides[0]; + scales += w_idx * s_strides[0]; + } else { + ulong2 idx = elem_to_loc_broadcast( + w_idx, w_shape, w_strides, s_strides, w_batch_ndims); + w += idx.x; + scales += idx.y; + } + y += tid.z * output_stride; +} +template +[[kernel]] void mxfp4_qmv_quad( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint quad_gid [[quadgroup_index_in_threadgroup]], + uint quad_lid [[thread_index_in_quadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if (batched) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + y, + out_vec_size * M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + } + threadgroup float lut[16]; + mxfp4_qmv_quad_impl( + w, + scales, + x, + y, + in_vec_size, + out_vec_size, + tid, + quad_gid, + quad_lid, + simd_gid, + simd_lid, + lut); +} +template +[[kernel]] void mxfp4_qmv_fast( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if (batched) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + y, + out_vec_size * M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + } + threadgroup float lut[16]; + mxfp4_qmv_fast_impl( + w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); +} +template +[[kernel]] void mxfp4_qmv( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if (batched) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + y, + out_vec_size * M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + } + threadgroup float lut[16]; + mxfp4_qmv_impl( + w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); +} +template +[[kernel]] void mxfp4_qvm( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if (batched) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + y, + out_vec_size * M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + } + threadgroup float lut[16]; + mxfp4_qvm_impl( + w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); +} +template +[[kernel]] void mxfp4_qvm_split_k( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int& final_block_size, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + y, + out_vec_size * M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + int in_vec_size_adj = + tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size; + threadgroup float lut[16]; + mxfp4_qvm_impl( + w, + scales, + x, + y, + in_vec_size_adj, + out_vec_size, + tid, + simd_gid, + simd_lid, + lut); +} +template < + typename T, + const int group_size, + typename S, + const bool aligned_N, + const bool batched, + const int BM = 32, + const int BK = 32, + const int BN = 32> +[[kernel]] void mxfp4_qmm_t( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + constexpr int BK_padded = (BK + 16 / sizeof(T)); + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BN * BK_padded]; + threadgroup T lut[16]; + if (batched) { + adjust_matrix_offsets( + x, + w, + scales, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + } + mxfp4_qmm_t_impl( + w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut); +} +template < + typename T, + const int group_size, + typename S, + const bool batched, + const int BM = 32, + const int BK = 32, + const int BN = 32> +[[kernel]] void mxfp4_qmm_n( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BK * BN_padded]; + threadgroup T lut[16]; + if (batched) { + adjust_matrix_offsets( + x, + w, + scales, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + } + mxfp4_qmm_n_impl( + w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut); +} +template +[[kernel]] void mxfp4_gather_qmv_fast( + const device uint32_t* w, + const device S* scales, + const device T* x, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + lhs_indices, + rhs_indices, + y, + out_vec_size * M, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + threadgroup float lut[16]; + mxfp4_qmv_fast_impl( + w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); +} +template +[[kernel]] void mxfp4_gather_qmv( + const device uint32_t* w, + const device S* scales, + const device T* x, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + lhs_indices, + rhs_indices, + y, + out_vec_size * M, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + threadgroup float lut[16]; + mxfp4_qmv_impl( + w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); +} +template +[[kernel]] void mxfp4_gather_qvm( + const device uint32_t* w, + const device S* scales, + const device T* x, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + lhs_indices, + rhs_indices, + y, + out_vec_size * M, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + threadgroup float lut[16]; + mxfp4_qvm_impl( + w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); +} +template < + typename T, + const int group_size, + typename S, + const bool aligned_N, + const int BM = 32, + const int BK = 32, + const int BN = 32> +[[kernel]] void mxfp4_gather_qmm_t( + const device uint32_t* w, + const device S* scales, + const device T* x, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + constexpr int BK_padded = (BK + 16 / sizeof(T)); + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BN * BK_padded]; + threadgroup T lut[16]; + adjust_matrix_offsets( + x, + w, + scales, + lhs_indices, + rhs_indices, + y, + M * N, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + mxfp4_qmm_t_impl( + w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut); +} +template < + typename T, + const int group_size, + typename S, + const int BM = 32, + const int BK = 32, + const int BN = 32> +[[kernel]] void mxfp4_gather_qmm_n( + const device uint32_t* w, + const device S* scales, + const device T* x, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BK * BN_padded]; + threadgroup T lut[16]; + adjust_matrix_offsets( + x, + w, + scales, + lhs_indices, + rhs_indices, + y, + M * N, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + mxfp4_qmm_n_impl( + w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut); +} +template < + typename T, + int group_size, + typename S, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose> +[[kernel]] void mxfp4_gather_qmm_rhs( + const device T* x, + const device uint32_t* w, + const device S* scales, + const device uint32_t* indices, + device T* y, + const constant int& M, + const constant int& N, + const constant int& K, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) { + constexpr int pack_factor = get_pack_factor<8>(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + threadgroup T lut[16]; + using mma_t = mlx::steel::BlockMMA< + T, + T, + BM, + BN, + BK, + WM, + WN, + false, + transpose, + BK_padded, + transpose ? BK_padded : BN_padded>; + using loader_x_t = + mlx::steel::BlockLoader; + using loader_w_t = QuantizedBlockLoader< + T, + transpose ? BN : BK, + transpose ? BK : BN, + transpose ? BK_padded : BN_padded, + transpose, + WM * WN * SIMD_SIZE, + group_size, + S>; + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded]; + const int K_w = K * bytes_per_pack / pack_factor; + const int K_g = K / group_size; + const int N_w = N * bytes_per_pack / pack_factor; + const int N_g = N / group_size; + const int K_it = K / BK; + const size_t stride_w = transpose ? N * K_w : K * N_w; + const size_t stride_s = transpose ? N * K_g : K * N_g; + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + const size_t y_row_long = size_t(y_row); + const size_t y_col_long = size_t(y_col); + const short tgp_bm = align_M ? BM : short(min(BM, M - y_row)); + const short tgp_bn = align_N ? BN : short(min(BN, N - y_col)); + const int k_remain = K - K_it * BK; + const short2 tile_x = short2(k_remain, tgp_bm); + const short2 tile_w = + transpose ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + auto wl = (const device uint8_t*)w; + x += y_row_long * K; + y += y_row_long * N + y_col_long; + wl += transpose ? y_col_long * K_w : y_col * bytes_per_pack / pack_factor; + scales += transpose ? y_col_long * K_g : y_col / group_size; + uint32_t index; + short offset; + uint32_t index_next = indices[y_row]; + short offset_next = 0; + int n = 0; + while (n < tgp_bm) { + n++; + offset = offset_next; + index = index_next; + offset_next = tgp_bm; + for (; n < tgp_bm; n++) { + if (indices[y_row + n] != index) { + offset_next = n; + index_next = indices[y_row + n]; + break; + } + } + threadgroup_barrier(mem_flags::mem_none); + thread mma_t mma_op(simd_group_id, simd_lane_id); + thread loader_x_t loader_x(x, K, Xs, simd_group_id, simd_lane_id); + thread loader_w_t loader_w( + wl + index * stride_w, + scales + index * stride_s, + transpose ? K : N, + Ws, + lut, + simd_group_id, + simd_lane_id); + if (align_M && align_N) { + gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize(Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + if (offset_next - offset == BM) { + mma_op.store_result(y, N); + } else { + mma_op.store_result_slice( + y, N, short2(0, offset), short2(BN, offset_next)); + } + } else { + if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { + gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize( + Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + if (offset_next - offset == BM) { + mma_op.store_result(y, N); + } else { + mma_op.store_result_slice( + y, N, short2(0, offset), short2(BN, offset_next)); + } + } + else if (align_N || tgp_bn == BN) { + gemm_loop_unaligned( + Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize( + Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + mma_op.store_result_slice( + y, N, short2(0, offset), short2(BN, offset_next)); + } + else if (align_M || tgp_bm == BM) { + gemm_loop_unaligned( + Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize( + Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + mma_op.store_result_slice( + y, N, short2(0, offset), short2(tgp_bn, offset_next)); + } + else { + gemm_loop_unaligned( + Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize( + Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + mma_op.store_result_slice( + y, N, short2(0, offset), short2(tgp_bn, offset_next)); + } + } + } +} +)preamble"; +} + +} // namespace mlx::core::metal diff --git a/Source/Cmlx/mlx-generated/gather_front.cpp b/Source/Cmlx/mlx-generated/gather_front.cpp new file mode 100644 index 00000000..66e60bc5 --- /dev/null +++ b/Source/Cmlx/mlx-generated/gather_front.cpp @@ -0,0 +1,42 @@ +namespace mlx::core::metal { + +const char* gather_front() { + return R"preamble( +template +struct Indices { + const array buffers; + const constant int* shapes; + const constant int64_t* strides; + const constant bool* row_contiguous; + const int ndim; +}; +template +METAL_FUNC size_t offset_neg_idx(IdxT idx, int size) { + if (is_unsigned_v) { + return idx; + } else { + return (idx < 0) ? idx + size : idx; + } +} + +template +[[kernel]] void gather_front( + const device T* src, + const device IdxT* indices, + device T* out, + const constant int64_t& stride, + const constant int& size, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + auto idx = offset_neg_idx(indices[index.y], size); + LocT src_idx = static_cast(stride) * idx; + LocT out_idx = static_cast(stride) * index.y; + int s_idx = N * index.x; + for (int i = 0; i < N && s_idx < stride; ++i, ++s_idx) { + out[out_idx + s_idx] = src[src_idx + s_idx]; + } +} +)preamble"; +} + +} // namespace mlx::core::metal diff --git a/Source/Cmlx/mlx-generated/gemv_masked.cpp b/Source/Cmlx/mlx-generated/gemv_masked.cpp index c01337d7..ebcbe300 100644 --- a/Source/Cmlx/mlx-generated/gemv_masked.cpp +++ b/Source/Cmlx/mlx-generated/gemv_masked.cpp @@ -216,28 +216,29 @@ struct GEMVKernel { mat_mask_offset += mat_mask_step; vec_mask_offset += vec_mask_step; } - if (leftover > 0 && - (!has_operand_mask || - (bool(mat_mask[mat_mask_offset]) && - bool(vec_mask[vec_mask_offset])))) { - T block_scale{1}; - if (has_mul_operand_mask) { - block_scale = - T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); - } - load_safe(in_vec, v_coeff, bn, in_size); - if (has_mul_operand_mask) { + if (leftover > 0) { + if (!has_operand_mask || + (bool(mat_mask[mat_mask_offset]) && + bool(vec_mask[vec_mask_offset]))) { + T block_scale{1}; + if (has_mul_operand_mask) { + block_scale = + T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); + } + load_safe(in_vec, v_coeff, bn, in_size); + if (has_mul_operand_mask) { #pragma clang loop unroll(full) - for (int tn = 0; tn < TN; tn++) { - v_coeff[tn] *= block_scale; + for (int tn = 0; tn < TN; tn++) { + v_coeff[tn] *= block_scale; + } } - } #pragma clang loop unroll(full) - for (int tm = 0; tm < TM; tm++) { - load_safe(&mat[tm * matrix_ld], inter, bn, in_size); + for (int tm = 0; tm < TM; tm++) { + load_safe(&mat[tm * matrix_ld], inter, bn, in_size); #pragma clang loop unroll(full) - for (int tn = 0; tn < TN; tn++) { - result[tm] += inter[tn] * v_coeff[tn]; + for (int tn = 0; tn < TN; tn++) { + result[tm] += inter[tn] * v_coeff[tn]; + } } } } @@ -413,27 +414,28 @@ struct GEMVTKernel { mat_mask_offset += mat_mask_step; vec_mask_offset += vec_mask_step; } - if (leftover > 0 && - (!has_operand_mask || - (bool(mat_mask[mat_mask_offset]) && - bool(vec_mask[vec_mask_offset])))) { - T block_scale{1}; - if (has_mul_operand_mask) { - block_scale = - T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); - } - for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) { - v_coeff[tm] = static_cast(in_vec[bm + tm]); + if (leftover > 0) { + if (!has_operand_mask || + (bool(mat_mask[mat_mask_offset]) && + bool(vec_mask[vec_mask_offset]))) { + T block_scale{1}; if (has_mul_operand_mask) { - v_coeff[tm] *= block_scale; + block_scale = + T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); } + for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) { + v_coeff[tm] = static_cast(in_vec[bm + tm]); + if (has_mul_operand_mask) { + v_coeff[tm] *= block_scale; + } #pragma clang loop unroll(full) - for (int tn = 0; tn < TN; tn++) { - inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn]; - } + for (int tn = 0; tn < TN; tn++) { + inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn]; + } #pragma clang loop unroll(full) - for (int tn = 0; tn < TN; tn++) { - result[tn] += v_coeff[tm] * inter[tn]; + for (int tn = 0; tn < TN; tn++) { + result[tn] += v_coeff[tm] * inter[tn]; + } } } } diff --git a/Source/Cmlx/mlx-generated/hadamard.cpp b/Source/Cmlx/mlx-generated/hadamard.cpp index b0839df9..e2450ba5 100644 --- a/Source/Cmlx/mlx-generated/hadamard.cpp +++ b/Source/Cmlx/mlx-generated/hadamard.cpp @@ -22,7 +22,7 @@ METAL_FUNC void radix_func(thread float* x) { h <<= 1; } } -template +template [[kernel]] void hadamard_n( const device T* in [[buffer(0)]], device T* out [[buffer(1)]], @@ -35,15 +35,22 @@ template constexpr short num_steps = logN / logR; constexpr short logFinal = logN % logR; constexpr short final_radix = 1 << (logFinal); - int batch_idx = elem.x * N; - short i = elem.y; + int batch_idx = elem.y * N * stride + elem.z; + short i = elem.x; threadgroup T buf[N]; + if (stride == 1) { #pragma clang loop unroll(full) - for (short j = 0; j < max_radix / read_width; j++) { - short index = j * read_width * num_threads + i * read_width; + for (short j = 0; j < max_radix / read_width; j++) { + short index = j * read_width * num_threads + i * read_width; #pragma clang loop unroll(full) - for (short r = 0; r < read_width; r++) { - buf[index + r] = in[batch_idx + index + r]; + for (short r = 0; r < read_width; r++) { + buf[index + r] = in[batch_idx + index + r]; + } + } + } else { +#pragma clang loop unroll(full) + for (short j = 0; j < max_radix; j++) { + buf[j * num_threads + i] = in[batch_idx + (j * num_threads + i) * stride]; } } threadgroup_barrier(mem_flags::mem_threadgroup); @@ -83,12 +90,20 @@ template } threadgroup_barrier(mem_flags::mem_threadgroup); } + if (stride == 1) { #pragma clang loop unroll(full) - for (short j = 0; j < max_radix / read_width; j++) { - short index = j * read_width * num_threads + i * read_width; + for (short j = 0; j < max_radix / read_width; j++) { + short index = j * read_width * num_threads + i * read_width; #pragma clang loop unroll(full) - for (short r = 0; r < read_width; r++) { - out[batch_idx + index + r] = T(buf[index + r] * scale); + for (short r = 0; r < read_width; r++) { + out[batch_idx + index + r] = T(buf[index + r] * scale); + } + } + } else { +#pragma clang loop unroll(full) + for (short j = 0; j < max_radix; j++) { + out[batch_idx + (j * num_threads + i) * stride] = + buf[j * num_threads + i]; } } } diff --git a/Source/Cmlx/mlx-generated/logsumexp.cpp b/Source/Cmlx/mlx-generated/logsumexp.cpp index 9c092cb2..d3d4cf3d 100644 --- a/Source/Cmlx/mlx-generated/logsumexp.cpp +++ b/Source/Cmlx/mlx-generated/logsumexp.cpp @@ -92,8 +92,8 @@ template } } else { for (int i = 0; i < N_READS; i++) { - vals[i] = (offset + i < axis_size) ? AccT(in[offset + i]) - : Limits::finite_min; + vals[i] = + (offset + i < axis_size) ? AccT(in[offset + i]) : Limits::min; } } prevmax = maxval; @@ -121,11 +121,8 @@ template } threadgroup_barrier(mem_flags::mem_threadgroup); normalizer = simd_sum(local_normalizer[simd_lane_id]); - if (simd_group_id == 0) { - normalizer = simd_sum(local_normalizer[simd_lane_id]); - if (simd_lane_id == 0) { - out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval); - } + if (lid == 0) { + out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval); } } )preamble"; diff --git a/Source/Cmlx/mlx-generated/metal/arg_reduce.metal b/Source/Cmlx/mlx-generated/metal/arg_reduce.metal index 8c904de6..3cd95c52 100644 --- a/Source/Cmlx/mlx-generated/metal/arg_reduce.metal +++ b/Source/Cmlx/mlx-generated/metal/arg_reduce.metal @@ -80,9 +80,10 @@ template const constant size_t& ndim [[buffer(5)]], const constant int64_t& axis_stride [[buffer(6)]], const constant size_t& axis_size [[buffer(7)]], - uint gid [[thread_position_in_grid]], - uint lid [[thread_position_in_threadgroup]], - uint lsize [[threads_per_threadgroup]], + uint3 gid [[thread_position_in_grid]], + uint3 gsize [[threads_per_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 lsize [[threads_per_threadgroup]], uint simd_size [[threads_per_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { @@ -104,17 +105,18 @@ template // Compute the input/output index. There is one beginning and one output for // the whole threadgroup. - auto in_idx = elem_to_loc(gid / lsize, shape, in_strides, ndim); - auto out_idx = elem_to_loc(gid / lsize, shape, out_strides, ndim); + int64_t row_idx = gid.y + static_cast(gsize.y) * gid.z; + auto in_idx = elem_to_loc(row_idx, shape, in_strides, ndim); + auto out_idx = elem_to_loc(row_idx, shape, out_strides, ndim); IndexValPair best{0, Op::init}; threadgroup IndexValPair local_data[32]; // Loop over the reduction axis in lsize*N_READS buckets - for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize); r++) { + for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize.x); r++) { // Read the current value - uint32_t current_index = r * lsize * N_READS + lid * N_READS; + uint32_t current_index = r * lsize.x * N_READS + lid.x * N_READS; uint32_t offset = current_index; const device T* current_in = in + in_idx + current_index * axis_stride; T vals[N_READS]; @@ -144,7 +146,7 @@ template } // Read the appropriate value from local data and perform one simd reduction - uint simd_groups = ceildiv(lsize, simd_size); + uint simd_groups = ceildiv(lsize.x, simd_size); if (simd_lane_id < simd_groups) { best = local_data[simd_lane_id]; } @@ -154,7 +156,7 @@ template } // Finally write the output - if (lid == 0) { + if (lid.x == 0) { out[out_idx] = best.index; } } diff --git a/Source/Cmlx/mlx-generated/metal/binary.h b/Source/Cmlx/mlx-generated/metal/binary.h index 91a02c81..f1df8853 100644 --- a/Source/Cmlx/mlx-generated/metal/binary.h +++ b/Source/Cmlx/mlx-generated/metal/binary.h @@ -9,64 +9,121 @@ template c[index] = Op()(a[0], b[0]); } -template +template ::n> [[kernel]] void binary_sv( device const T* a, device const T* b, device U* c, + constant uint& size, uint index [[thread_position_in_grid]]) { - c[index] = Op()(a[0], b[index]); + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + c[index + i] = Op()(a[0], b[index + i]); + } + } else { + for (int i = 0; i < N; ++i) { + c[index + i] = Op()(a[0], b[index + i]); + } + } } -template +template ::n> [[kernel]] void binary_vs( device const T* a, device const T* b, device U* c, + constant uint& size, uint index [[thread_position_in_grid]]) { - c[index] = Op()(a[index], b[0]); + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + c[index + i] = Op()(a[index + i], b[0]); + } + } else { + for (int i = 0; i < N; ++i) { + c[index + i] = Op()(a[index + i], b[0]); + } + } } -template +template ::n> [[kernel]] void binary_vv( device const T* a, device const T* b, device U* c, + constant uint& size, uint index [[thread_position_in_grid]]) { - c[index] = Op()(a[index], b[index]); + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + c[index + i] = Op()(a[index + i], b[index + i]); + } + } else { + for (int i = 0; i < N; ++i) { + c[index + i] = Op()(a[index + i], b[index + i]); + } + } } -template +template ::n> [[kernel]] void binary_sv2( device const T* a, device const T* b, device U* c, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - int64_t offset = index.x + grid_dim.x * int64_t(index.y); - c[offset] = Op()(a[0], b[offset]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + c[offset + i] = Op()(a[0], b[offset + i]); + } + } else { + for (int i = 0; i < N; ++i) { + c[offset + i] = Op()(a[0], b[offset + i]); + } + } } -template +template ::n> [[kernel]] void binary_vs2( device const T* a, device const T* b, device U* c, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - int64_t offset = index.x + grid_dim.x * int64_t(index.y); - c[offset] = Op()(a[offset], b[0]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + c[offset + i] = Op()(a[offset + i], b[0]); + } + } else { + for (int i = 0; i < N; ++i) { + c[offset + i] = Op()(a[offset + i], b[0]); + } + } } -template +template ::n> [[kernel]] void binary_vv2( device const T* a, device const T* b, device U* c, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - int64_t offset = index.x + grid_dim.x * int64_t(index.y); - c[offset] = Op()(a[offset], b[offset]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + c[offset + i] = Op()(a[offset + i], b[offset + i]); + } + } else { + for (int i = 0; i < N; ++i) { + c[offset + i] = Op()(a[offset + i], b[offset + i]); + } + } } template diff --git a/Source/Cmlx/mlx-generated/metal/binary_ops.h b/Source/Cmlx/mlx-generated/metal/binary_ops.h index 4aaf2b4d..cb3e8a37 100644 --- a/Source/Cmlx/mlx-generated/metal/binary_ops.h +++ b/Source/Cmlx/mlx-generated/metal/binary_ops.h @@ -223,6 +223,11 @@ struct Power { template metal::enable_if_t, T> operator()(T base, T exp) { T res = 1; + // Undefined to raise integer to negative power + if (exp < 0) { + return 0; + } + while (exp) { if (exp & 1) { res *= base; @@ -235,6 +240,13 @@ struct Power { template <> complex64_t operator()(complex64_t x, complex64_t y) { + if (x.real == 0 && x.imag == 0) { + if (metal::isnan(y.real) || metal::isnan(y.imag)) { + auto nan = metal::numeric_limits::quiet_NaN(); + return {nan, nan}; + } + return {0.0, 0.0}; + } auto x_theta = metal::atan2(x.imag, x.real); auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag); auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta); diff --git a/Source/Cmlx/mlx-generated/metal/binary_two.h b/Source/Cmlx/mlx-generated/metal/binary_two.h index 8f6b3392..4455e4ca 100644 --- a/Source/Cmlx/mlx-generated/metal/binary_two.h +++ b/Source/Cmlx/mlx-generated/metal/binary_two.h @@ -12,82 +12,151 @@ template d[index] = out[1]; } -template +template ::n> [[kernel]] void binary_sv( device const T* a, device const T* b, device U* c, device U* d, + constant uint& size, uint index [[thread_position_in_grid]]) { - auto out = Op()(a[0], b[index]); - c[index] = out[0]; - d[index] = out[1]; + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + auto out = Op()(a[0], b[index + i]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[0], b[index + i]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } + } } -template +template ::n> [[kernel]] void binary_vs( device const T* a, device const T* b, device U* c, device U* d, + constant uint& size, uint index [[thread_position_in_grid]]) { - auto out = Op()(a[index], b[0]); - c[index] = out[0]; - d[index] = out[1]; + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + auto out = Op()(a[index + i], b[0]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[index + i], b[0]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } + } } -template +template ::n> [[kernel]] void binary_vv( device const T* a, device const T* b, device U* c, device U* d, + constant uint& size, uint index [[thread_position_in_grid]]) { - auto out = Op()(a[index], b[index]); - c[index] = out[0]; - d[index] = out[1]; + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + auto out = Op()(a[index + i], b[index + i]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[index + i], b[index + i]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } + } } -template +template ::n> [[kernel]] void binary_sv2( device const T* a, device const T* b, device U* c, device U* d, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - auto out = Op()(a[0], b[offset]); - c[offset] = out[0]; - d[offset] = out[1]; + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + auto out = Op()(a[0], b[offset + i]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[0], b[offset + i]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } + } } -template +template ::n> [[kernel]] void binary_vs2( device const T* a, device const T* b, device U* c, device U* d, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - auto out = Op()(a[offset], b[0]); - c[offset] = out[0]; - d[offset] = out[1]; + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + auto out = Op()(a[offset + i], b[0]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[offset + i], b[0]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } + } } -template +template ::n> [[kernel]] void binary_vv2( device const T* a, device const T* b, device U* c, device U* d, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - auto out = Op()(a[offset], b[offset]); - c[offset] = out[0]; - d[offset] = out[1]; + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + auto out = Op()(a[offset + i], b[offset + i]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[offset + i], b[offset + i]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } + } } template diff --git a/Source/Cmlx/mlx-generated/metal/cexpf.h b/Source/Cmlx/mlx-generated/metal/cexpf.h new file mode 100644 index 00000000..b45fe6a2 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/cexpf.h @@ -0,0 +1,134 @@ +// Copyright © 2025 Apple Inc. +// Copyright © 2008-2013 NVIDIA Corporation +// Copyright © 2013 Filipe RNC Maia +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Forked from +// https://github.com/NVIDIA/cccl/blob/main/thrust/thrust/detail/complex/cexpf.h + +// TODO: We should use thrust::exp but the thrust header in old CUDA versions +// can not be used in JIT. + +#pragma once + +#include + +using ieee_float_shape_type = union { + float value; + uint32_t word; +}; + +inline void get_float_word(thread uint32_t& i, float d) { + ieee_float_shape_type gf_u; + gf_u.value = (d); + (i) = gf_u.word; +} + +inline void get_float_word(thread int32_t& i, float d) { + ieee_float_shape_type gf_u; + gf_u.value = (d); + (i) = gf_u.word; +} + +inline void set_float_word(thread float& d, uint32_t i) { + ieee_float_shape_type sf_u; + sf_u.word = (i); + (d) = sf_u.value; +} + +inline float frexp_expf(float x, thread int* expt) { + const uint32_t k = 235; + const float kln2 = 162.88958740F; + + float exp_x; + uint32_t hx; + + exp_x = metal::exp(x - kln2); + get_float_word(hx, exp_x); + *expt = (hx >> 23) - (0x7f + 127) + k; + set_float_word(exp_x, (hx & 0x7fffff) | ((0x7f + 127) << 23)); + return exp_x; +} + +inline complex64_t ldexp_cexpf(complex64_t z, int expt) { + float x, y, exp_x, scale1, scale2; + int ex_expt, half_expt; + + x = z.real; + y = z.imag; + exp_x = frexp_expf(x, &ex_expt); + expt += ex_expt; + + half_expt = expt / 2; + set_float_word(scale1, (0x7f + half_expt) << 23); + half_expt = expt - half_expt; + set_float_word(scale2, (0x7f + half_expt) << 23); + + return complex64_t{ + metal::cos(y) * exp_x * scale1 * scale2, + metal::sin(y) * exp_x * scale1 * scale2}; +} + +inline complex64_t cexpf(const thread complex64_t& z) { + float x, y, exp_x; + uint32_t hx, hy; + + const uint32_t exp_ovfl = 0x42b17218, cexp_ovfl = 0x43400074; + + x = z.real; + y = z.imag; + + get_float_word(hy, y); + hy &= 0x7fffffff; + + /* cexp(x + I 0) = exp(x) + I 0 */ + if (hy == 0) { + return complex64_t{metal::exp(x), y}; + } + get_float_word(hx, x); + /* cexp(0 + I y) = cos(y) + I sin(y) */ + if ((hx & 0x7fffffff) == 0) { + return complex64_t{metal::cos(y), metal::sin(y)}; + } + if (hy >= 0x7f800000) { + if ((hx & 0x7fffffff) != 0x7f800000) { + /* cexp(finite|NaN +- I Inf|NaN) = NaN + I NaN */ + return complex64_t{y - y, y - y}; + } else if (hx & 0x80000000) { + /* cexp(-Inf +- I Inf|NaN) = 0 + I 0 */ + return complex64_t{0.0, 0.0}; + } else { + /* cexp(+Inf +- I Inf|NaN) = Inf + I NaN */ + return complex64_t{x, y - y}; + } + } + + if (hx >= exp_ovfl && hx <= cexp_ovfl) { + /* + * x is between 88.7 and 192, so we must scale to avoid + * overflow in expf(x). + */ + return ldexp_cexpf(z, 0); + } else { + /* + * Cases covered here: + * - x < exp_ovfl and exp(x) won't overflow (common case) + * - x > cexp_ovfl, so exp(x) * s overflows for all s > 0 + * - x = +-Inf (generated by exp()) + * - x = NaN (spurious inexact exception from y) + */ + exp_x = metal::exp(x); + return complex64_t{exp_x * metal::cos(y), exp_x * metal::sin(y)}; + } +} diff --git a/Source/Cmlx/mlx-generated/metal/conv.metal b/Source/Cmlx/mlx-generated/metal/conv.metal index be03d69c..a757243d 100644 --- a/Source/Cmlx/mlx-generated/metal/conv.metal +++ b/Source/Cmlx/mlx-generated/metal/conv.metal @@ -166,115 +166,6 @@ instantiate_naive_unfold_nd_dims(float32, float); instantiate_naive_unfold_nd_dims(float16, half); instantiate_naive_unfold_nd_dims(bfloat16, bfloat16_t); -/////////////////////////////////////////////////////////////////////////////// -/// Slow and naive conv2d kernels -/////////////////////////////////////////////////////////////////////////////// - -template < - typename T, - const int BM, /* Threadgroup rows (in threads) */ - const int BN, /* Threadgroup cols (in threads) */ - const int TM, /* Thread rows (in elements) */ - const int TN, /* Thread cols (in elements) */ - const int BC = 16> -[[kernel]] void naive_conv_2d( - const device T* in [[buffer(0)]], - const device T* wt [[buffer(1)]], - device T* out [[buffer(2)]], - const constant MLXConvParams<2>& params [[buffer(3)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - (void)simd_gid; - (void)simd_lid; - - out += tid.z * params.out_strides[0]; - in += tid.z * params.in_strides[0]; - - int out_o = tid.y * BN * TN + lid.y * TN; - int out_hw = tid.x * BM * TM + lid.x * TM; - - int out_h[TM]; - int out_w[TN]; - - for (int m = 0; m < TM; ++m) { - int mm = (out_hw + m); - out_h[m] = mm / params.oS[1]; - out_w[m] = mm % params.oS[1]; - } - - T in_local[TM]; - T wt_local[TN]; - T out_local[TM * TN] = {T(0)}; - - for (int h = 0; h < params.wS[0]; ++h) { - for (int w = 0; w < params.wS[1]; ++w) { - for (int c = 0; c < params.C; ++c) { - // Local in - for (int m = 0; m < TM; m++) { - int i = out_h[m] * params.str[0] - params.pad[0] + h * params.kdil[0]; - int j = out_w[m] * params.str[1] - params.pad[1] + w * params.kdil[1]; - - bool valid = i >= 0 && i < params.iS[0] && j >= 0 && j < params.iS[1]; - in_local[m] = valid - ? in[i * params.in_strides[1] + j * params.in_strides[2] + c] - : T(0); - } - - // Load weight - for (int n = 0; n < TN; ++n) { - int o = out_o + n; - wt_local[n] = o < params.O - ? wt[o * params.wt_strides[0] + h * params.wt_strides[1] + - w * params.wt_strides[2] + c] - : T(0); - } - - // Accumulate - for (int m = 0; m < TM; ++m) { - for (int n = 0; n < TN; ++n) { - out_local[m * TN + n] += in_local[m] * wt_local[n]; - } - } - } - } - } - - for (int m = 0; m < TM; ++m) { - for (int n = 0; n < TN; ++n) { - if (out_h[m] < params.oS[0] && out_w[m] < params.oS[1] && - (out_o + n) < params.O) - out[out_h[m] * params.out_strides[1] + - out_w[m] * params.out_strides[2] + out_o + n] = - out_local[m * TN + n]; - } - } -} - -// Instantiations - -#define instantiate_naive_conv_2d(name, itype, bm, bn, tm, tn) \ - template [[host_name("naive_conv_2d_" #name "_bm" #bm "_bn" #bn "_tm" #tm \ - "_tn" #tn)]] [[kernel]] void \ - naive_conv_2d( \ - const device itype* in [[buffer(0)]], \ - const device itype* wt [[buffer(1)]], \ - device itype* out [[buffer(2)]], \ - const constant MLXConvParams<2>& params [[buffer(3)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]], \ - uint simd_gid [[simdgroup_index_in_threadgroup]], \ - uint simd_lid [[thread_index_in_simdgroup]]); - -#define instantiate_naive_conv_2d_blocks(name, itype) \ - instantiate_naive_conv_2d(name, itype, 16, 8, 4, 4) \ - instantiate_naive_conv_2d(name, itype, 16, 8, 2, 4) - -instantiate_naive_conv_2d_blocks(float32, float); -instantiate_naive_conv_2d_blocks(float16, half); -instantiate_naive_conv_2d_blocks(bfloat16, bfloat16_t); - /////////////////////////////////////////////////////////////////////////////// /// Depthwise convolution kernels /////////////////////////////////////////////////////////////////////////////// @@ -397,6 +288,40 @@ instantiate_depthconv2d(float32, float); instantiate_depthconv2d(float16, half); instantiate_depthconv2d(bfloat16, bfloat16_t); +template +[[kernel]] void depthwise_conv_1d( + const device T* in [[buffer(0)]], + const device T* w [[buffer(1)]], + device T* out [[buffer(2)]], + constant const IdxT strides[3], + constant const int& kernel_size, + uint3 tid [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + out += (tid.z * static_cast(grid_dim.y) + tid.y) * grid_dim.x + tid.x; + in += tid.z * strides[0] + tid.y * strides[1] + tid.x * strides[2]; + w += tid.x * kernel_size; + + float acc = 0.0; + for (int i = 0; i < kernel_size; ++i) { + acc += static_cast(in[0]) * w[i]; + in += strides[1]; + } + *out = static_cast(acc); +} + +#define instantiate_depthconv1d(iname, itype) \ + instantiate_kernel( \ + "depthwise_conv_1d_" #iname, depthwise_conv_1d, itype, int32_t) \ + instantiate_kernel( \ + "depthwise_conv_1d_" #iname "_large", \ + depthwise_conv_1d, \ + itype, \ + int64_t) + +instantiate_depthconv1d(float32, float); +instantiate_depthconv1d(float16, half); +instantiate_depthconv1d(bfloat16, bfloat16_t); + /////////////////////////////////////////////////////////////////////////////// /// Winograd kernels /////////////////////////////////////////////////////////////////////////////// diff --git a/Source/Cmlx/mlx-generated/metal/copy.h b/Source/Cmlx/mlx-generated/metal/copy.h index b1367cf4..cf22347e 100644 --- a/Source/Cmlx/mlx-generated/metal/copy.h +++ b/Source/Cmlx/mlx-generated/metal/copy.h @@ -1,39 +1,77 @@ // Copyright © 2024 Apple Inc. -template +template ::n> [[kernel]] void copy_s( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], + constant uint& size, uint index [[thread_position_in_grid]]) { - dst[index] = static_cast(src[0]); + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + dst[index + i] = static_cast(src[0]); + } + } else { + for (int i = 0; i < N; ++i) { + dst[index + i] = static_cast(src[0]); + } + } } -template +template ::n> [[kernel]] void copy_v( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], + constant uint& size, uint index [[thread_position_in_grid]]) { - dst[index] = static_cast(src[index]); + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + dst[index + i] = static_cast(src[index + i]); + } + } else { + for (int i = 0; i < N; ++i) { + dst[index + i] = static_cast(src[index + i]); + } + } } -template +template ::n> [[kernel]] void copy_s2( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - dst[offset] = static_cast(src[0]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + dst[offset + i] = static_cast(src[0]); + } + } else { + for (int i = 0; i < N; ++i) { + dst[offset + i] = static_cast(src[0]); + } + } } -template +template ::n> [[kernel]] void copy_v2( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - dst[offset] = static_cast(src[offset]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + dst[offset + i] = static_cast(src[offset + i]); + } + } else { + for (int i = 0; i < N; ++i) { + dst[offset + i] = static_cast(src[offset + i]); + } + } } template diff --git a/Source/Cmlx/mlx-generated/metal/fft/readwrite.h b/Source/Cmlx/mlx-generated/metal/fft/readwrite.h index 23231946..4459d36f 100644 --- a/Source/Cmlx/mlx-generated/metal/fft/readwrite.h +++ b/Source/Cmlx/mlx-generated/metal/fft/readwrite.h @@ -10,7 +10,7 @@ For many sizes, GPU FFTs are memory bandwidth bound so read/write performance is important. Where possible, we read 128 bits sequentially in each thread, -coalesced with accesses from adajcent threads for optimal performance. +coalesced with accesses from adjacent threads for optimal performance. We implement specialized reading/writing for: - FFT @@ -98,7 +98,7 @@ struct ReadWriter { } METAL_FUNC void load() const { - int batch_idx = elem.x * grid.y * n; + size_t batch_idx = size_t(elem.x * grid.y) * n; short tg_idx = elem.y * grid.z + elem.z; short max_index = grid.y * n - 2; @@ -121,7 +121,7 @@ struct ReadWriter { } METAL_FUNC void write() const { - int batch_idx = elem.x * grid.y * n; + size_t batch_idx = size_t(elem.x * grid.y) * n; short tg_idx = elem.y * grid.z + elem.z; short max_index = grid.y * n - 2; @@ -144,7 +144,7 @@ struct ReadWriter { // Padded IO for Bluestein's algorithm METAL_FUNC void load_padded(int length, const device float2* w_k) const { - int batch_idx = elem.x * grid.y * length + elem.y * length; + size_t batch_idx = size_t(elem.x * grid.y) * length + elem.y * length; int fft_idx = elem.z; int m = grid.z; @@ -161,7 +161,7 @@ struct ReadWriter { } METAL_FUNC void write_padded(int length, const device float2* w_k) const { - int batch_idx = elem.x * grid.y * length + elem.y * length; + size_t batch_idx = size_t(elem.x * grid.y) * length + elem.y * length; int fft_idx = elem.z; int m = grid.z; float2 inv_factor = {1.0f / n, -1.0f / n}; @@ -261,7 +261,7 @@ METAL_FUNC bool ReadWriter::out_of_bounds() const { template <> METAL_FUNC void ReadWriter::load() const { - int batch_idx = elem.x * grid.y * n * 2 + elem.y * n * 2; + size_t batch_idx = size_t(elem.x * grid.y) * n * 2 + elem.y * n * 2; threadgroup float2* seq_buf = buf + elem.y * n; // No out of bounds accesses on odd batch sizes @@ -283,7 +283,8 @@ template <> METAL_FUNC void ReadWriter::write() const { short n_over_2 = (n / 2) + 1; - int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2; + size_t batch_idx = + size_t(elem.x * grid.y) * n_over_2 * 2 + elem.y * n_over_2 * 2; threadgroup float2* seq_buf = buf + elem.y * n; int grid_index = elem.x * grid.y + elem.y; @@ -317,7 +318,7 @@ template <> METAL_FUNC void ReadWriter::load_padded( int length, const device float2* w_k) const { - int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2; + size_t batch_idx = size_t(elem.x * grid.y) * length * 2 + elem.y * length * 2; threadgroup float2* seq_buf = buf + elem.y * n; // No out of bounds accesses on odd batch sizes @@ -345,8 +346,8 @@ METAL_FUNC void ReadWriter::write_padded( int length, const device float2* w_k) const { int length_over_2 = (length / 2) + 1; - int batch_idx = - elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2; + size_t batch_idx = + size_t(elem.x * grid.y) * length_over_2 * 2 + elem.y * length_over_2 * 2; threadgroup float2* seq_buf = buf + elem.y * n + length - 1; int grid_index = elem.x * grid.y + elem.y; @@ -397,7 +398,8 @@ METAL_FUNC bool ReadWriter::out_of_bounds() const { template <> METAL_FUNC void ReadWriter::load() const { short n_over_2 = (n / 2) + 1; - int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2; + size_t batch_idx = + size_t(elem.x * grid.y) * n_over_2 * 2 + elem.y * n_over_2 * 2; threadgroup float2* seq_buf = buf + elem.y * n; // No out of bounds accesses on odd batch sizes @@ -458,8 +460,8 @@ METAL_FUNC void ReadWriter::load_padded( int n_over_2 = (n / 2) + 1; int length_over_2 = (length / 2) + 1; - int batch_idx = - elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2; + size_t batch_idx = + size_t(elem.x * grid.y) * length_over_2 * 2 + elem.y * length_over_2 * 2; threadgroup float2* seq_buf = buf + elem.y * n; // No out of bounds accesses on odd batch sizes @@ -503,7 +505,7 @@ template <> METAL_FUNC void ReadWriter::write_padded( int length, const device float2* w_k) const { - int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2; + size_t batch_idx = size_t(elem.x * grid.y) * length * 2 + elem.y * length * 2; threadgroup float2* seq_buf = buf + elem.y * n + length - 1; int grid_index = elem.x * grid.y + elem.y; diff --git a/Source/Cmlx/mlx-generated/metal/fp4_quantized.h b/Source/Cmlx/mlx-generated/metal/fp4_quantized.h new file mode 100644 index 00000000..0b22dc1e --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/fp4_quantized.h @@ -0,0 +1,1791 @@ +// Copyright © 2025 Apple Inc. + +#include +#include + +constant bool align_M [[function_constant(200)]]; +constant bool align_N [[function_constant(201)]]; +constant bool align_K [[function_constant(202)]]; + +using namespace metal; + +#define MLX_MTL_CONST static constant constexpr const + +MLX_MTL_CONST int SIMD_SIZE = 32; +MLX_MTL_CONST int QUAD_SIZE = 4; + +template +inline constexpr short get_pack_factor() { + return wsize / 4; +} + +template +inline constexpr short get_bytes_per_pack() { + return wsize / 8; +} + +template +static inline T dequantize_scale(uint8_t s) { + using FOrI = union { + bfloat16_t f; + uint16_t i; + }; + FOrI out; + out.i = (s == 0 ? 0x40 : (static_cast(s) << 7)); + return static_cast(out.f); +} + +template +inline void load_vector(const device T* x, thread U* x_thread) { + for (int i = 0; i < values_per_thread; i += 4) { + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1]; + x_thread[i + 2] = x[i + 2]; + x_thread[i + 3] = x[i + 3]; + } +} + +template +inline void load_vector_safe(const device T* x, thread U* x_thread, int N) { + for (int i = 0; i < N; i += 4) { + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1]; + x_thread[i + 2] = x[i + 2]; + x_thread[i + 3] = x[i + 3]; + } + + for (int i = N; i < values_per_thread; i++) { + x_thread[i] = 0; + } +} + +constexpr constant static float MXFP4_LUT[16] = { + +0.0f, + +0.5f, + +1.0f, + +1.5f, + +2.0f, + +3.0f, + +4.0f, + +6.0f, + -0.0f, + -0.5f, + -1.0f, + -1.5f, + -2.0f, + -3.0f, + -4.0f, + -6.0f}; + +template +void load_mxfp4_lut(threadgroup T* lut, uint simd_gid, uint simd_lid) { + if (simd_gid == 0 && simd_lid < 16) { + lut[simd_lid] = static_cast(MXFP4_LUT[simd_lid]); + } + threadgroup_barrier(mem_flags::mem_threadgroup); +} + +template +inline U qdot( + const device uint8_t* w, + const thread U* x_thread, + U scale, + const threadgroup U* lut) { + U accum = 0; + const device uint16_t* ws = (const device uint16_t*)w; + for (int i = 0; i < (values_per_thread / 4); i++) { + accum += + (x_thread[4 * i] * lut[ws[i] & 0xf] + + x_thread[4 * i + 1] * lut[(ws[i] >> 4) & 0xf] + + x_thread[4 * i + 2] * lut[(ws[i] >> 8) & 0xf] + + x_thread[4 * i + 3] * lut[(ws[i] >> 12) & 0xf]); + } + return scale * accum; +} + +template +inline U qdot_safe( + const device uint8_t* w, + const thread U* x_thread, + U scale, + const threadgroup U* lut, + int N) { + U accum = 0; + + const device uint16_t* ws = (const device uint16_t*)w; + for (int i = 0; i < (N / 4); i++) { + accum += + (x_thread[4 * i] * lut[ws[i] & 0xf] + + x_thread[4 * i + 1] * lut[(ws[i] >> 4) & 0xf] + + x_thread[4 * i + 2] * lut[(ws[i] >> 8) & 0xf] + + x_thread[4 * i + 3] * lut[(ws[i] >> 12) & 0xf]); + } + return scale * accum; +} + +template +inline void qouter( + const thread uint8_t* w, + U x, + U scale, + thread U* result, + const threadgroup U* lut) { + for (int i = 0; i < (values_per_thread / 2); i++) { + result[2 * i] += x * scale * lut[w[i] & 0xf]; + result[2 * i + 1] += x * scale * lut[(w[i] >> 4) & 0xf]; + } +} + +template +inline void dequantize( + const device uint8_t* w, + U scale, + threadgroup U* w_local, + const threadgroup U* lut) { + for (int i = 0; i < (N / 2); i++) { + w_local[2 * i] = scale * lut[w[i] & 0xf]; + w_local[2 * i + 1] = scale * lut[(w[i] >> 4) & 0xf]; + } +} + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size, + short group_size, + typename S> +struct QuantizedBlockLoader { + static_assert( + BCOLS <= group_size, + "The group size should be larger than the columns"); + static_assert( + group_size % BCOLS == 0, + "The group size should be divisible by the columns"); + + MLX_MTL_CONST short pack_factor = get_pack_factor<8>(); + MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack(); + MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor; + MLX_MTL_CONST short n_reads = + (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size; + MLX_MTL_CONST short group_steps = group_size / BCOLS; + + const int src_ld; + const int tile_stride; + short group_step_cnt; + const int group_stride; + + const short thread_idx; + const short bi; + const short bj; + + threadgroup T* dst; + const device uint8_t* src; + const device S* scales; + threadgroup T* lut; + + QuantizedBlockLoader( + const device uint8_t* src_, + const device S* scales_, + const int src_ld_, + threadgroup T* dst_, + threadgroup T* lut_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(src_ld_), + tile_stride( + reduction_dim ? BCOLS_PACKED * bytes_per_pack + : BROWS * src_ld * bytes_per_pack / pack_factor), + group_step_cnt(0), + group_stride(BROWS * src_ld / group_size), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(n_reads * thread_idx / BCOLS_PACKED), + bj((n_reads * thread_idx) % BCOLS_PACKED), + dst(dst_ + bi * dst_ld + bj * pack_factor), + src(src_ + bi * src_ld * bytes_per_pack / pack_factor + + bj * bytes_per_pack), + scales(scales_ + bi * src_ld / group_size), + lut(lut_) { + load_mxfp4_lut(lut, simd_group_id, simd_lane_id); + } + + void load_unsafe() const { + if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { + return; + } + + T scale = dequantize_scale(*scales); + for (int i = 0; i < n_reads; i++) { + dequantize( + src + i * bytes_per_pack, scale, dst + i * pack_factor, lut); + } + } + + void load_safe(short2 src_tile_dim) const { + if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { + return; + } + + if (reduction_dim == 1 && bi >= src_tile_dim.x) { + for (int i = 0; i < n_reads * pack_factor; i++) { + dst[i] = T(0); + } + return; + } + + if (reduction_dim == 0 && bi >= src_tile_dim.y) { + for (int i = 0; i < n_reads * pack_factor; i++) { + dst[i] = T(0); + } + return; + } + + T scale = dequantize_scale(*scales); + for (int i = 0; i < n_reads; i++) { + dequantize( + (device uint8_t*)(src + i * bytes_per_pack), + scale, + dst + i * pack_factor, + lut); + } + } + + void next() { + src += tile_stride; + if (reduction_dim == 1) { + if (group_steps > 1) { + group_step_cnt++; + if (group_step_cnt == group_steps) { + group_step_cnt = 0; + scales++; + } + } else { + scales++; + } + } else { + scales += group_stride; + } + } +}; + +template +METAL_FUNC void mxfp4_qmv_quad_impl( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid [[threadgroup_position_in_grid]], + uint quad_gid [[quadgroup_index_in_threadgroup]], + uint quad_lid [[thread_index_in_quadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]], + threadgroup float* lut) { + constexpr int quads_per_simd = SIMD_SIZE / QUAD_SIZE; + constexpr int pack_factor = 8; + constexpr int values_per_thread = D / QUAD_SIZE; + constexpr int packs_per_thread = values_per_thread / pack_factor; + constexpr int scale_step_per_thread = group_size / values_per_thread; + constexpr int results_per_quadgroup = 8; + + typedef float U; + + thread U x_thread[values_per_thread]; + thread U result[results_per_quadgroup] = {0}; + load_mxfp4_lut(lut, simd_gid, simd_lid); + + // Adjust positions + const int in_vec_size_w = in_vec_size / pack_factor; + const int in_vec_size_g = in_vec_size / group_size; + const int out_row = tid.y * quads_per_simd * results_per_quadgroup + quad_gid; + + w += out_row * in_vec_size_w + quad_lid * packs_per_thread; + scales += out_row * in_vec_size_g + quad_lid / scale_step_per_thread; + x += tid.x * in_vec_size + quad_lid * values_per_thread; + y += tid.x * out_vec_size + out_row; + + load_vector(x, x_thread); + + for (int row = 0; row < results_per_quadgroup; row++) { + auto wl = (const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd); + const device S* sl = scales + row * in_vec_size_g * quads_per_simd; + + U s = dequantize_scale(sl[0]); + if (row * quads_per_simd + out_row < out_vec_size) { + result[row] += qdot(wl, x_thread, s, lut); + } + } + + for (int row = 0; row < results_per_quadgroup; row++) { + result[row] = quad_sum(result[row]); + if (quad_lid == 0 && row * quads_per_simd + out_row < out_vec_size) { + y[row * quads_per_simd] = static_cast(result[row]); + } + } +} + +template +METAL_FUNC void mxfp4_qmv_fast_impl( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]], + threadgroup float* lut) { + constexpr int packs_per_thread = 2; + constexpr int num_simdgroups = 2; + constexpr int results_per_simdgroup = 4; + constexpr int pack_factor = get_pack_factor<32>(); + constexpr int bytes_per_pack = get_bytes_per_pack<32>(); + constexpr int values_per_thread = pack_factor * packs_per_thread; + constexpr int block_size = values_per_thread * SIMD_SIZE; + constexpr int scale_step_per_thread = group_size / values_per_thread; + + const device uint8_t* ws = (const device uint8_t*)w; + + typedef float U; + thread U x_thread[values_per_thread]; + thread U result[results_per_simdgroup] = {0}; + load_mxfp4_lut(lut, simd_gid, simd_lid); + + // Adjust positions + const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; + const int in_vec_size_g = in_vec_size / group_size; + const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + + simd_gid * results_per_simdgroup; + + ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; + scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + x += tid.x * in_vec_size + simd_lid * values_per_thread; + y += tid.x * out_vec_size + out_row; + + for (int k = 0; k < in_vec_size; k += block_size) { + load_vector(x, x_thread); + + for (int row = 0; row < results_per_simdgroup; row++) { + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); + const device auto* sl = scales + row * in_vec_size_g; + + U s = dequantize_scale(sl[0]); + result[row] += qdot(wl, x_thread, s, lut); + } + + ws += block_size * bytes_per_pack / pack_factor; + scales += block_size / group_size; + x += block_size; + } + + for (int row = 0; row < results_per_simdgroup; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0) { + y[row] = static_cast(result[row]); + } + } +} + +template +METAL_FUNC void mxfp4_qmv_impl( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]], + threadgroup float* lut) { + constexpr int num_simdgroups = 2; + constexpr int results_per_simdgroup = 4; + constexpr int packs_per_thread = 1; + constexpr int pack_factor = get_pack_factor<32>(); + constexpr int bytes_per_pack = get_bytes_per_pack<32>(); + + constexpr int values_per_thread = pack_factor * packs_per_thread; + constexpr int block_size = values_per_thread * SIMD_SIZE; + constexpr int scale_step_per_thread = group_size / values_per_thread; + + const device uint8_t* ws = (const device uint8_t*)w; + + typedef float U; + + thread U x_thread[values_per_thread]; + thread U result[results_per_simdgroup] = {0}; + load_mxfp4_lut(lut, simd_gid, simd_lid); + + // Adjust positions + const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; + const int in_vec_size_g = in_vec_size / group_size; + const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + + simd_gid * results_per_simdgroup; + const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row); + + if (out_row >= out_vec_size) { + return; + } + + // In this case we need to properly guard all our reads because there isn't + // even 1 tile in the matrix + if (out_vec_size < (num_simdgroups * results_per_simdgroup)) { + ws += + out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; + scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + x += tid.x * in_vec_size + simd_lid * values_per_thread; + y += tid.x * out_vec_size + out_row; + + int k = 0; + for (; k < in_vec_size - block_size; k += block_size) { + load_vector(x, x_thread); + + for (int row = 0; out_row + row < out_vec_size; row++) { + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); + const device auto* sl = scales + row * in_vec_size_g; + + S s = sl[0]; + result[row] += qdot(wl, x_thread, s, lut); + } + + ws += block_size * bytes_per_pack / pack_factor; + scales += block_size / group_size; + x += block_size; + } + const int remaining = clamp( + static_cast(in_vec_size - k - simd_lid * values_per_thread), + 0, + values_per_thread); + if (remaining > 0) { + load_vector_safe(x, x_thread, remaining); + + for (int row = 0; out_row + row < out_vec_size; row++) { + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); + const device auto* sl = scales + row * in_vec_size_g; + + U s = dequantize_scale(sl[0]); + result[row] += qdot(wl, x_thread, s, lut); + } + } + + for (int row = 0; out_row + row < out_vec_size; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0) { + y[row] = static_cast(result[row]); + } + } + } + + // In this case the last tile is moved back to redo some output values + else { + ws += used_out_row * in_vec_size_w + + simd_lid * packs_per_thread * bytes_per_pack; + scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + x += tid.x * in_vec_size + simd_lid * values_per_thread; + y += tid.x * out_vec_size + used_out_row; + + int k = 0; + for (; k < in_vec_size - block_size; k += block_size) { + load_vector(x, x_thread); + + for (int row = 0; row < results_per_simdgroup; row++) { + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); + const device auto* sl = scales + row * in_vec_size_g; + + U s = dequantize_scale(sl[0]); + result[row] += qdot(wl, x_thread, s, lut); + } + + ws += block_size * bytes_per_pack / pack_factor; + scales += block_size / group_size; + x += block_size; + } + const int remaining = clamp( + static_cast(in_vec_size - k - simd_lid * values_per_thread), + 0, + values_per_thread); + if (remaining > 0) { + load_vector_safe(x, x_thread, remaining); + + for (int row = 0; row < results_per_simdgroup; row++) { + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); + const device auto* sl = scales + row * in_vec_size_g; + + U s = dequantize_scale(sl[0]); + result[row] += + qdot_safe(wl, x_thread, s, lut, remaining); + } + } + for (int row = 0; row < results_per_simdgroup; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0) { + y[row] = static_cast(result[row]); + } + } + } +} + +template +METAL_FUNC void mxfp4_qvm_impl( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + const int in_vec_size, + const int out_vec_size, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]], + threadgroup float* lut) { + constexpr int num_simdgroups = 2; + constexpr int pack_factor = get_pack_factor<32>(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + + constexpr int tn = 32 / pack_factor; + constexpr int block_size = SIMD_SIZE; + + using W_T = uint32_t; + const device W_T* ws = (const device W_T*)w; + + typedef float U; + typedef struct { + W_T wi[tn * bytes_per_pack]; + } vec_w; + + thread vec_w w_local; + thread U result[tn * pack_factor] = {0}; + thread U scale = 0; + thread U x_local = 0; + + load_mxfp4_lut(lut, simd_gid, simd_lid); + + // Adjust positions + const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor; + const int out_vec_size_g = out_vec_size / group_size; + int out_col = pack_factor * tn * (tid.y * num_simdgroups + simd_gid); + ws += out_col * bytes_per_pack / pack_factor + simd_lid * out_vec_size_w; + scales += out_col / group_size + simd_lid * out_vec_size_g; + x += tid.x * in_vec_size + simd_lid; + y += tid.x * out_vec_size + out_col; + + if (out_col >= out_vec_size) { + return; + } + + // Loop over in_vec in blocks of block_size + int remaining = in_vec_size % block_size; + if (remaining == 0) { + for (int i = 0; i < in_vec_size; i += block_size) { + x_local = *x; + scale = dequantize_scale(*scales); + w_local = *((device vec_w*)ws); + qouter( + (thread uint8_t*)&w_local, x_local, scale, result, lut); + + x += block_size; + scales += block_size * out_vec_size_g; + ws += block_size * out_vec_size_w; + } + } else { + for (int i = block_size; i < in_vec_size; i += block_size) { + x_local = *x; + scale = dequantize_scale(*scales); + w_local = *((device vec_w*)ws); + + qouter( + (thread uint8_t*)&w_local, x_local, scale, result, lut); + + x += block_size; + scales += block_size * out_vec_size_g; + ws += block_size * out_vec_size_w; + } + if (static_cast(simd_lid) < remaining) { + x_local = *x; + scale = dequantize_scale(*scales); + w_local = *((device vec_w*)ws); + } else { + x_local = 0; + scale = 0; + } + qouter( + (thread uint8_t*)&w_local, x_local, scale, result, lut); + } + +// Accumulate in the simdgroup +#pragma clang loop unroll(full) + for (int k = 0; k < tn * pack_factor; k++) { + result[k] = simd_sum(result[k]); + } + + // Store the result + if (simd_lid == 0) { +#pragma clang loop unroll(full) + for (int k = 0; k < tn * pack_factor; k++) { + y[k] = static_cast(result[k]); + } + } +} + +template < + typename T, + const int group_size, + typename S, + const bool aligned_N, + const int BM = 32, + const int BK = 32, + const int BN = 32> +METAL_FUNC void mxfp4_qmm_t_impl( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + threadgroup T* Xs, + threadgroup T* Ws, + const constant int& K, + const constant int& N, + const constant int& M, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]], + threadgroup T* lut) { + static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); + static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); + + (void)lid; + + constexpr int WM = 2; + constexpr int WN = 2; + constexpr int pack_factor = get_pack_factor<8>(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + + // Instantiate the appropriate BlockMMA and Loader + using mma_t = mlx::steel:: + BlockMMA; + using loader_x_t = + mlx::steel::BlockLoader; + using loader_w_t = QuantizedBlockLoader< + T, + BN, + BK, + BK_padded, + 1, + WM * WN * SIMD_SIZE, + group_size, + S>; + + // Set the block + const int K_w = K * bytes_per_pack / pack_factor; + const int K_g = K / group_size; + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + + auto wl = (const device uint8_t*)w; + + x += y_row * static_cast(K); + wl += y_col * K_w; + scales += y_col * K_g; + y += y_row * static_cast(N) + y_col; + + // Make the x loader and mma operation + const short num_els = min(BM, M - y_row); + const short num_outs = min(BN, N - y_col); + loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); + loader_w_t loader_w(wl, scales, K, Ws, lut, simd_gid, simd_lid); + mma_t mma_op(simd_gid, simd_lid); + + if (num_els < BM) { + if (!aligned_N && num_outs < BN) { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(BK, num_els)); + loader_w.load_safe(short2(BK, num_outs)); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } else { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(BK, num_els)); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } + } else { + if (!aligned_N && num_outs < BN) { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_unsafe(); + loader_w.load_safe(short2(BK, num_outs)); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } else { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_unsafe(); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } + } + + // Store results to device memory + threadgroup_barrier(mem_flags::mem_threadgroup); + if (num_els < BM || num_outs < BN) { + mma_op.store_result_safe(y, N, short2(num_outs, num_els)); + } else { + mma_op.store_result(y, N); + } +} + +template < + typename T, + const int group_size, + typename S, + const int BM = 32, + const int BK = 32, + const int BN = 32> +METAL_FUNC void mxfp4_qmm_n_impl( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + threadgroup T* Xs, + threadgroup T* Ws, + const constant int& K, + const constant int& N, + const constant int& M, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]], + threadgroup T* lut) { + static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); + static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); + + (void)lid; + + constexpr int WM = 2; + constexpr int WN = 2; + constexpr int pack_factor = get_pack_factor<8>(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + // Instantiate the appropriate BlockMMA and Loader + using mma_t = mlx::steel:: + BlockMMA; + using loader_x_t = mlx::steel:: + BlockLoader; + using loader_w_t = QuantizedBlockLoader< + T, + BK, + BN, + BN_padded, + 0, + WM * WN * SIMD_SIZE, + group_size, + S>; + + auto wl = (const device uint8_t*)w; + + // Set the block + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + x += y_row * static_cast(K); + wl += y_col * bytes_per_pack / pack_factor; + scales += y_col / group_size; + y += y_row * static_cast(N) + y_col; + + // Make the x loader and mma operation + const short num_els = min(BM, M - y_row); + loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); + loader_w_t loader_w(wl, scales, N, Ws, lut, simd_gid, simd_lid); + mma_t mma_op(simd_gid, simd_lid); + + if (num_els < BM) { + if ((K % BK) != 0) { + const int k_blocks = K / BK; + for (int k = 0; k < k_blocks; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(BK, num_els)); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + const short num_k = K - k_blocks * BK; + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(num_k, num_els)); + loader_w.load_safe(short2(BN, num_k)); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + } else { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(BK, num_els)); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } + } else { + if ((K % BK) != 0) { + const int k_blocks = K / BK; + for (int k = 0; k < k_blocks; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_unsafe(); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + const short num_k = K - k_blocks * BK; + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(num_k, BM)); + loader_w.load_safe(short2(BN, num_k)); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + } else { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_unsafe(); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } + } + + // Store results to device memory + threadgroup_barrier(mem_flags::mem_threadgroup); + if (num_els < BM) { + mma_op.store_result_safe(y, N, short2(BN, num_els)); + } else { + mma_op.store_result(y, N); + } +} + +template +METAL_FUNC void adjust_matrix_offsets( + const device T*& x, + const device uint32_t*& w, + const device S*& scales, + device T*& y, + int output_stride, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]]) { + // Set the input/output matrices + uint32_t x_idx = tid.z; + uint32_t w_idx = tid.z; + if (x_batch_ndims == 1) { + x += x_idx * x_strides[0]; + } else { + x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); + } + if (w_batch_ndims == 1) { + w += w_idx * w_strides[0]; + scales += w_idx * s_strides[0]; + } else { + ulong2 idx = elem_to_loc_broadcast( + w_idx, w_shape, w_strides, s_strides, w_batch_ndims); + w += idx.x; + scales += idx.y; + } + y += tid.z * output_stride; +} + +template +METAL_FUNC void adjust_matrix_offsets( + const device T*& x, + const device uint32_t*& w, + const device S*& scales, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T*& y, + int output_stride, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]]) { + // Set the input/output matrices + uint32_t x_idx; + uint32_t w_idx; + if (batch_ndims == 1) { + x_idx = lhs_indices[tid.z * lhs_strides[0]]; + w_idx = rhs_indices[tid.z * rhs_strides[0]]; + } else { + ulong2 idx = elem_to_loc_broadcast( + tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims); + x_idx = lhs_indices[idx.x]; + w_idx = rhs_indices[idx.y]; + } + if (x_batch_ndims == 1) { + x += x_idx * x_strides[0]; + } else { + x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); + } + if (w_batch_ndims == 1) { + w += w_idx * w_strides[0]; + scales += w_idx * s_strides[0]; + } else { + ulong2 idx = elem_to_loc_broadcast( + w_idx, w_shape, w_strides, s_strides, w_batch_ndims); + w += idx.x; + scales += idx.y; + } + y += tid.z * output_stride; +} + +template +[[kernel]] void mxfp4_qmv_quad( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint quad_gid [[quadgroup_index_in_threadgroup]], + uint quad_lid [[thread_index_in_quadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if (batched) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + y, + out_vec_size * M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + } + threadgroup float lut[16]; + mxfp4_qmv_quad_impl( + w, + scales, + x, + y, + in_vec_size, + out_vec_size, + tid, + quad_gid, + quad_lid, + simd_gid, + simd_lid, + lut); +} + +template +[[kernel]] void mxfp4_qmv_fast( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if (batched) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + y, + out_vec_size * M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + } + threadgroup float lut[16]; + mxfp4_qmv_fast_impl( + w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); +} + +template +[[kernel]] void mxfp4_qmv( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if (batched) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + y, + out_vec_size * M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + } + threadgroup float lut[16]; + mxfp4_qmv_impl( + w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); +} + +template +[[kernel]] void mxfp4_qvm( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if (batched) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + y, + out_vec_size * M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + } + threadgroup float lut[16]; + mxfp4_qvm_impl( + w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); +} + +template +[[kernel]] void mxfp4_qvm_split_k( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int& final_block_size, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + y, + out_vec_size * M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + + // When (in_vec_size % split_k != 0) the final block needs to be smaller + int in_vec_size_adj = + tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size; + + threadgroup float lut[16]; + mxfp4_qvm_impl( + w, + scales, + x, + y, + in_vec_size_adj, + out_vec_size, + tid, + simd_gid, + simd_lid, + lut); +} + +template < + typename T, + const int group_size, + typename S, + const bool aligned_N, + const bool batched, + const int BM = 32, + const int BK = 32, + const int BN = 32> +[[kernel]] void mxfp4_qmm_t( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BN * BK_padded]; + threadgroup T lut[16]; + + if (batched) { + adjust_matrix_offsets( + x, + w, + scales, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + } + mxfp4_qmm_t_impl( + w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut); +} + +template < + typename T, + const int group_size, + typename S, + const bool batched, + const int BM = 32, + const int BK = 32, + const int BN = 32> +[[kernel]] void mxfp4_qmm_n( + const device uint32_t* w, + const device S* scales, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BK * BN_padded]; + threadgroup T lut[16]; + + if (batched) { + adjust_matrix_offsets( + x, + w, + scales, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + } + + mxfp4_qmm_n_impl( + w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut); +} + +template +[[kernel]] void mxfp4_gather_qmv_fast( + const device uint32_t* w, + const device S* scales, + const device T* x, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + lhs_indices, + rhs_indices, + y, + out_vec_size * M, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + threadgroup float lut[16]; + mxfp4_qmv_fast_impl( + w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); +} + +template +[[kernel]] void mxfp4_gather_qmv( + const device uint32_t* w, + const device S* scales, + const device T* x, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + lhs_indices, + rhs_indices, + y, + out_vec_size * M, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + threadgroup float lut[16]; + mxfp4_qmv_impl( + w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); +} + +template +[[kernel]] void mxfp4_gather_qvm( + const device uint32_t* w, + const device S* scales, + const device T* x, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + lhs_indices, + rhs_indices, + y, + out_vec_size * M, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + threadgroup float lut[16]; + mxfp4_qvm_impl( + w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid, lut); +} + +template < + typename T, + const int group_size, + typename S, + const bool aligned_N, + const int BM = 32, + const int BK = 32, + const int BN = 32> +[[kernel]] void mxfp4_gather_qmm_t( + const device uint32_t* w, + const device S* scales, + const device T* x, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BN * BK_padded]; + threadgroup T lut[16]; + + adjust_matrix_offsets( + x, + w, + scales, + lhs_indices, + rhs_indices, + y, + M * N, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + mxfp4_qmm_t_impl( + w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut); +} + +template < + typename T, + const int group_size, + typename S, + const int BM = 32, + const int BK = 32, + const int BN = 32> +[[kernel]] void mxfp4_gather_qmm_n( + const device uint32_t* w, + const device S* scales, + const device T* x, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BK * BN_padded]; + threadgroup T lut[16]; + + adjust_matrix_offsets( + x, + w, + scales, + lhs_indices, + rhs_indices, + y, + M * N, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + mxfp4_qmm_n_impl( + w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut); +} + +template < + typename T, + int group_size, + typename S, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose> +[[kernel]] void mxfp4_gather_qmm_rhs( + const device T* x, + const device uint32_t* w, + const device S* scales, + const device uint32_t* indices, + device T* y, + const constant int& M, + const constant int& N, + const constant int& K, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) { + constexpr int pack_factor = get_pack_factor<8>(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + threadgroup T lut[16]; + + using mma_t = mlx::steel::BlockMMA< + T, + T, + BM, + BN, + BK, + WM, + WN, + false, + transpose, + BK_padded, + transpose ? BK_padded : BN_padded>; + using loader_x_t = + mlx::steel::BlockLoader; + using loader_w_t = QuantizedBlockLoader< + T, + transpose ? BN : BK, + transpose ? BK : BN, + transpose ? BK_padded : BN_padded, + transpose, + WM * WN * SIMD_SIZE, + group_size, + S>; + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded]; + + // Compute the block + const int K_w = K * bytes_per_pack / pack_factor; + const int K_g = K / group_size; + const int N_w = N * bytes_per_pack / pack_factor; + const int N_g = N / group_size; + const int K_it = K / BK; + const size_t stride_w = transpose ? N * K_w : K * N_w; + const size_t stride_s = transpose ? N * K_g : K * N_g; + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + const size_t y_row_long = size_t(y_row); + const size_t y_col_long = size_t(y_col); + + // Prepare threadgroup bounds + const short tgp_bm = align_M ? BM : short(min(BM, M - y_row)); + const short tgp_bn = align_N ? BN : short(min(BN, N - y_col)); + + // Calculate the final tiles in the case that K is not aligned + const int k_remain = K - K_it * BK; + const short2 tile_x = short2(k_remain, tgp_bm); + const short2 tile_w = + transpose ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + + // Move x and output to the correct block + auto wl = (const device uint8_t*)w; + x += y_row_long * K; + y += y_row_long * N + y_col_long; + wl += transpose ? y_col_long * K_w : y_col * bytes_per_pack / pack_factor; + scales += transpose ? y_col_long * K_g : y_col / group_size; + + // Do as many matmuls as necessary + uint32_t index; + short offset; + uint32_t index_next = indices[y_row]; + short offset_next = 0; + int n = 0; + while (n < tgp_bm) { + n++; + offset = offset_next; + index = index_next; + offset_next = tgp_bm; + for (; n < tgp_bm; n++) { + if (indices[y_row + n] != index) { + offset_next = n; + index_next = indices[y_row + n]; + break; + } + } + threadgroup_barrier(mem_flags::mem_none); + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + // Prepare threadgroup loading operations + thread loader_x_t loader_x(x, K, Xs, simd_group_id, simd_lane_id); + thread loader_w_t loader_w( + wl + index * stride_w, + scales + index * stride_s, + transpose ? K : N, + Ws, + lut, + simd_group_id, + simd_lane_id); + + // Matrices are all aligned check nothing + if (align_M && align_N) { + gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize(Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + + // Store results to device memory + if (offset_next - offset == BM) { + mma_op.store_result(y, N); + } else { + mma_op.store_result_slice( + y, N, short2(0, offset), short2(BN, offset_next)); + } + } else { + // Tile aligned so check outside of the hot loop + if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { + gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize( + Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + + // Store results to device memory + if (offset_next - offset == BM) { + mma_op.store_result(y, N); + } else { + mma_op.store_result_slice( + y, N, short2(0, offset), short2(BN, offset_next)); + } + } + + // Tile partially aligned check rows + else if (align_N || tgp_bn == BN) { + gemm_loop_unaligned( + Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize( + Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + mma_op.store_result_slice( + y, N, short2(0, offset), short2(BN, offset_next)); + } + + // Tile partially aligned check cols + else if (align_M || tgp_bm == BM) { + gemm_loop_unaligned( + Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize( + Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + mma_op.store_result_slice( + y, N, short2(0, offset), short2(tgp_bn, offset_next)); + } + + // Nothing aligned so check both rows and cols + else { + gemm_loop_unaligned( + Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize( + Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + mma_op.store_result_slice( + y, N, short2(0, offset), short2(tgp_bn, offset_next)); + } + } + } +} diff --git a/Source/Cmlx/mlx-generated/metal/gemv_masked.h b/Source/Cmlx/mlx-generated/metal/gemv_masked.h index 451ed728..c5723c93 100644 --- a/Source/Cmlx/mlx-generated/metal/gemv_masked.h +++ b/Source/Cmlx/mlx-generated/metal/gemv_masked.h @@ -262,36 +262,37 @@ struct GEMVKernel { vec_mask_offset += vec_mask_step; } - if (leftover > 0 && - (!has_operand_mask || - (bool(mat_mask[mat_mask_offset]) && - bool(vec_mask[vec_mask_offset])))) { - T block_scale{1}; - if (has_mul_operand_mask) { - block_scale = - T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); - } + if (leftover > 0) { + if (!has_operand_mask || + (bool(mat_mask[mat_mask_offset]) && + bool(vec_mask[vec_mask_offset]))) { + T block_scale{1}; + if (has_mul_operand_mask) { + block_scale = + T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); + } - load_safe(in_vec, v_coeff, bn, in_size); + load_safe(in_vec, v_coeff, bn, in_size); - // Apply scale - if (has_mul_operand_mask) { - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - v_coeff[tn] *= block_scale; + // Apply scale + if (has_mul_operand_mask) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + v_coeff[tn] *= block_scale; + } } - } - - // Per thread work loop - MLX_MTL_PRAGMA_UNROLL - for (int tm = 0; tm < TM; tm++) { - // Load for the row - load_safe(&mat[tm * matrix_ld], inter, bn, in_size); - // Accumulate results + // Per thread work loop MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - result[tm] += inter[tn] * v_coeff[tn]; + for (int tm = 0; tm < TM; tm++) { + // Load for the row + load_safe(&mat[tm * matrix_ld], inter, bn, in_size); + + // Accumulate results + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + result[tm] += inter[tn] * v_coeff[tn]; + } } } } @@ -544,31 +545,32 @@ struct GEMVTKernel { vec_mask_offset += vec_mask_step; } - if (leftover > 0 && - (!has_operand_mask || - (bool(mat_mask[mat_mask_offset]) && - bool(vec_mask[vec_mask_offset])))) { - T block_scale{1}; - if (has_mul_operand_mask) { - block_scale = - T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); - } - - for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) { - v_coeff[tm] = static_cast(in_vec[bm + tm]); - + if (leftover > 0) { + if (!has_operand_mask || + (bool(mat_mask[mat_mask_offset]) && + bool(vec_mask[vec_mask_offset]))) { + T block_scale{1}; if (has_mul_operand_mask) { - v_coeff[tm] *= block_scale; + block_scale = + T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); } - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn]; - } + for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) { + v_coeff[tm] = static_cast(in_vec[bm + tm]); - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - result[tn] += v_coeff[tm] * inter[tn]; + if (has_mul_operand_mask) { + v_coeff[tm] *= block_scale; + } + + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn]; + } + + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + result[tn] += v_coeff[tm] * inter[tn]; + } } } } diff --git a/Source/Cmlx/mlx-generated/metal/hadamard.h b/Source/Cmlx/mlx-generated/metal/hadamard.h index 8f2d8cc1..d6c08f17 100644 --- a/Source/Cmlx/mlx-generated/metal/hadamard.h +++ b/Source/Cmlx/mlx-generated/metal/hadamard.h @@ -26,7 +26,7 @@ METAL_FUNC void radix_func(thread float* x) { } } -template +template [[kernel]] void hadamard_n( const device T* in [[buffer(0)]], device T* out [[buffer(1)]], @@ -46,18 +46,25 @@ template constexpr short logFinal = logN % logR; constexpr short final_radix = 1 << (logFinal); - int batch_idx = elem.x * N; - short i = elem.y; + int batch_idx = elem.y * N * stride + elem.z; + short i = elem.x; threadgroup T buf[N]; // Read values from device - STEEL_PRAGMA_UNROLL - for (short j = 0; j < max_radix / read_width; j++) { - short index = j * read_width * num_threads + i * read_width; + if (stride == 1) { STEEL_PRAGMA_UNROLL - for (short r = 0; r < read_width; r++) { - buf[index + r] = in[batch_idx + index + r]; + for (short j = 0; j < max_radix / read_width; j++) { + short index = j * read_width * num_threads + i * read_width; + STEEL_PRAGMA_UNROLL + for (short r = 0; r < read_width; r++) { + buf[index + r] = in[batch_idx + index + r]; + } + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < max_radix; j++) { + buf[j * num_threads + i] = in[batch_idx + (j * num_threads + i) * stride]; } } @@ -113,12 +120,20 @@ template } // Write values to device - STEEL_PRAGMA_UNROLL - for (short j = 0; j < max_radix / read_width; j++) { - short index = j * read_width * num_threads + i * read_width; + if (stride == 1) { STEEL_PRAGMA_UNROLL - for (short r = 0; r < read_width; r++) { - out[batch_idx + index + r] = T(buf[index + r] * scale); + for (short j = 0; j < max_radix / read_width; j++) { + short index = j * read_width * num_threads + i * read_width; + STEEL_PRAGMA_UNROLL + for (short r = 0; r < read_width; r++) { + out[batch_idx + index + r] = T(buf[index + r] * scale); + } + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < max_radix; j++) { + out[batch_idx + (j * num_threads + i) * stride] = + buf[j * num_threads + i]; } } } diff --git a/Source/Cmlx/mlx-generated/metal/gather.h b/Source/Cmlx/mlx-generated/metal/indexing/gather.h similarity index 98% rename from Source/Cmlx/mlx-generated/metal/gather.h rename to Source/Cmlx/mlx-generated/metal/indexing/gather.h index 532c1a01..d99c46c6 100644 --- a/Source/Cmlx/mlx-generated/metal/gather.h +++ b/Source/Cmlx/mlx-generated/metal/indexing/gather.h @@ -2,7 +2,7 @@ #pragma once -#include "indexing.h" +#include "../indexing/indexing.h" template METAL_FUNC void gather_impl( diff --git a/Source/Cmlx/mlx-generated/metal/gather_axis.h b/Source/Cmlx/mlx-generated/metal/indexing/gather_axis.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/gather_axis.h rename to Source/Cmlx/mlx-generated/metal/indexing/gather_axis.h diff --git a/Source/Cmlx/mlx-generated/metal/indexing/gather_front.h b/Source/Cmlx/mlx-generated/metal/indexing/gather_front.h new file mode 100644 index 00000000..2cd6eb41 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/indexing/gather_front.h @@ -0,0 +1,24 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "../indexing/indexing.h" + +template +[[kernel]] void gather_front( + const device T* src, + const device IdxT* indices, + device T* out, + const constant int64_t& stride, + const constant int& size, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + auto idx = offset_neg_idx(indices[index.y], size); + LocT src_idx = static_cast(stride) * idx; + LocT out_idx = static_cast(stride) * index.y; + + int s_idx = N * index.x; + for (int i = 0; i < N && s_idx < stride; ++i, ++s_idx) { + out[out_idx + s_idx] = src[src_idx + s_idx]; + } +} diff --git a/Source/Cmlx/mlx-generated/metal/indexing.h b/Source/Cmlx/mlx-generated/metal/indexing/indexing.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/indexing.h rename to Source/Cmlx/mlx-generated/metal/indexing/indexing.h diff --git a/Source/Cmlx/mlx-generated/metal/scatter.h b/Source/Cmlx/mlx-generated/metal/indexing/scatter.h similarity index 98% rename from Source/Cmlx/mlx-generated/metal/scatter.h rename to Source/Cmlx/mlx-generated/metal/indexing/scatter.h index 5792a2a4..99e65d20 100644 --- a/Source/Cmlx/mlx-generated/metal/scatter.h +++ b/Source/Cmlx/mlx-generated/metal/indexing/scatter.h @@ -2,7 +2,7 @@ #pragma once -#include "indexing.h" +#include "../indexing/indexing.h" template < typename T, diff --git a/Source/Cmlx/mlx-generated/metal/scatter_axis.h b/Source/Cmlx/mlx-generated/metal/indexing/scatter_axis.h similarity index 100% rename from Source/Cmlx/mlx-generated/metal/scatter_axis.h rename to Source/Cmlx/mlx-generated/metal/indexing/scatter_axis.h diff --git a/Source/Cmlx/mlx-generated/metal/layer_norm.metal b/Source/Cmlx/mlx-generated/metal/layer_norm.metal index 2a628d11..e1c862c9 100644 --- a/Source/Cmlx/mlx-generated/metal/layer_norm.metal +++ b/Source/Cmlx/mlx-generated/metal/layer_norm.metal @@ -9,7 +9,42 @@ using namespace metal; constant bool has_w [[function_constant(20)]]; -template +template +inline void initialize_buffer( + threadgroup float* xs, + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + if (simd_group_id == 0) { + for (int i = 0; i < N; i++) { + xs[N * simd_lane_id + i] = 0; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); +} + +template +inline void threadgroup_sum( + thread float* x, + threadgroup float* xs, + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + for (int i = 0; i < N; i++) { + x[i] = simd_sum(x[i]); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (simd_lane_id == 0) { + for (int i = 0; i < N; i++) { + xs[N * simd_group_id + i] = x[i]; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + for (int i = 0; i < N; i++) { + x[i] = xs[N * simd_lane_id + i]; + x[i] = simd_sum(x[i]); + } +} + +template [[kernel]] void layer_norm_single_row( const device T* x, const device T* w, @@ -23,90 +58,71 @@ template uint lid [[thread_position_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - float sumx = 0; - float sumx2 = 0; - float thread_x[N_READS]; - constexpr int SIMD_SIZE = 32; - threadgroup float local_sumx[SIMD_SIZE]; - threadgroup float local_sumx2[SIMD_SIZE]; - threadgroup float local_mean[1]; - threadgroup float local_normalizer[1]; + // Initialize the registers and threadgroup memory + float thread_x[N_READS] = {0}; + threadgroup float local_buffer[SIMD_SIZE] = {0}; + initialize_buffer(local_buffer, simd_lane_id, simd_group_id); + // Advance the pointers x += gid * size_t(axis_size) + lid * N_READS; w += w_stride * lid * N_READS; b += b_stride * lid * N_READS; + out += gid * size_t(axis_size) + lid * N_READS; + + // Compute some variables for reading writing etc + const bool safe = lid * N_READS + N_READS <= axis_size; + const int n = axis_size - lid * N_READS; - if (lid * N_READS + N_READS <= axis_size) { + // Read the inputs + if (safe) { for (int i = 0; i < N_READS; i++) { thread_x[i] = x[i]; - sumx2 += thread_x[i] * thread_x[i]; - sumx += thread_x[i]; } } else { - for (int i = 0; i < N_READS; i++) { - if ((lid * N_READS + i) < axis_size) { - thread_x[i] = x[i]; - sumx2 += thread_x[i] * thread_x[i]; - sumx += thread_x[i]; - } + for (int i = 0; i < n; i++) { + thread_x[i] = x[i]; } } - sumx = simd_sum(sumx); - sumx2 = simd_sum(sumx2); - - // Initialize shared memory - if (simd_group_id == 0) { - local_sumx[simd_lane_id] = 0; - local_sumx2[simd_lane_id] = 0; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Write simd accumulations into shared memory - if (simd_lane_id == 0) { - local_sumx[simd_group_id] = sumx; - local_sumx2[simd_group_id] = sumx2; + // Compute the mean + float mean = 0; + for (int i = 0; i < N_READS; i++) { + mean += thread_x[i]; } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Accumulate over simd groups - if (simd_group_id == 0) { - sumx = simd_sum(local_sumx[simd_lane_id]); - sumx2 = simd_sum(local_sumx2[simd_lane_id]); - if (simd_lane_id == 0) { - float mean = sumx / axis_size; - float variance = sumx2 / axis_size - mean * mean; - - local_mean[0] = mean; - local_normalizer[0] = metal::precise::rsqrt(variance + eps); + threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id); + mean /= axis_size; + + // Compute the normalizer + float normalizer = 0; + if (!safe) { + for (int i = n; i < N_READS; i++) { + thread_x[i] = mean; } } - threadgroup_barrier(mem_flags::mem_threadgroup); - - float mean = local_mean[0]; - float normalizer = local_normalizer[0]; + for (int i = 0; i < N_READS; i++) { + thread_x[i] -= mean; + normalizer += thread_x[i] * thread_x[i]; + } + threadgroup_sum(&normalizer, local_buffer, simd_lane_id, simd_group_id); + normalizer = metal::precise::rsqrt(normalizer / axis_size + eps); // Write the outputs - out += gid * size_t(axis_size) + lid * N_READS; - if (lid * N_READS + N_READS <= axis_size) { + if (safe) { for (int i = 0; i < N_READS; i++) { - thread_x[i] = (thread_x[i] - mean) * normalizer; + thread_x[i] *= normalizer; out[i] = w[w_stride * i] * static_cast(thread_x[i]) + b[b_stride * i]; } } else { - for (int i = 0; i < N_READS; i++) { - if ((lid * N_READS + i) < axis_size) { - thread_x[i] = (thread_x[i] - mean) * normalizer; - out[i] = - w[w_stride * i] * static_cast(thread_x[i]) + b[b_stride * i]; - } + for (int i = 0; i < n; i++) { + thread_x[i] *= normalizer; + out[i] = w[w_stride * i] * static_cast(thread_x[i]) + b[b_stride * i]; } } } -template +template [[kernel]] void layer_norm_looped( const device T* x, const device T* w, @@ -121,71 +137,52 @@ template uint lsize [[threads_per_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - float sumx = 0; - float sumx2 = 0; - constexpr int SIMD_SIZE = 32; - threadgroup float local_sumx[SIMD_SIZE]; - threadgroup float local_sumx2[SIMD_SIZE]; - threadgroup float local_mean[1]; - threadgroup float local_normalizer[1]; + threadgroup float local_buffer[SIMD_SIZE]; + initialize_buffer(local_buffer, simd_lane_id, simd_group_id); x += gid * size_t(axis_size) + lid * N_READS; w += w_stride * lid * N_READS; b += b_stride * lid * N_READS; + // Compute the mean + float mean = 0; for (uint r = 0; r < axis_size; r += lsize * N_READS) { if (r + lid * N_READS + N_READS <= axis_size) { for (int i = 0; i < N_READS; i++) { - float xi = x[i + r]; - sumx2 += xi * xi; - sumx += xi; + mean += x[i + r]; } } else { for (int i = 0; i < N_READS; i++) { if ((r + lid * N_READS + i) < axis_size) { - float xi = x[i + r]; - sumx2 += xi * xi; - sumx += xi; + mean += x[i + r]; } } } } + threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id); + mean /= axis_size; - sumx = simd_sum(sumx); - sumx2 = simd_sum(sumx2); - - // Initialize shared memory - if (simd_group_id == 0) { - local_sumx[simd_lane_id] = 0; - local_sumx2[simd_lane_id] = 0; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Write simd accumulations into shared memory - if (simd_lane_id == 0) { - local_sumx[simd_group_id] = sumx; - local_sumx2[simd_group_id] = sumx2; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Accumulate over simd groups - if (simd_group_id == 0) { - sumx = simd_sum(local_sumx[simd_lane_id]); - sumx2 = simd_sum(local_sumx2[simd_lane_id]); - if (simd_lane_id == 0) { - float mean = sumx / axis_size; - float variance = sumx2 / axis_size - mean * mean; - - local_mean[0] = mean; - local_normalizer[0] = metal::precise::rsqrt(variance + eps); + // Compute the normalizer + float normalizer = 0; + for (uint r = 0; r < axis_size; r += lsize * N_READS) { + if (r + lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + float t = x[i + r] - mean; + normalizer += t * t; + } + } else { + for (int i = 0; i < N_READS; i++) { + if ((r + lid * N_READS + i) < axis_size) { + float t = x[i + r] - mean; + normalizer += t * t; + } + } } } - threadgroup_barrier(mem_flags::mem_threadgroup); - - float mean = local_mean[0]; - float normalizer = local_normalizer[0]; + threadgroup_sum(&normalizer, local_buffer, simd_lane_id, simd_group_id); + normalizer = metal::precise::rsqrt(normalizer / axis_size + eps); // Write the outputs out += gid * size_t(axis_size) + lid * N_READS; @@ -208,7 +205,7 @@ template } } -template +template [[kernel]] void vjp_layer_norm_single_row( const device T* x, const device T* w, @@ -222,133 +219,96 @@ template uint lid [[thread_position_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + constexpr int SIMD_SIZE = 32; + // Advance the input pointers x += gid * size_t(axis_size) + lid * N_READS; g += gid * size_t(axis_size) + lid * N_READS; w += w_stride * lid * N_READS; - // Allocate registers for the computation and accumulators - float thread_x[N_READS]; - float thread_w[N_READS]; - float thread_g[N_READS]; - float sumx = 0; - float sumx2 = 0; - float sumwg = 0; - float sumwgx = 0; + // Initialize the registers and threadgroup memory + float thread_x[N_READS] = {0}; + float thread_w[N_READS] = {0}; + float thread_g[N_READS] = {0}; + threadgroup float local_buffer[3 * SIMD_SIZE]; + initialize_buffer<3>(local_buffer, simd_lane_id, simd_group_id); - constexpr int SIMD_SIZE = 32; - - threadgroup float local_sumx[SIMD_SIZE]; - threadgroup float local_sumx2[SIMD_SIZE]; - threadgroup float local_sumwg[SIMD_SIZE]; - threadgroup float local_sumwgx[SIMD_SIZE]; - threadgroup float local_mean[1]; - threadgroup float local_normalizer[1]; - threadgroup float local_meanwg[1]; - threadgroup float local_meanwgx[1]; + // Compute some variables for reading writing etc + const bool safe = lid * N_READS + N_READS <= axis_size; + const int n = axis_size - lid * N_READS; - if (lid * N_READS + N_READS <= axis_size) { + // Read the inputs + if (safe) { for (int i = 0; i < N_READS; i++) { thread_x[i] = x[i]; - thread_w[i] = w[i * w_stride]; thread_g[i] = g[i]; - float wg = thread_w[i] * thread_g[i]; - sumx += thread_x[i]; - sumx2 += thread_x[i] * thread_x[i]; - sumwg += wg; - sumwgx += wg * thread_x[i]; + thread_w[i] = w[i * w_stride]; } } else { - for (int i = 0; i < N_READS; i++) { - if ((lid * N_READS + i) < axis_size) { - thread_x[i] = x[i]; - thread_w[i] = w[i * w_stride]; - thread_g[i] = g[i]; - float wg = thread_w[i] * thread_g[i]; - sumx += thread_x[i]; - sumx2 += thread_x[i] * thread_x[i]; - sumwg += wg; - sumwgx += wg * thread_x[i]; - } + for (int i = 0; i < n; i++) { + thread_x[i] = x[i]; + thread_g[i] = g[i]; + thread_w[i] = w[i * w_stride]; } } - sumx = simd_sum(sumx); - sumx2 = simd_sum(sumx2); - sumwg = simd_sum(sumwg); - sumwgx = simd_sum(sumwgx); - - // Initialize shared memory - if (simd_group_id == 0) { - local_sumx[simd_lane_id] = 0; - local_sumx2[simd_lane_id] = 0; - local_sumwg[simd_lane_id] = 0; - local_sumwgx[simd_lane_id] = 0; + // Compute the mean + float mean = 0; + for (int i = 0; i < N_READS; i++) { + mean += thread_x[i]; } - threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id); + mean /= axis_size; - // Write simd accumulations into shared memory - if (simd_lane_id == 0) { - local_sumx[simd_group_id] = sumx; - local_sumx2[simd_group_id] = sumx2; - local_sumwg[simd_group_id] = sumwg; - local_sumwgx[simd_group_id] = sumwgx; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Accumulate over simd groups - if (simd_group_id == 0) { - sumx = simd_sum(local_sumx[simd_lane_id]); - sumx2 = simd_sum(local_sumx2[simd_lane_id]); - sumwg = simd_sum(local_sumwg[simd_lane_id]); - sumwgx = simd_sum(local_sumwgx[simd_lane_id]); - if (simd_lane_id == 0) { - float mean = sumx / axis_size; - float variance = sumx2 / axis_size - mean * mean; - - local_mean[0] = mean; - local_normalizer[0] = metal::precise::rsqrt(variance + eps); - local_meanwg[0] = sumwg / axis_size; - local_meanwgx[0] = sumwgx / axis_size; + // Compute the neccesary scaling factors using the mean + if (!safe) { + for (int i = n; i < N_READS; i++) { + thread_x[i] = mean; } } - threadgroup_barrier(mem_flags::mem_threadgroup); - - float mean = local_mean[0]; - float normalizer = local_normalizer[0]; - float meanwg = local_meanwg[0]; - float meanwgxc = local_meanwgx[0] - meanwg * mean; - float normalizer2 = normalizer * normalizer; + float factors[3] = {0}; + constexpr int meanwg = 0; + constexpr int meanwgxc = 1; + constexpr int normalizer2 = 2; + for (int i = 0; i < N_READS; i++) { + thread_x[i] -= mean; + factors[meanwg] += thread_w[i] * thread_g[i]; + factors[meanwgxc] += thread_w[i] * thread_g[i] * thread_x[i]; + factors[normalizer2] += thread_x[i] * thread_x[i]; + } + threadgroup_sum<3>(factors, local_buffer, simd_lane_id, simd_group_id); + factors[meanwg] /= axis_size; + factors[meanwgxc] /= axis_size; + factors[normalizer2] = 1 / (factors[normalizer2] / axis_size + eps); + float normalizer = metal::precise::sqrt(factors[normalizer2]); // Write the outputs gx += gid * size_t(axis_size) + lid * N_READS; gw += gid * size_t(axis_size) + lid * N_READS; - if (lid * N_READS + N_READS <= axis_size) { + if (safe) { for (int i = 0; i < N_READS; i++) { - thread_x[i] = (thread_x[i] - mean) * normalizer; + thread_x[i] *= normalizer; gx[i] = static_cast( - normalizer * (thread_w[i] * thread_g[i] - meanwg) - - thread_x[i] * meanwgxc * normalizer2); + normalizer * (thread_w[i] * thread_g[i] - factors[meanwg]) - + thread_x[i] * factors[meanwgxc] * factors[normalizer2]); if (has_w) { gw[i] = static_cast(thread_g[i] * thread_x[i]); } } } else { - for (int i = 0; i < N_READS; i++) { - if ((lid * N_READS + i) < axis_size) { - thread_x[i] = (thread_x[i] - mean) * normalizer; - gx[i] = static_cast( - normalizer * (thread_w[i] * thread_g[i] - meanwg) - - thread_x[i] * meanwgxc * normalizer2); - if (has_w) { - gw[i] = static_cast(thread_g[i] * thread_x[i]); - } + for (int i = 0; i < n; i++) { + thread_x[i] *= normalizer; + gx[i] = static_cast( + normalizer * (thread_w[i] * thread_g[i] - factors[meanwg]) - + thread_x[i] * factors[meanwgxc] * factors[normalizer2]); + if (has_w) { + gw[i] = static_cast(thread_g[i] * thread_x[i]); } } } } -template +template [[kernel]] void vjp_layer_norm_looped( const device T* x, const device T* w, @@ -363,102 +323,69 @@ template uint lsize [[threads_per_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + constexpr int SIMD_SIZE = 32; + // Advance the input pointers x += gid * size_t(axis_size) + lid * N_READS; g += gid * size_t(axis_size) + lid * N_READS; w += w_stride * lid * N_READS; - // Allocate registers for the accumulators - float sumx = 0; - float sumx2 = 0; - float sumwg = 0; - float sumwgx = 0; - - constexpr int SIMD_SIZE = 32; - - threadgroup float local_sumx[SIMD_SIZE]; - threadgroup float local_sumx2[SIMD_SIZE]; - threadgroup float local_sumwg[SIMD_SIZE]; - threadgroup float local_sumwgx[SIMD_SIZE]; - threadgroup float local_mean[1]; - threadgroup float local_normalizer[1]; - threadgroup float local_meanwg[1]; - threadgroup float local_meanwgx[1]; + threadgroup float local_buffer[3 * SIMD_SIZE]; + initialize_buffer<3>(local_buffer, simd_lane_id, simd_group_id); + // Compute the mean + float mean = 0; + for (uint r = 0; r < axis_size; r += lsize * N_READS) { + if (r + lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + mean += x[i + r]; + } + } else { + for (int i = 0; i < N_READS; i++) { + if ((r + lid * N_READS + i) < axis_size) { + mean += x[i + r]; + } + } + } + } + threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id); + mean /= axis_size; + + // Compute the neccesary scaling factors using the mean + float factors[3] = {0}; + constexpr int meanwg = 0; + constexpr int meanwgxc = 1; + constexpr int normalizer2 = 2; for (uint r = 0; r < axis_size; r += lsize * N_READS) { if (r + lid * N_READS + N_READS <= axis_size) { for (int i = 0; i < N_READS; i++) { - float xi = x[i + r]; + float t = x[i + r] - mean; float wi = w[(i + r) * w_stride]; float gi = g[i + r]; float wg = wi * gi; - sumx += xi; - sumx2 += xi * xi; - sumwg += wg; - sumwgx += wg * xi; + factors[meanwg] += wg; + factors[meanwgxc] += wg * t; + factors[normalizer2] += t * t; } } else { for (int i = 0; i < N_READS; i++) { if ((r + lid * N_READS + i) < axis_size) { - float xi = x[i + r]; + float t = x[i + r] - mean; float wi = w[(i + r) * w_stride]; float gi = g[i + r]; float wg = wi * gi; - sumx += xi; - sumx2 += xi * xi; - sumwg += wg; - sumwgx += wg * xi; + factors[meanwg] += wg; + factors[meanwgxc] += wg * t; + factors[normalizer2] += t * t; } } } } - - sumx = simd_sum(sumx); - sumx2 = simd_sum(sumx2); - sumwg = simd_sum(sumwg); - sumwgx = simd_sum(sumwgx); - - // Initialize shared memory - if (simd_group_id == 0) { - local_sumx[simd_lane_id] = 0; - local_sumx2[simd_lane_id] = 0; - local_sumwg[simd_lane_id] = 0; - local_sumwgx[simd_lane_id] = 0; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Write simd accumulations into shared memory - if (simd_lane_id == 0) { - local_sumx[simd_group_id] = sumx; - local_sumx2[simd_group_id] = sumx2; - local_sumwg[simd_group_id] = sumwg; - local_sumwgx[simd_group_id] = sumwgx; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Accumulate over simd groups - if (simd_group_id == 0) { - sumx = simd_sum(local_sumx[simd_lane_id]); - sumx2 = simd_sum(local_sumx2[simd_lane_id]); - sumwg = simd_sum(local_sumwg[simd_lane_id]); - sumwgx = simd_sum(local_sumwgx[simd_lane_id]); - if (simd_lane_id == 0) { - float mean = sumx / axis_size; - float variance = sumx2 / axis_size - mean * mean; - - local_mean[0] = mean; - local_normalizer[0] = metal::precise::rsqrt(variance + eps); - local_meanwg[0] = sumwg / axis_size; - local_meanwgx[0] = sumwgx / axis_size; - } - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - float mean = local_mean[0]; - float normalizer = local_normalizer[0]; - float meanwg = local_meanwg[0]; - float meanwgxc = local_meanwgx[0] - meanwg * mean; - float normalizer2 = normalizer * normalizer; + threadgroup_sum<3>(factors, local_buffer, simd_lane_id, simd_group_id); + factors[meanwg] /= axis_size; + factors[meanwgxc] /= axis_size; + factors[normalizer2] = 1 / (factors[normalizer2] / axis_size + eps); + float normalizer = metal::precise::sqrt(factors[normalizer2]); // Write the outputs gx += gid * size_t(axis_size) + lid * N_READS; @@ -470,7 +397,8 @@ template float wi = w[(i + r) * w_stride]; float gi = g[i + r]; gx[i + r] = static_cast( - normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2); + normalizer * (wi * gi - factors[meanwg]) - + xi * factors[meanwgxc] * factors[normalizer2]); if (has_w) { gw[i + r] = static_cast(gi * xi); } @@ -482,7 +410,8 @@ template float wi = w[(i + r) * w_stride]; float gi = g[i + r]; gx[i + r] = static_cast( - normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2); + normalizer * (wi * gi - factors[meanwg]) - + xi * factors[meanwgxc] * factors[normalizer2]); if (has_w) { gw[i + r] = static_cast(gi * xi); } diff --git a/Source/Cmlx/mlx-generated/metal/logsumexp.h b/Source/Cmlx/mlx-generated/metal/logsumexp.h index b6898e31..c746050b 100644 --- a/Source/Cmlx/mlx-generated/metal/logsumexp.h +++ b/Source/Cmlx/mlx-generated/metal/logsumexp.h @@ -103,8 +103,8 @@ template } } else { for (int i = 0; i < N_READS; i++) { - vals[i] = (offset + i < axis_size) ? AccT(in[offset + i]) - : Limits::finite_min; + vals[i] = + (offset + i < axis_size) ? AccT(in[offset + i]) : Limits::min; } } prevmax = maxval; @@ -134,10 +134,7 @@ template threadgroup_barrier(mem_flags::mem_threadgroup); normalizer = simd_sum(local_normalizer[simd_lane_id]); - if (simd_group_id == 0) { - normalizer = simd_sum(local_normalizer[simd_lane_id]); - if (simd_lane_id == 0) { - out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval); - } + if (lid == 0) { + out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval); } } diff --git a/Source/Cmlx/mlx-generated/metal/quantized.h b/Source/Cmlx/mlx-generated/metal/quantized.h index b2b0d8d8..bf639814 100644 --- a/Source/Cmlx/mlx-generated/metal/quantized.h +++ b/Source/Cmlx/mlx-generated/metal/quantized.h @@ -14,11 +14,23 @@ using namespace metal; MLX_MTL_CONST int SIMD_SIZE = 32; MLX_MTL_CONST int QUAD_SIZE = 4; +template +inline constexpr short get_pack_factor() { + return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits); +} + +template +inline constexpr short get_bytes_per_pack() { + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; + return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3); +} + template inline U load_vector(const device T* x, thread U* x_thread) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, - "Template undefined for bits not in {2, 3, 4, 6, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); U sum = 0; @@ -57,6 +69,21 @@ inline U load_vector(const device T* x, thread U* x_thread) { } } + else if (bits == 5) { + for (int i = 0; i < values_per_thread; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 32.0f; + x_thread[i + 2] = x[i + 2] / 4.0f; + x_thread[i + 3] = x[i + 3] / 128.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 2.0f; + x_thread[i + 6] = x[i + 6] / 64.0f; + x_thread[i + 7] = x[i + 7] / 8.0f; + } + } + else if (bits == 6) { for (int i = 0; i < values_per_thread; i += 4) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; @@ -80,8 +107,9 @@ inline U load_vector(const device T* x, thread U* x_thread) { template inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, - "Template undefined for bits not in {2, 3, 4, 6, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); U sum = 0; @@ -121,6 +149,21 @@ inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { } } + else if (bits == 5) { + for (int i = 0; i < N; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 32.0f; + x_thread[i + 2] = x[i + 2] / 4.0f; + x_thread[i + 3] = x[i + 3] / 128.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 2.0f; + x_thread[i + 6] = x[i + 6] / 64.0f; + x_thread[i + 7] = x[i + 7] / 8.0f; + } + } + else if (bits == 6) { for (int i = 0; i < N; i += 4) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; @@ -153,8 +196,9 @@ inline U qdot( U bias, U sum) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, - "Template undefined for bits not in {2, 3, 4, 6, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); U accum = 0; @@ -199,6 +243,26 @@ inline U qdot( } } + else if (bits == 5) { + for (int i = 0; i < (values_per_thread / 8); i++) { + x_thread += 8 * i; + w += 5 * i; + + accum += (w[0] & 0x1f) * x_thread[0]; + accum += (w[0] & 0xe0) * x_thread[1]; + accum += (w[1] & 0x3) * (x_thread[1] * 256.0f); + accum += (w[1] & 0x7c) * x_thread[2]; + accum += (w[1] & 0x80) * x_thread[3]; + accum += (w[2] & 0xf) * (x_thread[3] * 256.0f); + accum += (w[2] & 0xf0) * x_thread[4]; + accum += (w[3] & 0x1) * (x_thread[4] * 256.0f); + accum += (w[3] & 0x3e) * x_thread[5]; + accum += (w[3] & 0xc0) * x_thread[6]; + accum += (w[4] & 0x7) * (x_thread[6] * 256.0f); + accum += (w[4] & 0xf8) * x_thread[7]; + } + } + else if (bits == 6) { for (int i = 0; i < (values_per_thread / 4); i++) { x_thread += 4 * i; @@ -234,8 +298,9 @@ inline U qdot_safe( U sum, int N) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, - "Template undefined for bits not in {2, 3, 4, 6, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); U accum = 0; @@ -280,6 +345,26 @@ inline U qdot_safe( } } + else if (bits == 5) { + for (int i = 0; i < (N / 8); i++) { + x_thread += 8 * i; + w += 5 * i; + + accum += (w[0] & 0x1f) * x_thread[0]; + accum += (w[0] & 0xe0) * x_thread[1]; + accum += (w[1] & 0x3) * (x_thread[1] * 256.0f); + accum += (w[1] & 0x7c) * x_thread[2]; + accum += (w[1] & 0x80) * x_thread[3]; + accum += (w[2] & 0xf) * (x_thread[3] * 256.0f); + accum += (w[2] & 0xf0) * x_thread[4]; + accum += (w[3] & 0x1) * (x_thread[4] * 256.0f); + accum += (w[3] & 0x3e) * x_thread[5]; + accum += (w[3] & 0xc0) * x_thread[6]; + accum += (w[4] & 0x7) * (x_thread[6] * 256.0f); + accum += (w[4] & 0xf8) * x_thread[7]; + } + } + else if (bits == 6) { for (int i = 0; i < (N / 4); i++) { x_thread += 4 * i; @@ -310,8 +395,9 @@ template inline void qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, - "Template undefined for bits not in {2, 3, 4, 6, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); if (bits == 2) { U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f}; @@ -348,8 +434,31 @@ qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias); result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias); } + } - } else if (bits == 6) { + else if (bits == 5) { + for (int i = 0; i < (values_per_thread / 8); i++) { + uint8_t w0 = w[5 * i]; + uint8_t w1 = w[5 * i + 1]; + uint8_t w2 = w[5 * i + 2]; + uint8_t w3 = w[5 * i + 3]; + uint8_t w4 = w[5 * i + 4]; + result[8 * i] += x * ((w0 & 0x1f) * scale + bias); + result[8 * i + 1] += + x * ((((w0 & 0xe0) >> 5) + ((w1 & 0x3) << 3)) * scale + bias); + result[8 * i + 2] += x * (((w1 & 0x7c) >> 2) * scale + bias); + result[8 * i + 3] += + x * ((((w1 & 0x80) >> 7) + ((w2 & 0xf) << 1)) * scale + bias); + result[8 * i + 4] += + x * ((((w2 & 0xf0) >> 4) + ((w3 & 0x1) << 4)) * scale + bias); + result[8 * i + 5] += x * (((w3 & 0x3e) >> 1) * scale + bias); + result[8 * i + 6] += + x * ((((w3 & 0xc0) >> 6) + ((w4 & 0x7) << 2)) * scale + bias); + result[8 * i + 7] += x * (((w4 & 0xf8) >> 3) * scale + bias); + } + } + + else if (bits == 6) { for (int i = 0; i < (values_per_thread / 4); i++) { uint8_t w0 = w[3 * i]; uint8_t w1 = w[3 * i + 1]; @@ -375,8 +484,9 @@ template inline void dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, - "Template undefined for bits not in {2, 3, 4, 6, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); if (bits == 2) { U s[4] = { @@ -416,11 +526,26 @@ dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { } } + else if (bits == 5) { + for (int i = 0; i < (N / 8); i++) { + w_local += 8 * i; + w += 5 * i; + + w_local[0] = (w[0] & 0x1f) * scale + bias; + w_local[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias; + w_local[2] = ((w[1] & 0x7c) >> 2) * scale + bias; + w_local[3] = (((w[1] & 0x80) >> 7) + ((w[2] & 0xf) << 1)) * scale + bias; + w_local[4] = (((w[2] & 0xf0) >> 4) + ((w[3] & 0x1) << 4)) * scale + bias; + w_local[5] = ((w[3] & 0x3e) >> 1) * scale + bias; + w_local[6] = (((w[3] & 0xc0) >> 6) + ((w[4] & 0x7) << 2)) * scale + bias; + w_local[7] = ((w[4] & 0xf8) >> 3) * scale + bias; + } + } + else if (bits == 6) { for (int i = 0; i < (N / 4); i++) { w_local += 4 * i; w += 3 * i; - w_local[0] = (w[0] & 0x3f) * scale + bias; w_local[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias; w_local[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias; @@ -452,11 +577,12 @@ struct QuantizedBlockLoader { group_size % BCOLS == 0, "The group size should be divisible by the columns"); static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, - "Template undefined for bits not in {2, 3, 4, 6, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); - MLX_MTL_CONST short pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; - MLX_MTL_CONST short bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1; + MLX_MTL_CONST short pack_factor = get_pack_factor(); + MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack(); MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor; MLX_MTL_CONST short n_reads = (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size; @@ -517,14 +643,14 @@ struct QuantizedBlockLoader { return; } - if (reduction_dim == 1 && bi >= src_tile_dim.y) { + if (reduction_dim == 1 && bi >= src_tile_dim.x) { for (int i = 0; i < n_reads * pack_factor; i++) { dst[i] = T(0); } return; } - if (reduction_dim == 0 && bi >= src_tile_dim.x) { + if (reduction_dim == 0 && bi >= src_tile_dim.y) { for (int i = 0; i < n_reads * pack_factor; i++) { dst[i] = T(0); } @@ -632,12 +758,11 @@ METAL_FUNC void qmv_fast_impl( uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { - constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; constexpr int packs_per_thread = bits == 2 ? 1 : 2; constexpr int num_simdgroups = 2; constexpr int results_per_simdgroup = 4; - constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; - constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int values_per_thread = pack_factor * packs_per_thread; constexpr int block_size = values_per_thread * SIMD_SIZE; constexpr int scale_step_per_thread = group_size / values_per_thread; @@ -700,12 +825,12 @@ METAL_FUNC void qmv_impl( uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { - constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; constexpr int num_simdgroups = 2; constexpr int results_per_simdgroup = 4; constexpr int packs_per_thread = 1; - constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; - constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + constexpr int values_per_thread = pack_factor * packs_per_thread; constexpr int block_size = values_per_thread * SIMD_SIZE; constexpr int scale_step_per_thread = group_size / values_per_thread; @@ -857,8 +982,9 @@ METAL_FUNC void qvm_impl( uint simd_lid [[thread_index_in_simdgroup]]) { constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; constexpr int num_simdgroups = 2; - constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; - constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + constexpr int tn = 32 / pack_factor; constexpr int block_size = SIMD_SIZE; @@ -981,9 +1107,10 @@ METAL_FUNC void qmm_t_impl( constexpr int WM = 2; constexpr int WN = 2; - constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + constexpr int BK_padded = (BK + 16 / sizeof(T)); - constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1; // Instantiate the appropriate BlockMMA and Loader using mma_t = mlx::steel:: @@ -1008,11 +1135,11 @@ METAL_FUNC void qmm_t_impl( auto wl = (const device uint8_t*)w; - x += y_row * K; + x += y_row * static_cast(K); wl += y_col * K_w; scales += y_col * K_g; biases += y_col * K_g; - y += y_row * N + y_col; + y += y_row * static_cast(N) + y_col; // Make the x loader and mma operation const short num_els = min(BM, M - y_row); @@ -1106,11 +1233,11 @@ METAL_FUNC void qmm_n_impl( constexpr int WM = 2; constexpr int WN = 2; - constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + constexpr int BK_padded = (BK + 16 / sizeof(T)); constexpr int BN_padded = (BN + 16 / sizeof(T)); - constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; - constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; // Instantiate the appropriate BlockMMA and Loader using mma_t = mlx::steel:: @@ -1132,11 +1259,11 @@ METAL_FUNC void qmm_n_impl( // Set the block const int y_row = tid.y * BM; const int y_col = tid.x * BN; - x += y_row * K; + x += y_row * static_cast(K); wl += y_col * bytes_per_pack / pack_factor; scales += y_col / group_size; biases += y_col / group_size; - y += y_row * N + y_col; + y += y_row * static_cast(N) + y_col; // Make the x loader and mma operation const short num_els = min(BM, M - y_row); @@ -1307,7 +1434,7 @@ METAL_FUNC void adjust_matrix_offsets( } template -[[kernel]] void qmv_quad( +[[kernel]] void affine_qmv_quad( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], @@ -1359,7 +1486,7 @@ template } template -[[kernel]] void qmv_fast( +[[kernel]] void affine_qmv_fast( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], @@ -1411,7 +1538,7 @@ template } template -[[kernel]] void qmv( +[[kernel]] void affine_qmv( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], @@ -1463,7 +1590,7 @@ template } template -[[kernel]] void qvm( +[[kernel]] void affine_qvm( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], @@ -1515,7 +1642,7 @@ template } template -[[kernel]] void qvm_split_k( +[[kernel]] void affine_qvm_split_k( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], @@ -1579,7 +1706,7 @@ template < const int BM = 32, const int BK = 32, const int BN = 32> -[[kernel]] void qmm_t( +[[kernel]] void affine_qmm_t( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], @@ -1637,7 +1764,7 @@ template < const int BM = 32, const int BK = 32, const int BN = 32> -[[kernel]] void qmm_n( +[[kernel]] void affine_qmm_n( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], @@ -1690,7 +1817,7 @@ template < } template -[[kernel]] void gather_qmv_fast( +[[kernel]] void affine_gather_qmv_fast( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], @@ -1752,7 +1879,7 @@ template } template -[[kernel]] void gather_qmv( +[[kernel]] void affine_gather_qmv( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], @@ -1814,7 +1941,7 @@ template } template -[[kernel]] void gather_qvm( +[[kernel]] void affine_gather_qvm( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], @@ -1883,7 +2010,7 @@ template < const int BM = 32, const int BK = 32, const int BN = 32> -[[kernel]] void gather_qmm_t( +[[kernel]] void affine_gather_qmm_t( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], @@ -1950,7 +2077,7 @@ template < const int BM = 32, const int BK = 32, const int BN = 32> -[[kernel]] void gather_qmm_n( +[[kernel]] void affine_gather_qmm_n( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], @@ -2011,92 +2138,6 @@ template < w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); } -template -METAL_FUNC void gemm_loop_aligned( - threadgroup T* As, - threadgroup T* Bs, - thread mma_t& mma_op, - thread loader_a_t& loader_a, - thread loader_b_t& loader_b, - const int k_iterations) { - for (int k = 0; k < k_iterations; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Load elements into threadgroup memory - loader_a.load_unsafe(); - loader_b.load_unsafe(); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - - // Prepare for next iteration - loader_a.next(); - loader_b.next(); - } -} - -template < - bool rows_aligned, - bool cols_aligned, - bool transpose, - typename T, - typename mma_t, - typename loader_a_t, - typename loader_b_t> -METAL_FUNC void gemm_loop_unaligned( - threadgroup T* As, - threadgroup T* Bs, - thread mma_t& mma_op, - thread loader_a_t& loader_a, - thread loader_b_t& loader_b, - const int k_iterations, - const short tgp_bm, - const short tgp_bn, - const short tgp_bk) { - for (int k = 0; k < k_iterations; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Load elements into threadgroup memory - if (rows_aligned) { - loader_a.load_unsafe(); - } else { - loader_a.load_safe(short2(tgp_bk, tgp_bm)); - } - if (cols_aligned) { - loader_b.load_unsafe(); - } else { - loader_b.load_safe( - transpose ? short2(tgp_bk, tgp_bn) : short2(tgp_bn, tgp_bk)); - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - - // Prepare for next iteration - loader_a.next(); - loader_b.next(); - } -} - -template -METAL_FUNC void gemm_loop_finalize( - threadgroup T* As, - threadgroup T* Bs, - thread mma_t& mma_op, - thread loader_a_t& loader_a, - thread loader_b_t& loader_b, - const short2 tile_a, - const short2 tile_b) { - loader_a.load_safe(tile_a); - loader_b.load_safe(tile_b); - threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(As, Bs); -} - template < typename T, int group_size, @@ -2107,7 +2148,7 @@ template < int WM, int WN, bool transpose> -[[kernel]] void gather_qmm_rhs( +[[kernel]] void affine_gather_qmm_rhs( const device T* x [[buffer(0)]], const device uint32_t* w [[buffer(1)]], const device T* scales [[buffer(2)]], @@ -2120,11 +2161,10 @@ template < uint3 tid [[threadgroup_position_in_grid]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]]) { - constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int BK_padded = (BK + 16 / sizeof(T)); constexpr int BN_padded = (BN + 16 / sizeof(T)); - constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; - constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; using mma_t = mlx::steel::BlockMMA< T, @@ -2305,13 +2345,13 @@ template constexpr float eps = 1e-7; constexpr int simd_size = 32; constexpr float n_bins = (1 << bits) - 1; - constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int values_per_reduce = group_size / simd_size; - constexpr int writes_per_reduce = packs_per_int / values_per_reduce; + constexpr int writes_per_reduce = pack_factor / values_per_reduce; constexpr int writes_per_pack = - writes_per_reduce > 1 ? 1 : values_per_reduce / packs_per_int; + writes_per_reduce > 1 ? 1 : values_per_reduce / pack_factor; constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; - constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; static_assert( group_size % simd_size == 0, @@ -2354,8 +2394,8 @@ template biases[gindex] = static_cast(bias); } - // We accumulate 3 bytes worth for 3/6 bit so we need a uint32_t - uint32_t output = 0; + using OutType = metal::conditional_t; + OutType output = 0; #pragma clang loop unroll(full) for (int i = 0; i < values_per_reduce; i++) { @@ -2363,26 +2403,34 @@ template if (bits == 8) { output = val; } else { - output += val << (bits * (i % packs_per_int)); + output |= val << (bits * (i % pack_factor)); } - if (packs_per_int < values_per_reduce && - i % packs_per_int == packs_per_int - 1) { - out[out_index + i / packs_per_int] = output; + if (pack_factor < values_per_reduce && i % pack_factor == pack_factor - 1) { + out[out_index + i / pack_factor] = output; output = 0; } else { #pragma clang loop unroll(full) for (int j = 1; j < writes_per_reduce; j++) { uint8_t sval = simd_shuffle_down(val, j); - output += sval << (bits * (j * values_per_reduce + i)); + output |= static_cast(sval) + << (bits * (j * values_per_reduce + i)); } } } if (bits == 3 || bits == 6) { - if (in_index % packs_per_int == 0 && out_index % bytes_per_pack == 0) { + if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) { + out[out_index] = output & 0xff; + out[out_index + 1] = (output & 0xff00) >> 8; + out[out_index + 2] = (output & 0xff0000) >> 16; + } + } else if (bits == 5) { + if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) { out[out_index] = output & 0xff; out[out_index + 1] = (output & 0xff00) >> 8; out[out_index + 2] = (output & 0xff0000) >> 16; + out[out_index + 3] = (output & 0xff000000) >> 24; + out[out_index + 4] = (output & 0xff00000000) >> 32; } } else { if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) { @@ -2399,12 +2447,11 @@ template device T* out [[buffer(3)]], uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; - constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; - constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); size_t offset = index.x + grid_dim.x * size_t(index.y); - size_t oindex = offset * packs_per_int; + size_t oindex = offset * pack_factor; size_t gindex = oindex / group_size; T scale = scales[gindex]; T bias = biases[gindex]; @@ -2421,7 +2468,16 @@ template out[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias; out[6] = ((w[2] & 0x1c) >> 2) * scale + bias; out[7] = ((w[2] & 0xe0) >> 5) * scale + bias; - + } else if (bits == 5) { + w += offset * bytes_per_pack; + out[0] = (w[0] & 0x1f) * scale + bias; + out[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias; + out[2] = ((w[1] & 0x7c) >> 2) * scale + bias; + out[3] = (((w[1] & 0x80) >> 7) + ((w[2] & 0xf) << 1)) * scale + bias; + out[4] = (((w[2] & 0xf0) >> 4) + ((w[3] & 0x1) << 4)) * scale + bias; + out[5] = ((w[3] & 0x3e) >> 1) * scale + bias; + out[6] = (((w[3] & 0xc0) >> 6) + ((w[4] & 0x7) << 2)) * scale + bias; + out[7] = ((w[4] & 0xf8) >> 3) * scale + bias; } else if (bits == 6) { w += offset * bytes_per_pack; out[0] = (w[0] & 0x3f) * scale + bias; @@ -2431,7 +2487,7 @@ template } else { uint val = w[offset]; #pragma clang loop unroll(full) - for (int i = 0; i < packs_per_int; i++) { + for (int i = 0; i < pack_factor; i++) { uint8_t d; if (bits == 2) { d = (val >> (bits * i)) & 0x03; diff --git a/Source/Cmlx/mlx-generated/metal/quantized_utils.h b/Source/Cmlx/mlx-generated/metal/quantized_utils.h new file mode 100644 index 00000000..38253f8f --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/quantized_utils.h @@ -0,0 +1,90 @@ +// Copyright © 2023-2024 Apple Inc. + +#include +#include + +template +METAL_FUNC void gemm_loop_aligned( + threadgroup T* As, + threadgroup T* Bs, + thread mma_t& mma_op, + thread loader_a_t& loader_a, + thread loader_b_t& loader_b, + const int k_iterations) { + for (int k = 0; k < k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup memory + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } +} + +template < + bool rows_aligned, + bool cols_aligned, + bool transpose, + typename T, + typename mma_t, + typename loader_a_t, + typename loader_b_t> +METAL_FUNC void gemm_loop_unaligned( + threadgroup T* As, + threadgroup T* Bs, + thread mma_t& mma_op, + thread loader_a_t& loader_a, + thread loader_b_t& loader_b, + const int k_iterations, + const short tgp_bm, + const short tgp_bn, + const short tgp_bk) { + for (int k = 0; k < k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup memory + if (rows_aligned) { + loader_a.load_unsafe(); + } else { + loader_a.load_safe(short2(tgp_bk, tgp_bm)); + } + if (cols_aligned) { + loader_b.load_unsafe(); + } else { + loader_b.load_safe( + transpose ? short2(tgp_bk, tgp_bn) : short2(tgp_bn, tgp_bk)); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } +} + +template +METAL_FUNC void gemm_loop_finalize( + threadgroup T* As, + threadgroup T* Bs, + thread mma_t& mma_op, + thread loader_a_t& loader_a, + thread loader_b_t& loader_b, + const short2 tile_a, + const short2 tile_b) { + loader_a.load_safe(tile_a); + loader_b.load_safe(tile_b); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); +} diff --git a/Source/Cmlx/mlx-generated/metal/reduction/ops.h b/Source/Cmlx/mlx-generated/metal/reduction/ops.h index 68ed1198..11d8e83a 100644 --- a/Source/Cmlx/mlx-generated/metal/reduction/ops.h +++ b/Source/Cmlx/mlx-generated/metal/reduction/ops.h @@ -164,7 +164,15 @@ struct Min { DEFINE_SIMD_REDUCE() template - T simd_reduce_impl(T val) { + metal::enable_if_t, T> simd_reduce_impl(T val) { + return simd_min(val); + } + + template + metal::enable_if_t, T> simd_reduce_impl(T val) { + if (simd_any(val != val)) { + return static_cast(NAN); + } return simd_min(val); } @@ -176,17 +184,52 @@ struct Min { } // Operator - U operator()(U a, U b) { + template + metal::enable_if_t, T> operator()(T a, T b) { return a < b ? a : b; } -}; + template + metal::enable_if_t, T> operator()(T a, T b) { + if (metal::isnan(a) || metal::isnan(b)) { + return static_cast(NAN); + } else { + return a < b ? a : b; + } + } + + template <> + complex64_t operator()(complex64_t a, complex64_t b) { + bool real_is_nan = metal::isnan(a.real) || metal::isnan(b.real); + bool imag_is_nan = metal::isnan(a.imag) || metal::isnan(b.imag); + + if (!real_is_nan && !imag_is_nan) { + return a < b ? a : b; + } else if (real_is_nan && !imag_is_nan) { + return complex64_t( + static_cast(NAN), a.imag < b.imag ? a.imag : b.imag); + } else if (!real_is_nan && imag_is_nan) { + return complex64_t( + a.real < b.real ? a.real : b.real, static_cast(NAN)); + } else { + return complex64_t(static_cast(NAN), static_cast(NAN)); + } + }; +}; template struct Max { DEFINE_SIMD_REDUCE() template - T simd_reduce_impl(T val) { + metal::enable_if_t, T> simd_reduce_impl(T val) { + return simd_max(val); + } + + template + metal::enable_if_t, T> simd_reduce_impl(T val) { + if (simd_any(val != val)) { + return static_cast(NAN); + } return simd_max(val); } @@ -198,7 +241,35 @@ struct Max { } // Operator - U operator()(U a, U b) { + template + metal::enable_if_t, T> operator()(T a, T b) { return a > b ? a : b; } + + template + metal::enable_if_t, T> operator()(T a, T b) { + if (metal::isnan(a) || metal::isnan(b)) { + return static_cast(NAN); + } else { + return a > b ? a : b; + } + } + + template <> + complex64_t operator()(complex64_t a, complex64_t b) { + bool real_is_nan = metal::isnan(a.real) || metal::isnan(b.real); + bool imag_is_nan = metal::isnan(a.imag) || metal::isnan(b.imag); + + if (!real_is_nan && !imag_is_nan) { + return a > b ? a : b; + } else if (real_is_nan && !imag_is_nan) { + return complex64_t( + static_cast(NAN), a.imag > b.imag ? a.imag : b.imag); + } else if (!real_is_nan && imag_is_nan) { + return complex64_t( + a.real > b.real ? a.real : b.real, static_cast(NAN)); + } else { + return complex64_t(static_cast(NAN), static_cast(NAN)); + } + } }; diff --git a/Source/Cmlx/mlx-generated/metal/reduction/reduce_row.h b/Source/Cmlx/mlx-generated/metal/reduction/reduce_row.h index c8973429..936d75bb 100644 --- a/Source/Cmlx/mlx-generated/metal/reduction/reduce_row.h +++ b/Source/Cmlx/mlx-generated/metal/reduction/reduce_row.h @@ -224,7 +224,7 @@ template < if ((non_row_reductions < 32 && row_size <= 8) || non_row_reductions <= 8) { // Simple loop over non_row_reductions and reduce the row in the thread. - IdxT out_idx = tid.x + tsize.y * IdxT(tid.y); + IdxT out_idx = tid.x + tsize.x * IdxT(tid.y); in += elem_to_loc(out_idx, shape, strides, ndim); for (uint r = 0; r < non_row_reductions; r++) { diff --git a/Source/Cmlx/mlx-generated/metal/rope.metal b/Source/Cmlx/mlx-generated/metal/rope.metal index 106d23f6..e010aee8 100644 --- a/Source/Cmlx/mlx-generated/metal/rope.metal +++ b/Source/Cmlx/mlx-generated/metal/rope.metal @@ -10,7 +10,7 @@ void rope_single_impl( constant const int& offset, const float inv_freq, constant const float& scale, - constant const size_t& stride, + constant const int64_t& stride, uint2 pos, uint2 grid) { float L = scale * static_cast(offset); @@ -52,7 +52,7 @@ template device T* out [[buffer(1)]], constant const int& offset, constant const float& scale, - constant const size_t& stride, + constant const int64_t& stride, constant const float& base [[buffer(10)]], uint2 pos [[thread_position_in_grid]], uint2 grid [[threads_per_grid]]) { @@ -68,9 +68,9 @@ template device T* out [[buffer(1)]], constant const int& offset, constant const float& scale, - constant const size_t& stride, + constant const int64_t& stride, const device float* freqs [[buffer(10)]], - constant const size_t& freq_stride [[buffer(11)]], + constant const int64_t& freq_stride [[buffer(11)]], uint2 pos [[thread_position_in_grid]], uint2 grid [[threads_per_grid]]) { float inv_freq = 1.0 / (freqs[freq_stride * pos.x]); @@ -82,15 +82,21 @@ template void rope_impl( const device T* in, device T* out, - constant const int& offset, + const device int* offset, const float inv_freq, constant const float& scale, - constant const size_t strides[3], - constant const size_t out_strides[3], - constant const size_t& n_batch, + constant const int64_t strides[3], + constant const int64_t out_strides[3], + constant const int64_t& offset_stride, + constant const int& n_head, uint3 pos, uint3 grid) { - float L = scale * static_cast(pos.y + offset); + auto n_head_up = N * ((n_head + N - 1) / N); + auto head_idx = static_cast((pos.z * N) % n_head_up); + auto batch_idx = (pos.z * N) / n_head_up; + auto batch_offset = offset[batch_idx * offset_stride]; + float L = scale * static_cast(pos.y + batch_offset); + auto mat_idx = batch_idx * n_head + head_idx; // Compute costheta, sintheta float theta = L * inv_freq; @@ -102,20 +108,19 @@ void rope_impl( size_t out_index_1, out_index_2; if (traditional) { out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] + - N * pos.z * out_strides[0]; + mat_idx * out_strides[0]; out_index_2 = out_index_1 + 1; in_index_1 = - 2 * pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0]; + 2 * pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0]; in_index_2 = in_index_1 + strides[2]; } else { out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] + - N * pos.z * out_strides[0]; + mat_idx * out_strides[0]; out_index_2 = out_index_1 + grid.x * out_strides[2]; - in_index_1 = - pos.x * strides[2] + pos.y * strides[1] + N * pos.z * strides[0]; + in_index_1 = pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0]; in_index_2 = in_index_1 + grid.x * strides[2]; } - for (int i = 0; i < N && pos.z * N + i < n_batch; ++i) { + for (int i = 0; i < N && head_idx + i < n_head; ++i) { // Read and write the output float x1 = static_cast(in[in_index_1]); float x2 = static_cast(in[in_index_2]); @@ -141,11 +146,12 @@ template [[kernel]] void rope( const device T* in [[buffer(0)]], device T* out [[buffer(1)]], - constant const int& offset, + const device int* offset, constant const float& scale, - constant const size_t strides[3], - constant const size_t out_strides[3], - constant const size_t& n_batch, + constant const int64_t strides[3], + constant const int64_t out_strides[3], + constant const int64_t& offset_stride, + constant const int& n_head, constant const float& base [[buffer(10)]], uint3 pos [[thread_position_in_grid]], uint3 grid [[threads_per_grid]]) { @@ -159,7 +165,8 @@ template scale, strides, out_strides, - n_batch, + offset_stride, + n_head, pos, grid); } @@ -168,13 +175,14 @@ template [[kernel]] void rope_freqs( const device T* in [[buffer(0)]], device T* out [[buffer(1)]], - constant const int& offset, + const device int* offset, constant const float& scale, - constant const size_t strides[3], - constant const size_t out_strides[3], - constant const size_t& n_batch, + constant const int64_t strides[3], + constant const int64_t out_strides[3], + constant const int64_t& offset_stride, + constant const int& n_head, const device float* freqs [[buffer(10)]], - constant const size_t& freq_stride [[buffer(11)]], + constant const int64_t& freq_stride [[buffer(11)]], uint3 pos [[thread_position_in_grid]], uint3 grid [[threads_per_grid]]) { float inv_freq = 1.0 / (freqs[freq_stride * pos.x]); @@ -186,61 +194,20 @@ template scale, strides, out_strides, - n_batch, + offset_stride, + n_head, pos, grid); } // clang-format off #define instantiate_rope_g(name, type, traditional, forward) \ - template [[host_name("rope_" #name)]] [[kernel]] void \ - rope( \ - const device type* in [[buffer(0)]], \ - device type* out [[buffer(1)]], \ - constant const int& offset, \ - constant const float& scale, \ - constant const size_t strides[3], \ - constant const size_t out_strides[3], \ - constant const size_t& n_batch, \ - constant const float& base [[buffer(10)]], \ - uint3 pos [[thread_position_in_grid]], \ - uint3 grid [[threads_per_grid]]); \ - template [[host_name("rope_freqs_" #name)]] \ - [[kernel]] void rope_freqs( \ - const device type* in [[buffer(0)]], \ - device type* out [[buffer(1)]], \ - constant const int& offset, \ - constant const float& scale, \ - constant const size_t strides[3], \ - constant const size_t out_strides[3], \ - constant const size_t& n_batch, \ - const device float* freqs [[buffer(10)]], \ - constant const size_t& freq_stride [[buffer(11)]], \ - uint3 pos [[thread_position_in_grid]], \ - uint3 grid [[threads_per_grid]]); + instantiate_kernel("rope_" #name, rope, type, traditional, forward) \ + instantiate_kernel("rope_freqs_" #name, rope_freqs, type, traditional, forward) -#define instantiate_rope_s(name, type, traditional, forward) \ - template [[host_name("rope_single_" #name)]] [[kernel]] void \ - rope_single( \ - const device type* in [[buffer(0)]], \ - device type* out [[buffer(1)]], \ - constant const int& offset, \ - constant const float& scale, \ - constant const size_t& stride, \ - constant const float& base [[buffer(10)]], \ - uint2 pos [[thread_position_in_grid]], \ - uint2 grid [[threads_per_grid]]); \ - template [[host_name("rope_single_freqs_" #name)]] \ - [[kernel]] void rope_single_freqs( \ - const device type* in [[buffer(0)]], \ - device type* out [[buffer(1)]], \ - constant const int& offset, \ - constant const float& scale, \ - constant const size_t& stride, \ - const device float* freqs [[buffer(10)]], \ - constant const size_t& freq_stride [[buffer(11)]], \ - uint2 pos [[thread_position_in_grid]], \ - uint2 grid [[threads_per_grid]]); +#define instantiate_rope_s(name, type, traditional, forward) \ + instantiate_kernel("rope_single_" #name, rope_single, type, traditional, forward) \ + instantiate_kernel("rope_single_freqs_" #name, rope_single_freqs, type, traditional, forward) #define instantiate_rope(name, type, traditional, forward) \ instantiate_rope_s(name, type, traditional, forward) \ diff --git a/Source/Cmlx/mlx-generated/metal/sdpa_vector.h b/Source/Cmlx/mlx-generated/metal/sdpa_vector.h index c4c0f645..b7ded1a6 100644 --- a/Source/Cmlx/mlx-generated/metal/sdpa_vector.h +++ b/Source/Cmlx/mlx-generated/metal/sdpa_vector.h @@ -9,6 +9,7 @@ constant bool query_transposed [[function_constant(21)]]; constant bool do_causal [[function_constant(22)]]; constant bool bool_mask [[function_constant(23)]]; constant bool float_mask [[function_constant(24)]]; +constant bool has_sinks [[function_constant(25)]]; template [[kernel]] void sdpa_vector( @@ -31,6 +32,9 @@ template [[buffer(14), function_constant(has_mask)]], const constant int& mask_head_stride [[buffer(15), function_constant(has_mask)]], + const device T* sinks [[buffer(16), function_constant(has_sinks)]], + const constant int& num_q_heads + [[buffer(17), function_constant(has_sinks)]], uint3 tid [[threadgroup_position_in_grid]], uint3 tpg [[threadgroups_per_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], @@ -53,24 +57,24 @@ template threadgroup U sum_exp_scores[BN]; // Adjust positions - const int head_idx = tid.x; + const int q_batch_head_idx = tid.x; const int q_seq_idx = tid.y; - const int kv_head_idx = head_idx / gqa_factor; - const int o_offset = tpg.x * q_seq_idx + head_idx; + const int kv_head_idx = q_batch_head_idx / gqa_factor; + const int o_offset = q_batch_head_idx * tpg.y + q_seq_idx; const int q_offset = - query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx; + query_transposed ? tpg.x * q_seq_idx + q_batch_head_idx : o_offset; queries += q_offset * D + simd_lid * qk_per_thread; keys += kv_head_idx * k_head_stride + simd_gid * k_seq_stride + simd_lid * qk_per_thread; values += kv_head_idx * v_head_stride + simd_gid * v_seq_stride + simd_lid * v_per_thread; if (bool_mask) { - bmask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride + - q_seq_idx * mask_q_seq_stride; + bmask += q_batch_head_idx * mask_head_stride + + simd_gid * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride; } if (float_mask) { - fmask += head_idx * mask_head_stride + simd_gid * mask_kv_seq_stride + - q_seq_idx * mask_q_seq_stride; + fmask += q_batch_head_idx * mask_head_stride + + simd_gid * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride; } out += o_offset * V + simd_gid * v_per_thread; @@ -85,6 +89,10 @@ template U max_score = -INFINITY; U sum_exp_score = 0; + if (has_sinks && simd_gid == 0) { + max_score = static_cast(sinks[q_batch_head_idx % num_q_heads]); + sum_exp_score = 1; + } // For each key for (int i = simd_gid; i < N; i += BN) { @@ -93,6 +101,8 @@ template use_key = i <= (N - int(tpg.y) + int(q_seq_idx)); } else if (bool_mask) { use_key = bmask[0]; + } else if (float_mask) { + use_key = (fmask[0] >= Limits::finite_min); } if (use_key) { // Read the key @@ -107,13 +117,14 @@ template } score = simd_sum(score); if (float_mask) { - score += max(Limits::finite_min, static_cast(fmask[0])); + score += static_cast(fmask[0]); } // Update the accumulators U new_max = max(max_score, score); - U factor = fast::exp(max_score - new_max); - U exp_score = fast::exp(score - new_max); + bool is_neg_inf = new_max == -INFINITY; + U factor = is_neg_inf ? 1.0 : fast::exp(max_score - new_max); + U exp_score = is_neg_inf ? 0.0 : fast::exp(score - new_max); max_score = new_max; sum_exp_score = sum_exp_score * factor + exp_score; @@ -187,6 +198,9 @@ template [[buffer(16), function_constant(has_mask)]], const constant int& mask_head_stride [[buffer(17), function_constant(has_mask)]], + const device T* sinks [[buffer(18), function_constant(has_sinks)]], + const constant int& num_q_heads + [[buffer(19), function_constant(has_sinks)]], uint3 tid [[threadgroup_position_in_grid]], uint3 tpg [[threadgroups_per_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], @@ -211,12 +225,12 @@ template // Adjust positions const int block_idx = tid.z; - const int head_idx = tid.x; + const int q_batch_head_idx = tid.x; const int q_seq_idx = tid.y; - const int o_offset = tpg.x * q_seq_idx + head_idx; + const int o_offset = q_batch_head_idx * tpg.y + q_seq_idx; const int q_offset = - query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx; - const int kv_head_idx = head_idx / gqa_factor; + query_transposed ? tpg.x * q_seq_idx + q_batch_head_idx : o_offset; + const int kv_head_idx = q_batch_head_idx / gqa_factor; queries += q_offset * D + simd_lid * qk_per_thread; keys += kv_head_idx * k_head_stride + @@ -225,12 +239,12 @@ template (block_idx * BN + simd_gid) * v_seq_stride + simd_lid * v_per_thread; out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread; if (bool_mask) { - bmask += head_idx * mask_head_stride + + bmask += q_batch_head_idx * mask_head_stride + (block_idx * BN + simd_gid) * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride; } if (float_mask) { - fmask += head_idx * mask_head_stride + + fmask += q_batch_head_idx * mask_head_stride + (block_idx * BN + simd_gid) * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride; } @@ -245,8 +259,13 @@ template o[i] = 0; } - U max_score = -1e9; + U max_score = -INFINITY; U sum_exp_score = 0; + if (has_sinks && block_idx == 0 && simd_gid == 0) { + int q_head_idx = q_batch_head_idx % num_q_heads; + max_score = static_cast(sinks[q_head_idx]); + sum_exp_score = 1; + } // For each key for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) { @@ -255,6 +274,8 @@ template use_key = i <= (N - int(tpg.y) + int(q_seq_idx)); } else if (bool_mask) { use_key = bmask[0]; + } else if (float_mask) { + use_key = (fmask[0] >= Limits::finite_min); } if (use_key) { // Read the key @@ -268,6 +289,10 @@ template score += q[i] * k[i]; } score = simd_sum(score); + if (score < Limits::finite_min) { + continue; + } + if (float_mask) { score += fmask[0]; } @@ -358,8 +383,8 @@ template // Adjust positions const int head_idx = tid.x; const int q_seq_idx = tid.y; - const int n_heads = tpg.x; - const int q_offset = n_heads * q_seq_idx + head_idx; + const int q_offset = head_idx * tpg.y + q_seq_idx; + ; partials += q_offset * blocks * D + simd_gid * D + simd_lid * elem_per_thread; sums += q_offset * blocks; maxs += q_offset * blocks; diff --git a/Source/Cmlx/mlx-generated/metal/softmax.h b/Source/Cmlx/mlx-generated/metal/softmax.h index b36b73bd..6ea4ac73 100644 --- a/Source/Cmlx/mlx-generated/metal/softmax.h +++ b/Source/Cmlx/mlx-generated/metal/softmax.h @@ -128,8 +128,8 @@ template } } else { for (int i = 0; i < N_READS; i++) { - vals[i] = (offset + i < axis_size) ? AccT(in[offset + i]) - : Limits::finite_min; + vals[i] = + (offset + i < axis_size) ? AccT(in[offset + i]) : Limits::min; } } prevmax = maxval; diff --git a/Source/Cmlx/mlx-generated/metal/sort.h b/Source/Cmlx/mlx-generated/metal/sort.h index b067150d..5823e430 100644 --- a/Source/Cmlx/mlx-generated/metal/sort.h +++ b/Source/Cmlx/mlx-generated/metal/sort.h @@ -45,7 +45,9 @@ struct ThreadSort { for (short j = i & 1; j < N_PER_THREAD - 1; j += 2) { if (op(vals[j + 1], vals[j])) { thread_swap(vals[j + 1], vals[j]); - thread_swap(idxs[j + 1], idxs[j]); + if (ARG_SORT) { + thread_swap(idxs[j + 1], idxs[j]); + } } } } @@ -111,7 +113,9 @@ struct BlockMergeSort { bool pred = (b_idx < B_sz) && (a_idx >= A_sz || op(b, a)); vals[i] = pred ? b : a; - idxs[i] = pred ? Bs_idx[b_idx] : As_idx[a_idx]; + if (ARG_SORT) { + idxs[i] = pred ? Bs_idx[b_idx] : As_idx[a_idx]; + } b_idx += short(pred); a_idx += short(!pred); diff --git a/Source/Cmlx/mlx-generated/metal/steel/attn/kernels/steel_attention.h b/Source/Cmlx/mlx-generated/metal/steel/attn/kernels/steel_attention.h index 2e27ea06..7397039b 100644 --- a/Source/Cmlx/mlx-generated/metal/steel/attn/kernels/steel_attention.h +++ b/Source/Cmlx/mlx-generated/metal/steel/attn/kernels/steel_attention.h @@ -11,6 +11,7 @@ constant bool align_K [[function_constant(201)]]; constant bool has_mask [[function_constant(300)]]; constant bool do_causal [[function_constant(301)]]; +constant bool has_sinks [[function_constant(302)]]; template struct TransformScale { @@ -82,6 +83,7 @@ template < const constant AttnParams* params [[buffer(4)]], const constant AttnMaskParams* mask_params [[buffer(5), function_constant(has_mask)]], const device MaskType* mask [[buffer(6), function_constant(has_mask)]], + const device T* sinks [[buffer(7), function_constant(has_sinks)]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint3 tid [[threadgroup_position_in_grid]], @@ -95,7 +97,7 @@ template < Q += tidl.z * params->Q_strides[0] + // Batch tidl.y * params->Q_strides[1] + // Head - tidl.x * BQ * params->Q_strides[2]; // Seqeunce + tidl.x * BQ * params->Q_strides[2]; // Sequence ulong kv_head_idx = int(tid.y) / params->gqa_factor; K += tidl.z * params->K_strides[0] + // Batch @@ -106,7 +108,7 @@ template < O += tidl.z * params->O_strides[0] + // Batch tidl.y * params->O_strides[1] + // Head - tidl.x * BQ * params->O_strides[2]; // Seqeunce + tidl.x * BQ * params->O_strides[2]; // Sequence if (has_mask) { mask += tidl.z * mask_params->M_strides[0] + // Batch @@ -169,7 +171,7 @@ template < VBlockLoader loader_v( V, params->V_strides[2], Vs, simd_group_id, simd_lane_id); - TransformScale ts(static_cast(params->scale * 1.44269504089)); + TransformScale ts(static_cast(params->scale * M_LOG2E_F)); // Prepare MMA tiles constexpr short kFragSize = 8; // MMAFrag size @@ -232,6 +234,14 @@ template < max_score[i] = Limits::finite_min; } + if (has_sinks) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + max_score[i] = M_LOG2E_F * static_cast(sinks[tidl.y]); + sum_score[i] = 1; + } + } + int kb_lim = params->NK; if (do_causal) { @@ -350,7 +360,7 @@ template < Stile.frag_at(i, j)[jj] = mfrag[jj] ? Stile.frag_at(i, j)[jj] : neg_inf; } else { - Stile.frag_at(i, j)[jj] += 1.44269504089 * selem_t(mfrag[jj]); + Stile.frag_at(i, j)[jj] += M_LOG2E_F * selem_t(mfrag[jj]); } } } diff --git a/Source/Cmlx/mlx-generated/metal/steel/attn/loader.h b/Source/Cmlx/mlx-generated/metal/steel/attn/loader.h index 75d695e6..3b7c5166 100644 --- a/Source/Cmlx/mlx-generated/metal/steel/attn/loader.h +++ b/Source/Cmlx/mlx-generated/metal/steel/attn/loader.h @@ -113,7 +113,7 @@ struct BlockLoader { tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; } - // Zero out uneeded values + // Zero out unneeded values STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); @@ -240,7 +240,7 @@ struct BlockLoaderT { tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; } - // Zero out uneeded values + // Zero out unneeded values STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); diff --git a/Source/Cmlx/mlx-generated/metal/steel/conv/kernels/steel_conv_general.h b/Source/Cmlx/mlx-generated/metal/steel/conv/kernels/steel_conv_general.h index 9261b871..c92fcf36 100644 --- a/Source/Cmlx/mlx-generated/metal/steel/conv/kernels/steel_conv_general.h +++ b/Source/Cmlx/mlx-generated/metal/steel/conv/kernels/steel_conv_general.h @@ -2,6 +2,8 @@ #include "../../../steel/conv/loaders/loader_general.h" +constant bool align_C [[function_constant(200)]]; + template < typename T, int BM, @@ -118,30 +120,65 @@ implicit_gemm_conv_2d_general( // Prepare threadgroup mma operation mma_t mma_op(simd_gid, simd_lid); - int gemm_k_iterations = - base_wh_size * base_ww_size * gemm_params->gemm_k_iterations; + if (align_C) { + int gemm_k_iterations = + base_wh_size * base_ww_size * gemm_params->gemm_k_iterations; + + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + } - for (int k = 0; k < gemm_k_iterations; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - // Load elements into threadgroup - loader_a.load_unsafe(); - loader_b.load_unsafe(); + else { + for (int k = 1; k < gemm_params->gemm_k_iterations; k++) { + for (int j = 0; j < base_wh_size * base_ww_size; j++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); - threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup_barrier(mem_flags::mem_threadgroup); - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); - // Prepare for next iteration - loader_a.next(); - loader_b.next(); + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + } + const short remaining_k = params->C % BK; + for (int j = 0; j < base_wh_size * base_ww_size; j++) { + // Load elements into threadgroup + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_safe(remaining_k); + loader_b.load_safe(remaining_k); + threadgroup_barrier(mem_flags::mem_threadgroup); + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } } threadgroup_barrier(mem_flags::mem_none); // Store results to device memory { - // Adjust for simdgroup and thread locatio + // Adjust for simdgroup and thread location int offset_m = c_row + mma_op.sm; int offset_n = c_col + mma_op.sn; C += offset_n; diff --git a/Source/Cmlx/mlx-generated/metal/steel/conv/loaders/loader_channel_l.h b/Source/Cmlx/mlx-generated/metal/steel/conv/loaders/loader_channel_l.h index 85a6d134..22eebe03 100644 --- a/Source/Cmlx/mlx-generated/metal/steel/conv/loaders/loader_channel_l.h +++ b/Source/Cmlx/mlx-generated/metal/steel/conv/loaders/loader_channel_l.h @@ -381,6 +381,7 @@ struct Conv2DWeightBlockLoader { const constant MLXConvParams<2>* params; int weight_hw; + int weight_step; const int read_n; const bool do_read; @@ -402,6 +403,7 @@ struct Conv2DWeightBlockLoader { src(src_ + bi * src_ld + bj), params(params_), weight_hw(0), + weight_step(params->C / params->groups), read_n(offsets.y + bi), do_read(read_n + n_rows * TROWS <= gemm_params_->N) {} @@ -435,15 +437,15 @@ struct Conv2DWeightBlockLoader { /* Iteration helper */ METAL_FUNC void next() { if (++weight_hw < (params->wS[1] * params->wS[0])) { - src += params->wt_strides[2]; + src += weight_step; return; } weight_hw = 0; - src += BK - (params->wS[1] * params->wS[0] - 1) * params->wt_strides[2]; + src += BK - (params->wS[1] * params->wS[0] - 1) * weight_step; } }; } // namespace steel -} // namespace mlx \ No newline at end of file +} // namespace mlx diff --git a/Source/Cmlx/mlx-generated/metal/steel/conv/loaders/loader_channel_n.h b/Source/Cmlx/mlx-generated/metal/steel/conv/loaders/loader_channel_n.h index 2f12535f..1f37fb21 100644 --- a/Source/Cmlx/mlx-generated/metal/steel/conv/loaders/loader_channel_n.h +++ b/Source/Cmlx/mlx-generated/metal/steel/conv/loaders/loader_channel_n.h @@ -83,7 +83,7 @@ struct Conv2DInputBlockLoaderSmallChannels { const constant MLXConvParams<2>* params; const constant ImplicitGemmConv2DParams* gemm_params; - short weight_hw; + int weight_hw; const device T* src[n_rows]; @@ -272,7 +272,7 @@ struct Conv2DWeightBlockLoaderSmallChannels { return; } - const device T* curr_src = src + weight_hw * params->wt_strides[2]; + const device T* curr_src = src + weight_hw * (params->C / params->groups); if (BN != 8 || do_read) { STEEL_PRAGMA_UNROLL @@ -316,4 +316,4 @@ struct Conv2DWeightBlockLoaderSmallChannels { }; } // namespace steel -} // namespace mlx \ No newline at end of file +} // namespace mlx diff --git a/Source/Cmlx/mlx-generated/metal/steel/conv/loaders/loader_general.h b/Source/Cmlx/mlx-generated/metal/steel/conv/loaders/loader_general.h index 3f5be762..9043a3c4 100644 --- a/Source/Cmlx/mlx-generated/metal/steel/conv/loaders/loader_general.h +++ b/Source/Cmlx/mlx-generated/metal/steel/conv/loaders/loader_general.h @@ -137,6 +137,52 @@ struct Conv2DInputBlockLoaderGeneral { } } + METAL_FUNC void load_safe(const short remaining_k) const { + STEEL_PRAGMA_UNROLL + for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) { + // Find bounds + int n = read_n[i]; + + int h_flip = params->flip ? params->wS[0] - weight_h - 1 : weight_h; + int w_flip = params->flip ? params->wS[1] - weight_w - 1 : weight_w; + + int ih_dil = read_ih[i] + h_flip * params->kdil[0]; + int iw_dil = read_iw[i] + w_flip * params->kdil[1]; + + int ih = ih_dil / params->idil[0]; + int iw = iw_dil / params->idil[1]; + + size_t offset = ih * params->in_strides[1] + iw * params->in_strides[2]; + + // Read from input if in bounds + if ((n < params->N) && (ih_dil >= 0 && ih < params->iS[0]) && + (iw_dil >= 0 && iw < params->iS[1])) { + if (bj + vec_size <= remaining_k) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = (src[i])[offset + j]; + } + } else { + for (short j = 0; j < vec_size; ++j) { + if (bj + j < remaining_k) { + dst[is * dst_ld + j] = (src[i])[offset + j]; + } else { + dst[is * dst_ld + j] = T(0); + } + } + } + } + + // Zero pad otherwise + else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = T(0); + } + } + } + } + /* Iteration helper */ METAL_FUNC void next() { weight_w += jump_params->f_wgt_jump_w; @@ -262,6 +308,55 @@ struct Conv2DWeightBlockLoaderGeneral { } } + METAL_FUNC void load_safe(const short remaining_k) const { + const device T* curr_src = src + weight_h * params->wt_strides[1] + + weight_w * params->wt_strides[2]; + + if ((start_row + BN <= params->O)) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BN; i += TROWS) { + if (bj + vec_size <= remaining_k) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } + } else { + for (short j = 0; j < vec_size; j++) { + if (bj + j < remaining_k) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } else { + dst[i * dst_ld + j] = T(0); + } + } + } + } + } else { + for (short i = 0; i < BN; i += TROWS) { + if ((start_row + i) < params->O) { + if (bj + vec_size <= remaining_k) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } + } else { + for (short j = 0; j < vec_size; j++) { + if (bj + j < remaining_k) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } else { + dst[i * dst_ld + j] = T(0); + } + } + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + } + } + } + /* Iteration helper */ METAL_FUNC void next() { weight_w += jump_params->f_wgt_jump_w; diff --git a/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_fused.h b/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_fused.h index add495d9..85830872 100644 --- a/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_fused.h +++ b/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_fused.h @@ -33,8 +33,8 @@ template < device T* D [[buffer(3)]], const constant GEMMParams* params [[buffer(4)]], const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], - const constant int* batch_shape [[buffer(6)]], - const constant int64_t* batch_strides [[buffer(7)]], + const constant int* batch_shape [[buffer(6), function_constant(has_batch)]], + const constant int64_t* batch_strides [[buffer(7), function_constant(has_batch)]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint3 tid [[threadgroup_position_in_grid]], diff --git a/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_segmented.h b/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_segmented.h new file mode 100644 index 00000000..b915eb34 --- /dev/null +++ b/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_segmented.h @@ -0,0 +1,266 @@ +// Copyright © 2025 Apple Inc. + +using namespace mlx::steel; + +constant bool segments_contiguous [[function_constant(199)]]; +constant bool align_M [[function_constant(200)]]; +constant bool align_N [[function_constant(201)]]; + +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + typename AccumType = float> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void segmented_mm( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + const device uint32_t* segments [[buffer(2)]], + device T* C [[buffer(3)]], + const constant GEMMParams* params [[buffer(4)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]]) { + using gemm_kernel = GEMMKernel< + T, + T, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + true, + true, + AccumType>; + + using loader_a_t = typename gemm_kernel::loader_a_t; + using loader_b_t = typename gemm_kernel::loader_b_t; + using mma_t = typename gemm_kernel::mma_t; + + if (params->tiles_n <= static_cast(tid.x) || + params->tiles_m <= static_cast(tid.y)) { + return; + } + + // Prepare threadgroup memory + threadgroup T As[gemm_kernel::tgp_mem_size_a]; + threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; + + // Find the block in A, B, C + const int c_row = tid.y * BM; + const int c_col = tid.x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + + // Prepare threadgroup bounds + const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row)); + const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col)); + + // Move the pointers to the output tile + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + C += c_row_long * params->ldd + c_col_long; + + // Move the pointers to the start of the segment + uint32_t k_start, k_end; + if (segments_contiguous) { + k_start = segments[2 * tid.z]; + k_end = segments[2 * tid.z + 1]; + } else { + // We accept either contiguous (above) or weird strides where the beginning + // of the next one is the previous one. Basically the last two strides are + // both 1! + k_start = segments[tid.z]; + k_end = segments[tid.z + 1]; + } + A += transpose_a ? k_start * params->lda : k_start; + B += transpose_b ? k_start : k_start * params->ldb; + C += tid.z * params->batch_stride_d; + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + // Prepare threadgroup loading operations + thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); + thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); + + // Matrix level alignment so only check K + if (align_M && align_N) { + uint32_t k = k_start + BK; + for (; k <= k_end; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + short k_remain = BK - short(k - k_end); + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + if (k_remain > 0) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); + } + mma_op.store_result(C, params->ldd); + } else { + // Tile aligned do the same as above + if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { + uint32_t k = k_start + BK; + for (; k <= k_end; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + short k_remain = BK - short(k - k_end); + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + if (k_remain > 0) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); + } + mma_op.store_result(C, params->ldd); + } + + // Tile partially aligned check rows + else if (align_N || tgp_bn == BN) { + uint32_t k = k_start + BK; + for (; k <= k_end; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup + loader_a.load_safe( + transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm)); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + short k_remain = BK - short(k - k_end); + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + if (k_remain > 0) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); + } + mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm)); + } + + // Tile partially aligned check cols + else if (align_M || tgp_bm == BM) { + uint32_t k = k_start + BK; + for (; k <= k_end; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_safe( + transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK)); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + short k_remain = BK - short(k - k_end); + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + if (k_remain > 0) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); + } + mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm)); + } + + // Nothing aligned so check both rows and cols + else { + uint32_t k = k_start + BK; + for (; k <= k_end; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup + loader_a.load_safe( + transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm)); + loader_b.load_safe( + transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK)); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + short k_remain = BK - short(k - k_end); + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + if (k_remain > 0) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); + } + mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm)); + } + } +} diff --git a/Source/Cmlx/mlx-generated/metal/steel/gemm/loader.h b/Source/Cmlx/mlx-generated/metal/steel/gemm/loader.h index 1846e26d..cc79de86 100644 --- a/Source/Cmlx/mlx-generated/metal/steel/gemm/loader.h +++ b/Source/Cmlx/mlx-generated/metal/steel/gemm/loader.h @@ -113,7 +113,7 @@ struct BlockLoader { tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; } - // Zero out uneeded values + // Zero out unneeded values STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); diff --git a/Source/Cmlx/mlx-generated/metal/ternary.h b/Source/Cmlx/mlx-generated/metal/ternary.h index 4b3adcc8..570f5e4d 100644 --- a/Source/Cmlx/mlx-generated/metal/ternary.h +++ b/Source/Cmlx/mlx-generated/metal/ternary.h @@ -1,25 +1,44 @@ // Copyright © 2024 Apple Inc. -template +template ::n> [[kernel]] void ternary_v( device const bool* a, device const T* b, device const T* c, device T* d, + constant uint& size, uint index [[thread_position_in_grid]]) { - d[index] = Op()(a[index], b[index], c[index]); + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + d[index + i] = Op()(a[index + i], b[index + i], c[index + i]); + } + } else { + for (int i = 0; i < N; ++i) { + d[index + i] = Op()(a[index + i], b[index + i], c[index + i]); + } + } } -template +template ::n> [[kernel]] void ternary_v2( device const bool* a, device const T* b, device const T* c, device T* d, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - d[offset] = Op()(a[offset], b[offset], c[offset]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]); + } + } else { + for (int i = 0; i < N; ++i) { + d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]); + } + } } template diff --git a/Source/Cmlx/mlx-generated/metal/unary.h b/Source/Cmlx/mlx-generated/metal/unary.h index 69828599..649ba7f2 100644 --- a/Source/Cmlx/mlx-generated/metal/unary.h +++ b/Source/Cmlx/mlx-generated/metal/unary.h @@ -1,21 +1,40 @@ // Copyright © 2024 Apple Inc. -template +template ::n> [[kernel]] void unary_v( device const T* in, device U* out, + constant uint& size, uint index [[thread_position_in_grid]]) { - out[index] = Op()(in[index]); + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + out[index + i] = Op()(in[index + i]); + } + } else { + for (int i = 0; i < N; ++i) { + out[index + i] = Op()(in[index + i]); + } + } } -template +template ::n> [[kernel]] void unary_v2( device const T* in, device U* out, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - out[offset] = Op()(in[offset]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + out[offset + i] = Op()(in[offset + i]); + } + } else { + for (int i = 0; i < N; ++i) { + out[offset + i] = Op()(in[offset + i]); + } + } } template < diff --git a/Source/Cmlx/mlx-generated/metal/unary_ops.h b/Source/Cmlx/mlx-generated/metal/unary_ops.h index afe37aa1..eaf4fa78 100644 --- a/Source/Cmlx/mlx-generated/metal/unary_ops.h +++ b/Source/Cmlx/mlx-generated/metal/unary_ops.h @@ -5,6 +5,7 @@ #include #include +#include "cexpf.h" #include "erf.h" #include "expm1f.h" @@ -178,8 +179,7 @@ struct Exp { return metal::precise::exp(x); }; complex64_t operator()(complex64_t x) { - auto m = metal::precise::exp(x.real); - return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)}; + return cexpf(x); } }; diff --git a/Source/Cmlx/mlx-generated/metal/utils.h b/Source/Cmlx/mlx-generated/metal/utils.h index 8fd67b89..28840a5c 100644 --- a/Source/Cmlx/mlx-generated/metal/utils.h +++ b/Source/Cmlx/mlx-generated/metal/utils.h @@ -15,6 +15,14 @@ typedef half float16_t; +// Work per thread values for different types. The values here are expected to +// match get_work_per_thread in mlx/backend/metal/utils.h +template +struct WorkPerThread { + static_assert(sizeof(U) <= 8, "Type too large"); + static constexpr int constant n = 8 / sizeof(U); +}; + /////////////////////////////////////////////////////////////////////////////// // Type limits utils /////////////////////////////////////////////////////////////////////////////// diff --git a/Source/Cmlx/mlx-generated/quantized.cpp b/Source/Cmlx/mlx-generated/quantized.cpp index da1a4930..ddad18b2 100644 --- a/Source/Cmlx/mlx-generated/quantized.cpp +++ b/Source/Cmlx/mlx-generated/quantized.cpp @@ -8,11 +8,21 @@ constant bool align_K [[function_constant(202)]]; using namespace metal; static constant constexpr const int SIMD_SIZE = 32; static constant constexpr const int QUAD_SIZE = 4; +template +inline constexpr short get_pack_factor() { + return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits); +} +template +inline constexpr short get_bytes_per_pack() { + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; + return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3); +} template inline U load_vector(const device T* x, thread U* x_thread) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, - "Template undefined for bits not in {2, 3, 4, 6, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); U sum = 0; if (bits == 2) { for (int i = 0; i < values_per_thread; i += 4) { @@ -46,6 +56,20 @@ inline U load_vector(const device T* x, thread U* x_thread) { x_thread[i + 3] = x[i + 3] / 4096.0f; } } + else if (bits == 5) { + for (int i = 0; i < values_per_thread; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 32.0f; + x_thread[i + 2] = x[i + 2] / 4.0f; + x_thread[i + 3] = x[i + 3] / 128.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 2.0f; + x_thread[i + 6] = x[i + 6] / 64.0f; + x_thread[i + 7] = x[i + 7] / 8.0f; + } + } else if (bits == 6) { for (int i = 0; i < values_per_thread; i += 4) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; @@ -66,8 +90,9 @@ inline U load_vector(const device T* x, thread U* x_thread) { template inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, - "Template undefined for bits not in {2, 3, 4, 6, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); U sum = 0; if (bits == 2) { for (int i = 0; i < N; i += 4) { @@ -101,6 +126,20 @@ inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { x_thread[i + 3] = x[i + 3] / 4096.0f; } } + else if (bits == 5) { + for (int i = 0; i < N; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 32.0f; + x_thread[i + 2] = x[i + 2] / 4.0f; + x_thread[i + 3] = x[i + 3] / 128.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 2.0f; + x_thread[i + 6] = x[i + 6] / 64.0f; + x_thread[i + 7] = x[i + 7] / 8.0f; + } + } else if (bits == 6) { for (int i = 0; i < N; i += 4) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; @@ -129,8 +168,9 @@ inline U qdot( U bias, U sum) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, - "Template undefined for bits not in {2, 3, 4, 6, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); U accum = 0; if (bits == 2) { for (int i = 0; i < (values_per_thread / 4); i++) { @@ -167,6 +207,24 @@ inline U qdot( x_thread[4 * i + 3] * (ws[i] & 0xf000)); } } + else if (bits == 5) { + for (int i = 0; i < (values_per_thread / 8); i++) { + x_thread += 8 * i; + w += 5 * i; + accum += (w[0] & 0x1f) * x_thread[0]; + accum += (w[0] & 0xe0) * x_thread[1]; + accum += (w[1] & 0x3) * (x_thread[1] * 256.0f); + accum += (w[1] & 0x7c) * x_thread[2]; + accum += (w[1] & 0x80) * x_thread[3]; + accum += (w[2] & 0xf) * (x_thread[3] * 256.0f); + accum += (w[2] & 0xf0) * x_thread[4]; + accum += (w[3] & 0x1) * (x_thread[4] * 256.0f); + accum += (w[3] & 0x3e) * x_thread[5]; + accum += (w[3] & 0xc0) * x_thread[6]; + accum += (w[4] & 0x7) * (x_thread[6] * 256.0f); + accum += (w[4] & 0xf8) * x_thread[7]; + } + } else if (bits == 6) { for (int i = 0; i < (values_per_thread / 4); i++) { x_thread += 4 * i; @@ -195,8 +253,9 @@ inline U qdot_safe( U sum, int N) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, - "Template undefined for bits not in {2, 3, 4, 6, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); U accum = 0; if (bits == 2) { for (int i = 0; i < (N / 4); i++) { @@ -233,6 +292,24 @@ inline U qdot_safe( x_thread[4 * i + 3] * (ws[i] & 0xf000)); } } + else if (bits == 5) { + for (int i = 0; i < (N / 8); i++) { + x_thread += 8 * i; + w += 5 * i; + accum += (w[0] & 0x1f) * x_thread[0]; + accum += (w[0] & 0xe0) * x_thread[1]; + accum += (w[1] & 0x3) * (x_thread[1] * 256.0f); + accum += (w[1] & 0x7c) * x_thread[2]; + accum += (w[1] & 0x80) * x_thread[3]; + accum += (w[2] & 0xf) * (x_thread[3] * 256.0f); + accum += (w[2] & 0xf0) * x_thread[4]; + accum += (w[3] & 0x1) * (x_thread[4] * 256.0f); + accum += (w[3] & 0x3e) * x_thread[5]; + accum += (w[3] & 0xc0) * x_thread[6]; + accum += (w[4] & 0x7) * (x_thread[6] * 256.0f); + accum += (w[4] & 0xf8) * x_thread[7]; + } + } else if (bits == 6) { for (int i = 0; i < (N / 4); i++) { x_thread += 4 * i; @@ -256,8 +333,9 @@ template inline void qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, - "Template undefined for bits not in {2, 3, 4, 6, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); if (bits == 2) { U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f}; for (int i = 0; i < (values_per_thread / 4); i++) { @@ -290,7 +368,29 @@ qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias); result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias); } - } else if (bits == 6) { + } + else if (bits == 5) { + for (int i = 0; i < (values_per_thread / 8); i++) { + uint8_t w0 = w[5 * i]; + uint8_t w1 = w[5 * i + 1]; + uint8_t w2 = w[5 * i + 2]; + uint8_t w3 = w[5 * i + 3]; + uint8_t w4 = w[5 * i + 4]; + result[8 * i] += x * ((w0 & 0x1f) * scale + bias); + result[8 * i + 1] += + x * ((((w0 & 0xe0) >> 5) + ((w1 & 0x3) << 3)) * scale + bias); + result[8 * i + 2] += x * (((w1 & 0x7c) >> 2) * scale + bias); + result[8 * i + 3] += + x * ((((w1 & 0x80) >> 7) + ((w2 & 0xf) << 1)) * scale + bias); + result[8 * i + 4] += + x * ((((w2 & 0xf0) >> 4) + ((w3 & 0x1) << 4)) * scale + bias); + result[8 * i + 5] += x * (((w3 & 0x3e) >> 1) * scale + bias); + result[8 * i + 6] += + x * ((((w3 & 0xc0) >> 6) + ((w4 & 0x7) << 2)) * scale + bias); + result[8 * i + 7] += x * (((w4 & 0xf8) >> 3) * scale + bias); + } + } + else if (bits == 6) { for (int i = 0; i < (values_per_thread / 4); i++) { uint8_t w0 = w[3 * i]; uint8_t w1 = w[3 * i + 1]; @@ -313,8 +413,9 @@ template inline void dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, - "Template undefined for bits not in {2, 3, 4, 6, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); if (bits == 2) { U s[4] = { scale, @@ -349,6 +450,20 @@ dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { w_local[2 * i + 1] = s[1] * (w[i] & 0xf0) + bias; } } + else if (bits == 5) { + for (int i = 0; i < (N / 8); i++) { + w_local += 8 * i; + w += 5 * i; + w_local[0] = (w[0] & 0x1f) * scale + bias; + w_local[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias; + w_local[2] = ((w[1] & 0x7c) >> 2) * scale + bias; + w_local[3] = (((w[1] & 0x80) >> 7) + ((w[2] & 0xf) << 1)) * scale + bias; + w_local[4] = (((w[2] & 0xf0) >> 4) + ((w[3] & 0x1) << 4)) * scale + bias; + w_local[5] = ((w[3] & 0x3e) >> 1) * scale + bias; + w_local[6] = (((w[3] & 0xc0) >> 6) + ((w[4] & 0x7) << 2)) * scale + bias; + w_local[7] = ((w[4] & 0xf8) >> 3) * scale + bias; + } + } else if (bits == 6) { for (int i = 0; i < (N / 4); i++) { w_local += 4 * i; @@ -382,10 +497,11 @@ struct QuantizedBlockLoader { group_size % BCOLS == 0, "The group size should be divisible by the columns"); static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, - "Template undefined for bits not in {2, 3, 4, 6, 8}"); - static constant constexpr const short pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; - static constant constexpr const short bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1; + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + static constant constexpr const short pack_factor = get_pack_factor(); + static constant constexpr const short bytes_per_pack = get_bytes_per_pack(); static constant constexpr const short BCOLS_PACKED = BCOLS / pack_factor; static constant constexpr const short n_reads = (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size; @@ -438,13 +554,13 @@ struct QuantizedBlockLoader { if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { return; } - if (reduction_dim == 1 && bi >= src_tile_dim.y) { + if (reduction_dim == 1 && bi >= src_tile_dim.x) { for (int i = 0; i < n_reads * pack_factor; i++) { dst[i] = T(0); } return; } - if (reduction_dim == 0 && bi >= src_tile_dim.x) { + if (reduction_dim == 0 && bi >= src_tile_dim.y) { for (int i = 0; i < n_reads * pack_factor; i++) { dst[i] = T(0); } @@ -539,12 +655,11 @@ METAL_FUNC void qmv_fast_impl( uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { - constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; constexpr int packs_per_thread = bits == 2 ? 1 : 2; constexpr int num_simdgroups = 2; constexpr int results_per_simdgroup = 4; - constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; - constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int values_per_thread = pack_factor * packs_per_thread; constexpr int block_size = values_per_thread * SIMD_SIZE; constexpr int scale_step_per_thread = group_size / values_per_thread; @@ -595,12 +710,11 @@ METAL_FUNC void qmv_impl( uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { - constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; constexpr int num_simdgroups = 2; constexpr int results_per_simdgroup = 4; constexpr int packs_per_thread = 1; - constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; - constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int values_per_thread = pack_factor * packs_per_thread; constexpr int block_size = values_per_thread * SIMD_SIZE; constexpr int scale_step_per_thread = group_size / values_per_thread; @@ -727,8 +841,8 @@ METAL_FUNC void qvm_impl( uint simd_lid [[thread_index_in_simdgroup]]) { constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; constexpr int num_simdgroups = 2; - constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; - constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int tn = 32 / pack_factor; constexpr int block_size = SIMD_SIZE; using W_T = @@ -833,9 +947,9 @@ METAL_FUNC void qmm_t_impl( (void)lid; constexpr int WM = 2; constexpr int WN = 2; - constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int BK_padded = (BK + 16 / sizeof(T)); - constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1; using mma_t = mlx::steel:: BlockMMA; using loader_x_t = @@ -854,11 +968,11 @@ METAL_FUNC void qmm_t_impl( const int y_row = tid.y * BM; const int y_col = tid.x * BN; auto wl = (const device uint8_t*)w; - x += y_row * K; + x += y_row * static_cast(K); wl += y_col * K_w; scales += y_col * K_g; biases += y_col * K_g; - y += y_row * N + y_col; + y += y_row * static_cast(N) + y_col; const short num_els = min(BM, M - y_row); const short num_outs = min(BN, N - y_col); loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); @@ -943,11 +1057,10 @@ METAL_FUNC void qmm_n_impl( (void)lid; constexpr int WM = 2; constexpr int WN = 2; - constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int BK_padded = (BK + 16 / sizeof(T)); constexpr int BN_padded = (BN + 16 / sizeof(T)); - constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; - constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; using mma_t = mlx::steel:: BlockMMA; using loader_x_t = mlx::steel:: @@ -964,11 +1077,11 @@ METAL_FUNC void qmm_n_impl( auto wl = (const device uint8_t*)w; const int y_row = tid.y * BM; const int y_col = tid.x * BN; - x += y_row * K; + x += y_row * static_cast(K); wl += y_col * bytes_per_pack / pack_factor; scales += y_col / group_size; biases += y_col / group_size; - y += y_row * N + y_col; + y += y_row * static_cast(N) + y_col; const short num_els = min(BM, M - y_row); loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); loader_w_t loader_w(wl, scales, biases, N, Ws, simd_gid, simd_lid); @@ -1129,7 +1242,7 @@ METAL_FUNC void adjust_matrix_offsets( y += tid.z * output_stride; } template -[[kernel]] void qmv_quad( +[[kernel]] void affine_qmv_quad( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], @@ -1180,7 +1293,7 @@ template quad_lid); } template -[[kernel]] void qmv_fast( +[[kernel]] void affine_qmv_fast( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], @@ -1231,7 +1344,7 @@ template simd_lid); } template -[[kernel]] void qmv( +[[kernel]] void affine_qmv( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], @@ -1282,7 +1395,7 @@ template simd_lid); } template -[[kernel]] void qvm( +[[kernel]] void affine_qvm( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], @@ -1333,7 +1446,7 @@ template simd_lid); } template -[[kernel]] void qvm_split_k( +[[kernel]] void affine_qvm_split_k( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], @@ -1393,7 +1506,7 @@ template < const int BM = 32, const int BK = 32, const int BN = 32> -[[kernel]] void qmm_t( +[[kernel]] void affine_qmm_t( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], @@ -1447,7 +1560,7 @@ template < const int BM = 32, const int BK = 32, const int BN = 32> -[[kernel]] void qmm_n( +[[kernel]] void affine_qmm_n( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], @@ -1495,7 +1608,7 @@ template < w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); } template -[[kernel]] void gather_qmv_fast( +[[kernel]] void affine_gather_qmv_fast( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], @@ -1556,7 +1669,7 @@ template simd_lid); } template -[[kernel]] void gather_qmv( +[[kernel]] void affine_gather_qmv( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], @@ -1617,7 +1730,7 @@ template simd_lid); } template -[[kernel]] void gather_qvm( +[[kernel]] void affine_gather_qvm( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], @@ -1685,7 +1798,7 @@ template < const int BM = 32, const int BK = 32, const int BN = 32> -[[kernel]] void gather_qmm_t( +[[kernel]] void affine_gather_qmm_t( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], @@ -1748,7 +1861,7 @@ template < const int BM = 32, const int BK = 32, const int BN = 32> -[[kernel]] void gather_qmm_n( +[[kernel]] void affine_gather_qmm_n( const device uint32_t* w [[buffer(0)]], const device T* scales [[buffer(1)]], const device T* biases [[buffer(2)]], @@ -1805,75 +1918,6 @@ template < qmm_n_impl( w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); } -template -METAL_FUNC void gemm_loop_aligned( - threadgroup T* As, - threadgroup T* Bs, - thread mma_t& mma_op, - thread loader_a_t& loader_a, - thread loader_b_t& loader_b, - const int k_iterations) { - for (int k = 0; k < k_iterations; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - loader_a.load_unsafe(); - loader_b.load_unsafe(); - threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(As, Bs); - loader_a.next(); - loader_b.next(); - } -} -template < - bool rows_aligned, - bool cols_aligned, - bool transpose, - typename T, - typename mma_t, - typename loader_a_t, - typename loader_b_t> -METAL_FUNC void gemm_loop_unaligned( - threadgroup T* As, - threadgroup T* Bs, - thread mma_t& mma_op, - thread loader_a_t& loader_a, - thread loader_b_t& loader_b, - const int k_iterations, - const short tgp_bm, - const short tgp_bn, - const short tgp_bk) { - for (int k = 0; k < k_iterations; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - if (rows_aligned) { - loader_a.load_unsafe(); - } else { - loader_a.load_safe(short2(tgp_bk, tgp_bm)); - } - if (cols_aligned) { - loader_b.load_unsafe(); - } else { - loader_b.load_safe( - transpose ? short2(tgp_bk, tgp_bn) : short2(tgp_bn, tgp_bk)); - } - threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(As, Bs); - loader_a.next(); - loader_b.next(); - } -} -template -METAL_FUNC void gemm_loop_finalize( - threadgroup T* As, - threadgroup T* Bs, - thread mma_t& mma_op, - thread loader_a_t& loader_a, - thread loader_b_t& loader_b, - const short2 tile_a, - const short2 tile_b) { - loader_a.load_safe(tile_a); - loader_b.load_safe(tile_b); - threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(As, Bs); -} template < typename T, int group_size, @@ -1884,7 +1928,7 @@ template < int WM, int WN, bool transpose> -[[kernel]] void gather_qmm_rhs( +[[kernel]] void affine_gather_qmm_rhs( const device T* x [[buffer(0)]], const device uint32_t* w [[buffer(1)]], const device T* scales [[buffer(2)]], @@ -1897,11 +1941,10 @@ template < uint3 tid [[threadgroup_position_in_grid]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]]) { - constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int BK_padded = (BK + 16 / sizeof(T)); constexpr int BN_padded = (BN + 16 / sizeof(T)); - constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; - constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; using mma_t = mlx::steel::BlockMMA< T, T, @@ -2052,13 +2095,13 @@ template constexpr float eps = 1e-7; constexpr int simd_size = 32; constexpr float n_bins = (1 << bits) - 1; - constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int values_per_reduce = group_size / simd_size; - constexpr int writes_per_reduce = packs_per_int / values_per_reduce; + constexpr int writes_per_reduce = pack_factor / values_per_reduce; constexpr int writes_per_pack = - writes_per_reduce > 1 ? 1 : values_per_reduce / packs_per_int; + writes_per_reduce > 1 ? 1 : values_per_reduce / pack_factor; constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; - constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; static_assert( group_size % simd_size == 0, "Group size must be divisible by simd size."); @@ -2092,33 +2135,42 @@ template scales[gindex] = static_cast(scale); biases[gindex] = static_cast(bias); } - uint32_t output = 0; + using OutType = metal::conditional_t; + OutType output = 0; #pragma clang loop unroll(full) for (int i = 0; i < values_per_reduce; i++) { uint8_t val = min(round((w_thread[i] - bias) / scale), n_bins); if (bits == 8) { output = val; } else { - output += val << (bits * (i % packs_per_int)); + output |= val << (bits * (i % pack_factor)); } - if (packs_per_int < values_per_reduce && - i % packs_per_int == packs_per_int - 1) { - out[out_index + i / packs_per_int] = output; + if (pack_factor < values_per_reduce && i % pack_factor == pack_factor - 1) { + out[out_index + i / pack_factor] = output; output = 0; } else { #pragma clang loop unroll(full) for (int j = 1; j < writes_per_reduce; j++) { uint8_t sval = simd_shuffle_down(val, j); - output += sval << (bits * (j * values_per_reduce + i)); + output |= static_cast(sval) + << (bits * (j * values_per_reduce + i)); } } } if (bits == 3 || bits == 6) { - if (in_index % packs_per_int == 0 && out_index % bytes_per_pack == 0) { + if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) { out[out_index] = output & 0xff; out[out_index + 1] = (output & 0xff00) >> 8; out[out_index + 2] = (output & 0xff0000) >> 16; } + } else if (bits == 5) { + if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) { + out[out_index] = output & 0xff; + out[out_index + 1] = (output & 0xff00) >> 8; + out[out_index + 2] = (output & 0xff0000) >> 16; + out[out_index + 3] = (output & 0xff000000) >> 24; + out[out_index + 4] = (output & 0xff00000000) >> 32; + } } else { if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) { out[out_index / writes_per_reduce] = output; @@ -2133,11 +2185,10 @@ template device T* out [[buffer(3)]], uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; - constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; - constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); size_t offset = index.x + grid_dim.x * size_t(index.y); - size_t oindex = offset * packs_per_int; + size_t oindex = offset * pack_factor; size_t gindex = oindex / group_size; T scale = scales[gindex]; T bias = biases[gindex]; @@ -2152,6 +2203,16 @@ template out[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias; out[6] = ((w[2] & 0x1c) >> 2) * scale + bias; out[7] = ((w[2] & 0xe0) >> 5) * scale + bias; + } else if (bits == 5) { + w += offset * bytes_per_pack; + out[0] = (w[0] & 0x1f) * scale + bias; + out[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias; + out[2] = ((w[1] & 0x7c) >> 2) * scale + bias; + out[3] = (((w[1] & 0x80) >> 7) + ((w[2] & 0xf) << 1)) * scale + bias; + out[4] = (((w[2] & 0xf0) >> 4) + ((w[3] & 0x1) << 4)) * scale + bias; + out[5] = ((w[3] & 0x3e) >> 1) * scale + bias; + out[6] = (((w[3] & 0xc0) >> 6) + ((w[4] & 0x7) << 2)) * scale + bias; + out[7] = ((w[4] & 0xf8) >> 3) * scale + bias; } else if (bits == 6) { w += offset * bytes_per_pack; out[0] = (w[0] & 0x3f) * scale + bias; @@ -2161,7 +2222,7 @@ template } else { uint val = w[offset]; #pragma clang loop unroll(full) - for (int i = 0; i < packs_per_int; i++) { + for (int i = 0; i < pack_factor; i++) { uint8_t d; if (bits == 2) { d = (val >> (bits * i)) & 0x03; diff --git a/Source/Cmlx/mlx-generated/quantized_utils.cpp b/Source/Cmlx/mlx-generated/quantized_utils.cpp new file mode 100644 index 00000000..0eeafab9 --- /dev/null +++ b/Source/Cmlx/mlx-generated/quantized_utils.cpp @@ -0,0 +1,77 @@ +namespace mlx::core::metal { + +const char* quantized_utils() { + return R"preamble( +template +METAL_FUNC void gemm_loop_aligned( + threadgroup T* As, + threadgroup T* Bs, + thread mma_t& mma_op, + thread loader_a_t& loader_a, + thread loader_b_t& loader_b, + const int k_iterations) { + for (int k = 0; k < k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_unsafe(); + loader_b.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); + loader_a.next(); + loader_b.next(); + } +} +template < + bool rows_aligned, + bool cols_aligned, + bool transpose, + typename T, + typename mma_t, + typename loader_a_t, + typename loader_b_t> +METAL_FUNC void gemm_loop_unaligned( + threadgroup T* As, + threadgroup T* Bs, + thread mma_t& mma_op, + thread loader_a_t& loader_a, + thread loader_b_t& loader_b, + const int k_iterations, + const short tgp_bm, + const short tgp_bn, + const short tgp_bk) { + for (int k = 0; k < k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + if (rows_aligned) { + loader_a.load_unsafe(); + } else { + loader_a.load_safe(short2(tgp_bk, tgp_bm)); + } + if (cols_aligned) { + loader_b.load_unsafe(); + } else { + loader_b.load_safe( + transpose ? short2(tgp_bk, tgp_bn) : short2(tgp_bn, tgp_bk)); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); + loader_a.next(); + loader_b.next(); + } +} +template +METAL_FUNC void gemm_loop_finalize( + threadgroup T* As, + threadgroup T* Bs, + thread mma_t& mma_op, + thread loader_a_t& loader_a, + thread loader_b_t& loader_b, + const short2 tile_a, + const short2 tile_b) { + loader_a.load_safe(tile_a); + loader_b.load_safe(tile_b); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); +} +)preamble"; +} + +} // namespace mlx::core::metal diff --git a/Source/Cmlx/mlx-generated/reduce.cpp b/Source/Cmlx/mlx-generated/reduce.cpp index 6785affb..ac05030e 100644 --- a/Source/Cmlx/mlx-generated/reduce.cpp +++ b/Source/Cmlx/mlx-generated/reduce.cpp @@ -574,7 +574,7 @@ template < int blocks = IdxT(row_size) / N_READS; int extra = IdxT(row_size) % N_READS; if ((non_row_reductions < 32 && row_size <= 8) || non_row_reductions <= 8) { - IdxT out_idx = tid.x + tsize.y * IdxT(tid.y); + IdxT out_idx = tid.x + tsize.x * IdxT(tid.y); in += elem_to_loc(out_idx, shape, strides, ndim); for (uint r = 0; r < non_row_reductions; r++) { row = in + loop.location(); diff --git a/Source/Cmlx/mlx-generated/reduce_utils.cpp b/Source/Cmlx/mlx-generated/reduce_utils.cpp index de28f912..542404ce 100644 --- a/Source/Cmlx/mlx-generated/reduce_utils.cpp +++ b/Source/Cmlx/mlx-generated/reduce_utils.cpp @@ -393,7 +393,14 @@ template struct Min { template = true> T simd_reduce(T val) { return simd_reduce_impl(val); } template = true> T simd_reduce(T val) { for (short i = simd_size / 2; i > 0; i /= 2) { val = operator()(val, simd_shuffle_down(val, i)); } return val; } template - T simd_reduce_impl(T val) { + metal::enable_if_t, T> simd_reduce_impl(T val) { + return simd_min(val); + } + template + metal::enable_if_t, T> simd_reduce_impl(T val) { + if (simd_any(val != val)) { + return static_cast(NAN); + } return simd_min(val); } static constexpr constant U init = Limits::max; @@ -401,15 +408,47 @@ struct Min { void atomic_update(device mlx_atomic* out, T val, size_t offset = 0) { mlx_atomic_fetch_min_explicit(out, val, offset); } - U operator()(U a, U b) { + template + metal::enable_if_t, T> operator()(T a, T b) { return a < b ? a : b; } + template + metal::enable_if_t, T> operator()(T a, T b) { + if (metal::isnan(a) || metal::isnan(b)) { + return static_cast(NAN); + } else { + return a < b ? a : b; + } + } + template <> + complex64_t operator()(complex64_t a, complex64_t b) { + bool real_is_nan = metal::isnan(a.real) || metal::isnan(b.real); + bool imag_is_nan = metal::isnan(a.imag) || metal::isnan(b.imag); + if (!real_is_nan && !imag_is_nan) { + return a < b ? a : b; + } else if (real_is_nan && !imag_is_nan) { + return complex64_t( + static_cast(NAN), a.imag < b.imag ? a.imag : b.imag); + } else if (!real_is_nan && imag_is_nan) { + return complex64_t( + a.real < b.real ? a.real : b.real, static_cast(NAN)); + } else { + return complex64_t(static_cast(NAN), static_cast(NAN)); + } + }; }; template struct Max { template = true> T simd_reduce(T val) { return simd_reduce_impl(val); } template = true> T simd_reduce(T val) { for (short i = simd_size / 2; i > 0; i /= 2) { val = operator()(val, simd_shuffle_down(val, i)); } return val; } template - T simd_reduce_impl(T val) { + metal::enable_if_t, T> simd_reduce_impl(T val) { + return simd_max(val); + } + template + metal::enable_if_t, T> simd_reduce_impl(T val) { + if (simd_any(val != val)) { + return static_cast(NAN); + } return simd_max(val); } static constexpr constant U init = Limits::min; @@ -417,9 +456,34 @@ struct Max { void atomic_update(device mlx_atomic* out, T val, size_t offset = 0) { mlx_atomic_fetch_max_explicit(out, val, offset); } - U operator()(U a, U b) { + template + metal::enable_if_t, T> operator()(T a, T b) { return a > b ? a : b; } + template + metal::enable_if_t, T> operator()(T a, T b) { + if (metal::isnan(a) || metal::isnan(b)) { + return static_cast(NAN); + } else { + return a > b ? a : b; + } + } + template <> + complex64_t operator()(complex64_t a, complex64_t b) { + bool real_is_nan = metal::isnan(a.real) || metal::isnan(b.real); + bool imag_is_nan = metal::isnan(a.imag) || metal::isnan(b.imag); + if (!real_is_nan && !imag_is_nan) { + return a > b ? a : b; + } else if (real_is_nan && !imag_is_nan) { + return complex64_t( + static_cast(NAN), a.imag > b.imag ? a.imag : b.imag); + } else if (!real_is_nan && imag_is_nan) { + return complex64_t( + a.real > b.real ? a.real : b.real, static_cast(NAN)); + } else { + return complex64_t(static_cast(NAN), static_cast(NAN)); + } + } }; )preamble"; } diff --git a/Source/Cmlx/mlx-generated/scan.cpp b/Source/Cmlx/mlx-generated/scan.cpp index 05cc891e..2ad22654 100644 --- a/Source/Cmlx/mlx-generated/scan.cpp +++ b/Source/Cmlx/mlx-generated/scan.cpp @@ -199,6 +199,9 @@ struct Power { template metal::enable_if_t, T> operator()(T base, T exp) { T res = 1; + if (exp < 0) { + return 0; + } while (exp) { if (exp & 1) { res *= base; @@ -210,6 +213,13 @@ struct Power { } template <> complex64_t operator()(complex64_t x, complex64_t y) { + if (x.real == 0 && x.imag == 0) { + if (metal::isnan(y.real) || metal::isnan(y.imag)) { + auto nan = metal::numeric_limits::quiet_NaN(); + return {nan, nan}; + } + return {0.0, 0.0}; + } auto x_theta = metal::atan2(x.imag, x.real); auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag); auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta); diff --git a/Source/Cmlx/mlx-generated/softmax.cpp b/Source/Cmlx/mlx-generated/softmax.cpp index 8761da62..60f3e2ad 100644 --- a/Source/Cmlx/mlx-generated/softmax.cpp +++ b/Source/Cmlx/mlx-generated/softmax.cpp @@ -112,8 +112,8 @@ template } } else { for (int i = 0; i < N_READS; i++) { - vals[i] = (offset + i < axis_size) ? AccT(in[offset + i]) - : Limits::finite_min; + vals[i] = + (offset + i < axis_size) ? AccT(in[offset + i]) : Limits::min; } } prevmax = maxval; diff --git a/Source/Cmlx/mlx-generated/sort.cpp b/Source/Cmlx/mlx-generated/sort.cpp index 21bfc87f..7878dc23 100644 --- a/Source/Cmlx/mlx-generated/sort.cpp +++ b/Source/Cmlx/mlx-generated/sort.cpp @@ -33,7 +33,9 @@ struct ThreadSort { for (short j = i & 1; j < N_PER_THREAD - 1; j += 2) { if (op(vals[j + 1], vals[j])) { thread_swap(vals[j + 1], vals[j]); - thread_swap(idxs[j + 1], idxs[j]); + if (ARG_SORT) { + thread_swap(idxs[j + 1], idxs[j]); + } } } } @@ -87,7 +89,9 @@ struct BlockMergeSort { auto b = Bs[b_idx]; bool pred = (b_idx < B_sz) && (a_idx >= A_sz || op(b, a)); vals[i] = pred ? b : a; - idxs[i] = pred ? Bs_idx[b_idx] : As_idx[a_idx]; + if (ARG_SORT) { + idxs[i] = pred ? Bs_idx[b_idx] : As_idx[a_idx]; + } b_idx += short(pred); a_idx += short(!pred); } diff --git a/Source/Cmlx/mlx-generated/steel_conv_general.cpp b/Source/Cmlx/mlx-generated/steel_conv_general.cpp index 98e34d93..aa3d00ff 100644 --- a/Source/Cmlx/mlx-generated/steel_conv_general.cpp +++ b/Source/Cmlx/mlx-generated/steel_conv_general.cpp @@ -89,6 +89,42 @@ struct Conv2DInputBlockLoaderGeneral { } } else { +#pragma clang loop unroll(full) + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = T(0); + } + } + } + } + METAL_FUNC void load_safe(const short remaining_k) const { +#pragma clang loop unroll(full) + for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) { + int n = read_n[i]; + int h_flip = params->flip ? params->wS[0] - weight_h - 1 : weight_h; + int w_flip = params->flip ? params->wS[1] - weight_w - 1 : weight_w; + int ih_dil = read_ih[i] + h_flip * params->kdil[0]; + int iw_dil = read_iw[i] + w_flip * params->kdil[1]; + int ih = ih_dil / params->idil[0]; + int iw = iw_dil / params->idil[1]; + size_t offset = ih * params->in_strides[1] + iw * params->in_strides[2]; + if ((n < params->N) && (ih_dil >= 0 && ih < params->iS[0]) && + (iw_dil >= 0 && iw < params->iS[1])) { + if (bj + vec_size <= remaining_k) { +#pragma clang loop unroll(full) + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = (src[i])[offset + j]; + } + } else { + for (short j = 0; j < vec_size; ++j) { + if (bj + j < remaining_k) { + dst[is * dst_ld + j] = (src[i])[offset + j]; + } else { + dst[is * dst_ld + j] = T(0); + } + } + } + } + else { #pragma clang loop unroll(full) for (short j = 0; j < vec_size; ++j) { dst[is * dst_ld + j] = T(0); @@ -184,6 +220,53 @@ struct Conv2DWeightBlockLoaderGeneral { dst[i * dst_ld + j] = curr_src[i * src_ld + j]; } } else { +#pragma clang loop unroll(full) + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + } + } + } + METAL_FUNC void load_safe(const short remaining_k) const { + const device T* curr_src = src + weight_h * params->wt_strides[1] + + weight_w * params->wt_strides[2]; + if ((start_row + BN <= params->O)) { +#pragma clang loop unroll(full) + for (short i = 0; i < BN; i += TROWS) { + if (bj + vec_size <= remaining_k) { +#pragma clang loop unroll(full) + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } + } else { + for (short j = 0; j < vec_size; j++) { + if (bj + j < remaining_k) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } else { + dst[i * dst_ld + j] = T(0); + } + } + } + } + } else { + for (short i = 0; i < BN; i += TROWS) { + if ((start_row + i) < params->O) { + if (bj + vec_size <= remaining_k) { +#pragma clang loop unroll(full) + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } + } else { + for (short j = 0; j < vec_size; j++) { + if (bj + j < remaining_k) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } else { + dst[i * dst_ld + j] = T(0); + } + } + } + } else { #pragma clang loop unroll(full) for (short j = 0; j < vec_size; j++) { dst[i * dst_ld + j] = T(0); @@ -209,6 +292,7 @@ struct Conv2DWeightBlockLoaderGeneral { } } +constant bool align_C [[function_constant(200)]]; template < typename T, int BM, @@ -302,16 +386,41 @@ implicit_gemm_conv_2d_general( simd_gid, simd_lid); mma_t mma_op(simd_gid, simd_lid); - int gemm_k_iterations = - base_wh_size * base_ww_size * gemm_params->gemm_k_iterations; - for (int k = 0; k < gemm_k_iterations; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - loader_a.load_unsafe(); - loader_b.load_unsafe(); - threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(As, Bs); - loader_a.next(); - loader_b.next(); + if (align_C) { + int gemm_k_iterations = + base_wh_size * base_ww_size * gemm_params->gemm_k_iterations; + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_unsafe(); + loader_b.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); + loader_a.next(); + loader_b.next(); + } + } + else { + for (int k = 1; k < gemm_params->gemm_k_iterations; k++) { + for (int j = 0; j < base_wh_size * base_ww_size; j++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_unsafe(); + loader_b.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); + loader_a.next(); + loader_b.next(); + } + } + const short remaining_k = params->C % BK; + for (int j = 0; j < base_wh_size * base_ww_size; j++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_safe(remaining_k); + loader_b.load_safe(remaining_k); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); + loader_a.next(); + loader_b.next(); + } } threadgroup_barrier(mem_flags::mem_none); { diff --git a/Source/Cmlx/mlx-generated/steel_gemm_fused.cpp b/Source/Cmlx/mlx-generated/steel_gemm_fused.cpp index a3176dc2..6a829e9b 100644 --- a/Source/Cmlx/mlx-generated/steel_gemm_fused.cpp +++ b/Source/Cmlx/mlx-generated/steel_gemm_fused.cpp @@ -26,8 +26,8 @@ template < device T* D [[buffer(3)]], const constant GEMMParams* params [[buffer(4)]], const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], - const constant int* batch_shape [[buffer(6)]], - const constant int64_t* batch_strides [[buffer(7)]], + const constant int* batch_shape [[buffer(6), function_constant(has_batch)]], + const constant int64_t* batch_strides [[buffer(7), function_constant(has_batch)]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint3 tid [[threadgroup_position_in_grid]], diff --git a/Source/Cmlx/mlx-generated/steel_gemm_segmented.cpp b/Source/Cmlx/mlx-generated/steel_gemm_segmented.cpp new file mode 100644 index 00000000..d6587840 --- /dev/null +++ b/Source/Cmlx/mlx-generated/steel_gemm_segmented.cpp @@ -0,0 +1,207 @@ +namespace mlx::core::metal { + +const char* steel_gemm_segmented() { + return R"preamble( +using namespace mlx::steel; +constant bool segments_contiguous [[function_constant(199)]]; +constant bool align_M [[function_constant(200)]]; +constant bool align_N [[function_constant(201)]]; +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + typename AccumType = float> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void segmented_mm( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + const device uint32_t* segments [[buffer(2)]], + device T* C [[buffer(3)]], + const constant GEMMParams* params [[buffer(4)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]]) { + using gemm_kernel = GEMMKernel< + T, + T, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + true, + true, + AccumType>; + using loader_a_t = typename gemm_kernel::loader_a_t; + using loader_b_t = typename gemm_kernel::loader_b_t; + using mma_t = typename gemm_kernel::mma_t; + if (params->tiles_n <= static_cast(tid.x) || + params->tiles_m <= static_cast(tid.y)) { + return; + } + threadgroup T As[gemm_kernel::tgp_mem_size_a]; + threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; + const int c_row = tid.y * BM; + const int c_col = tid.x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row)); + const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col)); + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + C += c_row_long * params->ldd + c_col_long; + uint32_t k_start, k_end; + if (segments_contiguous) { + k_start = segments[2 * tid.z]; + k_end = segments[2 * tid.z + 1]; + } else { + k_start = segments[tid.z]; + k_end = segments[tid.z + 1]; + } + A += transpose_a ? k_start * params->lda : k_start; + B += transpose_b ? k_start : k_start * params->ldb; + C += tid.z * params->batch_stride_d; + thread mma_t mma_op(simd_group_id, simd_lane_id); + thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); + thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); + if (align_M && align_N) { + uint32_t k = k_start + BK; + for (; k <= k_end; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_unsafe(); + loader_b.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); + loader_a.next(); + loader_b.next(); + } + short k_remain = BK - short(k - k_end); + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + if (k_remain > 0) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); + } + mma_op.store_result(C, params->ldd); + } else { + if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { + uint32_t k = k_start + BK; + for (; k <= k_end; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_unsafe(); + loader_b.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); + loader_a.next(); + loader_b.next(); + } + short k_remain = BK - short(k - k_end); + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + if (k_remain > 0) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); + } + mma_op.store_result(C, params->ldd); + } + else if (align_N || tgp_bn == BN) { + uint32_t k = k_start + BK; + for (; k <= k_end; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_safe( + transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm)); + loader_b.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); + loader_a.next(); + loader_b.next(); + } + short k_remain = BK - short(k - k_end); + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + if (k_remain > 0) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); + } + mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm)); + } + else if (align_M || tgp_bm == BM) { + uint32_t k = k_start + BK; + for (; k <= k_end; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_unsafe(); + loader_b.load_safe( + transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK)); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); + loader_a.next(); + loader_b.next(); + } + short k_remain = BK - short(k - k_end); + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + if (k_remain > 0) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); + } + mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm)); + } + else { + uint32_t k = k_start + BK; + for (; k <= k_end; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_safe( + transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm)); + loader_b.load_safe( + transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK)); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); + loader_a.next(); + loader_b.next(); + } + short k_remain = BK - short(k - k_end); + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + if (k_remain > 0) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); + } + mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm)); + } + } +} +)preamble"; +} + +} // namespace mlx::core::metal diff --git a/Source/Cmlx/mlx-generated/ternary.cpp b/Source/Cmlx/mlx-generated/ternary.cpp index 143ee0d4..7e760273 100644 --- a/Source/Cmlx/mlx-generated/ternary.cpp +++ b/Source/Cmlx/mlx-generated/ternary.cpp @@ -2,25 +2,44 @@ namespace mlx::core::metal { const char* ternary() { return R"preamble( -template +template ::n> [[kernel]] void ternary_v( device const bool* a, device const T* b, device const T* c, device T* d, + constant uint& size, uint index [[thread_position_in_grid]]) { - d[index] = Op()(a[index], b[index], c[index]); + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + d[index + i] = Op()(a[index + i], b[index + i], c[index + i]); + } + } else { + for (int i = 0; i < N; ++i) { + d[index + i] = Op()(a[index + i], b[index + i], c[index + i]); + } + } } -template +template ::n> [[kernel]] void ternary_v2( device const bool* a, device const T* b, device const T* c, device T* d, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - d[offset] = Op()(a[offset], b[offset], c[offset]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]); + } + } else { + for (int i = 0; i < N; ++i) { + d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]); + } + } } template [[kernel]] void ternary_g_nd1( diff --git a/Source/Cmlx/mlx-generated/unary.cpp b/Source/Cmlx/mlx-generated/unary.cpp index bb5a5867..c55daadd 100644 --- a/Source/Cmlx/mlx-generated/unary.cpp +++ b/Source/Cmlx/mlx-generated/unary.cpp @@ -2,21 +2,40 @@ namespace mlx::core::metal { const char* unary() { return R"preamble( -template +template ::n> [[kernel]] void unary_v( device const T* in, device U* out, + constant uint& size, uint index [[thread_position_in_grid]]) { - out[index] = Op()(in[index]); + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + out[index + i] = Op()(in[index + i]); + } + } else { + for (int i = 0; i < N; ++i) { + out[index + i] = Op()(in[index + i]); + } + } } -template +template ::n> [[kernel]] void unary_v2( device const T* in, device U* out, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - out[offset] = Op()(in[offset]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + out[offset + i] = Op()(in[offset + i]); + } + } else { + for (int i = 0; i < N; ++i) { + out[offset + i] = Op()(in[offset + i]); + } + } } template < typename T, diff --git a/Source/Cmlx/mlx-generated/unary_ops.cpp b/Source/Cmlx/mlx-generated/unary_ops.cpp index b63e96b7..20531425 100644 --- a/Source/Cmlx/mlx-generated/unary_ops.cpp +++ b/Source/Cmlx/mlx-generated/unary_ops.cpp @@ -2,6 +2,82 @@ namespace mlx::core::metal { const char* unary_ops() { return R"preamble( +using ieee_float_shape_type = union { + float value; + uint32_t word; +}; +inline void get_float_word(thread uint32_t& i, float d) { + ieee_float_shape_type gf_u; + gf_u.value = (d); + (i) = gf_u.word; +} +inline void get_float_word(thread int32_t& i, float d) { + ieee_float_shape_type gf_u; + gf_u.value = (d); + (i) = gf_u.word; +} +inline void set_float_word(thread float& d, uint32_t i) { + ieee_float_shape_type sf_u; + sf_u.word = (i); + (d) = sf_u.value; +} +inline float frexp_expf(float x, thread int* expt) { + const uint32_t k = 235; + const float kln2 = 162.88958740F; + float exp_x; + uint32_t hx; + exp_x = metal::exp(x - kln2); + get_float_word(hx, exp_x); + *expt = (hx >> 23) - (0x7f + 127) + k; + set_float_word(exp_x, (hx & 0x7fffff) | ((0x7f + 127) << 23)); + return exp_x; +} +inline complex64_t ldexp_cexpf(complex64_t z, int expt) { + float x, y, exp_x, scale1, scale2; + int ex_expt, half_expt; + x = z.real; + y = z.imag; + exp_x = frexp_expf(x, &ex_expt); + expt += ex_expt; + half_expt = expt / 2; + set_float_word(scale1, (0x7f + half_expt) << 23); + half_expt = expt - half_expt; + set_float_word(scale2, (0x7f + half_expt) << 23); + return complex64_t{ + metal::cos(y) * exp_x * scale1 * scale2, + metal::sin(y) * exp_x * scale1 * scale2}; +} +inline complex64_t cexpf(const thread complex64_t& z) { + float x, y, exp_x; + uint32_t hx, hy; + const uint32_t exp_ovfl = 0x42b17218, cexp_ovfl = 0x43400074; + x = z.real; + y = z.imag; + get_float_word(hy, y); + hy &= 0x7fffffff; + if (hy == 0) { + return complex64_t{metal::exp(x), y}; + } + get_float_word(hx, x); + if ((hx & 0x7fffffff) == 0) { + return complex64_t{metal::cos(y), metal::sin(y)}; + } + if (hy >= 0x7f800000) { + if ((hx & 0x7fffffff) != 0x7f800000) { + return complex64_t{y - y, y - y}; + } else if (hx & 0x80000000) { + return complex64_t{0.0, 0.0}; + } else { + return complex64_t{x, y - y}; + } + } + if (hx >= exp_ovfl && hx <= cexp_ovfl) { + return ldexp_cexpf(z, 0); + } else { + exp_x = metal::exp(x); + return complex64_t{exp_x * metal::cos(y), exp_x * metal::sin(y)}; + } +} float erf(float a) { float r, s, t, u; t = metal::abs(a); @@ -247,8 +323,7 @@ struct Exp { return metal::precise::exp(x); }; complex64_t operator()(complex64_t x) { - auto m = metal::precise::exp(x.real); - return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)}; + return cexpf(x); } }; struct Expm1 { diff --git a/Source/Cmlx/mlx-generated/utils.cpp b/Source/Cmlx/mlx-generated/utils.cpp index 73eebac4..e8a66b47 100644 --- a/Source/Cmlx/mlx-generated/utils.cpp +++ b/Source/Cmlx/mlx-generated/utils.cpp @@ -310,6 +310,11 @@ static constant constexpr int RMS_LOOPED_LIMIT = 4096; typedef half float16_t; template +struct WorkPerThread { + static_assert(sizeof(U) <= 8, "Type too large"); + static constexpr int constant n = 8 / sizeof(U); +}; +template struct Limits { static const constant U max = metal::numeric_limits::max(); static const constant U min = metal::numeric_limits::min(); diff --git a/Source/MLX/MLXArray+Bytes.swift b/Source/MLX/MLXArray+Bytes.swift index 4b969333..db41db25 100644 --- a/Source/MLX/MLXArray+Bytes.swift +++ b/Source/MLX/MLXArray+Bytes.swift @@ -285,8 +285,6 @@ extension MLXArray { /// - ``asArray(_:)`` /// - ``asData(access:)`` public func asMTLBuffer(device: any MTLDevice, noCopy: Bool = false) -> (any MTLBuffer)? { - let data = asData(access: noCopy ? .noCopyIfContiguous : .copy) - self.eval() if noCopy && self.contiguousToDimension() == 0 { diff --git a/Source/MLX/MLXArray+Indexing.swift b/Source/MLX/MLXArray+Indexing.swift index fe098a38..a56f3b29 100644 --- a/Source/MLX/MLXArray+Indexing.swift +++ b/Source/MLX/MLXArray+Indexing.swift @@ -853,7 +853,7 @@ func updateSlice( var strides = [Int32](repeating: 1, count: ndim) // If it's just a simple slice, just do a slice update and return - if operations.count == 1, case let .slice(slice) = operations[0] { + if operations.count == 1, case .slice(let slice) = operations[0] { let size = src.dim(0).int32 starts[0] = slice.start(size) ends[0] = slice.end(size) diff --git a/Source/MLX/MLXFast.swift b/Source/MLX/MLXFast.swift index 0fac1b19..4cb2e28f 100644 --- a/Source/MLX/MLXFast.swift +++ b/Source/MLX/MLXFast.swift @@ -75,10 +75,12 @@ public enum MLXFast { /// - values: values with shape `[B, N_kv, T_kv, D]` /// - scale: scale for queries, typically `1 / sqrt(q.dim(-1))` /// - mask: mask array + /// - sinks: optional array of attention sinks /// - memoryEfficientThreshold: unused /// - stream: stream to evaluate on public static func scaledDotProductAttention( queries: MLXArray, keys: MLXArray, values: MLXArray, scale: Float, mask: MLXArray?, + sinks: MLXArray? = nil, memoryEfficientThreshold: Int? = nil, stream: StreamOrDevice = .default ) -> MLXArray { let masks = @@ -95,6 +97,7 @@ public enum MLXFast { &result, queries.ctx, keys.ctx, values.ctx, scale, "", masks, + (sinks ?? .mlxNone).ctx, stream.ctx) return MLXArray(result) } @@ -161,10 +164,13 @@ public enum MLXFast { /// - values: values with shape `[B, N_kv, T_kv, D]` /// - scale: scale for queries, typically `1 / sqrt(q.dim(-1))` /// - mask: a ``ScaledDotProductAttentionMaskMode`` + /// - sinks: optional array of attention sinks /// - stream: stream to evaluate on public static func scaledDotProductAttention( queries: MLXArray, keys: MLXArray, values: MLXArray, scale: Float, - mask: ScaledDotProductAttentionMaskMode, stream: StreamOrDevice = .default + mask: ScaledDotProductAttentionMaskMode, + sinks: MLXArray? = nil, + stream: StreamOrDevice = .default ) -> MLXArray { var result = mlx_array_new() @@ -180,6 +186,7 @@ public enum MLXFast { &result, queries.ctx, keys.ctx, values.ctx, scale, mask.mode, masks, + (sinks ?? .mlxNone).ctx, stream.ctx) return MLXArray(result) } diff --git a/Source/MLX/Ops.swift b/Source/MLX/Ops.swift index 12e94182..6f7adaff 100644 --- a/Source/MLX/Ops.swift +++ b/Source/MLX/Ops.swift @@ -997,21 +997,76 @@ public func degrees(_ array: MLXArray, stream: StreamOrDevice = .default) -> MLX return MLXArray(result) } +/// Quantization modes for weight compression in neural networks. +/// +/// Quantization reduces the precision of model weights to decrease memory usage and +/// potentially improve inference speed. Different modes use different strategies for +/// mapping full-precision values to lower-precision representations. +public enum QuantizationMode: String, Codable, Sendable { + /// Affine (linear) quantization with scale and bias parameters. + /// + /// This is the standard quantization approach where values are quantized using: + /// ``` + /// quantized_value = round((value - bias) / scale) + /// dequantized_value = quantized_value * scale + bias + /// ``` + /// + /// The `scale` and `bias` parameters are computed per group of elements (typically 64 or 128 elements) + /// to minimize quantization error. This mode provides good compression with reasonable accuracy preservation + /// for most neural network weights. + /// + /// ### See Also + /// - ``dequantized(_:scales:biases:groupSize:bits:mode:stream:)`` + /// - ``quantized(_:groupSize:bits:mode:stream:)`` + /// - ``quantizedMatmul(_:_:scales:biases:transpose:groupSize:bits:mode:stream:)`` + case affine + + /// MX (Microscaling) FP4 quantization format. + /// + /// MXFP4 is a specialized 4-bit floating-point format designed for neural network inference. + /// It uses a shared exponent across a block of values with individual 3-bit mantissas plus sign bits. + /// This format can provide better accuracy than standard 4-bit integer quantization for certain + /// weight distributions commonly found in transformer models. + /// + /// The format consists of: + /// - Shared 8-bit exponent per block + /// - Individual 3-bit mantissas + 1 sign bit per element + /// + /// ### See Also + /// - ``dequantized(_:scales:biases:groupSize:bits:mode:stream:)`` + /// - ``quantized(_:groupSize:bits:mode:stream:)`` + /// - ``quantizedMatmul(_:_:scales:biases:transpose:groupSize:bits:mode:stream:)`` + case mxfp4 +} + /// Dequantize the matrix `w` using the provided `scales` and /// `biases` and the `group_size` and `bits` configuration. /// /// For details, please see /// [this documentation](https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.dequantize.html) /// +/// - Parameters: +/// - w: The quantized weight matrix to dequantize +/// - scales: Scaling factors used during quantization. Should have shape compatible with the quantized groups +/// - biases: Bias values used during quantization. Should have shape compatible with the quantized groups +/// - groupSize: The size of each quantization group. Elements are quantized in groups of this size. Default is 64 +/// - bits: The number of bits used per quantized element. Default is 4 +/// - mode: The quantization mode used. Either `.affine` for standard affine quantization or `.mxfp4` for MXFP4 format. Default is `.affine` +/// - stream: Stream or device to evaluate on +/// /// ### See Also /// - ``quantized(_:groupSize:bits:stream:)`` /// - ``quantizedMatmul(_:_:scales:biases:transpose:groupSize:bits:stream:)`` public func dequantized( - _ w: MLXArray, scales: MLXArray, biases: MLXArray, groupSize: Int = 64, bits: Int = 4, + _ w: MLXArray, scales: MLXArray, biases: MLXArray?, groupSize: Int = 64, bits: Int = 4, + mode: QuantizationMode = .affine, stream: StreamOrDevice = .default ) -> MLXArray { var result = mlx_array_new() - mlx_dequantize(&result, w.ctx, scales.ctx, biases.ctx, groupSize.int32, bits.int32, stream.ctx) + mlx_dequantize( + &result, w.ctx, scales.ctx, (biases ?? .mlxNone).ctx, groupSize.int32, bits.int32, + mode.rawValue, + stream.ctx) return MLXArray(result) } @@ -1261,22 +1316,39 @@ public func gatherMatmul( /// Note that ``scales`` and ``biases`` must have the same batch dimensions /// as ``w`` since they represent the same quantized matrix. /// +/// - Parameters: +/// - x: The input matrix +/// - w: The quantized weight matrix to be used in the matrix multiplication +/// - scales: The scales to use per `groupSize` elements of `w` +/// - biases: The biases to use per `groupSize` elements of `w` +/// - lhsIndices: Optional indices for gathering from the left-hand side matrix +/// - rhsIndices: Optional indices for gathering from the right-hand side matrix +/// - transpose: Whether to transpose the weight matrix `w`. Default is `true` +/// - groupSize: The size of the group in `w` that shares a scale and bias. Default is `64` +/// - bits: The number of bits occupied by each element in `w`. Default is `4` +/// - mode: The quantization mode. Default is `.affine` +/// - sortedIndices: Whether the indices are sorted. Default is `false` +/// - stream: Stream or device to evaluate on +/// /// ### See Also /// - /// - ``quantizedMatmul(_:_:scales:biases:transpose:groupSize:bits:stream:)`` public func gatherQuantizedMatmul( - _ x: MLXArray, _ w: MLXArray, scales: MLXArray, biases: MLXArray, + _ x: MLXArray, _ w: MLXArray, scales: MLXArray, biases: MLXArray?, lhsIndices: MLXArray? = nil, rhsIndices: MLXArray? = nil, transpose: Bool = true, groupSize: Int = 64, bits: Int = 4, - sortedIndices: Bool = false, stream: StreamOrDevice = .default + mode: QuantizationMode = .affine, + sortedIndices: Bool = false, + stream: StreamOrDevice = .default ) -> MLXArray { var result = mlx_array_new() mlx_gather_qmm( &result, - x.ctx, w.ctx, scales.ctx, biases.ctx, (lhsIndices ?? .mlxNone).ctx, + x.ctx, w.ctx, scales.ctx, (biases ?? .mlxNone).ctx, (lhsIndices ?? .mlxNone).ctx, (rhsIndices ?? .mlxNone).ctx, transpose, - groupSize.int32, bits.int32, sortedIndices, stream.ctx) + groupSize.int32, bits.int32, mode.rawValue, sortedIndices, + stream.ctx) return MLXArray(result) } @@ -2036,6 +2108,14 @@ public func putAlong( /// /// > `quantized` currently only supports 2D inputs with dimensions which are multiples of 32 /// +/// - Parameters: +/// - w: Matrix to be quantized +/// - groupSize: The size of the group in `w` that shares a scale and bias. Default is `64` +/// - bits: The number of bits occupied by each element of `w` in the returned quantized matrix. Default is `4` +/// - mode: The quantization mode. Default is `.affine` +/// - stream: Stream or device to evaluate on +/// - Returns: A tuple containing the quantized weights (`wq`), scaling factors (`scales`), and bias values (`biases`) +/// /// For details, please see /// [this documentation](https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.quantize.html) /// @@ -2043,14 +2123,18 @@ public func putAlong( /// - ``dequantized(_:scales:biases:groupSize:bits:stream:)`` /// - ``quantizedMatmul(_:_:scales:biases:transpose:groupSize:bits:stream:)`` public func quantized( - _ w: MLXArray, groupSize: Int = 64, bits: Int = 4, stream: StreamOrDevice = .default + _ w: MLXArray, groupSize: Int = 64, bits: Int = 4, + mode: QuantizationMode = .affine, + stream: StreamOrDevice = .default ) -> (wq: MLXArray, scales: MLXArray, biases: MLXArray) { - var r1 = mlx_array_new() - var r2 = mlx_array_new() - var r3 = mlx_array_new() - mlx_quantize(&r1, &r2, &r3, w.ctx, groupSize.int32, bits.int32, stream.ctx) + var r = mlx_vector_array_new() + defer { mlx_vector_array_free(r) } + mlx_quantize( + &r, w.ctx, groupSize.int32, bits.int32, mode.rawValue, + stream.ctx) - return (MLXArray(r1), MLXArray(r2), MLXArray(r3)) + let arrays = mlx_vector_array_values(r) + return (arrays[0], arrays[1], arrays[2]) } /// Perform the matrix multiplication with the quantized matrix `w`. The @@ -2058,17 +2142,35 @@ public func quantized( /// elements. Each element in `w` takes `bits` bits and is packed in an /// unsigned 32 bit integer. /// +/// - Parameters: +/// - x: Input array +/// - w: Quantized matrix packed in unsigned integers +/// - scales: The scales to use per `groupSize` elements of `w` +/// - biases: The biases to use per `groupSize` elements of `w` +/// - transpose: Defines whether to multiply with the transposed `w` or not, +/// namely whether we are performing `x @ w.T` or `x @ w`. Default is `true` +/// - groupSize: The size of the group in `w` that shares a scale and bias. Default is `64` +/// - bits: The number of bits occupied by each element in `w`. Default is `4` +/// - mode: The quantization mode. Default is `.affine` +/// - stream: Stream or device to evaluate on +/// /// ### See Also /// - ``dequantized(_:scales:biases:groupSize:bits:stream:)`` /// - ``quantized(_:groupSize:bits:stream:)`` public func quantizedMatmul( - _ x: MLXArray, _ w: MLXArray, scales: MLXArray, biases: MLXArray, transpose: Bool = true, - groupSize: Int = 64, bits: Int = 4, stream: StreamOrDevice = .default + _ x: MLXArray, _ w: MLXArray, scales: MLXArray, biases: MLXArray?, + transpose: Bool = true, + groupSize: Int = 64, bits: Int = 4, + mode: QuantizationMode = .affine, + stream: StreamOrDevice = .default ) -> MLXArray { var result = mlx_array_new() mlx_quantized_matmul( &result, - x.ctx, w.ctx, scales.ctx, biases.ctx, transpose, groupSize.int32, bits.int32, stream.ctx + x.ctx, w.ctx, scales.ctx, (biases ?? .mlxNone).ctx, + transpose, groupSize.int32, bits.int32, + mode.rawValue, + stream.ctx ) return MLXArray(result) } diff --git a/Source/MLXNN/Module.swift b/Source/MLXNN/Module.swift index 54ade86d..2f7012a7 100644 --- a/Source/MLXNN/Module.swift +++ b/Source/MLXNN/Module.swift @@ -1582,8 +1582,9 @@ extension UpdateError: LocalizedError { "Unable to collect modules from container: \(path.joined(separator: ".")) in \(modules.joined(separator: "."))" case .mismatchedContainers(let base, let key): return "Mismatched containers: \(base) \(key)" - case let .mismatchedSize( - path, modules, expectedShape: expectedShape, actualShape: actualShape): + case .mismatchedSize( + let + path, let modules, let expectedShape, let actualShape): return "Mismatched parameter \(path.joined(separator: ".")) in \(modules.joined(separator: ".")) shape. Actual \(actualShape), expected \(expectedShape)" case .keyNotFound(let path, let modules): diff --git a/Tests/MLXTests/ExportTests.swift b/Tests/MLXTests/ExportTests.swift index 68626132..babcb56b 100644 --- a/Tests/MLXTests/ExportTests.swift +++ b/Tests/MLXTests/ExportTests.swift @@ -52,7 +52,7 @@ class ExportTests: XCTestCase { [arrays.dropFirst().reduce(arrays[0], +)] } - let x = MLXArray(1) + let x = MLXArray([1]) try exportFunctions(to: url, shapeless: true, f) { export in try export(x) diff --git a/Tests/MLXTests/ModuleTests.swift b/Tests/MLXTests/ModuleTests.swift index 6d3c4962..0059223e 100644 --- a/Tests/MLXTests/ModuleTests.swift +++ b/Tests/MLXTests/ModuleTests.swift @@ -555,7 +555,7 @@ class ModuleTests: XCTestCase { verify: .all) ) { error in guard let error = error as? UpdateError, - case let .keyNotFound(path, modules) = error + case .keyNotFound(let path, let modules) = error else { XCTFail("Expected to fail with UpdateError.keyNotFound, but got: \(error)") return @@ -586,8 +586,10 @@ class ModuleTests: XCTestCase { verify: .all) ) { error in guard let error = error as? UpdateError, - case let .mismatchedSize( - path, modules, expectedShape: expectedShape, actualShape: actualShape) = + case .mismatchedSize( + let + path, let modules, expectedShape: let expectedShape, + actualShape: let actualShape) = error else { XCTFail("Expected to fail with UpdateError.mismatchedSize, but got: \(error)") diff --git a/tools/update-mlx.sh b/tools/update-mlx.sh index d2e0a804..a52a191b 100755 --- a/tools/update-mlx.sh +++ b/tools/update-mlx.sh @@ -30,13 +30,16 @@ make \ conv \ copy \ fft \ + fp4_quantized \ gather \ gather_axis \ + gather_front \ gemm \ gemv_masked \ hadamard \ logsumexp \ quantized \ + quantized_utils \ reduce \ reduce_utils \ scan \ @@ -49,6 +52,7 @@ make \ steel_gemm_fused \ steel_gemm_gather \ steel_gemm_masked \ + steel_gemm_segmented \ steel_gemm_splitk \ ternary \ ternary_ops \