Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ jobs:
command: |
pip install pre-commit
brew install swift-format
pre-commit run --all
if ! git diff --quiet; then echo 'Style checks failed, please install pre-commit and run pre-commit run --all and push the change'; exit 1; fi
pre-commit run --all || (echo "Style checks failed, please install pre-commit and run pre-commit run --all and push the change"; echo ""; git --no-pager diff; exit 1)
- run:
name: Run Tests (Xcode, macOS)
command: |
Expand Down
11 changes: 8 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
repos:
- repo: https://github.com/slessans/pre-commit-swift-format
rev: ""

- repo: local
hooks:
- id: swift-format
args: ["--configuration", ".swift-format"]
name: swift-format
language: system
entry: swift-format format --in-place --configuration .swift-format --recursive .
require_serial: true
types: [swift]

- repo: https://github.com/cheshirekow/cmake-format-precommit
rev: v0.6.10
hooks:
Expand Down
34 changes: 33 additions & 1 deletion Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,38 @@ let package = Package(
"mlx/tests",

// opt-out of these backends (using metal)
"mlx/mlx/backend/no_metal",
"mlx/mlx/backend/no_gpu",
"mlx/mlx/backend/no_cpu",
"mlx/mlx/backend/metal/no_metal.cpp",

// special handling for cuda -- we need to keep one file:
// mlx/mlx/backend/cuda/no_cuda.cpp
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a little more complicated than I wish, but we can't exclude the directory + include one file, so I need to just list them.


"mlx/mlx/backend/cuda/allocator.cpp",
"mlx/mlx/backend/cuda/compiled.cpp",
"mlx/mlx/backend/cuda/conv.cpp",
"mlx/mlx/backend/cuda/cuda.cpp",
"mlx/mlx/backend/cuda/cudnn_utils.cpp",
"mlx/mlx/backend/cuda/custom_kernel.cpp",
"mlx/mlx/backend/cuda/device.cpp",
"mlx/mlx/backend/cuda/eval.cpp",
"mlx/mlx/backend/cuda/fence.cpp",
"mlx/mlx/backend/cuda/indexing.cpp",
"mlx/mlx/backend/cuda/jit_module.cpp",
"mlx/mlx/backend/cuda/matmul.cpp",
"mlx/mlx/backend/cuda/primitives.cpp",
"mlx/mlx/backend/cuda/slicing.cpp",
"mlx/mlx/backend/cuda/utils.cpp",
"mlx/mlx/backend/cuda/worker.cpp",
"mlx/mlx/backend/cuda/unary",
"mlx/mlx/backend/cuda/gemms",
"mlx/mlx/backend/cuda/steel",
"mlx/mlx/backend/cuda/reduce",
"mlx/mlx/backend/cuda/quantized",
"mlx/mlx/backend/cuda/conv",
"mlx/mlx/backend/cuda/copy",
"mlx/mlx/backend/cuda/device",
"mlx/mlx/backend/cuda/binary",

// build variants (we are opting _out_ of these)
"mlx/mlx/io/no_safetensors.cpp",
Expand All @@ -89,6 +119,8 @@ let package = Package(
// do not build distributed support (yet)
"mlx/mlx/distributed/mpi/mpi.cpp",
"mlx/mlx/distributed/ring/ring.cpp",
"mlx/mlx/distributed/nccl/nccl.cpp",
"mlx/mlx/distributed/nccl/nccl_stub",

// bnns instead of simd (accelerate)
"mlx/mlx/backend/cpu/gemms/simd_fp16.cpp",
Expand Down
80 changes: 64 additions & 16 deletions Source/Cmlx/include/mlx/c/fast.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,68 @@ extern "C" {
* \defgroup fast Fast custom operations
*/
/**@{*/
int mlx_fast_affine_dequantize(
mlx_array* res,
const mlx_array w,
const mlx_array scales,
const mlx_array biases,
int group_size,
int bits,
const mlx_stream s);
int mlx_fast_affine_quantize(
mlx_array* res_0,
mlx_array* res_1,
mlx_array* res_2,
const mlx_array w,
int group_size,
int bits,
const mlx_stream s);

typedef struct mlx_fast_cuda_kernel_config_ {
void* ctx;
} mlx_fast_cuda_kernel_config;
mlx_fast_cuda_kernel_config mlx_fast_cuda_kernel_config_new();
void mlx_fast_cuda_kernel_config_free(mlx_fast_cuda_kernel_config cls);

int mlx_fast_cuda_kernel_config_add_output_arg(
mlx_fast_cuda_kernel_config cls,
const int* shape,
size_t size,
mlx_dtype dtype);
int mlx_fast_cuda_kernel_config_set_grid(
mlx_fast_cuda_kernel_config cls,
int grid1,
int grid2,
int grid3);
int mlx_fast_cuda_kernel_config_set_thread_group(
mlx_fast_cuda_kernel_config cls,
int thread1,
int thread2,
int thread3);
int mlx_fast_cuda_kernel_config_set_init_value(
mlx_fast_cuda_kernel_config cls,
float value);
int mlx_fast_cuda_kernel_config_set_verbose(
mlx_fast_cuda_kernel_config cls,
bool verbose);
int mlx_fast_cuda_kernel_config_add_template_arg_dtype(
mlx_fast_cuda_kernel_config cls,
const char* name,
mlx_dtype dtype);
int mlx_fast_cuda_kernel_config_add_template_arg_int(
mlx_fast_cuda_kernel_config cls,
const char* name,
int value);
int mlx_fast_cuda_kernel_config_add_template_arg_bool(
mlx_fast_cuda_kernel_config cls,
const char* name,
bool value);

typedef struct mlx_fast_cuda_kernel_ {
void* ctx;
} mlx_fast_cuda_kernel;

mlx_fast_cuda_kernel mlx_fast_cuda_kernel_new(
const char* name,
const mlx_vector_string input_names,
const mlx_vector_string output_names,
const char* source,
const char* header,
bool ensure_row_contiguous,
bool atomic_outputs);

void mlx_fast_cuda_kernel_free(mlx_fast_cuda_kernel cls);
int mlx_fast_cuda_kernel_apply(
mlx_vector_array* outputs,
mlx_fast_cuda_kernel cls,
const mlx_vector_array inputs,
const mlx_fast_cuda_kernel_config config,
const mlx_stream stream);

int mlx_fast_layer_norm(
mlx_array* res,
const mlx_array x,
Expand Down Expand Up @@ -103,6 +149,7 @@ mlx_fast_metal_kernel mlx_fast_metal_kernel_new(
const char* header,
bool ensure_row_contiguous,
bool atomic_outputs);

void mlx_fast_metal_kernel_free(mlx_fast_metal_kernel cls);
int mlx_fast_metal_kernel_apply(
mlx_vector_array* outputs,
Expand Down Expand Up @@ -135,6 +182,7 @@ int mlx_fast_scaled_dot_product_attention(
float scale,
const char* mask_mode,
const mlx_vector_array mask_arrs,
const mlx_array sinks /* may be null */,
const mlx_stream s);
/**@}*/

Expand Down
12 changes: 12 additions & 0 deletions Source/Cmlx/include/mlx/c/fft.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ int mlx_fft_fftn(
const int* axes,
size_t axes_num,
const mlx_stream s);
int mlx_fft_fftshift(
mlx_array* res,
const mlx_array a,
const int* axes,
size_t axes_num,
const mlx_stream s);
int mlx_fft_ifft(
mlx_array* res,
const mlx_array a,
Expand All @@ -71,6 +77,12 @@ int mlx_fft_ifftn(
const int* axes,
size_t axes_num,
const mlx_stream s);
int mlx_fft_ifftshift(
mlx_array* res,
const mlx_array a,
const int* axes,
size_t axes_num,
const mlx_stream s);
int mlx_fft_irfft(
mlx_array* res,
const mlx_array a,
Expand Down
6 changes: 6 additions & 0 deletions Source/Cmlx/include/mlx/c/linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,18 @@ int mlx_linalg_cross(
const mlx_array b,
int axis,
const mlx_stream s);
int mlx_linalg_eig(
mlx_array* res_0,
mlx_array* res_1,
const mlx_array a,
const mlx_stream s);
int mlx_linalg_eigh(
mlx_array* res_0,
mlx_array* res_1,
const mlx_array a,
const char* UPLO,
const mlx_stream s);
int mlx_linalg_eigvals(mlx_array* res, const mlx_array a, const mlx_stream s);
int mlx_linalg_eigvalsh(
mlx_array* res,
const mlx_array a,
Expand Down
20 changes: 14 additions & 6 deletions Source/Cmlx/include/mlx/c/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -357,9 +357,10 @@ int mlx_dequantize(
mlx_array* res,
const mlx_array w,
const mlx_array scales,
const mlx_array biases,
const mlx_array biases /* may be null */,
int group_size,
int bits,
const char* mode,
const mlx_stream s);
int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s);
int mlx_diagonal(
Expand Down Expand Up @@ -452,12 +453,13 @@ int mlx_gather_qmm(
const mlx_array x,
const mlx_array w,
const mlx_array scales,
const mlx_array biases,
const mlx_array biases /* may be null */,
const mlx_array lhs_indices /* may be null */,
const mlx_array rhs_indices /* may be null */,
bool transpose,
int group_size,
int bits,
const char* mode,
bool sorted_indices,
const mlx_stream s);
int mlx_greater(
Expand Down Expand Up @@ -747,22 +749,22 @@ int mlx_put_along_axis(
int axis,
const mlx_stream s);
int mlx_quantize(
mlx_array* res_0,
mlx_array* res_1,
mlx_array* res_2,
mlx_vector_array* res,
const mlx_array w,
int group_size,
int bits,
const char* mode,
const mlx_stream s);
int mlx_quantized_matmul(
mlx_array* res,
const mlx_array x,
const mlx_array w,
const mlx_array scales,
const mlx_array biases,
const mlx_array biases /* may be null */,
bool transpose,
int group_size,
int bits,
const char* mode,
const mlx_stream s);
int mlx_radians(mlx_array* res, const mlx_array a, const mlx_stream s);
int mlx_real(mlx_array* res, const mlx_array a, const mlx_stream s);
Expand Down Expand Up @@ -868,6 +870,12 @@ int mlx_scatter_prod(
const int* axes,
size_t axes_num,
const mlx_stream s);
int mlx_segmented_mm(
mlx_array* res,
const mlx_array a,
const mlx_array b,
const mlx_array segments,
const mlx_stream s);
int mlx_sigmoid(mlx_array* res, const mlx_array a, const mlx_stream s);
int mlx_sign(mlx_array* res, const mlx_array a, const mlx_stream s);
int mlx_sin(mlx_array* res, const mlx_array a, const mlx_stream s);
Expand Down
9 changes: 9 additions & 0 deletions Source/Cmlx/include/mlx/c/random.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,15 @@ int mlx_random_multivariate_normal(
mlx_dtype dtype,
const mlx_array key /* may be null */,
const mlx_stream s);
int mlx_random_normal_broadcast(
mlx_array* res,
const int* shape,
size_t shape_num,
mlx_dtype dtype,
const mlx_array loc /* may be null */,
const mlx_array scale /* may be null */,
const mlx_array key /* may be null */,
const mlx_stream s);
int mlx_random_normal(
mlx_array* res,
const int* shape,
Expand Down
2 changes: 1 addition & 1 deletion Source/Cmlx/mlx
Submodule mlx updated 471 files
Loading