Skip to content

Commit dddeaa7

Browse files
committed
Support custom prompts for LLaMA models
1 parent cd7e106 commit dddeaa7

File tree

2 files changed

+35
-24
lines changed

2 files changed

+35
-24
lines changed

gptcli/config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import yaml
55

66
from gptcli.assistant import AssistantConfig
7+
from gptcli.llama import LLaMAModelConfig
8+
79

810
CONFIG_FILE_PATHS = [
911
os.path.join(os.path.expanduser("~"), ".config", "gpt-cli", "gpt.yml"),
@@ -24,7 +26,7 @@ class GptCliConfig:
2426
log_level: str = "INFO"
2527
assistants: Dict[str, AssistantConfig] = {}
2628
interactive: Optional[bool] = None
27-
llama_models: Optional[Dict[str, str]] = None
29+
llama_models: Optional[Dict[str, LLaMAModelConfig]] = None
2830

2931

3032
def choose_config_file(paths: List[str]) -> str:

gptcli/llama.py

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,70 @@
11
import os
2-
from pathlib import Path
32
import sys
4-
from typing import Iterator, List, Optional
5-
from llama_cpp import Llama
3+
from typing import Iterator, List, Optional, TypedDict, cast
4+
from llama_cpp import Completion, CompletionChunk, Llama
65

76
from gptcli.completion import CompletionProvider, Message
87

9-
LLAMA_MODELS: Optional[dict[str, str]] = None
108

9+
class LLaMAModelConfig(TypedDict):
10+
path: str
11+
human_prompt: str
12+
assistant_prompt: str
1113

12-
def init_llama_models(model_paths: dict[str, str]):
13-
for name, path in model_paths.items():
14-
if not os.path.isfile(path):
15-
print(f"LLaMA model {name} not found at {path}.")
14+
15+
LLAMA_MODELS: Optional[dict[str, LLaMAModelConfig]] = None
16+
17+
18+
def init_llama_models(models: dict[str, LLaMAModelConfig]):
19+
for name, model_config in models.items():
20+
if not os.path.isfile(model_config["path"]):
21+
print(f"LLaMA model {name} not found at {model_config['path']}.")
1622
sys.exit(1)
1723
if not name.startswith("llama"):
1824
print(f"LLaMA model names must start with `llama`, but got `{name}`.")
1925
sys.exit(1)
2026

2127
global LLAMA_MODELS
22-
LLAMA_MODELS = model_paths
28+
LLAMA_MODELS = models
2329

2430

25-
def role_to_name(role: str) -> str:
31+
def role_to_name(role: str, model_config: LLaMAModelConfig) -> str:
2632
if role == "system" or role == "user":
27-
return "### Human: "
33+
return model_config["human_prompt"]
2834
elif role == "assistant":
29-
return "### Assistant: "
35+
return model_config["assistant_prompt"]
3036
else:
3137
raise ValueError(f"Unknown role: {role}")
3238

3339

34-
def make_prompt(messages: List[Message]) -> str:
40+
def make_prompt(messages: List[Message], model_config: LLaMAModelConfig) -> str:
3541
prompt = "\n".join(
36-
[f"{role_to_name(message['role'])}{message['content']}" for message in messages]
42+
[
43+
f"{role_to_name(message['role'], model_config)} {message['content']}"
44+
for message in messages
45+
]
3746
)
38-
prompt += "### Assistant:"
47+
prompt += f"\n{model_config['assistant_prompt']}"
3948
return prompt
4049

4150

42-
END_SEQ = "### Human:"
43-
44-
4551
class LLaMACompletionProvider(CompletionProvider):
4652
def complete(
4753
self, messages: List[Message], args: dict, stream: bool = False
4854
) -> Iterator[str]:
4955
assert LLAMA_MODELS, "LLaMA models not initialized"
5056

57+
model_config = LLAMA_MODELS[args["model"]]
58+
5159
with suppress_stderr():
5260
llm = Llama(
53-
model_path=LLAMA_MODELS[args["model"]],
61+
model_path=model_config["path"],
5462
n_ctx=2048,
5563
verbose=False,
5664
use_mlock=True,
5765
)
58-
prompt = make_prompt(messages)
66+
prompt = make_prompt(messages, model_config)
67+
print(prompt)
5968

6069
extra_args = {}
6170
if "temperature" in args:
@@ -66,16 +75,16 @@ def complete(
6675
gen = llm.create_completion(
6776
prompt,
6877
max_tokens=1024,
69-
stop=END_SEQ,
78+
stop=model_config["human_prompt"],
7079
stream=stream,
7180
echo=False,
7281
**extra_args,
7382
)
7483
if stream:
75-
for x in gen:
84+
for x in cast(Iterator[CompletionChunk], gen):
7685
yield x["choices"][0]["text"]
7786
else:
78-
yield gen["choices"][0]["text"]
87+
yield cast(Completion, gen)["choices"][0]["text"]
7988

8089

8190
# https://stackoverflow.com/a/50438156

0 commit comments

Comments
 (0)