@@ -33,17 +33,11 @@ def segment(
33
33
image : Image .Image ,
34
34
inputs : list [SAMInput ],
35
35
) -> torch .Tensor :
36
- """Run the SAM model.
37
-
38
- Either bounding_boxes or point_lists must be provided. If both are provided, bounding_boxes will be used and
39
- point_lists will be ignored.
36
+ """Segment the image using the provided inputs.
40
37
41
38
Args:
42
- image (Image.Image): The image to segment.
43
- bounding_boxes (list[list[int]]): The bounding box prompts. Each bounding box is in the format
44
- [xmin, ymin, xmax, ymax].
45
- point_lists (list[list[list[int]]]): The points prompts. Each point is in the format [x, y, label].
46
- `label` is an integer where -1 is background, 0 is neutral, and 1 is foreground.
39
+ image: The image to segment.
40
+ inputs: A list of SAMInput objects containing bounding boxes and/or point lists.
47
41
48
42
Returns:
49
43
torch.Tensor: The segmentation masks. dtype: torch.bool. shape: [num_masks, channels, height, width].
@@ -80,11 +74,15 @@ def segment(
80
74
if labels is not None :
81
75
input_labels .append (labels )
82
76
77
+ batched_input_boxes = [input_boxes ] if input_boxes else None
78
+ batched_input_points = input_points if input_points else None
79
+ batched_input_labels = input_labels if input_labels else None
80
+
83
81
processed_inputs = self ._sam_processor (
84
82
images = image ,
85
- input_boxes = [ input_boxes ] if input_boxes else None ,
86
- input_points = input_points if input_points else None ,
87
- input_labels = input_labels if input_labels else None ,
83
+ input_boxes = batched_input_boxes ,
84
+ input_points = batched_input_points ,
85
+ input_labels = batched_input_labels ,
88
86
return_tensors = "pt" ,
89
87
).to (self ._sam_model .device )
90
88
outputs = self ._sam_model (** processed_inputs )
0 commit comments