1- from enum import Enum
1+ from itertools import zip_longest
22from pathlib import Path
33from typing import Literal
44
55import numpy as np
66import torch
77from PIL import Image
8- from pydantic import BaseModel , Field
9- from transformers import AutoProcessor
8+ from pydantic import BaseModel , Field , model_validator
109from transformers .models .sam import SamModel
1110from transformers .models .sam .processing_sam import SamProcessor
11+ from transformers .models .sam2 import Sam2Model
12+ from transformers .models .sam2 .processing_sam2 import Sam2Processor
1213
1314from invokeai .app .invocations .baseinvocation import BaseInvocation , invocation
1415from invokeai .app .invocations .fields import BoundingBoxField , ImageField , InputField , TensorField
1516from invokeai .app .invocations .primitives import MaskOutput
1617from invokeai .app .services .shared .invocation_context import InvocationContext
1718from invokeai .backend .image_util .segment_anything .mask_refinement import mask_to_polygon , polygon_to_mask
19+ from invokeai .backend .image_util .segment_anything .segment_anything_2_pipeline import SegmentAnything2Pipeline
1820from invokeai .backend .image_util .segment_anything .segment_anything_pipeline import SegmentAnythingPipeline
19-
20- SegmentAnythingModelKey = Literal ["segment-anything-base" , "segment-anything-large" , "segment-anything-huge" ]
21+ from invokeai .backend .image_util .segment_anything .shared import SAMInput , SAMPoint
22+
23+ SegmentAnythingModelKey = Literal [
24+ "segment-anything-base" ,
25+ "segment-anything-large" ,
26+ "segment-anything-huge" ,
27+ "segment-anything-2-tiny" ,
28+ "segment-anything-2-small" ,
29+ "segment-anything-2-base" ,
30+ "segment-anything-2-large" ,
31+ ]
2132SEGMENT_ANYTHING_MODEL_IDS : dict [SegmentAnythingModelKey , str ] = {
2233 "segment-anything-base" : "facebook/sam-vit-base" ,
2334 "segment-anything-large" : "facebook/sam-vit-large" ,
2435 "segment-anything-huge" : "facebook/sam-vit-huge" ,
36+ "segment-anything-2-tiny" : "facebook/sam2.1-hiera-tiny" ,
37+ "segment-anything-2-small" : "facebook/sam2.1-hiera-small" ,
38+ "segment-anything-2-base" : "facebook/sam2.1-hiera-base-plus" ,
39+ "segment-anything-2-large" : "facebook/sam2.1-hiera-large" ,
2540}
2641
2742
28- class SAMPointLabel (Enum ):
29- negative = - 1
30- neutral = 0
31- positive = 1
32-
33-
34- class SAMPoint (BaseModel ):
35- x : int = Field (..., description = "The x-coordinate of the point" )
36- y : int = Field (..., description = "The y-coordinate of the point" )
37- label : SAMPointLabel = Field (..., description = "The label of the point" )
38-
39-
4043class SAMPointsField (BaseModel ):
41- points : list [SAMPoint ] = Field (..., description = "The points of the object" )
44+ points : list [SAMPoint ] = Field (..., description = "The points of the object" , min_length = 1 )
4245
43- def to_list (self ) -> list [list [int ]]:
46+ def to_list (self ) -> list [list [float ]]:
4447 return [[point .x , point .y , point .label .value ] for point in self .points ]
4548
4649
4750@invocation (
4851 "segment_anything" ,
4952 title = "Segment Anything" ,
50- tags = ["prompt" , "segmentation" ],
53+ tags = ["prompt" , "segmentation" , "sam" , "sam2" ],
5154 category = "segmentation" ,
52- version = "1.2 .0" ,
55+ version = "1.3 .0" ,
5356)
5457class SegmentAnythingInvocation (BaseInvocation ):
55- """Runs a Segment Anything Model."""
58+ """Runs a Segment Anything Model (SAM or SAM2) ."""
5659
5760 # Reference:
5861 # - https://arxiv.org/pdf/2304.02643
5962 # - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam
6063 # - https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
6164
62- model : SegmentAnythingModelKey = InputField (description = "The Segment Anything model to use." )
65+ model : SegmentAnythingModelKey = InputField (description = "The Segment Anything model to use (SAM or SAM2) ." )
6366 image : ImageField = InputField (description = "The image to segment." )
6467 bounding_boxes : list [BoundingBoxField ] | None = InputField (
65- default = None , description = "The bounding boxes to prompt the SAM model with."
68+ default = None , description = "The bounding boxes to prompt the model with."
6669 )
6770 point_lists : list [SAMPointsField ] | None = InputField (
6871 default = None ,
69- description = "The list of point lists to prompt the SAM model with. Each list of points represents a single object." ,
72+ description = "The list of point lists to prompt the model with. Each list of points represents a single object." ,
7073 )
7174 apply_polygon_refinement : bool = InputField (
7275 description = "Whether to apply polygon refinement to the masks. This will smooth the edges of the masks slightly and ensure that each mask consists of a single closed polygon (before merging)." ,
@@ -77,14 +80,18 @@ class SegmentAnythingInvocation(BaseInvocation):
7780 default = "all" ,
7881 )
7982
83+ @model_validator (mode = "after" )
84+ def validate_points_and_boxes_len (self ):
85+ if self .point_lists is not None and self .bounding_boxes is not None :
86+ if len (self .point_lists ) != len (self .bounding_boxes ):
87+ raise ValueError ("If both point_lists and bounding_boxes are provided, they must have the same length." )
88+ return self
89+
8090 @torch .no_grad ()
8191 def invoke (self , context : InvocationContext ) -> MaskOutput :
8292 # The models expect a 3-channel RGB image.
8393 image_pil = context .images .get_pil (self .image .image_name , mode = "RGB" )
8494
85- if self .point_lists is not None and self .bounding_boxes is not None :
86- raise ValueError ("Only one of point_lists or bounding_box can be provided." )
87-
8895 if (not self .bounding_boxes or len (self .bounding_boxes ) == 0 ) and (
8996 not self .point_lists or len (self .point_lists ) == 0
9097 ):
@@ -111,26 +118,38 @@ def _load_sam_model(model_path: Path):
111118 # model, and figure out how to make it work in the pipeline.
112119 # torch_dtype=TorchDevice.choose_torch_dtype(),
113120 )
114-
115- sam_processor = AutoProcessor .from_pretrained (model_path , local_files_only = True )
116- assert isinstance (sam_processor , SamProcessor )
121+ sam_processor = SamProcessor .from_pretrained (model_path , local_files_only = True )
117122 return SegmentAnythingPipeline (sam_model = sam_model , sam_processor = sam_processor )
118123
124+ @staticmethod
125+ def _load_sam_2_model (model_path : Path ):
126+ sam2_model = Sam2Model .from_pretrained (model_path , local_files_only = True )
127+ sam2_processor = Sam2Processor .from_pretrained (model_path , local_files_only = True )
128+ return SegmentAnything2Pipeline (sam2_model = sam2_model , sam2_processor = sam2_processor )
129+
119130 def _segment (self , context : InvocationContext , image : Image .Image ) -> list [torch .Tensor ]:
120- """Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes."""
121- # Convert the bounding boxes to the SAM input format.
122- sam_bounding_boxes = (
123- [[bb .x_min , bb .y_min , bb .x_max , bb .y_max ] for bb in self .bounding_boxes ] if self .bounding_boxes else None
124- )
125- sam_points = [p .to_list () for p in self .point_lists ] if self .point_lists else None
131+ """Use Segment Anything (SAM or SAM2) to generate masks given an image + a set of bounding boxes."""
132+
133+ source = SEGMENT_ANYTHING_MODEL_IDS [self .model ]
134+ inputs : list [SAMInput ] = []
135+ for bbox_field , point_field in zip_longest (self .bounding_boxes or [], self .point_lists or [], fillvalue = None ):
136+ inputs .append (
137+ SAMInput (
138+ bounding_box = bbox_field ,
139+ points = point_field .points if point_field else None ,
140+ )
141+ )
126142
127- with (
128- context .models .load_remote_model (
129- source = SEGMENT_ANYTHING_MODEL_IDS [self .model ], loader = SegmentAnythingInvocation ._load_sam_model
130- ) as sam_pipeline ,
131- ):
132- assert isinstance (sam_pipeline , SegmentAnythingPipeline )
133- masks = sam_pipeline .segment (image = image , bounding_boxes = sam_bounding_boxes , point_lists = sam_points )
143+ if "sam2" in source :
144+ loader = SegmentAnythingInvocation ._load_sam_2_model
145+ with context .models .load_remote_model (source = source , loader = loader ) as pipeline :
146+ assert isinstance (pipeline , SegmentAnything2Pipeline )
147+ masks = pipeline .segment (image = image , inputs = inputs )
148+ else :
149+ loader = SegmentAnythingInvocation ._load_sam_model
150+ with context .models .load_remote_model (source = source , loader = loader ) as pipeline :
151+ assert isinstance (pipeline , SegmentAnythingPipeline )
152+ masks = pipeline .segment (image = image , inputs = inputs )
134153
135154 masks = self ._process_masks (masks )
136155 if self .apply_polygon_refinement :
0 commit comments