In 0.25.3, we're getting reports of Dia (our new macOS app) users on M1 and M2 devices crashing in mlx::core::fast::ScaledDotProductAttention::eval::gpu
.
With Metal shader validation enabled, this is an assertion failure:
validateComputeFunctionArguments:1056: failed assertion Compute Function(sdpa_vector_float_64_64_floatmask_qt_nc): missing buffer binding at index 11 for mask_head_stride[0].
inside the call to dispatch_threadgroups
in sdpa_vector
.
With shader validation disabled, this is a later fatal error caught in ErrorHandler.swift:
MLX/ErrorHandler.swift: 332: Fatal error: [metal::Device] Unable to load function sdpa_vector_float_64_64
Function sdpa_vector_float_64_64 was not found in the library at ..... mlx-c/mlx/c/transforms.cpp:73
In 0.25.2 these same M1 and M2 users (with shader validation disabled) were crashing later in mlx::core::metal::check_error
For now we've rolled back to 0.23.1 which has fixed the crashes.
I'm suspecting that the underlying cause might be the change that added float mask to sdpa vector not playing nicely on M1 and M2, and then this change making us crash sooner on 0.25.3.