Skip to content

Commit 68dbe45

Browse files
RyanJDickhipsterusername
authored andcommitted
Fix regression with FLUX diffusers LoRA models where lora keys were not given the expected prefix.
1 parent bd3d1dc commit 68dbe45

File tree

2 files changed

+6
-1
lines changed

2 files changed

+6
-1
lines changed

invokeai/backend/lora/conversions/flux_diffusers_lora_conversion_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import torch
44

5+
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import FLUX_KOHYA_TRANFORMER_PREFIX
56
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
67
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
78
from invokeai.backend.lora.layers.lora_layer import LoRALayer
@@ -189,7 +190,9 @@ def add_qkv_lora_layer_if_present(src_keys: list[str], dst_qkv_key: str) -> None
189190
# Assert that all keys were processed.
190191
assert len(grouped_state_dict) == 0
191192

192-
return LoRAModelRaw(layers=layers)
193+
layers_with_prefix = {f"{FLUX_KOHYA_TRANFORMER_PREFIX}{k}": v for k, v in layers.items()}
194+
195+
return LoRAModelRaw(layers=layers_with_prefix)
193196

194197

195198
def _group_by_layer(state_dict: Dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:

tests/backend/lora/conversions/test_flux_diffusers_lora_conversion_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
is_state_dict_likely_in_flux_diffusers_format,
66
lora_model_from_flux_diffusers_state_dict,
77
)
8+
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import FLUX_KOHYA_TRANFORMER_PREFIX
89
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_diffusers_format import (
910
state_dict_keys as flux_diffusers_state_dict_keys,
1011
)
@@ -50,6 +51,7 @@ def test_lora_model_from_flux_diffusers_state_dict():
5051
concatenated_weights = ["to_k", "to_v", "proj_mlp", "add_k_proj", "add_v_proj"]
5152
expected_lora_layers = {k for k in expected_lora_layers if not any(w in k for w in concatenated_weights)}
5253
assert len(model.layers) == len(expected_lora_layers)
54+
assert all(k.startswith(FLUX_KOHYA_TRANFORMER_PREFIX) for k in model.layers.keys())
5355

5456

5557
def test_lora_model_from_flux_diffusers_state_dict_extra_keys_error():

0 commit comments

Comments
 (0)