@@ -27,8 +27,75 @@ defmodule EXLA.BackendTest do
2727 @ skip_mac_arm [ ]
2828 end
2929
30+ if EXLA.Client . default_name ( ) == :mps do
31+ @ skip_mps [
32+ # Missing support for "stablehlo.reduce_window".
33+ # Reported in https://github.com/google/jax/issues/21387
34+ window_max: 3 ,
35+ window_min: 3 ,
36+ window_sum: 3 ,
37+ window_product: 3 ,
38+ window_reduce: 5 ,
39+ window_scatter_min: 5 ,
40+ window_scatter_max: 5 ,
41+ window_mean: 3 ,
42+ # (edge case) Argmax/argmin return wrong value in case of NaN.
43+ # Reported in https://github.com/google/jax/issues/21821
44+ argmin: 2 ,
45+ argmax: 2 ,
46+ # Missing support for general "stablehlo.reduce". Some cases work
47+ # becuase they are special-cased.
48+ # Reported in https://github.com/google/jax/issues/21384
49+ reduce: 4 ,
50+ # Missing support for "stablehlo.popcnt", "stablehlo.count_leading_zeros",
51+ # "stablehlo.cbrt".
52+ # Reported in https://github.com/google/jax/issues/21389
53+ count_leading_zeros: 1 ,
54+ population_count: 1 ,
55+ cbrt: 1 ,
56+ # Matrix multiplication for integers is not supported
57+ dot: 2 ,
58+ dot: 4 ,
59+ dot: 6 ,
60+ covariance: 3 ,
61+ # (edge case) Put slice with overflowing slice, different behaviour.
62+ # Reported in https://github.com/google/jax/issues/21392
63+ put_slice: 3 ,
64+ # (edge case) Slice with overflowing index, different behaviour.
65+ # Reported in https://github.com/google/jax/issues/21393
66+ slice: 4 ,
67+ # (edge case) Top-k wrong behaviour with NaNs.
68+ # Reported in https://github.com/google/jax/issues/21397
69+ top_k: 2 ,
70+ # Missing support for complex numbers.
71+ # Tracked in https://github.com/google/jax/issues/16416
72+ complex: 2 ,
73+ conjugate: 1 ,
74+ conv: 3 ,
75+ fft: 2 ,
76+ fft2: 2 ,
77+ ifft: 2 ,
78+ ifft2: 2 ,
79+ imag: 1 ,
80+ is_infinity: 1 ,
81+ is_nan: 1 ,
82+ phase: 1 ,
83+ real: 1 ,
84+ sigil_MAT: 2 ,
85+ # Missing support for float-64.
86+ # Tracked in https://github.com/google/jax/issues/20938
87+ iota: 2 ,
88+ as_type: 2 ,
89+ atan2: 2 ,
90+ # Missing support for u2/s2
91+ bit_size: 1
92+ ]
93+ else
94+ @ skip_mps [ ]
95+ end
96+
3097 doctest Nx ,
31- except: [ :moduledoc ] ++ @ excluded_doctests ++ @ skip_mac_arm
98+ except: [ :moduledoc ] ++ @ excluded_doctests ++ @ skip_mac_arm ++ @ skip_mps
3299
33100 test "Nx.to_binary/1" do
34101 t = Nx . tensor ( [ 1 , 2 , 3 , 4 ] , backend: EXLA.Backend )
@@ -199,6 +266,8 @@ defmodule EXLA.BackendTest do
199266 end
200267
201268 describe "quantized types" do
269+ # TODO mising support for s2
270+ @ tag :skip
202271 test "s2" do
203272 tensor = Nx . s2 ( - 1 )
204273 assert << - 1 :: 2 - signed - native >> = Nx . to_binary ( tensor )
@@ -237,6 +306,8 @@ defmodule EXLA.BackendTest do
237306 assert 28 = Nx . bit_size ( tensor )
238307 end
239308
309+ # TODO mising support for u2
310+ @ tag :skip
240311 test "u2" do
241312 tensor = Nx . u2 ( 1 )
242313 assert << 1 :: 2 - native >> = Nx . to_binary ( tensor )
0 commit comments