|
| 1 | +import json |
| 2 | +import sqlite3 |
| 3 | +from logging import Logger |
| 4 | +from pathlib import Path |
| 5 | +from typing import NamedTuple |
| 6 | + |
| 7 | +from pydantic import ValidationError |
| 8 | + |
| 9 | +from invokeai.app.services.config import InvokeAIAppConfig |
| 10 | +from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration |
| 11 | +from invokeai.backend.model_manager.config import AnyModelConfig, AnyModelConfigValidator |
| 12 | + |
| 13 | + |
| 14 | +class NormalizeResult(NamedTuple): |
| 15 | + new_relative_path: str | None |
| 16 | + rollback_ops: list[tuple[Path, Path]] |
| 17 | + |
| 18 | + |
| 19 | +class Migration22Callback: |
| 20 | + def __init__(self, app_config: InvokeAIAppConfig, logger: Logger) -> None: |
| 21 | + self._app_config = app_config |
| 22 | + self._logger = logger |
| 23 | + self._models_dir = app_config.models_path.resolve() |
| 24 | + |
| 25 | + def __call__(self, cursor: sqlite3.Cursor) -> None: |
| 26 | + # Grab all model records |
| 27 | + cursor.execute("SELECT id, config FROM models;") |
| 28 | + rows = cursor.fetchall() |
| 29 | + |
| 30 | + for model_id, config_json in rows: |
| 31 | + try: |
| 32 | + # Get the model config as a pydantic object |
| 33 | + config = self._load_model_config(config_json) |
| 34 | + except ValidationError: |
| 35 | + # This could happen if the config schema changed in a way that makes old configs invalid. Unlikely |
| 36 | + # for users, more likely for devs testing out migration paths. |
| 37 | + self._logger.warning("Skipping model %s: invalid config schema", model_id) |
| 38 | + continue |
| 39 | + except json.JSONDecodeError: |
| 40 | + # This should never happen, as we use pydantic to serialize the config to JSON. |
| 41 | + self._logger.warning("Skipping model %s: invalid config JSON", model_id) |
| 42 | + continue |
| 43 | + |
| 44 | + # We'll use a savepoint so we can roll back the database update if something goes wrong, and a simple |
| 45 | + # rollback of file operations if needed. |
| 46 | + cursor.execute("SAVEPOINT migrate_model") |
| 47 | + try: |
| 48 | + new_relative_path, rollback_ops = self._normalize_model_storage( |
| 49 | + key=config.key, |
| 50 | + path_value=config.path, |
| 51 | + ) |
| 52 | + except Exception as err: |
| 53 | + self._logger.error("Error normalizing model %s: %s", config.key, err) |
| 54 | + cursor.execute("ROLLBACK TO SAVEPOINT migrate_model") |
| 55 | + cursor.execute("RELEASE SAVEPOINT migrate_model") |
| 56 | + continue |
| 57 | + |
| 58 | + if new_relative_path is None: |
| 59 | + cursor.execute("RELEASE SAVEPOINT migrate_model") |
| 60 | + continue |
| 61 | + |
| 62 | + config.path = new_relative_path |
| 63 | + try: |
| 64 | + cursor.execute( |
| 65 | + "UPDATE models SET config = ? WHERE id = ?;", |
| 66 | + (config.model_dump_json(), model_id), |
| 67 | + ) |
| 68 | + except Exception as err: |
| 69 | + self._logger.error("Database update failed for model %s: %s", config.key, err) |
| 70 | + cursor.execute("ROLLBACK TO SAVEPOINT migrate_model") |
| 71 | + cursor.execute("RELEASE SAVEPOINT migrate_model") |
| 72 | + self._rollback_file_ops(rollback_ops) |
| 73 | + continue |
| 74 | + |
| 75 | + cursor.execute("RELEASE SAVEPOINT migrate_model") |
| 76 | + |
| 77 | + self._prune_empty_directories() |
| 78 | + |
| 79 | + def _normalize_model_storage(self, key: str, path_value: str) -> NormalizeResult: |
| 80 | + models_dir = self._models_dir |
| 81 | + stored_path = Path(path_value) |
| 82 | + |
| 83 | + relative_path: Path | None |
| 84 | + if stored_path.is_absolute(): |
| 85 | + # If the stored path is absolute, we need to check if it's inside the models directory, which means it is |
| 86 | + # an Invoke-managed model. If it's outside, it is user-managed we leave it alone. |
| 87 | + try: |
| 88 | + relative_path = stored_path.resolve().relative_to(models_dir) |
| 89 | + except ValueError: |
| 90 | + self._logger.info("Leaving user-managed model %s at %s", key, stored_path) |
| 91 | + return NormalizeResult(new_relative_path=None, rollback_ops=[]) |
| 92 | + else: |
| 93 | + # Relative paths are always relative to the models directory and thus Invoke-managed. |
| 94 | + relative_path = stored_path |
| 95 | + |
| 96 | + # If the relative path is empty, assume something is wrong. Warn and skip. |
| 97 | + if not relative_path.parts: |
| 98 | + self._logger.warning("Skipping model %s: empty relative path", key) |
| 99 | + return NormalizeResult(new_relative_path=None, rollback_ops=[]) |
| 100 | + |
| 101 | + # Sanity check: the path is relative. It should be present in the models directory. |
| 102 | + absolute_path = (models_dir / relative_path).resolve() |
| 103 | + if not absolute_path.exists(): |
| 104 | + self._logger.warning( |
| 105 | + "Skipping model %s: expected model files at %s but nothing was found", |
| 106 | + key, |
| 107 | + absolute_path, |
| 108 | + ) |
| 109 | + return NormalizeResult(new_relative_path=None, rollback_ops=[]) |
| 110 | + |
| 111 | + if relative_path.parts[0] == key: |
| 112 | + # Already normalized. Still ensure the stored path is relative. |
| 113 | + normalized_path = relative_path.as_posix() |
| 114 | + # If the stored path is already the normalized path, no change is needed. |
| 115 | + new_relative_path = normalized_path if stored_path.as_posix() != normalized_path else None |
| 116 | + return NormalizeResult(new_relative_path=new_relative_path, rollback_ops=[]) |
| 117 | + |
| 118 | + # We'll store the file operations we perform so we can roll them back if needed. |
| 119 | + rollback_ops: list[tuple[Path, Path]] = [] |
| 120 | + |
| 121 | + # Destination directory is models_dir/<key> - a flat directory structure. |
| 122 | + destination_dir = models_dir / key |
| 123 | + |
| 124 | + try: |
| 125 | + if absolute_path.is_file(): |
| 126 | + destination_dir.mkdir(parents=True, exist_ok=True) |
| 127 | + dest_file = destination_dir / absolute_path.name |
| 128 | + # This really shouldn't happen. |
| 129 | + if dest_file.exists(): |
| 130 | + self._logger.warning( |
| 131 | + "Destination for model %s already exists at %s; skipping move", |
| 132 | + key, |
| 133 | + dest_file, |
| 134 | + ) |
| 135 | + return NormalizeResult(new_relative_path=None, rollback_ops=[]) |
| 136 | + |
| 137 | + self._logger.info("Moving model file %s -> %s", absolute_path, dest_file) |
| 138 | + |
| 139 | + # `Path.rename()` effectively moves the file or directory. |
| 140 | + absolute_path.rename(dest_file) |
| 141 | + rollback_ops.append((dest_file, absolute_path)) |
| 142 | + |
| 143 | + return NormalizeResult( |
| 144 | + new_relative_path=(Path(key) / dest_file.name).as_posix(), |
| 145 | + rollback_ops=rollback_ops, |
| 146 | + ) |
| 147 | + |
| 148 | + if absolute_path.is_dir(): |
| 149 | + dest_path = destination_dir |
| 150 | + # This really shouldn't happen. |
| 151 | + if dest_path.exists(): |
| 152 | + self._logger.warning( |
| 153 | + "Destination directory %s already exists for model %s; skipping", |
| 154 | + dest_path, |
| 155 | + key, |
| 156 | + ) |
| 157 | + return NormalizeResult(new_relative_path=None, rollback_ops=[]) |
| 158 | + |
| 159 | + self._logger.info("Moving model directory %s -> %s", absolute_path, dest_path) |
| 160 | + |
| 161 | + # `Path.rename()` effectively moves the file or directory. |
| 162 | + absolute_path.rename(dest_path) |
| 163 | + rollback_ops.append((dest_path, absolute_path)) |
| 164 | + |
| 165 | + return NormalizeResult( |
| 166 | + new_relative_path=Path(key).as_posix(), |
| 167 | + rollback_ops=rollback_ops, |
| 168 | + ) |
| 169 | + |
| 170 | + # Maybe a broken symlink or something else weird? |
| 171 | + self._logger.warning("Skipping model %s: path %s is neither a file nor directory", key, absolute_path) |
| 172 | + return NormalizeResult(new_relative_path=None, rollback_ops=[]) |
| 173 | + except Exception: |
| 174 | + self._rollback_file_ops(rollback_ops) |
| 175 | + raise |
| 176 | + |
| 177 | + def _rollback_file_ops(self, rollback_ops: list[tuple[Path, Path]]) -> None: |
| 178 | + # This is a super-simple rollback that just reverses the move operations we performed. |
| 179 | + for source, destination in reversed(rollback_ops): |
| 180 | + try: |
| 181 | + if source.exists(): |
| 182 | + source.rename(destination) |
| 183 | + except Exception as err: |
| 184 | + self._logger.error("Failed to rollback move %s -> %s: %s", source, destination, err) |
| 185 | + |
| 186 | + def _prune_empty_directories(self) -> None: |
| 187 | + # These directories are system directories we want to keep even if empty. Technically, the app should not |
| 188 | + # have any problems if these are removed, creating them as needed, but it's cleaner to just leave them alone. |
| 189 | + keep_names = {"model_images", ".download_cache"} |
| 190 | + keep_dirs = {self._models_dir / name for name in keep_names} |
| 191 | + removed_dirs: set[Path] = set() |
| 192 | + |
| 193 | + # Walk the models directory tree from the bottom up, removing empty directories. We sort by path length |
| 194 | + # descending to ensure we visit children before parents. |
| 195 | + for directory in sorted(self._models_dir.rglob("*"), key=lambda p: len(p.parts), reverse=True): |
| 196 | + if not directory.is_dir(): |
| 197 | + continue |
| 198 | + if directory == self._models_dir: |
| 199 | + continue |
| 200 | + if any(directory == keep or keep in directory.parents for keep in keep_dirs): |
| 201 | + continue |
| 202 | + |
| 203 | + try: |
| 204 | + next(directory.iterdir()) |
| 205 | + except StopIteration: |
| 206 | + try: |
| 207 | + directory.rmdir() |
| 208 | + removed_dirs.add(directory) |
| 209 | + self._logger.debug("Removed empty directory %s", directory) |
| 210 | + except OSError: |
| 211 | + # Directory not empty (or some other error) - bail out. |
| 212 | + self._logger.warning("Failed to prune directory %s - not empty?", directory) |
| 213 | + continue |
| 214 | + except OSError: |
| 215 | + continue |
| 216 | + |
| 217 | + self._logger.info("Pruned %d empty directories under %s", len(removed_dirs), self._models_dir) |
| 218 | + |
| 219 | + def _load_model_config(self, config_json: str) -> AnyModelConfig: |
| 220 | + # The typing of the validator says it returns Unknown, but it's really a AnyModelConfig. This utility function |
| 221 | + # just makes that clear. |
| 222 | + return AnyModelConfigValidator.validate_json(config_json) |
| 223 | + |
| 224 | + |
| 225 | +def build_migration_22(app_config: InvokeAIAppConfig, logger: Logger) -> Migration: |
| 226 | + """Builds the migration object for migrating from version 21 to version 22. |
| 227 | +
|
| 228 | + This migration normalizes on-disk model storage so that each model lives within |
| 229 | + a directory named by its key inside the Invoke-managed models directory, and |
| 230 | + updates database records to reference the new relative paths. |
| 231 | + """ |
| 232 | + |
| 233 | + return Migration( |
| 234 | + from_version=21, |
| 235 | + to_version=22, |
| 236 | + callback=Migration22Callback(app_config=app_config, logger=logger), |
| 237 | + ) |
0 commit comments