Skip to content

Commit d2a43c6

Browse files
authored
Add option to disable response streaming (TheR1D#290)
1 parent 30fad64 commit d2a43c6

File tree

5 files changed

+20
-3
lines changed

5 files changed

+20
-3
lines changed

Diff for: README.md

+2
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,8 @@ DEFAULT_COLOR=magenta
313313
SYSTEM_ROLES=false
314314
# When in --shell mode, default to "Y" for no input.
315315
DEFAULT_EXECUTE_SHELL_CMD=false
316+
# Disable streaming of responses
317+
DISABLE_STREAMING=false
316318
```
317319
Possible options for `DEFAULT_COLOR`: black, red, green, yellow, blue, magenta, cyan, white, bright_black, bright_red, bright_green, bright_yellow, bright_blue, bright_magenta, bright_cyan, bright_white.
318320

Diff for: sgpt/client.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
CACHE_LENGTH = int(cfg.get("CACHE_LENGTH"))
1111
CACHE_PATH = Path(cfg.get("CACHE_PATH"))
1212
REQUEST_TIMEOUT = int(cfg.get("REQUEST_TIMEOUT"))
13+
DISABLE_STREAMING = str(cfg.get("DISABLE_STREAMING"))
1314

1415

1516
class OpenAIClient:
@@ -37,12 +38,13 @@ def _request(
3738
:param top_probability: Float in 0.0 - 1.0 range.
3839
:return: Response body JSON.
3940
"""
41+
stream = DISABLE_STREAMING == "false"
4042
data = {
4143
"messages": messages,
4244
"model": model,
4345
"temperature": temperature,
4446
"top_p": top_probability,
45-
"stream": True,
47+
"stream": stream,
4648
}
4749
endpoint = f"{self.api_host}/v1/chat/completions"
4850
response = requests.post(
@@ -54,11 +56,15 @@ def _request(
5456
},
5557
json=data,
5658
timeout=REQUEST_TIMEOUT,
57-
stream=True,
59+
stream=stream,
5860
)
5961
response.raise_for_status()
6062
# TODO: Optimise.
6163
# https://github.com/openai/openai-python/blob/237448dc072a2c062698da3f9f512fae38300c1c/openai/api_requestor.py#L98
64+
if not stream:
65+
data = response.json()
66+
yield data["choices"][0]["message"]["content"] # type: ignore
67+
return
6268
for line in response.iter_lines():
6369
data = line.lstrip(b"data: ").decode("utf-8")
6470
if data == "[DONE]": # type: ignore

Diff for: sgpt/config.py

+1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
"ROLE_STORAGE_PATH": os.getenv("ROLE_STORAGE_PATH", str(ROLE_STORAGE_PATH)),
3030
"SYSTEM_ROLES": os.getenv("SYSTEM_ROLES", "false"),
3131
"DEFAULT_EXECUTE_SHELL_CMD": os.getenv("DEFAULT_EXECUTE_SHELL_CMD", "false"),
32+
"DISABLE_STREAMING": os.getenv("DISABLE_STREAMING", "false")
3233
# New features might add their own config variables here.
3334
}
3435

Diff for: sgpt/handlers/handler.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,11 @@ def get_completion(self, **kwargs: Any) -> Generator[str, None, None]:
2727
def handle(self, prompt: str, **kwargs: Any) -> str:
2828
messages = self.make_messages(self.make_prompt(prompt))
2929
full_completion = ""
30+
stream = cfg.get("DISABLE_STREAMING") == "false"
31+
if not stream:
32+
typer.echo("Loading...\r", nl=False)
3033
for word in self.get_completion(messages=messages, **kwargs):
3134
typer.secho(word, fg=self.color, bold=True, nl=False)
3235
full_completion += word
33-
typer.echo()
36+
typer.echo("\033[K" if not stream else "") # Overwrite "loading..."
3437
return full_completion

Diff for: tests/test_integration.py

+5
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@
3030

3131

3232
class TestShellGpt(TestCase):
33+
@classmethod
34+
def setUpClass(cls):
35+
# Response streaming should be enabled for these tests.
36+
assert cfg.get("DISABLE_STREAMING") == "false"
37+
3338
def setUp(self) -> None:
3439
# Just to not spam the API.
3540
sleep(1)

0 commit comments

Comments
 (0)