From 111feb3bd3ebd7c35c50292d29bc99df86cf56c2 Mon Sep 17 00:00:00 2001 From: olgavrou Date: Wed, 29 Nov 2023 19:45:28 +0000 Subject: [PATCH 1/3] separate oai agnostic info from client wrapper --- autogen/oai/client.py | 175 +++++++++++++++++++++++------------------- 1 file changed, 98 insertions(+), 77 deletions(-) diff --git a/autogen/oai/client.py b/autogen/oai/client.py index b4a139401ad2..8efa397e16a5 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -29,13 +29,94 @@ _ch.setFormatter(logger_formatter) logger.addHandler(_ch) +def template_formatter( + template: str | Callable | None, + context: Optional[Dict] = None, + allow_format_str_template: Optional[bool] = False, + ): + if not context or template is None: + return template + if isinstance(template, str): + return template.format(**context) if allow_format_str_template else template + return template(context) + +class ResponseCreator: + cache_path_root: str = ".cache" + extra_kwargs = {"cache_seed", "filter_func", "allow_format_str_template", "context"} + + def construct_create_params(self, create_config: Dict, extra_kwargs: Dict) -> Dict: + """Prime the create_config with additional_kwargs.""" + # Validate the config + prompt = create_config.get("prompt") + messages = create_config.get("messages") + if (prompt is None) == (messages is None): + raise ValueError("Either prompt or messages should be in create config but not both.") + context = extra_kwargs.get("context") + if context is None: + # No need to instantiate if no context is provided. + return create_config + # Instantiate the prompt or messages + allow_format_str_template = extra_kwargs.get("allow_format_str_template", False) + # Make a copy of the config + params = create_config.copy() + if prompt is not None: + # Instantiate the prompt + params["prompt"] = self.instantiate(prompt, context, allow_format_str_template) + elif context: + # Instantiate the messages + params["messages"] = [ + { + **m, + "content": self.instantiate(m["content"], context, allow_format_str_template), + } + if m.get("content") + else m + for m in messages + ] + return params + + def create(self, client, completions_create, is_last, create_config: Dict, extra_kwargs: Dict): + # construct the create params + params = self.construct_create_params(create_config, extra_kwargs) + # get the cache_seed, filter_func and context + cache_seed = extra_kwargs.get("cache_seed", 41) + filter_func = extra_kwargs.get("filter_func") + context = extra_kwargs.get("context") + with diskcache.Cache(f"{self.cache_path_root}/{cache_seed}") as cache: + if cache_seed is not None: + # Try to get the response from cache + key = get_key(params) + response = cache.get(key, None) + if response is not None: + # check the filter + pass_filter = filter_func is None or filter_func(context=context, response=response) + if pass_filter or is_last: + # Return the response if it passes the filter or it is the last client + response.pass_filter = pass_filter + # TODO: add response.cost + return response + + response = completions_create(client, params) + + if cache_seed is not None: + # Cache the response + with diskcache.Cache(f"{self.cache_path_root}/{cache_seed}") as cache: + cache.set(key, response) + + # check the filter + pass_filter = filter_func is None or filter_func(context=context, response=response) + if pass_filter or is_last: + # Return the response if it passes the filter or it is the last client + response.pass_filter = pass_filter + return response + return None + class OpenAIWrapper: """A wrapper class for openai client.""" - cache_path_root: str = ".cache" - extra_kwargs = {"cache_seed", "filter_func", "allow_format_str_template", "context", "api_version"} openai_kwargs = set(inspect.getfullargspec(OpenAI.__init__).kwonlyargs) + extra_kwargs = {"api_version"} def __init__(self, *, config_list: List[Dict] = None, **base_config): """ @@ -69,6 +150,8 @@ def __init__(self, *, config_list: List[Dict] = None, **base_config): base_config: base config. It can contain both keyword arguments for openai client and additional kwargs. """ + self.response_creator = ResponseCreator() + self.extra_kwargs.update(self.response_creator.extra_kwargs) openai_config, extra_kwargs = self._separate_openai_config(base_config) if type(config_list) is list and len(config_list) == 0: logger.warning("openai client was provided with an empty config_list, which may not be intended.") @@ -145,42 +228,7 @@ def instantiate( context: Optional[Dict] = None, allow_format_str_template: Optional[bool] = False, ): - if not context or template is None: - return template - if isinstance(template, str): - return template.format(**context) if allow_format_str_template else template - return template(context) - - def _construct_create_params(self, create_config: Dict, extra_kwargs: Dict) -> Dict: - """Prime the create_config with additional_kwargs.""" - # Validate the config - prompt = create_config.get("prompt") - messages = create_config.get("messages") - if (prompt is None) == (messages is None): - raise ValueError("Either prompt or messages should be in create config but not both.") - context = extra_kwargs.get("context") - if context is None: - # No need to instantiate if no context is provided. - return create_config - # Instantiate the prompt or messages - allow_format_str_template = extra_kwargs.get("allow_format_str_template", False) - # Make a copy of the config - params = create_config.copy() - if prompt is not None: - # Instantiate the prompt - params["prompt"] = self.instantiate(prompt, context, allow_format_str_template) - elif context: - # Instantiate the messages - params["messages"] = [ - { - **m, - "content": self.instantiate(m["content"], context, allow_format_str_template), - } - if m.get("content") - else m - for m in messages - ] - return params + return template_formatter(template, context, allow_format_str_template) def create(self, **config): """Make a completion for a given config using openai's clients. @@ -220,50 +268,23 @@ def yes_or_no_filter(context, response): create_config, extra_kwargs = self._separate_create_config(full_config) # process for azure self._process_for_azure(create_config, extra_kwargs, "extra") - # construct the create params - params = self._construct_create_params(create_config, extra_kwargs) - # get the cache_seed, filter_func and context - cache_seed = extra_kwargs.get("cache_seed", 41) - filter_func = extra_kwargs.get("filter_func") - context = extra_kwargs.get("context") - - # Try to load the response from cache - if cache_seed is not None: - with diskcache.Cache(f"{self.cache_path_root}/{cache_seed}") as cache: - # Try to get the response from cache - key = get_key(params) - response = cache.get(key, None) - if response is not None: - # check the filter - pass_filter = filter_func is None or filter_func(context=context, response=response) - if pass_filter or i == last: - # Return the response if it passes the filter or it is the last client - response.config_id = i - response.pass_filter = pass_filter - response.cost = self.cost(response) - return response - continue # filter is not passed; try the next config try: - response = self._completions_create(client, params) + response = self.response_creator.create( + client=client, + is_last=(i == last), + completions_create=self._completions_create, + create_config=create_config, + extra_kwargs=extra_kwargs, + ) + if response is None: + continue # filter is not passed; try the next config + response.config_id = i + response.cost = self.cost(response) + return response except APIError: logger.debug(f"config {i} failed", exc_info=1) if i == last: raise - else: - if cache_seed is not None: - # Cache the response - with diskcache.Cache(f"{self.cache_path_root}/{cache_seed}") as cache: - cache.set(key, response) - - # check the filter - pass_filter = filter_func is None or filter_func(context=context, response=response) - if pass_filter or i == last: - # Return the response if it passes the filter or it is the last client - response.config_id = i - response.pass_filter = pass_filter - response.cost = self.cost(response) - return response - continue # filter is not passed; try the next config def cost(self, response: Union[ChatCompletion, Completion]) -> float: """Calculate the cost of the response.""" From 6ce9a6e7232831ab690ba40019adb3e86f042696 Mon Sep 17 00:00:00 2001 From: olgavrou Date: Thu, 30 Nov 2023 20:43:44 +0000 Subject: [PATCH 2/3] rl client which uses transformers --- autogen/oai/client.py | 271 +++++++++++++++++++++++++++--------------- 1 file changed, 174 insertions(+), 97 deletions(-) diff --git a/autogen/oai/client.py b/autogen/oai/client.py index 8efa397e16a5..ad5f2d80ad23 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -6,6 +6,7 @@ import logging import inspect from flaml.automl.logger import logger_formatter +from types import SimpleNamespace from autogen.oai.openai_utils import get_key, oai_price1k from autogen.token_count_utils import count_token @@ -22,6 +23,13 @@ except ImportError: ERROR = ImportError("Please install openai>=1 and diskcache to use autogen.OpenAIWrapper.") OpenAI = object + +try: + from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig + ERROR = None +except ImportError: + ERROR = ImportError("Please install transformers and diskcache to use autogen.RLClientWrapper.") + logger = logging.getLogger(__name__) if not logger.handlers: # Add the console handler. @@ -29,16 +37,18 @@ _ch.setFormatter(logger_formatter) logger.addHandler(_ch) + def template_formatter( - template: str | Callable | None, - context: Optional[Dict] = None, - allow_format_str_template: Optional[bool] = False, - ): - if not context or template is None: - return template - if isinstance(template, str): - return template.format(**context) if allow_format_str_template else template - return template(context) + template: str | Callable | None, + context: Optional[Dict] = None, + allow_format_str_template: Optional[bool] = False, +): + if not context or template is None: + return template + if isinstance(template, str): + return template.format(**context) if allow_format_str_template else template + return template(context) + class ResponseCreator: cache_path_root: str = ".cache" @@ -75,7 +85,7 @@ def construct_create_params(self, create_config: Dict, extra_kwargs: Dict) -> Di ] return params - def create(self, client, completions_create, is_last, create_config: Dict, extra_kwargs: Dict): + def create(self, client, client_id, is_last, create_config: Dict, extra_kwargs: Dict): # construct the create params params = self.construct_create_params(create_config, extra_kwargs) # get the cache_seed, filter_func and context @@ -96,7 +106,9 @@ def create(self, client, completions_create, is_last, create_config: Dict, extra # TODO: add response.cost return response - response = completions_create(client, params) + response = client.create(params) + if response is None: + return None if cache_seed is not None: # Cache the response @@ -108,10 +120,152 @@ def create(self, client, completions_create, is_last, create_config: Dict, extra if pass_filter or is_last: # Return the response if it passes the filter or it is the last client response.pass_filter = pass_filter + response.config_id = client_id + response.cost = client.cost(response) return response return None +class RLClient: + def __init__(self, config: Dict): + import torch + + self.device = ( + ("cuda" if torch.cuda.is_available() else "cpu") if config.get("device", None) is None else config["device"] + ) + self.tokenizer = AutoTokenizer.from_pretrained(config["local_model"], load_in_8bit=True, use_fast=False) + self.model = AutoModelForCausalLM.from_pretrained(config["local_model"]).to(self.device) + # get max_length from config or set to 1000 + self.max_length = config.get("max_length", 1000) + self.gen_config_params = config.get("params", {}) + # correct max_length in self.params + self.gen_config_params["max_length"] = self.max_length + self.gen_config_params["eos_token_id"] = self.tokenizer.eos_token_id + self.gen_config_params["pad_token_id"] = self.tokenizer.eos_token_id + print(f"Loaded model {config['local_model']} to {self.device}") + + def create(self, params): + if params.get("stream", False) and "messages" in params and "functions" not in params: + raise NotImplementedError("Local models do not support streaming or functions") + else: + response_contents = [""] * params.get("n", 1) + finish_reasons = [""] * params.get("n", 1) + completion_tokens = 0 + + response = SimpleNamespace() + inputs = self.tokenizer.apply_chat_template( + params["messages"], return_tensors="pt", add_generation_prompt=True + ).to(self.device) + + inputs_length = inputs.shape[-1] + # copy gen config params + gen_config_params = self.gen_config_params.copy() + # add inputs_length to max_length + gen_config_params["max_length"] += inputs_length + generation_config = GenerationConfig(**gen_config_params) + + + response.choices = [] + + for _ in range(len(response_contents)): + outputs = self.model.generate(inputs, generation_config=generation_config) + # Decode only the newly generated text, excluding the prompt + text = self.tokenizer.decode(outputs[0, inputs_length:], skip_special_tokens=True) + choice = SimpleNamespace() + choice.message = SimpleNamespace() + choice.message.content = text + choice.message.function_call = None + response.choices.append(choice) + + return response + + def cost(self, response) -> float: + """Calculate the cost of the response.""" + return 0 + + +class OpenAIClient: + def __init__(self, config: Dict): + self.client = OpenAI(**config) + + def create(self, params): + completions = self.client.chat.completions if "messages" in params else self.client.completions + # If streaming is enabled, has messages, and does not have functions, then + # iterate over the chunks of the response + if params.get("stream", False) and "messages" in params and "functions" not in params: + response_contents = [""] * params.get("n", 1) + finish_reasons = [""] * params.get("n", 1) + completion_tokens = 0 + + # Set the terminal text color to green + print("\033[32m", end="") + + # Send the chat completion request to OpenAI's API and process the response in chunks + for chunk in completions.create(**params): + if chunk.choices: + for choice in chunk.choices: + content = choice.delta.content + finish_reasons[choice.index] = choice.finish_reason + # If content is present, print it to the terminal and update response variables + if content is not None: + print(content, end="", flush=True) + response_contents[choice.index] += content + completion_tokens += 1 + else: + print() + + # Reset the terminal text color + print("\033[0m\n") + + # Prepare the final ChatCompletion object based on the accumulated data + model = chunk.model.replace("gpt-35", "gpt-3.5") # hack for Azure API + prompt_tokens = count_token(params["messages"], model) + response = ChatCompletion( + id=chunk.id, + model=chunk.model, + created=chunk.created, + object="chat.completion", + choices=[], + usage=CompletionUsage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ), + ) + for i in range(len(response_contents)): + response.choices.append( + Choice( + index=i, + finish_reason=finish_reasons[i], + message=ChatCompletionMessage( + role="assistant", content=response_contents[i], function_call=None + ), + ) + ) + else: + # If streaming is not enabled or using functions, send a regular chat completion request + # Functions are not supported, so ensure streaming is disabled + params = params.copy() + params["stream"] = False + response = completions.create(**params) + return response + + def cost(self, response: Union[ChatCompletion, Completion]) -> float: + """Calculate the cost of the response.""" + model = response.model + if model not in oai_price1k: + # TODO: add logging to warn that the model is not found + return 0 + + n_input_tokens = response.usage.prompt_tokens + n_output_tokens = response.usage.completion_tokens + tmp_price1K = oai_price1k[model] + # First value is input token rate, second value is output token rate + if isinstance(tmp_price1K, tuple): + return (tmp_price1K[0] * n_input_tokens + tmp_price1K[1] * n_output_tokens) / 1000 + return tmp_price1K * (n_input_tokens + n_output_tokens) / 1000 + + class OpenAIWrapper: """A wrapper class for openai client.""" @@ -218,8 +372,10 @@ def _client(self, config, openai_config): """ openai_config = {**openai_config, **{k: v for k, v in config.items() if k in self.openai_kwargs}} self._process_for_azure(openai_config, config) - client = OpenAI(**openai_config) - return client + if "local_model" in config: + return RLClient(config) + else: + return OpenAIClient(openai_config) @classmethod def instantiate( @@ -271,100 +427,21 @@ def yes_or_no_filter(context, response): try: response = self.response_creator.create( client=client, + client_id=i, is_last=(i == last), - completions_create=self._completions_create, create_config=create_config, extra_kwargs=extra_kwargs, ) - if response is None: - continue # filter is not passed; try the next config - response.config_id = i - response.cost = self.cost(response) - return response + + if response is not None: + return response except APIError: logger.debug(f"config {i} failed", exc_info=1) if i == last: raise - def cost(self, response: Union[ChatCompletion, Completion]) -> float: - """Calculate the cost of the response.""" - model = response.model - if model not in oai_price1k: - # TODO: add logging to warn that the model is not found - return 0 - - n_input_tokens = response.usage.prompt_tokens - n_output_tokens = response.usage.completion_tokens - tmp_price1K = oai_price1k[model] - # First value is input token rate, second value is output token rate - if isinstance(tmp_price1K, tuple): - return (tmp_price1K[0] * n_input_tokens + tmp_price1K[1] * n_output_tokens) / 1000 - return tmp_price1K * (n_input_tokens + n_output_tokens) / 1000 - - def _completions_create(self, client, params): - completions = client.chat.completions if "messages" in params else client.completions - # If streaming is enabled, has messages, and does not have functions, then - # iterate over the chunks of the response - if params.get("stream", False) and "messages" in params and "functions" not in params: - response_contents = [""] * params.get("n", 1) - finish_reasons = [""] * params.get("n", 1) - completion_tokens = 0 - - # Set the terminal text color to green - print("\033[32m", end="") - - # Send the chat completion request to OpenAI's API and process the response in chunks - for chunk in completions.create(**params): - if chunk.choices: - for choice in chunk.choices: - content = choice.delta.content - finish_reasons[choice.index] = choice.finish_reason - # If content is present, print it to the terminal and update response variables - if content is not None: - print(content, end="", flush=True) - response_contents[choice.index] += content - completion_tokens += 1 - else: - print() - - # Reset the terminal text color - print("\033[0m\n") - - # Prepare the final ChatCompletion object based on the accumulated data - model = chunk.model.replace("gpt-35", "gpt-3.5") # hack for Azure API - prompt_tokens = count_token(params["messages"], model) - response = ChatCompletion( - id=chunk.id, - model=chunk.model, - created=chunk.created, - object="chat.completion", - choices=[], - usage=CompletionUsage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, - ), - ) - for i in range(len(response_contents)): - response.choices.append( - Choice( - index=i, - finish_reason=finish_reasons[i], - message=ChatCompletionMessage( - role="assistant", content=response_contents[i], function_call=None - ), - ) - ) - else: - # If streaming is not enabled or using functions, send a regular chat completion request - # Functions are not supported, so ensure streaming is disabled - params = params.copy() - params["stream"] = False - response = completions.create(**params) - return response - @classmethod - def extract_text_or_function_call(cls, response: ChatCompletion | Completion) -> List[str]: + def extract_text_or_function_call(cls, response) -> List[str]: """Extract the text or function calls from a completion or chat response. Args: From 473b2802bc8d90587fa5a19714a90e3166e19fd1 Mon Sep 17 00:00:00 2001 From: olgavrou Date: Thu, 30 Nov 2023 23:51:47 +0000 Subject: [PATCH 3/3] make client pluggable by external code --- autogen/oai/__init__.py | 2 + autogen/oai/client.py | 85 ++++++++++------------------------------- 2 files changed, 23 insertions(+), 64 deletions(-) diff --git a/autogen/oai/__init__.py b/autogen/oai/__init__.py index dbcd2f796074..df3057b5b7f1 100644 --- a/autogen/oai/__init__.py +++ b/autogen/oai/__init__.py @@ -8,6 +8,7 @@ config_list_from_json, config_list_from_dotenv, ) +from autogen.oai.client import Client __all__ = [ "OpenAIWrapper", @@ -19,4 +20,5 @@ "config_list_from_models", "config_list_from_json", "config_list_from_dotenv", + "Client", ] diff --git a/autogen/oai/client.py b/autogen/oai/client.py index ad5f2d80ad23..502fdb2810ea 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -6,7 +6,7 @@ import logging import inspect from flaml.automl.logger import logger_formatter -from types import SimpleNamespace +from abc import ABC, abstractmethod from autogen.oai.openai_utils import get_key, oai_price1k from autogen.token_count_utils import count_token @@ -24,12 +24,6 @@ ERROR = ImportError("Please install openai>=1 and diskcache to use autogen.OpenAIWrapper.") OpenAI = object -try: - from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig - ERROR = None -except ImportError: - ERROR = ImportError("Please install transformers and diskcache to use autogen.RLClientWrapper.") - logger = logging.getLogger(__name__) if not logger.handlers: # Add the console handler. @@ -38,6 +32,14 @@ logger.addHandler(_ch) +def import_class_from_path(path, class_name): + import importlib.util + spec = importlib.util.spec_from_file_location("module.name", path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + cls = getattr(module, class_name) + return cls + def template_formatter( template: str | Callable | None, context: Optional[Dict] = None, @@ -125,66 +127,18 @@ def create(self, client, client_id, is_last, create_config: Dict, extra_kwargs: return response return None +class Client(ABC): -class RLClient: - def __init__(self, config: Dict): - import torch - - self.device = ( - ("cuda" if torch.cuda.is_available() else "cpu") if config.get("device", None) is None else config["device"] - ) - self.tokenizer = AutoTokenizer.from_pretrained(config["local_model"], load_in_8bit=True, use_fast=False) - self.model = AutoModelForCausalLM.from_pretrained(config["local_model"]).to(self.device) - # get max_length from config or set to 1000 - self.max_length = config.get("max_length", 1000) - self.gen_config_params = config.get("params", {}) - # correct max_length in self.params - self.gen_config_params["max_length"] = self.max_length - self.gen_config_params["eos_token_id"] = self.tokenizer.eos_token_id - self.gen_config_params["pad_token_id"] = self.tokenizer.eos_token_id - print(f"Loaded model {config['local_model']} to {self.device}") - + @abstractmethod def create(self, params): - if params.get("stream", False) and "messages" in params and "functions" not in params: - raise NotImplementedError("Local models do not support streaming or functions") - else: - response_contents = [""] * params.get("n", 1) - finish_reasons = [""] * params.get("n", 1) - completion_tokens = 0 + pass - response = SimpleNamespace() - inputs = self.tokenizer.apply_chat_template( - params["messages"], return_tensors="pt", add_generation_prompt=True - ).to(self.device) - - inputs_length = inputs.shape[-1] - # copy gen config params - gen_config_params = self.gen_config_params.copy() - # add inputs_length to max_length - gen_config_params["max_length"] += inputs_length - generation_config = GenerationConfig(**gen_config_params) - - - response.choices = [] - - for _ in range(len(response_contents)): - outputs = self.model.generate(inputs, generation_config=generation_config) - # Decode only the newly generated text, excluding the prompt - text = self.tokenizer.decode(outputs[0, inputs_length:], skip_special_tokens=True) - choice = SimpleNamespace() - choice.message = SimpleNamespace() - choice.message.content = text - choice.message.function_call = None - response.choices.append(choice) - - return response - - def cost(self, response) -> float: - """Calculate the cost of the response.""" - return 0 + @abstractmethod + def cost(self, response): + pass -class OpenAIClient: +class OpenAIClient(Client): def __init__(self, config: Dict): self.client = OpenAI(**config) @@ -372,8 +326,11 @@ def _client(self, config, openai_config): """ openai_config = {**openai_config, **{k: v for k, v in config.items() if k in self.openai_kwargs}} self._process_for_azure(openai_config, config) - if "local_model" in config: - return RLClient(config) + if "custom_client" in config and "custom_client_code_path" in config: + custom_client = config["custom_client"] + custom_client_code_path = config["custom_client_code_path"] + CustomClient = import_class_from_path(custom_client_code_path, custom_client) + return CustomClient(config) else: return OpenAIClient(openai_config)