Skip to content

Crash triggered by mlx::core::fast::ScaledDotProductAttention::eval_gpu in 0.25.3 on M1 and M2 #244

@alijuma

Description

@alijuma

In 0.25.3, we're getting reports of Dia (our new macOS app) users on M1 and M2 devices crashing in mlx::core::fast::ScaledDotProductAttention::eval::gpu.

With Metal shader validation enabled, this is an assertion failure:

validateComputeFunctionArguments:1056: failed assertion Compute Function(sdpa_vector_float_64_64_floatmask_qt_nc): missing buffer binding at index 11 for mask_head_stride[0].

inside the call to dispatch_threadgroups in sdpa_vector.

With shader validation disabled, this is a later fatal error caught in ErrorHandler.swift:

MLX/ErrorHandler.swift: 332: Fatal error: [metal::Device] Unable to load function sdpa_vector_float_64_64

Function sdpa_vector_float_64_64 was not found in the library at ..... mlx-c/mlx/c/transforms.cpp:73

In 0.25.2 these same M1 and M2 users (with shader validation disabled) were crashing later in mlx::core::metal::check_error

For now we've rolled back to 0.23.1 which has fixed the crashes.

I'm suspecting that the underlying cause might be the change that added float mask to sdpa vector not playing nicely on M1 and M2, and then this change making us crash sooner on 0.25.3.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions