Skip to content

ValueError: FlatParameter requires uniform dtype but got torch.bfloat16 and torch.float32 #82

@J-zin

Description

@J-zin

When training SDXL with fsdp, the error was raised when calling

self.model.feedforward_model, self.model.guidance_model = accelerator.prepare(
self.model.feedforward_model, self.model.guidance_model
)

I think it is because in self.model.guidance_model, real_unet is fp16 but fake_unet is float32, but most related to the package version

Here is what I installed. Any ideas to fix this

accelerate               0.23.0
aiofiles                 23.2.1
aiohappyeyeballs         2.4.4
aiohttp                  3.10.11
aiosignal                1.3.1
annotated-doc            0.0.4
annotated-types          0.7.0
anyio                    4.5.2
async-timeout            5.0.1
attrs                    25.3.0
certifi                  2025.11.12
charset-normalizer       3.4.4
clean-fid                0.1.35
click                    8.1.8
clip                     1.0
cmake                    4.2.0
contourpy                1.1.1
cycler                   0.12.1
datasets                 3.0.1
diffusers                0.28.2
dill                     0.3.8
DMD2                     0.0.1       /home/DMD2
eval_type_backport       0.3.1
evaluate                 0.4.6
exceptiongroup           1.3.1
fairscale                0.4.13
fastapi                  0.124.4
ffmpy                    0.5.0
filelock                 3.16.1
fonttools                4.57.0
frozenlist               1.5.0
fsspec                   2024.6.1
ftfy                     6.2.3
gitdb                    4.0.12
GitPython                3.1.45
gradio                   4.44.1
gradio_client            1.3.0
h11                      0.16.0
hf-xet                   1.2.0
httpcore                 1.0.9
httpx                    0.28.1
huggingface-hub          0.22.0
idna                     3.11
image-reward             1.5
imageio                  2.35.1
importlib_metadata       8.5.0
importlib_resources      6.4.5
Jinja2                   3.1.6
kiwisolver               1.4.7
lit                      18.1.8
lmdb                     1.7.5
markdown-it-py           3.0.0
MarkupSafe               2.1.5
matplotlib               3.7.5
mdurl                    0.1.2
mpmath                   1.3.0
multidict                6.1.0
multiprocess             0.70.16
networkx                 3.1
numpy                    1.24.4
nvidia-cublas-cu11       11.10.3.66
nvidia-cuda-cupti-cu11   11.7.101
nvidia-cuda-nvrtc-cu11   11.7.99
nvidia-cuda-runtime-cu11 11.7.99
nvidia-cudnn-cu11        8.5.0.96
nvidia-cufft-cu11        10.9.0.58
nvidia-curand-cu11       10.2.10.91
nvidia-cusolver-cu11     11.4.0.1
nvidia-cusparse-cu11     11.7.4.91
nvidia-nccl-cu11         2.14.3
nvidia-nvtx-cu11         11.7.91
open_clip_torch          2.32.0
opencv-python            4.12.0.88
orjson                   3.10.15
packaging                25.0
pandas                   2.0.3
peft                     0.12.0
pillow                   10.4.0
pip                      24.2
piq                      0.7.0
platformdirs             4.3.6
propcache                0.2.0
protobuf                 5.29.5
psutil                   7.1.3
pyarrow                  17.0.0
pydantic                 2.10.6
pydantic_core            2.27.2
pydub                    0.25.1
Pygments                 2.19.2
pyparsing                3.1.4
python-dateutil          2.9.0.post0
python-multipart         0.0.20
pytz                     2025.2
PyYAML                   6.0.3
regex                    2024.11.6
requests                 2.32.4
rich                     14.2.0
ruff                     0.14.10
safetensors              0.5.3
scipy                    1.10.1
semantic-version         2.10.0
sentry-sdk               2.48.0
setuptools               75.1.0
shellingham              1.5.4
six                      1.17.0
smmap                    5.0.2
sniffio                  1.3.1
starlette                0.44.0
sympy                    1.13.3
timm                     0.6.13
tokenizers               0.19.1
tomlkit                  0.12.0
torch                    2.0.1
torchvision              0.15.2
tqdm                     4.67.1
transformers             4.40.2
triton                   2.0.0
typer                    0.20.0
typing_extensions        4.13.2
tzdata                   2025.3
urllib3                  2.2.3
uvicorn                  0.33.0
wandb                    0.23.1
wcwidth                  0.2.14
websockets               12.0
wheel                    0.44.0
xxhash                   3.6.0
yarl                     1.15.2
zipp                     3.20.2

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions