Skip to content

Conversation

@hdevalence
Copy link

Selecting MPS as a backend when CUDA is unavailable does not initially work properly, as the code uses sparse tensors and the SparseMPS backend is not implemented. This commit changes the code so that when MPS is in use, sparse tensors are created on the CPU instead, and then moved back to the same device as the rest of the computation. Since Apple Silicon has unified memory, this "movement" should be an accounting change only and a no-op in terms of actual data movement.

I haven't fully tested this yet, but it provides a significant performance improvement on an M3 MBA, taking the time to execute all cells in the circuit_tracing_tutorial.ipynb from ~19min to <1min.

Note: this stacks on top of #4, cf #1

This allows the code to run (albeit slowly) on MacOS without a CUDA device.

I also experimented with using MPS, and got further than indicated here. As
noted in safety-research#1, there are issues with the SparseMPS backend, which is mostly not
implemented. Sparse tensors are used for e.g. transcoder activations. I think
it might not be terrible to do a workaround where MPS devices are used except
for the sparse computations, as unlike with CUDA, the cost of shuttling tensors
"between" devices on Apple Silicon is minimal (it's all unified memory).
However I haven't actually got this working to be able to test it.

In the meantime, the entire codebase now has consistent device selection
defaults and the notebooks will run on a Mac.
Selecting MPS as a backend when CUDA is unavailable does not initially work
properly, as the code uses sparse tensors and the SparseMPS backend is not
implemented. This commit changes the code so that when MPS is in use, sparse
tensors are created on the CPU instead, and then moved back to the same device
as the rest of the computation. Since Apple Silicon has unified memory, this
"movement" should be an accounting change only and a no-op in terms of actual
data movement.

I haven't fully tested this yet, but it provides a significant performance
improvement on an M3 MBA, taking the time to execute all cells in the
circuit_tracing_tutorial.ipynb from ~19min to <1min.
@mntss
Copy link
Contributor

mntss commented Jun 7, 2025

Thank you for the contribution! Eager to get this working on MPS.

I can't test this myself at the moment. Could you confirm that:

  1. pytest tests/ passes
  2. All notebooks run end-to-end

I'd prefer to handle this as a single change, so closing #4 in favor of this.

@Tingel24
Copy link

Tingel24 commented Jun 23, 2025

Running test_attributions_gemma.py::test_gemma_2_2b on

System Version: macOS 15.5 (24F74)
      Kernel Version: Darwin 24.5.0
      Boot Volume: Macintosh HD
      Boot Mode: Normal
      Secure Virtual Memory: Enabled
      System Integrity Protection: Enabled

results in a SIGTRAP for me: full stacktrace

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants