Skip to content

Commit 807f458

Browse files
RyanJDickhipsterusername
authored andcommitted
Move FLUX_LORA_TRANSFORMER_PREFIX and FLUX_LORA_CLIP_PREFIX to a shared location.
1 parent 68dbe45 commit 807f458

File tree

7 files changed

+18
-20
lines changed

7 files changed

+18
-20
lines changed

invokeai/app/invocations/flux_denoise.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
pack,
3131
unpack,
3232
)
33-
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import FLUX_KOHYA_TRANFORMER_PREFIX
33+
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
3434
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
3535
from invokeai.backend.lora.lora_patcher import LoRAPatcher
3636
from invokeai.backend.model_manager.config import ModelFormat
@@ -209,7 +209,7 @@ def _run_diffusion(
209209
LoRAPatcher.apply_lora_patches(
210210
model=transformer,
211211
patches=self._lora_iterator(context),
212-
prefix=FLUX_KOHYA_TRANFORMER_PREFIX,
212+
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
213213
cached_weights=cached_weights,
214214
)
215215
)
@@ -220,7 +220,7 @@ def _run_diffusion(
220220
LoRAPatcher.apply_lora_sidecar_patches(
221221
model=transformer,
222222
patches=self._lora_iterator(context),
223-
prefix=FLUX_KOHYA_TRANFORMER_PREFIX,
223+
prefix=FLUX_LORA_TRANSFORMER_PREFIX,
224224
dtype=inference_dtype,
225225
)
226226
)

invokeai/app/invocations/flux_text_encoder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from invokeai.app.invocations.primitives import FluxConditioningOutput
1111
from invokeai.app.services.shared.invocation_context import InvocationContext
1212
from invokeai.backend.flux.modules.conditioner import HFEncoder
13-
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import FLUX_KOHYA_CLIP_PREFIX
13+
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
1414
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
1515
from invokeai.backend.lora.lora_patcher import LoRAPatcher
1616
from invokeai.backend.model_manager.config import ModelFormat
@@ -101,7 +101,7 @@ def _clip_encode(self, context: InvocationContext) -> torch.Tensor:
101101
LoRAPatcher.apply_lora_patches(
102102
model=clip_text_encoder,
103103
patches=self._clip_lora_iterator(context),
104-
prefix=FLUX_KOHYA_CLIP_PREFIX,
104+
prefix=FLUX_LORA_CLIP_PREFIX,
105105
cached_weights=cached_weights,
106106
)
107107
)

invokeai/backend/lora/conversions/flux_diffusers_lora_conversion_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44

5-
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import FLUX_KOHYA_TRANFORMER_PREFIX
5+
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
66
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
77
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
88
from invokeai.backend.lora.layers.lora_layer import LoRALayer
@@ -190,7 +190,7 @@ def add_qkv_lora_layer_if_present(src_keys: list[str], dst_qkv_key: str) -> None
190190
# Assert that all keys were processed.
191191
assert len(grouped_state_dict) == 0
192192

193-
layers_with_prefix = {f"{FLUX_KOHYA_TRANFORMER_PREFIX}{k}": v for k, v in layers.items()}
193+
layers_with_prefix = {f"{FLUX_LORA_TRANSFORMER_PREFIX}{k}": v for k, v in layers.items()}
194194

195195
return LoRAModelRaw(layers=layers_with_prefix)
196196

invokeai/backend/lora/conversions/flux_kohya_lora_conversion_utils.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import torch
55

6+
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX, FLUX_LORA_TRANSFORMER_PREFIX
67
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
78
from invokeai.backend.lora.layers.utils import any_lora_layer_from_state_dict
89
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
@@ -23,11 +24,6 @@
2324
FLUX_KOHYA_CLIP_KEY_REGEX = r"lora_te1_text_model_encoder_layers_(\d+)_(mlp|self_attn)_(\w+)\.?.*"
2425

2526

26-
# Prefixes used to distinguish between transformer and CLIP text encoder keys in the InvokeAI LoRA format.
27-
FLUX_KOHYA_TRANFORMER_PREFIX = "lora_transformer-"
28-
FLUX_KOHYA_CLIP_PREFIX = "lora_clip-"
29-
30-
3127
def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> bool:
3228
"""Checks if the provided state dict is likely in the Kohya FLUX LoRA format.
3329
@@ -67,9 +63,9 @@ def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -
6763
# Create LoRA layers.
6864
layers: dict[str, AnyLoRALayer] = {}
6965
for layer_key, layer_state_dict in transformer_grouped_sd.items():
70-
layers[FLUX_KOHYA_TRANFORMER_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
66+
layers[FLUX_LORA_TRANSFORMER_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
7167
for layer_key, layer_state_dict in clip_grouped_sd.items():
72-
layers[FLUX_KOHYA_CLIP_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
68+
layers[FLUX_LORA_CLIP_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
7369

7470
# Create and return the LoRAModelRaw.
7571
return LoRAModelRaw(layers=layers)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Prefixes used to distinguish between transformer and CLIP text encoder keys in the FLUX InvokeAI LoRA format.
2+
FLUX_LORA_TRANSFORMER_PREFIX = "lora_transformer-"
3+
FLUX_LORA_CLIP_PREFIX = "lora_clip-"

tests/backend/lora/conversions/test_flux_diffusers_lora_conversion_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
is_state_dict_likely_in_flux_diffusers_format,
66
lora_model_from_flux_diffusers_state_dict,
77
)
8-
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import FLUX_KOHYA_TRANFORMER_PREFIX
8+
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
99
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_diffusers_format import (
1010
state_dict_keys as flux_diffusers_state_dict_keys,
1111
)
@@ -51,7 +51,7 @@ def test_lora_model_from_flux_diffusers_state_dict():
5151
concatenated_weights = ["to_k", "to_v", "proj_mlp", "add_k_proj", "add_v_proj"]
5252
expected_lora_layers = {k for k in expected_lora_layers if not any(w in k for w in concatenated_weights)}
5353
assert len(model.layers) == len(expected_lora_layers)
54-
assert all(k.startswith(FLUX_KOHYA_TRANFORMER_PREFIX) for k in model.layers.keys())
54+
assert all(k.startswith(FLUX_LORA_TRANSFORMER_PREFIX) for k in model.layers.keys())
5555

5656

5757
def test_lora_model_from_flux_diffusers_state_dict_extra_keys_error():

tests/backend/lora/conversions/test_flux_kohya_lora_conversion_utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,11 @@
55
from invokeai.backend.flux.model import Flux
66
from invokeai.backend.flux.util import params
77
from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import (
8-
FLUX_KOHYA_CLIP_PREFIX,
9-
FLUX_KOHYA_TRANFORMER_PREFIX,
108
_convert_flux_transformer_kohya_state_dict_to_invoke_format,
119
is_state_dict_likely_in_flux_kohya_format,
1210
lora_model_from_flux_kohya_state_dict,
1311
)
12+
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX, FLUX_LORA_TRANSFORMER_PREFIX
1413
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_diffusers_format import (
1514
state_dict_keys as flux_diffusers_state_dict_keys,
1615
)
@@ -95,8 +94,8 @@ def test_lora_model_from_flux_kohya_state_dict(sd_keys: list[str]):
9594
expected_layer_keys: set[str] = set()
9695
for k in sd_keys:
9796
# Replace prefixes.
98-
k = k.replace("lora_unet_", FLUX_KOHYA_TRANFORMER_PREFIX)
99-
k = k.replace("lora_te1_", FLUX_KOHYA_CLIP_PREFIX)
97+
k = k.replace("lora_unet_", FLUX_LORA_TRANSFORMER_PREFIX)
98+
k = k.replace("lora_te1_", FLUX_LORA_CLIP_PREFIX)
10099
# Remove suffixes.
101100
k = k.replace(".lora_up.weight", "")
102101
k = k.replace(".lora_down.weight", "")

0 commit comments

Comments
 (0)