36
36
from invokeai .app .invocations .fields import (
37
37
FieldKind ,
38
38
Input ,
39
+ InputFieldJSONSchemaExtra ,
40
+ UIType ,
41
+ migrate_model_ui_type ,
39
42
)
40
43
from invokeai .app .services .config .config_default import get_config
41
44
from invokeai .app .services .shared .invocation_context import InvocationContext
@@ -256,7 +259,9 @@ def invoke_internal(self, context: InvocationContext, services: "InvocationServi
256
259
is_intermediate : bool = Field (
257
260
default = False ,
258
261
description = "Whether or not this is an intermediate invocation." ,
259
- json_schema_extra = {"ui_type" : "IsIntermediate" , "field_kind" : FieldKind .NodeAttribute },
262
+ json_schema_extra = InputFieldJSONSchemaExtra (
263
+ input = Input .Direct , field_kind = FieldKind .NodeAttribute , ui_type = UIType ._IsIntermediate
264
+ ).model_dump (exclude_none = True ),
260
265
)
261
266
use_cache : bool = Field (
262
267
default = True ,
@@ -445,6 +450,15 @@ class _Model(BaseModel):
445
450
RESERVED_PYDANTIC_FIELD_NAMES = {m [0 ] for m in inspect .getmembers (_Model ())}
446
451
447
452
453
+ def is_enum_member (value : Any , enum_class : type [Enum ]) -> bool :
454
+ """Checks if a value is a member of an enum class."""
455
+ try :
456
+ enum_class (value )
457
+ return True
458
+ except ValueError :
459
+ return False
460
+
461
+
448
462
def validate_fields (model_fields : dict [str , FieldInfo ], model_type : str ) -> None :
449
463
"""
450
464
Validates the fields of an invocation or invocation output:
@@ -456,51 +470,99 @@ def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None
456
470
"""
457
471
for name , field in model_fields .items ():
458
472
if name in RESERVED_PYDANTIC_FIELD_NAMES :
459
- raise InvalidFieldError (f'Invalid field name " { name } " on " { model_type } " (reserved by pydantic)' )
473
+ raise InvalidFieldError (f" { model_type } . { name } : Invalid field name (reserved by pydantic)" )
460
474
461
475
if not field .annotation :
462
- raise InvalidFieldError (f'Invalid field type " { name } " on " { model_type } " (missing annotation)' )
476
+ raise InvalidFieldError (f" { model_type } . { name } : Invalid field type (missing annotation)" )
463
477
464
478
if not isinstance (field .json_schema_extra , dict ):
465
- raise InvalidFieldError (
466
- f'Invalid field definition for "{ name } " on "{ model_type } " (missing json_schema_extra dict)'
467
- )
479
+ raise InvalidFieldError (f"{ model_type } .{ name } : Invalid field definition (missing json_schema_extra dict)" )
468
480
469
481
field_kind = field .json_schema_extra .get ("field_kind" , None )
470
482
471
483
# must have a field_kind
472
- if not isinstance (field_kind , FieldKind ):
484
+ if not is_enum_member (field_kind , FieldKind ):
473
485
raise InvalidFieldError (
474
- f' Invalid field definition for " { name } " on " { model_type } " (maybe it\ ' s not an InputField or OutputField?)'
486
+ f" { model_type } . { name } : Invalid field definition for (maybe it's not an InputField or OutputField?)"
475
487
)
476
488
477
- if field_kind is FieldKind .Input and (
489
+ if field_kind == FieldKind .Input . value and (
478
490
name in RESERVED_NODE_ATTRIBUTE_FIELD_NAMES or name in RESERVED_INPUT_FIELD_NAMES
479
491
):
480
- raise InvalidFieldError (f'Invalid field name " { name } " on " { model_type } " (reserved input field name)' )
492
+ raise InvalidFieldError (f" { model_type } . { name } : Invalid field name (reserved input field name)" )
481
493
482
- if field_kind is FieldKind .Output and name in RESERVED_OUTPUT_FIELD_NAMES :
483
- raise InvalidFieldError (f'Invalid field name " { name } " on " { model_type } " (reserved output field name)' )
494
+ if field_kind == FieldKind .Output . value and name in RESERVED_OUTPUT_FIELD_NAMES :
495
+ raise InvalidFieldError (f" { model_type } . { name } : Invalid field name (reserved output field name)" )
484
496
485
- if (field_kind is FieldKind .Internal ) and name not in RESERVED_INPUT_FIELD_NAMES :
486
- raise InvalidFieldError (
487
- f'Invalid field name "{ name } " on "{ model_type } " (internal field without reserved name)'
488
- )
497
+ if field_kind == FieldKind .Internal .value and name not in RESERVED_INPUT_FIELD_NAMES :
498
+ raise InvalidFieldError (f"{ model_type } .{ name } : Invalid field name (internal field without reserved name)" )
489
499
490
500
# node attribute fields *must* be in the reserved list
491
501
if (
492
- field_kind is FieldKind .NodeAttribute
502
+ field_kind == FieldKind .NodeAttribute . value
493
503
and name not in RESERVED_NODE_ATTRIBUTE_FIELD_NAMES
494
504
and name not in RESERVED_OUTPUT_FIELD_NAMES
495
505
):
496
506
raise InvalidFieldError (
497
- f'Invalid field name " { name } " on " { model_type } " (node attribute field without reserved name)'
507
+ f" { model_type } . { name } : Invalid field name (node attribute field without reserved name)"
498
508
)
499
509
500
510
ui_type = field .json_schema_extra .get ("ui_type" , None )
501
- if isinstance (ui_type , str ) and ui_type .startswith ("DEPRECATED_" ):
502
- logger .warning (f'"UIType.{ ui_type .split ("_" )[- 1 ]} " is deprecated, ignoring' )
503
- field .json_schema_extra .pop ("ui_type" )
511
+ ui_model_base = field .json_schema_extra .get ("ui_model_base" , None )
512
+ ui_model_type = field .json_schema_extra .get ("ui_model_type" , None )
513
+ ui_model_variant = field .json_schema_extra .get ("ui_model_variant" , None )
514
+ ui_model_format = field .json_schema_extra .get ("ui_model_format" , None )
515
+
516
+ if ui_type is not None :
517
+ # There are 3 cases where we may need to take action:
518
+ #
519
+ # 1. The ui_type is a migratable, deprecated value. For example, ui_type=UIType.MainModel value is
520
+ # deprecated and should be migrated to:
521
+ # - ui_model_base=[BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]
522
+ # - ui_model_type=[ModelType.Main]
523
+ #
524
+ # 2. ui_type was set in conjunction with any of the new ui_model_[base|type|variant|format] fields, which
525
+ # is not allowed (they are mutually exclusive). In this case, we ignore ui_type and log a warning.
526
+ #
527
+ # 3. ui_type is a deprecated value that is not migratable. For example, ui_type=UIType.Image is deprecated;
528
+ # Image fields are now automatically detected based on the field's type annotation. In this case, we
529
+ # ignore ui_type and log a warning.
530
+ #
531
+ # The cases must be checked in this order to ensure proper handling.
532
+
533
+ # Easier to work with as an enum
534
+ ui_type = UIType (ui_type )
535
+
536
+ # The enum member values are not always the same as their names - we want to log the name so the user can
537
+ # easily review their code and see where the deprecated enum member is used.
538
+ human_readable_name = f"UIType.{ ui_type .name } "
539
+
540
+ # Case 1: migratable deprecated value
541
+ did_migrate = migrate_model_ui_type (ui_type , field .json_schema_extra )
542
+
543
+ if did_migrate :
544
+ logger .warning (
545
+ f'{ model_type } .{ name } : Migrated deprecated "ui_type" "{ human_readable_name } " to new ui_model_[base|type|variant|format] fields'
546
+ )
547
+ field .json_schema_extra .pop ("ui_type" )
548
+
549
+ # Case 2: mutually exclusive with new fields
550
+ elif (
551
+ ui_model_base is not None
552
+ or ui_model_type is not None
553
+ or ui_model_variant is not None
554
+ or ui_model_format is not None
555
+ ):
556
+ logger .warning (
557
+ f'{ model_type } .{ name } : "ui_type" is mutually exclusive with "ui_model_[base|type|format|variant]", ignoring "ui_type"'
558
+ )
559
+ field .json_schema_extra .pop ("ui_type" )
560
+
561
+ # Case 3: deprecated value that is not migratable
562
+ elif ui_type .startswith ("DEPRECATED_" ):
563
+ logger .warning (f'{ model_type } .{ name } : Deprecated "ui_type" "{ human_readable_name } ", ignoring' )
564
+ field .json_schema_extra .pop ("ui_type" )
565
+
504
566
return None
505
567
506
568
0 commit comments