Skip to content

Commit 1926b99

Browse files
chr1sj0nesGoogle-ML-Automation
authored andcommitted
[pallas] Fix spelling of 'fusible'.
PiperOrigin-RevId: 747663692
1 parent 0ed0fb7 commit 1926b99

File tree

10 files changed

+108
-106
lines changed

10 files changed

+108
-106
lines changed

jax/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -721,7 +721,7 @@ pytype_strict_library(
721721
":pallas", # build_cleaner: keep
722722
"//jax/_src/pallas/fuser:block_spec",
723723
"//jax/_src/pallas/fuser:custom_evaluate",
724-
"//jax/_src/pallas/fuser:fusable",
724+
"//jax/_src/pallas/fuser:fusible",
725725
"//jax/_src/pallas/fuser:fusion",
726726
"//jax/_src/pallas/fuser:jaxpr_fusion",
727727
],

jax/_src/pallas/fuser/BUILD

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ pytype_strict_library(
3333
deps = [
3434
":block_spec",
3535
":custom_evaluate",
36-
":fusable",
36+
":fusible",
3737
":fusion",
3838
":jaxpr_fusion",
3939
],
@@ -58,9 +58,9 @@ pytype_strict_library(
5858
)
5959

6060
pytype_strict_library(
61-
name = "fusable",
61+
name = "fusible",
6262
srcs = [
63-
"fusable.py",
63+
"fusible.py",
6464
],
6565
deps = [
6666
":fusion",
@@ -91,8 +91,8 @@ pytype_strict_library(
9191
"jaxpr_fusion.py",
9292
],
9393
deps = [
94-
":fusable",
95-
":fusable_dtype",
94+
":fusible",
95+
":fusible_dtype",
9696
":fusion",
9797
"//jax",
9898
"//jax:api_util",
@@ -104,13 +104,13 @@ pytype_strict_library(
104104
)
105105

106106
pytype_strict_library(
107-
name = "fusable_dtype",
107+
name = "fusible_dtype",
108108
srcs = [
109-
"fusable_dtype.py",
109+
"fusible_dtype.py",
110110
],
111111
deps = [
112112
":block_spec",
113-
":fusable",
113+
":fusible",
114114
"//jax",
115115
"//jax:api_util",
116116
"//jax:core",

jax/_src/pallas/fuser/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,6 @@
1717
from jax._src.pallas.fuser.block_spec import pull_block_spec as pull_block_spec
1818
from jax._src.pallas.fuser.block_spec import push_block_spec as push_block_spec
1919
from jax._src.pallas.fuser.custom_evaluate import evaluate as evaluate
20-
from jax._src.pallas.fuser.fusable import fusable as fusable
20+
from jax._src.pallas.fuser.fusible import fusible as fusible
2121
from jax._src.pallas.fuser.fusion import Fusion as Fusion
2222
from jax._src.pallas.fuser.jaxpr_fusion import fuse as fuse

jax/_src/pallas/fuser/fusable.py renamed to jax/_src/pallas/fuser/fusible.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""Fusable primitive."""
15+
"""Fusible primitive."""
1616
from typing import Any
1717

1818
import jax
@@ -25,8 +25,8 @@
2525
from jax._src.interpreters import partial_eval as pe
2626
from jax._src.pallas.fuser import fusion as fusion_lib
2727

28-
fusable_p = jax_core.Primitive('fusable')
29-
fusable_p.multiple_results = True
28+
fusible_p = jax_core.Primitive('fusible')
29+
fusible_p.multiple_results = True
3030

3131

3232
def _make_trivial_fusion(x: jax.Array) -> fusion_lib.Fusion:
@@ -37,22 +37,22 @@ def _make_trivial_fusion(x: jax.Array) -> fusion_lib.Fusion:
3737
)
3838

3939

40-
def fusable(f=None, *, output_fusion_prefix: Any = True):
40+
def fusible(f=None, *, output_fusion_prefix: Any = True):
4141
def decorator(f):
4242
def wrapper(*args):
4343
def wrapped(*args):
4444
in_fusions = tree_util.tree_map(_make_trivial_fusion, args)
4545
return f(*in_fusions, None)
4646

4747
flat_args, in_tree = tree_util.tree_flatten(args)
48-
debug_info = api_util.debug_info('fusable', wrapped, args, {})
48+
debug_info = api_util.debug_info('fusible', wrapped, args, {})
4949
flat_fun, out_tree_thunk = api_util.flatten_fun_nokwargs(
5050
lu.wrap_init(wrapped, debug_info=debug_info), in_tree
5151
)
5252
flat_avals = [jax_core.get_aval(x) for x in flat_args]
5353
jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals)
5454
out_tree = out_tree_thunk()
55-
out = fusable_p.bind(
55+
out = fusible_p.bind(
5656
*consts,
5757
*flat_args,
5858
jaxpr=jaxpr,
@@ -71,16 +71,16 @@ def wrapped(*args):
7171
return decorator
7272

7373

74-
@fusable_p.def_impl
74+
@fusible_p.def_impl
7575
def _(*consts_and_args, jaxpr, num_consts, **_):
7676
consts, args = util.split_list(consts_and_args, [num_consts])
7777
return jax_core.eval_jaxpr(jaxpr, consts, *args)
7878

7979

80-
mlir.register_lowering(fusable_p, mlir.lower_fun(fusable_p.impl))
80+
mlir.register_lowering(fusible_p, mlir.lower_fun(fusible_p.impl))
8181

8282

83-
@fusable_p.def_abstract_eval
83+
@fusible_p.def_abstract_eval
8484
def _(*args, jaxpr, **kwargs):
8585
del args, kwargs
8686
return [v.aval for v in jaxpr.outvars]

jax/_src/pallas/fuser/fusable_dtype.py renamed to jax/_src/pallas/fuser/fusible_dtype.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""Custom fusable dtypes."""
15+
"""Custom fusible dtypes."""
1616

1717
import abc
1818
import dataclasses
@@ -34,7 +34,7 @@
3434
from jax._src.pallas import pallas_call
3535
from jax._src.pallas import primitives as pallas_primitives
3636
from jax._src.pallas.fuser import block_spec
37-
from jax._src.pallas.fuser.fusable import fusable_p
37+
from jax._src.pallas.fuser.fusible import fusible_p
3838
from jax._src.state import discharge as state_discharge
3939
from jax._src.state import primitives as state_primitives
4040
from jax._src.util import foreach
@@ -54,7 +54,7 @@
5454

5555
@pack_dtype_p.def_abstract_eval
5656
def pack_dtype_abstract_eval(*xs, dtype):
57-
if dtypes.issubdtype(dtype, FusableElementDType):
57+
if dtypes.issubdtype(dtype, fusibleElementDType):
5858
return dtype.abstract_pack(*xs)
5959
raise ValueError("Attempted to pack non-fusion dtype: {dtype}")
6060

@@ -69,7 +69,7 @@ def pack(*xs, dtype):
6969

7070
@unpack_dtype_p.def_abstract_eval
7171
def unpack_dtype_abstract_eval(x):
72-
if dtypes.issubdtype(x.dtype, FusableElementDType):
72+
if dtypes.issubdtype(x.dtype, fusibleElementDType):
7373
return x.dtype.abstract_unpack(x)
7474
elif isinstance(x.dtype, pallas_core.AbstractMemoryRef):
7575
raise NotImplementedError()
@@ -80,20 +80,20 @@ def unpack(x):
8080
return unpack_dtype_p.bind(x)
8181

8282

83-
class FusableElementDType(dtypes.extended):
84-
"""Scalar dtype for fusable dtypes."""
83+
class fusibleElementDType(dtypes.extended):
84+
"""Scalar dtype for fusible dtypes."""
8585

8686

87-
class FusableTyRules:
87+
class fusibleTyRules:
8888
allow_conversion: bool = False
8989

9090

9191
class FusionDType(dtypes.ExtendedDType, metaclass=abc.ABCMeta):
92-
"""Base class for fusable extended dtypes."""
92+
"""Base class for fusible extended dtypes."""
9393

9494
_op_registry = {}
95-
_rules = FusableTyRules
96-
type = FusableElementDType
95+
_rules = fusibleTyRules
96+
type = fusibleElementDType
9797

9898
@abc.abstractmethod
9999
def abstract_unpack(self, x) -> Sequence[Any]:
@@ -124,7 +124,7 @@ def pull_block_spec_one_step(self, *args, **kwargs):
124124

125125

126126
def physicalize(f):
127-
"""Runs a function that contains fusable extended dtypes."""
127+
"""Runs a function that contains fusible extended dtypes."""
128128

129129
def wrapper(*args, **kwargs):
130130
if kwargs:
@@ -203,7 +203,7 @@ class Context:
203203
def physicalize_interp(
204204
jaxpr: core.Jaxpr, consts: Sequence[core.Value], *args: core.Value
205205
):
206-
"""Physicalizes a jaxpr by replacing fusable dtypes with physical types."""
206+
"""Physicalizes a jaxpr by replacing fusible dtypes with physical types."""
207207
# TODO: Merge into JAX core.
208208
env: dict[core.Var, Any] = {}
209209

@@ -446,12 +446,12 @@ def _pack_dtype_pull_rule(
446446
return dtype.pull_block_spec_one_step(block_spec) # pytype: disable=attribute-error
447447

448448

449-
def _fusable_physicalize_rule(
449+
def _fusible_physicalize_rule(
450450
_, *consts_and_args, jaxpr, num_consts, in_tree, out_tree, func
451451
):
452452
consts, _ = util.split_list(consts_and_args, [num_consts])
453453
new_jaxpr = physicalize_closed_jaxpr(core.ClosedJaxpr(jaxpr, consts))
454-
return fusable_p.bind(
454+
return fusible_p.bind(
455455
*consts_and_args,
456456
jaxpr=new_jaxpr.jaxpr,
457457
num_consts=num_consts,
@@ -461,4 +461,4 @@ def _fusable_physicalize_rule(
461461
)
462462

463463

464-
_physicalize_rules[fusable_p] = _fusable_physicalize_rule
464+
_physicalize_rules[fusible_p] = _fusible_physicalize_rule

jax/_src/pallas/fuser/jaxpr_fusion.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,22 +23,22 @@
2323
from jax._src import linear_util as lu
2424
from jax._src import tree_util
2525
from jax._src.interpreters import partial_eval as pe
26-
from jax._src.pallas.fuser import fusable_dtype
26+
from jax._src.pallas.fuser import fusible_dtype
2727
from jax._src.pallas.fuser import fusion as fusion_lib
28-
from jax._src.pallas.fuser.fusable import fusable_p
28+
from jax._src.pallas.fuser.fusible import fusible_p
2929

3030

3131
def fuse(f=None, *, physicalize: bool = False, debug: bool = False):
32-
"""Fuses a function into a single fusable.
32+
"""Fuses a function into a single fusible.
3333
3434
Args:
3535
f: The function to fuse.
3636
physicalize: (experimental) whether to physicalize the function.
3737
debug: Whether to print debug information.
3838
39-
There should be a single call to a `fusable` inside the body of `f`. `fuse`
39+
There should be a single call to a `fusible` inside the body of `f`. `fuse`
4040
returns a transformed function that will fuse the surrounding computation into
41-
the fusable and invoke it.
41+
the fusible and invoke it.
4242
"""
4343

4444
def decorator(f):
@@ -58,15 +58,15 @@ def wrapper(*args, **kwargs):
5858
return tree_util.tree_unflatten(out_tree, out_flat)
5959

6060
if physicalize:
61-
wrapper = fusable_dtype.physicalize(wrapper)
61+
wrapper = fusible_dtype.physicalize(wrapper)
6262
return wrapper
6363

6464
if f is not None:
6565
return decorator(f)
6666
return decorator
6767

6868

69-
_fusable: dict[jax_core.Primitive, Any] = {}
69+
_fusible: dict[jax_core.Primitive, Any] = {}
7070

7171

7272
def _construct_fusion_jaxpr(
@@ -148,11 +148,11 @@ def _construct_output_fusions(
148148
jaxpr,
149149
out_tree,
150150
fusion_eqn_index,
151-
fusion_eqn_outvars, # Flat list of vars output by the fusable eqn
152-
fusion_eqn_out_tree, # Tree structure of the fusable eqn outputs
151+
fusion_eqn_outvars, # Flat list of vars output by the fusible eqn
152+
fusion_eqn_out_tree, # Tree structure of the fusible eqn outputs
153153
output_fusion_prefix, # Pytree defining output groups
154154
):
155-
# 1. Create jaxpr_out: represents computation *after* the fusable
155+
# 1. Create jaxpr_out: represents computation *after* the fusible
156156
# Inputs: fusion_eqn_outvars
157157
# Outputs: jaxpr.outvars
158158
jaxpr_out, all_values, _, _, _ = _construct_fusion_jaxpr(
@@ -164,26 +164,26 @@ def _construct_output_fusions(
164164
tree_util.tree_unflatten(out_tree, jaxpr.outvars), # Original outputs
165165
tree_util.tree_unflatten(
166166
fusion_eqn_out_tree, fusion_eqn_outvars
167-
), # Fusable outputs as inputs
167+
), # Fusible outputs as inputs
168168
)
169169

170-
# 2. Group fusable outputs based on the mask
171-
unflat_fusable_outvars = jax.tree.unflatten(
170+
# 2. Group fusible outputs based on the mask
171+
unflat_fusible_outvars = jax.tree.unflatten(
172172
fusion_eqn_out_tree, fusion_eqn_outvars
173173
)
174174
partial_flat = jax.tree.structure(output_fusion_prefix).flatten_up_to(
175-
unflat_fusable_outvars
175+
unflat_fusible_outvars
176176
)
177177

178178
# 3. Calculate dependencies and check disjointness
179179
downstream_outputs_used_masks = [] # List of bool tuples, one per group
180180
already_used_final_outputs = set() # Indices of final outputs already claimed
181181
for outvars_group in partial_flat:
182182
# Identify vars in this group
183-
used_fusable_outvars = set(jax.tree.leaves(outvars_group))
183+
used_fusible_outvars = set(jax.tree.leaves(outvars_group))
184184
# Create mask for jaxpr_out inputs corresponding to this group
185185
in_used_mask = [
186-
True if v in used_fusable_outvars else False for v in jaxpr_out.invars
186+
True if v in used_fusible_outvars else False for v in jaxpr_out.invars
187187
]
188188
# Trace dependencies through jaxpr_out to find which final outputs are affected
189189
downstream_used_mask = _find_downstream(
@@ -257,25 +257,25 @@ def fuse_jaxpr(
257257

258258
# Collect input fusions
259259
for i, eqn in enumerate(jaxpr.eqns):
260-
if eqn.primitive is fusable_p:
260+
if eqn.primitive is fusible_p:
261261
fusion_eqn_index = i
262262
break
263263
if fusion_eqn_index is None:
264-
raise ValueError("No fusable eqn found")
264+
raise ValueError("No fusible eqn found")
265265
fusion_eqn = jaxpr.eqns[fusion_eqn_index]
266266

267267
# Now let's check if we need to do any fusion at all, e.g. do the outputs of
268268
# the jaxpr have any dependence on the fusion at all? We can DCE the jaxpr
269269
# with all the inputs and outputs to check if there is a dependence.
270270
dced_jaxpr, _ = pe.dce_jaxpr(jaxpr, [True] * len(jaxpr.outvars),
271271
instantiate=True)
272-
if not any(eqn.primitive is fusable_p for eqn in dced_jaxpr.eqns):
272+
if not any(eqn.primitive is fusible_p for eqn in dced_jaxpr.eqns):
273273
# Short circuit if there is nothing to fuse.
274274
return jax_core.eval_jaxpr(dced_jaxpr, consts, *args)
275275

276276
candidate_values = [*consts, *args]
277277

278-
# Construct fusions for non-constant inputs to the fusable.
278+
# Construct fusions for non-constant inputs to the fusible.
279279
in_fusions_flat = [
280280
construct_fusion(
281281
candidate_values,

jax/experimental/pallas/fuser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,6 @@
1919
from jax._src.pallas.fuser.block_spec import pull_block_spec as pull_block_spec
2020
from jax._src.pallas.fuser.block_spec import push_block_spec as push_block_spec
2121
from jax._src.pallas.fuser.custom_evaluate import evaluate as evaluate
22-
from jax._src.pallas.fuser.fusable import fusable as fusable
22+
from jax._src.pallas.fuser.fusible import fusible as fusible
2323
from jax._src.pallas.fuser.fusion import Fusion as Fusion
2424
from jax._src.pallas.fuser.jaxpr_fusion import fuse as fuse

tests/pallas/BUILD

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -702,8 +702,8 @@ jax_multiplatform_test(
702702
)
703703

704704
jax_multiplatform_test(
705-
name = "tpu_fusable_matmul_test",
706-
srcs = ["tpu_fusable_matmul_test.py"],
705+
name = "tpu_fusible_matmul_test",
706+
srcs = ["tpu_fusible_matmul_test.py"],
707707
disable_configs = [
708708
"tpu_v3",
709709
"tpu_pjrt_c_api",

0 commit comments

Comments
 (0)