1+ import base64
2+ import logging
13import re
4+ import json
5+ from imgcat import imgcat
26from prompt_toolkit import PromptSession
37from prompt_toolkit .history import FileHistory
48from openai import OpenAIError , InvalidRequestError
913from typing import Any , Dict , Optional , Tuple
1014
1115from rich .text import Text
16+ from gptcli .completion import FunctionCall , Message , merge_dicts
1217from gptcli .session import (
1318 ALL_COMMANDS ,
1419 COMMAND_CLEAR ,
3237class StreamingMarkdownPrinter :
3338 def __init__ (self , console : Console , markdown : bool ):
3439 self .console = console
35- self .current_text = ""
40+ self .current_message = {}
3641 self .markdown = markdown
3742 self .live : Optional [Live ] = None
3843
@@ -44,15 +49,50 @@ def __enter__(self) -> "StreamingMarkdownPrinter":
4449 self .live .__enter__ ()
4550 return self
4651
47- def print (self , text : str ):
48- self .current_text += text
52+ def _format_function_call (self , function_call : FunctionCall ) -> str :
53+ text = ""
54+ if function_call .get ("name" ) == "python_eval" :
55+ source = function_call .get ("arguments" , "" )
56+ try :
57+ source = json .loads (source ).get ("source" , "" )
58+ except :
59+ source = source + '"}'
60+ try :
61+ source = json .loads (source ).get ("source" , "" )
62+ except :
63+ source = ""
64+
65+ text += "\n \n Executing Python code:\n "
66+ text += f"```python\n { source } \n ```"
67+ else :
68+ function_name = function_call .get ("name" , "?" )
69+ function_arguments = function_call .get ("arguments" , {})
70+ text += f"""\n
71+ Calling function:
72+
73+ ```
74+ { function_name } ({ function_arguments } )
75+ ```"""
76+ return text
77+
78+ def print (self , message_delta : Message ):
79+ self .current_message = merge_dicts (self .current_message , message_delta )
80+
4981 if self .markdown :
5082 assert self .live
51- content = Markdown (self .current_text , style = "green" )
83+ text = self .current_message .get ("content" , "" )
84+
85+ function_call = self .current_message .get ("function_call" )
86+ if function_call :
87+ text += self ._format_function_call (function_call )
88+
89+ content = Markdown (text , style = "green" )
5290 self .live .update (content )
5391 self .live .refresh ()
5492 else :
55- self .console .print (Text (text , style = "green" ), end = "" )
93+ self .console .print (
94+ Text (message_delta .get ("content" , "" ), style = "green" ), end = ""
95+ )
5696
5797 def __exit__ (self , * args ):
5898 if self .markdown :
@@ -66,17 +106,29 @@ def __init__(self, console: Console, markdown: bool):
66106 self .console = console
67107 self .markdown = markdown
68108 self .printer = StreamingMarkdownPrinter (self .console , self .markdown )
69- self .first_token = True
70109
71110 def __enter__ (self ):
72111 self .printer .__enter__ ()
73112 return self
74113
75- def on_next_token (self , token : str ):
76- if self .first_token and token .startswith (" " ):
77- token = token [1 :]
78- self .first_token = False
79- self .printer .print (token )
114+ def on_message_delta (self , message_delta : Message ):
115+ self .printer .print (message_delta )
116+
117+ def on_function_result (self , result : dict ):
118+ self .console .print (Text ("Function result:" , style = "yellow" ))
119+ if "image/png" in result :
120+ image_base64 = result ["image/png" ]
121+ image_bytes = base64 .b64decode (image_base64 )
122+ imgcat (image_bytes )
123+ if "text/plain" in result :
124+ text = result ["text/plain" ]
125+ if self .markdown :
126+ content = Markdown (
127+ f"```\n { text } \n ```" ,
128+ )
129+ else :
130+ content = Text (text , style = "yellow" )
131+ self .console .print (content )
80132
81133 def __exit__ (self , * args ):
82134 self .printer .__exit__ (* args )
0 commit comments