Skip to content

Commit 7618b22

Browse files
authored
Merge branch 'main' into feat/workflow-library-ui-tweaks
2 parents 2a0885e + 3707c3b commit 7618b22

File tree

108 files changed

+3041
-3239
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

108 files changed

+3041
-3239
lines changed

invokeai/app/invocations/baseinvocation.py

Lines changed: 83 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@
3636
from invokeai.app.invocations.fields import (
3737
FieldKind,
3838
Input,
39+
InputFieldJSONSchemaExtra,
40+
UIType,
41+
migrate_model_ui_type,
3942
)
4043
from invokeai.app.services.config.config_default import get_config
4144
from invokeai.app.services.shared.invocation_context import InvocationContext
@@ -256,7 +259,9 @@ def invoke_internal(self, context: InvocationContext, services: "InvocationServi
256259
is_intermediate: bool = Field(
257260
default=False,
258261
description="Whether or not this is an intermediate invocation.",
259-
json_schema_extra={"ui_type": "IsIntermediate", "field_kind": FieldKind.NodeAttribute},
262+
json_schema_extra=InputFieldJSONSchemaExtra(
263+
input=Input.Direct, field_kind=FieldKind.NodeAttribute, ui_type=UIType._IsIntermediate
264+
).model_dump(exclude_none=True),
260265
)
261266
use_cache: bool = Field(
262267
default=True,
@@ -445,6 +450,15 @@ class _Model(BaseModel):
445450
RESERVED_PYDANTIC_FIELD_NAMES = {m[0] for m in inspect.getmembers(_Model())}
446451

447452

453+
def is_enum_member(value: Any, enum_class: type[Enum]) -> bool:
454+
"""Checks if a value is a member of an enum class."""
455+
try:
456+
enum_class(value)
457+
return True
458+
except ValueError:
459+
return False
460+
461+
448462
def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None:
449463
"""
450464
Validates the fields of an invocation or invocation output:
@@ -456,51 +470,99 @@ def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None
456470
"""
457471
for name, field in model_fields.items():
458472
if name in RESERVED_PYDANTIC_FIELD_NAMES:
459-
raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved by pydantic)')
473+
raise InvalidFieldError(f"{model_type}.{name}: Invalid field name (reserved by pydantic)")
460474

461475
if not field.annotation:
462-
raise InvalidFieldError(f'Invalid field type "{name}" on "{model_type}" (missing annotation)')
476+
raise InvalidFieldError(f"{model_type}.{name}: Invalid field type (missing annotation)")
463477

464478
if not isinstance(field.json_schema_extra, dict):
465-
raise InvalidFieldError(
466-
f'Invalid field definition for "{name}" on "{model_type}" (missing json_schema_extra dict)'
467-
)
479+
raise InvalidFieldError(f"{model_type}.{name}: Invalid field definition (missing json_schema_extra dict)")
468480

469481
field_kind = field.json_schema_extra.get("field_kind", None)
470482

471483
# must have a field_kind
472-
if not isinstance(field_kind, FieldKind):
484+
if not is_enum_member(field_kind, FieldKind):
473485
raise InvalidFieldError(
474-
f'Invalid field definition for "{name}" on "{model_type}" (maybe it\'s not an InputField or OutputField?)'
486+
f"{model_type}.{name}: Invalid field definition for (maybe it's not an InputField or OutputField?)"
475487
)
476488

477-
if field_kind is FieldKind.Input and (
489+
if field_kind == FieldKind.Input.value and (
478490
name in RESERVED_NODE_ATTRIBUTE_FIELD_NAMES or name in RESERVED_INPUT_FIELD_NAMES
479491
):
480-
raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved input field name)')
492+
raise InvalidFieldError(f"{model_type}.{name}: Invalid field name (reserved input field name)")
481493

482-
if field_kind is FieldKind.Output and name in RESERVED_OUTPUT_FIELD_NAMES:
483-
raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved output field name)')
494+
if field_kind == FieldKind.Output.value and name in RESERVED_OUTPUT_FIELD_NAMES:
495+
raise InvalidFieldError(f"{model_type}.{name}: Invalid field name (reserved output field name)")
484496

485-
if (field_kind is FieldKind.Internal) and name not in RESERVED_INPUT_FIELD_NAMES:
486-
raise InvalidFieldError(
487-
f'Invalid field name "{name}" on "{model_type}" (internal field without reserved name)'
488-
)
497+
if field_kind == FieldKind.Internal.value and name not in RESERVED_INPUT_FIELD_NAMES:
498+
raise InvalidFieldError(f"{model_type}.{name}: Invalid field name (internal field without reserved name)")
489499

490500
# node attribute fields *must* be in the reserved list
491501
if (
492-
field_kind is FieldKind.NodeAttribute
502+
field_kind == FieldKind.NodeAttribute.value
493503
and name not in RESERVED_NODE_ATTRIBUTE_FIELD_NAMES
494504
and name not in RESERVED_OUTPUT_FIELD_NAMES
495505
):
496506
raise InvalidFieldError(
497-
f'Invalid field name "{name}" on "{model_type}" (node attribute field without reserved name)'
507+
f"{model_type}.{name}: Invalid field name (node attribute field without reserved name)"
498508
)
499509

500510
ui_type = field.json_schema_extra.get("ui_type", None)
501-
if isinstance(ui_type, str) and ui_type.startswith("DEPRECATED_"):
502-
logger.warning(f'"UIType.{ui_type.split("_")[-1]}" is deprecated, ignoring')
503-
field.json_schema_extra.pop("ui_type")
511+
ui_model_base = field.json_schema_extra.get("ui_model_base", None)
512+
ui_model_type = field.json_schema_extra.get("ui_model_type", None)
513+
ui_model_variant = field.json_schema_extra.get("ui_model_variant", None)
514+
ui_model_format = field.json_schema_extra.get("ui_model_format", None)
515+
516+
if ui_type is not None:
517+
# There are 3 cases where we may need to take action:
518+
#
519+
# 1. The ui_type is a migratable, deprecated value. For example, ui_type=UIType.MainModel value is
520+
# deprecated and should be migrated to:
521+
# - ui_model_base=[BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]
522+
# - ui_model_type=[ModelType.Main]
523+
#
524+
# 2. ui_type was set in conjunction with any of the new ui_model_[base|type|variant|format] fields, which
525+
# is not allowed (they are mutually exclusive). In this case, we ignore ui_type and log a warning.
526+
#
527+
# 3. ui_type is a deprecated value that is not migratable. For example, ui_type=UIType.Image is deprecated;
528+
# Image fields are now automatically detected based on the field's type annotation. In this case, we
529+
# ignore ui_type and log a warning.
530+
#
531+
# The cases must be checked in this order to ensure proper handling.
532+
533+
# Easier to work with as an enum
534+
ui_type = UIType(ui_type)
535+
536+
# The enum member values are not always the same as their names - we want to log the name so the user can
537+
# easily review their code and see where the deprecated enum member is used.
538+
human_readable_name = f"UIType.{ui_type.name}"
539+
540+
# Case 1: migratable deprecated value
541+
did_migrate = migrate_model_ui_type(ui_type, field.json_schema_extra)
542+
543+
if did_migrate:
544+
logger.warning(
545+
f'{model_type}.{name}: Migrated deprecated "ui_type" "{human_readable_name}" to new ui_model_[base|type|variant|format] fields'
546+
)
547+
field.json_schema_extra.pop("ui_type")
548+
549+
# Case 2: mutually exclusive with new fields
550+
elif (
551+
ui_model_base is not None
552+
or ui_model_type is not None
553+
or ui_model_variant is not None
554+
or ui_model_format is not None
555+
):
556+
logger.warning(
557+
f'{model_type}.{name}: "ui_type" is mutually exclusive with "ui_model_[base|type|format|variant]", ignoring "ui_type"'
558+
)
559+
field.json_schema_extra.pop("ui_type")
560+
561+
# Case 3: deprecated value that is not migratable
562+
elif ui_type.startswith("DEPRECATED_"):
563+
logger.warning(f'{model_type}.{name}: Deprecated "ui_type" "{human_readable_name}", ignoring')
564+
field.json_schema_extra.pop("ui_type")
565+
504566
return None
505567

506568

invokeai/app/invocations/cogview4_model_loader.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
invocation,
66
invocation_output,
77
)
8-
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
8+
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
99
from invokeai.app.invocations.model import (
1010
GlmEncoderField,
1111
ModelIdentifierField,
@@ -14,6 +14,7 @@
1414
)
1515
from invokeai.app.services.shared.invocation_context import InvocationContext
1616
from invokeai.backend.model_manager.config import SubModelType
17+
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
1718

1819

1920
@invocation_output("cogview4_model_loader_output")
@@ -38,8 +39,9 @@ class CogView4ModelLoaderInvocation(BaseInvocation):
3839

3940
model: ModelIdentifierField = InputField(
4041
description=FieldDescriptions.cogview4_model,
41-
ui_type=UIType.CogView4MainModel,
4242
input=Input.Direct,
43+
ui_model_base=BaseModelType.CogView4,
44+
ui_model_type=ModelType.Main,
4345
)
4446

4547
def invoke(self, context: InvocationContext) -> CogView4ModelLoaderOutput:

invokeai/app/invocations/controlnet.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
ImageField,
1717
InputField,
1818
OutputField,
19-
UIType,
2019
)
2120
from invokeai.app.invocations.model import ModelIdentifierField
2221
from invokeai.app.invocations.primitives import ImageOutput
@@ -28,6 +27,7 @@
2827
heuristic_resize_fast,
2928
)
3029
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
30+
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
3131

3232

3333
class ControlField(BaseModel):
@@ -63,13 +63,17 @@ class ControlOutput(BaseInvocationOutput):
6363
control: ControlField = OutputField(description=FieldDescriptions.control)
6464

6565

66-
@invocation("controlnet", title="ControlNet - SD1.5, SDXL", tags=["controlnet"], category="controlnet", version="1.1.3")
66+
@invocation(
67+
"controlnet", title="ControlNet - SD1.5, SD2, SDXL", tags=["controlnet"], category="controlnet", version="1.1.3"
68+
)
6769
class ControlNetInvocation(BaseInvocation):
6870
"""Collects ControlNet info to pass to other nodes"""
6971

7072
image: ImageField = InputField(description="The control image")
7173
control_model: ModelIdentifierField = InputField(
72-
description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
74+
description=FieldDescriptions.controlnet_model,
75+
ui_model_base=[BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2, BaseModelType.StableDiffusionXL],
76+
ui_model_type=ModelType.ControlNet,
7377
)
7478
control_weight: Union[float, List[float]] = InputField(
7579
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"

0 commit comments

Comments
 (0)