Skip to content

Commit

Permalink
fix context switch and logger (#454)
Browse files Browse the repository at this point in the history
* fix the context switch issue

* add config example
  • Loading branch information
dongyuanjushi authored Mar 3, 2025
1 parent acfca99 commit 933dbff
Show file tree
Hide file tree
Showing 12 changed files with 276 additions and 65 deletions.
64 changes: 64 additions & 0 deletions aios/config/config.yaml.example
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Global Configuration for AIOS

# API Keys Configuration
api_keys:
openai: "" # OpenAI API key
gemini: "" # Google Gemini API key
groq: "" # Groq API key
anthropic: "" # Anthropic API key
huggingface:
auth_token: "" # HuggingFace auth token
home: "" # Optional: HuggingFace models path

# LLM Configuration
llms:
models:
# - name: "gpt-4o-mini"
# backend: "openai"
# max_new_tokens: 1024
# temperature: 1.0

- name: "gemini-1.5-flash"
backend: "google"
max_new_tokens: 1024
temperature: 1.0

# - name: "gpt-4o-mini"
# backend: "openai"
# max_new_tokens: 1024
# temperature: 1.0

# - name: "qwen2.5:7b"
# backend: "ollama"
# max_new_tokens: 1024
# temperature: 1.0
# hostname: "http://localhost:11434" # Make sure to run ollama server

#
# - name: "meta-llama/Meta-Llama-3.1-8B-Instruct"
# backend: "huggingface"
# max_new_tokens: 1024
# temperature: 1.0

log_mode: "console"
# use_context_manager: false
use_context_manager: false # set as true to enable context interrupt and switch

memory:
memory_limit: 524288 # 512KB
eviction_k: 3

storage:
root_dir: "root"
use_vector_db: true

scheduler:
log_mode: "console"

agent_factory:
log_mode: "console"
max_workers: 64

server:
host: "localhost"
port: 8000
38 changes: 32 additions & 6 deletions aios/context/simple_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from litellm import completion

import time

class SimpleContextManager(BaseContextManager):
def __init__(self):
BaseContextManager.__init__(self)
Expand All @@ -13,31 +15,55 @@ def __init__(self):
def start(self):
pass

def save_context(self, model, messages, temperature, pid, time_limit):
def save_context(self, model, messages, tools, temperature, pid, time_limit):
if isinstance(model, str):
response = completion(
model=model,
messages=messages,
# tools=tools,
temperature=temperature,
stream=True
)
start_time = time.time()
completed_response = ""

finished = True

for part in response:
completed_response += part.choices[0].delta.content or ""
if time.time() - start_time > time_limit:
if part.choices[0].finish_reason is None:
finished = False
break

self.context_dict[str(pid)] = completed_response
return completed_response
return completed_response, finished

else:
pass

def load_context(self, pid):
return self.context_dict[str(pid)]
def load_context(self, pid, model, tokenizer=None):
context = self.check_context(pid)

if context is None:
return ""

# Add type checking
if isinstance(context, str) and not isinstance(model, str):
raise TypeError("When context is string type, model must also be string type")
if not isinstance(context, str) and isinstance(model, str):
raise TypeError("When model is string type, context must also be string type")

if isinstance(model, str):
return context
else:
# For local models that return tensors, decode using tokenizer
if tokenizer:
return tokenizer.decode(context)
return context

def check_context(self, pid):
# return os.path.exists(os.path.join(self.context_dir, f"process-{pid}.pt"))
return str(pid) in self.context_dict.keys()
return self.context_dict.get(str(pid), None)

def clear_context(self, pid):
self.context_dict.pop(pid)
Expand Down
42 changes: 41 additions & 1 deletion aios/hooks/modules/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from aios.hooks.utils.validate import validate
from aios.hooks.stores import queue as QueueStore, processes as ProcessStore
from aios.scheduler.fifo_scheduler import FIFOScheduler
# from aios.scheduler.rr_scheduler import RRScheduler
from aios.scheduler.rr_scheduler import RRScheduler


@validate(SchedulerParams)
Expand Down Expand Up @@ -158,4 +158,44 @@ def fifo_scheduler_nonblock(params: SchedulerParams):

scheduler = FIFOScheduler(**params.model_dump())

return scheduler

@validate(SchedulerParams)
def rr_scheduler_nonblock(params: SchedulerParams):
"""
A context manager that starts and stops a FIFO scheduler.
Args:
params (SchedulerParams): The parameters for the scheduler.
"""
if params.get_llm_syscall is None:
from aios.hooks.stores._global import global_llm_req_queue_get_message
params.get_llm_syscall = global_llm_req_queue_get_message

if params.get_memory_syscall is None:
from aios.hooks.stores._global import global_memory_req_queue_get_message
params.get_memory_syscall = global_memory_req_queue_get_message

if params.get_storage_syscall is None:
from aios.hooks.stores._global import global_storage_req_queue_get_message
params.get_storage_syscall = global_storage_req_queue_get_message

if params.get_tool_syscall is None:
from aios.hooks.stores._global import global_tool_req_queue_get_message
params.get_tool_syscall = global_tool_req_queue_get_message

# if params.llm_request_queue is None:
# params.llm_request_queue = LLMRequestQueue

# if params.memory_request_queue is None:
# params.memory_request_queue = MemoryRequestQueue

# if params.storage_request_queue is None:
# params.storage_request_queue = StorageRequestQueue

# if params.tool_request_queue is None:
# params.tool_request_queue = ToolRequestQueue

scheduler = RRScheduler(**params.model_dump())

return scheduler
1 change: 1 addition & 0 deletions aios/hooks/types/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
class LLMParams(BaseModel):
llm_configs: List[Dict[str, Any]]
log_mode: str = ("console",)
use_context_manager: bool = False
110 changes: 81 additions & 29 deletions aios/llm_core/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from cerebrum.llm.apis import LLMQuery, LLMResponse
from litellm import completion
import json
from .utils import tool_calling_input_format, parse_json_format, parse_tool_calls
from .utils import tool_calling_input_format, parse_json_format, parse_tool_calls, pre_process_tools
from typing import Dict, Optional, Any, List, Union
import time
import re
Expand Down Expand Up @@ -311,17 +311,38 @@ def execute_llm_syscall(
llm_syscall.set_status("executing")
llm_syscall.set_start_time(time.time())

messages = self._prepare_messages(llm_syscall, messages, tools)
# breakpoint()

model_idxs = self.strategy.get_model_idxs(selected_llms)
model = self.llms[model_idxs[0]]


if tools:
tools = pre_process_tools(tools)

messages = self._prepare_messages(
llm_syscall=llm_syscall,
model=model,
messages=messages,
tools=tools
)

try:
response = self._get_model_response(model, messages, temperature, llm_syscall)
completed_response, finished = self._get_model_response(
model=model,
messages=messages,
tools=tools,
temperature=temperature,
llm_syscall=llm_syscall
)
except Exception as e:
return self._handle_completion_error(e)

return self._process_response(response, tools, ret_type)
return self._process_response(
completed_response=completed_response,
finished=finished,
tools=tools,
ret_type=ret_type
)

except Exception as e:
return LLMResponse(
Expand All @@ -331,7 +352,7 @@ def execute_llm_syscall(
status_code=500
)

def _prepare_messages(self, llm_syscall, messages: List[Dict], tools: Optional[List] = None) -> List[Dict]:
def _prepare_messages(self, llm_syscall, model, messages: List[Dict], tools: Optional[List] = None) -> List[Dict]:
"""
Prepare messages for the LLM, including context restoration and tool formatting.
Expand Down Expand Up @@ -363,14 +384,17 @@ def _prepare_messages(self, llm_syscall, messages: List[Dict], tools: Optional[L
"""
if self.context_manager:
pid = llm_syscall.get_pid()
if self.context_manager.check_restoration(pid):
restored_context = self.context_manager.gen_recover(pid)
if restored_context:
messages += [{
"role": "assistant",
"content": "" + restored_context,
}]

restored_context = self.context_manager.load_context(
pid=pid,
model=model,
# tokenizer=tokenizer # TODO: Add tokenizer
)
messages += [{
"role": "assistant",
"content": "" + restored_context,
}]

# if not isinstance(model, str):
if tools:
tools = pre_process_tools(tools)
messages = tool_calling_input_format(messages, tools)
Expand All @@ -381,6 +405,7 @@ def _get_model_response(
self,
model: Union[str, HfLocalBackend, VLLMLocalBackend, OllamaBackend],
messages: List[Dict],
tools: Optional[List],
temperature: float,
llm_syscall
) -> Any:
Expand Down Expand Up @@ -416,25 +441,52 @@ def _get_model_response(
}
```
"""
pid = llm_syscall.get_pid() if self.use_context_manager else None

if isinstance(model, str):
completion_response = (
self.context_manager.save_context(model, messages, temperature, pid)
if self.use_context_manager
else completion(model=model, messages=messages, temperature=temperature)
)
return completion_response.choices[0].message.content
if self.use_context_manager:

pid = llm_syscall.get_pid()
time_limit = llm_syscall.get_time_limit()
completed_response, finished = self.context_manager.save_context(
model=model,
messages=messages,
tools=tools,
temperature=temperature,
pid=pid,
time_limit=time_limit
)

return completed_response, finished

else:
completed_response = completion(
model=model,
messages=messages,
tools=tools,
temperature=temperature
)

return completed_response.choices[0].message.content, True
else:
return (
self.context_manager.save_context(model, messages, temperature, pid)
self.context_manager.save_context(
model=model,
messages=messages,
tools=tools,
temperature=temperature,
pid=pid,
time_limit=time_limit
)
if self.use_context_manager
else model(messages=messages, temperature=temperature)
else model(
messages=messages,
temperature=temperature
)
)

def _process_response(
self,
response: str,
completed_response: str,
finished: bool,
tools: Optional[List] = None,
ret_type: Optional[str] = None
) -> LLMResponse:
Expand Down Expand Up @@ -484,14 +536,14 @@ def _process_response(
```
"""
if tools:
if tool_calls := parse_tool_calls(response):
if tool_calls := parse_tool_calls(completed_response):
return LLMResponse(
response_message=None,
tool_calls=tool_calls,
finished=True
finished=finished
)

if ret_type == "json":
response = parse_json_format(response)
completed_response = parse_json_format(completed_response)

return LLMResponse(response_message=response, finished=True)
return LLMResponse(response_message=completed_response, finished=finished)
Loading

0 comments on commit 933dbff

Please sign in to comment.