Skip to content

Commit 15dcfff

Browse files
authored
Merge branch 'main' into maryhipp/restore-list-queue-items
2 parents 43b3880 + 3cec06f commit 15dcfff

File tree

7 files changed

+202
-123
lines changed

7 files changed

+202
-123
lines changed

invokeai/app/invocations/baseinvocation.py

Lines changed: 83 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@
3636
from invokeai.app.invocations.fields import (
3737
FieldKind,
3838
Input,
39+
InputFieldJSONSchemaExtra,
40+
UIType,
41+
migrate_model_ui_type,
3942
)
4043
from invokeai.app.services.config.config_default import get_config
4144
from invokeai.app.services.shared.invocation_context import InvocationContext
@@ -256,7 +259,9 @@ def invoke_internal(self, context: InvocationContext, services: "InvocationServi
256259
is_intermediate: bool = Field(
257260
default=False,
258261
description="Whether or not this is an intermediate invocation.",
259-
json_schema_extra={"ui_type": "IsIntermediate", "field_kind": FieldKind.NodeAttribute},
262+
json_schema_extra=InputFieldJSONSchemaExtra(
263+
input=Input.Direct, field_kind=FieldKind.NodeAttribute, ui_type=UIType._IsIntermediate
264+
).model_dump(exclude_none=True),
260265
)
261266
use_cache: bool = Field(
262267
default=True,
@@ -445,6 +450,15 @@ class _Model(BaseModel):
445450
RESERVED_PYDANTIC_FIELD_NAMES = {m[0] for m in inspect.getmembers(_Model())}
446451

447452

453+
def is_enum_member(value: Any, enum_class: type[Enum]) -> bool:
454+
"""Checks if a value is a member of an enum class."""
455+
try:
456+
enum_class(value)
457+
return True
458+
except ValueError:
459+
return False
460+
461+
448462
def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None:
449463
"""
450464
Validates the fields of an invocation or invocation output:
@@ -456,51 +470,99 @@ def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None
456470
"""
457471
for name, field in model_fields.items():
458472
if name in RESERVED_PYDANTIC_FIELD_NAMES:
459-
raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved by pydantic)')
473+
raise InvalidFieldError(f"{model_type}.{name}: Invalid field name (reserved by pydantic)")
460474

461475
if not field.annotation:
462-
raise InvalidFieldError(f'Invalid field type "{name}" on "{model_type}" (missing annotation)')
476+
raise InvalidFieldError(f"{model_type}.{name}: Invalid field type (missing annotation)")
463477

464478
if not isinstance(field.json_schema_extra, dict):
465-
raise InvalidFieldError(
466-
f'Invalid field definition for "{name}" on "{model_type}" (missing json_schema_extra dict)'
467-
)
479+
raise InvalidFieldError(f"{model_type}.{name}: Invalid field definition (missing json_schema_extra dict)")
468480

469481
field_kind = field.json_schema_extra.get("field_kind", None)
470482

471483
# must have a field_kind
472-
if not isinstance(field_kind, FieldKind):
484+
if not is_enum_member(field_kind, FieldKind):
473485
raise InvalidFieldError(
474-
f'Invalid field definition for "{name}" on "{model_type}" (maybe it\'s not an InputField or OutputField?)'
486+
f"{model_type}.{name}: Invalid field definition for (maybe it's not an InputField or OutputField?)"
475487
)
476488

477-
if field_kind is FieldKind.Input and (
489+
if field_kind == FieldKind.Input.value and (
478490
name in RESERVED_NODE_ATTRIBUTE_FIELD_NAMES or name in RESERVED_INPUT_FIELD_NAMES
479491
):
480-
raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved input field name)')
492+
raise InvalidFieldError(f"{model_type}.{name}: Invalid field name (reserved input field name)")
481493

482-
if field_kind is FieldKind.Output and name in RESERVED_OUTPUT_FIELD_NAMES:
483-
raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved output field name)')
494+
if field_kind == FieldKind.Output.value and name in RESERVED_OUTPUT_FIELD_NAMES:
495+
raise InvalidFieldError(f"{model_type}.{name}: Invalid field name (reserved output field name)")
484496

485-
if (field_kind is FieldKind.Internal) and name not in RESERVED_INPUT_FIELD_NAMES:
486-
raise InvalidFieldError(
487-
f'Invalid field name "{name}" on "{model_type}" (internal field without reserved name)'
488-
)
497+
if field_kind == FieldKind.Internal.value and name not in RESERVED_INPUT_FIELD_NAMES:
498+
raise InvalidFieldError(f"{model_type}.{name}: Invalid field name (internal field without reserved name)")
489499

490500
# node attribute fields *must* be in the reserved list
491501
if (
492-
field_kind is FieldKind.NodeAttribute
502+
field_kind == FieldKind.NodeAttribute.value
493503
and name not in RESERVED_NODE_ATTRIBUTE_FIELD_NAMES
494504
and name not in RESERVED_OUTPUT_FIELD_NAMES
495505
):
496506
raise InvalidFieldError(
497-
f'Invalid field name "{name}" on "{model_type}" (node attribute field without reserved name)'
507+
f"{model_type}.{name}: Invalid field name (node attribute field without reserved name)"
498508
)
499509

500510
ui_type = field.json_schema_extra.get("ui_type", None)
501-
if isinstance(ui_type, str) and ui_type.startswith("DEPRECATED_"):
502-
logger.warning(f'"UIType.{ui_type.split("_")[-1]}" is deprecated, ignoring')
503-
field.json_schema_extra.pop("ui_type")
511+
ui_model_base = field.json_schema_extra.get("ui_model_base", None)
512+
ui_model_type = field.json_schema_extra.get("ui_model_type", None)
513+
ui_model_variant = field.json_schema_extra.get("ui_model_variant", None)
514+
ui_model_format = field.json_schema_extra.get("ui_model_format", None)
515+
516+
if ui_type is not None:
517+
# There are 3 cases where we may need to take action:
518+
#
519+
# 1. The ui_type is a migratable, deprecated value. For example, ui_type=UIType.MainModel value is
520+
# deprecated and should be migrated to:
521+
# - ui_model_base=[BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]
522+
# - ui_model_type=[ModelType.Main]
523+
#
524+
# 2. ui_type was set in conjunction with any of the new ui_model_[base|type|variant|format] fields, which
525+
# is not allowed (they are mutually exclusive). In this case, we ignore ui_type and log a warning.
526+
#
527+
# 3. ui_type is a deprecated value that is not migratable. For example, ui_type=UIType.Image is deprecated;
528+
# Image fields are now automatically detected based on the field's type annotation. In this case, we
529+
# ignore ui_type and log a warning.
530+
#
531+
# The cases must be checked in this order to ensure proper handling.
532+
533+
# Easier to work with as an enum
534+
ui_type = UIType(ui_type)
535+
536+
# The enum member values are not always the same as their names - we want to log the name so the user can
537+
# easily review their code and see where the deprecated enum member is used.
538+
human_readable_name = f"UIType.{ui_type.name}"
539+
540+
# Case 1: migratable deprecated value
541+
did_migrate = migrate_model_ui_type(ui_type, field.json_schema_extra)
542+
543+
if did_migrate:
544+
logger.warning(
545+
f'{model_type}.{name}: Migrated deprecated "ui_type" "{human_readable_name}" to new ui_model_[base|type|variant|format] fields'
546+
)
547+
field.json_schema_extra.pop("ui_type")
548+
549+
# Case 2: mutually exclusive with new fields
550+
elif (
551+
ui_model_base is not None
552+
or ui_model_type is not None
553+
or ui_model_variant is not None
554+
or ui_model_format is not None
555+
):
556+
logger.warning(
557+
f'{model_type}.{name}: "ui_type" is mutually exclusive with "ui_model_[base|type|format|variant]", ignoring "ui_type"'
558+
)
559+
field.json_schema_extra.pop("ui_type")
560+
561+
# Case 3: deprecated value that is not migratable
562+
elif ui_type.startswith("DEPRECATED_"):
563+
logger.warning(f'{model_type}.{name}: Deprecated "ui_type" "{human_readable_name}", ignoring')
564+
field.json_schema_extra.pop("ui_type")
565+
504566
return None
505567

506568

invokeai/app/invocations/fields.py

Lines changed: 109 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
5454
# region Internal Field Types
5555
_Collection = "CollectionField"
5656
_CollectionItem = "CollectionItemField"
57+
_IsIntermediate = "IsIntermediate"
5758
# endregion
5859

5960
# region DEPRECATED
@@ -91,7 +92,6 @@ class UIType(str, Enum, metaclass=MetaEnum):
9192
CollectionItem = "DEPRECATED_CollectionItem"
9293
Enum = "DEPRECATED_Enum"
9394
WorkflowField = "DEPRECATED_WorkflowField"
94-
IsIntermediate = "DEPRECATED_IsIntermediate"
9595
BoardField = "DEPRECATED_BoardField"
9696
MetadataItem = "DEPRECATED_MetadataItem"
9797
MetadataItemCollection = "DEPRECATED_MetadataItemCollection"
@@ -423,6 +423,7 @@ class InputFieldJSONSchemaExtra(BaseModel):
423423
model_config = ConfigDict(
424424
validate_assignment=True,
425425
json_schema_serialization_defaults_required=True,
426+
use_enum_values=True,
426427
)
427428

428429

@@ -482,9 +483,114 @@ class OutputFieldJSONSchemaExtra(BaseModel):
482483
model_config = ConfigDict(
483484
validate_assignment=True,
484485
json_schema_serialization_defaults_required=True,
486+
use_enum_values=True,
485487
)
486488

487489

490+
def migrate_model_ui_type(ui_type: UIType | str, json_schema_extra: dict[str, Any]) -> bool:
491+
"""Migrate deprecated model-specifier ui_type values to new-style ui_model_[base|type|variant|format] in json_schema_extra."""
492+
if not isinstance(ui_type, UIType):
493+
ui_type = UIType(ui_type)
494+
495+
ui_model_type: list[ModelType] | None = None
496+
ui_model_base: list[BaseModelType] | None = None
497+
ui_model_format: list[ModelFormat] | None = None
498+
ui_model_variant: list[ClipVariantType | ModelVariantType] | None = None
499+
500+
match ui_type:
501+
case UIType.MainModel:
502+
ui_model_base = [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]
503+
ui_model_type = [ModelType.Main]
504+
case UIType.CogView4MainModel:
505+
ui_model_base = [BaseModelType.CogView4]
506+
ui_model_type = [ModelType.Main]
507+
case UIType.FluxMainModel:
508+
ui_model_base = [BaseModelType.Flux]
509+
ui_model_type = [ModelType.Main]
510+
case UIType.SD3MainModel:
511+
ui_model_base = [BaseModelType.StableDiffusion3]
512+
ui_model_type = [ModelType.Main]
513+
case UIType.SDXLMainModel:
514+
ui_model_base = [BaseModelType.StableDiffusionXL]
515+
ui_model_type = [ModelType.Main]
516+
case UIType.SDXLRefinerModel:
517+
ui_model_base = [BaseModelType.StableDiffusionXLRefiner]
518+
ui_model_type = [ModelType.Main]
519+
case UIType.VAEModel:
520+
ui_model_type = [ModelType.VAE]
521+
case UIType.FluxVAEModel:
522+
ui_model_base = [BaseModelType.Flux]
523+
ui_model_type = [ModelType.VAE]
524+
case UIType.LoRAModel:
525+
ui_model_type = [ModelType.LoRA]
526+
case UIType.ControlNetModel:
527+
ui_model_type = [ModelType.ControlNet]
528+
case UIType.IPAdapterModel:
529+
ui_model_type = [ModelType.IPAdapter]
530+
case UIType.T2IAdapterModel:
531+
ui_model_type = [ModelType.T2IAdapter]
532+
case UIType.T5EncoderModel:
533+
ui_model_type = [ModelType.T5Encoder]
534+
case UIType.CLIPEmbedModel:
535+
ui_model_type = [ModelType.CLIPEmbed]
536+
case UIType.CLIPLEmbedModel:
537+
ui_model_type = [ModelType.CLIPEmbed]
538+
ui_model_variant = [ClipVariantType.L]
539+
case UIType.CLIPGEmbedModel:
540+
ui_model_type = [ModelType.CLIPEmbed]
541+
ui_model_variant = [ClipVariantType.G]
542+
case UIType.SpandrelImageToImageModel:
543+
ui_model_type = [ModelType.SpandrelImageToImage]
544+
case UIType.ControlLoRAModel:
545+
ui_model_type = [ModelType.ControlLoRa]
546+
case UIType.SigLipModel:
547+
ui_model_type = [ModelType.SigLIP]
548+
case UIType.FluxReduxModel:
549+
ui_model_type = [ModelType.FluxRedux]
550+
case UIType.LlavaOnevisionModel:
551+
ui_model_type = [ModelType.LlavaOnevision]
552+
case UIType.Imagen3Model:
553+
ui_model_base = [BaseModelType.Imagen3]
554+
ui_model_type = [ModelType.Main]
555+
case UIType.Imagen4Model:
556+
ui_model_base = [BaseModelType.Imagen4]
557+
ui_model_type = [ModelType.Main]
558+
case UIType.ChatGPT4oModel:
559+
ui_model_base = [BaseModelType.ChatGPT4o]
560+
ui_model_type = [ModelType.Main]
561+
case UIType.Gemini2_5Model:
562+
ui_model_base = [BaseModelType.Gemini2_5]
563+
ui_model_type = [ModelType.Main]
564+
case UIType.FluxKontextModel:
565+
ui_model_base = [BaseModelType.FluxKontext]
566+
ui_model_type = [ModelType.Main]
567+
case UIType.Veo3Model:
568+
ui_model_base = [BaseModelType.Veo3]
569+
ui_model_type = [ModelType.Video]
570+
case UIType.RunwayModel:
571+
ui_model_base = [BaseModelType.Runway]
572+
ui_model_type = [ModelType.Video]
573+
case _:
574+
pass
575+
576+
did_migrate = False
577+
578+
if ui_model_type is not None:
579+
json_schema_extra["ui_model_type"] = [m.value for m in ui_model_type]
580+
did_migrate = True
581+
if ui_model_base is not None:
582+
json_schema_extra["ui_model_base"] = [m.value for m in ui_model_base]
583+
did_migrate = True
584+
if ui_model_format is not None:
585+
json_schema_extra["ui_model_format"] = [m.value for m in ui_model_format]
586+
did_migrate = True
587+
if ui_model_variant is not None:
588+
json_schema_extra["ui_model_variant"] = [m.value for m in ui_model_variant]
589+
did_migrate = True
590+
591+
return did_migrate
592+
593+
488594
def InputField(
489595
# copied from pydantic's Field
490596
# TODO: Can we support default_factory?
@@ -575,93 +681,6 @@ def InputField(
575681
field_kind=FieldKind.Input,
576682
)
577683

578-
if ui_type is not None:
579-
if (
580-
ui_model_base is not None
581-
or ui_model_type is not None
582-
or ui_model_variant is not None
583-
or ui_model_format is not None
584-
):
585-
logger.warning("InputField: Use either ui_type or ui_model_[base|type|variant|format]. Ignoring ui_type.")
586-
# Map old-style UIType to new-style ui_model_[base|type|variant|format]
587-
elif ui_type is UIType.MainModel:
588-
json_schema_extra_.ui_model_type = [ModelType.Main]
589-
elif ui_type is UIType.CogView4MainModel:
590-
json_schema_extra_.ui_model_base = [BaseModelType.CogView4]
591-
json_schema_extra_.ui_model_type = [ModelType.Main]
592-
elif ui_type is UIType.FluxMainModel:
593-
json_schema_extra_.ui_model_base = [BaseModelType.Flux]
594-
json_schema_extra_.ui_model_type = [ModelType.Main]
595-
elif ui_type is UIType.SD3MainModel:
596-
json_schema_extra_.ui_model_base = [BaseModelType.StableDiffusion3]
597-
json_schema_extra_.ui_model_type = [ModelType.Main]
598-
elif ui_type is UIType.SDXLMainModel:
599-
json_schema_extra_.ui_model_base = [BaseModelType.StableDiffusionXL]
600-
json_schema_extra_.ui_model_type = [ModelType.Main]
601-
elif ui_type is UIType.SDXLRefinerModel:
602-
json_schema_extra_.ui_model_base = [BaseModelType.StableDiffusionXLRefiner]
603-
json_schema_extra_.ui_model_type = [ModelType.Main]
604-
# Think this UIType is unused...?
605-
# elif ui_type is UIType.ONNXModel:
606-
# json_schema_extra_.ui_model_base =
607-
# json_schema_extra_.ui_model_type =
608-
elif ui_type is UIType.VAEModel:
609-
json_schema_extra_.ui_model_type = [ModelType.VAE]
610-
elif ui_type is UIType.FluxVAEModel:
611-
json_schema_extra_.ui_model_base = [BaseModelType.Flux]
612-
json_schema_extra_.ui_model_type = [ModelType.VAE]
613-
elif ui_type is UIType.LoRAModel:
614-
json_schema_extra_.ui_model_type = [ModelType.LoRA]
615-
elif ui_type is UIType.ControlNetModel:
616-
json_schema_extra_.ui_model_type = [ModelType.ControlNet]
617-
elif ui_type is UIType.IPAdapterModel:
618-
json_schema_extra_.ui_model_type = [ModelType.IPAdapter]
619-
elif ui_type is UIType.T2IAdapterModel:
620-
json_schema_extra_.ui_model_type = [ModelType.T2IAdapter]
621-
elif ui_type is UIType.T5EncoderModel:
622-
json_schema_extra_.ui_model_type = [ModelType.T5Encoder]
623-
elif ui_type is UIType.CLIPEmbedModel:
624-
json_schema_extra_.ui_model_type = [ModelType.CLIPEmbed]
625-
elif ui_type is UIType.CLIPLEmbedModel:
626-
json_schema_extra_.ui_model_type = [ModelType.CLIPEmbed]
627-
json_schema_extra_.ui_model_variant = [ClipVariantType.L]
628-
elif ui_type is UIType.CLIPGEmbedModel:
629-
json_schema_extra_.ui_model_type = [ModelType.CLIPEmbed]
630-
json_schema_extra_.ui_model_variant = [ClipVariantType.G]
631-
elif ui_type is UIType.SpandrelImageToImageModel:
632-
json_schema_extra_.ui_model_type = [ModelType.SpandrelImageToImage]
633-
elif ui_type is UIType.ControlLoRAModel:
634-
json_schema_extra_.ui_model_type = [ModelType.ControlLoRa]
635-
elif ui_type is UIType.SigLipModel:
636-
json_schema_extra_.ui_model_type = [ModelType.SigLIP]
637-
elif ui_type is UIType.FluxReduxModel:
638-
json_schema_extra_.ui_model_type = [ModelType.FluxRedux]
639-
elif ui_type is UIType.LlavaOnevisionModel:
640-
json_schema_extra_.ui_model_type = [ModelType.LlavaOnevision]
641-
elif ui_type is UIType.Imagen3Model:
642-
json_schema_extra_.ui_model_base = [BaseModelType.Imagen3]
643-
json_schema_extra_.ui_model_type = [ModelType.Main]
644-
elif ui_type is UIType.Imagen4Model:
645-
json_schema_extra_.ui_model_base = [BaseModelType.Imagen4]
646-
json_schema_extra_.ui_model_type = [ModelType.Main]
647-
elif ui_type is UIType.ChatGPT4oModel:
648-
json_schema_extra_.ui_model_base = [BaseModelType.ChatGPT4o]
649-
json_schema_extra_.ui_model_type = [ModelType.Main]
650-
elif ui_type is UIType.Gemini2_5Model:
651-
json_schema_extra_.ui_model_base = [BaseModelType.Gemini2_5]
652-
json_schema_extra_.ui_model_type = [ModelType.Main]
653-
elif ui_type is UIType.FluxKontextModel:
654-
json_schema_extra_.ui_model_base = [BaseModelType.FluxKontext]
655-
json_schema_extra_.ui_model_type = [ModelType.Main]
656-
elif ui_type is UIType.Veo3Model:
657-
json_schema_extra_.ui_model_base = [BaseModelType.Veo3]
658-
json_schema_extra_.ui_model_type = [ModelType.Video]
659-
elif ui_type is UIType.RunwayModel:
660-
json_schema_extra_.ui_model_base = [BaseModelType.Runway]
661-
json_schema_extra_.ui_model_type = [ModelType.Video]
662-
else:
663-
json_schema_extra_.ui_type = ui_type
664-
665684
if ui_component is not None:
666685
json_schema_extra_.ui_component = ui_component
667686
if ui_hidden is not None:
@@ -690,6 +709,8 @@ def InputField(
690709
json_schema_extra_.ui_model_format = ui_model_format
691710
else:
692711
json_schema_extra_.ui_model_format = [ui_model_format]
712+
if ui_type is not None:
713+
json_schema_extra_.ui_type = ui_type
693714

694715
"""
695716
There is a conflict between the typing of invocation definitions and the typing of an invocation's

0 commit comments

Comments
 (0)