Skip to content

Commit 018314a

Browse files
justinjfuGoogle-ML-Automation
authored andcommitted
[Pallas] Fix pallas_call lowering mutating compiler params during Triton lowering.
Addresses: #25714 PiperOrigin-RevId: 712566709
1 parent 18b193c commit 018314a

File tree

2 files changed

+49
-3
lines changed

2 files changed

+49
-3
lines changed

jax/_src/pallas/triton/pallas_call_registration.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,14 @@ def pallas_call_lowering(
6161
"scalar prefetch not implemented in the Triton backend"
6262
)
6363
triton_params = compiler_params.get("triton", compiler_params)
64-
num_warps = triton_params.pop("num_warps", 4)
64+
num_warps = triton_params.get("num_warps", 4)
6565
num_warps = 4 if num_warps is None else num_warps
6666
[lowering_platform] = ctx.platforms or ctx.module_context.platforms
6767
if lowering_platform == "rocm":
68-
num_stages = triton_params.pop("num_stages", 1)
68+
num_stages = triton_params.get("num_stages", 1)
6969
num_stages = 1 if num_stages is None else num_stages
7070
else:
71-
num_stages = triton_params.pop("num_stages", 3)
71+
num_stages = triton_params.get("num_stages", 3)
7272
num_stages = 3 if num_stages is None else num_stages
7373

7474
if debug:

tests/pallas/pallas_test.py

+46
Original file line numberDiff line numberDiff line change
@@ -30,19 +30,23 @@
3030
from jax._src import api_util
3131
from jax._src import checkify
3232
from jax._src import config
33+
from jax._src import core as jax_core
3334
from jax._src import dtypes
3435
from jax._src import test_util as jtu
3536
from jax._src.lax.control_flow.for_loop import for_loop
3637
from jax._src.pallas import core as pallas_core
38+
from jax._src.pallas import pallas_call
3739
from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr
3840
from jax.experimental import pallas as pl
3941
import jax.numpy as jnp
4042
import numpy as np
4143

4244
if sys.platform != "win32":
4345
from jax.experimental.pallas import tpu as pltpu
46+
from jax.experimental.pallas import triton as plgpu
4447
else:
4548
pltpu = None
49+
plgpu = None
4650

4751

4852
# TODO(sharadmv): Update signatures of pallas_call to correct inputs/outputs.
@@ -2361,5 +2365,47 @@ class PallasCallNamedGridInterpretTest(PallasCallNamedGridTest):
23612365
INTERPRET = True
23622366

23632367

2368+
def _find_pallas_call_in_jaxpr(
2369+
jaxpr: jax_core.Jaxpr) -> jax_core.JaxprEqn | None:
2370+
for eqn in jaxpr.eqns:
2371+
call_eqn = None
2372+
if eqn.primitive == pallas_call.pallas_call_p:
2373+
call_eqn = eqn
2374+
elif 'jaxpr' in eqn.params:
2375+
call_eqn = _find_pallas_call_in_jaxpr(eqn.params['jaxpr'])
2376+
if call_eqn is not None:
2377+
return call_eqn
2378+
return None
2379+
2380+
2381+
class PallasCompilerParamsTest(PallasBaseTest):
2382+
def test_triton_params_consistent_across_double_jit(self):
2383+
# Test for https://github.com/jax-ml/jax/issues/25714
2384+
if not jtu.test_device_matches(["gpu"]):
2385+
self.skipTest("Triton backend only works on GPU.")
2386+
params = plgpu.TritonCompilerParams(num_warps=8)
2387+
2388+
@jax.jit
2389+
@functools.partial(
2390+
self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.float32),
2391+
compiler_params=params)
2392+
def copy_kernel(x_ref, o_ref):
2393+
o_ref[...] = x_ref[...]
2394+
2395+
@functools.partial(jax.jit, static_argnames=["z"])
2396+
def plus_z(x, z):
2397+
return copy_kernel(x+z)
2398+
2399+
x = 0.
2400+
extracted_params = _find_pallas_call_in_jaxpr(
2401+
plus_z.trace(x, 1).jaxpr).params["compiler_params"]
2402+
self.assertEqual(plus_z(0., 1.), 1.)
2403+
self.assertEqual(extracted_params["triton"]["num_warps"], 8)
2404+
extracted_params = _find_pallas_call_in_jaxpr(
2405+
plus_z.trace(x, 2).jaxpr).params["compiler_params"]
2406+
self.assertEqual(plus_z(0., 2.), 2.)
2407+
self.assertEqual(extracted_params["triton"]["num_warps"], 8)
2408+
2409+
23642410
if __name__ == "__main__":
23652411
absltest.main()

0 commit comments

Comments
 (0)