diff --git a/pyproject.toml b/pyproject.toml index d99b24f..81a8c27 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: PyPy", ] dependencies = [ - "langchain==0.0.113", + "langchain==0.0.121", "openai", "tiktoken", "pinecone-client", diff --git a/summ/cli/cli.py b/summ/cli/cli.py index 8ec0b14..70be30b 100644 --- a/summ/cli/cli.py +++ b/summ/cli/cli.py @@ -19,6 +19,7 @@ class Options(BaseModel): debug: bool verbose: bool + model_name: str class CLI: @@ -71,9 +72,10 @@ def handler(signum, frame): @click.option("--debug/--no-debug", default=True) @click.option("--verbose/--no-verbose", default=False) @click.option("-n", default=3) + @click.option("--model-name", default="gpt-3.5-turbo") @click.pass_context - def cli(ctx, debug: bool, verbose: bool, n: int): - ctx.obj = Options(debug=debug, verbose=verbose) + def cli(ctx, debug: bool, verbose: bool, model_name: str, n: int): + ctx.obj = Options(debug=debug, verbose=verbose, model_name=model_name) langchain.verbose = verbose summ.n = n @@ -111,6 +113,7 @@ def query(ctx: click.Context, query: str, classes: list[Classes]): classes=classes, corpus=list(pipe.corpus()), debug=ctx.obj.debug, + model_name=ctx.obj.model_name ) click.echo("\n") click.secho(response) diff --git a/summ/cli/widgets/home.py b/summ/cli/widgets/home.py index 7066cc8..6a475c3 100644 --- a/summ/cli/widgets/home.py +++ b/summ/cli/widgets/home.py @@ -44,6 +44,7 @@ def watch_in_progress(self) -> None: class Home(Static): in_progress = reactive(False, init=False) question: reactive[str] = reactive("", init=False) + model_name: reactive[str] = reactive("", init=False) def __init__(self, summ: Summ, pipe: Pipeline, **kwargs) -> None: self.summ = summ @@ -69,6 +70,12 @@ def compose(self) -> ComposeResult: id="question", placeholder="What type of animal is Cronutt?", ), + InputWithLabel( + name="Model Name", + id="model_name", + placeholder="gpt-3.5-turbo", + value="gpt-3.5-turbo", + ), Container( Button("Query", variant="success", id="query", disabled=True), Button("Populate", variant="warning", id="populate"), @@ -84,7 +91,7 @@ def compose(self) -> ComposeResult: def action_query(self): return self.summ.query( - self.question, classes=[], corpus=list(self.pipe.corpus()), debug=True + self.question, classes=[], corpus=list(self.pipe.corpus()), debug=True, model_name=self.model_name ) def action_populate(self): @@ -108,11 +115,14 @@ async def on_button_pressed(self, event: Button.Pressed): self.in_progress = False def on_input_changed(self, event: Input.Changed): - self.question = event.value - self.query_one(OutputTree).question = event.value - if event.value: - self.query_one("#query", Button).disabled = False - self.query_one("#populate", Button).disabled = True - else: - self.query_one("#query", Button).disabled = True - self.query_one("#populate", Button).disabled = False + if event.sender.id == "question": + self.question = event.value + self.query_one(OutputTree).question = event.value + if event.value: + self.query_one("#query", Button).disabled = False + self.query_one("#populate", Button).disabled = True + else: + self.query_one("#query", Button).disabled = True + self.query_one("#populate", Button).disabled = False + elif event.sender.id == "model_name": + self.model_name = event.value diff --git a/summ/query/querier.py b/summ/query/querier.py index 223d275..8364147 100644 --- a/summ/query/querier.py +++ b/summ/query/querier.py @@ -63,8 +63,8 @@ class Querier(Chain): ), ) - def __init__(self, index: str, debug: bool = False): - super().__init__(debug=debug) + def __init__(self, index: str, debug: bool = False, model_name: str = "gpt-3.5-turbo"): + super().__init__(debug=debug, model_name=model_name) self.index_name = index self.embeddings = OpenAIEmbeddings() self.summarizer = Summarizer() diff --git a/summ/shared/chain.py b/summ/shared/chain.py index 9535d32..865e626 100644 --- a/summ/shared/chain.py +++ b/summ/shared/chain.py @@ -31,7 +31,7 @@ from langchain.chains import TransformChain from langchain.chains.base import Chain as LChain from langchain.docstore.document import Document -from langchain.llms import OpenAI +from langchain.chat_models import ChatOpenAI from openai.error import RateLimitError from pydantic import BaseModel from retry import retry @@ -255,11 +255,12 @@ def increment_n_tokens(cls, n: int): def tokens_used(cls) -> int: return cls._n_tokens - def __init__(self, debug: bool = False, verbose: bool = False): - self.llm = OpenAI(temperature=0.0) + def __init__(self, debug: bool = False, verbose: bool = False, model_name: str = "gpt-3.5-turbo"): + self.llm = ChatOpenAI(temperature=0.0, model_name=model_name) self.pool = Parallel(n_jobs=-1, prefer="threads", verbose=10 if verbose else 0) self.verbose = verbose self.debug = debug + self.model_name = model_name def spawn(self, cls: Type[T], **kwargs) -> T: instance = cls(debug=self.debug, verbose=self.verbose, **kwargs) diff --git a/summ/summ.py b/summ/summ.py index 59ee187..ba61493 100644 --- a/summ/summ.py +++ b/summ/summ.py @@ -56,6 +56,7 @@ def query( classes: list[Classes] = [], corpus: list[Document] = [], debug: bool = True, + model_name: str = "gpt-3.5-turbo", ) -> str: """ Query a pre-populated model with a given question. @@ -70,5 +71,5 @@ def query( raise Exception( f"Index {self.index} not found! Please run `summ populate` first." ) - querier = Querier(index=self.index, debug=debug) + querier = Querier(index=self.index, debug=debug, model_name=model_name) return querier.query(question, n=self.n, classes=classes, corpus=corpus)