Skip to content

Commit 2c99ce9

Browse files
bfloat16
1 parent 3ef797a commit 2c99ce9

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

predict.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def setup(self) -> None:
3333
self.pipe = FluxControlNetPipeline.from_pretrained(
3434
MODEL_CACHE,
3535
controlnet=controlnet,
36-
torch_dtype=torch.float16
36+
torch_dtype=torch.bfloat16
3737
).to("cuda")
3838

3939
# Quantize transformer
@@ -62,7 +62,7 @@ def setup(self) -> None:
6262
self.inpaint_pipe.transformer
6363
]:
6464
if hasattr(component, "to"):
65-
component.to(dtype=torch.float16)
65+
component.to(dtype=torch.bfloat16)
6666

6767
def predict(
6868
self,

0 commit comments

Comments
 (0)