Skip to content

Commit 17bf2a7

Browse files
tidy(backend) cleanup sam pipelines
1 parent 17cc219 commit 17bf2a7

File tree

2 files changed

+14
-22
lines changed

2 files changed

+14
-22
lines changed

invokeai/backend/image_util/segment_anything/segment_anything_2_pipeline.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,25 +41,19 @@ def segment(
4141
image: Image.Image,
4242
inputs: list[SAMInput],
4343
) -> torch.Tensor:
44-
"""Segment an image using the SAM2 model.
45-
46-
Either bounding_boxes or point_lists must be provided. If both are provided, bounding_boxes will be used and
47-
point_lists will be ignored.
44+
"""Segment the image using the provided inputs.
4845
4946
Args:
50-
image (Image.Image): The image to segment.
51-
bounding_boxes (list[list[int]]): The bounding box prompts. Each bounding box is in the format
52-
[xmin, ymin, xmax, ymax].
53-
point_lists (list[list[list[int]]]): The points prompts. Each point is in the format [x, y, label].
54-
`label` is an integer where -1 is background, 0 is neutral, and 1 is foreground.
47+
image: The image to segment.
48+
inputs: A list of SAMInput objects containing bounding boxes and/or point lists.
5549
5650
Returns:
5751
torch.Tensor: The segmentation masks. dtype: torch.bool. shape: [num_masks, channels, height, width].
5852
"""
5953

54+
input_boxes: list[list[float]] = []
6055
input_points: list[list[list[float]]] = []
6156
input_labels: list[list[int]] = []
62-
input_boxes: list[list[float]] = []
6357

6458
for i in inputs:
6559
box: list[float] | None = None

invokeai/backend/image_util/segment_anything/segment_anything_pipeline.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,11 @@ def segment(
3333
image: Image.Image,
3434
inputs: list[SAMInput],
3535
) -> torch.Tensor:
36-
"""Run the SAM model.
37-
38-
Either bounding_boxes or point_lists must be provided. If both are provided, bounding_boxes will be used and
39-
point_lists will be ignored.
36+
"""Segment the image using the provided inputs.
4037
4138
Args:
42-
image (Image.Image): The image to segment.
43-
bounding_boxes (list[list[int]]): The bounding box prompts. Each bounding box is in the format
44-
[xmin, ymin, xmax, ymax].
45-
point_lists (list[list[list[int]]]): The points prompts. Each point is in the format [x, y, label].
46-
`label` is an integer where -1 is background, 0 is neutral, and 1 is foreground.
39+
image: The image to segment.
40+
inputs: A list of SAMInput objects containing bounding boxes and/or point lists.
4741
4842
Returns:
4943
torch.Tensor: The segmentation masks. dtype: torch.bool. shape: [num_masks, channels, height, width].
@@ -80,11 +74,15 @@ def segment(
8074
if labels is not None:
8175
input_labels.append(labels)
8276

77+
batched_input_boxes = [input_boxes] if input_boxes else None
78+
batched_input_points = input_points if input_points else None
79+
batched_input_labels = input_labels if input_labels else None
80+
8381
processed_inputs = self._sam_processor(
8482
images=image,
85-
input_boxes=[input_boxes] if input_boxes else None,
86-
input_points=input_points if input_points else None,
87-
input_labels=input_labels if input_labels else None,
83+
input_boxes=batched_input_boxes,
84+
input_points=batched_input_points,
85+
input_labels=batched_input_labels,
8886
return_tensors="pt",
8987
).to(self._sam_model.device)
9088
outputs = self._sam_model(**processed_inputs)

0 commit comments

Comments
 (0)