|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
| 15 | +import collections |
15 | 16 | import contextlib |
16 | 17 | import dataclasses |
17 | 18 | import threading |
| 19 | +import typing as tp |
18 | 20 |
|
19 | 21 | import jax |
20 | 22 | from jax.sharding import PartitionSpec, NamedSharding |
21 | 23 | from flax.core import meta |
22 | 24 | from flax.typing import ( |
23 | 25 | LogicalRules, |
24 | | - Sharding, |
25 | 26 | ) |
26 | 27 |
|
27 | 28 | def get_pspec(sharding_names, sharding_rules = None) -> PartitionSpec: |
28 | 29 | """Given an `nnx.Variable`, return its `PartitionSpec`.""" |
29 | 30 | if get_logical_axis_rules() or sharding_rules: |
30 | | - context_rules = get_logical_axis_rules() |
31 | | - rules = composite_rules(context_rules, sharding_rules) |
32 | | - return PartitionSpec(*from_sharding_rules(sharding_names, rules)) |
| 31 | + sharding_names = logical_to_mesh_axes(sharding_names, sharding_rules) |
33 | 32 | return PartitionSpec(*sharding_names) |
34 | 33 |
|
35 | 34 |
|
@@ -105,10 +104,119 @@ def composite_rules(rule1, rule2): |
105 | 104 | return tuple(rules.items()) |
106 | 105 |
|
107 | 106 |
|
108 | | -def from_sharding_rules( |
109 | | - sharding: Sharding, sharding_rules: LogicalRules |
110 | | -) -> Sharding: |
111 | | - rules = {alias: on_mesh for (alias, on_mesh) in sharding_rules} |
112 | | - return tuple( |
113 | | - rules[str(s)] if (s and str(s) in rules) else s for s in sharding |
| 107 | + |
| 108 | +class _UnassignedAxis: |
| 109 | + """Sentinel class for unassigned logical axis name.""" |
| 110 | + |
| 111 | + def __repr__(self): |
| 112 | + return 'UnassignedAxis' |
| 113 | + |
| 114 | + def __bool__(self): |
| 115 | + return False |
| 116 | + |
| 117 | + |
| 118 | +_unassigned_axis = _UnassignedAxis() |
| 119 | + |
| 120 | + |
| 121 | +def _mesh_assignment_free(new_assignment, existing_assignments): |
| 122 | + """Determines if a given mesh axis has already been assigned.""" |
| 123 | + new = set(jax.tree_util.tree_leaves(new_assignment)) |
| 124 | + existing = set(jax.tree_util.tree_leaves(existing_assignments)) |
| 125 | + if existing.intersection(new): |
| 126 | + return False |
| 127 | + return True |
| 128 | + |
| 129 | + |
| 130 | +def _logical_to_mesh_axes( |
| 131 | + array_dim_names: tp.Sequence[str | None] | None, |
| 132 | + rules: LogicalRules | None = None, |
| 133 | +) -> list[_UnassignedAxis | None | str | tuple[str, ...]] | None: |
| 134 | + """Same as logical_to_mesh_axes, but doesn't fill in _unassigned_axis.""" |
| 135 | + if array_dim_names is None: |
| 136 | + return None |
| 137 | + if rules is None: |
| 138 | + rules = get_logical_axis_rules() |
| 139 | + axis_name_counts = collections.Counter(array_dim_names) |
| 140 | + # None and special values such as PartitionSpec.UNCONSTRAINED can appear more |
| 141 | + # then once. |
| 142 | + dups = tuple( |
| 143 | + k for k, v in axis_name_counts.items() if v > 1 and isinstance(k, str) |
114 | 144 | ) |
| 145 | + if dups: |
| 146 | + raise ValueError( |
| 147 | + f'Unsupported: Dimensions {dups} occur more than once in array names.' |
| 148 | + ) |
| 149 | + if not isinstance(rules, (tuple, list)): |
| 150 | + raise ValueError('Unknown axis rule specification type.') |
| 151 | + # We assign mesh axes using a priority based ruleset over logical axis names. |
| 152 | + result: list[_UnassignedAxis | None | str | tuple[str, ...]] |
| 153 | + result = [ |
| 154 | + (_unassigned_axis if isinstance(name, str) else name) |
| 155 | + for name in array_dim_names |
| 156 | + ] |
| 157 | + for rule_model_name, rule_mesh_names in rules: |
| 158 | + if rule_model_name in array_dim_names: |
| 159 | + pos = array_dim_names.index(rule_model_name) |
| 160 | + if ( |
| 161 | + _mesh_assignment_free(rule_mesh_names, result) |
| 162 | + and result[pos] == _unassigned_axis |
| 163 | + ): |
| 164 | + result[pos] = rule_mesh_names |
| 165 | + return result |
| 166 | + |
| 167 | + |
| 168 | +def logical_to_mesh_axes( |
| 169 | + array_dim_names: tp.Sequence[str | None] | None, |
| 170 | + rules: LogicalRules | None = None, |
| 171 | +) -> jax.sharding.PartitionSpec | None: |
| 172 | + """Compute layout for an array. |
| 173 | +
|
| 174 | + The rules are in order of precedence, and consist of pairs: |
| 175 | + ``(ArrayDimensionName, MeshDimensionName)``, meaning that the given array |
| 176 | + dimension (if present and unused) should be sharded across the given |
| 177 | + mesh dimension (if present and unused). |
| 178 | +
|
| 179 | + A Layout of an Array is expressed as a tuple with one element for each |
| 180 | + dimension in the Array. The element is either None, or is the name of a |
| 181 | + mesh-dimension, meaning that this dimension of the array is sharded across |
| 182 | + this dimension of the mesh. |
| 183 | +
|
| 184 | + For example, given an array with:: |
| 185 | +
|
| 186 | + array_dim_names = ('batch', 'length', 'heads', 'features') |
| 187 | +
|
| 188 | + and the layout rules are:: |
| 189 | +
|
| 190 | + rules = (('batch', 'X'), |
| 191 | + ('features', 'X'), |
| 192 | + ('heads', 'Y'), |
| 193 | + ('batch', 'Z')) |
| 194 | +
|
| 195 | + then this function will return:: |
| 196 | +
|
| 197 | + PartitionSpec('X', None, 'Y', None) |
| 198 | +
|
| 199 | + Args: |
| 200 | + array_dim_names: Tuple of array dimension names or None. |
| 201 | + rules: Optional logical to mesh rules override. Defaults to using the |
| 202 | + rules defined in the dynamic context set from the ``axis_rules`` function. |
| 203 | +
|
| 204 | + Returns: |
| 205 | + PartitionSpec for the parameter. |
| 206 | + """ |
| 207 | + result = _logical_to_mesh_axes(array_dim_names, rules) |
| 208 | + if result is None: |
| 209 | + return None |
| 210 | + # We default to None - ie unsharded along the dimension. |
| 211 | + result = [None if x is _unassigned_axis else x for x in result] |
| 212 | + return jax.sharding.PartitionSpec(*result) |
| 213 | + |
| 214 | + |
| 215 | + |
| 216 | +# def from_sharding_rules( |
| 217 | +# sharding_names: Sharding, sharding_rules: LogicalRules |
| 218 | +# ) -> Sharding: |
| 219 | +# rules = {alias: on_mesh for (alias, on_mesh) in sharding_rules} |
| 220 | +# return tuple( |
| 221 | +# rules[str(s)] if (s and str(s) in rules) else s for s in sharding_names |
| 222 | +# ) |
0 commit comments