Skip to content

Commit 4d6a868

Browse files
committed
Upload video_object_segmenting_mapper and video_depth_estimation_mapper
1 parent c07c2d7 commit 4d6a868

File tree

8 files changed

+525
-1
lines changed

8 files changed

+525
-1
lines changed

data_juicer/ops/mapper/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,11 @@
7979
from .video_captioning_from_frames_mapper import VideoCaptioningFromFramesMapper
8080
from .video_captioning_from_summarizer_mapper import VideoCaptioningFromSummarizerMapper
8181
from .video_captioning_from_video_mapper import VideoCaptioningFromVideoMapper
82+
from .video_depth_estimation_mapper import VideoDepthEstimationMapper
8283
from .video_extract_frames_mapper import VideoExtractFramesMapper
8384
from .video_face_blur_mapper import VideoFaceBlurMapper
8485
from .video_ffmpeg_wrapped_mapper import VideoFFmpegWrappedMapper
86+
from .video_object_segmenting_mapper import VideoObjectSegmentingMapper
8587
from .video_remove_watermark_mapper import VideoRemoveWatermarkMapper
8688
from .video_resize_aspect_ratio_mapper import VideoResizeAspectRatioMapper
8789
from .video_resize_resolution_mapper import VideoResizeResolutionMapper
@@ -168,9 +170,11 @@
168170
"VideoCaptioningFromFramesMapper",
169171
"VideoCaptioningFromSummarizerMapper",
170172
"VideoCaptioningFromVideoMapper",
173+
"VideoDepthEstimationMapper",
171174
"VideoExtractFramesMapper",
172175
"VideoFFmpegWrappedMapper",
173176
"VideoFaceBlurMapper",
177+
"VideoObjectSegmentingMapper",
174178
"VideoRemoveWatermarkMapper",
175179
"VideoResizeAspectRatioMapper",
176180
"VideoResizeResolutionMapper",
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
import os
2+
3+
import numpy as np
4+
5+
from data_juicer.utils.cache_utils import DATA_JUICER_ASSETS_CACHE
6+
from data_juicer.utils.constant import Fields, MetaKeys
7+
from data_juicer.utils.lazy_loader import LazyLoader
8+
from data_juicer.utils.model_utils import get_model, prepare_model
9+
10+
from ..base_op import OPERATORS, TAGGING_OPS, UNFORKABLE, Mapper
11+
from ..op_fusion import LOADED_VIDEOS
12+
13+
OP_NAME = "video_depth_estimation_mapper"
14+
15+
cv2 = LazyLoader("cv2", "opencv-python")
16+
torch = LazyLoader("torch")
17+
# video_depth_anything = LazyLoader("video_depth_anything", "git+https://github.com/DepthAnything/Video-Depth-Anything.git", pip_args=['--no-deps'])
18+
open3d = LazyLoader("open3d")
19+
20+
21+
@TAGGING_OPS.register_module(OP_NAME)
22+
@UNFORKABLE.register_module(OP_NAME)
23+
@OPERATORS.register_module(OP_NAME)
24+
@LOADED_VIDEOS.register_module(OP_NAME)
25+
class VideoDepthEstimationMapper(Mapper):
26+
"""Perform depth estimation on the video."""
27+
28+
_accelerator = "cuda"
29+
30+
def __init__(
31+
self,
32+
video_depth_model_path: str = "video_depth_anything_vitb.pth",
33+
point_cloud_dir_for_metric: str = DATA_JUICER_ASSETS_CACHE,
34+
max_res: int = 1280,
35+
torch_dtype: str = "fp16",
36+
if_save_visualization: bool = False,
37+
save_visualization_dir: str = DATA_JUICER_ASSETS_CACHE,
38+
grayscale: bool = False,
39+
*args,
40+
**kwargs,
41+
):
42+
"""
43+
Initialization method.
44+
45+
:param video_depth_model_path: The path to the Video-Depth-Anything model.
46+
If the model is a 'metric' model, the code will automatically switch
47+
to metric mode, and the user should input the path for storing point
48+
clouds.
49+
:param point_cloud_dir_for_metric: The path for storing point
50+
clouds (for a 'metric' model).
51+
:param max_res: The maximum resolution threshold for videos; videos exceeding
52+
this threshold will be resized.
53+
:param torch_dtype: The floating point type used for model inference. Can be
54+
one of ['fp32', 'fp16']
55+
:param if_save_visualization: Whether to save visualization results.
56+
:param save_visualization_dir: The path for saving visualization results.
57+
:param grayscale: If True, the colorful palette will not be applied.
58+
59+
"""
60+
61+
super().__init__(*args, **kwargs)
62+
63+
video_depth_anything_repo_path = os.path.join(DATA_JUICER_ASSETS_CACHE, "Video-Depth-Anything")
64+
if not os.path.exists(video_depth_anything_repo_path):
65+
os.system(
66+
f"git clone https://github.com/DepthAnything/Video-Depth-Anything.git {video_depth_anything_repo_path}"
67+
)
68+
import sys
69+
70+
sys.path.append(os.path.join(video_depth_anything_repo_path))
71+
from utils.dc_utils import read_video_frames, save_video
72+
73+
if "metric" in video_depth_model_path:
74+
self.metric = True
75+
else:
76+
self.metric = False
77+
78+
self.read_video_frames = read_video_frames
79+
self.save_video = save_video
80+
81+
self.tag_field_name = MetaKeys.video_depth_tags
82+
self.max_res = max_res
83+
self.torch_dtype = torch_dtype
84+
self.point_cloud_dir_for_metric = point_cloud_dir_for_metric
85+
self.if_save_visualization = if_save_visualization
86+
self.save_visualization_dir = save_visualization_dir
87+
self.grayscale = grayscale
88+
self.model_key = prepare_model(model_type="video_depth_anything", model_path=video_depth_model_path)
89+
90+
def process_single(self, sample=None, rank=None):
91+
# check if it's generated already
92+
if self.tag_field_name in sample[Fields.meta]:
93+
return sample
94+
95+
# there is no video in this sample
96+
if self.video_key not in sample or not sample[self.video_key]:
97+
sample[Fields.meta][self.tag_field_name] = {"depth_data": [], "fps": -1}
98+
return sample
99+
100+
video_depth_anything_model = get_model(model_key=self.model_key, rank=rank, use_cuda=self.use_cuda())
101+
102+
frames, target_fps = self.read_video_frames(sample[self.video_key][0], -1, -1, self.max_res)
103+
depths, fps = video_depth_anything_model.infer_video_depth(
104+
frames,
105+
target_fps,
106+
input_size=518,
107+
device="cuda" if self.use_cuda() else "cpu",
108+
fp32=False if self.torch_dtype == "fp16" else True,
109+
)
110+
111+
if self.if_save_visualization:
112+
video_name = os.path.basename(sample[self.video_key][0])
113+
os.makedirs(self.save_visualization_dir, exist_ok=True)
114+
processed_video_path = os.path.join(
115+
self.save_visualization_dir, os.path.splitext(video_name)[0] + "_src.mp4"
116+
)
117+
depth_vis_path = os.path.join(self.save_visualization_dir, os.path.splitext(video_name)[0] + "_vis.mp4")
118+
self.save_video(frames, processed_video_path, fps=fps)
119+
self.save_video(depths, depth_vis_path, fps=fps, is_depths=True, grayscale=self.grayscale)
120+
121+
if self.metric:
122+
width, height = depths[0].shape[-1], depths[0].shape[-2]
123+
x, y = np.meshgrid(np.arange(width), np.arange(height))
124+
x = (x - width / 2) / 470.4
125+
y = (y - height / 2) / 470.4
126+
127+
for i, (color_image, depth) in enumerate(zip(frames, depths)):
128+
z = np.array(depth)
129+
points = np.stack((np.multiply(x, z), np.multiply(y, z), z), axis=-1).reshape(-1, 3)
130+
colors = np.array(color_image).reshape(-1, 3) / 255.0
131+
132+
pcd = open3d.geometry.PointCloud()
133+
pcd.points = open3d.utility.Vector3dVector(points)
134+
pcd.colors = open3d.utility.Vector3dVector(colors)
135+
open3d.io.write_point_cloud(
136+
os.path.join(self.point_cloud_dir_for_metric, "point" + str(i).zfill(4) + ".ply"), pcd
137+
)
138+
139+
sample[Fields.meta][self.tag_field_name] = {}
140+
sample[Fields.meta][self.tag_field_name]["depth_data"] = depths
141+
sample[Fields.meta][self.tag_field_name]["fps"] = fps
142+
143+
return sample
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
import os
2+
import random
3+
from datetime import datetime
4+
5+
import numpy as np
6+
7+
from data_juicer.utils.cache_utils import DATA_JUICER_ASSETS_CACHE
8+
from data_juicer.utils.constant import Fields, MetaKeys
9+
from data_juicer.utils.lazy_loader import LazyLoader
10+
from data_juicer.utils.model_utils import check_model, get_model, prepare_model
11+
12+
from ..base_op import OPERATORS, TAGGING_OPS, UNFORKABLE, Mapper
13+
from ..op_fusion import LOADED_VIDEOS
14+
15+
OP_NAME = "video_object_segmenting_mapper"
16+
17+
cv2 = LazyLoader("cv2", "opencv-python")
18+
ultralytics = LazyLoader("ultralytics")
19+
torch = LazyLoader("torch")
20+
transformers = LazyLoader("transformers")
21+
22+
23+
@TAGGING_OPS.register_module(OP_NAME)
24+
@UNFORKABLE.register_module(OP_NAME)
25+
@OPERATORS.register_module(OP_NAME)
26+
@LOADED_VIDEOS.register_module(OP_NAME)
27+
class VideoObjectSegmentingMapper(Mapper):
28+
"""Text-guided semantic segmentation of valid objects throughout the video (YOLOE + SAM2)."""
29+
30+
_accelerator = "cuda"
31+
32+
def __init__(
33+
self,
34+
sam2_hf_model: str = "facebook/sam2.1-hiera-tiny",
35+
yoloe_path: str = "yoloe-11l-seg.pt",
36+
yoloe_conf: float = 0.5,
37+
torch_dtype: str = "bf16",
38+
if_binarize: bool = True,
39+
if_save_visualization: bool = False,
40+
save_visualization_dir: str = DATA_JUICER_ASSETS_CACHE,
41+
*args,
42+
**kwargs,
43+
):
44+
"""
45+
Initialization method.
46+
47+
:param hf_model: Hugginface model id of SAM2.
48+
:param yoloe_path: The path to the YOLOE model.
49+
:param yoloe_conf: Confidence threshold for YOLOE object detection.
50+
:param torch_dtype: The floating point type used for model inference. Can
51+
be one of ['fp32', 'fp16', 'bf16'].
52+
:param if_binarize: Whether the final mask requires binarization.
53+
If 'if_save_visualization' is set to True, 'if_binarize' will
54+
automatically be adjusted to True.
55+
:param if_save_visualization: Whether to save visualization results.
56+
:param save_visualization_dir: The path for saving visualization results.
57+
58+
"""
59+
60+
super().__init__(*args, **kwargs)
61+
62+
# Requires the weights for YOLOE and mobileclip_blt.
63+
self.yoloe_model = ultralytics.YOLO(check_model(yoloe_path))
64+
torch_dtype_dict = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
65+
self.torch_dtype = torch_dtype_dict[torch_dtype]
66+
self.sam2_model_key = prepare_model(
67+
model_type="huggingface", torch_dtype=self.torch_dtype, pretrained_model_name_or_path=sam2_hf_model
68+
)
69+
70+
self.tag_field_name = MetaKeys.video_object_segment_tags
71+
self.yoloe_conf = yoloe_conf
72+
self.if_save_visualization = if_save_visualization
73+
self.save_visualization_dir = save_visualization_dir
74+
self.if_binarize = True if if_save_visualization else if_binarize
75+
76+
def process_single(self, sample=None, rank=None):
77+
# check if it's generated already
78+
if self.tag_field_name in sample[Fields.meta]:
79+
return sample
80+
81+
# there is no video in this sample
82+
if self.video_key not in sample or not sample[self.video_key]:
83+
sample[Fields.meta][self.tag_field_name] = {
84+
"segment_data": [],
85+
"cls_id_dict": [],
86+
"object_cls_list": [],
87+
"yoloe_conf_list": [],
88+
}
89+
return sample
90+
91+
sam2_model, sam2_processor = get_model(model_key=self.sam2_model_key, rank=rank, use_cuda=self.use_cuda())
92+
93+
# Perform semantic segmentation on the first frame using YOLOE
94+
videoCapture = cv2.VideoCapture(sample[self.video_key][0])
95+
success, initial_frame = videoCapture.read()
96+
random_num_str = str(random.randint(10000, 99999))
97+
now_time_str = str(datetime.now())
98+
if success:
99+
if not os.path.exists(DATA_JUICER_ASSETS_CACHE):
100+
os.makedirs(DATA_JUICER_ASSETS_CACHE, exist_ok=True)
101+
102+
temp_video_name = sample[self.video_key][0].split("/")[-1].replace(".mp4", "")
103+
temp_initial_frame_path = os.path.join(
104+
DATA_JUICER_ASSETS_CACHE,
105+
f"{temp_video_name}_initial_frame_{now_time_str}_{random_num_str}.jpg",
106+
)
107+
cv2.imwrite(temp_initial_frame_path, initial_frame)
108+
else:
109+
# Failed to load initial frame
110+
sample[Fields.meta][self.tag_field_name] = {
111+
"segment_data": [],
112+
"cls_id_dict": [],
113+
"object_cls_list": [],
114+
"yoloe_conf_list": [],
115+
}
116+
return sample
117+
118+
self.yoloe_model.set_classes(
119+
sample["main_character_list"], self.yoloe_model.get_text_pe(sample["main_character_list"])
120+
)
121+
results = self.yoloe_model.predict(temp_initial_frame_path, verbose=False, conf=self.yoloe_conf)
122+
yoloe_bboxes = results[0].boxes.xyxy.tolist()
123+
bboxes_cls = results[0].boxes.cls.tolist()
124+
bboxes_cls = [int(x) for x in bboxes_cls]
125+
cls_id_dict = results[0].names
126+
yoloe_conf_list = results[0].boxes.conf.tolist()
127+
128+
obj_ids = []
129+
object_cls_list = []
130+
input_boxes = []
131+
for temp_cls, temp_box in zip(bboxes_cls, yoloe_bboxes):
132+
obj_ids.append(len(obj_ids))
133+
object_cls_list.append(temp_cls)
134+
input_boxes.append([int(x) for x in temp_box])
135+
136+
input_boxes = [input_boxes]
137+
os.remove(temp_initial_frame_path)
138+
139+
if len(obj_ids) == 0:
140+
sample[Fields.meta][self.tag_field_name] = {
141+
"segment_data": [],
142+
"cls_id_dict": [],
143+
"object_cls_list": [],
144+
"yoloe_conf_list": [],
145+
}
146+
return sample
147+
148+
# Track objects with SAM2
149+
video_frames, _ = transformers.video_utils.load_video(sample[self.video_key][0])
150+
151+
inference_session = sam2_processor.init_video_session(
152+
video=video_frames,
153+
inference_device="cuda" if self.use_cuda() else "cpu",
154+
dtype=self.torch_dtype,
155+
)
156+
157+
ann_frame_idx = 0
158+
sam2_processor.add_inputs_to_inference_session(
159+
inference_session=inference_session,
160+
frame_idx=ann_frame_idx,
161+
obj_ids=obj_ids,
162+
input_boxes=input_boxes,
163+
)
164+
165+
# Get masks for all objects on the first frame
166+
outputs = sam2_model(
167+
inference_session=inference_session,
168+
frame_idx=ann_frame_idx,
169+
)
170+
video_res_masks = sam2_processor.post_process_masks(
171+
[outputs.pred_masks],
172+
original_sizes=[[inference_session.video_height, inference_session.video_width]],
173+
binarize=False,
174+
)[0]
175+
176+
# Propagate all objects through the video
177+
video_segments = []
178+
for sam2_video_output in sam2_model.propagate_in_video_iterator(inference_session):
179+
video_res_masks = sam2_processor.post_process_masks(
180+
[sam2_video_output.pred_masks],
181+
original_sizes=[[inference_session.video_height, inference_session.video_width]],
182+
binarize=self.if_binarize,
183+
)[0]
184+
video_segments.append([video_res_masks[i].tolist() for i, obj_id in enumerate(inference_session.obj_ids)])
185+
186+
sample[Fields.meta][self.tag_field_name] = {}
187+
sample[Fields.meta][self.tag_field_name]["segment_data"] = video_segments
188+
sample[Fields.meta][self.tag_field_name]["cls_id_dict"] = [cls_id_dict[key] for key in cls_id_dict]
189+
sample[Fields.meta][self.tag_field_name]["object_cls_list"] = object_cls_list
190+
sample[Fields.meta][self.tag_field_name]["yoloe_conf_list"] = yoloe_conf_list
191+
192+
if self.if_save_visualization:
193+
if not os.path.exists(self.save_visualization_dir):
194+
os.makedirs(self.save_visualization_dir, exist_ok=True)
195+
196+
for temp_frame_masks_id, temp_frame_masks in enumerate(
197+
sample[Fields.meta][self.tag_field_name]["segment_data"]
198+
):
199+
for temp_obj_id, temp_mask in enumerate(temp_frame_masks):
200+
temp_img = np.zeros((initial_frame.shape[0], initial_frame.shape[1], 3), np.uint8)
201+
temp_mask = np.squeeze(np.array(temp_mask))
202+
temp_img[temp_mask] = [225, 225, 225]
203+
204+
temp_mask_path = os.path.join(
205+
self.save_visualization_dir,
206+
f"{temp_video_name}_mask_{str(temp_obj_id)}_{str(temp_frame_masks_id)}_{now_time_str}_{random_num_str}.jpg",
207+
)
208+
cv2.imwrite(temp_mask_path, temp_img)
209+
210+
return sample

data_juicer/utils/constant.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ class MetaKeys(object):
6060
video_audio_tags = "video_audio_tags"
6161
# # video frames
6262
video_frames = "video_frames"
63+
# # object segment info in video
64+
video_object_segment_tags = "video_object_segment_tags"
65+
# # depth info in video
66+
video_depth_tags = "video_depth_tags"
6367
# # image tags
6468
image_tags = "image_tags"
6569
# # bounding box tag

0 commit comments

Comments
 (0)