Skip to content

Commit cd7e106

Browse files
committed
Replace janky LLaMA implementation with llama-cpp-python
1 parent 67f64ed commit cd7e106

File tree

4 files changed

+48
-82
lines changed

4 files changed

+48
-82
lines changed

gpt.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -169,10 +169,8 @@ def main():
169169
if config.google_api_key:
170170
genai.configure(api_key=config.google_api_key)
171171

172-
if config.llama_config is not None:
173-
init_llama_models(
174-
config.llama_config["llama_cpp_dir"], config.llama_config["models"]
175-
)
172+
if config.llama_models is not None:
173+
init_llama_models(config.llama_models)
176174

177175
assistant = init_assistant(cast(AssistantGlobalArgs, args), config.assistants)
178176

gptcli/config.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,6 @@
1111
]
1212

1313

14-
class LLaMAConfig(TypedDict):
15-
llama_cpp_dir: str
16-
models: Dict[str, str] # name -> path
17-
18-
1914
@dataclass
2015
class GptCliConfig:
2116
default_assistant: str = "general"
@@ -29,7 +24,7 @@ class GptCliConfig:
2924
log_level: str = "INFO"
3025
assistants: Dict[str, AssistantConfig] = {}
3126
interactive: Optional[bool] = None
32-
llama_config: Optional[LLaMAConfig] = None
27+
llama_models: Optional[Dict[str, str]] = None
3328

3429

3530
def choose_config_file(paths: List[str]) -> str:

gptcli/llama.py

Lines changed: 44 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,15 @@
1-
import logging
21
import os
3-
import signal
4-
import subprocess
52
from pathlib import Path
63
import sys
74
from typing import Iterator, List, Optional
5+
from llama_cpp import Llama
86

97
from gptcli.completion import CompletionProvider, Message
108

11-
LLAMA_DIR: Optional[Path] = None
12-
LLAMA_MODELS: Optional[dict[str, Path]] = None
9+
LLAMA_MODELS: Optional[dict[str, str]] = None
1310

1411

15-
def init_llama_models(llama_cpp_dir: str, model_paths: dict[str, str]):
12+
def init_llama_models(model_paths: dict[str, str]):
1613
for name, path in model_paths.items():
1714
if not os.path.isfile(path):
1815
print(f"LLaMA model {name} not found at {path}.")
@@ -21,9 +18,8 @@ def init_llama_models(llama_cpp_dir: str, model_paths: dict[str, str]):
2118
print(f"LLaMA model names must start with `llama`, but got `{name}`.")
2219
sys.exit(1)
2320

24-
global LLAMA_DIR, LLAMA_MODELS
25-
LLAMA_DIR = Path(llama_cpp_dir)
26-
LLAMA_MODELS = {name: Path(path) for name, path in model_paths.items()}
21+
global LLAMA_MODELS
22+
LLAMA_MODELS = model_paths
2723

2824

2925
def role_to_name(role: str) -> str:
@@ -50,75 +46,51 @@ class LLaMACompletionProvider(CompletionProvider):
5046
def complete(
5147
self, messages: List[Message], args: dict, stream: bool = False
5248
) -> Iterator[str]:
53-
assert LLAMA_DIR, "LLaMA models not initialized"
5449
assert LLAMA_MODELS, "LLaMA models not initialized"
5550

51+
with suppress_stderr():
52+
llm = Llama(
53+
model_path=LLAMA_MODELS[args["model"]],
54+
n_ctx=2048,
55+
verbose=False,
56+
use_mlock=True,
57+
)
5658
prompt = make_prompt(messages)
5759

58-
extra_args = []
60+
extra_args = {}
5961
if "temperature" in args:
60-
extra_args += ["--temp", str(args["temperature"])]
62+
extra_args["temperature"] = args["temperature"]
6163
if "top_p" in args:
62-
extra_args += ["--top_p", str(args["top_p"])]
63-
64-
process = subprocess.Popen(
65-
[
66-
LLAMA_DIR / "main",
67-
"--model",
68-
LLAMA_MODELS[args["model"]],
69-
"-n",
70-
"4096",
71-
"-r",
72-
"### Human:",
73-
"-p",
74-
prompt,
75-
*extra_args,
76-
],
77-
stdin=subprocess.PIPE,
78-
stdout=subprocess.PIPE,
79-
stderr=subprocess.PIPE,
80-
shell=False,
81-
text=True,
64+
extra_args["top_p"] = args["top_p"]
65+
66+
gen = llm.create_completion(
67+
prompt,
68+
max_tokens=1024,
69+
stop=END_SEQ,
70+
stream=stream,
71+
echo=False,
72+
**extra_args,
8273
)
83-
8474
if stream:
85-
return self._read_stream(process, prompt)
75+
for x in gen:
76+
yield x["choices"][0]["text"]
8677
else:
87-
return self._read(process, prompt)
88-
89-
def _read_stream(self, process: subprocess.Popen, prompt: str) -> Iterator[str]:
90-
assert process.stdout, "LLaMA stdout not set"
91-
assert process.stderr, "LLaMA stderr not set"
92-
93-
buffer = ""
94-
num_read = 0
95-
char = process.stdout.read(1)
96-
97-
try:
98-
while char := process.stdout.read(1):
99-
num_read += len(char)
100-
if num_read <= len(prompt):
101-
continue
102-
103-
buffer += char
104-
if not buffer.startswith("#") or (buffer != END_SEQ[: len(buffer)]):
105-
yield buffer
106-
buffer = ""
107-
elif buffer.endswith(END_SEQ):
108-
yield buffer[: -len(END_SEQ)]
109-
buffer = ""
110-
process.terminate()
111-
break
112-
except KeyboardInterrupt:
113-
os.kill(process.pid, signal.SIGINT)
114-
raise
115-
finally:
116-
process.wait()
117-
stderr = "".join(process.stderr.readlines())
118-
logging.debug(f"LLaMA stderr: {stderr}")
119-
120-
def _read(self, process: subprocess.Popen, prompt: str) -> Iterator[str]:
121-
result = ""
122-
for token in self._read_stream(process, prompt):
123-
result += token
124-
yield result
78+
yield gen["choices"][0]["text"]
79+
80+
81+
# https://stackoverflow.com/a/50438156
82+
class suppress_stderr(object):
83+
def __enter__(self):
84+
self.errnull_file = open(os.devnull, "w")
85+
self.old_stderr_fileno_undup = sys.stderr.fileno()
86+
self.old_stderr_fileno = os.dup(sys.stderr.fileno())
87+
self.old_stderr = sys.stderr
88+
os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
89+
sys.stderr = self.errnull_file
90+
return self
91+
92+
def __exit__(self, *_):
93+
sys.stderr = self.old_stderr
94+
os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
95+
os.close(self.old_stderr_fileno)
96+
self.errnull_file.close()

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
anthropic==0.2.8
22
black==23.1.0
33
google-generativeai==0.1.0rc2
4+
llama-cpp-python==0.1.57
45
openai==0.27.2
56
prompt-toolkit==3.0.38
67
pytest==7.3.1

0 commit comments

Comments
 (0)