Skip to content

Commit 7218cf8

Browse files
authored
Add DetailMaster ops. (#795)
* upload DetailMaster ops * update Operators.md * update Operators.md * update according to gemini's comments * Fix bugs. * Fix a typo.
1 parent 7870825 commit 7218cf8

11 files changed

+821
-7
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ Besides, our paper is also updated to [v3](https://arxiv.org/abs/2309.02033).
107107
through the [sandbox laboratory](docs/Sandbox.md), and providing features such as feedback loops and visualization, so that you can better understand and improve your data and models. Many effect-proven datasets and models have been derived from DJ, in scenarios such as pre-training, text-to-video and image-to-text generation.
108108
![Data-in-the-loop](https://img.alicdn.com/imgextra/i2/O1CN017U7Zz31Y7XtCJ5GOz_!!6000000003012-0-tps-3640-1567.jpg)
109109

110-
## Doucmentation
110+
## Documentation
111111

112112
- Tutorial
113113
- [DJ-Cookbook](docs/tutorial/DJ-Cookbook.md)

data_juicer/ops/mapper/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
from .clean_html_mapper import CleanHtmlMapper
1313
from .clean_ip_mapper import CleanIpMapper
1414
from .clean_links_mapper import CleanLinksMapper
15+
from .detect_character_attributes_mapper import DetectCharacterAttributesMapper
16+
from .detect_character_locations_mapper import DetectCharacterLocationsMapper
17+
from .detect_main_character_mapper import DetectMainCharacterMapper
1518
from .dialog_intent_detection_mapper import DialogIntentDetectionMapper
1619
from .dialog_sentiment_detection_mapper import DialogSentimentDetectionMapper
1720
from .dialog_sentiment_intensity_mapper import DialogSentimentIntensityMapper
@@ -101,6 +104,9 @@
101104
"CleanHtmlMapper",
102105
"CleanIpMapper",
103106
"CleanLinksMapper",
107+
"DetectCharacterAttributesMapper",
108+
"DetectCharacterLocationsMapper",
109+
"DetectMainCharacterMapper",
104110
"DialogIntentDetectionMapper",
105111
"DialogSentimentDetectionMapper",
106112
"DialogSentimentIntensityMapper",
Lines changed: 303 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,303 @@
1+
import json
2+
import os
3+
import random
4+
from typing import Dict, Optional
5+
6+
from PIL import Image
7+
8+
import data_juicer
9+
from data_juicer.ops.load import load_ops
10+
from data_juicer.utils.cache_utils import DATA_JUICER_ASSETS_CACHE
11+
from data_juicer.utils.constant import Fields
12+
13+
from ..base_op import OPERATORS, TAGGING_OPS, UNFORKABLE, Mapper
14+
from ..op_fusion import LOADED_IMAGES
15+
16+
OP_NAME = "detect_character_attributes_mapper"
17+
18+
19+
@UNFORKABLE.register_module(OP_NAME)
20+
@TAGGING_OPS.register_module(OP_NAME)
21+
@OPERATORS.register_module(OP_NAME)
22+
@LOADED_IMAGES.register_module(OP_NAME)
23+
class DetectCharacterAttributesMapper(Mapper):
24+
"""Takes an image, a caption, and main character names as input to extract the characters' attributes."""
25+
26+
_accelerator = "cuda"
27+
28+
def __init__(
29+
self,
30+
detect_character_locations_mapper_args: Optional[Dict] = {},
31+
*args,
32+
**kwargs,
33+
):
34+
"""
35+
Initialization method.
36+
37+
:param detect_character_locations_mapper_args: Arguments for detect_character_locations_mapper_args.
38+
Controls the threshold for locating the main character.
39+
Default empty dict will use fixed values: default mllm_mapper_args,
40+
default image_text_matching_filter_args, yoloe_path="yoloe-11l-seg.pt",
41+
iou_threshold=0.7, matching_score_threshold=0.4,
42+
43+
"""
44+
super().__init__(*args, **kwargs)
45+
46+
self.FIXED_ARGS = {}
47+
self.FIXED_ARGS["detect_character_locations_mapper"] = {
48+
"mllm_mapper_args": {
49+
"max_new_tokens": 256,
50+
"temperature": 0.2,
51+
"top_p": None,
52+
"num_beams": 1,
53+
"hf_model": "llava-hf/llava-v1.6-vicuna-7b-hf",
54+
},
55+
"image_text_matching_filter_args": {
56+
"min_score": 0,
57+
"max_score": 1.0,
58+
"hf_blip": "Salesforce/blip-itm-base-coco",
59+
"num_proc": 1,
60+
},
61+
"yoloe_path": "yoloe-11l-seg.pt",
62+
"iou_threshold": 0.7,
63+
"matching_score_threshold": 0.4,
64+
}
65+
66+
self.detect_character_locations_mapper_args = self._prepare_op_args(
67+
"detect_character_locations_mapper", detect_character_locations_mapper_args
68+
)
69+
70+
self.fused_op_list = [{"detect_character_locations_mapper": self.detect_character_locations_mapper_args}]
71+
self.fused_ops = load_ops(self.fused_op_list)
72+
73+
accelerator_methods = set([op.accelerator for op in self.fused_ops])
74+
if "cuda" in accelerator_methods:
75+
self.accelerator = "cuda"
76+
77+
# update num_proc with the min num_proc of all fusible filters
78+
self.num_proc = min([op.runtime_np() for op in self.fused_ops]) if self.fused_ops else 1
79+
80+
def _prepare_op_args(self, op_name, args_dict):
81+
for key in self.FIXED_ARGS[op_name]:
82+
if key not in args_dict:
83+
args_dict[key] = self.FIXED_ARGS[op_name][key]
84+
args_dict["accelerator"] = self.accelerator
85+
return args_dict
86+
87+
def process_single(self, samples, rank=None):
88+
89+
if Fields.meta not in samples:
90+
samples[Fields.meta] = {}
91+
92+
detect_location_dataset = data_juicer.core.NestedDataset.from_list(
93+
[{"main_character_list": samples["main_character_list"], "images": samples["images"]}]
94+
)
95+
96+
character_locations = detect_location_dataset.map(
97+
self.fused_ops[0].process, num_proc=1, with_rank=True
98+
).to_list()
99+
character_locations = character_locations[0][Fields.meta]["main_character_locations_list"]
100+
101+
character_to_characteristics = {}
102+
character_to_cls = {}
103+
104+
for temp_character in samples["main_character_list"]:
105+
106+
# detect class
107+
prompt = (
108+
'Please classify the character "'
109+
+ temp_character
110+
+ "\" into the following categories: ['object', 'animal', 'person', 'text', 'other']. Only reply with the most fitting single category."
111+
)
112+
mllm_sample = {"text": prompt, "images": samples["images"]}
113+
output_text = self.fused_ops[0].fused_ops[0].process(mllm_sample)["text"][0].split("ASSISTANT:")[-1].strip()
114+
character_to_cls[temp_character] = output_text
115+
116+
# detect feature
117+
prompt = (
118+
'I will provide you with the corresponding description of an image, as follows: "'
119+
+ samples["text"]
120+
+ "\" Please extract all descriptions of the features related to '"
121+
+ temp_character
122+
+ '\' from this text, which may include color, material, action, and other typical features, and compile them into a list of phrase string. Formatted like: ["in a blue shirt", "sitting on a nearby fence", "with flame decals"]. Return only the phrase string list.'
123+
)
124+
mllm_sample = {"text": prompt, "images": samples["images"]}
125+
output_text = self.fused_ops[0].fused_ops[0].process(mllm_sample)["text"][0].split("ASSISTANT:")[-1].strip()
126+
try:
127+
character_to_characteristics[temp_character] = json.loads(output_text)
128+
except json.JSONDecodeError:
129+
character_to_characteristics[temp_character] = [output_text]
130+
131+
image = Image.open(samples["images"][0])
132+
valid_character_in_bbox_dict = {}
133+
for temp_character_with_bbox_idx, temp_character_with_bbox in enumerate(character_locations):
134+
crop_img = image.crop(temp_character_with_bbox["bbox"])
135+
136+
cache_img_name = (
137+
"temp_"
138+
+ str(random.randint(0, 9999))
139+
+ "_"
140+
+ str(temp_character_with_bbox_idx)
141+
+ samples["images"][0].split("/")[-1]
142+
)
143+
cache_img_path = os.path.join(
144+
DATA_JUICER_ASSETS_CACHE,
145+
cache_img_name,
146+
)
147+
crop_img.save(cache_img_path)
148+
149+
try:
150+
temp_character_cls = character_to_cls[temp_character_with_bbox["main_character"]]
151+
except Exception:
152+
os.remove(cache_img_path)
153+
continue
154+
155+
if "object" in temp_character_cls:
156+
prompt = (
157+
"Please analyze the key characteristics of the main object in this image, specifically the '"
158+
+ temp_character_with_bbox["main_character"]
159+
+ "', which may include color, material, shape, and other typical features. Currently identified characteristics include \""
160+
+ str(temp_character_cls)
161+
+ '". Please expand this list and respond in an identically formatted phrase string list.'
162+
)
163+
mllm_sample = {"text": prompt, "images": [cache_img_path]}
164+
output_text = (
165+
self.fused_ops[0].fused_ops[0].process(mllm_sample)["text"][0].split("ASSISTANT:")[-1].strip()
166+
)
167+
168+
elif "animal" in temp_character_cls:
169+
prompt = (
170+
"Please analyze the key characteristics of the primary animal in this image, specifically the '"
171+
+ temp_character_with_bbox["main_character"]
172+
+ "', which may include color, action, and other typical features. Currently identified characteristics include \""
173+
+ str(temp_character_cls)
174+
+ '". Please expand this list and respond in an identically formatted phrase string list.'
175+
)
176+
mllm_sample = {"text": prompt, "images": [cache_img_path]}
177+
output_text = (
178+
self.fused_ops[0].fused_ops[0].process(mllm_sample)["text"][0].split("ASSISTANT:")[-1].strip()
179+
)
180+
181+
elif "person" in temp_character_cls:
182+
prompt = (
183+
"Please analyze the key characteristics of the primary person in this image, specifically the '"
184+
+ temp_character_with_bbox["main_character"]
185+
+ "', which may include clothing, ages, and other typical features. Currently identified characteristics include \""
186+
+ str(temp_character_cls)
187+
+ '". Please expand this list and respond in an identically formatted phrase string list.'
188+
)
189+
mllm_sample = {"text": prompt, "images": [cache_img_path]}
190+
output_text = (
191+
self.fused_ops[0].fused_ops[0].process(mllm_sample)["text"][0].split("ASSISTANT:")[-1].strip()
192+
)
193+
194+
else:
195+
prompt = (
196+
"Please analyze the key characteristics of the primary character in this image, specifically the '"
197+
+ temp_character_with_bbox["main_character"]
198+
+ "'. Currently identified characteristics include \""
199+
+ str(temp_character_cls)
200+
+ '". Please expand this list and respond in an identically formatted phrase string list.'
201+
)
202+
mllm_sample = {"text": prompt, "images": [cache_img_path]}
203+
output_text = (
204+
self.fused_ops[0].fused_ops[0].process(mllm_sample)["text"][0].split("ASSISTANT:")[-1].strip()
205+
)
206+
207+
final_characteristic_list = []
208+
# filter
209+
try:
210+
characteristic_list = json.loads(output_text)
211+
except json.JSONDecodeError:
212+
characteristic_list = output_text
213+
214+
if isinstance(characteristic_list, list):
215+
if len(characteristic_list) == 1:
216+
characteristic_list = characteristic_list[0].replace("_", " ").split(", ")
217+
218+
try:
219+
for temp_characteristic in characteristic_list:
220+
221+
prompt = (
222+
'Please analyze the main character in this image, specifically the "'
223+
+ temp_character_with_bbox["main_character"]
224+
+ '". Is "'
225+
+ temp_characteristic
226+
+ "\" one of its features? Only respond with 'yes' if it is a perfect match. Please only respond with 'yes' or 'no'."
227+
)
228+
mllm_sample = {"text": prompt, "images": [cache_img_path]}
229+
output_text = (
230+
self.fused_ops[0]
231+
.fused_ops[0]
232+
.process(mllm_sample)["text"][0]
233+
.split("ASSISTANT:")[-1]
234+
.strip()
235+
)
236+
237+
if "yes" in output_text:
238+
final_characteristic_list.append(temp_characteristic)
239+
except Exception:
240+
os.remove(cache_img_path)
241+
continue
242+
else:
243+
try:
244+
characteristic_list = output_text.split("\n")
245+
if len(characteristic_list) == 1:
246+
characteristic_list = characteristic_list[0].replace("_", " ").split(", ")
247+
248+
for temp_characteristic in characteristic_list:
249+
prompt = (
250+
'Please analyze the main character in this image, specifically the "'
251+
+ temp_character_with_bbox["main_character"]
252+
+ '". Is "'
253+
+ temp_characteristic
254+
+ "\" one of its features? Only respond with 'yes' if it is a perfect match. Please only respond with 'yes' or 'no'."
255+
)
256+
mllm_sample = {"text": prompt, "images": [cache_img_path]}
257+
output_text = (
258+
self.fused_ops[0]
259+
.fused_ops[0]
260+
.process(mllm_sample)["text"][0]
261+
.split("ASSISTANT:")[-1]
262+
.strip()
263+
)
264+
265+
if "yes" in output_text:
266+
final_characteristic_list.append(temp_characteristic)
267+
except Exception:
268+
os.remove(cache_img_path)
269+
continue
270+
271+
valid_character_in_bbox_dict[temp_character_with_bbox["main_character"]] = {}
272+
valid_character_in_bbox_dict[temp_character_with_bbox["main_character"]]["bbox"] = temp_character_with_bbox[
273+
"bbox"
274+
]
275+
valid_character_in_bbox_dict[temp_character_with_bbox["main_character"]][
276+
"final_characteristic_list"
277+
] = final_characteristic_list
278+
279+
os.remove(cache_img_path)
280+
281+
new_character_list = []
282+
for temp_character in samples["main_character_list"]:
283+
temp_character_json = {}
284+
temp_character_json["main_character"] = temp_character
285+
if temp_character in valid_character_in_bbox_dict:
286+
temp_character_json["bbox"] = valid_character_in_bbox_dict[temp_character]["bbox"]
287+
288+
if len(valid_character_in_bbox_dict[temp_character]["final_characteristic_list"]) == 0:
289+
temp_character_json["characteristic_list"] = character_to_characteristics[temp_character]
290+
else:
291+
temp_character_json["characteristic_list"] = valid_character_in_bbox_dict[temp_character][
292+
"final_characteristic_list"
293+
]
294+
295+
else:
296+
temp_character_json["bbox"] = []
297+
temp_character_json["characteristic_list"] = character_to_characteristics[temp_character]
298+
299+
new_character_list.append(temp_character_json)
300+
301+
samples[Fields.meta]["main_character_attributes_list"] = new_character_list
302+
303+
return samples

0 commit comments

Comments
 (0)