diff --git a/invokeai/backend/model_manager/__init__.py b/invokeai/backend/model_manager/__init__.py index 175197984a9..ed29b6589a9 100644 --- a/invokeai/backend/model_manager/__init__.py +++ b/invokeai/backend/model_manager/__init__.py @@ -5,6 +5,7 @@ AnyModelConfig, BaseModelType, InvalidModelConfigException, + ModelConfigBase, ModelConfigFactory, ModelFormat, ModelRepoVariant, @@ -32,4 +33,5 @@ "ModelVariantType", "SchedulerPredictionType", "SubModelType", + "ModelConfigBase", ] diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 869536beba5..fd9ef9505dc 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -25,23 +25,26 @@ import time from abc import ABC, abstractmethod from enum import Enum -from functools import cached_property from inspect import isabstract from pathlib import Path from typing import ClassVar, Literal, Optional, TypeAlias, Union import diffusers import onnxruntime as ort +import safetensors.torch import torch from diffusers.models.modeling_utils import ModelMixin +from picklescan.scanner import scan_file_path from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter from typing_extensions import Annotated, Any, Dict from invokeai.app.util.misc import uuid_string from invokeai.backend.model_hash.hash_validator import validate_hash from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash +from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader from invokeai.backend.raw_model import RawModel from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES +from invokeai.backend.util.silence_warnings import SilenceWarnings logger = logging.getLogger(__name__) @@ -215,12 +218,37 @@ def __init__(self, path: Path, hash_algo: HASHING_ALGORITHMS = "blake3_single"): self.name = path.name self.hash_algo = hash_algo - @cached_property def hash(self): return ModelHash(algorithm=self.hash_algo).hash(self.path) - def lazy_load_state_dict(self) -> dict[str, torch.Tensor]: - raise NotImplementedError() + def size(self): + if self.format_type == ModelFormat.Checkpoint: + return self.path.stat().st_size + return sum(file.stat().st_size for file in self.path.rglob("*")) + + def component_paths(self): + if self.format_type == ModelFormat.Checkpoint: + return {self.path} + extensions = {".safetensors", ".pt", ".pth", ".ckpt", ".bin", ".gguf"} + return {f for f in self.path.rglob("*") if f.suffix in extensions} + + @staticmethod + def load_state_dict(path: Path): + with SilenceWarnings(): + if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")): + scan_result = scan_file_path(path) + if scan_result.infected_files != 0 or scan_result.scan_err: + raise RuntimeError(f"The model {path.stem} is potentially infected by malware. Aborting import.") + checkpoint = torch.load(path, map_location="cpu") + elif path.suffix.endswith(".gguf"): + checkpoint = gguf_sd_loader(path, compute_dtype=torch.float32) + elif path.suffix.endswith(".safetensors"): + checkpoint = safetensors.torch.load_file(path) + else: + raise ValueError(f"Unrecognized model extension: {path.suffix}") + + state_dict = checkpoint.get("state_dict", checkpoint) + return state_dict class MatchSpeed(int, Enum): @@ -343,7 +371,7 @@ def from_model_on_disk(cls, mod: ModelOnDisk, **overrides): fields["source"] = fields.get("source") or fields["path"] fields["source_type"] = fields.get("source_type") or ModelSourceType.Path fields["name"] = mod.name - fields["hash"] = fields.get("hash") or mod.hash + fields["hash"] = fields.get("hash") or mod.hash() fields.update(overrides) return cls(**fields) diff --git a/invokeai/backend/model_manager/legacy_probe.py b/invokeai/backend/model_manager/legacy_probe.py index a758dc753ab..dded12e497a 100644 --- a/invokeai/backend/model_manager/legacy_probe.py +++ b/invokeai/backend/model_manager/legacy_probe.py @@ -3,10 +3,10 @@ from pathlib import Path from typing import Any, Callable, Dict, Literal, Optional, Union +import picklescan.scanner as pscan import safetensors.torch import spandrel import torch -from picklescan.scanner import scan_file_path import invokeai.backend.util.logging as logger from invokeai.app.util.misc import uuid_string @@ -483,7 +483,7 @@ def _scan_model(cls, model_name: str, checkpoint: Path) -> None: and option to exit if an infected file is identified. """ # scan model - scan_result = scan_file_path(checkpoint) + scan_result = pscan.scan_file_path(checkpoint) if scan_result.infected_files != 0: raise Exception(f"The model {model_name} is potentially infected by malware. Aborting import.") if scan_result.scan_err: diff --git a/invokeai/backend/model_manager/util/model_util.py b/invokeai/backend/model_manager/util/model_util.py index 1a74152882b..4fc54c34f19 100644 --- a/invokeai/backend/model_manager/util/model_util.py +++ b/invokeai/backend/model_manager/util/model_util.py @@ -4,9 +4,9 @@ from pathlib import Path from typing import Dict, Optional, Union +import picklescan.scanner as pscan import safetensors import torch -from picklescan.scanner import scan_file_path from invokeai.backend.model_manager.config import ClipVariantType from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader @@ -57,7 +57,7 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = True) -> Dict[str, checkpoint = gguf_sd_loader(Path(path), compute_dtype=torch.float32) else: if scan: - scan_result = scan_file_path(path) + scan_result = pscan.scan_file_path(path) if scan_result.infected_files != 0: raise Exception(f"The model at {path} is potentially infected by malware. Aborting import.") if scan_result.scan_err: diff --git a/pyproject.toml b/pyproject.toml index 58aac54adfa..1eaabbdfed0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,6 +95,7 @@ dependencies = [ "semver~=3.0.1", "test-tube", "windows-curses; sys_platform=='win32'", + "humanize==4.12.1", ] [project.optional-dependencies] @@ -103,6 +104,7 @@ dependencies = [ "xformers>=0.0.28.post1; sys_platform!='darwin'", # torch 2.4+cu carries its own triton dependency ] + "onnx" = ["onnxruntime"] "onnx-cuda" = ["onnxruntime-gpu"] "onnx-directml" = ["onnxruntime-directml"] diff --git a/scripts/probe-model.py b/scripts/classify-model.py similarity index 60% rename from scripts/probe-model.py rename to scripts/classify-model.py index c04fd5c961d..5cc2e122327 100755 --- a/scripts/probe-model.py +++ b/scripts/classify-model.py @@ -7,7 +7,7 @@ from typing import get_args from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS -from invokeai.backend.model_manager import InvalidModelConfigException, ModelProbe +from invokeai.backend.model_manager import InvalidModelConfigException, ModelConfigBase, ModelProbe algos = ", ".join(set(get_args(HASHING_ALGORITHMS))) @@ -25,9 +25,17 @@ ) args = parser.parse_args() + +def classify_with_fallback(path: Path, hash_algo: HASHING_ALGORITHMS): + try: + return ModelConfigBase.classify(path, hash_algo) + except InvalidModelConfigException: + return ModelProbe.probe(path, hash_algo=hash_algo) + + for path in args.model_path: try: - info = ModelProbe.probe(path, hash_algo=args.hash_algo) - print(f"{path}:{info.model_dump_json(indent=4)}") - except InvalidModelConfigException as exc: - print(exc) + config = classify_with_fallback(path, args.hash_algo) + print(f"{path}:{config.model_dump_json(indent=4)}") + except InvalidModelConfigException as e: + print(e) diff --git a/scripts/strip_models.py b/scripts/strip_models.py new file mode 100644 index 00000000000..7756cce736b --- /dev/null +++ b/scripts/strip_models.py @@ -0,0 +1,115 @@ +""" +Usage: + strip_models.py + + Strips tensor data from model state_dicts while preserving metadata. + Used to create lightweight models for testing model classification. + +Parameters: + Directory containing original models. + Directory where stripped models will be saved. + +Options: + -h, --help Show this help message and exit +""" + +import argparse +import json +import shutil +import sys +from pathlib import Path + +import humanize +import torch + +from invokeai.backend.model_manager.config import ModelFormat, ModelOnDisk +from invokeai.backend.model_manager.search import ModelSearch + + +def strip(v): + match v: + case torch.Tensor(): + return {"shape": v.shape, "dtype": str(v.dtype), "fakeTensor": True} + case dict(): + return {k: strip(v) for k, v in v.items()} + case list() | tuple(): + return [strip(x) for x in v] + case _: + return v + + +STR_TO_DTYPE = {str(dtype): dtype for dtype in torch.__dict__.values() if isinstance(dtype, torch.dtype)} + + +def dress(v): + match v: + case {"shape": shape, "dtype": dtype_str, "fakeTensor": True}: + dtype = STR_TO_DTYPE[dtype_str] + return torch.empty(shape, dtype=dtype) + case dict(): + return {k: dress(v) for k, v in v.items()} + case list() | tuple(): + return [dress(x) for x in v] + case _: + return v + + +def load_stripped_model(path: Path, *args, **kwargs): + with open(path, "r") as f: + contents = json.load(f) + return dress(contents) + + +def create_stripped_model(original_model_path: Path, stripped_model_path: Path) -> ModelOnDisk: + original = ModelOnDisk(original_model_path) + if original.format_type == ModelFormat.Checkpoint: + shutil.copy2(original.path, stripped_model_path) + else: + shutil.copytree(original.path, stripped_model_path, dirs_exist_ok=True) + stripped = ModelOnDisk(stripped_model_path) + print(f"Created clone of {original.name} at {stripped.path}") + + for component_path in stripped.component_paths(): + original_state_dict = ModelOnDisk.load_state_dict(component_path) + stripped_state_dict = strip(original_state_dict) # type: ignore + with open(component_path, "w") as f: + json.dump(stripped_state_dict, f, indent=4) + + before_size = humanize.naturalsize(original.size()) + after_size = humanize.naturalsize(stripped.size()) + print(f"{original.name} before: {before_size}, after: {after_size}") + + return stripped + + +def parse_arguments(): + class Parser(argparse.ArgumentParser): + def error(self, reason): + raise ValueError(reason) + + parser = Parser() + parser.add_argument("models_input_dir", type=Path) + parser.add_argument("stripped_output_dir", type=Path) + + try: + args = parser.parse_args() + except ValueError as e: + print(f"Error: {e}", file=sys.stderr) + print(__doc__, file=sys.stderr) + sys.exit(2) + + if not args.models_input_dir.exists(): + parser.error(f"Error: Input models directory '{args.models_input_dir}' does not exist.") + if not args.models_input_dir.is_dir(): + parser.error(f"Error: '{args.input_models_dir}' is not a directory.") + + return args + + +if __name__ == "__main__": + args = parse_arguments() + model_paths = sorted(ModelSearch().search(args.models_input_dir)) + + for path in model_paths: + stripped_path = args.stripped_output_dir / path.name + create_stripped_model(path, stripped_path) diff --git a/tests/conftest.py b/tests/conftest.py index a5489888da0..b112a4ff2e9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,9 +7,14 @@ import logging import shutil from pathlib import Path +from types import SimpleNamespace +import picklescan.scanner import pytest +import safetensors.torch +import torch +import invokeai.backend.quantization.gguf.loaders as gguf_loaders from invokeai.app.services.board_image_records.board_image_records_sqlite import SqliteBoardImageRecordStorage from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService @@ -20,6 +25,7 @@ from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService from invokeai.app.services.invoker import Invoker from invokeai.backend.util.logging import InvokeAILogger +from scripts.strip_models import load_stripped_model from tests.backend.model_manager.model_manager_fixtures import * # noqa: F403 from tests.fixtures.sqlite_database import create_mock_sqlite_database # noqa: F401 from tests.test_nodes import TestEventService @@ -73,3 +79,23 @@ def invokeai_root_dir(tmp_path_factory) -> Path: temp_dir: Path = tmp_path_factory.mktemp("data") / "invokeai_root" shutil.copytree(root_template, temp_dir) return temp_dir + + +@pytest.fixture(scope="function") +def override_model_loading(monkeypatch): + """The legacy model probe directly calls model loading functions (e.g. torch.load) and also performs file scanning + via picklescan.scanner.scan_file_path. This fixture replaces these functions with test-friendly versions for + model files that have been 'stripped' to reduce their size (see scripts/strip_models.py). + + Ideally, model loading would be injected as a dependency (i.e. ModelOnDisk) - but to avoid modifying the legacy probe, + we monkeypatch as a temporary workaround until the legacy probe is fully deprecated. + """ + monkeypatch.setattr(torch, "load", load_stripped_model) + monkeypatch.setattr(safetensors.torch, "load", load_stripped_model) + monkeypatch.setattr(safetensors.torch, "load_file", load_stripped_model) + monkeypatch.setattr(gguf_loaders, "gguf_sd_loader", load_stripped_model) + + def fake_scan(*args, **kwargs): + return SimpleNamespace(infected_files=0, scan_err=None) + + monkeypatch.setattr(picklescan.scanner, "scan_file_path", fake_scan) diff --git a/tests/test_model_probe.py b/tests/test_model_probe.py index f84504dd291..058237fdf0f 100644 --- a/tests/test_model_probe.py +++ b/tests/test_model_probe.py @@ -32,6 +32,7 @@ get_default_settings_main, ) from invokeai.backend.model_manager.search import ModelSearch +from invokeai.backend.util.logging import InvokeAILogger @pytest.mark.parametrize( @@ -140,37 +141,45 @@ def test_minimal_working_example(datadir: Path): assert config.fun_quote == "Minimal working example of a ModelConfigBase subclass" -def test_regression_against_model_probe(datadir: Path): +def test_regression_against_model_probe(datadir: Path, override_model_loading): """Verifies results from ModelConfigBase.classify are consistent with those from ModelProbe.probe. The test paths are gathered from the 'test_model_probe' directory. """ + configs_with_tests = set() + model_paths = ModelSearch().search(datadir) for path in model_paths: legacy_config = new_config = None - probe_success = classify_success = True try: legacy_config = ModelProbe.probe(path) except InvalidModelConfigException: - probe_success = False + pass try: new_config = ModelConfigBase.classify(path) except InvalidModelConfigException: - classify_success = False + pass - if probe_success and classify_success: + if legacy_config and new_config: assert legacy_config == new_config - elif probe_success: + elif legacy_config: assert type(legacy_config) in ModelConfigBase._USING_LEGACY_PROBE - elif classify_success: + elif new_config: assert type(new_config) in ModelConfigBase._USING_CLASSIFY_API else: raise ValueError(f"Both probe and classify failed to classify model at path {path}.") + config_type = type(legacy_config or new_config) + configs_with_tests.add(config_type) + + untested_configs = ModelConfigBase.all_config_classes() - configs_with_tests + logger = InvokeAILogger.get_logger(__file__) + logger.warning(f"Function test_regression_against_model_probe missing test case for: {untested_configs}") + def create_fake_configs(config_cls, n): factory_args = {