Skip to content

Commit

Permalink
fastchat interface nearly finished, waiting for implementations of ot…
Browse files Browse the repository at this point in the history
…her parts
  • Loading branch information
zenglingqi647 committed Dec 12, 2023
1 parent d544533 commit cb61b24
Showing 1 changed file with 5 additions and 28 deletions.
33 changes: 5 additions & 28 deletions FastChat/fastchat/serve/fastchat_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@ def prompt_for_input(self, prompt: str) -> str:
except EOFError as e:
break
return "\n".join(prompt_data)

def prompt_for_output(self, role: str):
"""indicates the role of output. Skipped."""
pass

def stream_output(self, output_stream):
pre = 0
for outputs in output_stream:
Expand All @@ -84,19 +84,6 @@ def stream_output(self, output_stream):
if now > pre:
pre = now
return " ".join(output_text)

# def stream_output(self, output_stream):
# pre = 0
# txt = ""
# for outputs in output_stream:
# output_text = outputs["text"]
# output_text = output_text.strip().split(" ")
# now = len(output_text) - 1
# if now > pre:
# txt += " ".join(output_text[pre:now])
# pre = now
# txt += " ".join(output_text[pre:])
# return txt

def print_output(self, text: str):
print(text)
Expand All @@ -107,8 +94,6 @@ def chat_loop(
tokenizer,
model_path: str,
device: str,
conv_template: Optional[str],
conv_system_msg: Optional[str],
temperature: float,
repetition_penalty: float,
max_new_tokens: int,
Expand All @@ -132,12 +117,7 @@ def chat_loop(

# Chat
def new_chat():
if conv_template:
conv = get_conv_template(conv_template)
else:
conv = get_conversation_template(model_path)
if conv_system_msg is not None:
conv.set_system_message(conv_system_msg)
conv = get_conv_template('"vicuna_v1.1"')
return conv

conv = None
Expand Down Expand Up @@ -178,6 +158,7 @@ def new_chat():
)
outputs = chatio.stream_output(output_stream)
conv.update_last_message(outputs.strip())
# conv.messages[-1][1] is the return

except KeyboardInterrupt:
print("stopped generation.")
Expand Down Expand Up @@ -247,8 +228,6 @@ def query(args, model, tokenizer, prompt, chatio):
tokenizer,
args.model_path,
args.device,
args.conv_template,
args.conv_system_msg,
args.temperature,
args.repetition_penalty,
args.max_new_tokens,
Expand All @@ -262,8 +241,6 @@ def query(args, model, tokenizer, prompt, chatio):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
add_model_args(parser)
parser.add_argument("--conv-template", type=str, default=None, help="Conversation prompt template.")
parser.add_argument("--conv-system-msg", type=str, default=None, help="Conversation system message.")
parser.add_argument("--temperature", type=float, default=0.7)
parser.add_argument("--repetition_penalty", type=float, default=1.0)
parser.add_argument("--max-new-tokens", type=int, default=512)
Expand All @@ -289,4 +266,4 @@ def query(args, model, tokenizer, prompt, chatio):
prompt = "hello world"
chatio = SimpleChatIO(args.multiline)
model, tokenizer = setup(args)
query(args, model, tokenizer, prompt, chatio)
query(args, model, tokenizer, prompt, chatio)

0 comments on commit cb61b24

Please sign in to comment.