Skip to content

Commit 21ab16e

Browse files
committed
wip
wip wip wip
1 parent a6e9047 commit 21ab16e

18 files changed

+435
-73
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,4 +159,6 @@ cython_debug/
159159
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
160160
#.idea/
161161

162-
.vscode/
162+
.vscode/
163+
164+
gpt.yml

Dockerfile

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
FROM python:3.9.17-alpine
2+
WORKDIR /app
3+
4+
RUN apk add build-base linux-headers clang
5+
COPY requirements.txt requirements_docker.txt ./
6+
RUN --mount=type=cache,target=/root/.cache/pip pip install -r requirements.txt
7+
RUN --mount=type=cache,target=/root/.cache/pip pip install -r requirements_docker.txt
8+
9+
COPY . .
10+
11+
RUN adduser -D gpt
12+
USER gpt
13+
RUN mkdir -p $HOME/.config/gpt-cli
14+
RUN cp /app/gpt.yml $HOME/.config/gpt-cli/gpt.yml
15+
16+
ENV GPTCLI_ALLOW_CODE_EXECUTION=1
17+
ENTRYPOINT ["python", "gpt.py"]

gpt.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
#!/usr/bin/env python
22
import os
3+
import subprocess
4+
import tempfile
5+
import traceback
36
from typing import cast
47
import openai
8+
import random
59
import argparse
610
import sys
711
import logging
8-
import google.generativeai as genai
12+
13+
# import google.generativeai as genai
914
import gptcli.anthropic
1015
from gptcli.assistant import (
1116
Assistant,
@@ -24,6 +29,7 @@
2429
choose_config_file,
2530
read_yaml_config,
2631
)
32+
from gptcli.interpreter import CodeInterpreterListener
2733
from gptcli.llama import init_llama_models
2834
from gptcli.logging import LoggingChatListener
2935
from gptcli.cost import PriceChatListener
@@ -89,6 +95,12 @@ def parse_args(config: GptCliConfig):
8995
default=config.log_file,
9096
help="The file to write logs to",
9197
)
98+
parser.add_argument(
99+
"--history_file",
100+
type=str,
101+
default=config.history_file,
102+
help="The file to write chat history to",
103+
)
92104
parser.add_argument(
93105
"--log_level",
94106
type=str,
@@ -166,8 +178,8 @@ def main():
166178
if config.anthropic_api_key:
167179
gptcli.anthropic.api_key = config.anthropic_api_key
168180

169-
if config.google_api_key:
170-
genai.configure(api_key=config.google_api_key)
181+
# if config.google_api_key:
182+
# genai.configure(api_key=config.google_api_key)
171183

172184
if config.llama_models is not None:
173185
init_llama_models(config.llama_models)
@@ -215,16 +227,24 @@ def __init__(self, assistant: Assistant, markdown: bool, show_price: bool):
215227
if show_price:
216228
listeners.append(PriceChatListener(assistant))
217229

230+
if os.environ.get("GPTCLI_ALLOW_CODE_EXECUTION") == "1":
231+
listeners.append(CodeInterpreterListener("python_eval"))
232+
218233
listener = CompositeChatListener(listeners)
219-
super().__init__(assistant, listener)
234+
super().__init__(
235+
assistant,
236+
listener,
237+
)
220238

221239

222240
def run_interactive(args, assistant):
223241
logger.info("Starting a new chat session. Assistant config: %s", assistant.config)
224242
session = CLIChatSession(
225243
assistant=assistant, markdown=args.markdown, show_price=args.show_price
226244
)
227-
history_filename = os.path.expanduser("~/.config/gpt-cli/history")
245+
history_filename = args.history_file or os.path.expanduser(
246+
"~/.config/gpt-cli/history"
247+
)
228248
os.makedirs(os.path.dirname(history_filename), exist_ok=True)
229249
input_provider = CLIUserInputProvider(history_filename=history_filename)
230250
session.loop(input_provider)

gpt.yml.template

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
markdown: True
2+
openai_api_key: <YOUR_OPENAI_API_KEY_HERE>
3+
log_file: /mnt/gpt.log
4+
log_level: DEBUG
5+
history_file: /mnt/history
6+
assistants:
7+
python:
8+
model: gpt-4-0613
9+
enable_code_execution: True
10+
messages:
11+
- { role: "system", content: "You are a helpful assistant. You have access to a Python environment. You can install missing packages. You have access to the internet. The user can see the code you are executing and its output: do not repeat them to the user verbatim. Pre-installed packages: numpy, matplotlib, ipython, ipykernel." }
12+

gptcli/anthropic.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
from typing import Iterator, List
3-
import anthropic
3+
4+
# import anthropic
45

56
from gptcli.completion import CompletionProvider, Message
67

gptcli/assistant.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import platform
55
from typing import Any, Dict, Iterator, Optional, TypedDict, List
66

7-
from gptcli.completion import CompletionProvider, ModelOverrides, Message
7+
from gptcli.completion import Completion, CompletionProvider, ModelOverrides, Message
88
from gptcli.google import GoogleCompletionProvider
99
from gptcli.llama import LLaMACompletionProvider
1010
from gptcli.openai import OpenAICompletionProvider
@@ -16,12 +16,14 @@ class AssistantConfig(TypedDict, total=False):
1616
model: str
1717
temperature: float
1818
top_p: float
19+
enable_code_execution: bool
1920

2021

2122
CONFIG_DEFAULTS = {
2223
"model": "gpt-3.5-turbo",
2324
"temperature": 0.7,
2425
"top_p": 1.0,
26+
"enable_code_execution": False,
2527
}
2628

2729
DEFAULT_ASSISTANTS: Dict[str, AssistantConfig] = {
@@ -89,7 +91,7 @@ def init_messages(self) -> List[Message]:
8991
return self.config.get("messages", [])[:]
9092

9193
def supported_overrides(self) -> List[str]:
92-
return ["model", "temperature", "top_p"]
94+
return ["model", "temperature", "top_p", "enable_code_execution"]
9395

9496
def _param(self, param: str, override_params: ModelOverrides) -> Any:
9597
# If the param is in the override_params, use that value
@@ -101,9 +103,15 @@ def _param(self, param: str, override_params: ModelOverrides) -> Any:
101103

102104
def complete_chat(
103105
self, messages, override_params: ModelOverrides = {}, stream: bool = True
104-
) -> Iterator[str]:
106+
) -> Iterator[Completion]:
105107
model = self._param("model", override_params)
106108
completion_provider = get_completion_provider(model)
109+
110+
enable_code_execution = (
111+
bool(self._param("enable_code_execution", override_params))
112+
and os.environ.get("GPTCLI_ALLOW_CODE_EXECUTION") == "1"
113+
)
114+
107115
return completion_provider.complete(
108116
messages,
109117
{
@@ -112,6 +120,7 @@ def complete_chat(
112120
"top_p": float(self._param("top_p", override_params)),
113121
},
114122
stream,
123+
enable_code_execution,
115124
)
116125

117126

gptcli/cli.py

Lines changed: 63 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1+
import base64
2+
import logging
13
import re
4+
import json
5+
from imgcat import imgcat
26
from prompt_toolkit import PromptSession
37
from prompt_toolkit.history import FileHistory
48
from openai import OpenAIError, InvalidRequestError
@@ -9,6 +13,7 @@
913
from typing import Any, Dict, Optional, Tuple
1014

1115
from rich.text import Text
16+
from gptcli.completion import FunctionCall, Message, merge_dicts
1217
from gptcli.session import (
1318
ALL_COMMANDS,
1419
COMMAND_CLEAR,
@@ -32,7 +37,7 @@
3237
class StreamingMarkdownPrinter:
3338
def __init__(self, console: Console, markdown: bool):
3439
self.console = console
35-
self.current_text = ""
40+
self.current_message = {}
3641
self.markdown = markdown
3742
self.live: Optional[Live] = None
3843

@@ -44,15 +49,50 @@ def __enter__(self) -> "StreamingMarkdownPrinter":
4449
self.live.__enter__()
4550
return self
4651

47-
def print(self, text: str):
48-
self.current_text += text
52+
def _format_function_call(self, function_call: FunctionCall) -> str:
53+
text = ""
54+
if function_call.get("name") == "python_eval":
55+
source = function_call.get("arguments", "")
56+
try:
57+
source = json.loads(source).get("source", "")
58+
except:
59+
source = source + '"}'
60+
try:
61+
source = json.loads(source).get("source", "")
62+
except:
63+
source = ""
64+
65+
text += "\n\nExecuting Python code:\n"
66+
text += f"```python\n{source}\n```"
67+
else:
68+
function_name = function_call.get("name", "?")
69+
function_arguments = function_call.get("arguments", {})
70+
text += f"""\n
71+
Calling function:
72+
73+
```
74+
{function_name}({function_arguments})
75+
```"""
76+
return text
77+
78+
def print(self, message_delta: Message):
79+
self.current_message = merge_dicts(self.current_message, message_delta)
80+
4981
if self.markdown:
5082
assert self.live
51-
content = Markdown(self.current_text, style="green")
83+
text = self.current_message.get("content", "")
84+
85+
function_call = self.current_message.get("function_call")
86+
if function_call:
87+
text += self._format_function_call(function_call)
88+
89+
content = Markdown(text, style="green")
5290
self.live.update(content)
5391
self.live.refresh()
5492
else:
55-
self.console.print(Text(text, style="green"), end="")
93+
self.console.print(
94+
Text(message_delta.get("content", ""), style="green"), end=""
95+
)
5696

5797
def __exit__(self, *args):
5898
if self.markdown:
@@ -66,17 +106,29 @@ def __init__(self, console: Console, markdown: bool):
66106
self.console = console
67107
self.markdown = markdown
68108
self.printer = StreamingMarkdownPrinter(self.console, self.markdown)
69-
self.first_token = True
70109

71110
def __enter__(self):
72111
self.printer.__enter__()
73112
return self
74113

75-
def on_next_token(self, token: str):
76-
if self.first_token and token.startswith(" "):
77-
token = token[1:]
78-
self.first_token = False
79-
self.printer.print(token)
114+
def on_message_delta(self, message_delta: Message):
115+
self.printer.print(message_delta)
116+
117+
def on_function_result(self, result: dict):
118+
self.console.print(Text("Function result:", style="yellow"))
119+
if "image/png" in result:
120+
image_base64 = result["image/png"]
121+
image_bytes = base64.b64decode(image_base64)
122+
imgcat(image_bytes)
123+
if "text/plain" in result:
124+
text = result["text/plain"]
125+
if self.markdown:
126+
content = Markdown(
127+
f"```\n{text}\n```",
128+
)
129+
else:
130+
content = Text(text, style="yellow")
131+
self.console.print(content)
80132

81133
def __exit__(self, *args):
82134
self.printer.__exit__(*args)

gptcli/completion.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,57 @@
11
from abc import abstractmethod
2-
from typing import Iterator, List, TypedDict
2+
from typing import Iterator, List, Optional, TypedDict
3+
from typing_extensions import Required
34

45

5-
class Message(TypedDict):
6-
role: str
7-
content: str
6+
class FunctionCall(TypedDict, total=False):
7+
name: str
8+
arguments: str
9+
10+
11+
class Message(TypedDict, total=False):
12+
role: Required[str]
13+
content: Optional[str]
14+
name: str
15+
function_call: FunctionCall
16+
17+
18+
def merge_dicts(a: dict, b: dict):
19+
"""
20+
Given two nested dicts with string values, merge dict `b` into dict `a`, concatenating
21+
string values.
22+
"""
23+
for key, value in b.items():
24+
if isinstance(value, dict):
25+
a[key] = merge_dicts(a.get(key, {}), value)
26+
elif value is not None:
27+
a[key] = a.get(key, "") + value
28+
return a
829

930

1031
class ModelOverrides(TypedDict, total=False):
1132
model: str
1233
temperature: float
1334
top_p: float
35+
enable_code_execution: bool
36+
37+
38+
class CompletionDelta(TypedDict):
39+
content: Optional[str]
40+
function_call: Optional[FunctionCall]
41+
42+
43+
class Completion(TypedDict):
44+
delta: Message
45+
finish_reason: Optional[str]
1446

1547

1648
class CompletionProvider:
1749
@abstractmethod
1850
def complete(
19-
self, messages: List[Message], args: dict, stream: bool = False
20-
) -> Iterator[str]:
51+
self,
52+
messages: List[Message],
53+
args: dict,
54+
stream: bool = False,
55+
enable_code_execution: bool = False,
56+
) -> Iterator[Completion]:
2157
pass

gptcli/composite.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from gptcli.session import ChatListener, ResponseStreamer
33

44

5-
from typing import List
5+
from typing import List, Optional
66

77

88
class CompositeResponseStreamer(ResponseStreamer):
@@ -14,9 +14,13 @@ def __enter__(self):
1414
streamer.__enter__()
1515
return self
1616

17-
def on_next_token(self, token: str):
17+
def on_message_delta(self, message_delta: Message):
1818
for streamer in self.streamers:
19-
streamer.on_next_token(token)
19+
streamer.on_message_delta(message_delta)
20+
21+
def on_function_result(self, result: dict):
22+
for streamer in self.streamers:
23+
streamer.on_function_result(result)
2024

2125
def __exit__(self, *args):
2226
for streamer in self.streamers:
@@ -57,3 +61,9 @@ def on_chat_response(
5761
):
5862
for listener in self.listeners:
5963
listener.on_chat_response(messages, response, overrides)
64+
65+
def on_function_call(self, function_name: str, **kwargs) -> Optional[str]:
66+
for listener in self.listeners:
67+
result = listener.on_function_call(function_name, **kwargs)
68+
if result is not None:
69+
return result

0 commit comments

Comments
 (0)