Skip to content

Commit 23d9058

Browse files
authored
Merge pull request #125 from OthersideAI/SoM
Add set-of-mark prompting
2 parents 8bd5d76 + fa05718 commit 23d9058

File tree

22 files changed

+722
-316
lines changed

22 files changed

+722
-316
lines changed

.gitignore

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -162,13 +162,5 @@ cython_debug/
162162
.DS_Store
163163

164164
# Avoid sending testing screenshots up
165-
screenshot.png
166-
screenshot_with_grid.png
167-
screenshot_with_labeled_grid.png
168-
screenshot_mini.png
169-
screenshot_mini_with_grid.png
170-
grid_screenshot.png
171-
grid_reflection_screenshot.png
172-
reflection_screenshot.png
173-
summary_screenshot.png
174-
operate/screenshots/
165+
*.png
166+
operate/screenshots/

operate/actions/api_interactions.py renamed to operate/actions.py

Lines changed: 179 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -3,36 +3,65 @@
33
import json
44
import base64
55
import re
6+
import io
7+
import asyncio
8+
import aiohttp
9+
610
from PIL import Image
11+
from ultralytics import YOLO
712
import google.generativeai as genai
8-
from operate.config.settings import Config
9-
from operate.exceptions.exceptions import ModelNotRecognizedException
10-
from operate.utils.screenshot_util import capture_screen_with_cursor, add_grid_to_image, capture_mini_screenshot_with_cursor
11-
from operate.utils.action_util import get_last_assistant_message
12-
from operate.utils.prompt_util import format_vision_prompt, format_accurate_mode_vision_prompt,format_summary_prompt
13+
from operate.settings import Config
14+
from operate.exceptions import ModelNotRecognizedException
15+
from operate.utils.screenshot import (
16+
capture_screen_with_cursor,
17+
add_grid_to_image,
18+
capture_mini_screenshot_with_cursor,
19+
)
20+
from operate.utils.os import get_last_assistant_message
21+
from operate.prompts import (
22+
format_vision_prompt,
23+
format_accurate_mode_vision_prompt,
24+
format_summary_prompt,
25+
format_decision_prompt,
26+
format_label_prompt,
27+
)
28+
29+
30+
from operate.utils.label import (
31+
add_labels,
32+
parse_click_content,
33+
get_click_position_in_percent,
34+
get_label_coordinates,
35+
)
36+
from operate.utils.style import (
37+
ANSI_GREEN,
38+
ANSI_RED,
39+
ANSI_RESET,
40+
)
41+
1342

1443
# Load configuration
1544
config = Config()
45+
1646
client = config.initialize_openai_client()
1747

48+
yolo_model = YOLO("./operate/model/weights/best.pt") # Load your trained model
1849

19-
def get_next_action(model, messages, objective, accurate_mode):
20-
if model == "gpt-4-vision-preview":
21-
content = get_next_action_from_openai(
22-
messages, objective, accurate_mode)
23-
return content
50+
51+
async def get_next_action(model, messages, objective):
52+
if model == "gpt-4":
53+
return call_gpt_4_v(messages, objective)
54+
if model == "gpt-4-with-som":
55+
return await call_gpt_4_v_labeled(messages, objective)
2456
elif model == "agent-1":
2557
return "coming soon"
2658
elif model == "gemini-pro-vision":
27-
content = get_next_action_from_gemini_pro_vision(
28-
messages, objective
29-
)
30-
return content
59+
return call_gemini_pro_vision(messages, objective)
3160

3261
raise ModelNotRecognizedException(model)
3362

3463

35-
def get_next_action_from_openai(messages, objective, accurate_mode):
64+
def call_gpt_4_v(messages, objective):
3665
"""
3766
Get the next action for Self-Operating Computer
3867
"""
@@ -95,32 +124,14 @@ def get_next_action_from_openai(messages, objective, accurate_mode):
95124

96125
content = response.choices[0].message.content
97126

98-
if accurate_mode:
99-
if content.startswith("CLICK"):
100-
# Adjust pseudo_messages to include the accurate_mode_message
101-
102-
click_data = re.search(r"CLICK \{ (.+) \}", content).group(1)
103-
click_data_json = json.loads(f"{{{click_data}}}")
104-
prev_x = click_data_json["x"]
105-
prev_y = click_data_json["y"]
106-
107-
if config.debug:
108-
print(
109-
f"Previous coords before accurate tuning: prev_x {prev_x} prev_y {prev_y}"
110-
)
111-
content = accurate_mode_double_check(
112-
"gpt-4-vision-preview", pseudo_messages, prev_x, prev_y
113-
)
114-
assert content != "ERROR", "ERROR: accurate_mode_double_check failed"
115-
116127
return content
117128

118129
except Exception as e:
119130
print(f"Error parsing JSON: {e}")
120131
return "Failed take action after looking at the screenshot"
121132

122133

123-
def get_next_action_from_gemini_pro_vision(messages, objective):
134+
def call_gemini_pro_vision(messages, objective):
124135
"""
125136
Get the next action for Self-Operating Computer using Gemini Pro Vision
126137
"""
@@ -172,14 +183,13 @@ def get_next_action_from_gemini_pro_vision(messages, objective):
172183
return "Failed take action after looking at the screenshot"
173184

174185

186+
# This function is not used. `-accurate` mode was removed for now until a new PR fixes it.
175187
def accurate_mode_double_check(model, pseudo_messages, prev_x, prev_y):
176188
"""
177189
Reprompt OAI with additional screenshot of a mini screenshot centered around the cursor for further finetuning of clicked location
178190
"""
179-
print("[get_next_action_from_gemini_pro_vision] accurate_mode_double_check")
180191
try:
181-
screenshot_filename = os.path.join(
182-
"screenshots", "screenshot_mini.png")
192+
screenshot_filename = os.path.join("screenshots", "screenshot_mini.png")
183193
capture_mini_screenshot_with_cursor(
184194
file_path=screenshot_filename, x=prev_x, y=prev_y
185195
)
@@ -191,8 +201,7 @@ def accurate_mode_double_check(model, pseudo_messages, prev_x, prev_y):
191201
with open(new_screenshot_filename, "rb") as img_file:
192202
img_base64 = base64.b64encode(img_file.read()).decode("utf-8")
193203

194-
accurate_vision_prompt = format_accurate_mode_vision_prompt(
195-
prev_x, prev_y)
204+
accurate_vision_prompt = format_accurate_mode_vision_prompt(prev_x, prev_y)
196205

197206
accurate_mode_message = {
198207
"role": "user",
@@ -234,7 +243,7 @@ def summarize(model, messages, objective):
234243
capture_screen_with_cursor(screenshot_filename)
235244

236245
summary_prompt = format_summary_prompt(objective)
237-
246+
238247
if model == "gpt-4-vision-preview":
239248
with open(screenshot_filename, "rb") as img_file:
240249
img_base64 = base64.b64encode(img_file.read()).decode("utf-8")
@@ -266,7 +275,135 @@ def summarize(model, messages, objective):
266275
)
267276
content = summary_message.text
268277
return content
269-
278+
270279
except Exception as e:
271280
print(f"Error in summarize: {e}")
272-
return "Failed to summarize the workflow"
281+
return "Failed to summarize the workflow"
282+
283+
284+
async def call_gpt_4_v_labeled(messages, objective):
285+
time.sleep(1)
286+
try:
287+
screenshots_dir = "screenshots"
288+
if not os.path.exists(screenshots_dir):
289+
os.makedirs(screenshots_dir)
290+
291+
screenshot_filename = os.path.join(screenshots_dir, "screenshot.png")
292+
# Call the function to capture the screen with the cursor
293+
capture_screen_with_cursor(screenshot_filename)
294+
295+
with open(screenshot_filename, "rb") as img_file:
296+
img_base64 = base64.b64encode(img_file.read()).decode("utf-8")
297+
298+
previous_action = get_last_assistant_message(messages)
299+
300+
img_base64_labeled, img_base64_original, label_coordinates = add_labels(
301+
img_base64, yolo_model
302+
)
303+
304+
decision_prompt = format_decision_prompt(objective, previous_action)
305+
labeled_click_prompt = format_label_prompt(objective)
306+
307+
click_message = {
308+
"role": "user",
309+
"content": [
310+
{"type": "text", "text": labeled_click_prompt},
311+
{
312+
"type": "image_url",
313+
"image_url": {
314+
"url": f"data:image/jpeg;base64,{img_base64_labeled}"
315+
},
316+
},
317+
],
318+
}
319+
decision_message = {
320+
"role": "user",
321+
"content": [
322+
{"type": "text", "text": decision_prompt},
323+
{
324+
"type": "image_url",
325+
"image_url": {
326+
"url": f"data:image/jpeg;base64,{img_base64_original}"
327+
},
328+
},
329+
],
330+
}
331+
332+
click_messages = messages.copy()
333+
click_messages.append(click_message)
334+
decision_messages = messages.copy()
335+
decision_messages.append(decision_message)
336+
337+
click_future = fetch_openai_response_async(click_messages)
338+
decision_future = fetch_openai_response_async(decision_messages)
339+
340+
click_response, decision_response = await asyncio.gather(
341+
click_future, decision_future
342+
)
343+
344+
# Extracting the message content from the ChatCompletionMessage object
345+
click_content = click_response.get("choices")[0].get("message").get("content")
346+
347+
decision_content = (
348+
decision_response.get("choices")[0].get("message").get("content")
349+
)
350+
351+
if not decision_content.startswith("CLICK"):
352+
return decision_content
353+
354+
label_data = parse_click_content(click_content)
355+
356+
if label_data and "label" in label_data:
357+
coordinates = get_label_coordinates(label_data["label"], label_coordinates)
358+
image = Image.open(
359+
io.BytesIO(base64.b64decode(img_base64))
360+
) # Load the image to get its size
361+
image_size = image.size # Get the size of the image (width, height)
362+
click_position_percent = get_click_position_in_percent(
363+
coordinates, image_size
364+
)
365+
if not click_position_percent:
366+
print(
367+
f"{ANSI_GREEN}[Self-Operating Computer]{ANSI_RED}[Error] Failed to get click position in percent. Trying another method {ANSI_RESET}"
368+
)
369+
return call_gpt_4_v(messages, objective)
370+
371+
x_percent = f"{click_position_percent[0]:.2f}%"
372+
y_percent = f"{click_position_percent[1]:.2f}%"
373+
click_action = f'CLICK {{ "x": "{x_percent}", "y": "{y_percent}", "description": "{label_data["decision"]}", "reason": "{label_data["reason"]}" }}'
374+
375+
else:
376+
print(
377+
f"{ANSI_GREEN}[Self-Operating Computer]{ANSI_RED}[Error] No label found. Trying another method {ANSI_RESET}"
378+
)
379+
return call_gpt_4_v(messages, objective)
380+
381+
return click_action
382+
383+
except Exception as e:
384+
print(
385+
f"{ANSI_GREEN}[Self-Operating Computer]{ANSI_RED}[Error] Something went wrong. Trying another method {ANSI_RESET}"
386+
)
387+
return call_gpt_4_v(messages, objective)
388+
389+
390+
async def fetch_openai_response_async(messages):
391+
url = "https://api.openai.com/v1/chat/completions"
392+
headers = {
393+
"Content-Type": "application/json",
394+
"Authorization": f"Bearer {config.openai_api_key}",
395+
}
396+
data = {
397+
"model": "gpt-4-vision-preview",
398+
"messages": messages,
399+
"frequency_penalty": 1,
400+
"presence_penalty": 1,
401+
"temperature": 0.7,
402+
"max_tokens": 300,
403+
}
404+
405+
async with aiohttp.ClientSession() as session:
406+
async with session.post(
407+
url, headers=headers, data=json.dumps(data)
408+
) as response:
409+
return await response.json()

operate/actions/__init__.py

Whitespace-only changes.

operate/config/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)