Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion invokeai/app/api/routers/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
UnknownModelException,
)
from invokeai.app.util.suppress_output import SuppressOutput
from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.configs.factory import AnyModelConfig, ModelConfigFactory
from invokeai.backend.model_manager.configs.main import (
Main_Checkpoint_SD1_Config,
Main_Checkpoint_SD2_Config,
Expand All @@ -38,6 +38,7 @@
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.model_manager.starter_models import (
STARTER_BUNDLES,
Expand Down Expand Up @@ -191,6 +192,40 @@ async def get_model_record(
raise HTTPException(status_code=404, detail=str(e))


@model_manager_router.post(
"/i/{key}/reidentify",
operation_id="reidentify_model",
responses={
200: {
"description": "The model configuration was retrieved successfully",
"content": {"application/json": {"example": example_model_config}},
},
400: {"description": "Bad request"},
404: {"description": "The model could not be found"},
},
)
async def reidentify_model(
key: Annotated[str, Path(description="Key of the model to reidentify.")],
) -> AnyModelConfig:
"""Attempt to reidentify a model by re-probing its weights file."""
try:
config = ApiDependencies.invoker.services.model_manager.store.get_model(key)
models_path = ApiDependencies.invoker.services.configuration.models_path
if pathlib.Path(config.path).is_relative_to(models_path):
model_path = pathlib.Path(config.path)
else:
model_path = models_path / config.path
mod = ModelOnDisk(model_path)
result = ModelConfigFactory.from_model_on_disk(mod)
if result.config is None:
raise InvalidModelException("Unable to identify model format")
result.config.key = config.key # retain the same key
new_config = ApiDependencies.invoker.services.model_manager.store.replace_model(config.key, result.config)
return new_config
except UnknownModelException as e:
raise HTTPException(status_code=404, detail=str(e))


class FoundModel(BaseModel):
path: str = Field(description="Path to the model")
is_installed: bool = Field(description="Whether or not the model is already installed")
Expand Down
12 changes: 12 additions & 0 deletions invokeai/app/services/model_records/model_records_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,18 @@ def update_model(self, key: str, changes: ModelRecordChanges, allow_class_change
"""
pass

@abstractmethod
def replace_model(self, key: str, new_config: AnyModelConfig) -> AnyModelConfig:
"""
Replace the model record entirely, returning the new record.

This is used when we re-identify a model and have a new config object.

:param key: Unique key for the model to be updated.
:param new_config: The new model config to write.
"""
pass

@abstractmethod
def get_model(self, key: str) -> AnyModelConfig:
"""
Expand Down
17 changes: 17 additions & 0 deletions invokeai/app/services/model_records/model_records_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,23 @@ def update_model(self, key: str, changes: ModelRecordChanges, allow_class_change

return self.get_model(key)

def replace_model(self, key: str, new_config: AnyModelConfig) -> AnyModelConfig:
if key != new_config.key:
raise ValueError("key does not match new_config.key")
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
UPDATE models
SET
config=?
WHERE id=?;
""",
(new_config.model_dump_json(), key),
)
if cursor.rowcount == 0:
raise UnknownModelException("model not found")
return self.get_model(key)

def get_model(self, key: str) -> AnyModelConfig:
"""
Retrieve the ModelConfigBase instance for the indicated model.
Expand Down
6 changes: 5 additions & 1 deletion invokeai/backend/model_manager/configs/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,11 @@ def _validate_path_looks_like_model(path: Path) -> None:
# For directories, do a quick file count check with early exit
total_files = 0
# Ignore hidden files and directories
paths_to_check = (p for p in path.rglob("*") if not p.name.startswith("."))
paths_to_check = (
p
for p in path.rglob("*")
if not p.name.startswith(".") and not any(part.startswith(".") for part in p.parts)
)
for item in paths_to_check:
if item.is_file():
total_files += 1
Expand Down
5 changes: 5 additions & 0 deletions invokeai/frontend/web/public/locales/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -844,6 +844,11 @@
"clipLEmbed": "CLIP-L Embed",
"clipGEmbed": "CLIP-G Embed",
"config": "Config",
"reidentify": "Reidentify",
"reidentifyTooltip": "If a model didn't install correctly (e.g. it has the wrong type or doesn't work), you can try reidentifying it. This will reset any custom settings you may have applied.",
"reidentifySuccess": "Model reidentified successfully",
"reidentifyUnknown": "Unable to identify model",
"reidentifyError": "Error reidentifying model",
"convert": "Convert",
"convertingModelBegin": "Converting Model. Please wait.",
"convertToDiffusers": "Convert To Diffusers",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import { Button } from '@invoke-ai/ui-library';
import { toast } from 'features/toast/toast';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiSparkleFill } from 'react-icons/pi';
import { useReidentifyModelMutation } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';

interface Props {
modelConfig: AnyModelConfig;
}

export const ModelReidentifyButton = memo(({ modelConfig }: Props) => {
const { t } = useTranslation();
const [reidentifyModel, { isLoading }] = useReidentifyModelMutation();

const onClick = useCallback(() => {
reidentifyModel({ key: modelConfig.key })
.unwrap()
.then(({ type }) => {
if (type === 'unknown') {
toast({
id: 'MODEL_REIDENTIFY_UNKNOWN',
title: t('modelManager.reidentifyUnknown'),
status: 'warning',
});
}
toast({
id: 'MODEL_REIDENTIFY_SUCCESS',
title: t('modelManager.reidentifySuccess'),
status: 'success',
});
})
.catch((_) => {
toast({
id: 'MODEL_REIDENTIFY_ERROR',
title: t('modelManager.reidentifyError'),
status: 'error',
});
});
}, [modelConfig.key, reidentifyModel, t]);

return (
<Button
onClick={onClick}
size="sm"
aria-label={t('modelManager.reidentifyTooltip')}
tooltip={t('modelManager.reidentifyTooltip')}
isLoading={isLoading}
flexShrink={0}
leftIcon={<PiSparkleFill />}
>
{t('modelManager.reidentify')}
</Button>
);
});

ModelReidentifyButton.displayName = 'ModelReidentifyButton';
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import type { AnyModelConfig } from 'services/api/types';
import { MainModelDefaultSettings } from './MainModelDefaultSettings/MainModelDefaultSettings';
import { ModelAttrView } from './ModelAttrView';
import { ModelFooter } from './ModelFooter';
import { ModelReidentifyButton } from './ModelReidentifyButton';
import { RelatedModels } from './RelatedModels';

type Props = {
Expand All @@ -21,6 +22,7 @@ type Props = {

export const ModelView = memo(({ modelConfig }: Props) => {
const { t } = useTranslation();

const withSettings = useMemo(() => {
if (modelConfig.type === 'main' && modelConfig.base !== 'sdxl-refiner') {
return true;
Expand All @@ -46,6 +48,7 @@ export const ModelView = memo(({ modelConfig }: Props) => {
<ModelConvertButton modelConfig={modelConfig} />
)}
<ModelEditButton />
<ModelReidentifyButton modelConfig={modelConfig} />
</ModelHeader>
<Divider />
<Flex flexDir="column" gap={4}>
Expand Down
35 changes: 35 additions & 0 deletions invokeai/frontend/web/src/services/api/endpoints/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,40 @@ export const modelsApi = api.injectEndpoints({
emptyModelCache: build.mutation<void, void>({
query: () => ({ url: buildModelsUrl('empty_model_cache'), method: 'POST' }),
}),
reidentifyModel: build.mutation<
paths['/api/v2/models/i/{key}/reidentify']['post']['responses']['200']['content']['application/json'],
{ key: string }
>({
query: ({ key }) => {
return {
url: buildModelsUrl(`i/${key}/reidentify`),
method: 'POST',
};
},
onQueryStarted: async (_, { dispatch, queryFulfilled }) => {
try {
const { data } = await queryFulfilled;

// Update the individual model query caches
dispatch(modelsApi.util.upsertQueryData('getModelConfig', data.key, data));

const { base, name, type } = data;
dispatch(modelsApi.util.upsertQueryData('getModelConfigByAttrs', { base, name, type }, data));

// Update the list query cache
dispatch(
modelsApi.util.updateQueryData('getModelConfigs', undefined, (draft) => {
modelConfigsAdapter.updateOne(draft, {
id: data.key,
changes: data,
});
})
);
} catch {
// no-op
}
},
}),
}),
});

Expand All @@ -321,6 +355,7 @@ export const {
useSetHFTokenMutation,
useResetHFTokenMutation,
useEmptyModelCacheMutation,
useReidentifyModelMutation,
} = modelsApi;

export const selectModelConfigsQuery = modelsApi.endpoints.getModelConfigs.select();
84 changes: 84 additions & 0 deletions invokeai/frontend/web/src/services/api/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,26 @@ export type paths = {
patch: operations["update_model_record"];
trace?: never;
};
"/api/v2/models/i/{key}/reidentify": {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
get?: never;
put?: never;
/**
* Reidentify Model
* @description Attempt to reidentify a model by re-probing its weights file.
*/
post: operations["reidentify_model"];
delete?: never;
options?: never;
head?: never;
patch?: never;
trace?: never;
};
"/api/v2/models/scan_folder": {
parameters: {
query?: never;
Expand Down Expand Up @@ -24655,6 +24675,70 @@ export interface operations {
};
};
};
reidentify_model: {
parameters: {
query?: never;
header?: never;
path: {
/** @description Key of the model to reidentify. */
key: string;
};
cookie?: never;
};
requestBody?: never;
responses: {
/** @description The model configuration was retrieved successfully */
200: {
headers: {
[name: string]: unknown;
};
content: {
/** @example {
* "path": "string",
* "name": "string",
* "base": "sd-1",
* "type": "main",
* "format": "checkpoint",
* "config_path": "string",
* "key": "string",
* "hash": "string",
* "file_size": 1,
* "description": "string",
* "source": "string",
* "converted_at": 0,
* "variant": "normal",
* "prediction_type": "epsilon",
* "repo_variant": "fp16",
* "upcast_attention": false
* } */
"application/json": components["schemas"]["Main_Diffusers_SD1_Config"] | components["schemas"]["Main_Diffusers_SD2_Config"] | components["schemas"]["Main_Diffusers_SDXL_Config"] | components["schemas"]["Main_Diffusers_SDXLRefiner_Config"] | components["schemas"]["Main_Diffusers_SD3_Config"] | components["schemas"]["Main_Diffusers_CogView4_Config"] | components["schemas"]["Main_Checkpoint_SD1_Config"] | components["schemas"]["Main_Checkpoint_SD2_Config"] | components["schemas"]["Main_Checkpoint_SDXL_Config"] | components["schemas"]["Main_Checkpoint_SDXLRefiner_Config"] | components["schemas"]["Main_Checkpoint_FLUX_Config"] | components["schemas"]["Main_BnBNF4_FLUX_Config"] | components["schemas"]["Main_GGUF_FLUX_Config"] | components["schemas"]["VAE_Checkpoint_SD1_Config"] | components["schemas"]["VAE_Checkpoint_SD2_Config"] | components["schemas"]["VAE_Checkpoint_SDXL_Config"] | components["schemas"]["VAE_Checkpoint_FLUX_Config"] | components["schemas"]["VAE_Diffusers_SD1_Config"] | components["schemas"]["VAE_Diffusers_SDXL_Config"] | components["schemas"]["ControlNet_Checkpoint_SD1_Config"] | components["schemas"]["ControlNet_Checkpoint_SD2_Config"] | components["schemas"]["ControlNet_Checkpoint_SDXL_Config"] | components["schemas"]["ControlNet_Checkpoint_FLUX_Config"] | components["schemas"]["ControlNet_Diffusers_SD1_Config"] | components["schemas"]["ControlNet_Diffusers_SD2_Config"] | components["schemas"]["ControlNet_Diffusers_SDXL_Config"] | components["schemas"]["ControlNet_Diffusers_FLUX_Config"] | components["schemas"]["LoRA_LyCORIS_SD1_Config"] | components["schemas"]["LoRA_LyCORIS_SD2_Config"] | components["schemas"]["LoRA_LyCORIS_SDXL_Config"] | components["schemas"]["LoRA_LyCORIS_FLUX_Config"] | components["schemas"]["LoRA_OMI_SDXL_Config"] | components["schemas"]["LoRA_OMI_FLUX_Config"] | components["schemas"]["LoRA_Diffusers_SD1_Config"] | components["schemas"]["LoRA_Diffusers_SD2_Config"] | components["schemas"]["LoRA_Diffusers_SDXL_Config"] | components["schemas"]["LoRA_Diffusers_FLUX_Config"] | components["schemas"]["ControlLoRA_LyCORIS_FLUX_Config"] | components["schemas"]["T5Encoder_T5Encoder_Config"] | components["schemas"]["T5Encoder_BnBLLMint8_Config"] | components["schemas"]["TI_File_SD1_Config"] | components["schemas"]["TI_File_SD2_Config"] | components["schemas"]["TI_File_SDXL_Config"] | components["schemas"]["TI_Folder_SD1_Config"] | components["schemas"]["TI_Folder_SD2_Config"] | components["schemas"]["TI_Folder_SDXL_Config"] | components["schemas"]["IPAdapter_InvokeAI_SD1_Config"] | components["schemas"]["IPAdapter_InvokeAI_SD2_Config"] | components["schemas"]["IPAdapter_InvokeAI_SDXL_Config"] | components["schemas"]["IPAdapter_Checkpoint_SD1_Config"] | components["schemas"]["IPAdapter_Checkpoint_SD2_Config"] | components["schemas"]["IPAdapter_Checkpoint_SDXL_Config"] | components["schemas"]["IPAdapter_Checkpoint_FLUX_Config"] | components["schemas"]["T2IAdapter_Diffusers_SD1_Config"] | components["schemas"]["T2IAdapter_Diffusers_SDXL_Config"] | components["schemas"]["Spandrel_Checkpoint_Config"] | components["schemas"]["CLIPEmbed_Diffusers_G_Config"] | components["schemas"]["CLIPEmbed_Diffusers_L_Config"] | components["schemas"]["CLIPVision_Diffusers_Config"] | components["schemas"]["SigLIP_Diffusers_Config"] | components["schemas"]["FLUXRedux_Checkpoint_Config"] | components["schemas"]["LlavaOnevision_Diffusers_Config"] | components["schemas"]["Unknown_Config"];
};
};
/** @description Bad request */
400: {
headers: {
[name: string]: unknown;
};
content?: never;
};
/** @description The model could not be found */
404: {
headers: {
[name: string]: unknown;
};
content?: never;
};
/** @description Validation Error */
422: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["HTTPValidationError"];
};
};
};
};
scan_for_models: {
parameters: {
query?: {
Expand Down