Skip to content

Commit ffe5e9c

Browse files
authored
Merge branch 'main' into fix_rocm_documentation
2 parents 851d44e + b0aa48d commit ffe5e9c

File tree

86 files changed

+5200
-2220
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

86 files changed

+5200
-2220
lines changed

invokeai/app/api/routers/session_queue.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from invokeai.app.api.dependencies import ApiDependencies
88
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
99
from invokeai.app.services.session_queue.session_queue_common import (
10-
QUEUE_ORDER_BY,
1110
Batch,
1211
BatchStatus,
1312
CancelAllExceptCurrentResult,
@@ -92,21 +91,18 @@ async def list_all_queue_items(
9291

9392
@session_queue_router.get(
9493
"/{queue_id}/item_ids",
95-
operation_id="get_queue_itemIds",
94+
operation_id="get_queue_item_ids",
9695
responses={
9796
200: {"model": ItemIdsResult},
9897
},
9998
)
10099
async def get_queue_item_ids(
101100
queue_id: str = Path(description="The queue id to perform this operation on"),
102-
order_by: QUEUE_ORDER_BY = Query(default="created_at", description="The sort field"),
103101
order_dir: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The order of sort"),
104102
) -> ItemIdsResult:
105103
"""Gets all queue item ids that match the given parameters"""
106104
try:
107-
return ApiDependencies.invoker.services.session_queue.get_queue_item_ids(
108-
queue_id=queue_id, order_by=order_by, order_dir=order_dir
109-
)
105+
return ApiDependencies.invoker.services.session_queue.get_queue_item_ids(queue_id=queue_id, order_dir=order_dir)
110106
except Exception as e:
111107
raise HTTPException(status_code=500, detail=f"Unexpected error while listing all queue item ids: {e}")
112108

@@ -130,7 +126,9 @@ async def get_queue_items_by_item_ids(
130126
queue_items: list[SessionQueueItem] = []
131127
for item_id in item_ids:
132128
try:
133-
queue_item = session_queue_service.get_queue_item(item_id)
129+
queue_item = session_queue_service.get_queue_item(item_id=item_id)
130+
if queue_item.queue_id != queue_id: # Auth protection for items from other queues
131+
continue
134132
queue_items.append(queue_item)
135133
except Exception:
136134
# Skip missing queue items - they may have been deleted between item id fetch and queue item fetch
@@ -376,7 +374,10 @@ async def get_queue_item(
376374
) -> SessionQueueItem:
377375
"""Gets a queue item"""
378376
try:
379-
return ApiDependencies.invoker.services.session_queue.get_queue_item(item_id)
377+
queue_item = ApiDependencies.invoker.services.session_queue.get_queue_item(item_id=item_id)
378+
if queue_item.queue_id != queue_id:
379+
raise HTTPException(status_code=404, detail=f"Queue item with id {item_id} not found in queue {queue_id}")
380+
return queue_item
380381
except SessionQueueItemNotFoundError:
381382
raise HTTPException(status_code=404, detail=f"Queue item with id {item_id} not found in queue {queue_id}")
382383
except Exception as e:

invokeai/app/invocations/fields.py

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from enum import Enum
22
from typing import Any, Callable, Optional, Tuple
33

4-
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter, model_validator
4+
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter
55
from pydantic.fields import _Unset
66
from pydantic_core import PydanticUndefined
77

88
from invokeai.app.util.metaenum import MetaEnum
9+
from invokeai.backend.image_util.segment_anything.shared import BoundingBox
910
from invokeai.backend.util.logging import InvokeAILogger
1011

1112
logger = InvokeAILogger.get_logger()
@@ -331,14 +332,9 @@ class ConditioningField(BaseModel):
331332
)
332333

333334

334-
class BoundingBoxField(BaseModel):
335+
class BoundingBoxField(BoundingBox):
335336
"""A bounding box primitive value."""
336337

337-
x_min: int = Field(ge=0, description="The minimum x-coordinate of the bounding box (inclusive).")
338-
x_max: int = Field(ge=0, description="The maximum x-coordinate of the bounding box (exclusive).")
339-
y_min: int = Field(ge=0, description="The minimum y-coordinate of the bounding box (inclusive).")
340-
y_max: int = Field(ge=0, description="The maximum y-coordinate of the bounding box (exclusive).")
341-
342338
score: Optional[float] = Field(
343339
default=None,
344340
ge=0.0,
@@ -347,21 +343,6 @@ class BoundingBoxField(BaseModel):
347343
"when the bounding box was produced by a detector and has an associated confidence score.",
348344
)
349345

350-
@model_validator(mode="after")
351-
def check_coords(self):
352-
if self.x_min > self.x_max:
353-
raise ValueError(f"x_min ({self.x_min}) is greater than x_max ({self.x_max}).")
354-
if self.y_min > self.y_max:
355-
raise ValueError(f"y_min ({self.y_min}) is greater than y_max ({self.y_max}).")
356-
return self
357-
358-
def tuple(self) -> Tuple[int, int, int, int]:
359-
"""
360-
Returns the bounding box as a tuple suitable for use with PIL's `Image.crop()` method.
361-
This method returns a tuple of the form (left, upper, right, lower) == (x_min, y_min, x_max, y_max).
362-
"""
363-
return (self.x_min, self.y_min, self.x_max, self.y_max)
364-
365346

366347
class MetadataField(RootModel[dict[str, Any]]):
367348
"""

invokeai/app/invocations/segment_anything.py

Lines changed: 63 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,72 +1,75 @@
1-
from enum import Enum
1+
from itertools import zip_longest
22
from pathlib import Path
33
from typing import Literal
44

55
import numpy as np
66
import torch
77
from PIL import Image
8-
from pydantic import BaseModel, Field
9-
from transformers import AutoProcessor
8+
from pydantic import BaseModel, Field, model_validator
109
from transformers.models.sam import SamModel
1110
from transformers.models.sam.processing_sam import SamProcessor
11+
from transformers.models.sam2 import Sam2Model
12+
from transformers.models.sam2.processing_sam2 import Sam2Processor
1213

1314
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
1415
from invokeai.app.invocations.fields import BoundingBoxField, ImageField, InputField, TensorField
1516
from invokeai.app.invocations.primitives import MaskOutput
1617
from invokeai.app.services.shared.invocation_context import InvocationContext
1718
from invokeai.backend.image_util.segment_anything.mask_refinement import mask_to_polygon, polygon_to_mask
19+
from invokeai.backend.image_util.segment_anything.segment_anything_2_pipeline import SegmentAnything2Pipeline
1820
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
19-
20-
SegmentAnythingModelKey = Literal["segment-anything-base", "segment-anything-large", "segment-anything-huge"]
21+
from invokeai.backend.image_util.segment_anything.shared import SAMInput, SAMPoint
22+
23+
SegmentAnythingModelKey = Literal[
24+
"segment-anything-base",
25+
"segment-anything-large",
26+
"segment-anything-huge",
27+
"segment-anything-2-tiny",
28+
"segment-anything-2-small",
29+
"segment-anything-2-base",
30+
"segment-anything-2-large",
31+
]
2132
SEGMENT_ANYTHING_MODEL_IDS: dict[SegmentAnythingModelKey, str] = {
2233
"segment-anything-base": "facebook/sam-vit-base",
2334
"segment-anything-large": "facebook/sam-vit-large",
2435
"segment-anything-huge": "facebook/sam-vit-huge",
36+
"segment-anything-2-tiny": "facebook/sam2.1-hiera-tiny",
37+
"segment-anything-2-small": "facebook/sam2.1-hiera-small",
38+
"segment-anything-2-base": "facebook/sam2.1-hiera-base-plus",
39+
"segment-anything-2-large": "facebook/sam2.1-hiera-large",
2540
}
2641

2742

28-
class SAMPointLabel(Enum):
29-
negative = -1
30-
neutral = 0
31-
positive = 1
32-
33-
34-
class SAMPoint(BaseModel):
35-
x: int = Field(..., description="The x-coordinate of the point")
36-
y: int = Field(..., description="The y-coordinate of the point")
37-
label: SAMPointLabel = Field(..., description="The label of the point")
38-
39-
4043
class SAMPointsField(BaseModel):
41-
points: list[SAMPoint] = Field(..., description="The points of the object")
44+
points: list[SAMPoint] = Field(..., description="The points of the object", min_length=1)
4245

43-
def to_list(self) -> list[list[int]]:
46+
def to_list(self) -> list[list[float]]:
4447
return [[point.x, point.y, point.label.value] for point in self.points]
4548

4649

4750
@invocation(
4851
"segment_anything",
4952
title="Segment Anything",
50-
tags=["prompt", "segmentation"],
53+
tags=["prompt", "segmentation", "sam", "sam2"],
5154
category="segmentation",
52-
version="1.2.0",
55+
version="1.3.0",
5356
)
5457
class SegmentAnythingInvocation(BaseInvocation):
55-
"""Runs a Segment Anything Model."""
58+
"""Runs a Segment Anything Model (SAM or SAM2)."""
5659

5760
# Reference:
5861
# - https://arxiv.org/pdf/2304.02643
5962
# - https://huggingface.co/docs/transformers/v4.43.3/en/model_doc/grounding-dino#grounded-sam
6063
# - https://github.com/NielsRogge/Transformers-Tutorials/blob/a39f33ac1557b02ebfb191ea7753e332b5ca933f/Grounding%20DINO/GroundingDINO_with_Segment_Anything.ipynb
6164

62-
model: SegmentAnythingModelKey = InputField(description="The Segment Anything model to use.")
65+
model: SegmentAnythingModelKey = InputField(description="The Segment Anything model to use (SAM or SAM2).")
6366
image: ImageField = InputField(description="The image to segment.")
6467
bounding_boxes: list[BoundingBoxField] | None = InputField(
65-
default=None, description="The bounding boxes to prompt the SAM model with."
68+
default=None, description="The bounding boxes to prompt the model with."
6669
)
6770
point_lists: list[SAMPointsField] | None = InputField(
6871
default=None,
69-
description="The list of point lists to prompt the SAM model with. Each list of points represents a single object.",
72+
description="The list of point lists to prompt the model with. Each list of points represents a single object.",
7073
)
7174
apply_polygon_refinement: bool = InputField(
7275
description="Whether to apply polygon refinement to the masks. This will smooth the edges of the masks slightly and ensure that each mask consists of a single closed polygon (before merging).",
@@ -77,14 +80,18 @@ class SegmentAnythingInvocation(BaseInvocation):
7780
default="all",
7881
)
7982

83+
@model_validator(mode="after")
84+
def validate_points_and_boxes_len(self):
85+
if self.point_lists is not None and self.bounding_boxes is not None:
86+
if len(self.point_lists) != len(self.bounding_boxes):
87+
raise ValueError("If both point_lists and bounding_boxes are provided, they must have the same length.")
88+
return self
89+
8090
@torch.no_grad()
8191
def invoke(self, context: InvocationContext) -> MaskOutput:
8292
# The models expect a 3-channel RGB image.
8393
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
8494

85-
if self.point_lists is not None and self.bounding_boxes is not None:
86-
raise ValueError("Only one of point_lists or bounding_box can be provided.")
87-
8895
if (not self.bounding_boxes or len(self.bounding_boxes) == 0) and (
8996
not self.point_lists or len(self.point_lists) == 0
9097
):
@@ -111,26 +118,38 @@ def _load_sam_model(model_path: Path):
111118
# model, and figure out how to make it work in the pipeline.
112119
# torch_dtype=TorchDevice.choose_torch_dtype(),
113120
)
114-
115-
sam_processor = AutoProcessor.from_pretrained(model_path, local_files_only=True)
116-
assert isinstance(sam_processor, SamProcessor)
121+
sam_processor = SamProcessor.from_pretrained(model_path, local_files_only=True)
117122
return SegmentAnythingPipeline(sam_model=sam_model, sam_processor=sam_processor)
118123

124+
@staticmethod
125+
def _load_sam_2_model(model_path: Path):
126+
sam2_model = Sam2Model.from_pretrained(model_path, local_files_only=True)
127+
sam2_processor = Sam2Processor.from_pretrained(model_path, local_files_only=True)
128+
return SegmentAnything2Pipeline(sam2_model=sam2_model, sam2_processor=sam2_processor)
129+
119130
def _segment(self, context: InvocationContext, image: Image.Image) -> list[torch.Tensor]:
120-
"""Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes."""
121-
# Convert the bounding boxes to the SAM input format.
122-
sam_bounding_boxes = (
123-
[[bb.x_min, bb.y_min, bb.x_max, bb.y_max] for bb in self.bounding_boxes] if self.bounding_boxes else None
124-
)
125-
sam_points = [p.to_list() for p in self.point_lists] if self.point_lists else None
131+
"""Use Segment Anything (SAM or SAM2) to generate masks given an image + a set of bounding boxes."""
132+
133+
source = SEGMENT_ANYTHING_MODEL_IDS[self.model]
134+
inputs: list[SAMInput] = []
135+
for bbox_field, point_field in zip_longest(self.bounding_boxes or [], self.point_lists or [], fillvalue=None):
136+
inputs.append(
137+
SAMInput(
138+
bounding_box=bbox_field,
139+
points=point_field.points if point_field else None,
140+
)
141+
)
126142

127-
with (
128-
context.models.load_remote_model(
129-
source=SEGMENT_ANYTHING_MODEL_IDS[self.model], loader=SegmentAnythingInvocation._load_sam_model
130-
) as sam_pipeline,
131-
):
132-
assert isinstance(sam_pipeline, SegmentAnythingPipeline)
133-
masks = sam_pipeline.segment(image=image, bounding_boxes=sam_bounding_boxes, point_lists=sam_points)
143+
if "sam2" in source:
144+
loader = SegmentAnythingInvocation._load_sam_2_model
145+
with context.models.load_remote_model(source=source, loader=loader) as pipeline:
146+
assert isinstance(pipeline, SegmentAnything2Pipeline)
147+
masks = pipeline.segment(image=image, inputs=inputs)
148+
else:
149+
loader = SegmentAnythingInvocation._load_sam_model
150+
with context.models.load_remote_model(source=source, loader=loader) as pipeline:
151+
assert isinstance(pipeline, SegmentAnythingPipeline)
152+
masks = pipeline.segment(image=image, inputs=inputs)
134153

135154
masks = self._process_masks(masks)
136155
if self.apply_polygon_refinement:

invokeai/app/services/bulk_download/bulk_download_default.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,4 +150,15 @@ def get_path(self, bulk_download_item_name: str) -> str:
150150
def _is_valid_path(self, path: Union[str, Path]) -> bool:
151151
"""Validates the path given for a bulk download."""
152152
path = path if isinstance(path, Path) else Path(path)
153-
return path.exists()
153+
154+
# Resolve the path to handle any path traversal attempts (e.g., ../)
155+
resolved_path = path.resolve()
156+
157+
# The path may not traverse out of the bulk downloads folder or its subfolders
158+
does_not_traverse = resolved_path.parent == self._bulk_downloads_folder.resolve()
159+
160+
# The path must exist and be a .zip file
161+
does_exist = resolved_path.exists()
162+
is_zip_file = resolved_path.suffix == ".zip"
163+
164+
return does_exist and is_zip_file and does_not_traverse

invokeai/app/services/session_queue/session_queue_base.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from typing import Any, Coroutine, Optional
33

44
from invokeai.app.services.session_queue.session_queue_common import (
5-
QUEUE_ORDER_BY,
65
Batch,
76
BatchStatus,
87
CancelAllExceptCurrentResult,
@@ -149,15 +148,14 @@ def list_all_queue_items(
149148
def get_queue_item_ids(
150149
self,
151150
queue_id: str,
152-
order_by: QUEUE_ORDER_BY = "created_at",
153151
order_dir: SQLiteDirection = SQLiteDirection.Descending,
154152
) -> ItemIdsResult:
155153
"""Gets all queue item ids that match the given parameters"""
156154
pass
157155

158156
@abstractmethod
159157
def get_queue_item(self, item_id: int) -> SessionQueueItem:
160-
"""Gets a session queue item by ID"""
158+
"""Gets a session queue item by ID for a given queue"""
161159
pass
162160

163161
@abstractmethod

invokeai/app/services/session_queue/session_queue_common.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,6 @@ def validate_graph(cls, v: Graph):
174174

175175
DEFAULT_QUEUE_ID = "default"
176176

177-
QUEUE_ORDER_BY = Literal["created_at", "completed_at"]
178177
QUEUE_ITEM_STATUS = Literal["pending", "in_progress", "completed", "failed", "canceled"]
179178

180179

invokeai/app/services/session_queue/session_queue_sqlite.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from invokeai.app.services.session_queue.session_queue_common import (
1111
DEFAULT_QUEUE_ID,
1212
QUEUE_ITEM_STATUS,
13-
QUEUE_ORDER_BY,
1413
Batch,
1514
BatchStatus,
1615
CancelAllExceptCurrentResult,
@@ -623,15 +622,14 @@ def list_all_queue_items(
623622
def get_queue_item_ids(
624623
self,
625624
queue_id: str,
626-
order_by: QUEUE_ORDER_BY = "created_at",
627625
order_dir: SQLiteDirection = SQLiteDirection.Descending,
628626
) -> ItemIdsResult:
629627
with self._db.transaction() as cursor_:
630628
query = f"""--sql
631629
SELECT item_id
632630
FROM session_queue
633631
WHERE queue_id = ?
634-
ORDER BY {order_by} {order_dir.value}
632+
ORDER BY created_at {order_dir.value}
635633
"""
636634
query_params = [queue_id]
637635

0 commit comments

Comments
 (0)