Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

activate/deactivate scorer and some other fixes #27

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion prompt_selection.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@
" suffix = rl_chain.ToSelectFrom(['0']))\n",
"\n",
"vw_chain.metrics.to_pandas()['score'].plot(label=\"vw\")\n",
"rnd_chain.metrics.to_pandas()['score'].plot(label=\"slates\")\n",
"rnd_chain.metrics.to_pandas()['score'].plot(label=\"random\")\n",
"plt.legend()"
]
}
Expand Down
68 changes: 53 additions & 15 deletions rl_chain.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@
" MealPlanner(name=\"One-Pan Beef Enchiladas Verdes with Mexican Cheese Blend & Hot Sauce Crema\", difficulty=\"Easy\", tags=\"Spicy, Easy Cleanup, Easy Prep\", desc=\"When it comes to Mexican-style cuisine, burritos typically get all the glory. In our humble opinion, enchiladas are an unsung dinner hero. They’re technically easier-to-assemble burritos that get smothered in a delicious sauce, but they’re really so much more than that! Ours start with spiced beef and charred green pepper that get rolled up in warm tortillas. This winning combo gets topped with tangy salsa verde and cheese, then baked until bubbly and melty. Hear that? That’s the sound of the dinner bell!\"),\n",
" MealPlanner(name=\"Chicken & Mushroom Flatbreads with Gouda Cream Sauce & Parmesan\", difficulty=\"Easy\", tags=\"\", desc=\"Yes we love our simple cheese pizza with red sauce but tonight, move over, marinara—there’s a new sauce in town. In this recipe, crispy flatbreads are slathered with a rich, creamy gouda-mustard sauce we just can’t get enough of. We top that off with a pile of caramelized onion and earthy cremini mushrooms. Shower with Parmesan, and that’s it. Simple, satisfying, and all in 30 minutes–a dinner idea you can’t pass up!\"),\n",
" MealPlanner(name=\"Sweet Potato & Pepper Quesadillas with Southwest Crema & Tomato Salsa\", difficulty=\"Easy\", tags=\"Veggie\", desc=\"This quesadilla is jam-packed with flavorful roasted sweet potato and green pepper, plus two types of gooey, melty cheese (how could we choose just one?!). Of course, we’d never forget the toppings—there’s a fresh tomato salsa and dollops of spiced lime crema. Now for the fun part: piling on a little bit of everything to construct the perfect bite!\"),\n",
" MealPlanner(name=\"One-Pan Trattoria Tortelloni Bake with a Crispy Parmesan Panko Topping\", difficulty=\"Easy\", tags=\"Veggie, Easy Cleanup, Easy Prep\", desc=\"Think a cheesy stuffed pasta can’t get any better? What about baking it in a creamy sauce with a crispy topping? In this recipe, we toss cheese-stuffed tortelloni in an herby tomato cream sauce, then top with Parmesan and panko breadcrumbs. Once broiled, it turns into a showstopping topping that’ll earn you plenty of oohs and aahs from your lucky fellow diners.\"),\n",
" MealPlanner(name=\"One-Pan Trattoria Tortelloni Bake with a Crispy vegan cheese Panko Topping\", difficulty=\"Easy\", tags=\"Veggie, Easy Cleanup, Easy Prep\", desc=\"Think a cheesy stuffed pasta can’t get any better? What about baking it in a creamy sauce with a crispy topping? In this recipe, we toss cheese-stuffed tortelloni in an herby tomato cream sauce, then top with vegan cheese and panko breadcrumbs. Once broiled, it turns into a showstopping topping that’ll earn you plenty of oohs and aahs from your lucky fellow diners.\"),\n",
"]\n",
"\n",
"meals = [f'title={action.name.replace(\":\", \"\").replace(\"|\", \"\")}' for action in actions]"
"meals = [[f'{action.name.replace(\":\", \"\").replace(\"|\", \"\")}', f' {action.tags}'] for action in actions]"
]
},
{
Expand Down Expand Up @@ -89,7 +89,8 @@
" input_variables=[\"meal\", \"text_to_personalize\"], template=_PROMPT_TEMPLATE\n",
")\n",
"\n",
"chain = rl_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)\n"
"chain = rl_chain.PickBest.from_llm(llm=llm, prompt=PROMPT, metrics_step=1)\n",
"random_chain = rl_chain.PickBest.from_llm(llm=llm, prompt=PROMPT, metrics_step=1, policy=rl_chain.PickBestRandomPolicy)\n"
]
},
{
Expand All @@ -98,16 +99,54 @@
"metadata": {},
"outputs": [],
"source": [
"response = chain.run(\n",
" meal = rl_chain.ToSelectFrom(meals),\n",
" User = rl_chain.BasedOn(\"Tom Hanks\"),\n",
" preference = rl_chain.BasedOn(\"Vegetarian, regular dairy is ok\"),\n",
" text_to_personalize = \"This is the weeks specialty dish, our master chefs believe you will love it!\",\n",
")\n",
"\n",
"print(response[\"response\"])\n",
"rr = response[\"selection_metadata\"]\n",
"print(f\"score: {rr.selected.score}, selection index: {rr.selected.index}, probability: {rr.selected.probability}, \")"
"for i in range(2):\n",
" try:\n",
" if i % 2:\n",
" print(\"Tom\")\n",
" response = chain.run(\n",
" meal = rl_chain.ToSelectFrom(meals),\n",
" User = rl_chain.BasedOn(\"Tom\"),\n",
" preference = rl_chain.BasedOn([\"Vegetarian\", \"regular dairy is ok\"]),\n",
" text_to_personalize = \"This is the weeks specialty dish, our master chefs believe you will love it!\",\n",
" )\n",
" random_chain.run(\n",
" meal = rl_chain.ToSelectFrom(meals),\n",
" User = rl_chain.BasedOn(\"Tom\"),\n",
" preference = rl_chain.BasedOn([\"Vegetarian\", \"regular dairy is ok\"]),\n",
" text_to_personalize = \"This is the weeks specialty dish, our master chefs believe you will love it!\",\n",
" )\n",
" else:\n",
" print(\"Anna\")\n",
" response = chain.run(\n",
" meal = rl_chain.ToSelectFrom(meals),\n",
" User = rl_chain.BasedOn(\"Anna\"),\n",
" preference = rl_chain.BasedOn([\"Loves meat\", \"especially beef\"]),\n",
" text_to_personalize = \"This is the weeks specialty dish, our master chefs believe you will love it!\",\n",
" )\n",
" random_chain.run(\n",
" meal = rl_chain.ToSelectFrom(meals),\n",
" User = rl_chain.BasedOn(\"Anna\"),\n",
" preference = rl_chain.BasedOn([\"Loves meat\", \"especially beef\"]),\n",
" text_to_personalize = \"This is the weeks specialty dish, our master chefs believe you will love it!\",\n",
" )\n",
"\n",
" print(response[\"response\"])\n",
" rr = response[\"selection_metadata\"]\n",
" print(f\"score: {rr.selected.score}, selection index: {rr.selected.index}, probability: {rr.selected.probability}, \")\n",
" except Exception as e:\n",
" print(e)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from matplotlib import pyplot as plt\n",
"chain.metrics.to_pandas()['score'].plot(label=\"vw\")\n",
"random_chain.metrics.to_pandas()['score'].plot(label=\"random\")\n",
"plt.legend()"
]
},
{
Expand Down Expand Up @@ -330,8 +369,7 @@
"class CustomSelectionScorer(rl_chain.SelectionScorer):\n",
" #grade or score the response\n",
" def score_response(\n",
" self, inputs, llm_response: str\n",
" ) -> float:\n",
" self, inputs, llm_response: str, event: rl_chain.PickBest.Event) -> float:\n",
" # do whatever you want here, use whatever inputs you supplied and return reward\n",
" reward = 1.0\n",
" return reward\n",
Expand Down
2 changes: 1 addition & 1 deletion rl_chain/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .pick_best_chain import PickBest
from .pick_best_chain import PickBest, PickBestRandomPolicy
from .slates_chain import (
SlatesPersonalizerChain,
SlatesRandomPolicy,
Expand Down
18 changes: 17 additions & 1 deletion rl_chain/pick_best_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import random
from langchain.base_language import BaseLanguageModel
from langchain.chains.llm import LLMChain
from sentence_transformers import SentenceTransformer
Expand Down Expand Up @@ -83,6 +84,21 @@ def format(self, event: PickBest.Event) -> str:
return example_string[:-1]


class PickBestRandomPolicy(base.Policy):
def __init__(self, feature_embedder: base.Embedder, *_, **__):
self.feature_embedder = feature_embedder

def predict(self, event: PickBest.Event) -> List[Tuple[int, float]]:
num_items = len(event.to_select_from)
return [(i, 1.0 / num_items) for i in range(num_items)]

def learn(self, event: PickBest.Event) -> Any:
pass

def log(self, event: PickBest.Event) -> Any:
pass


class PickBest(base.RLChain):
"""
PickBest class that utilizes the Vowpal Wabbit (VW) model for personalization.
Expand Down Expand Up @@ -153,7 +169,7 @@ def __init__(
"--quiet",
"--interactions=::",
"--coin",
"--epsilon=0.2",
"--squarecb",
]
else:
if "--cb_explore_adf" not in vw_cmd:
Expand Down
111 changes: 62 additions & 49 deletions rl_chain/rl_chain_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,9 @@ class SelectionScorer(ABC, BaseModel):
"""Abstract method to grade the chosen selection or the response of the llm"""

@abstractmethod
def score_response(self, inputs: Dict[str, Any], llm_response: str) -> float:
def score_response(
self, inputs: Dict[str, Any], llm_response: str, event: Event
) -> float:
pass


Expand All @@ -222,7 +224,7 @@ def get_default_system_prompt() -> SystemMessagePromptTemplate:

@staticmethod
def get_default_prompt() -> ChatPromptTemplate:
human_template = 'Given this based_on "{rl_chain_selected_based_on}" as the most important attribute, rank how good or bad this text is: "{llm_response}".'
human_template = 'Given this based_on "{rl_chain_selected_based_on}" as the most important attribute, rank how good or bad this text is: "{rl_chain_selected}".'
human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)
default_system_prompt = AutoSelectionScorer.get_default_system_prompt()
chat_prompt = ChatPromptTemplate.from_messages(
Expand All @@ -249,7 +251,9 @@ def set_prompt_and_llm_chain(cls, values):
values["llm_chain"] = LLMChain(llm=llm, prompt=prompt)
return values

def score_response(self, inputs: Dict[str, Any], llm_response: str) -> float:
def score_response(
self, inputs: Dict[str, Any], llm_response: str, event: Event
) -> float:
ranking = self.llm_chain.predict(llm_response=llm_response, **inputs)
ranking = ranking.strip()
try:
Expand Down Expand Up @@ -283,10 +287,11 @@ class RLChain(Chain):
prompt: BasePromptTemplate
selection_scorer: Union[SelectionScorer, None]
policy: Optional[Policy]
auto_embed: bool = True
auto_embed: bool = False
selection_scorer_activated: bool = True
metrics: Optional[MetricsTracker] = None
selected_input_key = "rl_chain_selected"
selected_based_on_input_key = "rl_chain_selected_based_on"
metrics: Optional[MetricsTracker] = None

def __init__(
self,
Expand Down Expand Up @@ -336,6 +341,42 @@ def output_keys(self) -> List[str]:
"""
return [self.output_key]

def update_with_delayed_score(
self, score: float, event: Event, force_score=False
) -> None:
"""
Learn will be called with the score specified and the actions/embeddings/etc stored in event

Will raise an error if selection_scorer is set, and force_score=True was not provided during the method call
"""
if self._can_use_selection_scorer() and not force_score:
raise RuntimeError(
"The selection scorer is set, and force_score was not set to True. Please set force_score=True to use this function."
)
self.metrics.on_feedback(score)
self._call_after_scoring_before_learning(event=event, score=score)
self.policy.learn(event=event)
self.policy.log(event=event)

def deactivate_selection_scorer(self) -> None:
"""
Deactivates the selection scorer, meaning that the chain will no longer attempt to use the selection scorer to score responses.
"""
self.selection_scorer_activated = False

def activate_selection_scorer(self) -> None:
"""
Activates the selection scorer, meaning that the chain will attempt to use the selection scorer to score responses.
"""
self.selection_scorer_activated = True

def save_progress(self) -> None:
"""
This function should be called whenever there is a need to save the progress of the VW (Vowpal Wabbit) model within the chain. It saves the current state of the VW model to a file.

"""
self.policy.save()

def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
super()._validate_inputs(inputs)
if (
Expand All @@ -346,6 +387,12 @@ def _validate_inputs(self, inputs: Dict[str, Any]) -> None:
f"The rl chain does not accept '{self.selected_input_key}' or '{self.selected_based_on_input_key}' as input keys, they are reserved for internal use during auto reward."
)

def _can_use_selection_scorer(self) -> bool:
"""
Returns whether the chain can use the selection scorer to score responses or not.
"""
return self.selection_scorer is not None and self.selection_scorer_activated

@abstractmethod
def _call_before_predict(self, inputs: Dict[str, Any]) -> Event:
pass
Expand All @@ -368,32 +415,6 @@ def _call_after_scoring_before_learning(
) -> Event:
pass

def update_with_delayed_score(
self, score: float, event: Event, force_score=False
) -> None:
"""
Learn will be called with the score specified and the actions/embeddings/etc stored in event

Will raise an error if selection_scorer is set, and force_score=True was not provided during the method call
"""
if self.selection_scorer and not force_score:
raise RuntimeError(
"The selection scorer is set, and force_score was not set to True. Please set force_score=True to use this function."
)
self.metrics.on_feedback(score)
self._call_after_scoring_before_learning(event=event, score=score)
self.policy.learn(event=event)
self.policy.log(event=event)

def set_auto_embed(self, auto_embed: bool) -> None:
"""
Set whether the chain should auto embed the inputs or not. If set to False, the inputs will not be embedded and the user will need to embed the inputs themselves before calling run.

Args:
auto_embed (bool): Whether the chain should auto embed the inputs or not.
"""
self.auto_embed = auto_embed

def _call(
self,
inputs: Dict[str, Any],
Expand Down Expand Up @@ -429,13 +450,13 @@ def _call(

score = None
try:
if self.selection_scorer:
if self._can_use_selection_scorer():
score = self.selection_scorer.score_response(
inputs=next_chain_inputs, llm_response=output
inputs=next_chain_inputs, llm_response=output, event=event
)
except Exception as e:
logger.info(
f"The LLM was not able to rank and the chain was not able to adjust to this response, error: {e}"
f"The selection scorer was not able to rank and the chain was not able to adjust to this response, error: {e}"
)
self.metrics.on_feedback(score)
event = self._call_after_scoring_before_learning(score=score, event=event)
Expand All @@ -444,21 +465,6 @@ def _call(

return {self.output_key: {"response": output, "selection_metadata": event}}

def save_progress(self) -> None:
"""
This function should be called whenever there is a need to save the progress of the VW (Vowpal Wabbit) model within the chain. It saves the current state of the VW model to a file.

File Naming Convention:
The file will be named using the pattern `model-<checkpoint>.vw`, where `<checkpoint>` is a monotonically increasing number. The numbering starts from 1, and increments by 1 for each subsequent save. If there are already saved checkpoints, the number used for `<checkpoint>` will be the next in the sequence.

Example:
If there are already two saved checkpoints, `model-1.vw` and `model-2.vw`, the next time this function is called, it will save the model as `model-3.vw`.

Note:
Be cautious when deleting or renaming checkpoint files manually, as this could cause the function to reuse checkpoint numbers.
"""
self.policy.save()

@property
def _chain_type(self) -> str:
return "llm_personalizer_chain"
Expand Down Expand Up @@ -517,6 +523,13 @@ def embed_list_type(
for embed_item in item:
if isinstance(embed_item, dict):
ret_list.append(embed_dict_type(embed_item, model))
elif isinstance(embed_item, list):
item_embedding = embed_list_type(embed_item, model, namespace)
# Get the first key from the first dictionary
first_key = next(iter(item_embedding[0]))
# Group the values under that key
grouping = {first_key: [item[first_key] for item in item_embedding]}
ret_list.append(grouping)
else:
ret_list.append(embed_string_type(embed_item, model, namespace))
return ret_list
Expand Down
Loading