Skip to content

Commit 4867822

Browse files
authoredMar 12, 2025··
[sharktank] Add toy size Flux transformer (#1075)
We don't have proper tests for a toy size variant of the model, which is desirable for CI tests on every commit. Some of the tests fail during IREE buffer destruction. Which is a known issue. See #1050.
1 parent d3c462c commit 4867822

File tree

11 files changed

+248
-125
lines changed

11 files changed

+248
-125
lines changed
 

‎.github/workflows/ci-sharktank.yml

+3
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,9 @@ jobs:
144144
--with-flux-data \
145145
--with-vae-data \
146146
--with-quark-data \
147+
--iree-hal-target-device=hip \
148+
--iree-hip-target=gfx942 \
149+
--iree-device=hip://0 \
147150
sharktank/tests/models/clip/clip_test.py \
148151
sharktank/tests/models/t5/t5_test.py \
149152
sharktank/tests/models/flux/flux_test.py \

‎sharktank/sharktank/layers/mmdit.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -162,11 +162,12 @@ def forward(
162162

163163

164164
class MMDITSingleBlock(ThetaLayer):
165-
def __init__(self, theta, num_heads: int, hidden_size: int):
165+
def __init__(self, theta, num_heads: int, hidden_size: int, mlp_ratio: float):
166166
super().__init__(theta)
167167

168168
self.num_heads = num_heads
169169
self.hidden_size = hidden_size
170+
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
170171
self.add_module("mod", ModulationLayer(theta("modulation"), double=False))
171172
self.add_module(
172173
"attn_norm_q",
@@ -179,9 +180,6 @@ def __init__(self, theta, num_heads: int, hidden_size: int):
179180

180181
self.add_module("linear1", LinearLayer(theta("linear1")))
181182
self.add_module("linear2", LinearLayer(theta("linear2")))
182-
# TODO: There should be a way to refactor out the following two constants and just reference model shapes
183-
self.hidden_size = 3072
184-
self.mlp_hidden_dim = 3072
185183

186184
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
187185
mod, _ = self.mod(vec)
@@ -191,7 +189,7 @@ def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
191189
x_mod = (1 + mod.scale) * x_norm + mod.shift
192190
x_lin = self.linear1(x_mod)
193191
qkv, mlp = torch.split(
194-
x_lin, [3 * self.hidden_size, 4 * self.mlp_hidden_dim], dim=-1
192+
x_lin, [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1
195193
)
196194

197195
qkv_2 = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1) #

‎sharktank/sharktank/layers/testing.py

+37-36
Original file line numberDiff line numberDiff line change
@@ -95,24 +95,22 @@ def make_latent_attention_block_theta(
9595

9696

9797
def make_mmdit_double_block_random_theta(
98-
in_channels: int = 128,
9998
hidden_size: int = 3072,
99+
num_heads: int = 24,
100100
mlp_ratio: float = 4.0,
101101
dtype: torch.dtype | None = None,
102102
) -> Theta:
103-
in_channels = 128
104-
hidden_size = 3072
105-
mlp_ratio = 4.0
106-
mlp_hidden_size = int((mlp_ratio - 1) * hidden_size)
107-
mlp_hidden_size2 = int(mlp_ratio * hidden_size)
108-
mlp_hidden_size3 = int(2 * (mlp_ratio - 1) * hidden_size)
103+
head_dim = hidden_size // num_heads
104+
mlp_hidden_size = int(mlp_ratio * hidden_size)
105+
qkv_out_size = 3 * hidden_size
106+
modulation_size = hidden_size * 6
109107
return Theta(
110108
{
111109
"img_attn.norm.key_norm.scale": DefaultPrimitiveTensor( #
112-
data=make_rand_torch((in_channels,), dtype=dtype)
110+
data=make_rand_torch((head_dim,), dtype=dtype)
113111
),
114112
"img_attn.norm.query_norm.scale": DefaultPrimitiveTensor( #
115-
data=make_rand_torch((in_channels,), dtype=dtype)
113+
data=make_rand_torch((head_dim,), dtype=dtype)
116114
),
117115
"img_attn.proj.bias": DefaultPrimitiveTensor(
118116
data=make_rand_torch((hidden_size,), dtype=dtype)
@@ -121,34 +119,34 @@ def make_mmdit_double_block_random_theta(
121119
data=make_rand_torch((hidden_size, hidden_size), dtype=dtype)
122120
),
123121
"img_attn.qkv.bias": DefaultPrimitiveTensor(
124-
data=make_rand_torch((mlp_hidden_size,), dtype=dtype)
122+
data=make_rand_torch((qkv_out_size,), dtype=dtype)
125123
),
126124
"img_attn.qkv.weight": DefaultPrimitiveTensor(
127-
data=make_rand_torch((mlp_hidden_size, hidden_size), dtype=dtype)
125+
data=make_rand_torch((qkv_out_size, hidden_size), dtype=dtype)
128126
),
129127
"img_mlp.0.bias": DefaultPrimitiveTensor(
130-
data=make_rand_torch((mlp_hidden_size2), dtype=dtype)
128+
data=make_rand_torch((mlp_hidden_size), dtype=dtype)
131129
),
132130
"img_mlp.0.weight": DefaultPrimitiveTensor(
133-
data=make_rand_torch((mlp_hidden_size2, hidden_size), dtype=dtype)
131+
data=make_rand_torch((mlp_hidden_size, hidden_size), dtype=dtype)
134132
),
135133
"img_mlp.2.bias": DefaultPrimitiveTensor(
136134
data=make_rand_torch((hidden_size), dtype=dtype)
137135
),
138136
"img_mlp.2.weight": DefaultPrimitiveTensor(
139-
data=make_rand_torch((hidden_size, mlp_hidden_size2), dtype=dtype)
137+
data=make_rand_torch((hidden_size, mlp_hidden_size), dtype=dtype)
140138
),
141139
"img_mod.lin.bias": DefaultPrimitiveTensor(
142-
data=make_rand_torch((mlp_hidden_size3,), dtype=dtype)
140+
data=make_rand_torch((modulation_size,), dtype=dtype)
143141
),
144142
"img_mod.lin.weight": DefaultPrimitiveTensor(
145-
data=make_rand_torch((mlp_hidden_size3, hidden_size), dtype=dtype)
143+
data=make_rand_torch((modulation_size, hidden_size), dtype=dtype)
146144
),
147145
"txt_attn.norm.key_norm.scale": DefaultPrimitiveTensor( #
148-
data=make_rand_torch((in_channels,), dtype=dtype)
146+
data=make_rand_torch((head_dim,), dtype=dtype)
149147
),
150148
"txt_attn.norm.query_norm.scale": DefaultPrimitiveTensor( #
151-
data=make_rand_torch((in_channels,), dtype=dtype)
149+
data=make_rand_torch((head_dim,), dtype=dtype)
152150
),
153151
"txt_attn.proj.bias": DefaultPrimitiveTensor(
154152
data=make_rand_torch((hidden_size,), dtype=dtype)
@@ -157,49 +155,50 @@ def make_mmdit_double_block_random_theta(
157155
data=make_rand_torch((hidden_size, hidden_size), dtype=dtype)
158156
),
159157
"txt_attn.qkv.bias": DefaultPrimitiveTensor(
160-
data=make_rand_torch((mlp_hidden_size,), dtype=dtype)
158+
data=make_rand_torch((qkv_out_size,), dtype=dtype)
161159
),
162160
"txt_attn.qkv.weight": DefaultPrimitiveTensor(
163-
data=make_rand_torch((mlp_hidden_size, hidden_size), dtype=dtype)
161+
data=make_rand_torch((qkv_out_size, hidden_size), dtype=dtype)
164162
),
165163
"txt_mlp.0.bias": DefaultPrimitiveTensor(
166-
data=make_rand_torch((mlp_hidden_size2), dtype=dtype)
164+
data=make_rand_torch((mlp_hidden_size), dtype=dtype)
167165
),
168166
"txt_mlp.0.weight": DefaultPrimitiveTensor(
169-
data=make_rand_torch((mlp_hidden_size2, hidden_size), dtype=dtype)
167+
data=make_rand_torch((mlp_hidden_size, hidden_size), dtype=dtype)
170168
),
171169
"txt_mlp.2.bias": DefaultPrimitiveTensor(
172170
data=make_rand_torch((hidden_size), dtype=dtype)
173171
),
174172
"txt_mlp.2.weight": DefaultPrimitiveTensor(
175-
data=make_rand_torch((hidden_size, mlp_hidden_size2), dtype=dtype)
173+
data=make_rand_torch((hidden_size, mlp_hidden_size), dtype=dtype)
176174
),
177175
"txt_mod.lin.bias": DefaultPrimitiveTensor(
178-
data=make_rand_torch((mlp_hidden_size3,), dtype=dtype)
176+
data=make_rand_torch((modulation_size,), dtype=dtype)
179177
),
180178
"txt_mod.lin.weight": DefaultPrimitiveTensor(
181-
data=make_rand_torch((mlp_hidden_size3, hidden_size), dtype=dtype)
179+
data=make_rand_torch((modulation_size, hidden_size), dtype=dtype)
182180
),
183181
}
184182
)
185183

186184

187185
def make_mmdit_single_block_random_theta(
188-
in_channels: int = 128,
189186
hidden_size: int = 3072,
187+
num_heads: int = 24,
190188
mlp_ratio: float = 4.0,
191189
dtype: torch.dtype | None = None,
192190
) -> Theta:
193-
mlp_hidden_size = int((mlp_ratio - 1) * hidden_size)
194-
mlp_hidden_size2 = int((mlp_ratio + 1) * hidden_size)
195-
mlp_hidden_size3 = int((2 * mlp_ratio - 1) * hidden_size)
191+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
192+
head_dim = hidden_size // num_heads
193+
modulation_size = 3 * hidden_size
194+
linear1_hidden_size = hidden_size * 3 + mlp_hidden_dim
196195
return Theta(
197196
{
198197
"norm.key_norm.scale": DefaultPrimitiveTensor( #
199-
data=make_rand_torch((in_channels,), dtype=dtype)
198+
data=make_rand_torch((head_dim,), dtype=dtype)
200199
),
201200
"norm.query_norm.scale": DefaultPrimitiveTensor( #
202-
data=make_rand_torch((in_channels,), dtype=dtype)
201+
data=make_rand_torch((head_dim,), dtype=dtype)
203202
),
204203
"attn.proj.bias": DefaultPrimitiveTensor(
205204
data=make_rand_torch((hidden_size,), dtype=dtype)
@@ -208,22 +207,24 @@ def make_mmdit_single_block_random_theta(
208207
data=make_rand_torch((hidden_size, hidden_size), dtype=dtype)
209208
),
210209
"linear1.bias": DefaultPrimitiveTensor(
211-
data=make_rand_torch((mlp_hidden_size3,), dtype=dtype)
210+
data=make_rand_torch((linear1_hidden_size,), dtype=dtype)
212211
),
213212
"linear1.weight": DefaultPrimitiveTensor(
214-
data=make_rand_torch((mlp_hidden_size3, hidden_size), dtype=dtype)
213+
data=make_rand_torch((linear1_hidden_size, hidden_size), dtype=dtype)
215214
),
216215
"linear2.bias": DefaultPrimitiveTensor(
217216
data=make_rand_torch((hidden_size), dtype=dtype)
218217
),
219218
"linear2.weight": DefaultPrimitiveTensor(
220-
data=make_rand_torch((hidden_size, mlp_hidden_size2), dtype=dtype)
219+
data=make_rand_torch(
220+
(hidden_size, hidden_size + mlp_hidden_dim), dtype=dtype
221+
)
221222
),
222223
"modulation.lin.bias": DefaultPrimitiveTensor(
223-
data=make_rand_torch((mlp_hidden_size,), dtype=dtype)
224+
data=make_rand_torch((modulation_size,), dtype=dtype)
224225
),
225226
"modulation.lin.weight": DefaultPrimitiveTensor(
226-
data=make_rand_torch((mlp_hidden_size, hidden_size), dtype=dtype)
227+
data=make_rand_torch((modulation_size, hidden_size), dtype=dtype)
227228
),
228229
}
229230
)

‎sharktank/sharktank/models/flux/export.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
from ...types import Dataset
1717
from ...utils.hf_datasets import get_dataset
1818
from sharktank.transforms.dataset import set_float_dtype
19+
from iree.turbine.aot import (
20+
ExternalTensorTrait,
21+
)
1922

2023
flux_transformer_default_batch_sizes = [1]
2124

@@ -35,6 +38,8 @@ def export_flux_transformer_model_mlir(
3538
else:
3639
model = model_or_parameters_path
3740

41+
for t in model.theta.flatten().values():
42+
ExternalTensorTrait(external_name=t.name, external_scope="").set(t.as_torch())
3843
export_static_model_mlir(model, output_path=output_path, batch_sizes=batch_sizes)
3944

4045

@@ -60,7 +65,7 @@ def export_flux_transformer(
6065
):
6166
export_flux_transformer_iree_parameters(model, parameters_output_path)
6267
export_flux_transformer_model_mlir(
63-
parameters_output_path, output_path=mlir_output_path, batch_sizes=batch_sizes
68+
model, output_path=mlir_output_path, batch_sizes=batch_sizes
6469
)
6570

6671

‎sharktank/sharktank/models/flux/flux.py

+59-27
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from ... import ops
2626

2727
__all__ = [
28+
"FluxParams",
2829
"FluxModelV1",
2930
]
3031

@@ -49,6 +50,18 @@ class FluxParams:
4950
qkv_bias: bool
5051
guidance_embed: bool
5152

53+
time_dim: int = 256
54+
txt_context_length: int = 512
55+
56+
# The allowed range of these values is dependent on the model size.
57+
# They will not work for all variants, specifically toy-sized models.
58+
output_img_height: int = 1024
59+
output_img_width: int = 1024
60+
output_img_channels: int = 3
61+
62+
# def __post_init__(self):
63+
# assert self.hidden_size == self.vec_in_dim * int(self.mlp_ratio)
64+
5265
def to_hugging_face_properties(self) -> dict[str, Any]:
5366
hparams = {
5467
"in_channels": self.in_channels,
@@ -71,14 +84,12 @@ def from_hugging_face_properties(properties: dict[str, Any]) -> "FluxParams":
7184
vec_in_dim = p["pooled_projection_dim"]
7285
context_in_dim = p["joint_attention_dim"]
7386
mlp_ratio = 4.0
74-
hidden_size = vec_in_dim * int(mlp_ratio)
87+
hidden_size = int(vec_in_dim * mlp_ratio)
7588
num_heads = p["num_attention_heads"]
7689
depth = p["num_layers"]
7790
depth_single_blocks = p["num_single_layers"]
7891

79-
# TODO: figure out relation between hidden_size, num_heads and
80-
# attention_head_dim.
81-
# diffusers.FluxTransformer2DModel also hardcodes this.
92+
# diffusers.FluxTransformer2DModel hardcodes this.
8293
axes_dim = [16, 56, 56]
8394
assert sum(axes_dim) == p["attention_head_dim"]
8495

@@ -102,6 +113,29 @@ def from_hugging_face_properties(properties: dict[str, Any]) -> "FluxParams":
102113
guidance_embed=guidance_embed,
103114
)
104115

116+
def validate(self):
117+
if self.in_channels % 4 != 0:
118+
raise ValueError(f"In channels {self.in_channels} must be a multiple of 4")
119+
if self.hidden_size != self.vec_in_dim * self.mlp_ratio:
120+
raise ValueError(
121+
"Equality hidden_size == vec_in_dim * mlp_ratio does not hold. "
122+
f"{self.hidden_size} != {self.vec_in_dim} * {self.mlp_ratio}"
123+
)
124+
if self.hidden_size % self.num_heads != 0:
125+
raise ValueError(
126+
f"Hidden size {self.hidden_size} must be divisible by num_heads {self.num_heads}"
127+
)
128+
pe_dim = self.hidden_size // self.num_heads
129+
if sum(self.axes_dim) != pe_dim:
130+
raise ValueError(
131+
f"axes_dim {self.axes_dim} must sum up to the positional embeddings"
132+
f" dimension size {pe_dim}"
133+
)
134+
if any(d % 2 != 0 for d in self.axes_dim):
135+
raise ValueError(
136+
f"All elements of axes_dim {self.axes_dim} must be a multiple of 2"
137+
)
138+
105139

106140
class FluxModelV1(ThetaLayer):
107141
"""FluxModel adapted from Black Forest Lab's implementation."""
@@ -111,18 +145,11 @@ def __init__(self, theta: Theta, params: FluxParams):
111145
theta,
112146
)
113147

148+
params.validate()
114149
self.params = copy(params)
115150
self.in_channels = params.in_channels
116151
self.out_channels = self.in_channels
117-
if params.hidden_size % params.num_heads != 0:
118-
raise ValueError(
119-
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
120-
)
121152
pe_dim = params.hidden_size // params.num_heads
122-
if sum(params.axes_dim) != pe_dim:
123-
raise ValueError(
124-
f"Got {params.axes_dim} but expected positional dim {pe_dim}"
125-
)
126153
self.hidden_size = params.hidden_size
127154
self.num_heads = params.num_heads
128155
self.pe_embedder = EmbedND(
@@ -154,6 +181,7 @@ def __init__(self, theta: Theta, params: FluxParams):
154181
theta("single_blocks", i),
155182
num_heads=self.num_heads,
156183
hidden_size=self.hidden_size,
184+
mlp_ratio=params.mlp_ratio,
157185
)
158186
for i in range(params.depth_single_blocks)
159187
]
@@ -181,13 +209,15 @@ def forward(
181209

182210
# running on sequences img
183211
img = self.img_in(img)
184-
vec = self.time_in(timestep_embedding(timesteps, 256))
212+
vec = self.time_in(timestep_embedding(timesteps, self.params.time_dim))
185213
if self.guidance:
186214
if guidance is None:
187215
raise ValueError(
188216
"Didn't get guidance strength for guidance distilled model."
189217
)
190-
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
218+
vec = vec + self.guidance_in(
219+
timestep_embedding(guidance, self.params.time_dim)
220+
)
191221

192222
vec = vec + self.vector_in(y)
193223

@@ -213,14 +243,13 @@ def sample_inputs(
213243
if not (function is None or function == "forward"):
214244
raise ValueError(f'Only function "forward" is supported. Got "{function}"')
215245

216-
# The allowed range of these values is dependent on the model size.
217-
# They will not work for all variants, specifically toy-sized models.
218-
output_img_height = 1024
219-
output_img_width = 1024
220-
output_img_channels = 3
246+
output_img_channels = self.params.output_img_channels
221247

222248
img = self._get_noise(
223-
batch_size, output_img_height, output_img_width, self.dtype
249+
batch_size,
250+
self.params.output_img_height,
251+
self.params.output_img_width,
252+
self.dtype,
224253
)
225254

226255
_, c, h, w = img.shape
@@ -233,16 +262,17 @@ def sample_inputs(
233262
img_ids = img_ids.repeat(batch_size, 1, 1)
234263

235264
# T5 encoder output
236-
txt_context_length = 512
237-
txt_dims_per_token = 4096
238-
txt = torch.rand([1, txt_context_length, txt_dims_per_token], dtype=self.dtype)
265+
txt_dims_per_token = self.params.context_in_dim
266+
txt = torch.rand(
267+
[1, self.params.txt_context_length, txt_dims_per_token], dtype=self.dtype
268+
)
239269
txt = txt.repeat(batch_size, 1, 1)
240270
txt_ids = torch.zeros(batch_size, txt.shape[1], output_img_channels)
241271

242272
timesteps = torch.rand([batch_size], dtype=self.dtype)
243273

244274
# CLIP text model output
245-
y = make_rand_torch([1, 768], dtype=self.dtype)
275+
y = make_rand_torch([1, self.params.vec_in_dim], dtype=self.dtype)
246276
y = y.repeat(batch_size, 1)
247277

248278
args = tuple()
@@ -269,12 +299,14 @@ def _get_noise(
269299
width: int,
270300
dtype: torch.dtype,
271301
):
302+
assert self.params.in_channels % 4 == 0
303+
channels = self.params.in_channels // 4
272304
return torch.randn(
273305
batch_size,
274-
16,
306+
channels,
275307
# allow for packing
276-
2 * math.ceil(height / 16),
277-
2 * math.ceil(width / 16),
308+
2 * math.ceil(height / channels),
309+
2 * math.ceil(width / channels),
278310
dtype=dtype,
279311
)
280312

‎sharktank/sharktank/models/flux/testing.py

+45-8
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from .flux import FluxParams, FluxModelV1
1212
from .export import export_flux_transformer, flux_transformer_default_batch_sizes
13-
from ...types import DefaultPrimitiveTensor, Theta, save_load_theta
13+
from ...types import DefaultPrimitiveTensor, Theta
1414
from ...layers.testing import (
1515
make_rand_torch,
1616
make_mmdit_double_block_random_theta,
@@ -41,14 +41,11 @@ def convert_flux_transformer_input_for_hugging_face_model(
4141

4242

4343
def make_random_theta(config: FluxParams, dtype: torch.dtype):
44-
# TODO: do not hardcode values.
45-
4644
in_channels = config.in_channels
47-
in_channels2 = 128
4845
hidden_size = config.hidden_size
4946
mlp_ratio = config.mlp_ratio
5047
context_in_dim = config.context_in_dim
51-
time_dim = 256
48+
time_dim = config.time_dim
5249
vec_dim = config.vec_in_dim
5350
patch_size = 1
5451
out_channels = config.out_channels
@@ -107,12 +104,18 @@ def make_random_theta(config: FluxParams, dtype: torch.dtype):
107104

108105
for i in range(config.depth):
109106
tensor_dict[f"double_blocks.{i}"] = make_mmdit_double_block_random_theta(
110-
in_channels=in_channels, hidden_size=hidden_size, mlp_ratio=mlp_ratio
107+
hidden_size=hidden_size,
108+
mlp_ratio=mlp_ratio,
109+
num_heads=config.num_heads,
110+
dtype=dtype,
111111
).flatten()
112112

113113
for i in range(config.depth_single_blocks):
114114
tensor_dict[f"single_blocks.{i}"] = make_mmdit_single_block_random_theta(
115-
in_channels=in_channels2, hidden_size=hidden_size, mlp_ratio=mlp_ratio
115+
hidden_size=hidden_size,
116+
mlp_ratio=mlp_ratio,
117+
num_heads=config.num_heads,
118+
dtype=dtype,
116119
).flatten()
117120

118121
if config.guidance_embed:
@@ -141,7 +144,9 @@ def make_random_theta(config: FluxParams, dtype: torch.dtype):
141144
data=make_rand_torch((hidden_size,), dtype=dtype)
142145
)
143146

144-
return Theta(tensor_dict)
147+
res = Theta(tensor_dict)
148+
res.rename_tensors_to_paths()
149+
return res
145150

146151

147152
def make_dev_single_layer_config():
@@ -162,6 +167,38 @@ def make_dev_single_layer_config():
162167
)
163168

164169

170+
def make_toy_config() -> FluxParams:
171+
num_heads = 5
172+
mlp_ratio = 2
173+
axes_dim = [4 * 2, 4 * 3, 4 * 4]
174+
in_channels = sum(axes_dim)
175+
hidden_size = in_channels * num_heads
176+
vec_in_dim = hidden_size // mlp_ratio
177+
assert hidden_size == mlp_ratio * vec_in_dim
178+
output_img_height = 2 * in_channels // 4
179+
output_img_width = 3 * in_channels // 4
180+
return FluxParams(
181+
in_channels=in_channels,
182+
out_channels=in_channels,
183+
time_dim=13,
184+
vec_in_dim=vec_in_dim,
185+
context_in_dim=7,
186+
txt_context_length=11,
187+
hidden_size=hidden_size,
188+
mlp_ratio=float(mlp_ratio),
189+
num_heads=num_heads,
190+
depth=3,
191+
depth_single_blocks=2,
192+
axes_dim=axes_dim,
193+
theta=10_000,
194+
qkv_bias=True,
195+
guidance_embed=True,
196+
output_img_height=output_img_height,
197+
output_img_width=output_img_width,
198+
output_img_channels=3,
199+
)
200+
201+
165202
def export_dev_random_single_layer(
166203
dtype: torch.dtype,
167204
mlir_output_path: PathLike,

‎sharktank/sharktank/utils/logging.py

+5
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,13 @@
55
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

77
import logging
8+
import torch
89

910
from iree.turbine.support.logging import get_logger
1011

1112

1213
transform_logger: logging.Logger = get_logger("sharktank.transforms")
14+
15+
16+
def format_tensor_statistics(tensor: torch.Tensor):
17+
return f"mean = {tensor.mean()}, median = {tensor.median()}, std dev = {tensor.std()}, min = {tensor.min()}, max = {tensor.max()}"

‎sharktank/sharktank/utils/testing.py

+3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Optional
88
import contextlib
99
from pathlib import Path
10+
import pytest
1011
from os import PathLike
1112
import os
1213
import shutil
@@ -21,6 +22,8 @@
2122
from ..types import *
2223
from .math import cosine_similarity
2324

25+
is_mi300x = pytest.mark.skipif("config.getoption('iree_hip_target') != 'gfx942'")
26+
2427
# Range of torch.rand() is [0,1)
2528
# Range of torch.rand() * 2 - 1 is [-1, 1), includes negative values
2629
def make_rand_torch(shape: list[int], dtype: Optional[torch.dtype] = torch.float32):

‎sharktank/tests/evaluate/perplexity_iree_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
import numpy as np
1111

1212
from sharktank.evaluate import perplexity_iree
13+
from sharktank.utils.testing import is_mi300x
1314

14-
is_mi300x = pytest.mark.skipif("config.getoption('iree_hip_target') != 'gfx942'")
1515
skipif_run_quick_llama_test = pytest.mark.skipif(
1616
'not config.getoption("run-nightly-llama-tests")',
1717
reason="Run large tests if --run-nightly-llama-tests is passed",

‎sharktank/tests/layers/mmdit_test.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,16 @@ def _(model, img, txt, vec, rot) -> torch.Tensor:
5656
asm = str(output.mlir_module)
5757

5858
def testSingleExport(self):
59-
theta = make_mmdit_single_block_random_theta(hidden_size=self.hidden_size)
59+
mlp_ratio = 4.0
60+
theta = make_mmdit_single_block_random_theta(
61+
hidden_size=self.hidden_size, num_heads=self.num_heads, mlp_ratio=mlp_ratio
62+
)
6063
theta = self.save_load_theta(theta)
6164
mmdit = MMDITSingleBlock(
62-
theta=theta, num_heads=self.num_heads, hidden_size=self.hidden_size
65+
theta=theta,
66+
num_heads=self.num_heads,
67+
hidden_size=self.hidden_size,
68+
mlp_ratio=mlp_ratio,
6369
)
6470

6571
inp = torch.rand([self.batch_size, 1024, self.hidden_size])

‎sharktank/tests/models/flux/flux_test.py

+78-45
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,20 @@
2121
from sharktank.models.flux.testing import (
2222
convert_flux_transformer_input_for_hugging_face_model,
2323
export_dev_random_single_layer,
24-
make_dev_single_layer_config,
24+
make_toy_config,
2525
make_random_theta,
2626
)
2727
from sharktank.models.flux.flux import FluxModelV1, FluxParams
28-
from sharktank.utils.testing import TempDirTestBase
28+
from sharktank.utils.testing import TempDirTestBase, skip, is_mi300x
2929
from sharktank.utils.iree import (
30-
get_iree_devices,
3130
load_iree_module,
3231
run_iree_module_function,
3332
prepare_iree_module_function_args,
3433
call_torch_module_function,
3534
flatten_for_iree_signature,
3635
iree_to_torch,
3736
)
37+
from sharktank.utils.logging import format_tensor_statistics
3838
from sharktank import ops
3939
from sharktank.transforms.dataset import set_float_dtype
4040
from sharktank.types import Dataset, Theta
@@ -44,8 +44,6 @@
4444
with_flux_data = pytest.mark.skipif("not config.getoption('with_flux_data')")
4545

4646
iree_compile_flags = [
47-
"--iree-hal-target-device=hip",
48-
"--iree-hip-target=gfx942",
4947
"--iree-opt-const-eval=false",
5048
"--iree-opt-strip-assertions=true",
5149
"--iree-global-opt-propagate-transposes=true",
@@ -74,6 +72,15 @@ def convert_dtype_if_dtype(
7472
return t
7573

7674

75+
def convert_input_dtype(input: dict[str, torch.Tensor], dtype: torch.dtype):
76+
always_float32_input_arg_names = set(["img_ids", "txt_ids"])
77+
return OrderedDict(
78+
(k, t if k in always_float32_input_arg_names else t.to(dtype=dtype))
79+
for k, t in input.items()
80+
)
81+
82+
83+
@pytest.mark.usefixtures("path_prefix", "get_iree_flags")
7784
class FluxTest(TempDirTestBase):
7885
def setUp(self):
7986
super().setUp()
@@ -96,6 +103,7 @@ def runCompareIreeAgainstTorchEager(
96103
target_theta = reference_model.theta.transform(
97104
functools.partial(set_float_dtype, dtype=target_dtype)
98105
)
106+
99107
target_torch_model = FluxModelV1(
100108
theta=target_theta,
101109
params=reference_model.params,
@@ -115,30 +123,22 @@ def runCompareIreeAgainstTorchEager(
115123

116124
iree_module_path = self._temp_dir / "model.vmfb"
117125
logger.info("Compiling MLIR file...")
126+
compile_flags = iree_compile_flags + [
127+
f"--iree-hal-target-device={self.iree_hal_target_device}",
128+
f"--iree-hip-target={self.iree_hip_target}",
129+
]
118130
iree.compiler.compile_file(
119131
str(mlir_path),
120132
output_file=str(iree_module_path),
121-
extra_args=iree_compile_flags,
133+
extra_args=compile_flags,
122134
)
123135

124-
target_input_args, target_input_kwargs = target_torch_model.sample_inputs(
136+
reference_input_args, reference_input_kwargs = reference_model.sample_inputs(
125137
batch_size
126138
)
127-
128-
reference_input_args = [
129-
convert_dtype_if_dtype(
130-
t, source_dtype=target_dtype, target_dtype=reference_model.dtype
131-
)
132-
for t in target_input_args
133-
]
134-
reference_input_kwargs = OrderedDict(
135-
(
136-
k,
137-
convert_dtype_if_dtype(
138-
t, source_dtype=target_dtype, target_dtype=reference_model.dtype
139-
),
140-
)
141-
for k, t in target_input_kwargs.items()
139+
assert len(reference_input_args) == 0
140+
target_input_kwargs = convert_input_dtype(
141+
reference_input_kwargs, dtype=target_dtype
142142
)
143143

144144
logger.info("Invoking reference torch function...")
@@ -150,15 +150,15 @@ def runCompareIreeAgainstTorchEager(
150150
)
151151
expected_outputs = flatten_for_iree_signature(reference_result_dict)
152152

153-
iree_devices = get_iree_devices(driver="hip", device_count=1)
153+
iree_devices = [iree.runtime.get_device(self.iree_device)]
154154
logger.info("Loading IREE module...")
155155
iree_module, iree_vm_context, iree_vm_instance = load_iree_module(
156156
module_path=iree_module_path,
157157
devices=iree_devices,
158158
parameters_path=parameters_path,
159159
)
160160
iree_args = prepare_iree_module_function_args(
161-
args=flatten_for_iree_signature([target_input_args, target_input_kwargs]),
161+
args=flatten_for_iree_signature(target_input_kwargs),
162162
devices=iree_devices,
163163
)
164164

@@ -177,9 +177,14 @@ def runCompareIreeAgainstTorchEager(
177177
for i in range(len(expected_outputs))
178178
]
179179
logger.info("Comparing outputs...")
180+
logger.info(f"Expected output {format_tensor_statistics(expected_outputs[0])}")
181+
abs_diff = (actual_outputs[0] - expected_outputs[0]).abs()
182+
logger.info(
183+
f"Actual vs expected abs diff {format_tensor_statistics(abs_diff[0])}"
184+
)
180185
torch.testing.assert_close(actual_outputs, expected_outputs, atol=atol, rtol=0)
181186

182-
def runTestCompareDevIreeAgainstHuggingFace(
187+
def runTestCompareDevIreeAgainstEager(
183188
self, reference_dtype: torch.dtype, target_dtype: torch.dtype, atol: float
184189
):
185190
parameters_output_path = self._temp_dir / "parameters.irpa"
@@ -211,21 +216,12 @@ def runTestCompareTorchEagerAgainstHuggingFace(
211216
):
212217
target_input_args, target_input_kwargs = target_model.sample_inputs()
213218

214-
reference_input_args = [
215-
convert_dtype_if_dtype(
216-
t, source_dtype=target_model.dtype, target_dtype=reference_dtype
217-
)
218-
for t in target_input_args
219-
]
220-
reference_input_kwargs = OrderedDict(
221-
(
222-
k,
223-
convert_dtype_if_dtype(
224-
t, source_dtype=target_model.dtype, target_dtype=reference_dtype
225-
),
226-
)
227-
for k, t in target_input_kwargs.items()
219+
assert len(target_input_args) == 0
220+
reference_input_args = []
221+
reference_input_kwargs = convert_input_dtype(
222+
target_input_kwargs, dtype=reference_dtype
228223
)
224+
229225
reference_input_kwargs = convert_flux_transformer_input_for_hugging_face_model(
230226
*reference_input_args, **reference_input_kwargs
231227
)
@@ -238,18 +234,55 @@ def runTestCompareTorchEagerAgainstHuggingFace(
238234

239235
torch.testing.assert_close(target_output, reference_output, atol=atol, rtol=0)
240236

237+
def runTestCompareToyIreeAgainstEager(
238+
self, reference_dtype: torch.dtype, target_dtype: torch.dtype, atol: float
239+
):
240+
config = make_toy_config()
241+
reference_theta = make_random_theta(config, dtype=reference_dtype)
242+
reference_model = FluxModelV1(theta=reference_theta, params=config)
243+
self.runCompareIreeAgainstTorchEager(
244+
reference_model=reference_model, target_dtype=target_dtype, atol=atol
245+
)
246+
247+
@is_mi300x
248+
def testCompareToyIreeF32AgainstEagerF64(self):
249+
"""atol is apparently high because the expected output range is large.
250+
Its absolute maximum is 3915. Observed atol is 0.036."""
251+
self.runTestCompareToyIreeAgainstEager(
252+
reference_dtype=torch.float64, target_dtype=torch.float32, atol=1e-1
253+
)
254+
255+
@skip(
256+
reason=(
257+
"Sporadic segmentation fault during buffer destruction."
258+
" See https://github.com/nod-ai/shark-ai/issues/1050"
259+
)
260+
)
261+
@is_mi300x
262+
def testCompareToyIreeBf16AgainstEagerF64(self):
263+
"""atol is apparently high because the expected output range is large.
264+
Its absolute maximum is 3915. Observed atol is 260.6.
265+
This is consistent with the expectation that bf16 atol should be worse by ~10^4
266+
compared to f32. f32 can represent ~7 digits and bf16 can represent ~3."""
267+
self.runTestCompareToyIreeAgainstEager(
268+
reference_dtype=torch.float64, target_dtype=torch.bfloat16, atol=5e2
269+
)
270+
241271
@with_flux_data
242-
def testCompareDevIreeF32AgainstHuggingFaceF32(self):
243-
self.runTestCompareDevIreeAgainstHuggingFace(
272+
def testCompareDevIreeF32AgainstEagerF32(self):
273+
self.runTestCompareDevIreeAgainstEager(
244274
reference_dtype=torch.float32, target_dtype=torch.float32, atol=1e-2
245275
)
246276

247-
@pytest.mark.skip(
248-
reason="Segmentation fault during output comparison. See https://github.com/nod-ai/shark-ai/issues/1050"
277+
@skip(
278+
reason=(
279+
"Sporadic segmentation fault during buffer destruction."
280+
" See https://github.com/nod-ai/shark-ai/issues/1050"
281+
)
249282
)
250283
@with_flux_data
251-
def testCompareDevIreeBf16AgainstHuggingFaceF32(self):
252-
self.runTestCompareDevIreeAgainstHuggingFace(
284+
def testCompareDevIreeBf16AgainstEagerF32(self):
285+
self.runTestCompareDevIreeAgainstEager(
253286
reference_dtype=torch.float32, target_dtype=torch.bfloat16, atol=1
254287
)
255288

0 commit comments

Comments
 (0)
Please sign in to comment.