Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ classifiers = [
"Programming Language :: Python :: Implementation :: PyPy",
]
dependencies = [
"langchain==0.0.113",
"langchain==0.0.121",
"openai",
"tiktoken",
"pinecone-client",
Expand Down
7 changes: 5 additions & 2 deletions summ/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
class Options(BaseModel):
debug: bool
verbose: bool
model_name: str


class CLI:
Expand Down Expand Up @@ -71,9 +72,10 @@ def handler(signum, frame):
@click.option("--debug/--no-debug", default=True)
@click.option("--verbose/--no-verbose", default=False)
@click.option("-n", default=3)
@click.option("--model-name", default="gpt-3.5-turbo")
@click.pass_context
def cli(ctx, debug: bool, verbose: bool, n: int):
ctx.obj = Options(debug=debug, verbose=verbose)
def cli(ctx, debug: bool, verbose: bool, model_name: str, n: int):
ctx.obj = Options(debug=debug, verbose=verbose, model_name=model_name)
langchain.verbose = verbose
summ.n = n

Expand Down Expand Up @@ -111,6 +113,7 @@ def query(ctx: click.Context, query: str, classes: list[Classes]):
classes=classes,
corpus=list(pipe.corpus()),
debug=ctx.obj.debug,
model_name=ctx.obj.model_name
)
click.echo("\n")
click.secho(response)
Expand Down
28 changes: 19 additions & 9 deletions summ/cli/widgets/home.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def watch_in_progress(self) -> None:
class Home(Static):
in_progress = reactive(False, init=False)
question: reactive[str] = reactive("", init=False)
model_name: reactive[str] = reactive("", init=False)

def __init__(self, summ: Summ, pipe: Pipeline, **kwargs) -> None:
self.summ = summ
Expand All @@ -69,6 +70,12 @@ def compose(self) -> ComposeResult:
id="question",
placeholder="What type of animal is Cronutt?",
),
InputWithLabel(
name="Model Name",
id="model_name",
placeholder="gpt-3.5-turbo",
value="gpt-3.5-turbo",
),
Container(
Button("Query", variant="success", id="query", disabled=True),
Button("Populate", variant="warning", id="populate"),
Expand All @@ -84,7 +91,7 @@ def compose(self) -> ComposeResult:

def action_query(self):
return self.summ.query(
self.question, classes=[], corpus=list(self.pipe.corpus()), debug=True
self.question, classes=[], corpus=list(self.pipe.corpus()), debug=True, model_name=self.model_name
)

def action_populate(self):
Expand All @@ -108,11 +115,14 @@ async def on_button_pressed(self, event: Button.Pressed):
self.in_progress = False

def on_input_changed(self, event: Input.Changed):
self.question = event.value
self.query_one(OutputTree).question = event.value
if event.value:
self.query_one("#query", Button).disabled = False
self.query_one("#populate", Button).disabled = True
else:
self.query_one("#query", Button).disabled = True
self.query_one("#populate", Button).disabled = False
if event.sender.id == "question":
self.question = event.value
self.query_one(OutputTree).question = event.value
if event.value:
self.query_one("#query", Button).disabled = False
self.query_one("#populate", Button).disabled = True
else:
self.query_one("#query", Button).disabled = True
self.query_one("#populate", Button).disabled = False
elif event.sender.id == "model_name":
self.model_name = event.value
4 changes: 2 additions & 2 deletions summ/query/querier.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ class Querier(Chain):
),
)

def __init__(self, index: str, debug: bool = False):
super().__init__(debug=debug)
def __init__(self, index: str, debug: bool = False, model_name: str = "gpt-3.5-turbo"):
super().__init__(debug=debug, model_name=model_name)
self.index_name = index
self.embeddings = OpenAIEmbeddings()
self.summarizer = Summarizer()
Expand Down
7 changes: 4 additions & 3 deletions summ/shared/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from langchain.chains import TransformChain
from langchain.chains.base import Chain as LChain
from langchain.docstore.document import Document
from langchain.llms import OpenAI
from langchain.chat_models import ChatOpenAI
from openai.error import RateLimitError
from pydantic import BaseModel
from retry import retry
Expand Down Expand Up @@ -255,11 +255,12 @@ def increment_n_tokens(cls, n: int):
def tokens_used(cls) -> int:
return cls._n_tokens

def __init__(self, debug: bool = False, verbose: bool = False):
self.llm = OpenAI(temperature=0.0)
def __init__(self, debug: bool = False, verbose: bool = False, model_name: str = "gpt-3.5-turbo"):
self.llm = ChatOpenAI(temperature=0.0, model_name=model_name)
self.pool = Parallel(n_jobs=-1, prefer="threads", verbose=10 if verbose else 0)
self.verbose = verbose
self.debug = debug
self.model_name = model_name

def spawn(self, cls: Type[T], **kwargs) -> T:
instance = cls(debug=self.debug, verbose=self.verbose, **kwargs)
Expand Down
3 changes: 2 additions & 1 deletion summ/summ.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def query(
classes: list[Classes] = [],
corpus: list[Document] = [],
debug: bool = True,
model_name: str = "gpt-3.5-turbo",
) -> str:
"""
Query a pre-populated model with a given question.
Expand All @@ -70,5 +71,5 @@ def query(
raise Exception(
f"Index {self.index} not found! Please run `summ populate` first."
)
querier = Querier(index=self.index, debug=debug)
querier = Querier(index=self.index, debug=debug, model_name=model_name)
return querier.query(question, n=self.n, classes=classes, corpus=corpus)