Commit 7ac9e184 authored by 陈威志's avatar 陈威志

init

parents
## PG-Agent: An Agent Powered by Page Graph
[![Paper](http://img.shields.io/badge/Paper-arxiv.2509.03536-99D4C8.svg)](https://arxiv.org/abs/2509.03536)
This is the source code for **page graph construction** and **multi-agent workflow**.
### Data Preparation
------
The open-source datasets we use are from following repositories:
- AITW & Mind2Web: [here](https://github.com/njucckevin/SeeClick/blob/main/agent_tasks/readme_agent.md)
- GUI Odyssey: [here](https://github.com/OpenGVLab/GUI-Odyssey/blob/master/README.md)
### Page Graph Construction
------
You can run the following code to construct the corresponding page graph.
```
cd document_construction
sh pre.sh
```
- AITW
```
python aitw_document/main.py
```
- Mind2Web
```
python mind2web_document/main.py
```
- GUI Odyssey
```
python odyssey_document/main.py
```
### Multi-agent Workflow
------
You can run the following code to evaluate the agent in following benchmarks with corresponding page graphs .
```
cd workflow
sh pre.sh
```
- AITW
```
python aitw/aitw_test.py
```
- Mind2Web
```
python mind2web/mind2web_test.py
```
- GUI Odyssey
```
python odyssey/odyssey_test.py
```
### Citation
------
```
@misc{chen2025pgagentagentpoweredpage,
title={PG-Agent: An Agent Powered by Page Graph},
author={Weizhi Chen and Ziwei Wang and Leyang Yang and Sheng Zhou and Xiaoxuan Tang and Jiajun Bu and Yong Li and Wei Jiang},
year={2025},
eprint={2509.03536},
archivePrefix={arXiv},
primaryClass={cs.AI},
url={https://arxiv.org/abs/2509.03536},
}
```
\ No newline at end of file
# from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
# from qwen_vl_utils import process_vision_info
import random
# import cv2
import copy
from pathlib import Path
from tqdm import tqdm
import requests
from urllib.parse import quote
import json
from tqdm import tqdm
from langchain_community.vectorstores import FAISS
from langchain.schema import Document
from langchain_huggingface import HuggingFaceEmbeddings
from PIL import Image
import prompts
url = "http://localhost:8000/v1/chat/completions"
headers = {
"Content-Type": "application/json"
}
def chat(img_url_list: str = '', query: str = '') -> dict:
content = []
for img_url in img_url_list:
img_url = quote(img_url, safe='/:')
content.append({"type": "image_url", "image_url": {"url": img_url}})
content.append({"type": "text", "text": query})
data = {
"model": "Qwen2.5-VL-72B-Instruct",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": content}
],
'temperature':0
}
response = requests.post(url, headers=headers, data=json.dumps(data))
response = response.json()
response = response['choices'][0]['message']['content']
return response
def action2description(step):
w, h = Image.open('AITW_simplified/aitw_images/' + step['img_filename']+'.png' ).size
intention = step['goal']
action = step['action_type_text']
coord1_x, coord1_y = step['touch'][0]*w, step['touch'][1]*h
coord2_x, coord2_y = step['lift'][0]*w, step['lift'][1]*h
text = step['type_text']
if action == 'click':
descrpition = f'### Action type: {action}\n\
### Coordinates: ({coord1_x},{coord1_y})'
elif action == 'type':
descrpition = f'### Action type: {action}\n\
### Content: {text}'
else:
descrpition = action
return descrpition
def check_repeat_item(img_path, page_summary, search_document, embedding_model):
if len(search_document) == 0:
return None, None
vectorstore = FAISS.from_documents(search_document, embedding_model)
search_res = vectorstore.similarity_search(page_summary)
old_description = ""
for i, res in enumerate(search_res):
old_description += f'{i+1}. ' + res.page_content + '\n'
check_repeat_prompt = prompts.check_repeat.format(old_description=old_description)
check_repeat_res = chat([img_path], check_repeat_prompt)
sample_index = check_repeat_res.split('### Index: ')[1]
if sample_index == 'None':
return None, None
else:
sample_index = int(sample_index) - 1
old_img_path = search_res[sample_index].metadata['img_path']
double_check_res = chat([old_img_path, img_path], prompts.check_repeat_2)
double_check_res = double_check_res.split('### Conclusion: ')[1].strip()
assert double_check_res in ['Yes','No']
if double_check_res == 'No':
return None, None
repeat_index = search_res[sample_index].metadata['index']
new_summary = search_res[sample_index].page_content#check_repeat_res.split('### New Summary: ')[1]
return new_summary, repeat_index
def create_new_item(img_path, knowledge_library, search_document, embedding_model):
page_summary = chat([img_path], prompts.page_summary)
new_summary, repeat_index = check_repeat_item(img_path, page_summary, search_document, embedding_model)
if repeat_index is None:
knowledge_item = {}
knowledge_item['index'] = len(knowledge_library)
knowledge_item['page_summary'] = page_summary#.split('### Page Summary: ')[1]
knowledge_item['original_image'] = []
knowledge_item['next_page_list'] = [{'actions':[],'page_index':None}]
knowledge_library[knowledge_item['index']] = knowledge_item
search_document.append(Document(page_content = page_summary, metadata = {"index": knowledge_item['index'], "img_path": img_path}))
else:
knowledge_library[repeat_index]['page_summary'] = new_summary
search_document[repeat_index].page_content = new_summary
knowledge_item = knowledge_library[repeat_index]
return knowledge_item
def get_item(img_path, last_img_path, last_action_summary, last_page_idx, knowledge_library, search_document, embedding_model):
if last_page_idx is None:
knowledge_item = create_new_item(img_path, knowledge_library, search_document, embedding_model)
redirection_flag = True
else:
redirection_res = chat([last_img_path, img_path], prompts.redirection_judge.format(action=last_action_summary))
redirection_res = redirection_res.split('### Conclusion: ')[1].strip()
assert redirection_res in ['Yes','No']
if redirection_res == 'Yes':
knowledge_item = create_new_item(img_path, knowledge_library, search_document, embedding_model)
redirection_flag = True
elif redirection_res == 'No':
knowledge_item = knowledge_library[last_page_idx]
redirection_flag = False
knowledge_item['original_image'].append(img_path.split('http://localhost:6666/aitw_images/')[1])
return knowledge_item, redirection_flag
aitw_train_data = json.load(open('aitw_annots/aitw_data_train.json','r'))
aitw_data_type_list = [ 'install','googleapps','general','single','webshopping']
embedding_model_name = "bge-m3"
embedding_model = HuggingFaceEmbeddings(model_name = embedding_model_name,model_kwargs={'device': 'cuda:0'})
for aitw_data_type in aitw_data_type_list:
knowledge_library = {}
search_document = []
selected_episode = random.sample(aitw_train_data[aitw_data_type], len(aitw_train_data[aitw_data_type]) // 10)
for episode in tqdm(selected_episode):
last_page_idx = None
last_img_path = None
last_action_summary = None
for i in range(len(episode)):
img_path = 'http://localhost:6666/aitw_images/'+episode[i]['img_filename']+'.png'
if last_page_idx is not None:
action_description = action2description(episode[i-1])
if action_description[:10] == '### Action':
last_action_summary = chat([last_img_path], prompts.action_summary.format(action_description=action_description))
else:
last_action_summary = action_description
knowledge_item, redirection_flag = get_item(img_path, last_img_path, last_action_summary, last_page_idx, knowledge_library, search_document, embedding_model)
if last_page_idx is not None:
knowledge_library[last_page_idx]['next_page_list'][-1]['actions'].append(last_action_summary)
knowledge_library[last_page_idx]['next_page_list'][-1]['goal'] = episode[i]['goal']
if redirection_flag:
knowledge_library[last_page_idx]['next_page_list'][-1]['page_index'] = knowledge_item['index']
knowledge_library[last_page_idx]['next_page_list'].append({'actions':[],'page_index':None})
last_page_idx = knowledge_item['index']
last_img_path = img_path
f_json = open(f'{aitw_data_type}_library.json', 'w')
json.dump(knowledge_library, f_json, ensure_ascii=False, indent=4)
f_json.close()
page_summary = 'Please describe this screen containing following content with one full sentence, including \
the type of page, the function of page and the key components of the screen.'
action_summary = 'An operation has now been performed on the screen. \
Here is the type of the operation and relevant parameters:\n\
{action_description}\n\
You are required to summarize this operation with a verb phrase that begins with the given operation type.'
redirection_judge = 'You will receive the images of screens before and after operation \'{action}\'. \
You need to determine whether this operation leads to a new page, or it is just an in-page operation. \
You are required to output with the following format:\n\
### Thought: <Generate your thinking process briefly>\n\
### Conclusion: <\'Yes\' or \'No\'>\n\
Do not output anything else.'
check_repeat = 'You are a professional GUI agent. You will be given a screen and some descriptions. \
Your task is to find one description that best fits the current page.\n\
Here are the descriptions:\n\
{old_description}\
You should answer with the following format:\n\
### Thought: <Generate your thinking process briefly>\n\
### Index: <The index of chosen description, or \'None\' if none of them fits>\n\
Do not output anything else.'
check_repeat_2 = 'Are these two screens similar? You should consider the type, layout, and content of the pages comprehensively.\n\
You are required to output with the following format:\n\
### Thought: <Generate your thinking process briefly>\n\
### Conclusion: <\'Yes\' or \'No\'>\n\
Do not output anything else.'
\ No newline at end of file
# from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
# from qwen_vl_utils import process_vision_info
import random
# import cv2
import copy
import os
from pathlib import Path
from tqdm import tqdm
import requests
from urllib.parse import quote
import json
from tqdm import tqdm
from langchain_community.vectorstores import FAISS
from langchain.schema import Document
from langchain_huggingface import HuggingFaceEmbeddings
import prompts
url = "http://localhost:8000/v1/chat/completions"
headers = {
"Content-Type": "application/json"
}
def chat(img_url_list: str = '', query: str = '') -> dict:
content = []
for img_url in img_url_list:
img_url = quote(img_url, safe='/:')
content.append({"type": "image_url", "image_url": {"url": img_url}})
content.append({"type": "text", "text": query})
data = {
"model": "Qwen2.5-VL-72B-Instruct",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": content}
],
'temperature':0
}
response = requests.post(url, headers=headers, data=json.dumps(data))
response = response.json()
response = response['choices'][0]['message']['content']
return response
def get_action_summary(img_path, action):
action_type = action['operation']['op']
assert action_type in ['CLICK', 'TYPE', 'SELECT']
bbox = [int(action["bbox"]["x"]), int(action["bbox"]["y"]), int(action["bbox"]["x"] + action["bbox"]["width"]),
int(action["bbox"]["y"] + action["bbox"]["height"])]
bbox_str = f'[{bbox[0]}, {bbox[1]}, {bbox[2]}, {bbox[3]}]'
if action_type == 'CLICK':
query = prompts.click_action_summary.format(bbox=bbox_str)
elif action_type == 'TYPE':
query = prompts.type_action_summary.format(content=action['operation']['value'],bbox=bbox_str)
elif action_type == 'SELECT':
query = prompts.select_action_summary.format(content=action['operation']['value'],bbox=bbox_str)
action_summary = chat([img_path], query)
if action_summary[-1] == '.':
action_summary = action_summary[:-1]
# if len(action['pos_candidates']) > 0:
# print(action['pos_candidates'][0]['choice'])
return action_summary
def check_repeat_item(domain, img_path, page_summary, search_document, embedding_model):
if len(search_document[domain]) == 0:
return None, None
vectorstore = FAISS.from_documents(search_document[domain], embedding_model)
search_res = vectorstore.similarity_search(page_summary)
old_description = ""
for i, res in enumerate(search_res):
old_description += f'{i+1}. ' + res.page_content + '\n'
check_repeat_prompt = prompts.check_repeat.format(old_description=old_description)
check_repeat_res = chat([img_path], check_repeat_prompt)
sample_index = check_repeat_res.split('### Index: ')[1].strip()#.split('\n')[0]
if sample_index == 'None':
return None, None
else:
sample_index = int(sample_index) - 1
old_img_path = search_res[sample_index].metadata['img_path']
double_check_res = chat([old_img_path, img_path], prompts.check_repeat_2)
double_check_res = double_check_res.split('### Conclusion: ')[1].strip()
assert double_check_res in ['Yes','No']
if double_check_res == 'No':
return None, None
repeat_index = search_res[sample_index].metadata['index']
new_summary = search_res[sample_index].page_content#check_repeat_res.split('### New Summary: ')[1]
return new_summary, repeat_index
def create_new_item(domain, img_path, knowledge_library, search_document, embedding_model):
page_summary = chat([img_path], prompts.page_summary)
new_summary, repeat_index = check_repeat_item(domain, img_path, page_summary, search_document, embedding_model)
if repeat_index is None:
knowledge_item = {}
knowledge_item['index'] = len(knowledge_library[domain])
knowledge_item['page_summary'] = page_summary#.split('### Page Summary: ')[1]
knowledge_item['original_image'] = []
knowledge_item['next_page_list'] = [{'actions':[],'page_index':None}]
knowledge_library[domain][knowledge_item['index']] = knowledge_item
search_document[domain].append(Document(page_content = page_summary, metadata = {"index": knowledge_item['index'], "img_path": img_path}))
else:
knowledge_library[domain][repeat_index]['page_summary'] = new_summary
search_document[domain][repeat_index].page_content = new_summary
knowledge_item = knowledge_library[domain][repeat_index]
return knowledge_item
def get_item(domain, img_path, last_img_path, last_action_summary, last_page_idx, knowledge_library, search_document, embedding_model):
if last_page_idx is None:
knowledge_item = create_new_item(domain, img_path, knowledge_library, search_document, embedding_model)
redirection_flag = True
else:
redirection_res = chat([last_img_path, img_path], prompts.redirection_judge.format(action=last_action_summary))
redirection_res = redirection_res.split('### Conclusion: ')[1].strip()
assert redirection_res in ['Yes','No']
if redirection_res == 'Yes':
knowledge_item = create_new_item(domain, img_path, knowledge_library, search_document, embedding_model)
redirection_flag = True
elif redirection_res == 'No':
knowledge_item = knowledge_library[domain][last_page_idx]
redirection_flag = False
knowledge_item['original_image'].append(img_path.split('http://localhost:6667/mind2web_images/')[1])
return knowledge_item, redirection_flag
mind2web_train_data = json.load(open('mind2web_annots/mind2web_data_train.json','r'))
embedding_model_name = "bge-m3"
embedding_model = HuggingFaceEmbeddings(model_name = embedding_model_name,model_kwargs={'device': 'cuda:0'})
knowledge_library = {}
search_document = {}
selected_episode = random.sample(mind2web_train_data, len(mind2web_train_data) // 10)
for episode in tqdm(selected_episode):
last_page_idx = None
last_img_path = None
last_action_summary = None
domain = episode['domain']
if domain not in list(knowledge_library.keys()):
knowledge_library[domain] = {}
search_document[domain] = []
goal = episode['confirmed_task']
episode_id = episode['annotation_id']
action_list = episode['actions']
terminate_flag = False
for i in range(len(action_list)):
img_path = 'http://localhost:6667/mind2web_images/'+episode_id+'-'+action_list[i]['action_uid']+'.jpg'
if not os.path.exists('mind2web_images/'+episode_id+'-'+action_list[i]['action_uid']+'.jpg'):
terminate_flag = True
print('IMAGE NOT FOUND')
print(episode_id+'-'+action_list[i]['action_uid'])
break
if last_page_idx is not None:
last_action_summary = get_action_summary(last_img_path, action_list[i-1])
knowledge_item, redirection_flag = get_item(domain, img_path, last_img_path, last_action_summary, last_page_idx, knowledge_library, search_document, embedding_model)
if last_page_idx is not None:
knowledge_library[domain][last_page_idx]['next_page_list'][-1]['actions'].append(last_action_summary)
knowledge_library[domain][last_page_idx]['next_page_list'][-1]['goal'] = goal
if redirection_flag:
knowledge_library[domain][last_page_idx]['next_page_list'][-1]['page_index'] = knowledge_item['index']
knowledge_library[domain][last_page_idx]['next_page_list'].append({'actions':[],'page_index':None})
last_page_idx = knowledge_item['index']
last_img_path = img_path
if terminate_flag:
continue
if len(action_list) > 1:
last_action_summary = get_action_summary(last_img_path, action_list[-1])
knowledge_library[domain][last_page_idx]['next_page_list'][-1]['actions'].append(last_action_summary)
knowledge_library[domain][last_page_idx]['next_page_list'][-1]['goal'] = goal
f_json = open(f'mind2web_library.json', 'w')
json.dump(knowledge_library, f_json, ensure_ascii=False, indent=4)
f_json.close()
\ No newline at end of file
page_summary = 'Please describe this screen containing following content with one full sentence, including \
the type of page, the function of page and the key components of the screen.'
click_action_summary = 'This is a page of website. The user clicks the item at coordinates {bbox}. You are required to summarize this operation beginning with \"click\". Do not mention original coordinates.'
type_action_summary = 'This is a page of website. The user types the content \"{content}\" at coordinates {bbox}. You are required to summarize this operation beginning with \"type\". Do not mention original coordinates.'
select_action_summary = 'This is a page of website. The user opens a \"Select Menu\" or \"Dropdown List\" at coordinates {bbox}, and select the option \"{content}\". You are required to summarize this operation beginning with \"select\". Do not mention original coordinates.'
redirection_judge = 'You will receive the images of screens before and after operation \'{action}\'. \
You need to determine whether this operation leads to a new page, or it is just an in-page operation. \
You are required to output with the following format:\n\
### Thought: <Generate your thinking process briefly>\n\
### Conclusion: <\'Yes\' or \'No\'>\n\
Do not output anything else.'
check_repeat = 'You are a professional GUI agent. You will be given a webpage and some descriptions. \
Your task is to find one description that best fits the current webpage.\n\
Here are the descriptions:\n\
{old_description}\
You should answer with the following format:\n\
### Thought: <Generate your thinking process briefly>\n\
### Index: <The index of chosen description, or \'None\' if none of them fits>\n\
Do not output anything else.'
check_repeat_2 = 'Are these two screens similar? You should consider the type, layout, and content of the pages comprehensively.\n\
You are required to output with the following format:\n\
### Thought: <Generate your thinking process briefly>\n\
### Conclusion: <\'Yes\' or \'No\'>\n\
Do not output anything else.'
# from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
# from qwen_vl_utils import process_vision_info
import random
# import cv2
import copy
import os
from pathlib import Path
from tqdm import tqdm
import requests
from urllib.parse import quote
import json
from tqdm import tqdm
from langchain_community.vectorstores import FAISS
from langchain.schema import Document
from langchain_huggingface import HuggingFaceEmbeddings
from PIL import Image
import numpy as np
import prompts
url = "http://localhost:8000/v1/chat/completions"
headers = {
"Content-Type": "application/json"
}
def chat(img_url_list: str = '', query: str = '') -> dict:
content = []
for img_url in img_url_list:
img_url = quote(img_url, safe='/:')
content.append({"type": "image_url", "image_url": {"url": img_url}})
content.append({"type": "text", "text": query})
data = {
"model": "Qwen2.5-VL-72B-Instruct",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": content}
],
'temperature':0
}
response = requests.post(url, headers=headers, data=json.dumps(data))
response = response.json()
response = response['choices'][0]['message']['content']
return response
def get_action_summary(img_path, step):
action = step['action']
info = step['info']
assert action in ['CLICK', 'TEXT', 'SCROLL', 'LONG_PRESS']
if action == 'CLICK' or action == "LONG_PRESS":
if info == 'KEY_HOME':
gt = 'press home to go to the home screen'
elif info == 'KEY_BACK':
gt = 'press back to go to the previous screen'
elif info == 'KEY_APPSELECT':
gt = 'go to the previous App'
elif type(info) == list:
w, h = Image.open('GUI-Odyssey-master/data/screenshots/' + step['screenshot']).size
bbox_str = f'[{int(info[0][0]/1000*w)}, {int(info[0][1]/1000*h)}]'
query = prompts.click_action_summary.format(bbox=bbox_str)
gt = chat([img_path], query)
if gt[-1] == '.':
gt = gt[:-1]
else:
raise ValueError(f'Unknown click action {info}')
elif action == 'SCROLL':
start = np.array(info[0])
end = np.array(info[1])
delta = end - start
delta_abs = np.abs(delta)
lr = 'left' if delta[0] < 0 else 'right'
ud = 'up' if delta[1] < 0 else 'down'
if delta_abs[0] > delta_abs[1]:
gt = f"scroll {lr}"
else:
gt = f"scroll {ud}"
elif action == 'TEXT':
gt = f'type {info}'
return gt
def check_repeat_item(domain, img_path, page_summary, search_document, embedding_model):
if len(search_document[domain]) == 0:
return None, None
vectorstore = FAISS.from_documents(search_document[domain], embedding_model)
search_res = vectorstore.similarity_search(page_summary)
old_description = ""
for i, res in enumerate(search_res):
old_description += f'{i+1}. ' + res.page_content + '\n'
check_repeat_prompt = prompts.check_repeat.format(old_description=old_description)
check_repeat_res = chat([img_path], check_repeat_prompt)
sample_index = check_repeat_res.split('### Index: ')[1].strip()#.split('\n')[0]
if sample_index == 'None':
return None, None
else:
sample_index = int(sample_index) - 1
old_img_path = search_res[sample_index].metadata['img_path']
double_check_res = chat([old_img_path, img_path], prompts.check_repeat_2)
double_check_res = double_check_res.split('### Conclusion: ')[1].strip()
assert double_check_res in ['Yes','No']
if double_check_res == 'No':
return None, None
repeat_index = search_res[sample_index].metadata['index']
new_summary = search_res[sample_index].page_content#check_repeat_res.split('### New Summary: ')[1]
return new_summary, repeat_index
def create_new_item(domain, img_path, knowledge_library, search_document, embedding_model):
page_summary = chat([img_path], prompts.page_summary)
new_summary, repeat_index = check_repeat_item(domain, img_path, page_summary, search_document, embedding_model)
if repeat_index is None:
knowledge_item = {}
knowledge_item['index'] = len(knowledge_library[domain])
knowledge_item['page_summary'] = page_summary#.split('### Page Summary: ')[1]
knowledge_item['original_image'] = []
knowledge_item['next_page_list'] = [{'actions':[],'page_index':None}]
knowledge_library[domain][knowledge_item['index']] = knowledge_item
search_document[domain].append(Document(page_content = page_summary, metadata = {"index": knowledge_item['index'], "img_path": img_path}))
else:
knowledge_library[domain][repeat_index]['page_summary'] = new_summary
search_document[domain][repeat_index].page_content = new_summary
knowledge_item = knowledge_library[domain][repeat_index]
return knowledge_item
def get_item(domain, img_path, last_img_path, last_action_summary, last_page_idx, knowledge_library, search_document, embedding_model):
if last_page_idx is None:
knowledge_item = create_new_item(domain, img_path, knowledge_library, search_document, embedding_model)
redirection_flag = True
else:
redirection_res = chat([last_img_path, img_path], prompts.redirection_judge.format(action=last_action_summary))
redirection_res = redirection_res.split('### Conclusion: ')[1].strip()
assert redirection_res in ['Yes','No']
if redirection_res == 'Yes':
knowledge_item = create_new_item(domain, img_path, knowledge_library, search_document, embedding_model)
redirection_flag = True
elif redirection_res == 'No':
knowledge_item = knowledge_library[domain][last_page_idx]
redirection_flag = False
knowledge_item['original_image'].append(img_path.split('http://localhost:6668/')[1])
return knowledge_item, redirection_flag
odyssey_data = json.load(open('data/splits/splits_random_split.json','r'))
annotations_path = 'data/annotations/'
imgs_path = 'data/screenshots/'
embedding_model_name = "bge-m3"
embedding_model = HuggingFaceEmbeddings(model_name = embedding_model_name,model_kwargs={'device': 'cuda:0'})
knowledge_library = {}
search_document = {}
selected_episode_idx = random.sample(odyssey_data['train'], len(odyssey_data['train']) // 50)
for train_idx in tqdm(selected_episode_idx):
episode = json.load(open(annotations_path + train_idx,'r'))
last_page_idx = None
last_img_path = None
last_action_summary = None
domain = episode['task_info']['category']
if domain not in list(knowledge_library.keys()):
knowledge_library[domain] = {}
search_document[domain] = []
goal = episode['task_info']['instruction']
action_list = episode['steps']
for i in range(len(action_list)):
img_path = 'http://localhost:6668/'+action_list[i]['screenshot']
if last_page_idx is not None:
last_action_summary = get_action_summary(last_img_path, action_list[i-1])
knowledge_item, redirection_flag = get_item(domain, img_path, last_img_path, last_action_summary, last_page_idx, knowledge_library, search_document, embedding_model)
if last_page_idx is not None:
knowledge_library[domain][last_page_idx]['next_page_list'][-1]['actions'].append(last_action_summary)
knowledge_library[domain][last_page_idx]['next_page_list'][-1]['goal'] = goal
if redirection_flag:
knowledge_library[domain][last_page_idx]['next_page_list'][-1]['page_index'] = knowledge_item['index']
knowledge_library[domain][last_page_idx]['next_page_list'].append({'actions':[],'page_index':None})
last_page_idx = knowledge_item['index']
last_img_path = img_path
f_json = open(f'odyssey_library.json', 'w')
json.dump(knowledge_library, f_json, ensure_ascii=False, indent=4)
f_json.close()
page_summary = 'Please describe this screen containing following content with one full sentence, including \
the type of page, the function of page and the key components of the screen.'
click_action_summary = 'The user clicks the item at coordinates {bbox}. You are required to summarize this operation with a verb phrase that begins with \"click\". Do not mention original coordinates.'
redirection_judge = 'You will receive the images of screens before and after operation \'{action}\'. \
You need to determine whether this operation leads to a new page, or it is just an in-page operation. \
You are required to output with the following format:\n\
### Thought: <Generate your thinking process briefly>\n\
### Conclusion: <\'Yes\' or \'No\'>\n\
Do not output anything else.'
check_repeat = 'You are a professional GUI agent. You will be given a screen and some descriptions. \
Your task is to find one description that best fits the current page.\n\
Here are the descriptions:\n\
{old_description}\
You should answer with the following format:\n\
### Thought: <Generate your thinking process briefly>\n\
### Index: <The index of chosen description, or \'None\' if none of them fits>\n\
Do not output anything else.'
check_repeat_2 = 'Are these two screens similar? You should consider the type, layout, and content of the pages comprehensively.\n\
You are required to output with the following format:\n\
### Thought: <Generate your thinking process briefly>\n\
### Conclusion: <\'Yes\' or \'No\'>\n\
Do not output anything else.'
pip install langchain
conda install -c pytorch faiss-gpu
pip install -U langchain-community
pip install sentence-transformers
pip install numpy==1.23.2
pip install -U langchain-huggingface
pip install jax
pip install jaxlib
pip install --upgrade vllm
python -m vllm.entrypoints.openai.api_server --served-model-name Qwen2.5-VL-72B-Instruct --model Qwen2.5-VL-72B-Instruct -tp 4 --limit_mm_per_prompt image=2
'''
Adapted from https://github.com/google-research/google-research/tree/master/android_in_the_wild
'''
import jax
import jax.numpy as jnp
import numpy as np
import os
import action_type as action_type_lib
from PIL import Image
_TAP_DISTANCE_THRESHOLD = 0.14 # Fraction of the screen
ANNOTATION_WIDTH_AUGMENT_FRACTION = 1.4
ANNOTATION_HEIGHT_AUGMENT_FRACTION = 1.4
# Interval determining if an action is a tap or a swipe.
_SWIPE_DISTANCE_THRESHOLD = 0.04
def _yx_in_bounding_boxes(
yx, bounding_boxes
):
"""Check if the (y,x) point is contained in each bounding box.
Args:
yx: The (y, x) coordinate in pixels of the point.
bounding_boxes: A 2D int array of shape (num_bboxes, 4), where each row
represents a bounding box: (y_top_left, x_top_left, box_height,
box_width). Note: containment is inclusive of the bounding box edges.
Returns:
is_inside: A 1D bool array where each element specifies if the point is
contained within the respective box.
"""
y, x = yx
# `bounding_boxes` has shape (n_elements, 4); we extract each array along the
# last axis into shape (n_elements, 1), then squeeze unneeded dimension.
top, left, height, width = [
jnp.squeeze(v, axis=-1) for v in jnp.split(bounding_boxes, 4, axis=-1)
]
# The y-axis is inverted for AndroidEnv, so bottom = top + height.
bottom, right = top + height, left + width
return jnp.logical_and(y >= top, y <= bottom) & jnp.logical_and(
x >= left, x <= right)
def _resize_annotation_bounding_boxes(
annotation_positions, annotation_width_augment_fraction,
annotation_height_augment_fraction):
"""Resize the bounding boxes by the given fractions.
Args:
annotation_positions: Array of shape (N, 4), where each row represents the
(y, x, height, width) of the bounding boxes.
annotation_width_augment_fraction: The fraction to augment the box widths,
E.g., 1.4 == 240% total increase.
annotation_height_augment_fraction: Same as described for width, but for box
height.
Returns:
Resized bounding box.
"""
height_change = (
annotation_height_augment_fraction * annotation_positions[:, 2])
width_change = (
annotation_width_augment_fraction * annotation_positions[:, 3])
# Limit bounding box positions to the screen.
resized_annotations = jnp.stack([
jnp.maximum(0, annotation_positions[:, 0] - (height_change / 2)),
jnp.maximum(0, annotation_positions[:, 1] - (width_change / 2)),
jnp.minimum(1, annotation_positions[:, 2] + height_change),
jnp.minimum(1, annotation_positions[:, 3] + width_change),
],
axis=1)
return resized_annotations
def is_tap_action(normalized_start_yx,
normalized_end_yx):
distance = jnp.linalg.norm(
jnp.array(normalized_start_yx) - jnp.array(normalized_end_yx))
return distance <= _SWIPE_DISTANCE_THRESHOLD
def _is_non_dual_point_action(action_type):
return jnp.not_equal(action_type, action_type_lib.ActionType.DUAL_POINT)
def _check_tap_actions_match(
tap_1_yx,
tap_2_yx,
annotation_positions,
matching_tap_distance_threshold_screen_percentage,
annotation_width_augment_fraction,
annotation_height_augment_fraction,
):
"""Determines if two tap actions are the same."""
resized_annotation_positions = _resize_annotation_bounding_boxes(
annotation_positions,
annotation_width_augment_fraction,
annotation_height_augment_fraction,
)
# Check if the ground truth tap action falls in an annotation's bounding box.
tap1_in_box = _yx_in_bounding_boxes(tap_1_yx, resized_annotation_positions)
tap2_in_box = _yx_in_bounding_boxes(tap_2_yx, resized_annotation_positions)
both_in_box = jnp.max(tap1_in_box & tap2_in_box)
# If the ground-truth tap action falls outside any of the annotation
# bounding boxes or one of the actions is inside a bounding box and the other
# is outside bounding box or vice versa, compare the points using Euclidean
# distance.
within_threshold = (
jnp.linalg.norm(jnp.array(tap_1_yx) - jnp.array(tap_2_yx))
<= matching_tap_distance_threshold_screen_percentage
)
return jnp.logical_or(both_in_box, within_threshold)
def _check_drag_actions_match(
drag_1_touch_yx,
drag_1_lift_yx,
drag_2_touch_yx,
drag_2_lift_yx,
):
"""Determines if two drag actions are the same."""
# Store drag deltas (the change in the y and x coordinates from touch to
# lift), magnitudes, and the index of the main axis, which is the axis with
# the greatest change in coordinate value (e.g. a drag starting at (0, 0) and
# ending at (0.3, 0.5) has a main axis index of 1).
drag_1_deltas = drag_1_lift_yx - drag_1_touch_yx
drag_1_magnitudes = jnp.abs(drag_1_deltas)
drag_1_main_axis = np.argmax(drag_1_magnitudes)
drag_2_deltas = drag_2_lift_yx - drag_2_touch_yx
drag_2_magnitudes = jnp.abs(drag_2_deltas)
drag_2_main_axis = np.argmax(drag_2_magnitudes)
return jnp.equal(drag_1_main_axis, drag_2_main_axis) #只判断滑动的方向
def check_actions_match(
action_1_touch_yx,
action_1_lift_yx,
action_1_action_type,
action_2_touch_yx,
action_2_lift_yx,
action_2_action_type,
annotation_positions,
tap_distance_threshold = _TAP_DISTANCE_THRESHOLD,
annotation_width_augment_fraction = ANNOTATION_WIDTH_AUGMENT_FRACTION,
annotation_height_augment_fraction = ANNOTATION_HEIGHT_AUGMENT_FRACTION,
):
"""Determines if two actions are considered to be the same.
Two actions being "the same" is defined here as two actions that would result
in a similar screen state.
Args:
action_1_touch_yx: The (y, x) coordinates of the first action's touch.
action_1_lift_yx: The (y, x) coordinates of the first action's lift.
action_1_action_type: The action type of the first action.
action_2_touch_yx: The (y, x) coordinates of the second action's touch.
action_2_lift_yx: The (y, x) coordinates of the second action's lift.
action_2_action_type: The action type of the second action.
annotation_positions: The positions of the UI annotations for the screen. It
is A 2D int array of shape (num_bboxes, 4), where each row represents a
bounding box: (y_top_left, x_top_left, box_height, box_width). Note that
containment is inclusive of the bounding box edges.
tap_distance_threshold: The threshold that determines if two taps result in
a matching screen state if they don't fall the same bounding boxes.
annotation_width_augment_fraction: The fraction to increase the width of the
bounding box by.
annotation_height_augment_fraction: The fraction to increase the height of
of the bounding box by.
Returns:
A boolean representing whether the two given actions are the same or not.
"""
action_1_touch_yx = jnp.asarray(action_1_touch_yx)
action_1_lift_yx = jnp.asarray(action_1_lift_yx)
action_2_touch_yx = jnp.asarray(action_2_touch_yx)
action_2_lift_yx = jnp.asarray(action_2_lift_yx)
# Checks if at least one of the actions is global (i.e. not DUAL_POINT),
# because if that is the case, only the actions' types need to be compared.
has_non_dual_point_action = jnp.logical_or(
_is_non_dual_point_action(action_1_action_type),
_is_non_dual_point_action(action_2_action_type),
)
#print("non dual point: "+str(has_non_dual_point_action))
different_dual_point_types = jnp.logical_xor(
is_tap_action(action_1_touch_yx, action_1_lift_yx),
is_tap_action(action_2_touch_yx, action_2_lift_yx),
)
#print("different dual type: "+str(different_dual_point_types))
is_tap = jnp.logical_and(
is_tap_action(action_1_touch_yx, action_1_lift_yx),
is_tap_action(action_2_touch_yx, action_2_lift_yx),
)
#print("is tap: "+str(is_tap))
taps_match = _check_tap_actions_match(
action_1_touch_yx,
action_2_touch_yx,
annotation_positions,
tap_distance_threshold,
annotation_width_augment_fraction,
annotation_height_augment_fraction,
)
#print("tap match: "+str(taps_match))
taps_match = jnp.logical_and(is_tap, taps_match)
#print("tap match: "+str(taps_match))
drags_match = _check_drag_actions_match(
action_1_touch_yx, action_1_lift_yx, action_2_touch_yx, action_2_lift_yx
)
drags_match = jnp.where(is_tap, False, drags_match)
#print("drag match: "+str(drags_match))
return jnp.where(
has_non_dual_point_action,
jnp.equal(action_1_action_type, action_2_action_type),
jnp.where(
different_dual_point_types,
False,
jnp.logical_or(taps_match, drags_match),
),
)
def action_2_format(step_data):
# 把test数据集中的动作格式转换为计算matching score的格式
action_type = step_data["action_type_id"]
if action_type == 4:
if step_data["action_type_text"] == 'click': # 点击
touch_point = step_data["touch"]
lift_point = step_data["lift"]
else: # 上下左右滑动
if step_data["action_type_text"] == 'scroll down':
touch_point = [0.5, 0.8]
lift_point = [0.5, 0.2]
elif step_data["action_type_text"] == 'scroll up':
touch_point = [0.5, 0.2]
lift_point = [0.5, 0.8]
elif step_data["action_type_text"] == 'scroll left':
touch_point = [0.2, 0.5]
lift_point = [0.8, 0.5]
elif step_data["action_type_text"] == 'scroll right':
touch_point = [0.8, 0.5]
lift_point = [0.2, 0.5]
else:
touch_point = [-1.0, -1.0]
lift_point = [-1.0, -1.0]
if action_type == 3:
typed_text = step_data["type_text"]
else:
typed_text = ""
action = {"action_type": action_type, "touch_point": touch_point, "lift_point": lift_point,
"typed_text": typed_text}
action["touch_point"] = [action["touch_point"][1], action["touch_point"][0]]
action["lift_point"] = [action["lift_point"][1], action["lift_point"][0]]
action["typed_text"] = action["typed_text"].lower()
return action
def pred_2_format(step_data):
# 把模型输出的内容转换为计算action_matching的格式
action_type = step_data["action_type"]
if action_type == 4: # 点击
action_type_new = 4
touch_point = step_data["click_point"]
lift_point = step_data["click_point"]
typed_text = ""
elif action_type == 0:
action_type_new = 4
touch_point = [0.5, 0.8]
lift_point = [0.5, 0.2]
typed_text = ""
elif action_type == 1:
action_type_new = 4
touch_point = [0.5, 0.2]
lift_point = [0.5, 0.8]
typed_text = ""
elif action_type == 8:
action_type_new = 4
touch_point = [0.2, 0.5]
lift_point = [0.8, 0.5]
typed_text = ""
elif action_type == 9:
action_type_new = 4
touch_point = [0.8, 0.5]
lift_point = [0.2, 0.5]
typed_text = ""
else:
action_type_new = action_type
touch_point = [-1.0, -1.0]
lift_point = [-1.0, -1.0]
typed_text = ""
if action_type_new == 3:
typed_text = step_data["typed_text"]
action = {"action_type": action_type_new, "touch_point": touch_point, "lift_point": lift_point,
"typed_text": typed_text}
action["touch_point"] = [action["touch_point"][1], action["touch_point"][0]]
action["lift_point"] = [action["lift_point"][1], action["lift_point"][0]]
action["typed_text"] = action["typed_text"].lower()
return action
def pred_2_format_4_mpgui(step_data,img_filename=''):
# 把模型输出的内容转换为计算action_matching的格式
action_type = step_data["action_type"]
if img_filename != '':
img_path = 'AITW_simplified/aitw_images/' + img_filename
w, h = Image.open(img_path).size
else:
w, h = 1000, 1000
if action_type == 4: # 点击
action_type_new = 4
if 'click_point' in step_data:
touch_point = step_data["click_point"]
lift_point = step_data["click_point"]
# for MP-GUI
if touch_point[0] > 1.:
touch_point = [touch_point[0]/w, touch_point[1]/h]
if lift_point[0] > 1:
lift_point = [lift_point[0]/w, lift_point[1]/h]
else:
print(f'$$ error pred step: {step_data}')
touch_point = [0., 0.]
lift_point = [0., 0.]
typed_text = ""
elif action_type == 0:
action_type_new = 4
touch_point = [0.5, 0.8]
lift_point = [0.5, 0.2]
typed_text = ""
elif action_type == 1:
action_type_new = 4
touch_point = [0.5, 0.2]
lift_point = [0.5, 0.8]
typed_text = ""
elif action_type == 8:
action_type_new = 4
touch_point = [0.2, 0.5]
lift_point = [0.8, 0.5]
typed_text = ""
elif action_type == 9:
action_type_new = 4
touch_point = [0.8, 0.5]
lift_point = [0.2, 0.5]
typed_text = ""
else:
action_type_new = action_type
touch_point = [-1.0, -1.0]
lift_point = [-1.0, -1.0]
typed_text = ""
if action_type_new == 3:
typed_text = step_data["typed_text"]
action = {"action_type": action_type_new, "touch_point": touch_point, "lift_point": lift_point,
"typed_text": typed_text}
action["touch_point"] = [action["touch_point"][1], action["touch_point"][0]]
action["lift_point"] = [action["lift_point"][1], action["lift_point"][0]]
action["typed_text"] = action["typed_text"].lower()
return action
def convert_qwen_format(response):
pred_action = response
# pred_action = response.split('### Action ###')[-1].strip()
# print(pred_action)
item = {}
if 'Click' in pred_action:
action_id = 4
try:
x, y = pred_action.split('(')[-1].split(')')[0].split(',')
x, y = int(x), int(y)
except:
x,y = 0, 0
item = {
'action_type': action_id,
'click_point': (x,y)
}
elif 'Scroll("up")' in pred_action:
item = {
'action_type': 1
}
elif 'Scroll("down")' in pred_action:
item = {
'action_type': 0
}
elif 'Scroll("left")' in pred_action:
item = {
'action_type': 8
}
elif 'Scroll("right")' in pred_action:
item = {
'action_type': 9
}
elif 'Type' in pred_action:
text = pred_action.split('("')[-1].split('")')[0]
item = {
'action_type': 3,
'typed_text': text
}
elif 'Complete' in pred_action:
item ={
'action_type': 10
}
elif 'Back' in pred_action:
item ={
'action_type': 5
}
elif 'Home' in pred_action:
item ={
'action_type': 6
}
elif 'Enter' in pred_action:
item ={
'action_type': 7
}
else:
item ={
'action_type': 2 #error
}
return item
# def convert_qwen_format_mind2web(response):
# pred_action = response#.split('### Action')[-1].strip()
# item = {}
# if 'Click' in pred_action:
# try:
# x, y = pred_action.split('(')[-1].split(')')[0].split(',')
# x, y = int(x), int(y)
# click_point = (x, y)
# except:
# x,y = 0, 0
# click_point = (x, y)
# item = {"action_type": 4, "click_point": click_point}
# elif 'Type' in pred_action:
# try:
# # Type(x,y,"typed_text")
# s = pred_action.split('(')[-1]
# x, y, tp_txt = s.split(',')
# x, y = int(x), int(y)
# click_point = (x, y)
# select_value = tp_txt.replace('"','').replace(')', '')
# except:
# click_point = (0,0)
# select_value = ''
# item = {"action_type": 3, "click_point": click_point, "value": select_value}
# elif 'Select' in pred_action:
# try:
# s = pred_action.split('(')[-1]
# x, y, tp_txt = s.split(',')
# x, y = int(x), int(y)
# click_point = (x, y)
# select_value = tp_txt.replace('"','').replace(')', '')
# except:
# click_point = (0,0)
# select_value = ''
# item = {"action_type": 3, "click_point": click_point, "value": select_value}
# else:
# item = {"action_type": 0, "click_point": (0,0)}
# return item
def convert_qwen_format_mind2web(response):
pred_action = response#.split('### Action')[-1].strip()
item = {}
if 'Click' in pred_action:
try:
# print(pred_action)
s = pred_action[pred_action.find('(')+1:pred_action.rfind(')')]
x, y = s.split(',')[0], s.split(',')[1]
x, y = int(x), int(y)
click_point = (x, y)
except:
x,y = 0, 0
click_point = (x, y)
item = {"action_type": 4, "click_point": click_point}
elif 'Type' in pred_action:
try:
# print(pred_action)
s = pred_action[pred_action.find('(')+1:pred_action.rfind(')')]
x, y, tp_txt = s.split(',')[0], s.split(',')[1], ','.join(s.split(',')[2:])
x, y = int(x), int(y)
click_point = (x, y)
typed_text = tp_txt[tp_txt.find('"')+1:tp_txt.rfind('"')]
# print(select_value)
except:
click_point = (0,0)
typed_text = ''
item = {"action_type": 3, "click_point": click_point, "value": typed_text}
elif 'Select' in pred_action:
try:
# print(pred_action)
s = pred_action[pred_action.find('(')+1:pred_action.rfind(')')]
x, y, tp_txt = s.split(',')[0], s.split(',')[1], ','.join(s.split(',')[2:])
x, y = int(x), int(y)
click_point = (x, y)
select_value = tp_txt[tp_txt.find('"')+1:tp_txt.rfind('"')]
# print(select_value)
except:
click_point = (0,0)
select_value = ''
item = {"action_type": 2, "click_point": click_point, "value": select_value}
else:
item = {"action_type": 0, "click_point": (0,0)}
return item
def convert_qwen_format_mind2web_InternVL(response):
pred_action = response#.split('### Action')[-1].strip()
item = {}
if 'Click' in pred_action:
try:
# print(pred_action)
s = pred_action[pred_action.find('(')+1:pred_action.rfind(')')]
[x1, y1, x2, y2] = s.split(',')
x, y = (int(x1)+int(x2))/2, (int(y1)+int(y2))/2
click_point = (x, y)
except:
x,y = 0, 0
click_point = (x, y)
item = {"action_type": 4, "click_point": click_point}
elif 'Type' in pred_action:
try:
# print(pred_action)
s = pred_action[pred_action.find('(')+1:pred_action.rfind(')')]
x1, y1, x2, y2, tp_txt = s.split(',')[0], s.split(',')[1], s.split(',')[2], s.split(',')[3],','.join(s.split(',')[4:])
x, y = (int(x1)+int(x2))/2, (int(y1)+int(y2))/2
click_point = (x, y)
typed_text = tp_txt[tp_txt.find('"')+1:tp_txt.rfind('"')]
# print(select_value)
except:
click_point = (0,0)
typed_text = ''
item = {"action_type": 3, "click_point": click_point, "value": typed_text}
elif 'Select' in pred_action:
try:
# print(pred_action)
s = pred_action[pred_action.find('(')+1:pred_action.rfind(')')]
x1, y1, x2, y2, tp_txt = s.split(',')[0], s.split(',')[1], s.split(',')[2], s.split(',')[3],','.join(s.split(',')[4:])
x, y = (int(x1)+int(x2))/2, (int(y1)+int(y2))/2
click_point = (x, y)
select_value = tp_txt[tp_txt.find('"')+1:tp_txt.rfind('"')]
# print(select_value)
except:
click_point = (0,0)
select_value = ''
item = {"action_type": 2, "click_point": click_point, "value": select_value}
else:
item = {"action_type": 0, "click_point": (0,0)}
return item
def simple_decode(gt, img_path=None):
idx = gt.find(':')
if idx == -1:
action = gt
info = ""
else:
action = gt[:idx].strip()
info = gt[idx+1:].strip()
if action in ['CLICK', "LONG_PRESS"]:
info = eval(info)
if img_path is not None:
img_path = 'GUI-Odyssey-master/data/screenshots/' + img_path
w, h = Image.open(img_path).size
info = (info[0] / w * 1000, info[1] / h * 1000)
return {"action": action, "info": info}
TEXT_ANLS_THRESHOLD = 0.5
CLICK_COORD_THRESHOLD = 0.14
def levenshtein_distance(s1, s2):
if len(s1) > len(s2):
s1, s2 = s2, s1
distances = range(len(s1) + 1)
for i2, c2 in enumerate(s2):
distances_ = [i2+1]
for i1, c1 in enumerate(s1):
if c1 == c2:
distances_.append(distances[i1])
else:
distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
distances = distances_
return distances[-1]
def text_matching(gt, pred):
gt = gt.strip()
pred = pred.strip()
if gt in pred or pred in gt:
return True
dist = levenshtein_distance(gt, pred)
length = max(len(gt), len(pred))
value = 0.0 if length == 0 else float(dist) / float(length)
value = 1 - value
return value >= TEXT_ANLS_THRESHOLD
def click_matching(gt_info, pred_info):
if type(pred_info) == str:
pred_info = eval(pred_info)
if type(gt_info) == str:
gt_info = eval(gt_info)
pred = np.asarray(pred_info) / 1000
gt = np.asarray(gt_info) / 1000
return np.linalg.norm(pred - gt) <= CLICK_COORD_THRESHOLD
def action_matching(pred_action, pred_info, gt_action, gt_info):
pred_action = pred_action.strip()
if type(pred_info) == str:
pred_info = pred_info.strip()
gt_action = gt_action.strip()
if type(gt_info) == str:
gt_info = gt_info.strip()
if pred_action != gt_action:
return {'is_correct': 'no', 'info': 'action_fail'}
if gt_action not in ['SCROLL', 'CLICK', 'TYPE', 'LONG_PRESS']:
return {'is_correct': 'yes', 'info': 'action_correct'}
elif gt_action == 'TYPE':
text_flag = text_matching(gt_info, pred_info)
if text_flag:
return {'is_correct': 'yes', 'info': 'type_correct'}
else:
return {'is_correct': 'no', 'info': 'type_fail'}
elif gt_action == 'SCROLL':
if gt_info.lower() == pred_info.lower():
return {'is_correct': 'yes', 'info': 'scroll_correct'}
else:
return {'is_correct': 'no', 'info': 'scroll_fail'}
elif gt_action == 'CLICK' or gt_action == 'LONG_PRESS':
click_flag = click_matching(gt_info, pred_info)
if click_flag:
return {'is_correct': 'yes', 'info': 'click_correct'}
else:
return {'is_correct': 'no', 'info': 'click_fail'}
else:
raise ValueError('Invalid action type')
def stat_result(eval_dict, metric):
text_correct = sum([1 for _ in eval_dict if _['info'] == 'type_correct'])
type_correct = sum([1 for _ in eval_dict if _['info'] != 'action_fail'])
text_total = sum([1 for _ in eval_dict if _['info'].startswith('type_')])
if metric == 'macro':
action_correct = sum([1 for _ in eval_dict if _['is_correct'] == 'yes'])
AMS = round(action_correct / len(eval_dict) * 100, 2)
SR_cnt, SR_tot, SR = check_SR(eval_dict)
elif metric == 'micro':
task_cate_dict = {}
acc_list = []
SR_list = []
# print(eval_dict)
for sample in eval_dict:
cat = sample['more_info']['category']
if cat not in task_cate_dict:
task_cate_dict[cat] = []
task_cate_dict[cat].append(sample)
# assert len(task_cate_dict) == 6 #总共6个类别的数据,跑部分数据可以注释掉
for k, v in task_cate_dict.items():
SR_cnt, SR_tot, SR = check_SR(v)
SR_list.append((SR))
acc = round(sum([1 for x in v if x['is_correct'] == 'yes']) / len(v) * 100, 2)
acc_list.append(acc)
print(f'category: {k}, AMS: {acc}, SR: {SR}')
AMS = np.round(np.mean(acc_list), 2)
SR = np.round(np.mean(SR_list), 2)
else:
raise ValueError(f'No metric {metric} found.')
info = {
'AMS': AMS,
'SR': SR,
'total': len(eval_dict),
'action_type': '{} / {} = {:.2f}'.format(type_correct, len(eval_dict), type_correct / len(eval_dict) * 100),
'text': '{} / {} = {:.2f}'.format(text_correct, text_total, text_correct / text_total * 100),
}
return info
def check_SR(eval_dict):
episode_dict = {}
steps_map = {}
for data in eval_dict:
if 'img' in data: img = data['img']
elif 'image' in data: img = data['image']
else: img = data['question'].split('</img>')[0].split('<img>')[1]
img = os.path.basename(img)
tail = img.split('_')[-1]
episode = img.replace(f'_{tail}', '')
if episode not in episode_dict:
episode_dict[episode] = []
else:
assert steps_map[episode] == data['more_info']['step_length']
info = data['is_correct']
episode_dict[episode].append(info)
steps_map[episode] = data['more_info']['step_length']
cnt, tot = 0, 0
# print('=== ',episode_dict)
for k, v in episode_dict.items():
if len(v) != steps_map[k]:
print(f'step length of {k} does not match.')
continue
tot += 1
v = list(set(v))
if len(v) == 1 and v[0] == 'yes':
cnt += 1
SR = round(cnt / tot * 100, 2)
print(f'total episode: {tot}, successful episode: {cnt}, SR: {SR}')
return cnt, tot, SR
def odyssey_action_matching_evaluation(pred_output, metric='macro'):
eval_dict = []
for idx, sample in enumerate(pred_output):
question, pred, gt, more_info = sample['question'], sample['pred'], sample['gt'], sample['more_info']
# sample_eval_dict = {'question': question, 'pred': str(pred), 'gt': str(gt), 'more_info': more_info}
sample_eval_dict = {'question': question, 'pred': str(pred), 'gt': str(gt), 'more_info': more_info, 'img':sample['img']}
gt_simple_info = simple_decode(gt)
gt_action = gt_simple_info['action']
gt_info = gt_simple_info['info']
try:
pred_simple_info = simple_decode(pred, sample['img'])
# print('pred_simple_info:', pred_simple_info)
pred_action = pred_simple_info['action']
pred_info = pred_simple_info['info']
except:
# print('### eval err:', idx, pred)
log_info = {'is_correct': 'no', 'info': 'decode invalid'}
sample_eval_dict.update(log_info)
eval_dict.append(sample_eval_dict)
continue
try:
check_match = action_matching(pred_action, pred_info, gt_action, gt_info)
except Exception as exc:
print('$$$ eval err:', gt, pred, exc)
check_match = {'is_correct': 'no', 'info': 'match invalid'}
sample_eval_dict.update(check_match)
eval_dict.append(sample_eval_dict)
# print('===== ',eval_dict)
info = stat_result(eval_dict, metric)
metrics = {"info": info, "pred": eval_dict}
return metrics
\ No newline at end of file
# evaluation on aitw
import requests
import os
import random
import torch
import json
from collections import deque
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
from peft import AutoPeftModelForCausalLM
from transformers.generation import GenerationConfig
import re
import logging
import ast
import argparse
from PIL import Image
import numpy as np
from langchain_community.vectorstores import FAISS
from langchain.schema import Document
from langchain_huggingface import HuggingFaceEmbeddings
import time
from prompts import AITW_GLOBAL_PLANNING_PROMT, AITW_OBSERVATION_PROMT, AITW_PLANNING_PROMT, AITW_EXECUTION_PROMT, PAGE_SUMMARY_PROMPT, REFERENCE_FORMAT, ACTION_SUMMARY_PROMPT
import action_matching
logging.basicConfig(level=logging.INFO)
def get_global_plan(img_path, global_plan, previous_actions, goal):
if global_plan == '': global_plan = 'No old global plan.'
if len(previous_actions) == 1:
last_action = 'No action taken before.'
else:
last_action = previous_actions[-2]
global_plan_prompt = AITW_GLOBAL_PLANNING_PROMT.replace('<goal>',goal)
global_plan = chat([img_path], global_plan_prompt)
return global_plan.split('### Global Plan ###')[-1].strip()
def get_execution(img_path, action_plan, reference_actions):
exec_prompt = AITW_EXECUTION_PROMT.replace('<action_plan>',action_plan)
exec_prompt = exec_prompt.replace('<reference>',"")
execution = chat([img_path], exec_prompt)
execution = execution.split('### Action ###')[-1].strip()
return execution
def get_observation(img_path, goal, previous_step):
if previous_step == '':
previous_step = 'No previous step has been taken.'
obs_prompt = AITW_OBSERVATION_PROMT.replace('<goal>',goal).replace('<history>',previous_step)
observation = chat([img_path], obs_prompt)
return observation
def get_plan_action(img_path, goal, observations, global_plan, reference_actions,previous_step):
plan_prompt = AITW_PLANNING_PROMT.replace('<goal>',goal)
plan_prompt = plan_prompt.replace('<observation>',observations)
plan_prompt = plan_prompt.replace('<global_plan>',global_plan)
plan_prompt = plan_prompt.replace('<reference>',reference_actions)
plan_prompt = plan_prompt.replace('<history>',previous_step)
plan_action = chat([img_path], plan_prompt)
return plan_action
def bfs_goals(goal_list, idx, search_document):
if idx is None: return
queue = deque([(idx, 0)])
visited = set()
visited.add(idx)
while queue:
cur_node, cur_depth = queue.popleft()
if cur_depth >= 3: continue
nxt_node_list = search_document[cur_node].metadata['next_page_list']
for nxt_node in nxt_node_list:
if nxt_node['actions'] == []: continue
if nxt_node['goal'] not in goal_list:
goal_list.append(nxt_node['goal'])
node_idx = nxt_node['page_index']
if node_idx is not None and node_idx not in visited:
visited.add(node_idx)
queue.append((node_idx, cur_depth+1))
def get_reference_actions(img_path, goal, search_document, embedding_model):
reference_actions = ''
page_summary = chat([img_path], PAGE_SUMMARY_PROMPT)
vectorstore = FAISS.from_documents(search_document, embedding_model)
search_res = vectorstore.similarity_search(page_summary)
max_count = 10
count = 0
for res in search_res:
for actions_chain in res.metadata['next_page_list']:
if len(actions_chain['actions'])==0: continue
count=count+1
action_string = ''
for one_action in actions_chain['actions']:
action_string += ', ' + one_action
action_string = action_string[2:]
goal_list = [actions_chain['goal']]
bfs_goals(goal_list, actions_chain['page_index'], search_document)
goals_string = ''
for one_goal in goal_list:
goals_string += ', ' + one_goal
goals_string = goals_string[2:]
one_reference = REFERENCE_FORMAT.format(idx = count, actions = action_string, goals = goals_string)
reference_actions += one_reference
if count == max_count: break
if count == max_count: break
return reference_actions
def document_transform(raw_document):
search_document = []
for idx in raw_document:
item = raw_document[idx]
search_document.append(Document(page_content = item['page_summary'], metadata = item))
return search_document
def action2step_for_Qwen(step_data, img_path, img_filename):
action_type = step_data["action_type_id"]
if action_type == 4:
if step_data["action_type_text"] == 'click': # for click action, we calculate midpoint of touch and lift as the click point
touch_point = step_data["touch"]
lift_point = step_data["lift"]
action_type_new = 4
click_point = [(touch_point[0] + lift_point[0]) / 2, (touch_point[1] + lift_point[1]) / 2]
click_point = [f"{item:.2f}" for item in click_point]
w, h = Image.open('aitw_images/' + img_filename).size
click_point = "({},{})".format(int(float(click_point[0])*w), int(float(click_point[1])*h))
action_des = ACTION_SUMMARY_PROMPT.format(coordinates = click_point)
action = chat([img_path], action_des)
else:
action = step_data["action_type_text"]
elif action_type == 3:
typed_text = step_data["type_text"]
action = f'type {typed_text}'
else:
action = step_data["action_type_text"]
return action
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
random.seed(0)
url = "http://localhost:8000/v1/chat/completions"
headers = {
"Content-Type": "application/json"
}
def chat(img_url_list: str = '', query: str = '') -> dict:
content = []
for img_url in img_url_list:
content.append({"type": "image_url", "image_url": {"url": img_url}})
content.append({"type": "text", "text": query})
data = {
"model": "Qwen2.5-VL-72B-Instruct",
"messages": [
{"role": "system", "content": "You are a powerful agent that is trained to perform some basic tasks on a smartphone."},
{"role": "user", "content": content}
],
"temperature":0}
response = requests.post(url, headers=headers, data=json.dumps(data))
response = response.json()
response = response['choices'][0]['message']['content']
return response
embedding_model_name = "bge-m3"
embedding_model = HuggingFaceEmbeddings(model_name = embedding_model_name,model_kwargs={'device': 'cuda:0'})
aitw_imgs_dir = "aitw_images"
aitw_test = json.load(open('aitw_annots/aitw_data_test.json', 'r'))
score_average = 0
test_record = {}
time_count = 0
time_total = 0
for task, episodes in aitw_test.items():
print("Task: " + task)
testing_type = task
raw_document = json.load(open(f'{testing_type}_library.json', 'r'))
search_document = document_transform(raw_document)
test_record[task] = []
corr_action = 0
corr_type = 0
num_text = 0
corr_text = 0
num_scroll = 0
corr_scroll = 0
num_click = 0
corr_click = 0
num_both_click = 0
corr_both_click = 0
num_wrong_format = 0
num = 0
for j, episode in enumerate(episodes):
test_record[task].append([])
previous_actions = []
global_plan = ''
tag = 0
for st, step in enumerate(episode):
one_step_record = {}
one_step_record['ep_id'] = step['ep_id']
one_step_record['step'] = step['step']
one_step_record['img_filename'] = step['img_filename']
one_step_record['goal'] = step['goal']
img_filename = step["img_filename"] + '.png'
img_path = os.path.join(aitw_imgs_dir, img_filename)
img_path = 'http://localhost:6666/aitw_images/' + img_filename
goal = step["goal"]
previous_step = ""
for i, action in enumerate(previous_actions[-4:]):
previous_step += 'Step' + str(i+1) + ': ' + action + ". \n"
action_step = action2step_for_Qwen(step, img_path, img_filename)
previous_actions.append(action_step)
action_ref = action_matching.action_2_format(step)
t_start = time.time()
observations = get_observation(img_path, goal, previous_step)
reference_actions = get_reference_actions(img_path, goal, search_document, embedding_model)
if tag == 0:
global_plan = get_global_plan(img_path, global_plan, previous_actions, goal)
tag = 1
try:
plan_action = get_plan_action(img_path, goal, observations, global_plan, reference_actions,previous_step)
response_ = get_execution(img_path, plan_action, reference_actions)
except:
print('==== ERROR ====')
continue
response = action_matching.convert_qwen_format(response_)
time_total += time.time() - t_start
time_count += 1
print('average inference time: ',time_total,time_count,time_total / time_count)
num += 1
try:
action_pred = action_matching.pred_2_format_4_mpgui(response,img_filename)
annot_position = np.array(
[step["annot_position"][i:i + 4] for i in range(0, len(step["annot_position"]), 4)])
check_match = action_matching.check_actions_match(action_pred["touch_point"], action_pred["lift_point"],
action_pred["action_type"], action_ref["touch_point"],
action_ref["lift_point"], action_ref["action_type"],
annot_position)
print(f'-------eposide:{j+1}/step:{st+1}----------')
print('Goal: ', goal)
print('Correct: ', check_match)
print('Img: ', img_filename)
print('History: ', previous_step)
print('gt: ', action_ref,step['action_addition'])
print('Observation: ', observations)
print('pred: ', action_pred)
print('Global Planning: \n', global_plan)
print('Loacl Planning: \n', plan_action)
print('Decision: \n', response_)
print('---------------------------------------------')
# step accuracy
if check_match == True:
corr_action += 1
match_label = 1
# logging.info("Step: " + str(j) + " right")
else:
match_label = 0
# logging.info("Step: " + str(j) + " wrong")
# type accuracy
if action_pred["action_type"] == action_ref["action_type"]:
corr_type += 1
# text accuracy
if action_ref["action_type"] == 3:
num_text += 1
if (action_pred["typed_text"] == action_ref["typed_text"]) or (
action_pred["typed_text"] in action_ref["typed_text"]) or (
action_ref["typed_text"] in action_pred["typed_text"]):
corr_text += 1
if action_ref["action_type"] == 4:
# click accuracy
if action_matching.is_tap_action(action_ref["touch_point"], action_ref["lift_point"]):
num_click += 1
if match_label:
corr_click += 1
# scroll accuracy
else:
num_scroll += 1
if match_label:
corr_scroll += 1
if (action_pred["action_type"] == 4) and action_matching.is_tap_action(action_ref["touch_point"],
action_ref[
"lift_point"]) and action_matching.is_tap_action(
action_pred["touch_point"], action_pred["lift_point"]):
num_both_click += 1
if match_label:
corr_both_click += 1
one_step_record['action_label'] = action_ref
one_step_record['action_predict'] = action_pred
one_step_record['is_match'] = match_label
test_record[task][-1].append(one_step_record)
f_json = open(f'aitw_record.json', 'w')
json.dump(test_record, f_json, ensure_ascii=False, indent=4)
f_json.close()
except:
num_wrong_format += 1
print("Step: " + str(j) + " wrong format")
score_average += corr_action / num
print("Action Acc: " + str(corr_action / num))
print("Type Acc: " + str(corr_type / num))
print("Text Acc: " + str(corr_text / num_text))
print("Click Acc: " + str(corr_click / num_click))
print("Scroll Acc: " + str(corr_scroll / num_scroll))
print("Both Click Acc: " + str(corr_both_click / num_both_click))
print("Num Both Click: " + str(num_both_click))
print("Num wrong format: " + str(num_wrong_format))
print("Average score: " + str(score_average / 5))
AITW_ACTION_SPACE = '''
1. Click(x, y): An action of click a coordinate point on the smartphone screen and x,y is the position of the coordinate point on the screen.
Your click location should be a UI element or text on the screen.
A simple use case could be Click(100,238), which means you click the UI element at (100,238) on the current screen.
2. Type("typed_text"): An action of typing a piece of text.
A simple use case can be text("Hello, world!"), which inserts the string "Hello, world!" into the input area on the smartphone screen.
3. Scroll("direction"): This function is used to scroll the screen to a specific direction.
"direction" is a string that represents one of the four directions: "up", "down", "left", "right".
A simple use case could be Scroll("up"), which means you take a scroll up action on the current screen.
4. Back(): The action for returning to the previous step.
5. Home(): The action for returning to the homepage.
6. Enter(): The action of pressing the ENTER key to submit input content.
7. Complete: It means you think the task is complete.
'''
AITW_OBSERVATION_PROMT = f"""
You are a smart GUI agent, capable of comprehensively understanding the GUI interface as well as the user's intentions.
You will be given user's ultimate purpose and the previous actions that you have taken.
Your task is to carefully observe the screen, descripe it and conclude some useful clues in one sentence.
Now you can start to observe:
### User's purpose ###
<goal>
### History trajectory ###
History trajectory can remind you of the operations that have been executed before, thus avoiding repetitive actions.
<history>
### Observation ###
"""
AITW_PLANNING_PROMT = f"""
You are a smart GUI agent, capable of comprehensively understanding the GUI interface as well as the user's intentions.
Your task is to plan the next action to complete user's purpose with the help of references.
I will give you several important information:
### User's purpose ###
This is the user's global purpose, and your goal is to complete it:
<goal>
### Observation ###
This is the observation of the screen and some useful clues that help you plan:
<observation>
### Global Plan ###
This is the global plan for completing user's purpose:
<global_plan>
### History trajectory ###
History trajectory can remind you of the operations that have been executed before, thus avoiding repetitive actions.
<history>
### Reference ###
There are some reference actions that you can follow:
<reference>
Based on given information, you are required to output with following format:
1. <Please decide which sub-goal in the \"### Global Plan ###\" should be executed based on the screen image>
2. <Check if the user's global purpose has been completed. If the current screen state matches the user's global purpose, directly suggest that the task has been completed>
3. <If the global purpose is not completed: Inspired by \"### Reference ###\", you can list some actions than can possibly push the task progress or complete the goal>
"""
AITW_EXECUTION_PROMT = f"""
You are a smart GUI agent, capable of comprehensively understanding the GUI interface.
You will be given a smartphone screenshot and a plan that you decide to take.
Before you start, I will explain the data format:
### Plan ###
This is your plan:
<action_plan>
### Action Space ###
These are the functions to interact with the phone:
{AITW_ACTION_SPACE}
### Reference ###
There are some reference actions that you can follow:
<reference>
Now please choose one action in \"### Action Space ###\" for the current screen state based on \"### Plan ###\" and \"### Reference ###\".
You should output with following format:
### Thought ###
According to \"### Plan ###\", you should first determine weather the purpose has been completed. If not, think step-by-step and output the action that should be taken currently.
### Action ###
The action you finally choose from \"### Action Space ###\". Do not output anything else.
"""
AITW_GLOBAL_PLANNING_PROMT = f'''
You are an agent that is trained to complete certain tasks on a smartphone. You will be given a screenshot of a smartphone app.
The global task you should complete is:
\"<goal>\"
Now, carefully analyze all the above content and provide your output in the following format:
### Global Plan ###
Please break down the overall task into 2~3 simple sub-goals.
Note that since you can’t see future phone screenshots, each sub-goal should be abstract, high-level, and not involve interacting with specific UI elements.
'''
PAGE_SUMMARY_PROMPT = 'Please describe this screen containing following content with one full sentence: \
the type of page, the function of page and a few key components of the screen.'
REFERENCE_FORMAT = '''{idx}.
You can take following action: {actions}.
This can help you achieve goals like: {goals}.
'''
ACTION_SUMMARY_PROMPT = 'A click operation has now been performed at coordinates {coordinates}. \
You are required to summarize this operation with a verb phrase.'
'''
Adapted from https://github.com/google-research/google-research/tree/master/android_in_the_wild
'''
import jax
import jax.numpy as jnp
import numpy as np
import os
import action_type as action_type_lib
from PIL import Image
_TAP_DISTANCE_THRESHOLD = 0.14 # Fraction of the screen
ANNOTATION_WIDTH_AUGMENT_FRACTION = 1.4
ANNOTATION_HEIGHT_AUGMENT_FRACTION = 1.4
# Interval determining if an action is a tap or a swipe.
_SWIPE_DISTANCE_THRESHOLD = 0.04
def _yx_in_bounding_boxes(
yx, bounding_boxes
):
"""Check if the (y,x) point is contained in each bounding box.
Args:
yx: The (y, x) coordinate in pixels of the point.
bounding_boxes: A 2D int array of shape (num_bboxes, 4), where each row
represents a bounding box: (y_top_left, x_top_left, box_height,
box_width). Note: containment is inclusive of the bounding box edges.
Returns:
is_inside: A 1D bool array where each element specifies if the point is
contained within the respective box.
"""
y, x = yx
# `bounding_boxes` has shape (n_elements, 4); we extract each array along the
# last axis into shape (n_elements, 1), then squeeze unneeded dimension.
top, left, height, width = [
jnp.squeeze(v, axis=-1) for v in jnp.split(bounding_boxes, 4, axis=-1)
]
# The y-axis is inverted for AndroidEnv, so bottom = top + height.
bottom, right = top + height, left + width
return jnp.logical_and(y >= top, y <= bottom) & jnp.logical_and(
x >= left, x <= right)
def _resize_annotation_bounding_boxes(
annotation_positions, annotation_width_augment_fraction,
annotation_height_augment_fraction):
"""Resize the bounding boxes by the given fractions.
Args:
annotation_positions: Array of shape (N, 4), where each row represents the
(y, x, height, width) of the bounding boxes.
annotation_width_augment_fraction: The fraction to augment the box widths,
E.g., 1.4 == 240% total increase.
annotation_height_augment_fraction: Same as described for width, but for box
height.
Returns:
Resized bounding box.
"""
height_change = (
annotation_height_augment_fraction * annotation_positions[:, 2])
width_change = (
annotation_width_augment_fraction * annotation_positions[:, 3])
# Limit bounding box positions to the screen.
resized_annotations = jnp.stack([
jnp.maximum(0, annotation_positions[:, 0] - (height_change / 2)),
jnp.maximum(0, annotation_positions[:, 1] - (width_change / 2)),
jnp.minimum(1, annotation_positions[:, 2] + height_change),
jnp.minimum(1, annotation_positions[:, 3] + width_change),
],
axis=1)
return resized_annotations
def is_tap_action(normalized_start_yx,
normalized_end_yx):
distance = jnp.linalg.norm(
jnp.array(normalized_start_yx) - jnp.array(normalized_end_yx))
return distance <= _SWIPE_DISTANCE_THRESHOLD
def _is_non_dual_point_action(action_type):
return jnp.not_equal(action_type, action_type_lib.ActionType.DUAL_POINT)
def _check_tap_actions_match(
tap_1_yx,
tap_2_yx,
annotation_positions,
matching_tap_distance_threshold_screen_percentage,
annotation_width_augment_fraction,
annotation_height_augment_fraction,
):
"""Determines if two tap actions are the same."""
resized_annotation_positions = _resize_annotation_bounding_boxes(
annotation_positions,
annotation_width_augment_fraction,
annotation_height_augment_fraction,
)
# Check if the ground truth tap action falls in an annotation's bounding box.
tap1_in_box = _yx_in_bounding_boxes(tap_1_yx, resized_annotation_positions)
tap2_in_box = _yx_in_bounding_boxes(tap_2_yx, resized_annotation_positions)
both_in_box = jnp.max(tap1_in_box & tap2_in_box)
# If the ground-truth tap action falls outside any of the annotation
# bounding boxes or one of the actions is inside a bounding box and the other
# is outside bounding box or vice versa, compare the points using Euclidean
# distance.
within_threshold = (
jnp.linalg.norm(jnp.array(tap_1_yx) - jnp.array(tap_2_yx))
<= matching_tap_distance_threshold_screen_percentage
)
return jnp.logical_or(both_in_box, within_threshold)
def _check_drag_actions_match(
drag_1_touch_yx,
drag_1_lift_yx,
drag_2_touch_yx,
drag_2_lift_yx,
):
"""Determines if two drag actions are the same."""
# Store drag deltas (the change in the y and x coordinates from touch to
# lift), magnitudes, and the index of the main axis, which is the axis with
# the greatest change in coordinate value (e.g. a drag starting at (0, 0) and
# ending at (0.3, 0.5) has a main axis index of 1).
drag_1_deltas = drag_1_lift_yx - drag_1_touch_yx
drag_1_magnitudes = jnp.abs(drag_1_deltas)
drag_1_main_axis = np.argmax(drag_1_magnitudes)
drag_2_deltas = drag_2_lift_yx - drag_2_touch_yx
drag_2_magnitudes = jnp.abs(drag_2_deltas)
drag_2_main_axis = np.argmax(drag_2_magnitudes)
return jnp.equal(drag_1_main_axis, drag_2_main_axis) #只判断滑动的方向
def check_actions_match(
action_1_touch_yx,
action_1_lift_yx,
action_1_action_type,
action_2_touch_yx,
action_2_lift_yx,
action_2_action_type,
annotation_positions,
tap_distance_threshold = _TAP_DISTANCE_THRESHOLD,
annotation_width_augment_fraction = ANNOTATION_WIDTH_AUGMENT_FRACTION,
annotation_height_augment_fraction = ANNOTATION_HEIGHT_AUGMENT_FRACTION,
):
"""Determines if two actions are considered to be the same.
Two actions being "the same" is defined here as two actions that would result
in a similar screen state.
Args:
action_1_touch_yx: The (y, x) coordinates of the first action's touch.
action_1_lift_yx: The (y, x) coordinates of the first action's lift.
action_1_action_type: The action type of the first action.
action_2_touch_yx: The (y, x) coordinates of the second action's touch.
action_2_lift_yx: The (y, x) coordinates of the second action's lift.
action_2_action_type: The action type of the second action.
annotation_positions: The positions of the UI annotations for the screen. It
is A 2D int array of shape (num_bboxes, 4), where each row represents a
bounding box: (y_top_left, x_top_left, box_height, box_width). Note that
containment is inclusive of the bounding box edges.
tap_distance_threshold: The threshold that determines if two taps result in
a matching screen state if they don't fall the same bounding boxes.
annotation_width_augment_fraction: The fraction to increase the width of the
bounding box by.
annotation_height_augment_fraction: The fraction to increase the height of
of the bounding box by.
Returns:
A boolean representing whether the two given actions are the same or not.
"""
action_1_touch_yx = jnp.asarray(action_1_touch_yx)
action_1_lift_yx = jnp.asarray(action_1_lift_yx)
action_2_touch_yx = jnp.asarray(action_2_touch_yx)
action_2_lift_yx = jnp.asarray(action_2_lift_yx)
# Checks if at least one of the actions is global (i.e. not DUAL_POINT),
# because if that is the case, only the actions' types need to be compared.
has_non_dual_point_action = jnp.logical_or(
_is_non_dual_point_action(action_1_action_type),
_is_non_dual_point_action(action_2_action_type),
)
#print("non dual point: "+str(has_non_dual_point_action))
different_dual_point_types = jnp.logical_xor(
is_tap_action(action_1_touch_yx, action_1_lift_yx),
is_tap_action(action_2_touch_yx, action_2_lift_yx),
)
#print("different dual type: "+str(different_dual_point_types))
is_tap = jnp.logical_and(
is_tap_action(action_1_touch_yx, action_1_lift_yx),
is_tap_action(action_2_touch_yx, action_2_lift_yx),
)
#print("is tap: "+str(is_tap))
taps_match = _check_tap_actions_match(
action_1_touch_yx,
action_2_touch_yx,
annotation_positions,
tap_distance_threshold,
annotation_width_augment_fraction,
annotation_height_augment_fraction,
)
#print("tap match: "+str(taps_match))
taps_match = jnp.logical_and(is_tap, taps_match)
#print("tap match: "+str(taps_match))
drags_match = _check_drag_actions_match(
action_1_touch_yx, action_1_lift_yx, action_2_touch_yx, action_2_lift_yx
)
drags_match = jnp.where(is_tap, False, drags_match)
#print("drag match: "+str(drags_match))
return jnp.where(
has_non_dual_point_action,
jnp.equal(action_1_action_type, action_2_action_type),
jnp.where(
different_dual_point_types,
False,
jnp.logical_or(taps_match, drags_match),
),
)
def action_2_format(step_data):
# 把test数据集中的动作格式转换为计算matching score的格式
action_type = step_data["action_type_id"]
if action_type == 4:
if step_data["action_type_text"] == 'click': # 点击
touch_point = step_data["touch"]
lift_point = step_data["lift"]
else: # 上下左右滑动
if step_data["action_type_text"] == 'scroll down':
touch_point = [0.5, 0.8]
lift_point = [0.5, 0.2]
elif step_data["action_type_text"] == 'scroll up':
touch_point = [0.5, 0.2]
lift_point = [0.5, 0.8]
elif step_data["action_type_text"] == 'scroll left':
touch_point = [0.2, 0.5]
lift_point = [0.8, 0.5]
elif step_data["action_type_text"] == 'scroll right':
touch_point = [0.8, 0.5]
lift_point = [0.2, 0.5]
else:
touch_point = [-1.0, -1.0]
lift_point = [-1.0, -1.0]
if action_type == 3:
typed_text = step_data["type_text"]
else:
typed_text = ""
action = {"action_type": action_type, "touch_point": touch_point, "lift_point": lift_point,
"typed_text": typed_text}
action["touch_point"] = [action["touch_point"][1], action["touch_point"][0]]
action["lift_point"] = [action["lift_point"][1], action["lift_point"][0]]
action["typed_text"] = action["typed_text"].lower()
return action
def pred_2_format(step_data):
# 把模型输出的内容转换为计算action_matching的格式
action_type = step_data["action_type"]
if action_type == 4: # 点击
action_type_new = 4
touch_point = step_data["click_point"]
lift_point = step_data["click_point"]
typed_text = ""
elif action_type == 0:
action_type_new = 4
touch_point = [0.5, 0.8]
lift_point = [0.5, 0.2]
typed_text = ""
elif action_type == 1:
action_type_new = 4
touch_point = [0.5, 0.2]
lift_point = [0.5, 0.8]
typed_text = ""
elif action_type == 8:
action_type_new = 4
touch_point = [0.2, 0.5]
lift_point = [0.8, 0.5]
typed_text = ""
elif action_type == 9:
action_type_new = 4
touch_point = [0.8, 0.5]
lift_point = [0.2, 0.5]
typed_text = ""
else:
action_type_new = action_type
touch_point = [-1.0, -1.0]
lift_point = [-1.0, -1.0]
typed_text = ""
if action_type_new == 3:
typed_text = step_data["typed_text"]
action = {"action_type": action_type_new, "touch_point": touch_point, "lift_point": lift_point,
"typed_text": typed_text}
action["touch_point"] = [action["touch_point"][1], action["touch_point"][0]]
action["lift_point"] = [action["lift_point"][1], action["lift_point"][0]]
action["typed_text"] = action["typed_text"].lower()
return action
def pred_2_format_4_mpgui(step_data,img_filename=''):
# 把模型输出的内容转换为计算action_matching的格式
action_type = step_data["action_type"]
if img_filename != '':
img_path = 'AITW_simplified/aitw_images/' + img_filename
w, h = Image.open(img_path).size
else:
w, h = 1000, 1000
if action_type == 4: # 点击
action_type_new = 4
if 'click_point' in step_data:
touch_point = step_data["click_point"]
lift_point = step_data["click_point"]
# for MP-GUI
if touch_point[0] > 1.:
touch_point = [touch_point[0]/w, touch_point[1]/h]
if lift_point[0] > 1:
lift_point = [lift_point[0]/w, lift_point[1]/h]
else:
print(f'$$ error pred step: {step_data}')
touch_point = [0., 0.]
lift_point = [0., 0.]
typed_text = ""
elif action_type == 0:
action_type_new = 4
touch_point = [0.5, 0.8]
lift_point = [0.5, 0.2]
typed_text = ""
elif action_type == 1:
action_type_new = 4
touch_point = [0.5, 0.2]
lift_point = [0.5, 0.8]
typed_text = ""
elif action_type == 8:
action_type_new = 4
touch_point = [0.2, 0.5]
lift_point = [0.8, 0.5]
typed_text = ""
elif action_type == 9:
action_type_new = 4
touch_point = [0.8, 0.5]
lift_point = [0.2, 0.5]
typed_text = ""
else:
action_type_new = action_type
touch_point = [-1.0, -1.0]
lift_point = [-1.0, -1.0]
typed_text = ""
if action_type_new == 3:
typed_text = step_data["typed_text"]
action = {"action_type": action_type_new, "touch_point": touch_point, "lift_point": lift_point,
"typed_text": typed_text}
action["touch_point"] = [action["touch_point"][1], action["touch_point"][0]]
action["lift_point"] = [action["lift_point"][1], action["lift_point"][0]]
action["typed_text"] = action["typed_text"].lower()
return action
def convert_qwen_format(response):
pred_action = response
# pred_action = response.split('### Action ###')[-1].strip()
# print(pred_action)
item = {}
if 'Click' in pred_action:
action_id = 4
try:
x, y = pred_action.split('(')[-1].split(')')[0].split(',')
x, y = int(x), int(y)
except:
x,y = 0, 0
item = {
'action_type': action_id,
'click_point': (x,y)
}
elif 'Scroll("up")' in pred_action:
item = {
'action_type': 1
}
elif 'Scroll("down")' in pred_action:
item = {
'action_type': 0
}
elif 'Scroll("left")' in pred_action:
item = {
'action_type': 8
}
elif 'Scroll("right")' in pred_action:
item = {
'action_type': 9
}
elif 'Type' in pred_action:
text = pred_action.split('("')[-1].split('")')[0]
item = {
'action_type': 3,
'typed_text': text
}
elif 'Complete' in pred_action:
item ={
'action_type': 10
}
elif 'Back' in pred_action:
item ={
'action_type': 5
}
elif 'Home' in pred_action:
item ={
'action_type': 6
}
elif 'Enter' in pred_action:
item ={
'action_type': 7
}
else:
item ={
'action_type': 2 #error
}
return item
# def convert_qwen_format_mind2web(response):
# pred_action = response#.split('### Action')[-1].strip()
# item = {}
# if 'Click' in pred_action:
# try:
# x, y = pred_action.split('(')[-1].split(')')[0].split(',')
# x, y = int(x), int(y)
# click_point = (x, y)
# except:
# x,y = 0, 0
# click_point = (x, y)
# item = {"action_type": 4, "click_point": click_point}
# elif 'Type' in pred_action:
# try:
# # Type(x,y,"typed_text")
# s = pred_action.split('(')[-1]
# x, y, tp_txt = s.split(',')
# x, y = int(x), int(y)
# click_point = (x, y)
# select_value = tp_txt.replace('"','').replace(')', '')
# except:
# click_point = (0,0)
# select_value = ''
# item = {"action_type": 3, "click_point": click_point, "value": select_value}
# elif 'Select' in pred_action:
# try:
# s = pred_action.split('(')[-1]
# x, y, tp_txt = s.split(',')
# x, y = int(x), int(y)
# click_point = (x, y)
# select_value = tp_txt.replace('"','').replace(')', '')
# except:
# click_point = (0,0)
# select_value = ''
# item = {"action_type": 3, "click_point": click_point, "value": select_value}
# else:
# item = {"action_type": 0, "click_point": (0,0)}
# return item
def convert_qwen_format_mind2web(response):
pred_action = response#.split('### Action')[-1].strip()
item = {}
if 'Click' in pred_action:
try:
# print(pred_action)
s = pred_action[pred_action.find('(')+1:pred_action.rfind(')')]
x, y = s.split(',')[0], s.split(',')[1]
x, y = int(x), int(y)
click_point = (x, y)
except:
x,y = 0, 0
click_point = (x, y)
item = {"action_type": 4, "click_point": click_point}
elif 'Type' in pred_action:
try:
# print(pred_action)
s = pred_action[pred_action.find('(')+1:pred_action.rfind(')')]
x, y, tp_txt = s.split(',')[0], s.split(',')[1], ','.join(s.split(',')[2:])
x, y = int(x), int(y)
click_point = (x, y)
typed_text = tp_txt[tp_txt.find('"')+1:tp_txt.rfind('"')]
# print(select_value)
except:
click_point = (0,0)
typed_text = ''
item = {"action_type": 3, "click_point": click_point, "value": typed_text}
elif 'Select' in pred_action:
try:
# print(pred_action)
s = pred_action[pred_action.find('(')+1:pred_action.rfind(')')]
x, y, tp_txt = s.split(',')[0], s.split(',')[1], ','.join(s.split(',')[2:])
x, y = int(x), int(y)
click_point = (x, y)
select_value = tp_txt[tp_txt.find('"')+1:tp_txt.rfind('"')]
# print(select_value)
except:
click_point = (0,0)
select_value = ''
item = {"action_type": 2, "click_point": click_point, "value": select_value}
else:
item = {"action_type": 0, "click_point": (0,0)}
return item
def convert_qwen_format_mind2web_InternVL(response):
pred_action = response#.split('### Action')[-1].strip()
item = {}
if 'Click' in pred_action:
try:
# print(pred_action)
s = pred_action[pred_action.find('(')+1:pred_action.rfind(')')]
[x1, y1, x2, y2] = s.split(',')
x, y = (int(x1)+int(x2))/2, (int(y1)+int(y2))/2
click_point = (x, y)
except:
x,y = 0, 0
click_point = (x, y)
item = {"action_type": 4, "click_point": click_point}
elif 'Type' in pred_action:
try:
# print(pred_action)
s = pred_action[pred_action.find('(')+1:pred_action.rfind(')')]
x1, y1, x2, y2, tp_txt = s.split(',')[0], s.split(',')[1], s.split(',')[2], s.split(',')[3],','.join(s.split(',')[4:])
x, y = (int(x1)+int(x2))/2, (int(y1)+int(y2))/2
click_point = (x, y)
typed_text = tp_txt[tp_txt.find('"')+1:tp_txt.rfind('"')]
# print(select_value)
except:
click_point = (0,0)
typed_text = ''
item = {"action_type": 3, "click_point": click_point, "value": typed_text}
elif 'Select' in pred_action:
try:
# print(pred_action)
s = pred_action[pred_action.find('(')+1:pred_action.rfind(')')]
x1, y1, x2, y2, tp_txt = s.split(',')[0], s.split(',')[1], s.split(',')[2], s.split(',')[3],','.join(s.split(',')[4:])
x, y = (int(x1)+int(x2))/2, (int(y1)+int(y2))/2
click_point = (x, y)
select_value = tp_txt[tp_txt.find('"')+1:tp_txt.rfind('"')]
# print(select_value)
except:
click_point = (0,0)
select_value = ''
item = {"action_type": 2, "click_point": click_point, "value": select_value}
else:
item = {"action_type": 0, "click_point": (0,0)}
return item
def simple_decode(gt, img_path=None):
idx = gt.find(':')
if idx == -1:
action = gt
info = ""
else:
action = gt[:idx].strip()
info = gt[idx+1:].strip()
if action in ['CLICK', "LONG_PRESS"]:
info = eval(info)
if img_path is not None:
img_path = 'GUI-Odyssey-master/data/screenshots/' + img_path
w, h = Image.open(img_path).size
info = (info[0] / w * 1000, info[1] / h * 1000)
return {"action": action, "info": info}
TEXT_ANLS_THRESHOLD = 0.5
CLICK_COORD_THRESHOLD = 0.14
def levenshtein_distance(s1, s2):
if len(s1) > len(s2):
s1, s2 = s2, s1
distances = range(len(s1) + 1)
for i2, c2 in enumerate(s2):
distances_ = [i2+1]
for i1, c1 in enumerate(s1):
if c1 == c2:
distances_.append(distances[i1])
else:
distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
distances = distances_
return distances[-1]
def text_matching(gt, pred):
gt = gt.strip()
pred = pred.strip()
if gt in pred or pred in gt:
return True
dist = levenshtein_distance(gt, pred)
length = max(len(gt), len(pred))
value = 0.0 if length == 0 else float(dist) / float(length)
value = 1 - value
return value >= TEXT_ANLS_THRESHOLD
def click_matching(gt_info, pred_info):
if type(pred_info) == str:
pred_info = eval(pred_info)
if type(gt_info) == str:
gt_info = eval(gt_info)
pred = np.asarray(pred_info) / 1000
gt = np.asarray(gt_info) / 1000
return np.linalg.norm(pred - gt) <= CLICK_COORD_THRESHOLD
def action_matching(pred_action, pred_info, gt_action, gt_info):
pred_action = pred_action.strip()
if type(pred_info) == str:
pred_info = pred_info.strip()
gt_action = gt_action.strip()
if type(gt_info) == str:
gt_info = gt_info.strip()
if pred_action != gt_action:
return {'is_correct': 'no', 'info': 'action_fail'}
if gt_action not in ['SCROLL', 'CLICK', 'TYPE', 'LONG_PRESS']:
return {'is_correct': 'yes', 'info': 'action_correct'}
elif gt_action == 'TYPE':
text_flag = text_matching(gt_info, pred_info)
if text_flag:
return {'is_correct': 'yes', 'info': 'type_correct'}
else:
return {'is_correct': 'no', 'info': 'type_fail'}
elif gt_action == 'SCROLL':
if gt_info.lower() == pred_info.lower():
return {'is_correct': 'yes', 'info': 'scroll_correct'}
else:
return {'is_correct': 'no', 'info': 'scroll_fail'}
elif gt_action == 'CLICK' or gt_action == 'LONG_PRESS':
click_flag = click_matching(gt_info, pred_info)
if click_flag:
return {'is_correct': 'yes', 'info': 'click_correct'}
else:
return {'is_correct': 'no', 'info': 'click_fail'}
else:
raise ValueError('Invalid action type')
def stat_result(eval_dict, metric):
text_correct = sum([1 for _ in eval_dict if _['info'] == 'type_correct'])
type_correct = sum([1 for _ in eval_dict if _['info'] != 'action_fail'])
text_total = sum([1 for _ in eval_dict if _['info'].startswith('type_')])
if metric == 'macro':
action_correct = sum([1 for _ in eval_dict if _['is_correct'] == 'yes'])
AMS = round(action_correct / len(eval_dict) * 100, 2)
SR_cnt, SR_tot, SR = check_SR(eval_dict)
elif metric == 'micro':
task_cate_dict = {}
acc_list = []
SR_list = []
# print(eval_dict)
for sample in eval_dict:
cat = sample['more_info']['category']
if cat not in task_cate_dict:
task_cate_dict[cat] = []
task_cate_dict[cat].append(sample)
# assert len(task_cate_dict) == 6 #总共6个类别的数据,跑部分数据可以注释掉
for k, v in task_cate_dict.items():
SR_cnt, SR_tot, SR = check_SR(v)
SR_list.append((SR))
acc = round(sum([1 for x in v if x['is_correct'] == 'yes']) / len(v) * 100, 2)
acc_list.append(acc)
print(f'category: {k}, AMS: {acc}, SR: {SR}')
AMS = np.round(np.mean(acc_list), 2)
SR = np.round(np.mean(SR_list), 2)
else:
raise ValueError(f'No metric {metric} found.')
info = {
'AMS': AMS,
'SR': SR,
'total': len(eval_dict),
'action_type': '{} / {} = {:.2f}'.format(type_correct, len(eval_dict), type_correct / len(eval_dict) * 100),
'text': '{} / {} = {:.2f}'.format(text_correct, text_total, text_correct / text_total * 100),
}
return info
def check_SR(eval_dict):
episode_dict = {}
steps_map = {}
for data in eval_dict:
if 'img' in data: img = data['img']
elif 'image' in data: img = data['image']
else: img = data['question'].split('</img>')[0].split('<img>')[1]
img = os.path.basename(img)
tail = img.split('_')[-1]
episode = img.replace(f'_{tail}', '')
if episode not in episode_dict:
episode_dict[episode] = []
else:
assert steps_map[episode] == data['more_info']['step_length']
info = data['is_correct']
episode_dict[episode].append(info)
steps_map[episode] = data['more_info']['step_length']
cnt, tot = 0, 0
# print('=== ',episode_dict)
for k, v in episode_dict.items():
if len(v) != steps_map[k]:
print(f'step length of {k} does not match.')
continue
tot += 1
v = list(set(v))
if len(v) == 1 and v[0] == 'yes':
cnt += 1
SR = round(cnt / tot * 100, 2)
print(f'total episode: {tot}, successful episode: {cnt}, SR: {SR}')
return cnt, tot, SR
def odyssey_action_matching_evaluation(pred_output, metric='macro'):
eval_dict = []
for idx, sample in enumerate(pred_output):
question, pred, gt, more_info = sample['question'], sample['pred'], sample['gt'], sample['more_info']
# sample_eval_dict = {'question': question, 'pred': str(pred), 'gt': str(gt), 'more_info': more_info}
sample_eval_dict = {'question': question, 'pred': str(pred), 'gt': str(gt), 'more_info': more_info, 'img':sample['img']}
gt_simple_info = simple_decode(gt)
gt_action = gt_simple_info['action']
gt_info = gt_simple_info['info']
try:
pred_simple_info = simple_decode(pred, sample['img'])
# print('pred_simple_info:', pred_simple_info)
pred_action = pred_simple_info['action']
pred_info = pred_simple_info['info']
except:
# print('### eval err:', idx, pred)
log_info = {'is_correct': 'no', 'info': 'decode invalid'}
sample_eval_dict.update(log_info)
eval_dict.append(sample_eval_dict)
continue
try:
check_match = action_matching(pred_action, pred_info, gt_action, gt_info)
except Exception as exc:
print('$$$ eval err:', gt, pred, exc)
check_match = {'is_correct': 'no', 'info': 'match invalid'}
sample_eval_dict.update(check_match)
eval_dict.append(sample_eval_dict)
# print('===== ',eval_dict)
info = stat_result(eval_dict, metric)
metrics = {"info": info, "pred": eval_dict}
return metrics
\ No newline at end of file
# evaluation on mind2web
import os
import random
import torch
import json
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import AutoPeftModelForCausalLM
from transformers.generation import GenerationConfig
import re
import logging
import ast
import argparse
from PIL import Image
import numpy as np
from langchain_community.vectorstores import FAISS
from langchain.schema import Document
from langchain_huggingface import HuggingFaceEmbeddings
from collections import deque
from prompts import MIND2WEB_GLOBAL_PLANNING_PROMT, MIND2WEB_OBSERVATION_PROMT, MIND2WEB_PLANNING_PROMT, MIND2WEB_EXECUTION_PROMT, PAGE_SUMMARY_PROMPT, REFERENCE_FORMAT, ACTION_SUMMARY_PROMPT
import action_matching
import requests
# logging.basicConfig(level=print)
def get_global_plan(img_path, goal):
global_plan_prompt = MIND2WEB_GLOBAL_PLANNING_PROMT.replace('<goal>',goal)
global_plan = chat([img_path], global_plan_prompt)
return global_plan.split('### Global Plan ###')[-1].strip()
def get_execution(img_path, action_plan, reference_actions):
exec_prompt = MIND2WEB_EXECUTION_PROMT.replace('<action_plan>',action_plan).replace('<reference>',reference_actions)
response = chat([img_path], exec_prompt)
execution = response.split('### Action ###')[-1].strip()
thought = response.split('### Thought ###')[-1].split('### Action ###')[0].strip()
return thought, execution
def get_observation(img_path, goal, previous_step):
if previous_step == '':
previous_step = '<No previous step has been taken.>'
obs_prompt = MIND2WEB_OBSERVATION_PROMT.replace('<goal>',goal).replace('<history>',previous_step)
observation = chat([img_path], obs_prompt)
return observation
def get_plan_action(img_path, goal, observations, global_plan, reference_actions, previous_step):
if previous_step == '':
previous_step = '<No previous step has been taken.>'
plan_prompt = MIND2WEB_PLANNING_PROMT.replace('<goal>',goal)
plan_prompt = plan_prompt.replace('<observation>',observations)
plan_prompt = plan_prompt.replace('<global_plan>',global_plan)
plan_prompt = plan_prompt.replace('<reference>',reference_actions)
plan_prompt = plan_prompt.replace('<history>',previous_step)
plan_action = chat([img_path], plan_prompt)
return plan_action
def bfs_goals(goal_list, idx, search_document):
if idx is None: return
queue = deque([(idx, 0)])
visited = set()
visited.add(idx)
while queue:
cur_node, cur_depth = queue.popleft()
if cur_depth >= 3: continue
nxt_node_list = search_document[cur_node].metadata['next_page_list']
for nxt_node in nxt_node_list:
if nxt_node['actions'] == []: continue
if nxt_node['goal'] not in goal_list:
goal_list.append(nxt_node['goal'])
node_idx = nxt_node['page_index']
if node_idx is not None and node_idx not in visited:
visited.add(node_idx)
queue.append((node_idx, cur_depth+1))
def get_reference_actions(img_path, domain, goal, search_document, embedding_model):
reference_actions = ''
page_summary = chat([img_path], PAGE_SUMMARY_PROMPT)
if domain in list(search_document.keys()):
search_keys = [domain]
else:
search_keys = search_document.keys()
max_count_final = 10
count = 0
max_count = 0
for search_key in search_keys:
max_count += int(max_count_final // len(search_keys))
vectorstore = FAISS.from_documents(search_document[search_key], embedding_model)
search_res = vectorstore.similarity_search(page_summary)
for res in search_res:
for actions_chain in res.metadata['next_page_list']:
if len(actions_chain['actions'])==0: continue
count=count+1
action_string = ''
for one_action in actions_chain['actions']:
if one_action[-1] == '.':
one_action = one_action[:-1]
action_string += ', ' + one_action
action_string = action_string[2:]
goal_list = [actions_chain['goal']]
bfs_goals(goal_list, actions_chain['page_index'], search_document[search_key])
goals_string = ''
for one_goal in goal_list:#random.sample(goal_list, min(10,len(goal_list))):#
if one_goal[-1] == '.':
one_goal = one_goal[:-1]
goals_string += '; ' + one_goal
goals_string = goals_string[2:]
one_reference = REFERENCE_FORMAT.format(idx = count, actions = action_string, goals = goals_string)
reference_actions += one_reference
if count == max_count: break
if count == max_count: break
return reference_actions
def document_transform(raw_document):
search_document = {}
for type_name, pages in raw_document.items():
document = []
for idx in pages:
item = pages[idx]
document.append(Document(page_content = item['page_summary'], metadata = item))
search_document[type_name] = document
return search_document
# convert action to prediction format (and return the groundtruth bbox)
def action2description(action, img_path):
action_type = action['operation']['op']
assert action_type in ['CLICK', 'TYPE', 'SELECT']
bbox = [int(action["bbox"]["x"]), int(action["bbox"]["y"]), int(action["bbox"]["x"] + action["bbox"]["width"]),
int(action["bbox"]["y"] + action["bbox"]["height"])]
bbox_str = f'[{bbox[0]}, {bbox[1]}, {bbox[2]}, {bbox[3]}]'
if action_type == 'CLICK':
query = ACTION_SUMMARY_PROMPT['click_action_summary'].format(bbox=bbox_str)
elif action_type == 'TYPE':
query = ACTION_SUMMARY_PROMPT['type_action_summary'].format(content=action['operation']['value'],bbox=bbox_str)
elif action_type == 'SELECT':
query = ACTION_SUMMARY_PROMPT['select_action_summary'].format(content=action['operation']['value'],bbox=bbox_str)
action_summary = chat([img_path], query)
if action_summary[-1] == '.':
action_summary = action_summary[:-1]
return action_summary
def action2step(action, image_size, return_bbox=False):
action_type = action["operation"]["original_op"]
assert action_type in ['CLICK', 'TYPE', 'SELECT', 'HOVER', 'ENTER']
point_x = action["bbox"]["x"] + (action["bbox"]["width"] / 2)
point_y = action["bbox"]["y"] + (action["bbox"]["height"] / 2)
click_point = [point_x / image_size[0], point_y / image_size[1]]
click_point = [round(item, 3) for item in click_point]
click_point = [f"{item:.2f}" for item in click_point]
click_point = "({},{})".format(int(float(click_point[0])*1000), int(float(click_point[1])*1000))
if return_bbox:
bbox = [action["bbox"]["x"], action["bbox"]["y"], action["bbox"]["x"] + action["bbox"]["width"],
action["bbox"]["y"] + action["bbox"]["height"]]
bbox = [bbox[0] / image_size[0], bbox[1] / image_size[1], bbox[2] / image_size[0], bbox[3] / image_size[1]]
bbox = [round(item, 3)*1000 for item in bbox]
if action_type in ['CLICK', 'HOVER', 'ENTER']:
action_step = "{{\"action_type\": {}, \"click_point\": {}}}".format(4, click_point)
elif action_type == 'SELECT':
select_value = action["operation"]["value"]
action_step = "{{\"action_type\": {}, \"click_point\": {}, \"value\": \"{}\"}}".format(2, click_point,
select_value)
elif action_type == 'TYPE':
typed_text = action["operation"]["value"]
action_step = "{{\"action_type\": {}, \"click_point\": {}, \"value\": \"{}\"}}".format(3, click_point,
typed_text)
if return_bbox:
return action_step, bbox
else:
return action_step
# calculate action f1 following mind2web
def calculate_f1(pred, label):
pred = set(pred.strip().split())
label = set(label.strip().split())
if len(pred) == 0 and len(label) == 0:
return 1
if len(pred) == 0 or len(label) == 0:
return 0
tp = len(pred & label)
fp = len(pred - label)
fn = len(label - pred)
precision = tp / (tp + fp)
recall = tp / (tp + fn)
if precision == 0 or recall == 0:
return 0
f1 = 2 * precision * recall / (precision + recall)
return f1
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
random.seed(0)
url = "http://localhost:8000/v1/chat/completions"
headers = {
"Content-Type": "application/json"
}
def chat(img_url_list: str = '', query: str = '') -> dict:
content = []
for img_url in img_url_list:
content.append({"type": "image_url", "image_url": {"url": img_url}})
content.append({"type": "text", "text": query})
data = {
"model": "Qwen2.5-VL-72B-Instruct",
"messages": [
{"role": "system", "content": "You are a powerful agent that is trained to perform some basic tasks on the web page."},
{"role": "user", "content": content}
],
"temperature":0}
response = requests.post(url, headers=headers, data=json.dumps(data))
response = response.json()
response = response['choices'][0]['message']['content']
return response
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, required=True)
args = parser.parse_args()
embedding_model_name = "bge-m3"
embedding_model = HuggingFaceEmbeddings(model_name = embedding_model_name,model_kwargs={'device': 'cuda:0'})
mind2web_imgs_dir = 'mind2web_images/'
mind2web_test = json.load(open('mind2web_annots/mind2web_data_test_' + args.task + '.json', 'r'))
raw_document = json.load(open('mind2web_library.json', 'r'))
search_document = document_transform(raw_document)
results = []
for episode in tqdm(mind2web_test):
domain = episode['domain']
goal = episode["confirmed_task"]
annot_id = episode["annotation_id"]
previous_actions = []
results_actions = []
global_plan = ''
flag = 0
for j, step in enumerate(episode["actions"]):
if "bbox" not in step:
print("action not found")
continue
filename = annot_id + '-' + step["action_uid"] + '.jpg'
img_path = os.path.join(mind2web_imgs_dir, filename)
img_path_server = 'http://localhost:6667/mind2web_images/' + filename
if not os.path.exists(img_path):
print("img not found")
continue
image = Image.open(img_path)
previous_step = ""
for i, action in enumerate(previous_actions[-4:]):
previous_step += 'Step' + str(i+1) + ': ' + action + ". \n"
action_step = action2description(step, img_path_server)
previous_actions.append(action_step)
action_step_ref, bbox_ref = action2step(step, [1000,1000], return_bbox=True)
try:
action_step_ref = ast.literal_eval(action_step_ref)
except:
print('# error action_step_ref')
continue
observations = get_observation(img_path_server, goal, previous_step)
reference_actions = get_reference_actions(img_path_server, domain, goal, search_document, embedding_model)
if flag == 0:
global_plan = get_global_plan(img_path_server, goal)
flag = 1
plan_action = get_plan_action(img_path_server, goal, observations, global_plan, reference_actions, previous_step)
thought, response = get_execution(img_path_server, plan_action, reference_actions)
step_result = {"annot_id": annot_id, "step" : j+1, "img_path": img_path, "instruction": goal, "sentence": response,
"Op_match": False, "Ele_match": False, "Op_F1": [0, action_step_ref["action_type"]]}
if 0 < 1:
action_pred = action_matching.convert_qwen_format_mind2web(response)
if action_pred["action_type"] == action_step_ref["action_type"]:
step_result["Op_match"] = True
click_point = action_pred["click_point"]
if (bbox_ref[0] <= click_point[0] <= bbox_ref[2]) and (bbox_ref[1] <= click_point[1] <= bbox_ref[3]):
step_result["Ele_match"] = True
pred_str = str(action_pred["action_type"])
if action_pred["action_type"] == 3 or action_pred["action_type"] == 2:
pred_str += ' '
pred_str += action_pred["value"].lower()
ref_str = str(action_step_ref["action_type"])
if action_step_ref["action_type"] == 3 or action_step_ref["action_type"] == 2:
ref_str += ' '
ref_str += action_step_ref["value"].lower()
op_f1 = calculate_f1(pred_str, ref_str)
step_result["Op_F1"][0] = op_f1
print(f'-------step:{j+1}----------')
print('Goal: ', goal)
print('Img: ', img_path)
print('History: ', previous_step)
print('gt: ', step['operation']['op'],step['operation']['value'],bbox_ref)
print('Observation: ', observations)
print('pred: ', action_pred)
print('Global Planning: \n', global_plan)
print('References: \n',reference_actions)
print('Loacl Planning: \n', plan_action)
print('Thought: ',thought)
print('Decision: \n', response)
print('---------------------------------------------')
results_actions.append(step_result)
results.append(results_actions)
f_json = open(f'mind2web_record.json', 'w')
json.dump(results, f_json, ensure_ascii=False, indent=4)
f_json.close()
# calculate metrics
num_step = 0
num_episode = 0
num_op = 0
num_ele = 0
op_f1 = {4: [], 2: [], 3: []}
macro_ele_acc = {}
macro_step_acc = {}
macro_action_f1 = {}
num_step_success = 0
num_episode_success = 0
for i, item in enumerate(results):
macro_ele_acc[i] = []
macro_step_acc[i] = []
macro_action_f1[i] = []
num_episode += 1
episode_success = True
for step_result in item:
num_step += 1
if step_result["Op_match"]:
num_op += 1
if step_result["Ele_match"]:
num_ele += 1
macro_ele_acc[i].append(1)
else:
macro_ele_acc[i].append(0)
if step_result["Op_F1"][1] in op_f1:
op_f1[step_result["Op_F1"][1]].append(step_result["Op_F1"][0])
macro_action_f1[i].append(step_result["Op_F1"][0])
if step_result["Op_F1"][0] == 1.0 and step_result["Ele_match"]:
num_step_success += 1
macro_step_acc[i].append(1)
else:
macro_step_acc[i].append(0)
episode_success = False
if episode_success:
num_episode_success += 1
marco_op_f1 = np.mean([np.mean(x) for x in op_f1.values()])
print("Operation F1: " + str(marco_op_f1))
print("Element Acc: " + str(num_ele / num_step))
print("Step Success: " + str(num_step_success / num_step))
print("Episode Success: " + str(num_episode_success / num_episode))
print("Operation F1 cate: " + str([np.mean(x) for x in op_f1.values()]))
macro_ele_acc = np.mean([np.mean(x) for x in macro_ele_acc.values()])
macro_step_acc = np.mean([np.mean(x) for x in macro_step_acc.values()])
macro_action_f1 = np.mean([np.mean(x) for x in macro_action_f1.values()])
print("Macro Ele Acc: " + str(macro_ele_acc))
print("Macro Op F1: " + str(macro_action_f1))
print("Macro Step SR: " + str(macro_step_acc))
MIND2WEB_ACTION_SPACE='''
1. Click(x,y): An action of clicking a coordinate point on the web screen and x,y is the position of the coordinate point on the screen.
Your click location should be a UI element or text on the screen.
A simple use case could be Click(100,238), which means you click the UI element at (100,238) on the current screen.
2. Type(x,y,"typed_text"): An action of typing a piece of text at the positon with coordinates x and y.
A simple use case could be Type(340,212,"Where was Obama born?"), which inputs the string "Where was Obama born?" into the input area at the cordinates (340,212) on the web screen.
3. Select(x,y,"option"): An action of opening a \"Select Menu\" or \"Dropdown List\" located at coordinates (x, y) and choose an option you specify.
A simple use case could be Select(679,437,"female"), which opens the list at the coordinates (679,437) and select the option "female" from the list.
'''
MIND2WEB_OBSERVATION_PROMT = f"""
You are a smart GUI agent, capable of comprehensively understanding the GUI interface as well as the user's intentions.
You will be given user's ultimate purpose and the previous actions that you have taken.
Your task is to carefully observe the screen, descripe it and conclude some useful clues in one sentence.
Now you can start to observe:
### User's purpose ###
<goal>
### History trajectory ###
History trajectory can remind you of the operations that have been executed before, thus avoiding repetitive actions.
<history>
### Observation ###
"""
MIND2WEB_GLOBAL_PLANNING_PROMT = f'''
You are an agent that is trained to complete certain tasks on the webpage. You will be given a screenshot of a website.
The global task you should complete is:
\"<goal>\"
Now, carefully analyze all the above content and provide your output in the following format:
### Global Plan ###
Please break down the overall task into 2~3 simple sub-goals.
Note that since you can’t see future webpages, each sub-goal should be abstract, high-level, and not involve interacting with specific UI elements.
'''
MIND2WEB_PLANNING_PROMT = f"""
You are a smart GUI agent, capable of comprehensively understanding the GUI interface as well as the user's intentions.
Your task is to plan the next action to complete user's purpose with the help of references.
I will give you several important information:
### User's purpose ###
This is the user's global purpose, and your goal is to complete it:
<goal>
### Observation ###
This is the observation of the screen and some useful clues that help you plan:
<observation>
### Global Plan ###
This is the global plan for completing user's purpose:
<global_plan>
### History trajectory ###
History trajectory can remind you of the operations that have been executed before, thus avoiding repetitive actions.
<history>
### Reference ###
There are some reference actions that you can follow:
<reference>
Based on given information, you are required to output with following format:
1. <Please decide which sub-goal in the \"### Global Plan ###\" should be executed based on the screen image>
2. <Inspired by \"### Reference ###\", you can list some actions than can possibly push the task progress or complete the goal>
"""
MIND2WEB_EXECUTION_PROMT = f"""
You are a smart GUI agent, capable of comprehensively understanding the GUI interface.
You will be given a screenshot of a website and a plan that you decide to take.
Before you start, I will explain the data format:
### Plan ###
This is your plan:
<action_plan>
### Reference ###
There are some reference actions that you can follow:
<reference>
### Action Space ###
These are the functions to interact with the webpage:
{MIND2WEB_ACTION_SPACE}
Now please choose one action in \"### Action Space ###\" for the current webpage based on \"### Plan ###\" and \"### Reference ###\".
You should output with following format:
### Thought ###
Think step-by-step and output the action that should be taken currently.
### Action ###
Output only one action you finally choose from \"### Action Space ###\". Do not output anything else.
"""
ACTION_SUMMARY_PROMPT = {
'click_action_summary' : 'This is a page of website. The user clicks the item at coordinates {bbox}. You are required to summarize this operation beginning with \"click\". Do not mention original coordinates.',
'type_action_summary' : 'This is a page of website. The user types the content \"{content}\" at coordinates {bbox}. You are required to summarize this operation beginning with \"type\". Do not mention original coordinates.',
'select_action_summary' : 'This is a page of website. The user opens a \"Select Menu\" or \"Dropdown List\" at coordinates {bbox}, and select the option \"{content}\". You are required to summarize this operation beginning with \"select\". Do not mention original coordinates.'
}
PAGE_SUMMARY_PROMPT = 'Please describe this screen containing following content with one full sentence, including \
the type of page, the function of page and the key components of the screen.'
REFERENCE_FORMAT = '''{idx}.
You can take following action: {actions}.
This can help you achieve goals like: {goals}.
'''
\ No newline at end of file
'''
Adapted from https://github.com/google-research/google-research/tree/master/android_in_the_wild
'''
import jax
import jax.numpy as jnp
import numpy as np
import os
import action_type as action_type_lib
from PIL import Image
_TAP_DISTANCE_THRESHOLD = 0.14 # Fraction of the screen
ANNOTATION_WIDTH_AUGMENT_FRACTION = 1.4
ANNOTATION_HEIGHT_AUGMENT_FRACTION = 1.4
# Interval determining if an action is a tap or a swipe.
_SWIPE_DISTANCE_THRESHOLD = 0.04
def _yx_in_bounding_boxes(
yx, bounding_boxes
):
"""Check if the (y,x) point is contained in each bounding box.
Args:
yx: The (y, x) coordinate in pixels of the point.
bounding_boxes: A 2D int array of shape (num_bboxes, 4), where each row
represents a bounding box: (y_top_left, x_top_left, box_height,
box_width). Note: containment is inclusive of the bounding box edges.
Returns:
is_inside: A 1D bool array where each element specifies if the point is
contained within the respective box.
"""
y, x = yx
# `bounding_boxes` has shape (n_elements, 4); we extract each array along the
# last axis into shape (n_elements, 1), then squeeze unneeded dimension.
top, left, height, width = [
jnp.squeeze(v, axis=-1) for v in jnp.split(bounding_boxes, 4, axis=-1)
]
# The y-axis is inverted for AndroidEnv, so bottom = top + height.
bottom, right = top + height, left + width
return jnp.logical_and(y >= top, y <= bottom) & jnp.logical_and(
x >= left, x <= right)
def _resize_annotation_bounding_boxes(
annotation_positions, annotation_width_augment_fraction,
annotation_height_augment_fraction):
"""Resize the bounding boxes by the given fractions.
Args:
annotation_positions: Array of shape (N, 4), where each row represents the
(y, x, height, width) of the bounding boxes.
annotation_width_augment_fraction: The fraction to augment the box widths,
E.g., 1.4 == 240% total increase.
annotation_height_augment_fraction: Same as described for width, but for box
height.
Returns:
Resized bounding box.
"""
height_change = (
annotation_height_augment_fraction * annotation_positions[:, 2])
width_change = (
annotation_width_augment_fraction * annotation_positions[:, 3])
# Limit bounding box positions to the screen.
resized_annotations = jnp.stack([
jnp.maximum(0, annotation_positions[:, 0] - (height_change / 2)),
jnp.maximum(0, annotation_positions[:, 1] - (width_change / 2)),
jnp.minimum(1, annotation_positions[:, 2] + height_change),
jnp.minimum(1, annotation_positions[:, 3] + width_change),
],
axis=1)
return resized_annotations
def is_tap_action(normalized_start_yx,
normalized_end_yx):
distance = jnp.linalg.norm(
jnp.array(normalized_start_yx) - jnp.array(normalized_end_yx))
return distance <= _SWIPE_DISTANCE_THRESHOLD
def _is_non_dual_point_action(action_type):
return jnp.not_equal(action_type, action_type_lib.ActionType.DUAL_POINT)
def _check_tap_actions_match(
tap_1_yx,
tap_2_yx,
annotation_positions,
matching_tap_distance_threshold_screen_percentage,
annotation_width_augment_fraction,
annotation_height_augment_fraction,
):
"""Determines if two tap actions are the same."""
resized_annotation_positions = _resize_annotation_bounding_boxes(
annotation_positions,
annotation_width_augment_fraction,
annotation_height_augment_fraction,
)
# Check if the ground truth tap action falls in an annotation's bounding box.
tap1_in_box = _yx_in_bounding_boxes(tap_1_yx, resized_annotation_positions)
tap2_in_box = _yx_in_bounding_boxes(tap_2_yx, resized_annotation_positions)
both_in_box = jnp.max(tap1_in_box & tap2_in_box)
# If the ground-truth tap action falls outside any of the annotation
# bounding boxes or one of the actions is inside a bounding box and the other
# is outside bounding box or vice versa, compare the points using Euclidean
# distance.
within_threshold = (
jnp.linalg.norm(jnp.array(tap_1_yx) - jnp.array(tap_2_yx))
<= matching_tap_distance_threshold_screen_percentage
)
return jnp.logical_or(both_in_box, within_threshold)
def _check_drag_actions_match(
drag_1_touch_yx,
drag_1_lift_yx,
drag_2_touch_yx,
drag_2_lift_yx,
):
"""Determines if two drag actions are the same."""
# Store drag deltas (the change in the y and x coordinates from touch to
# lift), magnitudes, and the index of the main axis, which is the axis with
# the greatest change in coordinate value (e.g. a drag starting at (0, 0) and
# ending at (0.3, 0.5) has a main axis index of 1).
drag_1_deltas = drag_1_lift_yx - drag_1_touch_yx
drag_1_magnitudes = jnp.abs(drag_1_deltas)
drag_1_main_axis = np.argmax(drag_1_magnitudes)
drag_2_deltas = drag_2_lift_yx - drag_2_touch_yx
drag_2_magnitudes = jnp.abs(drag_2_deltas)
drag_2_main_axis = np.argmax(drag_2_magnitudes)
return jnp.equal(drag_1_main_axis, drag_2_main_axis) #只判断滑动的方向
def check_actions_match(
action_1_touch_yx,
action_1_lift_yx,
action_1_action_type,
action_2_touch_yx,
action_2_lift_yx,
action_2_action_type,
annotation_positions,
tap_distance_threshold = _TAP_DISTANCE_THRESHOLD,
annotation_width_augment_fraction = ANNOTATION_WIDTH_AUGMENT_FRACTION,
annotation_height_augment_fraction = ANNOTATION_HEIGHT_AUGMENT_FRACTION,
):
"""Determines if two actions are considered to be the same.
Two actions being "the same" is defined here as two actions that would result
in a similar screen state.
Args:
action_1_touch_yx: The (y, x) coordinates of the first action's touch.
action_1_lift_yx: The (y, x) coordinates of the first action's lift.
action_1_action_type: The action type of the first action.
action_2_touch_yx: The (y, x) coordinates of the second action's touch.
action_2_lift_yx: The (y, x) coordinates of the second action's lift.
action_2_action_type: The action type of the second action.
annotation_positions: The positions of the UI annotations for the screen. It
is A 2D int array of shape (num_bboxes, 4), where each row represents a
bounding box: (y_top_left, x_top_left, box_height, box_width). Note that
containment is inclusive of the bounding box edges.
tap_distance_threshold: The threshold that determines if two taps result in
a matching screen state if they don't fall the same bounding boxes.
annotation_width_augment_fraction: The fraction to increase the width of the
bounding box by.
annotation_height_augment_fraction: The fraction to increase the height of
of the bounding box by.
Returns:
A boolean representing whether the two given actions are the same or not.
"""
action_1_touch_yx = jnp.asarray(action_1_touch_yx)
action_1_lift_yx = jnp.asarray(action_1_lift_yx)
action_2_touch_yx = jnp.asarray(action_2_touch_yx)
action_2_lift_yx = jnp.asarray(action_2_lift_yx)
# Checks if at least one of the actions is global (i.e. not DUAL_POINT),
# because if that is the case, only the actions' types need to be compared.
has_non_dual_point_action = jnp.logical_or(
_is_non_dual_point_action(action_1_action_type),
_is_non_dual_point_action(action_2_action_type),
)
#print("non dual point: "+str(has_non_dual_point_action))
different_dual_point_types = jnp.logical_xor(
is_tap_action(action_1_touch_yx, action_1_lift_yx),
is_tap_action(action_2_touch_yx, action_2_lift_yx),
)
#print("different dual type: "+str(different_dual_point_types))
is_tap = jnp.logical_and(
is_tap_action(action_1_touch_yx, action_1_lift_yx),
is_tap_action(action_2_touch_yx, action_2_lift_yx),
)
#print("is tap: "+str(is_tap))
taps_match = _check_tap_actions_match(
action_1_touch_yx,
action_2_touch_yx,
annotation_positions,
tap_distance_threshold,
annotation_width_augment_fraction,
annotation_height_augment_fraction,
)
#print("tap match: "+str(taps_match))
taps_match = jnp.logical_and(is_tap, taps_match)
#print("tap match: "+str(taps_match))
drags_match = _check_drag_actions_match(
action_1_touch_yx, action_1_lift_yx, action_2_touch_yx, action_2_lift_yx
)
drags_match = jnp.where(is_tap, False, drags_match)
#print("drag match: "+str(drags_match))
return jnp.where(
has_non_dual_point_action,
jnp.equal(action_1_action_type, action_2_action_type),
jnp.where(
different_dual_point_types,
False,
jnp.logical_or(taps_match, drags_match),
),
)
def action_2_format(step_data):
# 把test数据集中的动作格式转换为计算matching score的格式
action_type = step_data["action_type_id"]
if action_type == 4:
if step_data["action_type_text"] == 'click': # 点击
touch_point = step_data["touch"]
lift_point = step_data["lift"]
else: # 上下左右滑动
if step_data["action_type_text"] == 'scroll down':
touch_point = [0.5, 0.8]
lift_point = [0.5, 0.2]
elif step_data["action_type_text"] == 'scroll up':
touch_point = [0.5, 0.2]
lift_point = [0.5, 0.8]
elif step_data["action_type_text"] == 'scroll left':
touch_point = [0.2, 0.5]
lift_point = [0.8, 0.5]
elif step_data["action_type_text"] == 'scroll right':
touch_point = [0.8, 0.5]
lift_point = [0.2, 0.5]
else:
touch_point = [-1.0, -1.0]
lift_point = [-1.0, -1.0]
if action_type == 3:
typed_text = step_data["type_text"]
else:
typed_text = ""
action = {"action_type": action_type, "touch_point": touch_point, "lift_point": lift_point,
"typed_text": typed_text}
action["touch_point"] = [action["touch_point"][1], action["touch_point"][0]]
action["lift_point"] = [action["lift_point"][1], action["lift_point"][0]]
action["typed_text"] = action["typed_text"].lower()
return action
def pred_2_format(step_data):
# 把模型输出的内容转换为计算action_matching的格式
action_type = step_data["action_type"]
if action_type == 4: # 点击
action_type_new = 4
touch_point = step_data["click_point"]
lift_point = step_data["click_point"]
typed_text = ""
elif action_type == 0:
action_type_new = 4
touch_point = [0.5, 0.8]
lift_point = [0.5, 0.2]
typed_text = ""
elif action_type == 1:
action_type_new = 4
touch_point = [0.5, 0.2]
lift_point = [0.5, 0.8]
typed_text = ""
elif action_type == 8:
action_type_new = 4
touch_point = [0.2, 0.5]
lift_point = [0.8, 0.5]
typed_text = ""
elif action_type == 9:
action_type_new = 4
touch_point = [0.8, 0.5]
lift_point = [0.2, 0.5]
typed_text = ""
else:
action_type_new = action_type
touch_point = [-1.0, -1.0]
lift_point = [-1.0, -1.0]
typed_text = ""
if action_type_new == 3:
typed_text = step_data["typed_text"]
action = {"action_type": action_type_new, "touch_point": touch_point, "lift_point": lift_point,
"typed_text": typed_text}
action["touch_point"] = [action["touch_point"][1], action["touch_point"][0]]
action["lift_point"] = [action["lift_point"][1], action["lift_point"][0]]
action["typed_text"] = action["typed_text"].lower()
return action
def pred_2_format_4_mpgui(step_data,img_filename=''):
# 把模型输出的内容转换为计算action_matching的格式
action_type = step_data["action_type"]
if img_filename != '':
img_path = 'AITW_simplified/aitw_images/' + img_filename
w, h = Image.open(img_path).size
else:
w, h = 1000, 1000
if action_type == 4: # 点击
action_type_new = 4
if 'click_point' in step_data:
touch_point = step_data["click_point"]
lift_point = step_data["click_point"]
# for MP-GUI
if touch_point[0] > 1.:
touch_point = [touch_point[0]/w, touch_point[1]/h]
if lift_point[0] > 1:
lift_point = [lift_point[0]/w, lift_point[1]/h]
else:
print(f'$$ error pred step: {step_data}')
touch_point = [0., 0.]
lift_point = [0., 0.]
typed_text = ""
elif action_type == 0:
action_type_new = 4
touch_point = [0.5, 0.8]
lift_point = [0.5, 0.2]
typed_text = ""
elif action_type == 1:
action_type_new = 4
touch_point = [0.5, 0.2]
lift_point = [0.5, 0.8]
typed_text = ""
elif action_type == 8:
action_type_new = 4
touch_point = [0.2, 0.5]
lift_point = [0.8, 0.5]
typed_text = ""
elif action_type == 9:
action_type_new = 4
touch_point = [0.8, 0.5]
lift_point = [0.2, 0.5]
typed_text = ""
else:
action_type_new = action_type
touch_point = [-1.0, -1.0]
lift_point = [-1.0, -1.0]
typed_text = ""
if action_type_new == 3:
typed_text = step_data["typed_text"]
action = {"action_type": action_type_new, "touch_point": touch_point, "lift_point": lift_point,
"typed_text": typed_text}
action["touch_point"] = [action["touch_point"][1], action["touch_point"][0]]
action["lift_point"] = [action["lift_point"][1], action["lift_point"][0]]
action["typed_text"] = action["typed_text"].lower()
return action
def convert_qwen_format(response):
pred_action = response
# pred_action = response.split('### Action ###')[-1].strip()
# print(pred_action)
item = {}
if 'Click' in pred_action:
action_id = 4
try:
x, y = pred_action.split('(')[-1].split(')')[0].split(',')
x, y = int(x), int(y)
except:
x,y = 0, 0
item = {
'action_type': action_id,
'click_point': (x,y)
}
elif 'Scroll("up")' in pred_action:
item = {
'action_type': 1
}
elif 'Scroll("down")' in pred_action:
item = {
'action_type': 0
}
elif 'Scroll("left")' in pred_action:
item = {
'action_type': 8
}
elif 'Scroll("right")' in pred_action:
item = {
'action_type': 9
}
elif 'Type' in pred_action:
text = pred_action.split('("')[-1].split('")')[0]
item = {
'action_type': 3,
'typed_text': text
}
elif 'Complete' in pred_action:
item ={
'action_type': 10
}
elif 'Back' in pred_action:
item ={
'action_type': 5
}
elif 'Home' in pred_action:
item ={
'action_type': 6
}
elif 'Enter' in pred_action:
item ={
'action_type': 7
}
else:
item ={
'action_type': 2 #error
}
return item
# def convert_qwen_format_mind2web(response):
# pred_action = response#.split('### Action')[-1].strip()
# item = {}
# if 'Click' in pred_action:
# try:
# x, y = pred_action.split('(')[-1].split(')')[0].split(',')
# x, y = int(x), int(y)
# click_point = (x, y)
# except:
# x,y = 0, 0
# click_point = (x, y)
# item = {"action_type": 4, "click_point": click_point}
# elif 'Type' in pred_action:
# try:
# # Type(x,y,"typed_text")
# s = pred_action.split('(')[-1]
# x, y, tp_txt = s.split(',')
# x, y = int(x), int(y)
# click_point = (x, y)
# select_value = tp_txt.replace('"','').replace(')', '')
# except:
# click_point = (0,0)
# select_value = ''
# item = {"action_type": 3, "click_point": click_point, "value": select_value}
# elif 'Select' in pred_action:
# try:
# s = pred_action.split('(')[-1]
# x, y, tp_txt = s.split(',')
# x, y = int(x), int(y)
# click_point = (x, y)
# select_value = tp_txt.replace('"','').replace(')', '')
# except:
# click_point = (0,0)
# select_value = ''
# item = {"action_type": 3, "click_point": click_point, "value": select_value}
# else:
# item = {"action_type": 0, "click_point": (0,0)}
# return item
def convert_qwen_format_mind2web(response):
pred_action = response#.split('### Action')[-1].strip()
item = {}
if 'Click' in pred_action:
try:
# print(pred_action)
s = pred_action[pred_action.find('(')+1:pred_action.rfind(')')]
x, y = s.split(',')[0], s.split(',')[1]
x, y = int(x), int(y)
click_point = (x, y)
except:
x,y = 0, 0
click_point = (x, y)
item = {"action_type": 4, "click_point": click_point}
elif 'Type' in pred_action:
try:
# print(pred_action)
s = pred_action[pred_action.find('(')+1:pred_action.rfind(')')]
x, y, tp_txt = s.split(',')[0], s.split(',')[1], ','.join(s.split(',')[2:])
x, y = int(x), int(y)
click_point = (x, y)
typed_text = tp_txt[tp_txt.find('"')+1:tp_txt.rfind('"')]
# print(select_value)
except:
click_point = (0,0)
typed_text = ''
item = {"action_type": 3, "click_point": click_point, "value": typed_text}
elif 'Select' in pred_action:
try:
# print(pred_action)
s = pred_action[pred_action.find('(')+1:pred_action.rfind(')')]
x, y, tp_txt = s.split(',')[0], s.split(',')[1], ','.join(s.split(',')[2:])
x, y = int(x), int(y)
click_point = (x, y)
select_value = tp_txt[tp_txt.find('"')+1:tp_txt.rfind('"')]
# print(select_value)
except:
click_point = (0,0)
select_value = ''
item = {"action_type": 2, "click_point": click_point, "value": select_value}
else:
item = {"action_type": 0, "click_point": (0,0)}
return item
def convert_qwen_format_mind2web_InternVL(response):
pred_action = response#.split('### Action')[-1].strip()
item = {}
if 'Click' in pred_action:
try:
# print(pred_action)
s = pred_action[pred_action.find('(')+1:pred_action.rfind(')')]
[x1, y1, x2, y2] = s.split(',')
x, y = (int(x1)+int(x2))/2, (int(y1)+int(y2))/2
click_point = (x, y)
except:
x,y = 0, 0
click_point = (x, y)
item = {"action_type": 4, "click_point": click_point}
elif 'Type' in pred_action:
try:
# print(pred_action)
s = pred_action[pred_action.find('(')+1:pred_action.rfind(')')]
x1, y1, x2, y2, tp_txt = s.split(',')[0], s.split(',')[1], s.split(',')[2], s.split(',')[3],','.join(s.split(',')[4:])
x, y = (int(x1)+int(x2))/2, (int(y1)+int(y2))/2
click_point = (x, y)
typed_text = tp_txt[tp_txt.find('"')+1:tp_txt.rfind('"')]
# print(select_value)
except:
click_point = (0,0)
typed_text = ''
item = {"action_type": 3, "click_point": click_point, "value": typed_text}
elif 'Select' in pred_action:
try:
# print(pred_action)
s = pred_action[pred_action.find('(')+1:pred_action.rfind(')')]
x1, y1, x2, y2, tp_txt = s.split(',')[0], s.split(',')[1], s.split(',')[2], s.split(',')[3],','.join(s.split(',')[4:])
x, y = (int(x1)+int(x2))/2, (int(y1)+int(y2))/2
click_point = (x, y)
select_value = tp_txt[tp_txt.find('"')+1:tp_txt.rfind('"')]
# print(select_value)
except:
click_point = (0,0)
select_value = ''
item = {"action_type": 2, "click_point": click_point, "value": select_value}
else:
item = {"action_type": 0, "click_point": (0,0)}
return item
def simple_decode(gt, img_path=None):
idx = gt.find(':')
if idx == -1:
action = gt
info = ""
else:
action = gt[:idx].strip()
info = gt[idx+1:].strip()
if action in ['CLICK', "LONG_PRESS"]:
info = eval(info)
if img_path is not None:
img_path = 'GUI-Odyssey-master/data/screenshots/' + img_path
w, h = Image.open(img_path).size
info = (info[0] / w * 1000, info[1] / h * 1000)
return {"action": action, "info": info}
TEXT_ANLS_THRESHOLD = 0.5
CLICK_COORD_THRESHOLD = 0.14
def levenshtein_distance(s1, s2):
if len(s1) > len(s2):
s1, s2 = s2, s1
distances = range(len(s1) + 1)
for i2, c2 in enumerate(s2):
distances_ = [i2+1]
for i1, c1 in enumerate(s1):
if c1 == c2:
distances_.append(distances[i1])
else:
distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
distances = distances_
return distances[-1]
def text_matching(gt, pred):
gt = gt.strip()
pred = pred.strip()
if gt in pred or pred in gt:
return True
dist = levenshtein_distance(gt, pred)
length = max(len(gt), len(pred))
value = 0.0 if length == 0 else float(dist) / float(length)
value = 1 - value
return value >= TEXT_ANLS_THRESHOLD
def click_matching(gt_info, pred_info):
if type(pred_info) == str:
pred_info = eval(pred_info)
if type(gt_info) == str:
gt_info = eval(gt_info)
pred = np.asarray(pred_info) / 1000
gt = np.asarray(gt_info) / 1000
return np.linalg.norm(pred - gt) <= CLICK_COORD_THRESHOLD
def action_matching(pred_action, pred_info, gt_action, gt_info):
pred_action = pred_action.strip()
if type(pred_info) == str:
pred_info = pred_info.strip()
gt_action = gt_action.strip()
if type(gt_info) == str:
gt_info = gt_info.strip()
if pred_action != gt_action:
return {'is_correct': 'no', 'info': 'action_fail'}
if gt_action not in ['SCROLL', 'CLICK', 'TYPE', 'LONG_PRESS']:
return {'is_correct': 'yes', 'info': 'action_correct'}
elif gt_action == 'TYPE':
text_flag = text_matching(gt_info, pred_info)
if text_flag:
return {'is_correct': 'yes', 'info': 'type_correct'}
else:
return {'is_correct': 'no', 'info': 'type_fail'}
elif gt_action == 'SCROLL':
if gt_info.lower() == pred_info.lower():
return {'is_correct': 'yes', 'info': 'scroll_correct'}
else:
return {'is_correct': 'no', 'info': 'scroll_fail'}
elif gt_action == 'CLICK' or gt_action == 'LONG_PRESS':
click_flag = click_matching(gt_info, pred_info)
if click_flag:
return {'is_correct': 'yes', 'info': 'click_correct'}
else:
return {'is_correct': 'no', 'info': 'click_fail'}
else:
raise ValueError('Invalid action type')
def stat_result(eval_dict, metric):
text_correct = sum([1 for _ in eval_dict if _['info'] == 'type_correct'])
type_correct = sum([1 for _ in eval_dict if _['info'] != 'action_fail'])
text_total = sum([1 for _ in eval_dict if _['info'].startswith('type_')])
if metric == 'macro':
action_correct = sum([1 for _ in eval_dict if _['is_correct'] == 'yes'])
AMS = round(action_correct / len(eval_dict) * 100, 2)
SR_cnt, SR_tot, SR = check_SR(eval_dict)
elif metric == 'micro':
task_cate_dict = {}
acc_list = []
SR_list = []
# print(eval_dict)
for sample in eval_dict:
cat = sample['more_info']['category']
if cat not in task_cate_dict:
task_cate_dict[cat] = []
task_cate_dict[cat].append(sample)
# assert len(task_cate_dict) == 6 #总共6个类别的数据,跑部分数据可以注释掉
for k, v in task_cate_dict.items():
SR_cnt, SR_tot, SR = check_SR(v)
SR_list.append((SR))
acc = round(sum([1 for x in v if x['is_correct'] == 'yes']) / len(v) * 100, 2)
acc_list.append(acc)
print(f'category: {k}, AMS: {acc}, SR: {SR}')
AMS = np.round(np.mean(acc_list), 2)
SR = np.round(np.mean(SR_list), 2)
else:
raise ValueError(f'No metric {metric} found.')
info = {
'AMS': AMS,
'SR': SR,
'total': len(eval_dict),
'action_type': '{} / {} = {:.2f}'.format(type_correct, len(eval_dict), type_correct / len(eval_dict) * 100),
'text': '{} / {} = {:.2f}'.format(text_correct, text_total, text_correct / text_total * 100),
}
return info
def check_SR(eval_dict):
episode_dict = {}
steps_map = {}
for data in eval_dict:
if 'img' in data: img = data['img']
elif 'image' in data: img = data['image']
else: img = data['question'].split('</img>')[0].split('<img>')[1]
img = os.path.basename(img)
tail = img.split('_')[-1]
episode = img.replace(f'_{tail}', '')
if episode not in episode_dict:
episode_dict[episode] = []
else:
assert steps_map[episode] == data['more_info']['step_length']
info = data['is_correct']
episode_dict[episode].append(info)
steps_map[episode] = data['more_info']['step_length']
cnt, tot = 0, 0
# print('=== ',episode_dict)
for k, v in episode_dict.items():
if len(v) != steps_map[k]:
print(f'step length of {k} does not match.')
continue
tot += 1
v = list(set(v))
if len(v) == 1 and v[0] == 'yes':
cnt += 1
SR = round(cnt / tot * 100, 2)
print(f'total episode: {tot}, successful episode: {cnt}, SR: {SR}')
return cnt, tot, SR
def odyssey_action_matching_evaluation(pred_output, metric='macro'):
eval_dict = []
for idx, sample in enumerate(pred_output):
question, pred, gt, more_info = sample['question'], sample['pred'], sample['gt'], sample['more_info']
# sample_eval_dict = {'question': question, 'pred': str(pred), 'gt': str(gt), 'more_info': more_info}
sample_eval_dict = {'question': question, 'pred': str(pred), 'gt': str(gt), 'more_info': more_info, 'img':sample['img']}
gt_simple_info = simple_decode(gt)
gt_action = gt_simple_info['action']
gt_info = gt_simple_info['info']
try:
pred_simple_info = simple_decode(pred, sample['img'])
# print('pred_simple_info:', pred_simple_info)
pred_action = pred_simple_info['action']
pred_info = pred_simple_info['info']
except:
# print('### eval err:', idx, pred)
log_info = {'is_correct': 'no', 'info': 'decode invalid'}
sample_eval_dict.update(log_info)
eval_dict.append(sample_eval_dict)
continue
try:
check_match = action_matching(pred_action, pred_info, gt_action, gt_info)
except Exception as exc:
print('$$$ eval err:', gt, pred, exc)
check_match = {'is_correct': 'no', 'info': 'match invalid'}
sample_eval_dict.update(check_match)
eval_dict.append(sample_eval_dict)
# print('===== ',eval_dict)
info = stat_result(eval_dict, metric)
metrics = {"info": info, "pred": eval_dict}
return metrics
\ No newline at end of file
# evaluation on odyssey
import os
import random
import torch
import json
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import AutoPeftModelForCausalLM
from transformers.generation import GenerationConfig
import re
import logging
import ast
import argparse
from PIL import Image
import numpy as np
from langchain_community.vectorstores import FAISS
from langchain.schema import Document
from langchain_huggingface import HuggingFaceEmbeddings
from collections import deque
import requests
from prompts import ODYSSEY_GLOBAL_PLANNING_PROMT, ODYSSEY_OBSERVATION_PROMT, ODYSSEY_PLANNING_PROMT, ODYSSEY_EXECUTION_PROMT, PAGE_SUMMARY_PROMPT, REFERENCE_FORMAT, ACTION_SUMMARY_PROMPT
import action_matching
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
random.seed(0)
url = "http://localhost:8000/v1/chat/completions"
headers = {
"Content-Type": "application/json"
}
def get_global_plan(img_path, goal):
global_plan_prompt = ODYSSEY_GLOBAL_PLANNING_PROMT.replace('<goal>',goal)
global_plan = chat([img_path], global_plan_prompt)
return global_plan.split('### Global Plan ###')[-1].strip()
def get_execution(img_path, action_plan, reference_actions):
exec_prompt = ODYSSEY_EXECUTION_PROMT.replace('<action_plan>',action_plan).replace('<reference>',reference_actions)
response = chat([img_path], exec_prompt)
execution = response.split('### Action ###')[-1].strip()
thought = response.split('### Thought ###')[-1].split('### Action ###')[0].strip()
return thought, execution
def get_observation(img_path, goal, previous_step):
if previous_step == '':
previous_step = '<No previous step has been taken.>'
obs_prompt = ODYSSEY_OBSERVATION_PROMT.replace('<goal>',goal).replace('<history>',previous_step)
observation = chat([img_path], obs_prompt)
return observation
def get_plan_action(img_path, goal, observations, global_plan, reference_actions, previous_step):
if previous_step == '':
previous_step = '<No previous step has been taken.>'
plan_prompt = ODYSSEY_PLANNING_PROMT.replace('<goal>',goal)
plan_prompt = plan_prompt.replace('<observation>',observations)
plan_prompt = plan_prompt.replace('<global_plan>',global_plan)
plan_prompt = plan_prompt.replace('<reference>',reference_actions)
plan_prompt = plan_prompt.replace('<history>',previous_step)
plan_action = chat([img_path], plan_prompt)
return plan_action
def bfs_goals(goal_list, idx, search_document):
if idx is None: return
queue = deque([(idx, 0)])
visited = set()
visited.add(idx)
while queue:
cur_node, cur_depth = queue.popleft()
if cur_depth >= 3: continue
nxt_node_list = search_document[cur_node].metadata['next_page_list']
for nxt_node in nxt_node_list:
if nxt_node['actions'] == []: continue
if nxt_node['goal'] not in goal_list:
goal_list.append(nxt_node['goal'])
node_idx = nxt_node['page_index']
if node_idx is not None and node_idx not in visited:
visited.add(node_idx)
queue.append((node_idx, cur_depth+1))
def get_reference_actions(img_path, search_key, goal, search_document, embedding_model):
reference_actions = ''
page_summary = chat([img_path], PAGE_SUMMARY_PROMPT)
count = 0
max_count = 10
vectorstore = FAISS.from_documents(search_document[search_key], embedding_model)
search_res = vectorstore.similarity_search(page_summary)
for res in search_res:
for actions_chain in res.metadata['next_page_list']:
if len(actions_chain['actions'])==0: continue
count=count+1
action_string = ''
for one_action in actions_chain['actions']:
action_string += ', ' + one_action
action_string = action_string[2:]
goal_list = [actions_chain['goal']]
bfs_goals(goal_list, actions_chain['page_index'], search_document[search_key])
goals_string = ''
for one_goal in goal_list:
if one_goal[-1] == '.':
one_goal = one_goal[:-1]
goals_string += '; ' + one_goal
goals_string = goals_string[2:]
one_reference = REFERENCE_FORMAT.format(idx = count, actions = action_string, goals = goals_string)
reference_actions += one_reference
if count == max_count: break
if count == max_count: break
return reference_actions
def get_action_summary(step, img_path):
action = step['action']
info = step['info']
assert action in ['CLICK', 'TEXT', 'SCROLL', 'LONG_PRESS', 'COMPLETE', 'INCOMPLETE']
if action == 'CLICK' or action == "LONG_PRESS":
if info == 'KEY_HOME':
gt = 'press home to go to the home screen'
elif info == 'KEY_BACK':
gt = 'press back to go to the previous screen'
elif info == 'KEY_APPSELECT':
gt = 'go to the previous App'
elif type(info) == list:
w, h = Image.open('data/screenshots/' + step['screenshot']).size
bbox_str = f'[{int(info[0][0]/1000*w)}, {int(info[0][1]/1000*h)}]'
query = ACTION_SUMMARY_PROMPT.format(bbox=bbox_str)
gt = chat([img_path], query)
if gt[-1] == '.':
gt = gt[:-1]
else:
raise ValueError(f'Unknown click action {info}')
elif action == 'SCROLL':
start = np.array(info[0])
end = np.array(info[1])
delta = end - start
delta_abs = np.abs(delta)
lr = 'left' if delta[0] < 0 else 'right'
ud = 'up' if delta[1] < 0 else 'down'
if delta_abs[0] > delta_abs[1]:
gt = f"scroll {lr}"
else:
gt = f"scroll {ud}"
elif action == 'TEXT':
gt = f'type {info}'
elif action == 'COMPLETE':
gt = action
elif action == 'INCOMPLETE':
gt = 'IMPOSSIBLE'
else:
raise ValueError(f'Unknown action {action}')
return gt
def decode_action(action, info):
if action == 'CLICK' or action == "LONG_PRESS":
if info == 'KEY_HOME':
gt = 'PRESS_HOME'
elif info == 'KEY_BACK':
gt = 'PRESS_BACK'
elif info == 'KEY_APPSELECT':
gt = 'PRESS_RECENT'
elif type(info) == list:
gt = f"{action}: {tuple(info[0])}"
else:
raise ValueError(f'Unknown click action {info}')
elif action == 'SCROLL':
start = np.array(info[0])
end = np.array(info[1])
delta = end - start
delta_abs = np.abs(delta)
lr = 'LEFT' if delta[0] < 0 else 'RIGHT'
ud = 'UP' if delta[1] < 0 else 'DOWN'
if delta_abs[0] > delta_abs[1]:
gt = f"SCROLL: {lr}"
else:
gt = f"SCROLL: {ud}"
elif action == 'TEXT':
gt = f'TYPE: {info}'
elif action == 'COMPLETE':
gt = action
elif action == 'INCOMPLETE':
gt = 'IMPOSSIBLE'
else:
raise ValueError(f'Unknown action {action}')
return gt
def document_transform(raw_document):
search_document = {}
for type_name, pages in raw_document.items():
document = []
for idx in pages:
item = pages[idx]
document.append(Document(page_content = item['page_summary'], metadata = item))
search_document[type_name] = document
return search_document
def chat(img_url_list: str = '', query: str = '') -> dict:
content = []
for img_url in img_url_list:
content.append({"type": "image_url", "image_url": {"url": img_url}})
content.append({"type": "text", "text": query})
data = {
"model": "Qwen2.5-VL-72B-Instruct",
"messages": [
{"role": "system", "content": "You are a powerful agent that is trained to perform some basic tasks on the web page."},
{"role": "user", "content": content}
],
"temperature":0}
response = requests.post(url, headers=headers, data=json.dumps(data))
response = response.json()
response = response['choices'][0]['message']['content']
return response
if __name__ == '__main__':
odyssey_data = json.load(open('data/splits/splits_random_split.json','r'))
annotations_path = 'data/annotations/'
imgs_path = 'data/screenshots/'
embedding_model_name = "bge-m3"
embedding_model = HuggingFaceEmbeddings(model_name = embedding_model_name,model_kwargs={'device': 'cuda:0'})
raw_document = json.load(open('odyssey_library.json', 'r'))
search_document = document_transform(raw_document)
outputs = []
for test_idx in tqdm(odyssey_data['test']):
episode = json.load(open(annotations_path + test_idx,'r'))
domain = episode['task_info']['category']
goal = episode['task_info']['instruction']
previous_actions = []
global_plan = ''
flag = 0
for step in episode["steps"]:
img_path = 'http://localhost:6668/'+step['screenshot']
gt = decode_action(step['action'],step['info'])
previous_step = ""
for i, action in enumerate(previous_actions[-4:]):
previous_step += 'Step' + str(i+1) + ': ' + action + ". \n"
action_step = get_action_summary(step, img_path)
previous_actions.append(action_step)
observations = get_observation(img_path, goal, previous_step)
reference_actions = get_reference_actions(img_path, domain, goal, search_document, embedding_model)
if flag == 0:
global_plan = get_global_plan(img_path, goal)
flag = 1
plan_action = get_plan_action(img_path, goal, observations, global_plan, reference_actions, previous_step)
thought, pred = get_execution(img_path, plan_action, reference_actions)
more_info = {'category': domain, 'step_length': episode['step_length']}
outputs.append({
'question': goal,
'pred': str(pred),
'gt': gt,
'more_info': more_info,
'img': img_path.split('/')[-1]
})
print('-------step:{}----------'.format(step['step']))
print('Goal: ', goal)
print('Img: ', img_path)
print('History: ', previous_step)
print('gt: ', gt)
print('Observation: ', observations)
print('Global Planning: \n', global_plan)
print('References: \n',reference_actions)
print('Loacl Planning: \n', plan_action)
print('Thought: ',thought)
print('Decision: \n', pred)
print('---------------------------------------------')
savefile = 'odyssey_record.json'
json.dump(outputs, open(savefile, 'w'), indent=4, ensure_ascii=False)
print(f"Saving predict result ...")
savefile = 'odyssey_record.json'
json.dump(outputs, open(savefile, 'w'), indent=4, ensure_ascii=False)
print(f"Evaluating ...")
metrics = action_matching.odyssey_action_matching_evaluation(outputs, metric='micro')
savefile2 = 'odyssey_eval.json'
json.dump(metrics, open(savefile2, 'w'), indent=4, ensure_ascii=False)
ODYSSEY_ACTION_SPACE = '''
1. 'CLICK: (x,y)': An action of clicking a coordinate point on the smartphone screen and x,y is the position of the coordinate point on the screen.
Your click location should be a UI element or text on the screen.
A simple use case could be 'CLICK: (100,238)', which means you click the UI element at (100,238) on the current screen.
2. 'TYPE: typed_text': An action of typing a piece of text.
A simple use case can be 'TYPE: Hello, world!', which inserts the string "Hello, world!" into the input area on the smartphone screen.
3. 'SCROLL: direction': This function is used to scroll an UI element shown on the smartphone screen, usually a scroll view or a slide bar.
"direction" is a string that represents one of the four directions: UP, DOWN, LEFT, RIGHT.
A simple use case could be 'SCROLL: UP', which means you take a scroll up action on the current screen.
4. 'PRESS_BACK': The action for returning to the previous screen.
5. 'PRESS_HOME': The action for returning to the homepage.
6. 'PRESS_RECENT': The action to go to the previous App.
7. 'COMPLETE': It means you think the task has been completed based on current screen.
8. 'IMPOSSIBLE': It means you think the task cannot be completed based on current screen.
9. 'LONG_PRESS: (x,y)': An action of pressing a coordinate point on the smartphone screen for a long time to copy texts or download images, where x and y is the position of the coordinate point on the screen.
'''
ODYSSEY_OBSERVATION_PROMT = f"""
You are a smart GUI agent, capable of comprehensively understanding the GUI interface as well as the user's intentions.
You will be given user's ultimate purpose and the previous actions that you have taken.
Your task is to carefully observe the screen, descripe it and conclude some useful clues in one sentence.
Now you can start to observe:
### User's purpose ###
<goal>
### History trajectory ###
History trajectory can remind you of the operations that have been executed before, thus avoiding repetitive actions.
<history>
### Observation ###
"""
ODYSSEY_GLOBAL_PLANNING_PROMT = f'''
You are an agent that is trained to complete certain tasks on a smartphone. You will be given a screenshot of a smartphone app.
The global task you should complete is:
\"<goal>\"
Now, carefully analyze all the above content and provide your output in the following format:
### Global Plan ###
Please break down the overall task into 2~3 simple sub-goals.
Note that since you can’t see future phone screenshots, each sub-goal should be abstract, high-level, and not involve interacting with specific UI elements.
'''
ODYSSEY_PLANNING_PROMT = f"""
You are a smart GUI agent, capable of comprehensively understanding the GUI interface as well as the user's intentions.
Your task is to plan the next action to complete user's purpose with the help of references.
I will give you several important information:
### User's purpose ###
This is the user's global purpose, and your goal is to complete it:
<goal>
### Observation ###
This is the observation of the screen and some useful clues that help you plan:
<observation>
### Global Plan ###
This is the global plan for completing user's purpose:
<global_plan>
### History trajectory ###
History trajectory can remind you of the operations that have been executed before, thus avoiding repetitive actions.
<history>
### Reference ###
There are some reference actions that you can follow:
<reference>
Based on given information, you are required to output with following format:
1. <Please decide which sub-goal in the \"### Global Plan ###\" should be executed based on the screen image>
2. <Check if the user's global purpose has been completed. If the current screen state matches the user's global purpose, directly suggest that the task has been completed>
3. <If the global purpose is not completed: Inspired by \"### Reference ###\", you can list some actions than can possibly push the task progress or complete the goal>
"""
ODYSSEY_EXECUTION_PROMT = f"""
You are a smart GUI agent, capable of comprehensively understanding the GUI interface.
You will be given a smartphone screenshot and a plan that you decide to take.
Before you start, I will explain the data format:
### Plan ###
This is your plan:
<action_plan>
### Action Space ###
These are the functions to interact with the phone:
{ODYSSEY_ACTION_SPACE}
### Reference ###
There are some reference actions that you can follow:
<reference>
Now please choose one action in \"### Action Space ###\" for the current screen state based on \"### Plan ###\" and \"### Reference ###\".
You should output with following format:
### Thought ###
According to \"### Plan ###\", you should first determine weather the purpose has been completed. If not, think step-by-step and output the action that should be taken currently.
### Action ###
The action you finally choose from \"### Action Space ###\". Do not output anything else.
"""
REFERENCE_FORMAT = '''{idx}.
You can take following action: {actions}.
This can help you achieve goals like: {goals}.
'''
PAGE_SUMMARY_PROMPT = 'Please describe this screen containing following content with one full sentence, including \
the type of page, the function of page and the key components of the screen.'
ACTION_SUMMARY_PROMPT = 'The user clicks the item at coordinates {bbox}. You are required to summarize this operation with a verb phrase that begins with \"click\". Do not mention original coordinates.'
pip install langchain_community
pip install langchain_huggingface
pip install jax
pip install jaxlib
pip install faiss-gpu
pip install sentence-transformers
python -m vllm.entrypoints.openai.api_server --served-model-name Qwen2.5-VL-72B-Instruct --model Qwen2.5-VL-72B-Instruct -tp 4 --limit_mm_per_prompt image=2
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment