Skip to content

Commit c9ffbf1

Browse files
committed
--author=Kanzhi Cheng <827023266@qq.com>
support qwen2.5VL
1 parent 71510f3 commit c9ffbf1

File tree

11 files changed

+502
-55
lines changed

11 files changed

+502
-55
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,15 +140,15 @@ import torch
140140

141141
from qwen_vl_utils import process_vision_info
142142
from datasets import load_dataset
143-
from transformers import Qwen2VLProcessor
143+
from transformers import AutoProcessor
144144
from gui_actor.constants import chat_template
145145
from gui_actor.modeling import Qwen2VLForConditionalGenerationWithPointer
146146
from gui_actor.inference import inference
147147

148148

149149
# load model
150150
model_name_or_path = "microsoft/GUI-Actor-7B-Qwen2-VL"
151-
data_processor = Qwen2VLProcessor.from_pretrained(model_name_or_path)
151+
data_processor = AutoProcessor.from_pretrained(model_name_or_path)
152152
tokenizer = data_processor.tokenizer
153153
model = Qwen2VLForConditionalGenerationWithPointer.from_pretrained(
154154
model_name_or_path,

eval/screenSpot.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55

66
from tqdm import tqdm
77
from datasets import load_dataset
8-
from transformers import Qwen2VLProcessor
8+
from transformers import AutoProcessor
99

1010
from gui_actor.constants import chat_template
1111
from gui_actor.modeling import Qwen2VLForConditionalGenerationWithPointer
12+
from gui_actor.modeling_qwen25vl import Qwen2_5_VLForConditionalGenerationWithPointer
1213
from gui_actor.inference import inference, ForceFollowTokensLogitsProcessor
1314
from gui_actor.utils import do_boxes_overlap
14-
from gui_actor.constants import DEFAULT_POINTER_PAD_TOKEN, DEFAULT_POINTER_END_TOKEN, grounding_system_message
15+
from gui_actor.constants import DEFAULT_POINTER_PAD_TOKEN, DEFAULT_POINTER_END_TOKEN
1516

1617
IMAGE_PATCH_SIZE =14
1718

@@ -27,19 +28,33 @@ def normalize_bbox(bbox_x1y1x2y2, img_width, img_height):
2728
y2 = y2 / img_height
2829
return x1, y1, x2, y2
2930

30-
def evaluate(model_name_or_path, use_placeholder, topk):
31+
def evaluate(model_name_or_path, model_type, use_placeholder, topk):
3132
# initialize model
32-
data_processor = Qwen2VLProcessor.from_pretrained(model_name_or_path)
33+
data_processor = AutoProcessor.from_pretrained(model_name_or_path)
3334
tokenizer = data_processor.tokenizer
3435
for k, v in tokenizer.added_tokens_encoder.items():
3536
print(v, k)
3637

37-
model = Qwen2VLForConditionalGenerationWithPointer.from_pretrained(
38-
model_name_or_path,
39-
torch_dtype=torch.bfloat16,
40-
device_map="cuda:0",
41-
attn_implementation="flash_attention_2"
42-
).eval()
38+
if model_type == "qwen2vl":
39+
print(f"Loading model with Qwen2-VL backbone from {model_name_or_path}")
40+
model = Qwen2VLForConditionalGenerationWithPointer.from_pretrained(
41+
model_name_or_path,
42+
torch_dtype=torch.bfloat16,
43+
device_map="cuda:0",
44+
attn_implementation="flash_attention_2"
45+
).eval()
46+
grounding_system_message = "You are a GUI agent. You are given a task and a screenshot of the screen. You need to perform a series of pyautogui actions to complete the task."
47+
elif model_type == "qwen25vl":
48+
print(f"Loading model with Qwen2.5-VL backbone from {model_name_or_path}")
49+
model = Qwen2_5_VLForConditionalGenerationWithPointer.from_pretrained(
50+
model_name_or_path,
51+
torch_dtype=torch.bfloat16,
52+
device_map="cuda:0",
53+
attn_implementation="flash_attention_2"
54+
).eval()
55+
grounding_system_message = "You are a GUI agent. Given a screenshot of the current GUI and a human instruction, your task is to locate the screen element that corresponds to the instruction. You should output a PyAutoGUI action that performs a click on the correct position. To indicate the click location, we will use some special tokens, which is used to refer to a visual patch later. For example, you can output: pyautogui.click(<your_special_token_here>)."
56+
else:
57+
raise ValueError(f"Invalid model type: {model_type}")
4358
print(f"Loaded model from {model_name_or_path}")
4459

4560
logits_processor_pointer = ForceFollowTokensLogitsProcessor(
@@ -248,7 +263,8 @@ def format_cell(cell):
248263
"""
249264
if __name__ == "__main__":
250265
parser = argparse.ArgumentParser()
251-
parser.add_argument("--model_name_or_path", type=str, default="microsoft/GUI-Actor-2B-Qwen2-VL")
266+
parser.add_argument("--model_type", type=str, default="qwen25vl", choices=["qwen2vl", "qwen25vl"])
267+
parser.add_argument("--model_name_or_path", type=str, default="qianhuiwu/GUI-Actor-3B-Qwen-2.5-VL")
252268
parser.add_argument("--save_path", type=str, default="./")
253269
parser.add_argument('--topk', type=int, default=3, help='Topk')
254270
parser.add_argument('--no-placeholder', dest='use_placeholder', action='store_false', help='Disable the placeholder')
@@ -271,7 +287,7 @@ def format_cell(cell):
271287
results = json.load(f)
272288
else:
273289
print(f"Evaluating {args.model_name_or_path}...")
274-
results = evaluate(args.model_name_or_path, args.use_placeholder, args.topk)
290+
results = evaluate(args.model_name_or_path, args.model_type, args.use_placeholder, args.topk)
275291
with open(pred_path, "w") as f:
276292
json.dump(results, f)
277293
print(f"Saved {len(results)} predictions to {pred_path}")

eval/screenSpot_pro.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55

66
from tqdm import tqdm
77
from datasets import load_dataset
8-
from transformers import Qwen2VLProcessor
8+
from transformers import AutoProcessor
99
from PIL import Image
1010
from gui_actor.constants import chat_template
1111
from gui_actor.modeling import Qwen2VLForConditionalGenerationWithPointer
12+
from gui_actor.modeling_qwen25vl import Qwen2_5_VLForConditionalGenerationWithPointer
1213
from gui_actor.inference import inference, ForceFollowTokensLogitsProcessor
1314
from gui_actor.utils import do_boxes_overlap
14-
from gui_actor.constants import DEFAULT_POINTER_PAD_TOKEN, DEFAULT_POINTER_END_TOKEN, grounding_system_message
15+
from gui_actor.constants import DEFAULT_POINTER_PAD_TOKEN, DEFAULT_POINTER_END_TOKEN
1516

1617
IMAGE_PATCH_SIZE =14
1718

@@ -27,19 +28,33 @@ def normalize_bbox(bbox_x1y1x2y2, img_width, img_height):
2728
y2 = y2 / img_height
2829
return x1, y1, x2, y2
2930

30-
def evaluate(model_name_or_path, data_fn, image_dir, use_placeholder, topk, resize_to_pixels=None):
31+
def evaluate(model_name_or_path, model_type, data_fn, image_dir, use_placeholder, topk, resize_to_pixels=None):
3132
# initialize model
32-
data_processor = Qwen2VLProcessor.from_pretrained(model_name_or_path)
33+
data_processor = AutoProcessor.from_pretrained(model_name_or_path)
3334
tokenizer = data_processor.tokenizer
3435
for k, v in tokenizer.added_tokens_encoder.items():
3536
print(v, k)
3637

37-
model = Qwen2VLForConditionalGenerationWithPointer.from_pretrained(
38-
model_name_or_path,
39-
torch_dtype=torch.bfloat16,
40-
device_map="cuda:0",
41-
attn_implementation="flash_attention_2"
42-
).eval()
38+
if model_type == "qwen2vl":
39+
print(f"Loading model with Qwen2-VL backbone from {model_name_or_path}")
40+
model = Qwen2VLForConditionalGenerationWithPointer.from_pretrained(
41+
model_name_or_path,
42+
torch_dtype=torch.bfloat16,
43+
device_map="cuda:0",
44+
attn_implementation="flash_attention_2"
45+
).eval()
46+
grounding_system_message = "You are a GUI agent. You are given a task and a screenshot of the screen. You need to perform a series of pyautogui actions to complete the task."
47+
elif model_type == "qwen25vl":
48+
print(f"Loading model with Qwen2.5-VL backbone from {model_name_or_path}")
49+
model = Qwen2_5_VLForConditionalGenerationWithPointer.from_pretrained(
50+
model_name_or_path,
51+
torch_dtype=torch.bfloat16,
52+
device_map="cuda:0",
53+
attn_implementation="flash_attention_2"
54+
).eval()
55+
grounding_system_message = "You are a GUI agent. Given a screenshot of the current GUI and a human instruction, your task is to locate the screen element that corresponds to the instruction. You should output a PyAutoGUI action that performs a click on the correct position. To indicate the click location, we will use some special tokens, which is used to refer to a visual patch later. For example, you can output: pyautogui.click(<your_special_token_here>)."
56+
else:
57+
raise ValueError(f"Invalid model type: {model_type}")
4358
print(f"Loaded model from {model_name_or_path}")
4459

4560
logits_processor_pointer = ForceFollowTokensLogitsProcessor(
@@ -137,6 +152,8 @@ def evaluate(model_name_or_path, data_fn, image_dir, use_placeholder, topk, resi
137152
results.append(ele)
138153

139154
return results
155+
156+
140157
def get_metric(list_of_examples,
141158
groups=["Dev", "Creative", "CAD", "Scientific", "Office", "OS"],
142159
ui_types=["text", "icon"]):
@@ -247,13 +264,15 @@ def format_cell(cell):
247264
print(metric_info)
248265
return metric_info
249266

267+
250268
"""
251269
# cd to project root directory
252270
python eval/screenSpot_pro.py --save_path <path_to_save_results> --data_path <path_to_data>
253271
"""
254272
if __name__ == "__main__":
255273
parser = argparse.ArgumentParser()
256-
parser.add_argument("--model_name_or_path", type=str, default="microsoft/GUI-Actor-2B-Qwen2-VL")
274+
parser.add_argument("--model_type", type=str, default="qwen25vl", choices=["qwen2vl", "qwen25vl"])
275+
parser.add_argument("--model_name_or_path", type=str, default="microsoft/GUI-Actor-7B-Qwen2.5-VL")
257276
parser.add_argument("--save_path", type=str, default="./")
258277
parser.add_argument("--data_path", type=str, default="/mnt/data/ScreenSpot-Pro")
259278
parser.add_argument("--resize_to_pixels", type=int, default=3200*1800, help="If set to <0, will not resize the image.")
@@ -281,7 +300,7 @@ def format_cell(cell):
281300
results = json.load(f)
282301
else:
283302
print(f"Evaluating {args.model_name_or_path}...")
284-
results = evaluate(args.model_name_or_path, data_fn, image_dir, args.use_placeholder, args.topk, resize_to_pixels)
303+
results = evaluate(args.model_name_or_path, args.model_type, data_fn, image_dir, args.use_placeholder, args.topk, resize_to_pixels)
285304
with open(pred_path, "w") as f:
286305
json.dump(results, f)
287306
print(f"Saved {len(results)} predictions to {pred_path}")

eval/screenSpot_v2.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55

66
from tqdm import tqdm
77
from datasets import load_dataset
8-
from transformers import Qwen2VLProcessor
8+
from transformers import AutoProcessor
99

1010
from gui_actor.constants import chat_template
1111
from gui_actor.modeling import Qwen2VLForConditionalGenerationWithPointer
12+
from gui_actor.modeling_qwen25vl import Qwen2_5_VLForConditionalGenerationWithPointer
1213
from gui_actor.inference import inference, ForceFollowTokensLogitsProcessor
1314
from gui_actor.utils import do_boxes_overlap
14-
from gui_actor.constants import DEFAULT_POINTER_PAD_TOKEN, DEFAULT_POINTER_END_TOKEN, grounding_system_message
15+
from gui_actor.constants import DEFAULT_POINTER_PAD_TOKEN, DEFAULT_POINTER_END_TOKEN
1516

1617
IMAGE_PATCH_SIZE =14
1718

@@ -27,19 +28,33 @@ def normalize_bbox(bbox_x1y1x2y2, img_width, img_height):
2728
y2 = y2 / img_height
2829
return x1, y1, x2, y2
2930

30-
def evaluate(model_name_or_path, use_placeholder, topk):
31+
def evaluate(model_name_or_path, model_type, use_placeholder, topk):
3132
# initialize model
32-
data_processor = Qwen2VLProcessor.from_pretrained(model_name_or_path)
33+
data_processor = AutoProcessor.from_pretrained(model_name_or_path)
3334
tokenizer = data_processor.tokenizer
3435
for k, v in tokenizer.added_tokens_encoder.items():
3536
print(v, k)
3637

37-
model = Qwen2VLForConditionalGenerationWithPointer.from_pretrained(
38-
model_name_or_path,
39-
torch_dtype=torch.bfloat16,
40-
device_map="cuda:0",
41-
attn_implementation="flash_attention_2"
42-
).eval()
38+
if model_type == "qwen2vl":
39+
print(f"Loading model with Qwen2-VL backbone from {model_name_or_path}")
40+
model = Qwen2VLForConditionalGenerationWithPointer.from_pretrained(
41+
model_name_or_path,
42+
torch_dtype=torch.bfloat16,
43+
device_map="cuda:0",
44+
attn_implementation="flash_attention_2"
45+
).eval()
46+
grounding_system_message = "You are a GUI agent. You are given a task and a screenshot of the screen. You need to perform a series of pyautogui actions to complete the task."
47+
elif model_type == "qwen25vl":
48+
print(f"Loading model with Qwen2.5-VL backbone from {model_name_or_path}")
49+
model = Qwen2_5_VLForConditionalGenerationWithPointer.from_pretrained(
50+
model_name_or_path,
51+
torch_dtype=torch.bfloat16,
52+
device_map="cuda:0",
53+
attn_implementation="flash_attention_2"
54+
).eval()
55+
grounding_system_message = "You are a GUI agent. Given a screenshot of the current GUI and a human instruction, your task is to locate the screen element that corresponds to the instruction. You should output a PyAutoGUI action that performs a click on the correct position. To indicate the click location, we will use some special tokens, which is used to refer to a visual patch later. For example, you can output: pyautogui.click(<your_special_token_here>)."
56+
else:
57+
raise ValueError(f"Invalid model type: {model_type}")
4358
print(f"Loaded model from {model_name_or_path}")
4459

4560
logits_processor_pointer = ForceFollowTokensLogitsProcessor(
@@ -248,6 +263,7 @@ def format_cell(cell):
248263
"""
249264
if __name__ == "__main__":
250265
parser = argparse.ArgumentParser()
266+
parser.add_argument("--model_type", type=str, default="qwen2vl", choices=["qwen2vl", "qwen25vl"])
251267
parser.add_argument("--model_name_or_path", type=str, default="microsoft/GUI-Actor-2B-Qwen2-VL")
252268
parser.add_argument("--save_path", type=str, default="./")
253269
parser.add_argument('--topk', type=int, default=3, help='Topk')
@@ -271,7 +287,7 @@ def format_cell(cell):
271287
results = json.load(f)
272288
else:
273289
print(f"Evaluating {args.model_name_or_path}...")
274-
results = evaluate(args.model_name_or_path, args.use_placeholder, args.topk)
290+
results = evaluate(args.model_name_or_path, args.model_type, args.use_placeholder, args.topk)
275291
with open(pred_path, "w") as f:
276292
json.dump(results, f)
277293
print(f"Saved {len(results)} predictions to {pred_path}")

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@ dependencies = [
1414
"accelerate==1.1.1",
1515
"qwen-vl-utils==0.0.8",
1616
"deepspeed==0.16.0",
17-
"transformers==4.50.0",
17+
"transformers==4.51.3",
1818
"flash-attn",
19-
"wandb==0.18.3"
19+
"wandb==0.18.3",
20+
"datasets>=2.18.0"
2021
]
2122
requires-python = ">=3.10,<3.13"
2223
readme = "README.md"

scripts/train.sh

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
#!/bin/bash
2-
llm_model="./checkpoints/qwen2vl_warmup"
3-
output_dir="./checkpoints/qwen2vl_sft"
2+
# model_type: qwen2vl or qwen25vl
3+
model_type="qwen2vl"
4+
llm_model="./checkpoints/${model_type}_warmup"
5+
output_dir="./checkpoints/${model_type}_sft"
46

57
# === Training Command ===
68
torchrun --nproc_per_node=4 train.py \
79
--deepspeed ./scripts/zero3.json \
810
--data_path data/data_config.yaml \
911
--image_folder "" \
12+
--model_type ${model_type} \
1013
--model_name_or_path ${llm_model} \
1114
--group_by_modality_length True \
1215
--bf16 True \

scripts/warmup.sh

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
#!/bin/bash
2-
llm_model="Qwen/Qwen2-VL-2B-Instruct"
3-
output_dir="./checkpoints/qwen2vl_warmup"
2+
# model_type: qwen2vl or qwen25vl
3+
model_type="qwen25vl"
4+
llm_model="Qwen/Qwen2.5-VL-3B-Instruct"
5+
output_dir="./checkpoints/${model_type}_warmup"
46

57
# === Training Command ===
68
torchrun --nproc_per_node=4 train.py \
79
--deepspeed ./scripts/zero3.json \
810
--data_path data/data_config.yaml \
911
--image_folder "" \
12+
--model_type ${model_type} \
1013
--model_name_or_path ${llm_model} \
1114
--group_by_modality_length True \
1215
--bf16 True \

src/gui_actor/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# UNMASK_TOKEN_IDS = [198, 151644, 151645]
1616

1717
# System Message
18-
grounding_system_message = "You are a GUI agent. You are given a task and a screenshot of the screen. You need to perform a series of pyautogui actions to complete the task."
18+
grounding_system_message = "You are a GUI agent. Given a screenshot of the current GUI and a human instruction, your task is to locate the screen element that corresponds to the instruction. You should output a PyAutoGUI action that performs a click on the correct position. To indicate the click location, we will use some special tokens, which is used to refer to a visual patch later. For example, you can output: pyautogui.click(<your_special_token_here>)."
1919

2020
# Chat Template
2121
chat_template = "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"

src/gui_actor/modeling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def forward(self,
166166
if inputs_embeds is None:
167167
inputs_embeds = self.model.embed_tokens(input_ids) # shape: (batch_size, seq_len, d_model)
168168
if pixel_values is not None:
169-
pixel_values = pixel_values.type(self.visual.get_dtype())
169+
pixel_values = pixel_values.type(self.visual.dtype)
170170
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
171171
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
172172
n_image_features = image_embeds.shape[0]
@@ -184,7 +184,7 @@ def forward(self,
184184
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
185185

186186
if pixel_values_videos is not None:
187-
pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
187+
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
188188
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
189189
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
190190
n_video_features = video_embeds.shape[0]

0 commit comments

Comments
 (0)