Skip to content

Commit e116564

Browse files
authored
Merge pull request #180 from roywei/claude_support
add claude support
2 parents 0deab80 + 60c5f3d commit e116564

File tree

4 files changed

+222
-7
lines changed

4 files changed

+222
-7
lines changed

operate/config.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import sys
33
from dotenv import load_dotenv
44
from openai import OpenAI
5+
import anthropic
56
from prompt_toolkit.shortcuts import input_dialog
67
import google.generativeai as genai
78

@@ -33,6 +34,10 @@ def __init__(self):
3334
self.google_api_key = (
3435
None # instance variables are backups in case saving to a `.env` fails
3536
)
37+
self.anthropic_api_key = (
38+
None # instance variables are backups in case saving to a `.env` fails
39+
)
40+
3641

3742
def initialize_openai(self):
3843
if self.verbose:
@@ -71,6 +76,14 @@ def initialize_google(self):
7176
model = genai.GenerativeModel("gemini-pro-vision")
7277

7378
return model
79+
80+
def initialize_anthropic(self):
81+
if self.anthropic_api_key:
82+
api_key = self.anthropic_api_key
83+
else:
84+
api_key = os.getenv("ANTHROPIC_API_KEY")
85+
return anthropic.Anthropic(api_key=api_key)
86+
7487

7588
def validation(self, model, voice_mode):
7689
"""
@@ -87,6 +100,9 @@ def validation(self, model, voice_mode):
87100
self.require_api_key(
88101
"GOOGLE_API_KEY", "Google API key", model == "gemini-pro-vision"
89102
)
103+
self.require_api_key(
104+
"ANTHROPIC_API_KEY", "Anthropic API key", model == "claude-3-with-ocr"
105+
)
90106

91107
def require_api_key(self, key_name, key_description, is_required):
92108
key_exists = bool(os.environ.get(key_name))

operate/models/apis.py

Lines changed: 197 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,16 @@ async def get_next_action(model, messages, objective, session_id):
4848
if model == "gpt-4-with-ocr":
4949
operation = await call_gpt_4_vision_preview_ocr(messages, objective, model)
5050
return operation, None
51-
elif model == "agent-1":
51+
if model == "agent-1":
5252
return "coming soon"
53-
elif model == "gemini-pro-vision":
53+
if model == "gemini-pro-vision":
5454
return call_gemini_pro_vision(messages, objective), None
55-
elif model == "llava":
56-
operation = call_ollama_llava(messages), None
57-
return operation
58-
55+
if model == "llava":
56+
operation = call_ollama_llava(messages)
57+
return operation, None
58+
if model == "claude-3":
59+
operation = await call_claude_3_with_ocr(messages, objective, model)
60+
return operation, None
5961
raise ModelNotRecognizedException(model)
6062

6163

@@ -528,6 +530,195 @@ def call_ollama_llava(messages):
528530
return call_ollama_llava(messages)
529531

530532

533+
async def call_claude_3_with_ocr(messages, objective, model):
534+
if config.verbose:
535+
print("[call_claude_3_with_ocr]")
536+
537+
try:
538+
time.sleep(1)
539+
client = config.initialize_anthropic()
540+
541+
confirm_system_prompt(messages, objective, model)
542+
screenshots_dir = "screenshots"
543+
if not os.path.exists(screenshots_dir):
544+
os.makedirs(screenshots_dir)
545+
546+
screenshot_filename = os.path.join(screenshots_dir, "screenshot.png")
547+
capture_screen_with_cursor(screenshot_filename)
548+
549+
# downsize screenshot due to 5MB size limit
550+
with open(screenshot_filename, "rb") as img_file:
551+
img = Image.open(img_file)
552+
553+
# Calculate the new dimensions while maintaining the aspect ratio
554+
original_width, original_height = img.size
555+
aspect_ratio = original_width / original_height
556+
new_width = 2560 # Adjust this value to achieve the desired file size
557+
new_height = int(new_width / aspect_ratio)
558+
if config.verbose:
559+
print("[call_claude_3_with_ocr] resizing claude")
560+
561+
# Resize the image
562+
img_resized = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
563+
564+
# Save the resized image to a BytesIO object
565+
img_buffer = io.BytesIO()
566+
img_resized.save(img_buffer, format="PNG")
567+
img_buffer.seek(0)
568+
569+
# Encode the resized image as base64
570+
img_data = base64.b64encode(img_buffer.getvalue()).decode("utf-8")
571+
572+
if len(messages) == 1:
573+
user_prompt = get_user_first_message_prompt()
574+
else:
575+
user_prompt = get_user_prompt()
576+
577+
vision_message = {
578+
"role": "user",
579+
"content": [
580+
{
581+
"type": "image",
582+
"source": {
583+
"type": "base64",
584+
"media_type": "image/png",
585+
"data": img_data,
586+
},
587+
},
588+
{
589+
"type": "text",
590+
"text": user_prompt
591+
+ "**REMEMBER** Only output json format, do not append any other text.",
592+
},
593+
],
594+
}
595+
messages.append(vision_message)
596+
597+
# anthropic api expect system prompt as an separate argument
598+
response = client.messages.create(
599+
model="claude-3-opus-20240229",
600+
max_tokens=3000,
601+
system=messages[0]["content"],
602+
messages=messages[1:],
603+
)
604+
605+
content = response.content[0].text
606+
content = clean_json(content)
607+
content_str = content
608+
try:
609+
content = json.loads(content)
610+
# rework for json mode output
611+
except json.JSONDecodeError as e:
612+
if config.verbose:
613+
print(
614+
f"{ANSI_GREEN}[Self-Operating Computer]{ANSI_RED}[Error] JSONDecodeError: {e} {ANSI_RESET}"
615+
)
616+
response = client.messages.create(
617+
model="claude-3-opus-20240229",
618+
max_tokens=3000,
619+
system=f"This json string is not valid, when using with json.loads(content) \
620+
it throws the following error: {e}, return correct json string. \
621+
**REMEMBER** Only output json format, do not append any other text.",
622+
messages=[{"role": "user", "content": content}],
623+
)
624+
content = response.content[0].text
625+
content = clean_json(content)
626+
content_str = content
627+
content = json.loads(content)
628+
629+
if config.verbose:
630+
print(
631+
f"{ANSI_GREEN}[Self-Operating Computer]{ANSI_BRIGHT_MAGENTA}[{model}] content: {content} {ANSI_RESET}"
632+
)
633+
processed_content = []
634+
635+
for operation in content:
636+
if operation.get("operation") == "click":
637+
text_to_click = operation.get("text")
638+
if config.verbose:
639+
print(
640+
"[call_claude_3_ocr][click] text_to_click",
641+
text_to_click,
642+
)
643+
# Initialize EasyOCR Reader
644+
reader = easyocr.Reader(["en"])
645+
646+
# Read the screenshot
647+
result = reader.readtext(screenshot_filename)
648+
649+
# limit the text to extract has a higher success rate
650+
text_element_index = get_text_element(
651+
result, text_to_click[:3], screenshot_filename
652+
)
653+
coordinates = get_text_coordinates(
654+
result, text_element_index, screenshot_filename
655+
)
656+
657+
# add `coordinates`` to `content`
658+
operation["x"] = coordinates["x"]
659+
operation["y"] = coordinates["y"]
660+
661+
if config.verbose:
662+
print(
663+
"[call_claude_3_ocr][click] text_element_index",
664+
text_element_index,
665+
)
666+
print(
667+
"[call_claude_3_ocr][click] coordinates",
668+
coordinates,
669+
)
670+
print(
671+
"[call_claude_3_ocr][click] final operation",
672+
operation,
673+
)
674+
processed_content.append(operation)
675+
676+
else:
677+
processed_content.append(operation)
678+
679+
assistant_message = {"role": "assistant", "content": content_str}
680+
messages.append(assistant_message)
681+
682+
return processed_content
683+
684+
except Exception as e:
685+
print(
686+
f"{ANSI_GREEN}[Self-Operating Computer]{ANSI_BRIGHT_MAGENTA}[{model}] That did not work. Trying another method {ANSI_RESET}"
687+
)
688+
if config.verbose:
689+
print("[Self-Operating Computer][Operate] error", e)
690+
traceback.print_exc()
691+
print("message before convertion ", messages)
692+
693+
# Convert the messages to the GPT-4 format
694+
gpt4_messages = [messages[0]] # Include the system message
695+
for message in messages[1:]:
696+
if message["role"] == "user":
697+
# Update the image type format from "source" to "url"
698+
updated_content = []
699+
for item in message["content"]:
700+
if isinstance(item, dict) and "type" in item:
701+
if item["type"] == "image":
702+
updated_content.append(
703+
{
704+
"type": "image_url",
705+
"image_url": {
706+
"url": f"data:image/png;base64,{item['source']['data']}"
707+
},
708+
}
709+
)
710+
else:
711+
updated_content.append(item)
712+
713+
gpt4_messages.append({"role": "user", "content": updated_content})
714+
elif message["role"] == "assistant":
715+
gpt4_messages.append(
716+
{"role": "assistant", "content": message["content"]}
717+
)
718+
719+
return gpt_4_fallback(gpt4_messages, objective, model)
720+
721+
531722
def get_last_assistant_message(messages):
532723
"""
533724
Retrieve the last message from the assistant in the messages array.

operate/models/prompts.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,13 @@ def get_system_prompt(model, objective):
238238
os_search_str=os_search_str,
239239
operating_system=operating_system,
240240
)
241+
elif model == "claude-3-with-ocr":
242+
prompt = SYSTEM_PROMPT_OCR.format(
243+
objective=objective,
244+
cmd_string=cmd_string,
245+
os_search_str=os_search_str,
246+
operating_system=operating_system,
247+
)
241248
else:
242249
prompt = SYSTEM_PROMPT_STANDARD.format(
243250
objective=objective,

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,5 @@ google-generativeai==0.3.0
5151
aiohttp==3.9.1
5252
ultralytics==8.0.227
5353
easyocr==1.7.1
54-
ollama==0.1.6
54+
ollama==0.1.6
55+
anthropic

0 commit comments

Comments
 (0)