-
Notifications
You must be signed in to change notification settings - Fork 117
Using Swift 6.1 traits to conditionally compile for MLX Cuda backend #259
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Using Swift 6.1 traits to conditionally compile for MLX Cuda backend #259
Conversation
- Fixed indexing issue in reduce.cpp to correctly calculate output index. - Improved simd_reduce implementations in reduce_utils.cpp to handle integral and non-integral types, including complex numbers. - Added handling for zero input in complex power operation in scan.cpp. - Updated softmax.cpp to use finite minimum for values outside axis size. - Enhanced Conv2D input and weight loaders in steel_conv_general.cpp for better safety and performance. - Introduced segmented GEMM kernel in steel_gemm_segmented.cpp for improved matrix multiplication. - Modified unary and ternary operations to support batch processing in unary.cpp and ternary.cpp. - Added utility functions for complex exponential calculations in unary_ops.cpp. - Updated utils.cpp to define WorkPerThread struct for better thread utilization. - Updated update-mlx.sh to include new segmented GEMM kernel in build process.
This looks really good! The CMakeLists.txt might not work -- it references mlx-c, which references a different version of mlx. Let's see what CI says. I will try it out tomorrow as well. |
CI hit the |
Okay I ran |
Source/MLXNN/Module.swift
Outdated
} else { | ||
return value! | ||
} | ||
return value! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dob't think we want this change -- there is a warning but I think (and I think there is a test) that this does actually do something.
For some reason that swift-format didn't do it, it is still failing in CI. I can try it later today.
Source/MLX/MLXArray+Bytes.swift
Outdated
/// - ``asData(access:)`` | ||
public func asMTLBuffer(device: any MTLDevice, noCopy: Bool = false) -> (any MTLBuffer)? { | ||
let data = asData(access: noCopy ? .noCopyIfContiguous : .copy) | ||
_ = asData(access: noCopy ? .noCopyIfContiguous : .copy) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like a bug maybe -- perhaps this line should be deleted entirely.
Source/MLXNN/Module.swift
Outdated
} else { | ||
return module! | ||
} | ||
return module! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change should also be removed
Aha, swift-format needs to be run on the new Package.swift! Look at the two items in Module.swift -- I think you were fixing warnings but we need those, despite the warning it does nothing. In fact |
2. Reverted changes in Module.swift lines 1378 and 1487 3. Fixed the bug in MLXArray+Bytes.swift line 288
Great!
|
OK, seeing a failure here:
building this:
So I think the problem may be that the Package.swift (old one) needs to exclude the cuda files from the build -- it picks them up by default. It also makes me think we need a code 16.3 (and maybe linux) builder. The Xcode side of things can be set up like this:
but we want |
…pported CUDA backend files in the older Package.swift
Alright let's give this another try. |
Looks like
The next step after that is:
I found I had to change both Package.swift files:
Now I am hitting this issue: https://developer.apple.com/forums/thread/779744 (though I suspect it is not related to this PR). Per that thread I may try Xcode 16.4. Finally this runs in CI:
Surprisingly the latter build for me -- this picks up the mlx dependency transitively via mlx-c and I thought it might fail since the versions do not match, but it built ok. Anyway, I think these are the current problems:
|
steel_gemm_segmented
which generated a lot of new file changes inmlx-generated