Skip to content

Commit

Permalink
Merge pull request #789 from expectedparrot/turbo_mode
Browse files Browse the repository at this point in the history
Turbo mode
  • Loading branch information
apostolosfilippas authored Jul 21, 2024
2 parents 4401659 + ef2074f commit 2fe050c
Show file tree
Hide file tree
Showing 11 changed files with 228 additions and 34 deletions.
9 changes: 4 additions & 5 deletions edsl/agents/Invigilator.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,14 @@ def _format_raw_response(
This cleans up the raw response to make it suitable to pass to AgentResponseDict.
"""
# not actually used, but this removes the temptation to delete agent from the signature
_ = agent
try:
response = question._validate_answer(raw_response)
except Exception as e:
"""If the response is invalid, remove it from the cache and raise the exception."""
self._remove_from_cache(raw_response)
raise e

# breakpoint()
question_dict = self.survey.question_names_to_questions()
for other_question, answer in self.current_answers.items():
if other_question in question_dict:
Expand All @@ -95,12 +94,10 @@ def _format_raw_response(
question_dict[new_question].comment = answer

combined_dict = {**question_dict, **scenario}
# print("combined_dict: ", combined_dict)
# print("response: ", response)
# breakpoint()
answer = question._translate_answer_code_to_answer(
response["answer"], combined_dict
)
#breakpoint()
data = {
"answer": answer,
"comment": response.get(
Expand All @@ -111,6 +108,8 @@ def _format_raw_response(
"cached_response": raw_response.get("cached_response", None),
"usage": raw_response.get("usage", {}),
"raw_model_response": raw_model_response,
"cache_used": raw_response.get("cache_used", False),
"cache_key": raw_response.get("cache_key", None),
}
return AgentResponseDict(**data)

Expand Down
16 changes: 13 additions & 3 deletions edsl/data/Cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
data: Optional[Union["SQLiteDict", dict]] = None,
immediate_write: bool = True,
method=None,
verbose = False
):
"""
Create two dictionaries to store the cache data.
Expand All @@ -59,6 +60,7 @@ def __init__(
self.new_entries = {}
self.new_entries_to_write_later = {}
self.coop = None
self.verbose = verbose

self.filename = filename
if filename and data:
Expand Down Expand Up @@ -122,7 +124,7 @@ def fetch(
system_prompt: str,
user_prompt: str,
iteration: int,
) -> Union[None, str]:
) -> tuple(Union[None, str], str):
"""
Fetch a value (LLM output) from the cache.
Expand All @@ -135,7 +137,7 @@ def fetch(
Return None if the response is not found.
>>> c = Cache()
>>> c.fetch(model="gpt-3", parameters="default", system_prompt="Hello", user_prompt="Hi", iteration=1) is None
>>> c.fetch(model="gpt-3", parameters="default", system_prompt="Hello", user_prompt="Hi", iteration=1)[0] is None
True
Expand All @@ -151,8 +153,13 @@ def fetch(
)
entry = self.data.get(key, None)
if entry is not None:
if self.verbose:
print(f"Cache hit for key: {key}")
self.fetched_data[key] = entry
return None if entry is None else entry.output
else:
if self.verbose:
print(f"Cache miss for key: {key}")
return None if entry is None else entry.output, key

def store(
self,
Expand Down Expand Up @@ -354,6 +361,9 @@ def __exit__(self, exc_type, exc_value, traceback):
for key, entry in self.new_entries_to_write_later.items():
self.data[key] = entry

if self.filename:
self.write(self.filename)

####################
# DUNDER / USEFUL
####################
Expand Down
4 changes: 4 additions & 0 deletions edsl/data_transfer_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ def __init__(
cached_response=None,
raw_model_response=None,
simple_model_raw_response=None,
cache_used=None,
cache_key=None,
):
"""Initialize the AgentResponseDict object."""
usage = usage or {"prompt_tokens": 0, "completion_tokens": 0}
Expand All @@ -30,5 +32,7 @@ def __init__(
"cached_response": cached_response,
"raw_model_response": raw_model_response,
"simple_model_raw_response": simple_model_raw_response,
"cache_used": cache_used,
"cache_key": cache_key,
}
)
10 changes: 10 additions & 0 deletions edsl/jobs/buckets/ModelBuckets.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@ def __add__(self, other: "ModelBuckets"):
requests_bucket=self.requests_bucket + other.requests_bucket,
tokens_bucket=self.tokens_bucket + other.tokens_bucket,
)

def turbo_mode_on(self):
"""Set the refill rate to infinity for both buckets."""
self.requests_bucket.turbo_mode_on()
self.tokens_bucket.turbo_mode_on()

def turbo_mode_off(self):
"""Restore the refill rate to its original value for both buckets."""
self.requests_bucket.turbo_mode_off()
self.tokens_bucket.turbo_mode_off()

@classmethod
def infinity_bucket(cls, model_name: str = "not_specified") -> "ModelBuckets":
Expand Down
23 changes: 20 additions & 3 deletions edsl/jobs/buckets/TokenBucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,29 @@ def __init__(
self.bucket_name = bucket_name
self.bucket_type = bucket_type
self.capacity = capacity # Maximum number of tokens
self._old_capacity = capacity
self.tokens = capacity # Current number of available tokens
self.refill_rate = refill_rate # Rate at which tokens are refilled
self._old_refill_rate = refill_rate
self.last_refill = time.monotonic() # Last refill time

self.log: List[Any] = []

self.turbo_mode = False

def turbo_mode_on(self):
"""Set the refill rate to infinity."""
if self.turbo_mode:
pass
else:
self.turbo_mode = True
self.capacity=float("inf")
self.refill_rate=float("inf")

def turbo_mode_off(self):
"""Restore the refill rate to its original value."""
self.turbo_mode = False
self.capacity = self._old_capacity
self.refill_rate = self._old_refill_rate

def __add__(self, other) -> "TokenBucket":
"""Combine two token buckets.
Expand Down Expand Up @@ -98,7 +115,7 @@ async def get_tokens(self, amount: Union[int, float] = 1) -> None:
raise ValueError(msg)
while self.tokens < amount:
self.refill()
await asyncio.sleep(0.1) # Sleep briefly to prevent busy waiting
await asyncio.sleep(0.01) # Sleep briefly to prevent busy waiting
self.tokens -= amount

now = time.monotonic()
Expand Down
13 changes: 7 additions & 6 deletions edsl/jobs/interviews/Interview.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@ class Interview(InterviewStatusMixin, InterviewTaskBuildingMixin):

def __init__(
self,
agent: Agent,
survey: Survey,
scenario: Scenario,
model: Type[LanguageModel],
debug: bool = False,
agent: 'Agent',
survey: 'Survey',
scenario: 'Scenario',
model: Type['LanguageModel'],
debug: Optional[bool] = False,
iteration: int = 0,
cache: "Cache" = None,
sidecar_model: LanguageModel = None,
sidecar_model: 'LanguageModel' = None,
):
"""Initialize the Interview instance.
Expand Down Expand Up @@ -99,6 +99,7 @@ async def async_conduct_interview(
if model_buckets is None or hasattr(self.agent, "answer_question_directly"):
model_buckets = ModelBuckets.infinity_bucket()

# FOR TESTING
# model_buckets = ModelBuckets.infinity_bucket()

## build the tasks using the InterviewTaskBuildingMixin
Expand Down
2 changes: 2 additions & 0 deletions edsl/jobs/runners/JobsRunnerAsyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def _populate_total_interviews(self, n: int = 1) -> None:
self.total_interviews.append(interview)

async def run_async(self, cache=None) -> Results:
from edsl.results.Results import Results
if cache is None:
self.cache = Cache()
else:
Expand All @@ -98,6 +99,7 @@ async def run_async(self, cache=None) -> Results:
return Results(survey=self.jobs.survey, data=data)

def simple_run(self):
from edsl.results.Results import Results
data = asyncio.run(self.run_async())
return Results(survey=self.jobs.survey, data=data)

Expand Down
21 changes: 15 additions & 6 deletions edsl/jobs/tasks/QuestionTaskCreator.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,17 +144,26 @@ async def _run_focal_task(self, debug: bool) -> Answers:
self.task_status = TaskStatus.FAILED
raise e

if "cached_response" in results:
if results["cached_response"]:
# Gives back the tokens b/c the API was not called.
self.tokens_bucket.add_tokens(requested_tokens)
self.requests_bucket.add_tokens(1)
self.from_cache = True
## This isn't working
#breakpoint()
if results.get('cache_used', False):
self.tokens_bucket.add_tokens(requested_tokens)
self.requests_bucket.add_tokens(1)
self.from_cache = True
#print("Turning on turbo!")
self.tokens_bucket.turbo_mode_on()
self.requests_bucket.turbo_mode_on()
else:
#breakpoint()
#print("Turning off turbo!")
self.tokens_bucket.turbo_mode_off()
self.requests_bucket.turbo_mode_off()

_ = results.pop("cached_response", None)

tracker = self.cached_token_usage if self.from_cache else self.new_token_usage


# TODO: This is hacky. The 'func' call should return an object that definitely has a 'usage' key.
usage = results.get("usage", {"prompt_tokens": 0, "completion_tokens": 0})
prompt_tokens = usage.get("prompt_tokens", 0)
Expand Down
9 changes: 4 additions & 5 deletions edsl/language_models/LanguageModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,12 +323,10 @@ async def async_get_raw_response(
image_hash = hashlib.md5(encoded_image.encode()).hexdigest()
cache_call_params["user_prompt"] = f"{user_prompt} {image_hash}"

cached_response = cache.fetch(**cache_call_params)

cached_response, cache_key = cache.fetch(**cache_call_params)
if cached_response:
response = json.loads(cached_response)
cache_used = True
cache_key = None
else:
remote_call = hasattr(self, "remote") and self.remote
f = (
Expand All @@ -340,14 +338,15 @@ async def async_get_raw_response(
if encoded_image:
params["encoded_image"] = encoded_image
response = await f(**params)
cache_key = cache.store(
new_cache_key = cache.store(
user_prompt=user_prompt,
model=str(self.model),
parameters=self.parameters,
system_prompt=system_prompt,
response=response,
iteration=iteration,
)
assert new_cache_key == cache_key
cache_used = False

return response, cache_used, cache_key
Expand Down Expand Up @@ -412,7 +411,7 @@ async def async_get_response(

dict_response.update(
{
"cached_used": cache_used,
"cache_used": cache_used,
"cache_key": cache_key,
"usage": raw_response.get("usage", {}),
"raw_model_response": raw_response,
Expand Down
Loading

0 comments on commit 2fe050c

Please sign in to comment.