|
30 | 30 | from jax._src import api_util
|
31 | 31 | from jax._src import checkify
|
32 | 32 | from jax._src import config
|
| 33 | +from jax._src import core as jax_core |
33 | 34 | from jax._src import dtypes
|
34 | 35 | from jax._src import test_util as jtu
|
35 | 36 | from jax._src.lax.control_flow.for_loop import for_loop
|
36 | 37 | from jax._src.pallas import core as pallas_core
|
| 38 | +from jax._src.pallas import pallas_call |
37 | 39 | from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr
|
38 | 40 | from jax.experimental import pallas as pl
|
39 | 41 | import jax.numpy as jnp
|
40 | 42 | import numpy as np
|
41 | 43 |
|
42 | 44 | if sys.platform != "win32":
|
43 | 45 | from jax.experimental.pallas import tpu as pltpu
|
| 46 | + from jax.experimental.pallas import triton as plgpu |
44 | 47 | else:
|
45 | 48 | pltpu = None
|
| 49 | + plgpu = None |
46 | 50 |
|
47 | 51 |
|
48 | 52 | # TODO(sharadmv): Update signatures of pallas_call to correct inputs/outputs.
|
@@ -2361,5 +2365,47 @@ class PallasCallNamedGridInterpretTest(PallasCallNamedGridTest):
|
2361 | 2365 | INTERPRET = True
|
2362 | 2366 |
|
2363 | 2367 |
|
| 2368 | +def _find_pallas_call_in_jaxpr( |
| 2369 | + jaxpr: jax_core.Jaxpr) -> jax_core.JaxprEqn | None: |
| 2370 | + for eqn in jaxpr.eqns: |
| 2371 | + call_eqn = None |
| 2372 | + if eqn.primitive == pallas_call.pallas_call_p: |
| 2373 | + call_eqn = eqn |
| 2374 | + elif 'jaxpr' in eqn.params: |
| 2375 | + call_eqn = _find_pallas_call_in_jaxpr(eqn.params['jaxpr']) |
| 2376 | + if call_eqn is not None: |
| 2377 | + return call_eqn |
| 2378 | + return None |
| 2379 | + |
| 2380 | + |
| 2381 | +class PallasCompilerParamsTest(PallasBaseTest): |
| 2382 | + def test_triton_params_consistent_across_double_jit(self): |
| 2383 | + # Test for https://github.com/jax-ml/jax/issues/25714 |
| 2384 | + if not jtu.test_device_matches(["gpu"]): |
| 2385 | + self.skipTest("Triton backend only works on GPU.") |
| 2386 | + params = plgpu.TritonCompilerParams(num_warps=8) |
| 2387 | + |
| 2388 | + @jax.jit |
| 2389 | + @functools.partial( |
| 2390 | + self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.float32), |
| 2391 | + compiler_params=params) |
| 2392 | + def copy_kernel(x_ref, o_ref): |
| 2393 | + o_ref[...] = x_ref[...] |
| 2394 | + |
| 2395 | + @functools.partial(jax.jit, static_argnames=["z"]) |
| 2396 | + def plus_z(x, z): |
| 2397 | + return copy_kernel(x+z) |
| 2398 | + |
| 2399 | + x = 0. |
| 2400 | + extracted_params = _find_pallas_call_in_jaxpr( |
| 2401 | + plus_z.trace(x, 1).jaxpr).params["compiler_params"] |
| 2402 | + self.assertEqual(plus_z(0., 1.), 1.) |
| 2403 | + self.assertEqual(extracted_params["triton"]["num_warps"], 8) |
| 2404 | + extracted_params = _find_pallas_call_in_jaxpr( |
| 2405 | + plus_z.trace(x, 2).jaxpr).params["compiler_params"] |
| 2406 | + self.assertEqual(plus_z(0., 2.), 2.) |
| 2407 | + self.assertEqual(extracted_params["triton"]["num_warps"], 8) |
| 2408 | + |
| 2409 | + |
2364 | 2410 | if __name__ == "__main__":
|
2365 | 2411 | absltest.main()
|
0 commit comments