3
3
import json
4
4
import base64
5
5
import re
6
+ import io
7
+ import asyncio
8
+ import aiohttp
9
+
6
10
from PIL import Image
11
+ from ultralytics import YOLO
7
12
import google .generativeai as genai
8
- from operate .config .settings import Config
9
- from operate .exceptions .exceptions import ModelNotRecognizedException
10
- from operate .utils .screenshot_util import capture_screen_with_cursor , add_grid_to_image , capture_mini_screenshot_with_cursor
11
- from operate .utils .action_util import get_last_assistant_message
12
- from operate .utils .prompt_util import format_vision_prompt , format_accurate_mode_vision_prompt ,format_summary_prompt
13
+ from operate .settings import Config
14
+ from operate .exceptions import ModelNotRecognizedException
15
+ from operate .utils .screenshot import (
16
+ capture_screen_with_cursor ,
17
+ add_grid_to_image ,
18
+ capture_mini_screenshot_with_cursor ,
19
+ )
20
+ from operate .utils .os import get_last_assistant_message
21
+ from operate .prompts import (
22
+ format_vision_prompt ,
23
+ format_accurate_mode_vision_prompt ,
24
+ format_summary_prompt ,
25
+ format_decision_prompt ,
26
+ format_label_prompt ,
27
+ )
28
+
29
+
30
+ from operate .utils .label import (
31
+ add_labels ,
32
+ parse_click_content ,
33
+ get_click_position_in_percent ,
34
+ get_label_coordinates ,
35
+ )
36
+ from operate .utils .style import (
37
+ ANSI_GREEN ,
38
+ ANSI_RED ,
39
+ ANSI_RESET ,
40
+ )
41
+
13
42
14
43
# Load configuration
15
44
config = Config ()
45
+
16
46
client = config .initialize_openai_client ()
17
47
48
+ yolo_model = YOLO ("./operate/model/weights/best.pt" ) # Load your trained model
18
49
19
- def get_next_action (model , messages , objective , accurate_mode ):
20
- if model == "gpt-4-vision-preview" :
21
- content = get_next_action_from_openai (
22
- messages , objective , accurate_mode )
23
- return content
50
+
51
+ async def get_next_action (model , messages , objective ):
52
+ if model == "gpt-4" :
53
+ return call_gpt_4_v (messages , objective )
54
+ if model == "gpt-4-with-som" :
55
+ return await call_gpt_4_v_labeled (messages , objective )
24
56
elif model == "agent-1" :
25
57
return "coming soon"
26
58
elif model == "gemini-pro-vision" :
27
- content = get_next_action_from_gemini_pro_vision (
28
- messages , objective
29
- )
30
- return content
59
+ return call_gemini_pro_vision (messages , objective )
31
60
32
61
raise ModelNotRecognizedException (model )
33
62
34
63
35
- def get_next_action_from_openai (messages , objective , accurate_mode ):
64
+ def call_gpt_4_v (messages , objective ):
36
65
"""
37
66
Get the next action for Self-Operating Computer
38
67
"""
@@ -95,32 +124,14 @@ def get_next_action_from_openai(messages, objective, accurate_mode):
95
124
96
125
content = response .choices [0 ].message .content
97
126
98
- if accurate_mode :
99
- if content .startswith ("CLICK" ):
100
- # Adjust pseudo_messages to include the accurate_mode_message
101
-
102
- click_data = re .search (r"CLICK \{ (.+) \}" , content ).group (1 )
103
- click_data_json = json .loads (f"{{{ click_data } }}" )
104
- prev_x = click_data_json ["x" ]
105
- prev_y = click_data_json ["y" ]
106
-
107
- if config .debug :
108
- print (
109
- f"Previous coords before accurate tuning: prev_x { prev_x } prev_y { prev_y } "
110
- )
111
- content = accurate_mode_double_check (
112
- "gpt-4-vision-preview" , pseudo_messages , prev_x , prev_y
113
- )
114
- assert content != "ERROR" , "ERROR: accurate_mode_double_check failed"
115
-
116
127
return content
117
128
118
129
except Exception as e :
119
130
print (f"Error parsing JSON: { e } " )
120
131
return "Failed take action after looking at the screenshot"
121
132
122
133
123
- def get_next_action_from_gemini_pro_vision (messages , objective ):
134
+ def call_gemini_pro_vision (messages , objective ):
124
135
"""
125
136
Get the next action for Self-Operating Computer using Gemini Pro Vision
126
137
"""
@@ -172,14 +183,13 @@ def get_next_action_from_gemini_pro_vision(messages, objective):
172
183
return "Failed take action after looking at the screenshot"
173
184
174
185
186
+ # This function is not used. `-accurate` mode was removed for now until a new PR fixes it.
175
187
def accurate_mode_double_check (model , pseudo_messages , prev_x , prev_y ):
176
188
"""
177
189
Reprompt OAI with additional screenshot of a mini screenshot centered around the cursor for further finetuning of clicked location
178
190
"""
179
- print ("[get_next_action_from_gemini_pro_vision] accurate_mode_double_check" )
180
191
try :
181
- screenshot_filename = os .path .join (
182
- "screenshots" , "screenshot_mini.png" )
192
+ screenshot_filename = os .path .join ("screenshots" , "screenshot_mini.png" )
183
193
capture_mini_screenshot_with_cursor (
184
194
file_path = screenshot_filename , x = prev_x , y = prev_y
185
195
)
@@ -191,8 +201,7 @@ def accurate_mode_double_check(model, pseudo_messages, prev_x, prev_y):
191
201
with open (new_screenshot_filename , "rb" ) as img_file :
192
202
img_base64 = base64 .b64encode (img_file .read ()).decode ("utf-8" )
193
203
194
- accurate_vision_prompt = format_accurate_mode_vision_prompt (
195
- prev_x , prev_y )
204
+ accurate_vision_prompt = format_accurate_mode_vision_prompt (prev_x , prev_y )
196
205
197
206
accurate_mode_message = {
198
207
"role" : "user" ,
@@ -234,7 +243,7 @@ def summarize(model, messages, objective):
234
243
capture_screen_with_cursor (screenshot_filename )
235
244
236
245
summary_prompt = format_summary_prompt (objective )
237
-
246
+
238
247
if model == "gpt-4-vision-preview" :
239
248
with open (screenshot_filename , "rb" ) as img_file :
240
249
img_base64 = base64 .b64encode (img_file .read ()).decode ("utf-8" )
@@ -266,7 +275,135 @@ def summarize(model, messages, objective):
266
275
)
267
276
content = summary_message .text
268
277
return content
269
-
278
+
270
279
except Exception as e :
271
280
print (f"Error in summarize: { e } " )
272
- return "Failed to summarize the workflow"
281
+ return "Failed to summarize the workflow"
282
+
283
+
284
+ async def call_gpt_4_v_labeled (messages , objective ):
285
+ time .sleep (1 )
286
+ try :
287
+ screenshots_dir = "screenshots"
288
+ if not os .path .exists (screenshots_dir ):
289
+ os .makedirs (screenshots_dir )
290
+
291
+ screenshot_filename = os .path .join (screenshots_dir , "screenshot.png" )
292
+ # Call the function to capture the screen with the cursor
293
+ capture_screen_with_cursor (screenshot_filename )
294
+
295
+ with open (screenshot_filename , "rb" ) as img_file :
296
+ img_base64 = base64 .b64encode (img_file .read ()).decode ("utf-8" )
297
+
298
+ previous_action = get_last_assistant_message (messages )
299
+
300
+ img_base64_labeled , img_base64_original , label_coordinates = add_labels (
301
+ img_base64 , yolo_model
302
+ )
303
+
304
+ decision_prompt = format_decision_prompt (objective , previous_action )
305
+ labeled_click_prompt = format_label_prompt (objective )
306
+
307
+ click_message = {
308
+ "role" : "user" ,
309
+ "content" : [
310
+ {"type" : "text" , "text" : labeled_click_prompt },
311
+ {
312
+ "type" : "image_url" ,
313
+ "image_url" : {
314
+ "url" : f"data:image/jpeg;base64,{ img_base64_labeled } "
315
+ },
316
+ },
317
+ ],
318
+ }
319
+ decision_message = {
320
+ "role" : "user" ,
321
+ "content" : [
322
+ {"type" : "text" , "text" : decision_prompt },
323
+ {
324
+ "type" : "image_url" ,
325
+ "image_url" : {
326
+ "url" : f"data:image/jpeg;base64,{ img_base64_original } "
327
+ },
328
+ },
329
+ ],
330
+ }
331
+
332
+ click_messages = messages .copy ()
333
+ click_messages .append (click_message )
334
+ decision_messages = messages .copy ()
335
+ decision_messages .append (decision_message )
336
+
337
+ click_future = fetch_openai_response_async (click_messages )
338
+ decision_future = fetch_openai_response_async (decision_messages )
339
+
340
+ click_response , decision_response = await asyncio .gather (
341
+ click_future , decision_future
342
+ )
343
+
344
+ # Extracting the message content from the ChatCompletionMessage object
345
+ click_content = click_response .get ("choices" )[0 ].get ("message" ).get ("content" )
346
+
347
+ decision_content = (
348
+ decision_response .get ("choices" )[0 ].get ("message" ).get ("content" )
349
+ )
350
+
351
+ if not decision_content .startswith ("CLICK" ):
352
+ return decision_content
353
+
354
+ label_data = parse_click_content (click_content )
355
+
356
+ if label_data and "label" in label_data :
357
+ coordinates = get_label_coordinates (label_data ["label" ], label_coordinates )
358
+ image = Image .open (
359
+ io .BytesIO (base64 .b64decode (img_base64 ))
360
+ ) # Load the image to get its size
361
+ image_size = image .size # Get the size of the image (width, height)
362
+ click_position_percent = get_click_position_in_percent (
363
+ coordinates , image_size
364
+ )
365
+ if not click_position_percent :
366
+ print (
367
+ f"{ ANSI_GREEN } [Self-Operating Computer]{ ANSI_RED } [Error] Failed to get click position in percent. Trying another method { ANSI_RESET } "
368
+ )
369
+ return call_gpt_4_v (messages , objective )
370
+
371
+ x_percent = f"{ click_position_percent [0 ]:.2f} %"
372
+ y_percent = f"{ click_position_percent [1 ]:.2f} %"
373
+ click_action = f'CLICK {{ "x": "{ x_percent } ", "y": "{ y_percent } ", "description": "{ label_data ["decision" ]} ", "reason": "{ label_data ["reason" ]} " }}'
374
+
375
+ else :
376
+ print (
377
+ f"{ ANSI_GREEN } [Self-Operating Computer]{ ANSI_RED } [Error] No label found. Trying another method { ANSI_RESET } "
378
+ )
379
+ return call_gpt_4_v (messages , objective )
380
+
381
+ return click_action
382
+
383
+ except Exception as e :
384
+ print (
385
+ f"{ ANSI_GREEN } [Self-Operating Computer]{ ANSI_RED } [Error] Something went wrong. Trying another method { ANSI_RESET } "
386
+ )
387
+ return call_gpt_4_v (messages , objective )
388
+
389
+
390
+ async def fetch_openai_response_async (messages ):
391
+ url = "https://api.openai.com/v1/chat/completions"
392
+ headers = {
393
+ "Content-Type" : "application/json" ,
394
+ "Authorization" : f"Bearer { config .openai_api_key } " ,
395
+ }
396
+ data = {
397
+ "model" : "gpt-4-vision-preview" ,
398
+ "messages" : messages ,
399
+ "frequency_penalty" : 1 ,
400
+ "presence_penalty" : 1 ,
401
+ "temperature" : 0.7 ,
402
+ "max_tokens" : 300 ,
403
+ }
404
+
405
+ async with aiohttp .ClientSession () as session :
406
+ async with session .post (
407
+ url , headers = headers , data = json .dumps (data )
408
+ ) as response :
409
+ return await response .json ()
0 commit comments