Skip to content

Commit 974074b

Browse files
feat(mm): add migration to flat model storage
1 parent c852187 commit 974074b

File tree

2 files changed

+239
-0
lines changed

2 files changed

+239
-0
lines changed

invokeai/app/services/shared/sqlite/sqlite_util.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_19 import build_migration_19
2525
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_20 import build_migration_20
2626
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_21 import build_migration_21
27+
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_22 import build_migration_22
2728
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
2829

2930

@@ -65,6 +66,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
6566
migrator.register_migration(build_migration_19(app_config=config))
6667
migrator.register_migration(build_migration_20())
6768
migrator.register_migration(build_migration_21())
69+
migrator.register_migration(build_migration_22(app_config=config, logger=logger))
6870
migrator.run_migrations()
6971

7072
return db
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
import json
2+
import sqlite3
3+
from logging import Logger
4+
from pathlib import Path
5+
from typing import NamedTuple
6+
7+
from pydantic import ValidationError
8+
9+
from invokeai.app.services.config import InvokeAIAppConfig
10+
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
11+
from invokeai.backend.model_manager.config import AnyModelConfig, AnyModelConfigValidator
12+
13+
14+
class NormalizeResult(NamedTuple):
15+
new_relative_path: str | None
16+
rollback_ops: list[tuple[Path, Path]]
17+
18+
19+
class Migration22Callback:
20+
def __init__(self, app_config: InvokeAIAppConfig, logger: Logger) -> None:
21+
self._app_config = app_config
22+
self._logger = logger
23+
self._models_dir = app_config.models_path.resolve()
24+
25+
def __call__(self, cursor: sqlite3.Cursor) -> None:
26+
# Grab all model records
27+
cursor.execute("SELECT id, config FROM models;")
28+
rows = cursor.fetchall()
29+
30+
for model_id, config_json in rows:
31+
try:
32+
# Get the model config as a pydantic object
33+
config = self._load_model_config(config_json)
34+
except ValidationError:
35+
# This could happen if the config schema changed in a way that makes old configs invalid. Unlikely
36+
# for users, more likely for devs testing out migration paths.
37+
self._logger.warning("Skipping model %s: invalid config schema", model_id)
38+
continue
39+
except json.JSONDecodeError:
40+
# This should never happen, as we use pydantic to serialize the config to JSON.
41+
self._logger.warning("Skipping model %s: invalid config JSON", model_id)
42+
continue
43+
44+
# We'll use a savepoint so we can roll back the database update if something goes wrong, and a simple
45+
# rollback of file operations if needed.
46+
cursor.execute("SAVEPOINT migrate_model")
47+
try:
48+
new_relative_path, rollback_ops = self._normalize_model_storage(
49+
key=config.key,
50+
path_value=config.path,
51+
)
52+
except Exception as err:
53+
self._logger.error("Error normalizing model %s: %s", config.key, err)
54+
cursor.execute("ROLLBACK TO SAVEPOINT migrate_model")
55+
cursor.execute("RELEASE SAVEPOINT migrate_model")
56+
continue
57+
58+
if new_relative_path is None:
59+
cursor.execute("RELEASE SAVEPOINT migrate_model")
60+
continue
61+
62+
config.path = new_relative_path
63+
try:
64+
cursor.execute(
65+
"UPDATE models SET config = ? WHERE id = ?;",
66+
(config.model_dump_json(), model_id),
67+
)
68+
except Exception as err:
69+
self._logger.error("Database update failed for model %s: %s", config.key, err)
70+
cursor.execute("ROLLBACK TO SAVEPOINT migrate_model")
71+
cursor.execute("RELEASE SAVEPOINT migrate_model")
72+
self._rollback_file_ops(rollback_ops)
73+
continue
74+
75+
cursor.execute("RELEASE SAVEPOINT migrate_model")
76+
77+
self._prune_empty_directories()
78+
79+
def _normalize_model_storage(self, key: str, path_value: str) -> NormalizeResult:
80+
models_dir = self._models_dir
81+
stored_path = Path(path_value)
82+
83+
relative_path: Path | None
84+
if stored_path.is_absolute():
85+
# If the stored path is absolute, we need to check if it's inside the models directory, which means it is
86+
# an Invoke-managed model. If it's outside, it is user-managed we leave it alone.
87+
try:
88+
relative_path = stored_path.resolve().relative_to(models_dir)
89+
except ValueError:
90+
self._logger.info("Leaving user-managed model %s at %s", key, stored_path)
91+
return NormalizeResult(new_relative_path=None, rollback_ops=[])
92+
else:
93+
# Relative paths are always relative to the models directory and thus Invoke-managed.
94+
relative_path = stored_path
95+
96+
# If the relative path is empty, assume something is wrong. Warn and skip.
97+
if not relative_path.parts:
98+
self._logger.warning("Skipping model %s: empty relative path", key)
99+
return NormalizeResult(new_relative_path=None, rollback_ops=[])
100+
101+
# Sanity check: the path is relative. It should be present in the models directory.
102+
absolute_path = (models_dir / relative_path).resolve()
103+
if not absolute_path.exists():
104+
self._logger.warning(
105+
"Skipping model %s: expected model files at %s but nothing was found",
106+
key,
107+
absolute_path,
108+
)
109+
return NormalizeResult(new_relative_path=None, rollback_ops=[])
110+
111+
if relative_path.parts[0] == key:
112+
# Already normalized. Still ensure the stored path is relative.
113+
normalized_path = relative_path.as_posix()
114+
# If the stored path is already the normalized path, no change is needed.
115+
new_relative_path = normalized_path if stored_path.as_posix() != normalized_path else None
116+
return NormalizeResult(new_relative_path=new_relative_path, rollback_ops=[])
117+
118+
# We'll store the file operations we perform so we can roll them back if needed.
119+
rollback_ops: list[tuple[Path, Path]] = []
120+
121+
# Destination directory is models_dir/<key> - a flat directory structure.
122+
destination_dir = models_dir / key
123+
124+
try:
125+
if absolute_path.is_file():
126+
destination_dir.mkdir(parents=True, exist_ok=True)
127+
dest_file = destination_dir / absolute_path.name
128+
# This really shouldn't happen.
129+
if dest_file.exists():
130+
self._logger.warning(
131+
"Destination for model %s already exists at %s; skipping move",
132+
key,
133+
dest_file,
134+
)
135+
return NormalizeResult(new_relative_path=None, rollback_ops=[])
136+
137+
self._logger.info("Moving model file %s -> %s", absolute_path, dest_file)
138+
139+
# `Path.rename()` effectively moves the file or directory.
140+
absolute_path.rename(dest_file)
141+
rollback_ops.append((dest_file, absolute_path))
142+
143+
return NormalizeResult(
144+
new_relative_path=(Path(key) / dest_file.name).as_posix(),
145+
rollback_ops=rollback_ops,
146+
)
147+
148+
if absolute_path.is_dir():
149+
dest_path = destination_dir
150+
# This really shouldn't happen.
151+
if dest_path.exists():
152+
self._logger.warning(
153+
"Destination directory %s already exists for model %s; skipping",
154+
dest_path,
155+
key,
156+
)
157+
return NormalizeResult(new_relative_path=None, rollback_ops=[])
158+
159+
self._logger.info("Moving model directory %s -> %s", absolute_path, dest_path)
160+
161+
# `Path.rename()` effectively moves the file or directory.
162+
absolute_path.rename(dest_path)
163+
rollback_ops.append((dest_path, absolute_path))
164+
165+
return NormalizeResult(
166+
new_relative_path=Path(key).as_posix(),
167+
rollback_ops=rollback_ops,
168+
)
169+
170+
# Maybe a broken symlink or something else weird?
171+
self._logger.warning("Skipping model %s: path %s is neither a file nor directory", key, absolute_path)
172+
return NormalizeResult(new_relative_path=None, rollback_ops=[])
173+
except Exception:
174+
self._rollback_file_ops(rollback_ops)
175+
raise
176+
177+
def _rollback_file_ops(self, rollback_ops: list[tuple[Path, Path]]) -> None:
178+
# This is a super-simple rollback that just reverses the move operations we performed.
179+
for source, destination in reversed(rollback_ops):
180+
try:
181+
if source.exists():
182+
source.rename(destination)
183+
except Exception as err:
184+
self._logger.error("Failed to rollback move %s -> %s: %s", source, destination, err)
185+
186+
def _prune_empty_directories(self) -> None:
187+
# These directories are system directories we want to keep even if empty. Technically, the app should not
188+
# have any problems if these are removed, creating them as needed, but it's cleaner to just leave them alone.
189+
keep_names = {"model_images", ".download_cache"}
190+
keep_dirs = {self._models_dir / name for name in keep_names}
191+
removed_dirs: set[Path] = set()
192+
193+
# Walk the models directory tree from the bottom up, removing empty directories. We sort by path length
194+
# descending to ensure we visit children before parents.
195+
for directory in sorted(self._models_dir.rglob("*"), key=lambda p: len(p.parts), reverse=True):
196+
if not directory.is_dir():
197+
continue
198+
if directory == self._models_dir:
199+
continue
200+
if any(directory == keep or keep in directory.parents for keep in keep_dirs):
201+
continue
202+
203+
try:
204+
next(directory.iterdir())
205+
except StopIteration:
206+
try:
207+
directory.rmdir()
208+
removed_dirs.add(directory)
209+
self._logger.debug("Removed empty directory %s", directory)
210+
except OSError:
211+
# Directory not empty (or some other error) - bail out.
212+
self._logger.warning("Failed to prune directory %s - not empty?", directory)
213+
continue
214+
except OSError:
215+
continue
216+
217+
self._logger.info("Pruned %d empty directories under %s", len(removed_dirs), self._models_dir)
218+
219+
def _load_model_config(self, config_json: str) -> AnyModelConfig:
220+
# The typing of the validator says it returns Unknown, but it's really a AnyModelConfig. This utility function
221+
# just makes that clear.
222+
return AnyModelConfigValidator.validate_json(config_json)
223+
224+
225+
def build_migration_22(app_config: InvokeAIAppConfig, logger: Logger) -> Migration:
226+
"""Builds the migration object for migrating from version 21 to version 22.
227+
228+
This migration normalizes on-disk model storage so that each model lives within
229+
a directory named by its key inside the Invoke-managed models directory, and
230+
updates database records to reference the new relative paths.
231+
"""
232+
233+
return Migration(
234+
from_version=21,
235+
to_version=22,
236+
callback=Migration22Callback(app_config=app_config, logger=logger),
237+
)

0 commit comments

Comments
 (0)