From 72a2bd9226e56905d906318ad2d193b548d95072 Mon Sep 17 00:00:00 2001 From: Max Crouse Date: Wed, 3 Jul 2024 09:30:57 -0500 Subject: [PATCH 01/41] updating blocks --- fms_sdg/base/block.py | 49 ++++++++ fms_sdg/base/databuilder.py | 119 +++++++----------- fms_sdg/base/generator.py | 35 ------ fms_sdg/base/registry.py | 55 ++------ fms_sdg/base/validator.py | 7 +- fms_sdg/{ => blocks}/generators/README.md | 0 fms_sdg/{ => blocks}/generators/__init__.py | 0 fms_sdg/{ => blocks}/generators/genai.py | 11 +- fms_sdg/{ => blocks}/generators/llm.py | 23 +++- fms_sdg/{ => blocks}/generators/openai.py | 11 +- fms_sdg/{ => blocks}/generators/utils.py | 0 fms_sdg/{ => blocks}/generators/vllm.py | 12 +- fms_sdg/{ => blocks}/validators/api.py | 10 +- fms_sdg/{ => blocks}/validators/lm_judge.py | 16 +-- .../nl2sql/sql_execution_validator.py | 8 +- .../validators/nl2sql/sql_syntax_validator.py | 8 +- fms_sdg/{ => blocks}/validators/rouge.py | 8 +- .../api/api_function_checking.yaml | 3 +- .../api/api_yes_no_detection.yaml | 3 +- fms_sdg/databuilders/api/generate.py | 8 +- fms_sdg/databuilders/nl2sql/generate.py | 10 +- fms_sdg/databuilders/nl2sql/nl2sql.yaml | 3 +- fms_sdg/databuilders/simple/generate.py | 18 ++- fms_sdg/databuilders/simple/simple.yaml | 3 +- fms_sdg/validators/genai.py | 19 --- pyproject.toml | 3 +- templates/generator/template.py | 4 +- 27 files changed, 204 insertions(+), 242 deletions(-) create mode 100644 fms_sdg/base/block.py delete mode 100644 fms_sdg/base/generator.py rename fms_sdg/{ => blocks}/generators/README.md (100%) rename fms_sdg/{ => blocks}/generators/__init__.py (100%) rename fms_sdg/{ => blocks}/generators/genai.py (96%) rename fms_sdg/{ => blocks}/generators/llm.py (92%) rename fms_sdg/{ => blocks}/generators/openai.py (95%) rename fms_sdg/{ => blocks}/generators/utils.py (100%) rename fms_sdg/{ => blocks}/generators/vllm.py (98%) rename fms_sdg/{ => blocks}/validators/api.py (95%) rename fms_sdg/{ => blocks}/validators/lm_judge.py (61%) rename fms_sdg/{ => blocks}/validators/nl2sql/sql_execution_validator.py (91%) rename fms_sdg/{ => blocks}/validators/nl2sql/sql_syntax_validator.py (87%) rename fms_sdg/{ => blocks}/validators/rouge.py (92%) delete mode 100644 fms_sdg/validators/genai.py diff --git a/fms_sdg/base/block.py b/fms_sdg/base/block.py new file mode 100644 index 00000000..e395c994 --- /dev/null +++ b/fms_sdg/base/block.py @@ -0,0 +1,49 @@ +# Standard +from abc import ABC +from typing import Any, Dict, List, Optional, Union +import abc + +# Third Party +from datasets import Dataset +import pandas as pd + + +class BaseBlock(ABC): + """Base Class for all Blocks""" + + def __init__(self, name: str, config: Dict, **kwargs: Any) -> None: + self._name = name + self._config: Dict = config + self._blocks: List[BaseBlock] = [] + + # overwrite config fields with kwargs (usually these will be command line args) + self._config.update(kwargs) + + self._arg_fields = self._config.get("arg_fields", None) + self._kwarg_fields = self._config.get("kwarg_fields", None) + self._result_field = self._config.get("result_field", None) + + @property + def name(self): + return self._name + + @property + def config(self): + return self._config + + @property + def blocks(self) -> List: + """Returns the constituent blocks associated with this class.""" + return self._blocks + + @abc.abstractmethod + def __call__( + self, + inputs: Union[List[Dict], pd.DataFrame, Dataset], + *args: Any, + arg_fields: Optional[List[str]] = None, + kwarg_fields: Optional[List[str]] = None, + result_field: Optional[str] = None, + **kwargs: Any, + ) -> None: + pass diff --git a/fms_sdg/base/databuilder.py b/fms_sdg/base/databuilder.py index 8519b770..6fe8c233 100644 --- a/fms_sdg/base/databuilder.py +++ b/fms_sdg/base/databuilder.py @@ -10,11 +10,10 @@ from tqdm import tqdm # Local -from fms_sdg.base.generator import BaseGenerator -from fms_sdg.base.registry import get_generator, get_validator +from fms_sdg.base.block import BaseBlock +from fms_sdg.base.registry import get_block from fms_sdg.base.task import SdgData, SdgTask -from fms_sdg.base.validator import BaseValidator -from fms_sdg.generators.llm import CachingLM, LMGenerator +from fms_sdg.blocks.generators.llm import CachingLM, LMGeneratorBlock from fms_sdg.utils import all_annotations, sdg_logger @@ -22,19 +21,14 @@ class DataBuilderConfig(dict): # data builder naming/registry name: Optional[str] = None - generators: Optional[Union[str, list]] = None - validators: Optional[Union[str, list]] = None + blocks: Optional[dict] = None generation_kwargs: Optional[dict] = None metadata: Optional[ dict ] = None # by default, not used in the code. allows for users to pass arbitrary info to data builders def __post_init__(self) -> None: - if self.generation_kwargs is not None: - if "temperature" in self.generation_kwargs: - self.generation_kwargs["temperature"] = float( - self.generation_kwargs["temperature"] - ) + pass TYPE_KEY = "type" @@ -66,7 +60,7 @@ def __init__( self._restart_generation = restart_generation # initializing generators / validators - self._init_gv(lm_cache=lm_cache) + self._init_blocks(lm_cache=lm_cache) # TODO: Data loader goes here self._tasks: List[SdgTask] = [ @@ -91,68 +85,47 @@ def config(self) -> DataBuilderConfig: return self._config @property - def generators(self) -> List[BaseGenerator]: - """Returns the generators associated with this class.""" - return self._generators + def blocks(self) -> List[BaseBlock]: + """Returns the blocks associated with this class.""" + return self._blocks - @property - def validators(self) -> List[BaseValidator]: - """Returns the validators associated with this class.""" - return self._validators - - def _init_gv(self, lm_cache: str = None): - _generators = ( - [self.config.generators] - if type(self.config.generators) == str - else self.config.generators - ) - _validators = ( - [self.config.validators] - if type(self.config.validators) == str - else self.config.validators - ) - self._generators: List[BaseGenerator] = [] - self._validators: List[BaseValidator] = [] - - # TODO: need to handle nested generators / validators - for i, info_src in enumerate([_generators, _validators]): - # user may not define a generator / validator - if info_src is not None: - for obj_name, obj_config in info_src.items(): - sdg_logger.debug( - "Initializing object %s with config %s", obj_name, obj_config - ) - obj = (get_generator if i == 0 else get_validator)( - obj_config[TYPE_KEY] - )(obj_name, obj_config) - - if lm_cache is not None and isinstance(obj, LMGenerator): - sdg_logger.info( - "Using cache at %s", - lm_cache + "_rank" + str(obj.rank) + ".db", - ) - obj = CachingLM( - obj, - lm_cache - # each rank receives a different cache db. - # necessary to avoid multiple writes to cache at once - + f"_model{os.path.split(obj.model_id_or_path)[-1]}_rank{obj.rank}.db", - ) - - type_annotations = all_annotations(type(self)) - assert ( - obj_name in type_annotations - ), f"Object {obj_name} is missing from definition of DataBuilder {self.__class__}" - - obj_type = type_annotations[obj_name] - - # double check types - assert isinstance(obj, obj_type) or ( - isinstance(obj, CachingLM) and isinstance(obj.lm, obj_type) - ), f"Type of retrieved object {obj.__class__} for {obj_name} does not match type {obj_type} specified in DataBuilder {self.__class__}" - - setattr(self, obj_name, obj) - (self._generators if i == 0 else self._validators).append(obj) + def _init_blocks(self, lm_cache: str = None): + self._blocks: List[BaseBlock] = [] + + # TODO: need to handle nested blocks + for obj_name, obj_config in self.config.blocks.items(): + sdg_logger.debug( + "Initializing object %s with config %s", obj_name, obj_config + ) + obj = get_block(obj_config[TYPE_KEY])(obj_name, obj_config) + + if lm_cache is not None and isinstance(obj, LMGeneratorBlock): + sdg_logger.info( + "Using cache at %s", + lm_cache + "_rank" + str(obj.rank) + ".db", + ) + obj = CachingLM( + obj, + lm_cache + # each rank receives a different cache db. + # necessary to avoid multiple writes to cache at once + + f"_model{os.path.split(obj.model_id_or_path)[-1]}_rank{obj.rank}.db", + ) + + type_annotations = all_annotations(type(self)) + assert ( + obj_name in type_annotations + ), f"Object {obj_name} is missing from definition of DataBuilder {self.__class__}" + + obj_type = type_annotations[obj_name] + + # double check types + assert isinstance(obj, obj_type) or ( + isinstance(obj, CachingLM) and isinstance(obj.lm, obj_type) + ), f"Type of retrieved object {obj.__class__} for {obj_name} does not match type {obj_type} specified in DataBuilder {self.__class__}" + + setattr(self, obj_name, obj) + self._blocks.append(obj) def execute_tasks(self): # main entry point to task execution diff --git a/fms_sdg/base/generator.py b/fms_sdg/base/generator.py deleted file mode 100644 index 7739fc30..00000000 --- a/fms_sdg/base/generator.py +++ /dev/null @@ -1,35 +0,0 @@ -# Standard -from abc import ABC -from typing import Any, Dict, List, Union -import abc - - -class BaseGenerator(ABC): - """Base Class for all Generators""" - - def __init__(self, name: str, config: Dict, **kwargs: Any) -> None: - self._name = name - self._config: Dict = config - self._generators: List[BaseGenerator] = [] - - # overwrite config fields with kwargs (usually these will be command line args) - self._config.update(kwargs) - - @property - def name(self): - return self._name - - @property - def config(self): - return self._config - - @property - def generators(self) -> List: - """Returns the generators associated with this class.""" - return self._generators - - @abc.abstractmethod - def generate_batch( - self, *args: Union[List, Dict], **kwargs: Union[str, Dict] - ) -> None: - raise NotImplementedError diff --git a/fms_sdg/base/registry.py b/fms_sdg/base/registry.py index 81b6f292..157b4219 100644 --- a/fms_sdg/base/registry.py +++ b/fms_sdg/base/registry.py @@ -3,76 +3,43 @@ import logging # Local -from fms_sdg.base.generator import BaseGenerator +from fms_sdg.base.block import BaseBlock from fms_sdg.base.resource import BaseResource -from fms_sdg.base.validator import BaseValidator eval_logger = logging.getLogger("fms_sdg") # TODO: generator registry, validator registry, task registry -GENERATOR_REGISTRY = {} +BLOCK_REGISTRY = {} -def register_generator(*names): +def register_block(*names): # either pass a list or a single alias. # function receives them as a tuple of strings def decorate(cls): for name in names: assert issubclass( - cls, BaseGenerator - ), f"Generator '{name}' ({cls.__name__}) must extend BaseGenerator class" + cls, BaseBlock + ), f"Block '{name}' ({cls.__name__}) must extend BaseBlock class" assert ( - name not in GENERATOR_REGISTRY - ), f"Generator named '{name}' conflicts with existing generator! Please register with a non-conflicting alias instead." + name not in BLOCK_REGISTRY + ), f"Block named '{name}' conflicts with existing block! Please register with a non-conflicting alias instead." - GENERATOR_REGISTRY[name] = cls + BLOCK_REGISTRY[name] = cls return cls return decorate -def get_generator(model_name): +def get_block(block_name): try: - return GENERATOR_REGISTRY[model_name] + return BLOCK_REGISTRY[block_name] except KeyError: raise ValueError( - f"Attempted to load generator '{model_name}', but no generator for this name found! Supported generator names: {', '.join(GENERATOR_REGISTRY.keys())}" - ) - - -VALIDATOR_REGISTRY = {} - - -def register_validator(*names): - # either pass a list or a single alias. - # function receives them as a tuple of strings - - def decorate(cls): - for name in names: - assert issubclass( - cls, BaseValidator - ), f"Validator '{name}' ({cls.__name__}) must extend BaseValidator class" - - assert ( - name not in VALIDATOR_REGISTRY - ), f"Validator named '{name}' conflicts with existing validator! Please register with a non-conflicting alias instead." - - VALIDATOR_REGISTRY[name] = cls - return cls - - return decorate - - -def get_validator(model_name): - try: - return VALIDATOR_REGISTRY[model_name] - except KeyError: - raise ValueError( - f"Attempted to load validator '{model_name}', but no validator for this name found! Supported validator names: {', '.join(VALIDATOR_REGISTRY.keys())}" + f"Attempted to load block '{block_name}', but no block for this name found! Supported block names: {', '.join(BLOCK_REGISTRY.keys())}" ) diff --git a/fms_sdg/base/validator.py b/fms_sdg/base/validator.py index 6b4fad23..1864d3f5 100644 --- a/fms_sdg/base/validator.py +++ b/fms_sdg/base/validator.py @@ -5,17 +5,18 @@ import collections # Local -from fms_sdg.base.generator import BaseGenerator +from fms_sdg.base.block import BaseBlock +from fms_sdg.base.generator import BaseGeneratorBlock from fms_sdg.base.instance import Instance -class BaseValidator(ABC): +class BaseValidatorBlock(BaseBlock): """Base Class for all Validators""" def __init__(self, name: str, config: Dict) -> None: self._name = name self._config = config - self._generators: List[BaseGenerator] = [] + self._generators: List[BaseGeneratorBlock] = [] self._validators: List[BaseValidator] = [] @property diff --git a/fms_sdg/generators/README.md b/fms_sdg/blocks/generators/README.md similarity index 100% rename from fms_sdg/generators/README.md rename to fms_sdg/blocks/generators/README.md diff --git a/fms_sdg/generators/__init__.py b/fms_sdg/blocks/generators/__init__.py similarity index 100% rename from fms_sdg/generators/__init__.py rename to fms_sdg/blocks/generators/__init__.py diff --git a/fms_sdg/generators/genai.py b/fms_sdg/blocks/generators/genai.py similarity index 96% rename from fms_sdg/generators/genai.py rename to fms_sdg/blocks/generators/genai.py index faf74a0e..3c3ba8b7 100644 --- a/fms_sdg/generators/genai.py +++ b/fms_sdg/blocks/generators/genai.py @@ -1,5 +1,4 @@ # Standard -from collections import defaultdict from typing import Any, Dict, List import copy import os @@ -9,10 +8,10 @@ # Local from fms_sdg.base.instance import Instance -from fms_sdg.base.registry import get_resource, register_generator -from fms_sdg.generators.llm import LMGenerator +from fms_sdg.base.registry import get_resource, register_block +from fms_sdg.blocks.generators.llm import LMGeneratorBlock from fms_sdg.resources.genai import GenAIKeyResource -import fms_sdg.generators.utils as generator_utils +import fms_sdg.blocks.generators.utils as generator_utils import fms_sdg.utils as utils try: @@ -29,8 +28,8 @@ pass -@register_generator("genai") -class GenAIGenerator(LMGenerator): +@register_block("genai") +class GenAIGeneratorBlock(LMGeneratorBlock): """GenAI Generator""" def __init__(self, name: str, config: Dict, **kwargs: Any): diff --git a/fms_sdg/generators/llm.py b/fms_sdg/blocks/generators/llm.py similarity index 92% rename from fms_sdg/generators/llm.py rename to fms_sdg/blocks/generators/llm.py index 4be899f3..4a9418db 100644 --- a/fms_sdg/generators/llm.py +++ b/fms_sdg/blocks/generators/llm.py @@ -11,7 +11,7 @@ """ # Standard -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union import abc import copy import hashlib @@ -19,20 +19,22 @@ import os # Third Party +from datasets import Dataset from genai.schema import TextGenerationParameters from sqlitedict import SqliteDict from tqdm import tqdm +import pandas as pd import transformers # Local -from fms_sdg.base.generator import BaseGenerator +from fms_sdg.base.block import BaseBlock from fms_sdg.base.instance import Instance from fms_sdg.utils import sdg_logger MODEL_ID_OR_PATH = "model_id_or_path" -class LMGenerator(BaseGenerator): +class LMGeneratorBlock(BaseBlock): """Class for LLM Generators""" def __init__(self, name: str, config: Dict, **kwargs: Any): @@ -136,6 +138,19 @@ def generate_batch( def set_cache_hook(self, cache_hook) -> None: self.cache_hook = cache_hook + def __call__( + self, + inputs: Union[List[Dict], pd.DataFrame, Dataset], + *args: Any, + arg_fields: Optional[List[str]] = None, + kwarg_fields: Optional[List[str]] = None, + result_field: Optional[str] = None, + method: str = None, + ) -> None: + assert method in ["generate", "loglikelihood"] + if method == "generate": + self.generate_batch(inputs) + ### SQLite-based caching of LM responses def hash_args(attr, request): @@ -159,7 +174,7 @@ def add_partial(self, attr, req, res) -> None: class CachingLM: - def __init__(self, lm: LMGenerator, cache_db) -> None: + def __init__(self, lm: LMGeneratorBlock, cache_db) -> None: """LM wrapper that returns cached results if they exist, and uses the underlying LM if not. :param lm: LM diff --git a/fms_sdg/generators/openai.py b/fms_sdg/blocks/generators/openai.py similarity index 95% rename from fms_sdg/generators/openai.py rename to fms_sdg/blocks/generators/openai.py index b1c6fb94..8c6cf668 100644 --- a/fms_sdg/generators/openai.py +++ b/fms_sdg/blocks/generators/openai.py @@ -11,7 +11,6 @@ """ # Standard -from collections import defaultdict from importlib.util import find_spec from typing import Any, Dict, List import copy @@ -21,10 +20,10 @@ # Local from fms_sdg.base.instance import Instance -from fms_sdg.base.registry import get_resource, register_generator -from fms_sdg.generators.llm import LMGenerator +from fms_sdg.base.registry import get_resource, register_block +from fms_sdg.blocks.generators.llm import LMGeneratorBlock from fms_sdg.resources.openai import OpenAIKeyResource -import fms_sdg.generators.utils as generator_utils +import fms_sdg.blocks.generators.utils as generator_utils import fms_sdg.utils as utils try: @@ -68,8 +67,8 @@ def completion(): return completion() -@register_generator("openai-chat", "local-chat-completions") -class OpenaiChatCompletionsLM(LMGenerator): +@register_block("openai-chat", "local-chat-completions") +class OpenaiChatCompletionsLMBlock(LMGeneratorBlock): def __init__(self, name: str, config: Dict, **kwargs: Any) -> None: """ diff --git a/fms_sdg/generators/utils.py b/fms_sdg/blocks/generators/utils.py similarity index 100% rename from fms_sdg/generators/utils.py rename to fms_sdg/blocks/generators/utils.py diff --git a/fms_sdg/generators/vllm.py b/fms_sdg/blocks/generators/vllm.py similarity index 98% rename from fms_sdg/generators/vllm.py rename to fms_sdg/blocks/generators/vllm.py index ec291192..18e1456a 100644 --- a/fms_sdg/generators/vllm.py +++ b/fms_sdg/blocks/generators/vllm.py @@ -23,11 +23,11 @@ # Local from fms_sdg.base.instance import Instance -from fms_sdg.base.registry import register_generator -from fms_sdg.generators.llm import LMGenerator -from fms_sdg.generators.utils import Collator, undistribute +from fms_sdg.base.registry import register_block +from fms_sdg.blocks.generators.llm import LMGeneratorBlock +from fms_sdg.blocks.generators.utils import Collator, undistribute from fms_sdg.utils import sdg_logger -import fms_sdg.generators.utils as generator_utils +import fms_sdg.blocks.generators.utils as generator_utils try: # Third Party @@ -41,8 +41,8 @@ # TODO: this can be made more efficient for our purposes by rewriting the async code ourselves -@register_generator("vllm") -class vLLMGenerator(LMGenerator): +@register_block("vllm") +class vLLMGeneratorBlock(LMGeneratorBlock): """vLLM Generator""" _DEFAULT_MAX_LENGTH = 2048 diff --git a/fms_sdg/validators/api.py b/fms_sdg/blocks/validators/api.py similarity index 95% rename from fms_sdg/validators/api.py rename to fms_sdg/blocks/validators/api.py index 2a5991a3..329ac173 100644 --- a/fms_sdg/validators/api.py +++ b/fms_sdg/blocks/validators/api.py @@ -3,9 +3,9 @@ import json # Local +from fms_sdg.base.block import BaseBlock from fms_sdg.base.instance import Instance -from fms_sdg.base.registry import register_validator -from fms_sdg.base.validator import BaseValidator +from fms_sdg.base.registry import register_block # Constants @@ -19,8 +19,8 @@ # Classes -@register_validator("api_function_checking") -class APIGenSpecValidator(BaseValidator): +@register_block("api_function_checking") +class APIGenSpecValidator(BaseBlock): """Class for API Sequence Prediction Validator""" def validate_batch(self, inputs: List[Instance], **kwargs: Any) -> None: @@ -149,7 +149,7 @@ def is_nested_match(arg_content: str, prev_components: List[Dict], api_info: Dic return False -@register_validator("api_yes_no") +@register_block("api_yes_no") class ApiGenSpecYesNoValidation(APIGenSpecValidator): """Class for API Intent Detection Validator""" diff --git a/fms_sdg/validators/lm_judge.py b/fms_sdg/blocks/validators/lm_judge.py similarity index 61% rename from fms_sdg/validators/lm_judge.py rename to fms_sdg/blocks/validators/lm_judge.py index d178b270..df19fcae 100644 --- a/fms_sdg/validators/lm_judge.py +++ b/fms_sdg/blocks/validators/lm_judge.py @@ -3,21 +3,23 @@ # Local from fms_sdg.base.instance import Instance -from fms_sdg.base.registry import get_generator, register_validator -from fms_sdg.base.validator import BaseValidator -from fms_sdg.generators.llm import LMGenerator +from fms_sdg.base.registry import get_block, register_block +from fms_sdg.base.validator import BaseBlock +from fms_sdg.blocks.generators.llm import LMGeneratorBlock TYPE_KEY = "lm_type" -@register_validator("llm_judge") -class LMJudgeValidator(BaseValidator): +@register_block("llm_judge") +class LMJudgeValidator(BaseBlock): """LLM-based Validator""" def __init__(self, name: str, config: Dict, **kwargs: Any): super().__init__(name, config, **kwargs) - self._llm_generator: LMGenerator = get_generator(config[TYPE_KEY])(name, config) - self._generators.append(self._llm_generator) + self._llm_generator: LMGeneratorBlock = get_block(config[TYPE_KEY])( + name, config + ) + self._blocks.append(self._llm_generator) def validate_batch(self, inputs: List[Instance], **kwargs: Any) -> None: generator_inputs = [Instance([x.args[0]], x.kwargs) for x in inputs] diff --git a/fms_sdg/validators/nl2sql/sql_execution_validator.py b/fms_sdg/blocks/validators/nl2sql/sql_execution_validator.py similarity index 91% rename from fms_sdg/validators/nl2sql/sql_execution_validator.py rename to fms_sdg/blocks/validators/nl2sql/sql_execution_validator.py index d2717bfe..ef4ff785 100644 --- a/fms_sdg/validators/nl2sql/sql_execution_validator.py +++ b/fms_sdg/blocks/validators/nl2sql/sql_execution_validator.py @@ -7,16 +7,16 @@ import sqlglot # Local +from fms_sdg.base.block import BaseBlock from fms_sdg.base.instance import Instance -from fms_sdg.base.registry import register_validator -from fms_sdg.base.validator import BaseValidator +from fms_sdg.base.registry import register_block logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) -@register_validator("sql_execution_validator") -class SQLExecutionValidator(BaseValidator): +@register_block("sql_execution_validator") +class SQLExecutionValidator(BaseBlock): """SQL execution validator.""" def validate_batch(self, inputs: List[Instance], **kwargs: Any) -> None: diff --git a/fms_sdg/validators/nl2sql/sql_syntax_validator.py b/fms_sdg/blocks/validators/nl2sql/sql_syntax_validator.py similarity index 87% rename from fms_sdg/validators/nl2sql/sql_syntax_validator.py rename to fms_sdg/blocks/validators/nl2sql/sql_syntax_validator.py index 830cfa1e..37cd92be 100644 --- a/fms_sdg/validators/nl2sql/sql_syntax_validator.py +++ b/fms_sdg/blocks/validators/nl2sql/sql_syntax_validator.py @@ -6,16 +6,16 @@ import sqlglot # Local +from fms_sdg.base.block import BaseBlock from fms_sdg.base.instance import Instance -from fms_sdg.base.registry import register_validator -from fms_sdg.base.validator import BaseValidator +from fms_sdg.base.registry import register_block logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) -@register_validator("sql_syntax_validator") -class SQLSyntaxValidator(BaseValidator): +@register_block("sql_syntax_validator") +class SQLSyntaxValidator(BaseBlock): """SQL syntax validator.""" def validate_batch(self, inputs: List[Instance], **kwargs: Any) -> None: diff --git a/fms_sdg/validators/rouge.py b/fms_sdg/blocks/validators/rouge.py similarity index 92% rename from fms_sdg/validators/rouge.py rename to fms_sdg/blocks/validators/rouge.py index d5229655..6e9d262f 100644 --- a/fms_sdg/validators/rouge.py +++ b/fms_sdg/blocks/validators/rouge.py @@ -3,9 +3,9 @@ from typing import Any, Dict, List, Union # Local +from fms_sdg.base.block import BaseBlock from fms_sdg.base.instance import Instance -from fms_sdg.base.registry import register_validator -from fms_sdg.base.validator import BaseValidator +from fms_sdg.base.registry import register_block try: # Third Party @@ -14,8 +14,8 @@ pass -@register_validator("rouge_scorer") -class RougeValidator(BaseValidator): +@register_block("rouge_scorer") +class RougeValidator(BaseBlock): """Base Class for all Validators""" def __init__(self, name: str, config: Dict) -> None: diff --git a/fms_sdg/databuilders/api/api_function_checking.yaml b/fms_sdg/databuilders/api/api_function_checking.yaml index 647655a6..56da07a3 100644 --- a/fms_sdg/databuilders/api/api_function_checking.yaml +++ b/fms_sdg/databuilders/api/api_function_checking.yaml @@ -1,5 +1,5 @@ name: api_function_checking -generators: +blocks: llm1: type: genai decoding_method: sample @@ -7,7 +7,6 @@ generators: max_new_tokens: 1024 min_new_tokens: 1 model_id_or_path: mistralai/mixtral-8x7b-instruct-v01 -validators: val1: type: api_function_checking val2: diff --git a/fms_sdg/databuilders/api/api_yes_no_detection.yaml b/fms_sdg/databuilders/api/api_yes_no_detection.yaml index 1c81dd7a..b7fd2ec0 100644 --- a/fms_sdg/databuilders/api/api_yes_no_detection.yaml +++ b/fms_sdg/databuilders/api/api_yes_no_detection.yaml @@ -1,5 +1,5 @@ name: api_yes_no_detection -generators: +blocks: llm1: type: genai decoding_method: sample @@ -7,7 +7,6 @@ generators: max_new_tokens: 1024 min_new_tokens: 1 model_id_or_path: mistralai/mixtral-8x7b-instruct-v01 -validators: val1: type: api_yes_no val2: diff --git a/fms_sdg/databuilders/api/generate.py b/fms_sdg/databuilders/api/generate.py index 3e542235..1e02b24b 100644 --- a/fms_sdg/databuilders/api/generate.py +++ b/fms_sdg/databuilders/api/generate.py @@ -9,8 +9,8 @@ from fms_sdg.base.instance import Instance from fms_sdg.base.registry import register_data_builder from fms_sdg.base.task import group_data_by_task +from fms_sdg.blocks.generators.llm import LMGeneratorBlock from fms_sdg.databuilders.api.task import ApiSdgData, ApiSdgTask -from fms_sdg.generators.llm import LMGenerator from fms_sdg.utils import sdg_logger from fms_sdg.validators.api import APIGenSpecValidator, ApiGenSpecYesNoValidation from fms_sdg.validators.rouge import RougeValidator @@ -38,7 +38,7 @@ def __init__( ), "Number of prompt examples must be at least 1" # llm1 is the main generator that will produce the synthetic examples - llm1: LMGenerator + llm1: LMGeneratorBlock val1: APIGenSpecValidator val2: RougeValidator @@ -218,7 +218,7 @@ class ApiYesNoDataBuilder(ApiDataBuilder): """Class for API Sequence task""" # llm1 is the main generator that will produce the synthetic examples - llm1: LMGenerator + llm1: LMGeneratorBlock val1: ApiGenSpecYesNoValidation @@ -227,5 +227,5 @@ class ApiDetectionDataBuilder(ApiDataBuilder): """Class for API Sequence task""" # llm1 is the main generator that will produce the synthetic examples - llm1: LMGenerator + llm1: LMGeneratorBlock val1: APIGenSpecValidator diff --git a/fms_sdg/databuilders/nl2sql/generate.py b/fms_sdg/databuilders/nl2sql/generate.py index 8d784e6b..2b967955 100644 --- a/fms_sdg/databuilders/nl2sql/generate.py +++ b/fms_sdg/databuilders/nl2sql/generate.py @@ -8,6 +8,11 @@ from fms_sdg.base.instance import Instance from fms_sdg.base.registry import register_data_builder from fms_sdg.base.task import SdgTask +from fms_sdg.blocks.generators.llm import LMGeneratorBlock +from fms_sdg.blocks.validators.nl2sql.sql_execution_validator import ( + SQLExecutionValidator, +) +from fms_sdg.blocks.validators.nl2sql.sql_syntax_validator import SQLSyntaxValidator from fms_sdg.databuilders.nl2sql.sqlinstruct.models import ( SQLDataGenerationSchema, SQLTriplet, @@ -18,10 +23,7 @@ from fms_sdg.databuilders.nl2sql.sqlinstruct.prompts import PromptFactory from fms_sdg.databuilders.nl2sql.task import SqlSdgData, SqlSdgTask from fms_sdg.databuilders.simple.task import InstructLabSdgData -from fms_sdg.generators.llm import LMGenerator from fms_sdg.utils import sdg_logger -from fms_sdg.validators.nl2sql.sql_execution_validator import SQLExecutionValidator -from fms_sdg.validators.nl2sql.sql_syntax_validator import SQLSyntaxValidator @register_data_builder("nl2sql") @@ -38,7 +40,7 @@ def __init__( super().__init__(*args, **kwargs) # llm1 is a code generator for the synthetic examples - llm1: LMGenerator + llm1: LMGeneratorBlock # val1 is the validator which checks SQL syntax val1: SQLSyntaxValidator diff --git a/fms_sdg/databuilders/nl2sql/nl2sql.yaml b/fms_sdg/databuilders/nl2sql/nl2sql.yaml index 2ead1a91..411ed7d0 100644 --- a/fms_sdg/databuilders/nl2sql/nl2sql.yaml +++ b/fms_sdg/databuilders/nl2sql/nl2sql.yaml @@ -1,12 +1,11 @@ name: nl2sql -generators: +blocks: llm1: type: genai temperature: 0.0 max_new_tokens: 512 min_new_tokens: 1 model_id_or_path: ibm/granite-8b-code-instruct -validators: val1: type: sql_syntax_validator val2: diff --git a/fms_sdg/databuilders/simple/generate.py b/fms_sdg/databuilders/simple/generate.py index ed2ccafa..a9d9c82c 100644 --- a/fms_sdg/databuilders/simple/generate.py +++ b/fms_sdg/databuilders/simple/generate.py @@ -4,15 +4,18 @@ import random import time +# Third Party +import pandas as pd + # Local from fms_sdg.base.databuilder import DataBuilder from fms_sdg.base.instance import Instance from fms_sdg.base.registry import register_data_builder from fms_sdg.base.task import SdgTask, group_data_by_task +from fms_sdg.blocks.generators.llm import LMGeneratorBlock +from fms_sdg.blocks.validators.rouge import RougeValidator from fms_sdg.databuilders.simple.task import InstructLabSdgData, InstructLabSdgTask -from fms_sdg.generators.llm import LMGenerator from fms_sdg.utils import sdg_logger -from fms_sdg.validators.rouge import RougeValidator import fms_sdg.databuilders.simple.utils as utils @@ -23,7 +26,7 @@ class SimpleInstructDataBuilder(DataBuilder): TASK_TYPE: SdgTask = InstructLabSdgTask # llm1 is the main generator that will produce the synthetic examples - llm1: LMGenerator + llm1: LMGeneratorBlock # val1 is the validator which checks rouge score val1: RougeValidator @@ -55,6 +58,9 @@ def __call__( instruction_data: List[InstructLabSdgData], ) -> List[InstructLabSdgData]: + print(pd.DataFrame([instruction_data[0]]).to_markdown()) + input("--") + inputs: List[Instance] = [] instruction_data = instruction_data + [] random.shuffle(instruction_data) @@ -66,6 +72,12 @@ def __call__( prompt = self._encode_prompt(prompt_instructions) args = [prompt] kwargs = {"stop_sequences": [f"* Task {len(prompt_instructions)+2}"]} + print( + pd.DataFrame( + [Instance(args, kwargs, data=prompt_instructions)] + ).to_markdown() + ) + input("--") inputs.append(Instance(args, kwargs, data=prompt_instructions)) request_start = time.time() diff --git a/fms_sdg/databuilders/simple/simple.yaml b/fms_sdg/databuilders/simple/simple.yaml index a3f6e963..477b05e3 100644 --- a/fms_sdg/databuilders/simple/simple.yaml +++ b/fms_sdg/databuilders/simple/simple.yaml @@ -1,12 +1,11 @@ name: simple -generators: +blocks: llm1: type: genai temperature: 0.0 max_new_tokens: 512 min_new_tokens: 1 model_id_or_path: mistralai/mixtral-8x7b-instruct-v01 -validators: val1: type: rouge_scorer threshold: 1.0 diff --git a/fms_sdg/validators/genai.py b/fms_sdg/validators/genai.py deleted file mode 100644 index 7fe79be8..00000000 --- a/fms_sdg/validators/genai.py +++ /dev/null @@ -1,19 +0,0 @@ -# Standard -from typing import Any, Dict, List - -# Local -from fms_sdg.base.instance import Instance -from fms_sdg.base.registry import register_validator -from fms_sdg.base.validator import BaseValidator -from fms_sdg.generators.genai import GenAIGenerator - - -@register_validator("genai") -class GenAIValidator(GenAIGenerator, BaseValidator): - """GenAI Validator""" - - def __init__(self, name: str, config: Dict, **kwargs: Any): - super().__init__(name, config, **kwargs) - - def validate_batch(self, inputs: List[Instance], **kwargs: Any) -> None: - pass diff --git a/pyproject.toml b/pyproject.toml index 23d9b915..da84e0d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ readme = "README.md" classifiers = [ "License :: OSI Approved :: Apache License", ] -requires-python = ">=3.8" +requires-python = ">=3.9" dependencies = [ "accelerate>=0.26.0", "datasets>=2.16.0", @@ -25,6 +25,7 @@ dependencies = [ "torch>=2.3", "tqdm-multiprocess", "transformers>=4.1", + "tabulate>=0.9", "zstandard", "dill", "word2number", diff --git a/templates/generator/template.py b/templates/generator/template.py index 6d84e4d2..db633be3 100644 --- a/templates/generator/template.py +++ b/templates/generator/template.py @@ -2,13 +2,13 @@ from typing import Any, Dict, List # Local -from fms_sdg.base.generator import BaseGenerator +from fms_sdg.base.generator import BaseGeneratorBlock from fms_sdg.base.instance import Instance from fms_sdg.base.registry import register_generator @register_generator("template_generator") -class TemplateGenerator(BaseGenerator): +class TemplateGenerator(BaseGeneratorBlock): """Base Class for all Generators""" def __init__(self, name: str, config: Dict, **kwargs: Any) -> None: From 7b2d811cc4a0e91a4fe75a1baf87c9ce0a10ebce Mon Sep 17 00:00:00 2001 From: Max Crouse Date: Wed, 3 Jul 2024 10:59:37 -0500 Subject: [PATCH 02/41] update all pub databuilders --- fms_sdg/base/block.py | 69 +++++++++++++ fms_sdg/blocks/generators/genai.py | 24 ++++- fms_sdg/blocks/generators/llm.py | 52 +++++++++- fms_sdg/blocks/generators/openai.py | 24 ++++- fms_sdg/blocks/generators/vllm.py | 22 +++++ fms_sdg/blocks/validators/api.py | 32 +++++-- fms_sdg/blocks/validators/lm_judge.py | 4 +- .../nl2sql/sql_execution_validator.py | 34 ++++--- .../validators/nl2sql/sql_syntax_validator.py | 34 ++++--- fms_sdg/blocks/validators/rouge.py | 32 +++++-- .../api/api_function_checking.yaml | 2 + .../api/api_yes_no_detection.yaml | 2 + fms_sdg/databuilders/api/generate.py | 82 ++++++++++------ fms_sdg/databuilders/nl2sql/generate.py | 82 +++++++++------- fms_sdg/databuilders/nl2sql/nl2sql.yaml | 2 + .../nl2sql/sqlinstruct/pipeline.py | 17 ++-- fms_sdg/databuilders/simple/generate.py | 56 ++++++----- fms_sdg/databuilders/simple/simple.yaml | 1 + tests/{ => blocks}/generators/__init__.py | 0 tests/{ => blocks}/generators/test_llm.py | 96 ++++++++++--------- tests/{ => blocks}/validators/__init__.py | 0 tests/{ => blocks}/validators/test_api.py | 0 .../{ => blocks}/validators/test_lm_judge.py | 0 tests/{ => blocks}/validators/test_nl2sql.py | 0 tests/{ => blocks}/validators/test_rouge.py | 0 25 files changed, 476 insertions(+), 191 deletions(-) rename tests/{ => blocks}/generators/__init__.py (100%) rename tests/{ => blocks}/generators/test_llm.py (57%) rename tests/{ => blocks}/validators/__init__.py (100%) rename tests/{ => blocks}/validators/test_api.py (100%) rename tests/{ => blocks}/validators/test_lm_judge.py (100%) rename tests/{ => blocks}/validators/test_nl2sql.py (100%) rename tests/{ => blocks}/validators/test_rouge.py (100%) diff --git a/fms_sdg/base/block.py b/fms_sdg/base/block.py index e395c994..0bc6a6b0 100644 --- a/fms_sdg/base/block.py +++ b/fms_sdg/base/block.py @@ -36,6 +36,44 @@ def blocks(self) -> List: """Returns the constituent blocks associated with this class.""" return self._blocks + def get_args_kwargs( + self, + inp: Union[Dict, pd.DataFrame, Dataset], + arg_fields: Optional[List[str]] = None, + kwarg_fields: Optional[List[str]] = None, + ): + arg_fields = arg_fields if arg_fields is not None else self._arg_fields + kwarg_fields = kwarg_fields if kwarg_fields is not None else self._kwarg_fields + + if arg_fields is None: + arg_fields = [] + if kwarg_fields is None: + kwarg_fields = [] + + if type(inp) == dict: + args = [inp.get(arg) for arg in arg_fields] + kwargs = {kwarg: inp.get(kwarg) for kwarg in kwarg_fields} + elif type(inp) in [pd.DataFrame, Dataset]: + args = [inp.get(arg) for arg in arg_fields] + kwargs = {kwarg: inp.get(kwarg) for kwarg in kwarg_fields} + else: + raise ValueError(f"Unexpected input type: {type(inp)}") + + return args, kwargs + + def write_result( + self, inp: Union[Dict, pd.DataFrame, Dataset], res: Any, result_field: str + ): + result_field = result_field if result_field is not None else self._result_field + assert result_field is not None, "Result field cannot be None!" + + if type(inp) == dict: + inp[result_field] = res + elif type(inp) in [pd.DataFrame, Dataset]: + inp[result_field] = res + else: + raise ValueError(f"Unexpected input type: {type(inp)}") + @abc.abstractmethod def __call__( self, @@ -47,3 +85,34 @@ def __call__( **kwargs: Any, ) -> None: pass + + +class BaseGeneratorBlock(BaseBlock): + pass + + +class BaseValidatorBlock(BaseBlock): + def __init__(self, name: str, config: Dict, **kwargs: Any) -> None: + super().__init__(name, config, **kwargs) + self._filter_invalids = config.get("filter", False) + + def __call__( + self, + inputs: Union[List[Dict], pd.DataFrame, Dataset], + *args: Any, + arg_fields: Optional[List[str]] = None, + kwarg_fields: Optional[List[str]] = None, + result_field: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + outputs = [] + for x in inputs: + inp_args, inp_kwargs = self.get_args_kwargs(x, arg_fields, kwarg_fields) + res = self._validate(*inp_args, **inp_kwargs) + if res or not self._filter_invalids: + self.write_result(x, res, result_field) + outputs.append(x) + return outputs + + def _validate(self, *args: Any, **kwargs: Any) -> bool: + raise NotImplementedError diff --git a/fms_sdg/blocks/generators/genai.py b/fms_sdg/blocks/generators/genai.py index 7302c10c..671da0f5 100644 --- a/fms_sdg/blocks/generators/genai.py +++ b/fms_sdg/blocks/generators/genai.py @@ -1,9 +1,11 @@ # Standard -from typing import Any, Dict, List +from typing import Any, Dict, List, Union import copy import os # Third Party +from datasets import Dataset +from pandas import DataFrame from tqdm import tqdm # Local @@ -205,3 +207,23 @@ def loglikelihood_batch( pbar.update(1) pbar.close() + + def __call__( + self, + inputs: Union[List[Dict], DataFrame, Dataset], + *args: Any, + arg_fields: Union[List[str], None] = None, + kwarg_fields: Union[List[str], None] = None, + result_field: Union[str, None] = None, + method: str = "generate", + **kwargs: Any, + ) -> None: + return super().__call__( + inputs, + *args, + arg_fields=arg_fields, + kwarg_fields=kwarg_fields, + result_field=result_field, + method=method, + **kwargs, + ) diff --git a/fms_sdg/blocks/generators/llm.py b/fms_sdg/blocks/generators/llm.py index 10252d96..e87f7855 100644 --- a/fms_sdg/blocks/generators/llm.py +++ b/fms_sdg/blocks/generators/llm.py @@ -27,14 +27,14 @@ import transformers # Local -from fms_sdg.base.block import BaseBlock +from fms_sdg.base.block import BaseGeneratorBlock from fms_sdg.base.instance import Instance from fms_sdg.utils import sdg_logger MODEL_ID_OR_PATH = "model_id_or_path" -class LMGeneratorBlock(BaseBlock): +class LMGeneratorBlock(BaseGeneratorBlock): """Class for LLM Generators""" def __init__(self, name: str, config: Dict, **kwargs: Any): @@ -98,11 +98,40 @@ def __call__( arg_fields: Optional[List[str]] = None, kwarg_fields: Optional[List[str]] = None, result_field: Optional[str] = None, - method: str = None, + method: str = "generate", + **kwargs: Any, ) -> None: - assert method in ["generate", "loglikelihood"] + + # simplify generation here + instances: List[Instance] = [] + for inp in inputs: + inp_args, inp_kwargs = self.get_args_kwargs(inp, arg_fields, kwarg_fields) + instances.append(Instance(args=inp_args, kwargs=inp_kwargs, data=inp)) + if method == "generate": - self.generate_batch(inputs) + self.generate_batch( + instances, + **kwargs, + ) + elif method == "loglikelihood": + self.loglikelihood_batch( + instances, + **kwargs, + ) + else: + err_str = ( + f"Unhandled method type: {method}" + if method is not None + else "Must set 'method' kwarg to 'generate' or 'loglikelihood'" + ) + raise ValueError(err_str) + + outputs = [] + for inst in instances: + self.write_result(inst.data, inst.result, result_field) + outputs.append(inst.data) + + return outputs ### SQLite-based caching of LM responses @@ -147,6 +176,8 @@ def __init__(self, lm: LMGeneratorBlock, cache_db) -> None: self.dbdict def __getattr__(self, attr): + print(attr) + input("--") lm_attr = getattr(self.lm, attr) if not callable(lm_attr): return lm_attr @@ -214,5 +245,16 @@ def fn(requests: List[Instance]): return fn + # def __call__( + # self, + # inputs: Union[List[Dict], pd.DataFrame, Dataset], + # *args: Any, + # arg_fields: Optional[List[str]] = None, + # kwarg_fields: Optional[List[str]] = None, + # result_field: Optional[str] = None, + # method: str = "generate", + # **kwargs: Any, + # ) -> None: + def get_cache_hook(self): return CacheHook(self) diff --git a/fms_sdg/blocks/generators/openai.py b/fms_sdg/blocks/generators/openai.py index f334364f..9dbdc9d8 100644 --- a/fms_sdg/blocks/generators/openai.py +++ b/fms_sdg/blocks/generators/openai.py @@ -12,10 +12,12 @@ # Standard from importlib.util import find_spec -from typing import Any, Dict, List +from typing import Any, Dict, List, Union import copy # Third Party +from datasets import Dataset +from pandas import DataFrame from tqdm import tqdm # Local @@ -179,3 +181,23 @@ def generate_batch( def loglikelihood_batch(self, requests, disable_tqdm: bool = False): raise NotImplementedError("No support for logits.") + + def __call__( + self, + inputs: Union[List[Dict], DataFrame, Dataset], + *args: Any, + arg_fields: Union[List[str], None] = None, + kwarg_fields: Union[List[str], None] = None, + result_field: Union[str, None] = None, + method: str = "generate", + **kwargs: Any, + ) -> None: + return super().__call__( + inputs, + *args, + arg_fields=arg_fields, + kwarg_fields=kwarg_fields, + result_field=result_field, + method=method, + **kwargs, + ) diff --git a/fms_sdg/blocks/generators/vllm.py b/fms_sdg/blocks/generators/vllm.py index 4894f388..31efb2d3 100644 --- a/fms_sdg/blocks/generators/vllm.py +++ b/fms_sdg/blocks/generators/vllm.py @@ -17,8 +17,10 @@ import copy # Third Party +from datasets import Dataset from more_itertools import distribute from packaging.version import parse as parse_version +from pandas import DataFrame from tqdm import tqdm # Local @@ -461,3 +463,23 @@ def modify_gen_kwargs(kwargs: dict) -> dict: "spaces_between_special_tokens", False ) return kwargs + + def __call__( + self, + inputs: Union[List[Dict], DataFrame, Dataset], + *args: Any, + arg_fields: Union[List[str], None] = None, + kwarg_fields: Union[List[str], None] = None, + result_field: Union[str, None] = None, + method: str = "generate", + **kwargs: Any, + ) -> None: + return super().__call__( + inputs, + *args, + arg_fields=arg_fields, + kwarg_fields=kwarg_fields, + result_field=result_field, + method=method, + **kwargs, + ) diff --git a/fms_sdg/blocks/validators/api.py b/fms_sdg/blocks/validators/api.py index 329ac173..0244d0f3 100644 --- a/fms_sdg/blocks/validators/api.py +++ b/fms_sdg/blocks/validators/api.py @@ -1,10 +1,13 @@ # Standard -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional, Union import json +# Third Party +from datasets import Dataset +from pandas import DataFrame + # Local -from fms_sdg.base.block import BaseBlock -from fms_sdg.base.instance import Instance +from fms_sdg.base.block import BaseValidatorBlock from fms_sdg.base.registry import register_block # Constants @@ -20,13 +23,26 @@ @register_block("api_function_checking") -class APIGenSpecValidator(BaseBlock): +class APIGenSpecValidator(BaseValidatorBlock): """Class for API Sequence Prediction Validator""" - def validate_batch(self, inputs: List[Instance], **kwargs: Any) -> None: - """Takes in a list of Instance objects (each containing their own arg / kwargs) and sets their result flag to true or false""" - for x in inputs: - x.result = self._validate(*x.args, **x.kwargs) + def __call__( + self, + inputs: Union[List[Dict], DataFrame, Dataset], + *args: Any, + arg_fields: Optional[List[str]] = None, + kwarg_fields: Optional[List[str]] = None, + result_field: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + return super().__call__( + inputs, + *args, + arg_fields=arg_fields, + kwarg_fields=kwarg_fields, + result_field=result_field, + **kwargs, + ) def _validate( self, diff --git a/fms_sdg/blocks/validators/lm_judge.py b/fms_sdg/blocks/validators/lm_judge.py index df19fcae..0a9b4e88 100644 --- a/fms_sdg/blocks/validators/lm_judge.py +++ b/fms_sdg/blocks/validators/lm_judge.py @@ -4,14 +4,14 @@ # Local from fms_sdg.base.instance import Instance from fms_sdg.base.registry import get_block, register_block -from fms_sdg.base.validator import BaseBlock +from fms_sdg.base.validator import BaseValidatorBlock from fms_sdg.blocks.generators.llm import LMGeneratorBlock TYPE_KEY = "lm_type" @register_block("llm_judge") -class LMJudgeValidator(BaseBlock): +class LMJudgeValidator(BaseValidatorBlock): """LLM-based Validator""" def __init__(self, name: str, config: Dict, **kwargs: Any): diff --git a/fms_sdg/blocks/validators/nl2sql/sql_execution_validator.py b/fms_sdg/blocks/validators/nl2sql/sql_execution_validator.py index ef4ff785..5a9ee1e6 100644 --- a/fms_sdg/blocks/validators/nl2sql/sql_execution_validator.py +++ b/fms_sdg/blocks/validators/nl2sql/sql_execution_validator.py @@ -1,14 +1,15 @@ # Standard -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional, Union import logging import sqlite3 # Third Party +from datasets import Dataset +from pandas import DataFrame import sqlglot # Local -from fms_sdg.base.block import BaseBlock -from fms_sdg.base.instance import Instance +from fms_sdg.base.block import BaseValidatorBlock from fms_sdg.base.registry import register_block logger = logging.getLogger(__name__) @@ -16,17 +17,26 @@ @register_block("sql_execution_validator") -class SQLExecutionValidator(BaseBlock): +class SQLExecutionValidator(BaseValidatorBlock): """SQL execution validator.""" - def validate_batch(self, inputs: List[Instance], **kwargs: Any) -> None: - """Validate a batch. - - Args: - inputs: list of instances. - """ - for x in inputs: - x.result = self._validate(*x.args, **x.kwargs) + def __call__( + self, + inputs: Union[List[Dict], DataFrame, Dataset], + *args: Any, + arg_fields: Optional[List[str]] = None, + kwarg_fields: Optional[List[str]] = None, + result_field: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + return super().__call__( + inputs, + *args, + arg_fields=arg_fields, + kwarg_fields=kwarg_fields, + result_field=result_field, + **kwargs, + ) def _validate(self, record: Dict[str, str], **kwargs: Any) -> bool: """Validate a record containing information on schema, query and utterance. diff --git a/fms_sdg/blocks/validators/nl2sql/sql_syntax_validator.py b/fms_sdg/blocks/validators/nl2sql/sql_syntax_validator.py index 37cd92be..a66c5c08 100644 --- a/fms_sdg/blocks/validators/nl2sql/sql_syntax_validator.py +++ b/fms_sdg/blocks/validators/nl2sql/sql_syntax_validator.py @@ -1,13 +1,14 @@ # Standard -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional, Union import logging # Third Party +from datasets import Dataset +from pandas import DataFrame import sqlglot # Local -from fms_sdg.base.block import BaseBlock -from fms_sdg.base.instance import Instance +from fms_sdg.base.block import BaseValidatorBlock from fms_sdg.base.registry import register_block logger = logging.getLogger(__name__) @@ -15,17 +16,26 @@ @register_block("sql_syntax_validator") -class SQLSyntaxValidator(BaseBlock): +class SQLSyntaxValidator(BaseValidatorBlock): """SQL syntax validator.""" - def validate_batch(self, inputs: List[Instance], **kwargs: Any) -> None: - """Validate a batch. - - Args: - inputs: list of instances. - """ - for x in inputs: - x.result = self._validate(*x.args, **x.kwargs) + def __call__( + self, + inputs: Union[List[Dict], DataFrame, Dataset], + *args: Any, + arg_fields: Optional[List[str]] = None, + kwarg_fields: Optional[List[str]] = None, + result_field: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + return super().__call__( + inputs, + *args, + arg_fields=arg_fields, + kwarg_fields=kwarg_fields, + result_field=result_field, + **kwargs, + ) def _validate( self, record: Dict[str, str], sql_dialect: str = "postgres", **kwargs: Any diff --git a/fms_sdg/blocks/validators/rouge.py b/fms_sdg/blocks/validators/rouge.py index 6e9d262f..e0cf70da 100644 --- a/fms_sdg/blocks/validators/rouge.py +++ b/fms_sdg/blocks/validators/rouge.py @@ -1,10 +1,13 @@ # Standard from functools import partial -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union + +# Third Party +from datasets import Dataset +from pandas import DataFrame # Local -from fms_sdg.base.block import BaseBlock -from fms_sdg.base.instance import Instance +from fms_sdg.base.block import BaseValidatorBlock from fms_sdg.base.registry import register_block try: @@ -15,7 +18,7 @@ @register_block("rouge_scorer") -class RougeValidator(BaseBlock): +class RougeValidator(BaseValidatorBlock): """Base Class for all Validators""" def __init__(self, name: str, config: Dict) -> None: @@ -41,10 +44,23 @@ def tokenize(self, inp: Union[List, str]): self._cache[inp] = self.scorer._tokenizer.tokenize(inp) return self._cache[inp] - def validate_batch(self, inputs: List[Instance], **kwargs: Any) -> None: - """Takes in a list of Instance objects (each containing their own arg / kwargs) and returns a list of tuples [[, instance0], [, instance1], ...]""" - for x in inputs: - x.result = self._validate(*x.args, **x.kwargs) + def __call__( + self, + inputs: Union[List[Dict], DataFrame, Dataset], + *args: Any, + arg_fields: Optional[List[str]] = None, + kwarg_fields: Optional[List[str]] = None, + result_field: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + return super().__call__( + inputs, + *args, + arg_fields=arg_fields, + kwarg_fields=kwarg_fields, + result_field=result_field, + **kwargs, + ) def _validate(self, new_tokens: List[int], check_tokens: List[List[int]]) -> bool: """Runs through all the validators if data list is None. Otherwise just runs through the validators specified for data in the List""" diff --git a/fms_sdg/databuilders/api/api_function_checking.yaml b/fms_sdg/databuilders/api/api_function_checking.yaml index 56da07a3..bcab7b8e 100644 --- a/fms_sdg/databuilders/api/api_function_checking.yaml +++ b/fms_sdg/databuilders/api/api_function_checking.yaml @@ -9,8 +9,10 @@ blocks: model_id_or_path: mistralai/mixtral-8x7b-instruct-v01 val1: type: api_function_checking + filter: true val2: type: rouge_scorer + filter: true threshold: 0.35 metadata: version: 1.0 diff --git a/fms_sdg/databuilders/api/api_yes_no_detection.yaml b/fms_sdg/databuilders/api/api_yes_no_detection.yaml index b7fd2ec0..cbbace1b 100644 --- a/fms_sdg/databuilders/api/api_yes_no_detection.yaml +++ b/fms_sdg/databuilders/api/api_yes_no_detection.yaml @@ -9,8 +9,10 @@ blocks: model_id_or_path: mistralai/mixtral-8x7b-instruct-v01 val1: type: api_yes_no + filter: true val2: type: rouge_scorer + filter: true threshold: 0.35 metadata: version: 1.0 diff --git a/fms_sdg/databuilders/api/generate.py b/fms_sdg/databuilders/api/generate.py index 1e02b24b..13bd56db 100644 --- a/fms_sdg/databuilders/api/generate.py +++ b/fms_sdg/databuilders/api/generate.py @@ -1,19 +1,17 @@ # Standard -from typing import Any, List, Optional, Tuple -import copy +from typing import Any, Dict, List, Optional import random import time # Local from fms_sdg.base.databuilder import DataBuilder -from fms_sdg.base.instance import Instance from fms_sdg.base.registry import register_data_builder from fms_sdg.base.task import group_data_by_task from fms_sdg.blocks.generators.llm import LMGeneratorBlock +from fms_sdg.blocks.validators.api import APIGenSpecValidator, ApiGenSpecYesNoValidation +from fms_sdg.blocks.validators.rouge import RougeValidator from fms_sdg.databuilders.api.task import ApiSdgData, ApiSdgTask from fms_sdg.utils import sdg_logger -from fms_sdg.validators.api import APIGenSpecValidator, ApiGenSpecYesNoValidation -from fms_sdg.validators.rouge import RougeValidator import fms_sdg.databuilders.api.utils as api_utils @@ -51,22 +49,26 @@ def __call__( # first generate new data instruction_data = instruction_data + [] random.shuffle(instruction_data) - gen_inputs: List[Instance] = [] + gen_inputs: List[Dict] = [] for task_data in group_data_by_task(instruction_data): for _ in range(self._num_base_examples): prompt, new_instr = self._construct_new_data(task_data) - args = [prompt] - kwargs = {"stop_sequences": [f"API:"]} - gen_inputs.append(Instance(args, kwargs, data=new_instr)) + inp = {"prompt": prompt, "stop_sequences": [f"API:"], "data": new_instr} + gen_inputs.append(inp) request_start = time.time() - self.llm1.generate_batch(gen_inputs) + llm_outputs = self.llm1( + gen_inputs, + arg_fields=["prompt"], + kwarg_fields=["stop_sequences"], + result_field="output", + ) request_duration = time.time() - request_start # now begin filtering generated data post_process_start = time.time() - outputs, wf_discarded = self._wf_filter_data(gen_inputs) + outputs, wf_discarded = self._wf_filter_data(llm_outputs) outputs, rouge_discarded = self._rouge_filter_data(outputs) @@ -82,13 +84,13 @@ def __call__( return outputs - def _wf_filter_data(self, data_to_filter: List[Instance]): + def _wf_filter_data(self, data_to_filter: List[Dict]): # Well-formedness filtering - val1_inputs: List[Instance] = [] + val1_inputs: List[Dict] = [] discarded = 0 for gen_inp in data_to_filter: - new_instr: ApiSdgData = gen_inp.data - components = gen_inp.result.split("A:") + new_instr: ApiSdgData = gen_inp["data"] + components = gen_inp["output"].split("A:") if len(components) == 2: question, answer = [x.strip() for x in components] new_instr.input = question @@ -101,8 +103,10 @@ def _wf_filter_data(self, data_to_filter: List[Instance]): } # grab schema from input - args = [new_apis, question, answer] - kwargs = { + inp = { + "new_apis": new_apis, + "question": question, + "answer": answer, "check_arg_question_overlap": new_instr.check_arg_question_overlap, "intent_only": new_instr.intent_only, "require_nested": new_instr.require_nested, @@ -116,17 +120,30 @@ def _wf_filter_data(self, data_to_filter: List[Instance]): if new_instr.single_function else len(new_instr.positive_functions) ), + "data": new_instr, } - val1_inputs.append(Instance(args, kwargs, data=new_instr)) + + val1_inputs.append(inp) else: discarded += 1 - self.val1.validate_batch(val1_inputs) - # filter invalid data - outputs: List[ApiSdgData] = [ - val1_input.data for val1_input in val1_inputs if val1_input.result + outputs = [ + output["data"] + for output in self.val1( + val1_inputs, + arg_fields=["new_apis", "question", "answer"], + kwarg_fields=[ + "check_arg_question_overlap", + "intent_only", + "require_nested", + "min_ct", + "max_ct", + ], + result_field="output", + ) ] + discarded += len(val1_inputs) - len(outputs) return outputs, discarded @@ -137,16 +154,27 @@ def _rouge_filter_data(self, data_to_filter: List[ApiSdgData]): [instr.input for instr in data_to_filter] ) - val2_inputs: List[Instance] = [] + val2_inputs: List[Dict] = [] for new_data in data_to_filter: # computing similarity with the pre-tokenized instructions new_instruction_tokens = self.val2.tokenize(new_data.input) - args = [new_instruction_tokens, all_instruction_tokens] - val2_inputs.append(Instance(args, data=new_data)) - self.val2.validate_batch(val2_inputs) + inp = { + "new_instruction_tokens": new_instruction_tokens, + "all_instruction_tokens": all_instruction_tokens, + "data": new_data, + } + val2_inputs.append(inp) # filter rouge failed data - outputs = [val2_input.data for val2_input in val2_inputs if val2_input.result] + outputs = [ + output["data"] + for output in self.val2( + val2_inputs, + arg_fields=["new_instruction_tokens", "all_instruction_tokens"], + result_field="output", + ) + ] + discarded = len(val2_inputs) - len(outputs) return outputs, discarded diff --git a/fms_sdg/databuilders/nl2sql/generate.py b/fms_sdg/databuilders/nl2sql/generate.py index 2b967955..95f116d7 100644 --- a/fms_sdg/databuilders/nl2sql/generate.py +++ b/fms_sdg/databuilders/nl2sql/generate.py @@ -1,7 +1,6 @@ # Standard from dataclasses import asdict from typing import Any, Iterable, List, Set, Tuple -import time # Local from fms_sdg.base.databuilder import DataBuilder @@ -74,20 +73,22 @@ def __call__( instances = prompting_pipeline.run( data_generation_schema=data_generation_schema ) - self.llm1.generate_batch(instances) + llm_outputs = self.llm1( + instances, arg_fields=["prompt"], result_field="output" + ) sdg_logger.info("Post-processing generated data...") # NOTE: we process outputs in form of a tuple: schema, utterance, query to easily drop duplicates processed_outputs: Set[Tuple[str, str, str]] = set() - for instance in instances: - text = instance.args[0] + instance.result + for instance in llm_outputs: + text = instance["prompt"] + instance["output"] for prompt_class in PromptFactory.prompts.values(): if prompt_class.is_compatible(text): entries = prompt_class.get_utterances_and_queries(text) for entry in entries: processed_outputs.add( ( - instance.data["schema"], + instance["data"]["schema"], entry["utterance"], entry["query"], ) @@ -95,44 +96,51 @@ def __call__( sdg_logger.info("Validating generated data...") instances_for_validation = [ - Instance( - kwargs={ - "record": { - "sql_schema": sql_schema, - "utterance": utterance, - "sql_query": sql_query, - }, - "sql_dialect": str(data_generation_schema.database_type.name), + { + "record": { + "sql_schema": sql_schema, + "utterance": utterance, + "sql_query": sql_query, }, - data=SQLTriplet( - schema=sql_schema, utterances=[utterance], queries=[sql_query] + "sql_dialect": str(data_generation_schema.database_type.name), + "data": SQLTriplet( + schema=sql_schema, + utterances=[utterance], + queries=[sql_query], ).to_instruction(), - ) + } for sql_schema, utterance, sql_query in processed_outputs ] - self.val1.validate_batch(inputs=instances_for_validation) - self.val2.validate_batch(inputs=instances_for_validation) + filtered_output = self.val1( + instances_for_validation, + kwarg_fields=["record", "sql_dialect"], + result_field="output", + ) + filtered_output = self.val2( + filtered_output, + kwarg_fields=["record", "sql_dialect"], + result_field="output", + ) sdg_logger.info("Converting to instructions...") - for instance in instances_for_validation: - # NOTE: we keep only valid ones - if instance.result: - # NOTE: convert the generated instructions to a format compatible with fms_sdg. - converted_instruction = InstructLabSdgData( - # NOTE: coming from the package configuration - task_name=instruction_data_item.taxonomy_path, - # NOTE: info coming from taxonomy - taxonomy_path=instruction_data_item.taxonomy_path, - task_description=instruction_data_item.task_description, - # NOTE: info coming from generated entries - instruction=instance.data["user"], - input="", - document=None, - output=instance.data["assistant"], - ) - outputs.append(converted_instruction) - else: - discarded += 1 + for instance in filtered_output: + # NOTE: convert the generated instructions to a format compatible with fms_sdg. + converted_instruction = InstructLabSdgData( + # NOTE: coming from the package configuration + task_name=instruction_data_item.taxonomy_path, + # NOTE: info coming from taxonomy + taxonomy_path=instruction_data_item.taxonomy_path, + task_description=instruction_data_item.task_description, + # NOTE: info coming from generated entries + instruction=instance["data"]["user"], + input="", + document=None, + output=instance["data"]["assistant"], + ) + outputs.append(converted_instruction) + + discarded += len(instances_for_validation) - len(filtered_output) + sdg_logger.info("Data generation completed.") return outputs diff --git a/fms_sdg/databuilders/nl2sql/nl2sql.yaml b/fms_sdg/databuilders/nl2sql/nl2sql.yaml index 411ed7d0..e29664e9 100644 --- a/fms_sdg/databuilders/nl2sql/nl2sql.yaml +++ b/fms_sdg/databuilders/nl2sql/nl2sql.yaml @@ -8,7 +8,9 @@ blocks: model_id_or_path: ibm/granite-8b-code-instruct val1: type: sql_syntax_validator + filter: true val2: type: sql_execution_validator + filter: true metadata: version: 1.0 diff --git a/fms_sdg/databuilders/nl2sql/sqlinstruct/pipeline.py b/fms_sdg/databuilders/nl2sql/sqlinstruct/pipeline.py index 9f21fcc7..0a845bcc 100644 --- a/fms_sdg/databuilders/nl2sql/sqlinstruct/pipeline.py +++ b/fms_sdg/databuilders/nl2sql/sqlinstruct/pipeline.py @@ -2,7 +2,7 @@ # Standard from copy import deepcopy -from typing import List +from typing import Dict, List # Local from .models import SQLDataGenerationSchema, SQLTriplet @@ -11,7 +11,6 @@ SchemaAndQueryToUtterancePrompt, SchemaToUtteranceAndQueryPrompt, ) -from fms_sdg.base.instance import Instance class SQLDataGenerationPromptingPipeline: @@ -19,7 +18,7 @@ def __init__(self) -> None: """Initialize SQLDataGenerationPromptingPipeline.""" self.prompt_factory = PromptFactory() - def run(self, data_generation_schema: SQLDataGenerationSchema) -> List[Instance]: + def run(self, data_generation_schema: SQLDataGenerationSchema) -> List[Dict]: """Run the data generation pipeline. Args: @@ -55,7 +54,7 @@ def run(self, data_generation_schema: SQLDataGenerationSchema) -> List[Instance] ) ) - instances: List[Instance] = [] + instances: List[Dict] = [] # NOTE: this is trivially parallel for prompt_method_name in self.prompt_factory.prompts.keys(): prompt_object = self.prompt_factory.build( @@ -87,10 +86,12 @@ def run(self, data_generation_schema: SQLDataGenerationSchema) -> List[Instance] ) instances.extend( [ - Instance( - args=[prompt_object.encode_prompt(sql_triplet=sql_triplet)], - data=sql_triplet.model_dump(by_alias=True), - ) + { + "prompt": prompt_object.encode_prompt( + sql_triplet=sql_triplet + ), + "data": sql_triplet.model_dump(by_alias=True), + } for sql_triplet in sql_triplets ] ) diff --git a/fms_sdg/databuilders/simple/generate.py b/fms_sdg/databuilders/simple/generate.py index a9d9c82c..e7d4bb7f 100644 --- a/fms_sdg/databuilders/simple/generate.py +++ b/fms_sdg/databuilders/simple/generate.py @@ -1,5 +1,5 @@ # Standard -from typing import Any, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import copy import random import time @@ -9,7 +9,6 @@ # Local from fms_sdg.base.databuilder import DataBuilder -from fms_sdg.base.instance import Instance from fms_sdg.base.registry import register_data_builder from fms_sdg.base.task import SdgTask, group_data_by_task from fms_sdg.blocks.generators.llm import LMGeneratorBlock @@ -58,10 +57,7 @@ def __call__( instruction_data: List[InstructLabSdgData], ) -> List[InstructLabSdgData]: - print(pd.DataFrame([instruction_data[0]]).to_markdown()) - input("--") - - inputs: List[Instance] = [] + inputs: List[Dict] = [] instruction_data = instruction_data + [] random.shuffle(instruction_data) for grouped_data in group_data_by_task(instruction_data): @@ -70,27 +66,30 @@ def __call__( i : i + self._num_prompt_instructions ] prompt = self._encode_prompt(prompt_instructions) - args = [prompt] - kwargs = {"stop_sequences": [f"* Task {len(prompt_instructions)+2}"]} - print( - pd.DataFrame( - [Instance(args, kwargs, data=prompt_instructions)] - ).to_markdown() - ) - input("--") - inputs.append(Instance(args, kwargs, data=prompt_instructions)) + inp = { + "prompt": prompt, + "stop_sequences": [f"* Task {len(prompt_instructions)+2}"], + "data": prompt_instructions, + } + inputs.append(inp) request_start = time.time() - self.llm1.generate_batch(inputs) + + llm_outputs = self.llm1( + inputs, + arg_fields=["prompt"], + kwarg_fields=["stop_sequences"], + result_field="output", + ) request_duration = time.time() - request_start post_process_start = time.time() instruction_data = [] - for gen_inp in inputs: - prompt_instructions: List[InstructLabSdgData] = gen_inp.data + for gen_inp in llm_outputs: + prompt_instructions: List[InstructLabSdgData] = gen_inp["data"] new_instruction_dicts, discarded = utils.post_process_gpt3_response( len(prompt_instructions), - gen_inp.result, + gen_inp["output"], ) # make sure the generated instruction carried over extra fields for new_ins_dict, orig_ins in zip( @@ -116,18 +115,27 @@ def __call__( [instr.instruction for instr in instruction_data] ) - val1_inputs: List[Instance] = [] + val1_inputs: List[Dict] = [] for instruction_data_entry in instruction_data: # computing similarity with the pre-tokenized instructions new_instruction_tokens = self.val1.tokenize( instruction_data_entry.instruction ) - args = [new_instruction_tokens, all_instruction_tokens] - val1_inputs.append(Instance(args, data=instruction_data_entry)) - self.val1.validate_batch(val1_inputs) + inp = { + "new_toks": new_instruction_tokens, + "all_toks": all_instruction_tokens, + "data": instruction_data_entry, + } + val1_inputs.append(inp) # filter rouge failed data - outputs = [val1_input.data for val1_input in val1_inputs if val1_input.result] + outputs = [ + output["data"] + for output in self.val1( + val1_inputs, arg_fields=["new_toks", "all_toks"], result_field="output" + ) + ] + discarded += len(val1_inputs) - len(outputs) assess_duration = time.time() - assess_start diff --git a/fms_sdg/databuilders/simple/simple.yaml b/fms_sdg/databuilders/simple/simple.yaml index 477b05e3..7f5cc743 100644 --- a/fms_sdg/databuilders/simple/simple.yaml +++ b/fms_sdg/databuilders/simple/simple.yaml @@ -8,6 +8,7 @@ blocks: model_id_or_path: mistralai/mixtral-8x7b-instruct-v01 val1: type: rouge_scorer + filter: true threshold: 1.0 metadata: version: 1.0 diff --git a/tests/generators/__init__.py b/tests/blocks/generators/__init__.py similarity index 100% rename from tests/generators/__init__.py rename to tests/blocks/generators/__init__.py diff --git a/tests/generators/test_llm.py b/tests/blocks/generators/test_llm.py similarity index 57% rename from tests/generators/test_llm.py rename to tests/blocks/generators/test_llm.py index 15bfea99..364c9564 100644 --- a/tests/generators/test_llm.py +++ b/tests/blocks/generators/test_llm.py @@ -1,5 +1,5 @@ # Standard -from typing import List +from typing import Dict, List import copy import os import time @@ -8,7 +8,6 @@ import pytest # Local -from fms_sdg.base.instance import Instance from fms_sdg.base.registry import get_block from fms_sdg.blocks.generators.llm import CachingLM, LMGeneratorBlock @@ -55,20 +54,20 @@ def test_generate_batch(self, model_cfg): name=f"test_{model_cfg['type']}", config=model_cfg ) - inputs: List[Instance] = [] + inputs: List[Dict] = [] for prompt in PROMPTS: - args = [prompt] - inputs.append(Instance(args)) + inp = {"prompt": prompt} + inputs.append(inp) inputs_copy = copy.deepcopy(inputs) - lm.generate_batch(inputs) + lm(inputs, arg_fields=["prompt"], result_field="output") for i, inp in enumerate(inputs): assert ( - inp.args == inputs_copy[i].args + inp["prompt"] == inputs_copy[i]["prompt"] ), f"Input list has been rearranged at index {i}" - assert isinstance(inp.result, str) + assert isinstance(inp["output"], str) @pytest.mark.parametrize("model_cfg", [GREEDY_GENAI_CFG]) # , GREEDY_VLLM_CFG]) def test_loglikelihood_batch(self, model_cfg): @@ -76,48 +75,53 @@ def test_loglikelihood_batch(self, model_cfg): name=f"test_{model_cfg['type']}", config=model_cfg ) - inputs: List[Instance] = [] + inputs: List[Dict] = [] for prompt in PROMPTS: - args = [prompt, prompt] - inputs.append(Instance(args)) + inp = {"prompt1": prompt, "prompt2": prompt} + inputs.append(inp) inputs_copy = copy.deepcopy(inputs) - lm.loglikelihood_batch(inputs) + lm( + inputs, + arg_fields=["prompt1", "prompt2"], + result_field="output", + method="loglikelihood", + ) for i, inp in enumerate(inputs): assert ( - inp.args == inputs_copy[i].args + inp["prompt1"] == inputs_copy[i]["prompt1"] ), f"Input list has been rearranged at index {i}" - assert isinstance(inp.result, float) + assert isinstance(inp["output"], float) - def test_loglikelihood_batch_alignment(self): - vllm_config, genai_config = dict(GREEDY_VLLM_CFG), dict(GREEDY_GENAI_CFG) - vllm_config["model_id_or_path"] = "ibm-granite/granite-8b-code-instruct" - genai_config["model_id_or_path"] = "ibm/granite-8b-code-instruct" + # def test_loglikelihood_batch_alignment(self): + # vllm_config, genai_config = dict(GREEDY_VLLM_CFG), dict(GREEDY_GENAI_CFG) + # vllm_config["model_id_or_path"] = "ibm-granite/granite-8b-code-instruct" + # genai_config["model_id_or_path"] = "ibm/granite-8b-code-instruct" - vllm: LMGeneratorBlock = get_block(vllm_config["type"])( - name=f"test_{vllm_config['type']}", config=vllm_config - ) - genai: LMGeneratorBlock = get_block(genai_config["type"])( - name=f"test_{genai_config['type']}", config=genai_config - ) + # vllm: LMGeneratorBlock = get_block(vllm_config["type"])( + # name=f"test_{vllm_config['type']}", config=vllm_config + # ) + # genai: LMGeneratorBlock = get_block(genai_config["type"])( + # name=f"test_{genai_config['type']}", config=genai_config + # ) - inputs: List[Instance] = [] - for prompt in PROMPTS[:1]: - args = [prompt, prompt] - inputs.append(Instance(args)) + # inputs: List[Instance] = [] + # for prompt in PROMPTS[:1]: + # args = [prompt, prompt] + # inputs.append(Instance(args)) - inputs_vllm = copy.deepcopy(inputs) - inputs_genai = copy.deepcopy(inputs) + # inputs_vllm = copy.deepcopy(inputs) + # inputs_genai = copy.deepcopy(inputs) - vllm.loglikelihood_batch(inputs_vllm) - genai.loglikelihood_batch(inputs_genai) + # vllm.loglikelihood_batch(inputs_vllm) + # genai.loglikelihood_batch(inputs_genai) - for i, inp in enumerate(inputs): - assert ( - inp.args == inputs_vllm[i].args == inputs_genai[i].args - ), f"Input list has been rearranged at index {i}" + # for i, inp in enumerate(inputs): + # assert ( + # inp.args == inputs_vllm[i].args == inputs_genai[i].args + # ), f"Input list has been rearranged at index {i}" def test_lm_caching(self): cache_path = os.path.join( @@ -130,16 +134,16 @@ def test_lm_caching(self): name=f"test_{GREEDY_GENAI_CFG['type']}", config=GREEDY_GENAI_CFG ) - non_cache_inputs: List[Instance] = [] + non_cache_inputs: List[Dict] = [] for prompt in PROMPTS: - args = [prompt] - non_cache_inputs.append(Instance(args)) + inp = {"prompt": prompt} + non_cache_inputs.append(inp) pre_cache_inputs = copy.deepcopy(non_cache_inputs) post_cache_inputs = copy.deepcopy(non_cache_inputs) non_cache_time = time.time() - lm.generate_batch(non_cache_inputs) + lm(non_cache_inputs, arg_fields=["prompt"], result_field="output") non_cache_time = time.time() - non_cache_time cache_lm = CachingLM( @@ -148,11 +152,11 @@ def test_lm_caching(self): ) pre_cache_time = time.time() - cache_lm.generate_batch(pre_cache_inputs) + cache_lm(pre_cache_inputs, arg_fields=["prompt"], result_field="output") pre_cache_time = time.time() - pre_cache_time post_cache_time = time.time() - cache_lm.generate_batch(post_cache_inputs) + cache_lm(post_cache_inputs, arg_fields=["prompt"], result_field="output") post_cache_time = time.time() - post_cache_time os.remove(cache_path) @@ -165,8 +169,8 @@ def test_lm_caching(self): zip(non_cache_inputs, pre_cache_inputs, post_cache_inputs) ): assert ( - non.args == pre.args == post.args - ), f"Input list has been rearranged at index {i}: {(non.args, pre.args, post.args)}" + non["prompt"] == pre["prompt"] == post["prompt"] + ), f"Input list has been rearranged at index {i}: {(non['prompt'], pre['prompt'], post['prompt'])}" assert ( - non.result == pre.result == post.result - ), f"Different results detected at index {i}: {(non.result, pre.result, post.result)}" + non["output"] == pre["output"] == post["output"] + ), f"Different results detected at index {i}: {(non['output'], pre['output'], post['output'])}" diff --git a/tests/validators/__init__.py b/tests/blocks/validators/__init__.py similarity index 100% rename from tests/validators/__init__.py rename to tests/blocks/validators/__init__.py diff --git a/tests/validators/test_api.py b/tests/blocks/validators/test_api.py similarity index 100% rename from tests/validators/test_api.py rename to tests/blocks/validators/test_api.py diff --git a/tests/validators/test_lm_judge.py b/tests/blocks/validators/test_lm_judge.py similarity index 100% rename from tests/validators/test_lm_judge.py rename to tests/blocks/validators/test_lm_judge.py diff --git a/tests/validators/test_nl2sql.py b/tests/blocks/validators/test_nl2sql.py similarity index 100% rename from tests/validators/test_nl2sql.py rename to tests/blocks/validators/test_nl2sql.py diff --git a/tests/validators/test_rouge.py b/tests/blocks/validators/test_rouge.py similarity index 100% rename from tests/validators/test_rouge.py rename to tests/blocks/validators/test_rouge.py From 4ee38bbe2647997cd189dfaf24ae99d475a632f6 Mon Sep 17 00:00:00 2001 From: Max Crouse Date: Wed, 3 Jul 2024 12:39:34 -0500 Subject: [PATCH 03/41] caching llm --- fms_sdg/blocks/generators/llm.py | 55 +++++++++++++++++++++++++------- tests/blocks/__init__.py | 0 2 files changed, 43 insertions(+), 12 deletions(-) create mode 100644 tests/blocks/__init__.py diff --git a/fms_sdg/blocks/generators/llm.py b/fms_sdg/blocks/generators/llm.py index e87f7855..a61a55d5 100644 --- a/fms_sdg/blocks/generators/llm.py +++ b/fms_sdg/blocks/generators/llm.py @@ -176,8 +176,6 @@ def __init__(self, lm: LMGeneratorBlock, cache_db) -> None: self.dbdict def __getattr__(self, attr): - print(attr) - input("--") lm_attr = getattr(self.lm, attr) if not callable(lm_attr): return lm_attr @@ -245,16 +243,49 @@ def fn(requests: List[Instance]): return fn - # def __call__( - # self, - # inputs: Union[List[Dict], pd.DataFrame, Dataset], - # *args: Any, - # arg_fields: Optional[List[str]] = None, - # kwarg_fields: Optional[List[str]] = None, - # result_field: Optional[str] = None, - # method: str = "generate", - # **kwargs: Any, - # ) -> None: + def __call__( + self, + inputs: Union[List[Dict], pd.DataFrame, Dataset], + *args: Any, + arg_fields: Optional[List[str]] = None, + kwarg_fields: Optional[List[str]] = None, + result_field: Optional[str] = None, + method: str = "generate", + **kwargs: Any, + ) -> None: + + # simplify generation here + instances: List[Instance] = [] + for inp in inputs: + inp_args, inp_kwargs = self.lm.get_args_kwargs( + inp, arg_fields, kwarg_fields + ) + instances.append(Instance(args=inp_args, kwargs=inp_kwargs, data=inp)) + + if method == "generate": + self.generate_batch( + instances, + **kwargs, + ) + elif method == "loglikelihood": + self.loglikelihood_batch( + instances, + **kwargs, + ) + else: + err_str = ( + f"Unhandled method type: {method}" + if method is not None + else "Must set 'method' kwarg to 'generate' or 'loglikelihood'" + ) + raise ValueError(err_str) + + outputs = [] + for inst in instances: + self.lm.write_result(inst.data, inst.result, result_field) + outputs.append(inst.data) + + return outputs def get_cache_hook(self): return CacheHook(self) diff --git a/tests/blocks/__init__.py b/tests/blocks/__init__.py new file mode 100644 index 00000000..e69de29b From d8590574b30e25e80cff1a8013027fd39c979934 Mon Sep 17 00:00:00 2001 From: Max Crouse Date: Wed, 3 Jul 2024 13:41:50 -0500 Subject: [PATCH 04/41] adding compatibility_tests --- tests/compatibility_tests/.gitignore | 2 + tests/compatibility_tests/README.md | 12 ++++++ tests/compatibility_tests/blocks.py | 59 ++++++++++++++++++++++++++++ 3 files changed, 73 insertions(+) create mode 100644 tests/compatibility_tests/.gitignore create mode 100644 tests/compatibility_tests/README.md create mode 100644 tests/compatibility_tests/blocks.py diff --git a/tests/compatibility_tests/.gitignore b/tests/compatibility_tests/.gitignore new file mode 100644 index 00000000..d884e5f7 --- /dev/null +++ b/tests/compatibility_tests/.gitignore @@ -0,0 +1,2 @@ +sdg +data-prep-kit \ No newline at end of file diff --git a/tests/compatibility_tests/README.md b/tests/compatibility_tests/README.md new file mode 100644 index 00000000..89d22e0e --- /dev/null +++ b/tests/compatibility_tests/README.md @@ -0,0 +1,12 @@ +For the purposes of testing compatibility between DPK, instructlab-sdg, fms-dgt + +```bash +cd ./tests/compatibility_tests +git clone git@github.com:instructlab/sdg.git +pip install ./sdg +export PYTHONPATH="$PYTHONPATH:${PWD}" + +git clone git@github.com:IBM/data-prep-kit.git +cd data-prep-kit +pip install ./data-prep-kit +``` diff --git a/tests/compatibility_tests/blocks.py b/tests/compatibility_tests/blocks.py new file mode 100644 index 00000000..f284f9b7 --- /dev/null +++ b/tests/compatibility_tests/blocks.py @@ -0,0 +1,59 @@ +# Standard +from typing import Any, Callable, Dict, List, Optional, Union +import json +import operator + +# Third Party +from datasets import Dataset +from sdg.src.instructlab.sdg.filterblock import FilterByValueBlock +import pandas as pd + +# First Party +from fms_dgt.base.block import BaseBlock + + +class TestFilterBlock(BaseBlock): + """Base Class for all Blocks""" + + def __init__( + self, + *args: Any, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + + self._filter_block = FilterByValueBlock( + filter_column=self.config.get("filter_column"), + filter_value=self.config.get("filter_value"), + operation=self.config.get("operation"), + ) + + def __call__( + self, + inputs: Union[List[Dict], pd.DataFrame, Dataset], + **kwargs: Any, + ) -> Any: + return self._filter_block.generate(inputs) + + +def main(): + dataset = Dataset.from_dict( + {"test": ["keep", "remove", "keep"], "nochange": ["data", "data", "data"]} + ) + test_block = TestFilterBlock( + name="block", + config={ + "filter_column": "test", + "filter_value": "remove", + "operation": operator.ne, + }, + ) + ret_dataset: Dataset = test_block(dataset) + + print(json.dumps(dataset.to_dict(), indent=4)) + print("\n=====\n") + print(json.dumps(ret_dataset.to_dict(), indent=4)) + + +if __name__ == "__main__": + main() From 3e1fbc03cab7c2ff6ae38dd3c6eff65085c72ab4 Mon Sep 17 00:00:00 2001 From: Max Crouse Date: Wed, 3 Jul 2024 14:35:38 -0500 Subject: [PATCH 05/41] template update --- templates/databuilder/data_builder_name.yaml | 4 +-- templates/databuilder/generate.py | 37 +++++++++++--------- templates/generator/template.py | 16 +++++++-- templates/validator/template.py | 27 +++++++++++--- 4 files changed, 58 insertions(+), 26 deletions(-) diff --git a/templates/databuilder/data_builder_name.yaml b/templates/databuilder/data_builder_name.yaml index eeb42500..afbea628 100644 --- a/templates/databuilder/data_builder_name.yaml +++ b/templates/databuilder/data_builder_name.yaml @@ -1,14 +1,14 @@ name: data_builder_name # MUST match the name of the file! -generators: +blocks: llm1: type: genai temperature: 0.0 max_new_tokens: 512 min_new_tokens: 1 model_id_or_path: mistralai/mixtral-8x7b-instruct-v01 -validators: val1: type: rouge_scorer + filter: true threshold: 0.0 metadata: version: 1.0 diff --git a/templates/databuilder/generate.py b/templates/databuilder/generate.py index 58453ff6..dd374c89 100644 --- a/templates/databuilder/generate.py +++ b/templates/databuilder/generate.py @@ -10,8 +10,8 @@ from fms_dgt.base.instance import Instance from fms_dgt.base.registry import register_data_builder from fms_dgt.base.task import SdgTask -from fms_dgt.generators.llm import LMGenerator -from fms_dgt.validators.rouge import RougeValidator +from fms_dgt.blocks.generators.llm import LMGeneratorBlock +from fms_dgt.blocks.validators.rouge import RougeValidator @register_data_builder("data_builder_name") @@ -21,7 +21,7 @@ class TemplateDataBuilder(DataBuilder): TASK_TYPE: SdgTask = TemplateSdgTask # llm1 is the main generator that will produce the synthetic examples - llm1: LMGenerator + llm1: LMGeneratorBlock # val1 is the validator which checks rouge score val1: RougeValidator @@ -45,34 +45,37 @@ def __call__( for idata in instruction_data: # example of how to form an argument to the LLM generator prompt = idata.instruction + "\n\n" + idata.input - args = [prompt] - kwargs = {"seed": request_idx} - generator_inputs.append(Instance(args, kwargs, data=idata)) + inp = {"prompt": prompt, "seed": request_idx, "data": idata} + generator_inputs.append(inp) - self.llm1.generate_batch(generator_inputs) + llm_outputs = self.llm1( + generator_inputs, + arg_fields=["prompt"], + kwarg_fields=["seed"], + result_field="output", + ) validator_inputs = [] - for generator_input in generator_inputs: + for output in llm_outputs: # original input example - orig_input: TemplateSdgData = generator_input.data + orig_input: TemplateSdgData = output["data"] # getting output - generator_output = generator_input.result + generator_output = output["output"] # assign output to instruction new_instruction = copy.copy(orig_input) new_instruction.instruction = generator_output - args = [generator_output] - validator_inputs.append(Instance(args, data=new_instruction)) - - self.val1.validate_batch(validator_inputs) + inp = {"to_val": generator_output, "data": new_instruction} + validator_inputs.append(inp) # filter rouge failed data outputs = [ - validator_input.data - for validator_input in validator_inputs - if validator_input.result + output["data"] + for output in self.val1( + validator_inputs, arg_fields=["to_val"], result_field="output" + ) ] discarded = len(validator_inputs) - len(outputs) diff --git a/templates/generator/template.py b/templates/generator/template.py index 6aa468ec..02a2f692 100644 --- a/templates/generator/template.py +++ b/templates/generator/template.py @@ -1,5 +1,9 @@ # Standard -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional, Union + +# Third Party +from datasets import Dataset +from pandas import DataFrame # Local from fms_dgt.base.block import BaseGeneratorBlock @@ -14,5 +18,13 @@ class TemplateGenerator(BaseGeneratorBlock): def __init__(self, name: str, config: Dict, **kwargs: Any) -> None: super().__init__(name, config, **kwargs) - def __call__(self, inputs: List[Instance], **kwargs: Any) -> None: + def __call__( + self, + inputs: Union[List[Dict], DataFrame, Dataset], + *args: Any, + arg_fields: Optional[List[str]] = None, + kwarg_fields: Optional[List[str]] = None, + result_field: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: raise NotImplementedError diff --git a/templates/validator/template.py b/templates/validator/template.py index 0807bff8..aff3b44b 100644 --- a/templates/validator/template.py +++ b/templates/validator/template.py @@ -1,5 +1,9 @@ # Standard -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional, Union + +# Third Party +from datasets import Dataset +from pandas import DataFrame # Local from fms_dgt.base.block import BaseValidatorBlock @@ -14,10 +18,23 @@ class TemplateValidator(BaseValidatorBlock): def __init__(self, name: str, config: Dict) -> None: super().__init__(name, config) - def __call__(self, inputs: List[Instance], **kwargs: Any) -> None: - """Takes in a list of Instance objects (each containing their own arg / kwargs)""" - for x in inputs: - x.result = self._validate(*x.args, **x.kwargs) + def __call__( + self, + inputs: Union[List[Dict], DataFrame, Dataset], + *args: Any, + arg_fields: Optional[List[str]] = None, + kwarg_fields: Optional[List[str]] = None, + result_field: Optional[List[str]] = None, + **kwargs: Any, + ) -> None: + return super().__call__( + inputs, + *args, + arg_fields=arg_fields, + kwarg_fields=kwarg_fields, + result_field=result_field, + **kwargs, + ) def _validate(self, *args, **kwargs) -> bool: """Return True if valid and False otherwise""" From aec74ff649f0cdb687004446d1a8910e4d8af39a Mon Sep 17 00:00:00 2001 From: Max Crouse Date: Wed, 3 Jul 2024 14:36:18 -0500 Subject: [PATCH 06/41] template update --- templates/databuilder/generate.py | 5 ++--- templates/generator/template.py | 1 - templates/validator/template.py | 1 - 3 files changed, 2 insertions(+), 5 deletions(-) diff --git a/templates/databuilder/generate.py b/templates/databuilder/generate.py index dd374c89..4f69bcad 100644 --- a/templates/databuilder/generate.py +++ b/templates/databuilder/generate.py @@ -1,5 +1,5 @@ # Standard -from typing import Any, List, Tuple +from typing import Any, Dict, List, Tuple import copy # First Party @@ -7,7 +7,6 @@ # Local from fms_dgt.base.databuilder import DataBuilder -from fms_dgt.base.instance import Instance from fms_dgt.base.registry import register_data_builder from fms_dgt.base.task import SdgTask from fms_dgt.blocks.generators.llm import LMGeneratorBlock @@ -41,7 +40,7 @@ def __call__( # None of this code should work, you should replace it with your own SDG flow. However, it will illustrate the general process - generator_inputs: List[Instance] = [] + generator_inputs: List[Dict] = [] for idata in instruction_data: # example of how to form an argument to the LLM generator prompt = idata.instruction + "\n\n" + idata.input diff --git a/templates/generator/template.py b/templates/generator/template.py index 02a2f692..ceff7d5a 100644 --- a/templates/generator/template.py +++ b/templates/generator/template.py @@ -7,7 +7,6 @@ # Local from fms_dgt.base.block import BaseGeneratorBlock -from fms_dgt.base.instance import Instance from fms_dgt.base.registry import register_block diff --git a/templates/validator/template.py b/templates/validator/template.py index aff3b44b..fedb93de 100644 --- a/templates/validator/template.py +++ b/templates/validator/template.py @@ -7,7 +7,6 @@ # Local from fms_dgt.base.block import BaseValidatorBlock -from fms_dgt.base.instance import Instance from fms_dgt.base.registry import register_block From 8a215de970ebf3d4c70bdeed67db9cf7ffce8c72 Mon Sep 17 00:00:00 2001 From: Max Crouse Date: Wed, 3 Jul 2024 15:05:43 -0500 Subject: [PATCH 07/41] rm import --- fms_dgt/blocks/generators/llm.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/fms_dgt/blocks/generators/llm.py b/fms_dgt/blocks/generators/llm.py index 42ee5adc..faa263f4 100644 --- a/fms_dgt/blocks/generators/llm.py +++ b/fms_dgt/blocks/generators/llm.py @@ -24,9 +24,8 @@ from sqlitedict import SqliteDict from tqdm import tqdm import pandas as pd -import transformers -# First Party +# Local from fms_dgt.base.block import BaseGeneratorBlock from fms_dgt.base.instance import Instance from fms_dgt.utils import sdg_logger From 52d26f6e76eb1782d364f8fd061b564b5cbf713c Mon Sep 17 00:00:00 2001 From: Max Crouse Date: Thu, 4 Jul 2024 11:48:52 -0500 Subject: [PATCH 08/41] remove old return type --- fms_dgt/base/block.py | 4 ++-- fms_dgt/blocks/generators/genai.py | 4 ++-- fms_dgt/blocks/generators/llm.py | 2 +- fms_dgt/blocks/generators/openai.py | 4 ++-- fms_dgt/blocks/generators/vllm.py | 4 ++-- fms_dgt/blocks/validators/api.py | 4 ++-- fms_dgt/blocks/validators/nl2sql/sql_execution_validator.py | 4 ++-- fms_dgt/blocks/validators/nl2sql/sql_syntax_validator.py | 4 ++-- fms_dgt/blocks/validators/rouge.py | 4 ++-- 9 files changed, 17 insertions(+), 17 deletions(-) diff --git a/fms_dgt/base/block.py b/fms_dgt/base/block.py index 0bc6a6b0..566441db 100644 --- a/fms_dgt/base/block.py +++ b/fms_dgt/base/block.py @@ -83,7 +83,7 @@ def __call__( kwarg_fields: Optional[List[str]] = None, result_field: Optional[str] = None, **kwargs: Any, - ) -> None: + ): pass @@ -104,7 +104,7 @@ def __call__( kwarg_fields: Optional[List[str]] = None, result_field: Optional[List[str]] = None, **kwargs: Any, - ) -> None: + ): outputs = [] for x in inputs: inp_args, inp_kwargs = self.get_args_kwargs(x, arg_fields, kwarg_fields) diff --git a/fms_dgt/blocks/generators/genai.py b/fms_dgt/blocks/generators/genai.py index b136fd64..f62b2176 100644 --- a/fms_dgt/blocks/generators/genai.py +++ b/fms_dgt/blocks/generators/genai.py @@ -8,7 +8,7 @@ from pandas import DataFrame from tqdm import tqdm -# First Party +# Local from fms_dgt.base.instance import Instance from fms_dgt.base.registry import get_resource, register_block from fms_dgt.blocks.generators.llm import LMGeneratorBlock @@ -217,7 +217,7 @@ def __call__( result_field: Union[str, None] = None, method: str = "generate", **kwargs: Any, - ) -> None: + ): return super().__call__( inputs, *args, diff --git a/fms_dgt/blocks/generators/llm.py b/fms_dgt/blocks/generators/llm.py index faa263f4..ee22de3b 100644 --- a/fms_dgt/blocks/generators/llm.py +++ b/fms_dgt/blocks/generators/llm.py @@ -99,7 +99,7 @@ def __call__( result_field: Optional[str] = None, method: str = "generate", **kwargs: Any, - ) -> None: + ): # simplify generation here instances: List[Instance] = [] diff --git a/fms_dgt/blocks/generators/openai.py b/fms_dgt/blocks/generators/openai.py index a336c3fd..b5a07203 100644 --- a/fms_dgt/blocks/generators/openai.py +++ b/fms_dgt/blocks/generators/openai.py @@ -20,7 +20,7 @@ from pandas import DataFrame from tqdm import tqdm -# First Party +# Local from fms_dgt.base.instance import Instance from fms_dgt.base.registry import get_resource, register_block from fms_dgt.blocks.generators.llm import LMGeneratorBlock @@ -191,7 +191,7 @@ def __call__( result_field: Union[str, None] = None, method: str = "generate", **kwargs: Any, - ) -> None: + ): return super().__call__( inputs, *args, diff --git a/fms_dgt/blocks/generators/vllm.py b/fms_dgt/blocks/generators/vllm.py index 29a0ba2b..45933a17 100644 --- a/fms_dgt/blocks/generators/vllm.py +++ b/fms_dgt/blocks/generators/vllm.py @@ -23,7 +23,7 @@ from pandas import DataFrame from tqdm import tqdm -# First Party +# Local from fms_dgt.base.instance import Instance from fms_dgt.base.registry import register_block from fms_dgt.blocks.generators.llm import LMGeneratorBlock @@ -473,7 +473,7 @@ def __call__( result_field: Union[str, None] = None, method: str = "generate", **kwargs: Any, - ) -> None: + ): return super().__call__( inputs, *args, diff --git a/fms_dgt/blocks/validators/api.py b/fms_dgt/blocks/validators/api.py index 32e4337d..717f7eb1 100644 --- a/fms_dgt/blocks/validators/api.py +++ b/fms_dgt/blocks/validators/api.py @@ -6,7 +6,7 @@ from datasets import Dataset from pandas import DataFrame -# First Party +# Local from fms_dgt.base.block import BaseValidatorBlock from fms_dgt.base.registry import register_block @@ -34,7 +34,7 @@ def __call__( kwarg_fields: Optional[List[str]] = None, result_field: Optional[List[str]] = None, **kwargs: Any, - ) -> None: + ): return super().__call__( inputs, *args, diff --git a/fms_dgt/blocks/validators/nl2sql/sql_execution_validator.py b/fms_dgt/blocks/validators/nl2sql/sql_execution_validator.py index 1d882604..329b795c 100644 --- a/fms_dgt/blocks/validators/nl2sql/sql_execution_validator.py +++ b/fms_dgt/blocks/validators/nl2sql/sql_execution_validator.py @@ -8,7 +8,7 @@ from pandas import DataFrame import sqlglot -# First Party +# Local from fms_dgt.base.block import BaseValidatorBlock from fms_dgt.base.registry import register_block @@ -28,7 +28,7 @@ def __call__( kwarg_fields: Optional[List[str]] = None, result_field: Optional[List[str]] = None, **kwargs: Any, - ) -> None: + ): return super().__call__( inputs, *args, diff --git a/fms_dgt/blocks/validators/nl2sql/sql_syntax_validator.py b/fms_dgt/blocks/validators/nl2sql/sql_syntax_validator.py index bfab7841..9195ad09 100644 --- a/fms_dgt/blocks/validators/nl2sql/sql_syntax_validator.py +++ b/fms_dgt/blocks/validators/nl2sql/sql_syntax_validator.py @@ -7,7 +7,7 @@ from pandas import DataFrame import sqlglot -# First Party +# Local from fms_dgt.base.block import BaseValidatorBlock from fms_dgt.base.registry import register_block @@ -27,7 +27,7 @@ def __call__( kwarg_fields: Optional[List[str]] = None, result_field: Optional[List[str]] = None, **kwargs: Any, - ) -> None: + ): return super().__call__( inputs, *args, diff --git a/fms_dgt/blocks/validators/rouge.py b/fms_dgt/blocks/validators/rouge.py index 359f9ffc..acb5465e 100644 --- a/fms_dgt/blocks/validators/rouge.py +++ b/fms_dgt/blocks/validators/rouge.py @@ -6,7 +6,7 @@ from datasets import Dataset from pandas import DataFrame -# First Party +# Local from fms_dgt.base.block import BaseValidatorBlock from fms_dgt.base.registry import register_block @@ -52,7 +52,7 @@ def __call__( kwarg_fields: Optional[List[str]] = None, result_field: Optional[List[str]] = None, **kwargs: Any, - ) -> None: + ): return super().__call__( inputs, *args, From 94ffa99c86d3ce558d8ade8a0db51f952ef2ea68 Mon Sep 17 00:00:00 2001 From: Max Crouse Date: Fri, 5 Jul 2024 13:51:11 -0500 Subject: [PATCH 09/41] add parquet saving / loading --- fms_dgt/base/block.py | 6 ++-- fms_dgt/base/task.py | 69 ++++++++++++++++++++++++++++++------------- pyproject.toml | 1 + 3 files changed, 52 insertions(+), 24 deletions(-) diff --git a/fms_dgt/base/block.py b/fms_dgt/base/block.py index 566441db..b43c34ae 100644 --- a/fms_dgt/base/block.py +++ b/fms_dgt/base/block.py @@ -1,6 +1,6 @@ # Standard from abc import ABC -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Type, Union import abc # Third Party @@ -77,7 +77,7 @@ def write_result( @abc.abstractmethod def __call__( self, - inputs: Union[List[Dict], pd.DataFrame, Dataset], + inputs: Union[List[Dict], Type[pd.DataFrame], Type[Dataset]], *args: Any, arg_fields: Optional[List[str]] = None, kwarg_fields: Optional[List[str]] = None, @@ -98,7 +98,7 @@ def __init__(self, name: str, config: Dict, **kwargs: Any) -> None: def __call__( self, - inputs: Union[List[Dict], pd.DataFrame, Dataset], + inputs: Union[List[Dict], Type[pd.DataFrame], Type[Dataset]], *args: Any, arg_fields: Optional[List[str]] = None, kwarg_fields: Optional[List[str]] = None, diff --git a/fms_dgt/base/task.py b/fms_dgt/base/task.py index 148bb2fb..21285f09 100644 --- a/fms_dgt/base/task.py +++ b/fms_dgt/base/task.py @@ -5,10 +5,11 @@ import json import os -# First Party -from fms_dgt.utils import group_data_by_attribute +# Third Party +import pandas as pd -DEFAULT_OUTPUT_DIR = "output" +# Local +from fms_dgt.utils import group_data_by_attribute @dataclass @@ -42,7 +43,8 @@ def __init__( task_description: str, created_by: str, data_builder: str, - output_dir: Optional[str] = None, + output_dir: Optional[str] = "output", + output_format: Optional[str] = "parquet", seed_data: Optional[List[Any]] = None, num_outputs_to_generate: Optional[int] = None, **kwargs: Any, @@ -54,8 +56,8 @@ def __init__( self._num_outputs_to_generate = num_outputs_to_generate self.machine_data = [] - self._output_dir = output_dir if output_dir is not None else DEFAULT_OUTPUT_DIR - self._output_path = self._get_default_output_path() + self._output_dir = output_dir + self._output_path = self._get_default_output_path(output_format) self._seed_data = seed_data def __post_init__(self): @@ -97,35 +99,60 @@ def num_outputs_to_generate(self): def is_complete(self): return len(self.machine_data) > self.num_outputs_to_generate - def _get_default_output_path(self): + def _get_default_output_path(self, output_format: str = None): path_components = [] path_components.append(self._output_dir) path_components.append(self._name) - path_components.append("generated_instructions.jsonl") + path_components.append("generated_instructions." + output_format) return os.path.join(*path_components) def save_data( - self, new_data: Union[SdgData, List[SdgData]], output_path: str = None + self, + new_data: Union[SdgData, List[SdgData]], + output_path: str = None, ) -> None: if type(new_data) != list: new_data = [new_data] output_path = self._output_path if output_path is None else output_path - os.makedirs(os.path.dirname(output_path), exist_ok=True) - with open(output_path, "a") as f: - for d in new_data: - f.write(json.dumps(d.to_output_dict()) + "\n") + output_format = os.path.splitext(output_path)[-1] + + if output_format == ".jsonl": + os.makedirs(os.path.dirname(output_path), exist_ok=True) + with open(output_path, "a") as f: + for d in new_data: + f.write(json.dumps(d.to_output_dict()) + "\n") + elif output_format == ".parquet": + os.makedirs(os.path.dirname(output_path), exist_ok=True) + pd.DataFrame(new_data).to_parquet( + output_path, engine="fastparquet", append=os.path.isfile(output_path) + ) + else: + raise ValueError(f"Unhandled output format: {output_format}") def load_data(self, output_path: str = None) -> List[SdgData]: output_path = self._output_path if output_path is None else output_path - with open(output_path, "r") as f: - try: - machine_data = [ - self.instantiate_output_example(**json.loads(l.strip())) - for l in f.readlines() - ] - except ValueError: - machine_data = [] + output_format = os.path.splitext(output_path)[-1] + if output_format == ".jsonl": + with open(output_path, "r") as f: + try: + machine_data = [ + self.instantiate_output_example(**json.loads(l.strip())) + for l in f.readlines() + ] + except ValueError: + machine_data = [] + elif output_format == ".parquet": + machine_data = [ + self.instantiate_output_example(**r) + for r in ( + pd.read_parquet(output_path, engine="fastparquet") + .apply(dict, axis=1) + .to_list() + ) + ] + else: + raise ValueError(f"Unhandled output format: {output_format}") self.machine_data = machine_data diff --git a/pyproject.toml b/pyproject.toml index e7de7251..d0516c01 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ dependencies = [ "more_itertools", "GitPython", "Jinja2", + "fastparquet>=2024.5.0" ] [tool.setuptools.packages.find] From a19116f37c610263b3d9354eef987b88c8a51446 Mon Sep 17 00:00:00 2001 From: Max Crouse Date: Fri, 5 Jul 2024 17:31:12 -0500 Subject: [PATCH 10/41] adding utility block --- fms_dgt/base/block.py | 4 ++ fms_dgt/blocks/utilities/flatten_field.py | 49 +++++++++++++++++++++++ 2 files changed, 53 insertions(+) create mode 100644 fms_dgt/blocks/utilities/flatten_field.py diff --git a/fms_dgt/base/block.py b/fms_dgt/base/block.py index b43c34ae..fd1315ab 100644 --- a/fms_dgt/base/block.py +++ b/fms_dgt/base/block.py @@ -87,6 +87,10 @@ def __call__( pass +class BaseUtilityBlock(BaseBlock): + pass + + class BaseGeneratorBlock(BaseBlock): pass diff --git a/fms_dgt/blocks/utilities/flatten_field.py b/fms_dgt/blocks/utilities/flatten_field.py new file mode 100644 index 00000000..5f47bb4a --- /dev/null +++ b/fms_dgt/blocks/utilities/flatten_field.py @@ -0,0 +1,49 @@ +# Standard +from typing import Any, Dict, List, Optional, Union +import copy + +# Third Party +from datasets import Dataset +from pandas import DataFrame + +# Local +from fms_dgt.base.block import BaseUtilityBlock +from fms_dgt.base.registry import register_block + + +@register_block("flatten_field") +class FlattenFieldBlock(BaseUtilityBlock): + """Flatten specified args""" + + def __call__( + self, + inputs: Union[List[Dict], DataFrame, Dataset], + *args: Any, + arg_fields: Optional[List[str]] = None, + kwarg_fields: Optional[List[str]] = None, + result_field: Optional[List[str]] = None, + **kwargs: Any, + ): + arg_fields = arg_fields if arg_fields is not None else self._arg_fields + if arg_fields is None: + arg_fields = [] + + assert ( + len(arg_fields) == 1 + ), f"{self.__class__.__name__} can only have 1 arg field!" + + outputs = [] + for x in inputs: + inp_args, inp_kwargs = self.get_args_kwargs(x, arg_fields, kwarg_fields) + to_flatten = inp_args[0] if type(inp_args[0]) == list else [inp_args[0]] + + # remove flattened attribute + x_copy = copy.copy(x) + delattr(x_copy, arg_fields[0]) + + for el in to_flatten: + outputs.append(copy.copy(x_copy)) + delattr(outputs[-1], arg_fields[0]) + self.write_result(outputs[-1], el, result_field) + + return outputs From e48ec42e9f6923d70b1fd9250a0206ed1c825e53 Mon Sep 17 00:00:00 2001 From: Max Crouse Date: Fri, 5 Jul 2024 18:20:18 -0500 Subject: [PATCH 11/41] adding utility block --- fms_dgt/blocks/utilities/flatten_field.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fms_dgt/blocks/utilities/flatten_field.py b/fms_dgt/blocks/utilities/flatten_field.py index 5f47bb4a..ef8c44d9 100644 --- a/fms_dgt/blocks/utilities/flatten_field.py +++ b/fms_dgt/blocks/utilities/flatten_field.py @@ -34,7 +34,7 @@ def __call__( outputs = [] for x in inputs: - inp_args, inp_kwargs = self.get_args_kwargs(x, arg_fields, kwarg_fields) + inp_args, _ = self.get_args_kwargs(x, arg_fields, kwarg_fields) to_flatten = inp_args[0] if type(inp_args[0]) == list else [inp_args[0]] # remove flattened attribute From 0ebe2a06cf06b448986622a1dd4d57a8b33f7d21 Mon Sep 17 00:00:00 2001 From: Max Crouse Date: Mon, 8 Jul 2024 09:47:29 -0500 Subject: [PATCH 12/41] rm block suffix --- fms_dgt/base/databuilder.py | 6 +++--- fms_dgt/blocks/generators/genai.py | 4 ++-- fms_dgt/blocks/generators/llm.py | 4 ++-- fms_dgt/blocks/generators/openai.py | 4 ++-- fms_dgt/blocks/generators/vllm.py | 4 ++-- fms_dgt/blocks/utilities/flatten_field.py | 2 +- fms_dgt/blocks/validators/lm_judge.py | 8 +++----- fms_dgt/databuilders/api/generate.py | 10 +++++----- fms_dgt/databuilders/nl2sql/generate.py | 6 +++--- fms_dgt/databuilders/simple/generate.py | 6 +++--- templates/databuilder/generate.py | 4 ++-- tests/blocks/generators/test_llm.py | 10 +++++----- 12 files changed, 33 insertions(+), 35 deletions(-) diff --git a/fms_dgt/base/databuilder.py b/fms_dgt/base/databuilder.py index f039b7da..88078f45 100644 --- a/fms_dgt/base/databuilder.py +++ b/fms_dgt/base/databuilder.py @@ -9,11 +9,11 @@ # Third Party from tqdm import tqdm -# First Party +# Local from fms_dgt.base.block import BaseBlock from fms_dgt.base.registry import get_block from fms_dgt.base.task import SdgData, SdgTask -from fms_dgt.blocks.generators.llm import CachingLM, LMGeneratorBlock +from fms_dgt.blocks.generators.llm import CachingLM, LMGenerator from fms_dgt.utils import all_annotations, sdg_logger @@ -99,7 +99,7 @@ def _init_blocks(self, lm_cache: str = None): ) obj = get_block(obj_config[TYPE_KEY])(obj_name, obj_config) - if lm_cache is not None and isinstance(obj, LMGeneratorBlock): + if lm_cache is not None and isinstance(obj, LMGenerator): sdg_logger.info( "Using cache at %s", lm_cache + "_rank" + str(obj.rank) + ".db", diff --git a/fms_dgt/blocks/generators/genai.py b/fms_dgt/blocks/generators/genai.py index f62b2176..88482bd7 100644 --- a/fms_dgt/blocks/generators/genai.py +++ b/fms_dgt/blocks/generators/genai.py @@ -11,7 +11,7 @@ # Local from fms_dgt.base.instance import Instance from fms_dgt.base.registry import get_resource, register_block -from fms_dgt.blocks.generators.llm import LMGeneratorBlock +from fms_dgt.blocks.generators.llm import LMGenerator from fms_dgt.resources.genai import GenAIKeyResource import fms_dgt.blocks.generators.utils as generator_utils import fms_dgt.utils as utils @@ -31,7 +31,7 @@ @register_block("genai") -class GenAIGeneratorBlock(LMGeneratorBlock): +class GenAIGenerator(LMGenerator): """GenAI Generator""" def __init__(self, name: str, config: Dict, **kwargs: Any): diff --git a/fms_dgt/blocks/generators/llm.py b/fms_dgt/blocks/generators/llm.py index ee22de3b..3f9eee28 100644 --- a/fms_dgt/blocks/generators/llm.py +++ b/fms_dgt/blocks/generators/llm.py @@ -33,7 +33,7 @@ MODEL_ID_OR_PATH = "model_id_or_path" -class LMGeneratorBlock(BaseGeneratorBlock): +class LMGenerator(BaseGeneratorBlock): """Class for LLM Generators""" def __init__(self, name: str, config: Dict, **kwargs: Any): @@ -155,7 +155,7 @@ def add_partial(self, attr, req, res) -> None: class CachingLM: - def __init__(self, lm: LMGeneratorBlock, cache_db) -> None: + def __init__(self, lm: LMGenerator, cache_db) -> None: """LM wrapper that returns cached results if they exist, and uses the underlying LM if not. :param lm: LM diff --git a/fms_dgt/blocks/generators/openai.py b/fms_dgt/blocks/generators/openai.py index b5a07203..f9f274a6 100644 --- a/fms_dgt/blocks/generators/openai.py +++ b/fms_dgt/blocks/generators/openai.py @@ -23,7 +23,7 @@ # Local from fms_dgt.base.instance import Instance from fms_dgt.base.registry import get_resource, register_block -from fms_dgt.blocks.generators.llm import LMGeneratorBlock +from fms_dgt.blocks.generators.llm import LMGenerator from fms_dgt.resources.openai import OpenAIKeyResource import fms_dgt.blocks.generators.utils as generator_utils import fms_dgt.utils as utils @@ -70,7 +70,7 @@ def completion(): @register_block("openai-chat", "local-chat-completions") -class OpenaiChatCompletionsLMBlock(LMGeneratorBlock): +class OpenaiChatCompletionsLM(LMGenerator): def __init__(self, name: str, config: Dict, **kwargs: Any) -> None: """ diff --git a/fms_dgt/blocks/generators/vllm.py b/fms_dgt/blocks/generators/vllm.py index 45933a17..b06934ea 100644 --- a/fms_dgt/blocks/generators/vllm.py +++ b/fms_dgt/blocks/generators/vllm.py @@ -26,7 +26,7 @@ # Local from fms_dgt.base.instance import Instance from fms_dgt.base.registry import register_block -from fms_dgt.blocks.generators.llm import LMGeneratorBlock +from fms_dgt.blocks.generators.llm import LMGenerator from fms_dgt.blocks.generators.utils import Collator, undistribute from fms_dgt.utils import sdg_logger import fms_dgt.blocks.generators.utils as generator_utils @@ -45,7 +45,7 @@ # TODO: this can be made more efficient for our purposes by rewriting the async code ourselves @register_block("vllm") -class vLLMGeneratorBlock(LMGeneratorBlock): +class vLLMGenerator(LMGenerator): """vLLM Generator""" _DEFAULT_MAX_LENGTH = 2048 diff --git a/fms_dgt/blocks/utilities/flatten_field.py b/fms_dgt/blocks/utilities/flatten_field.py index ef8c44d9..cfeaa007 100644 --- a/fms_dgt/blocks/utilities/flatten_field.py +++ b/fms_dgt/blocks/utilities/flatten_field.py @@ -12,7 +12,7 @@ @register_block("flatten_field") -class FlattenFieldBlock(BaseUtilityBlock): +class FlattenField(BaseUtilityBlock): """Flatten specified args""" def __call__( diff --git a/fms_dgt/blocks/validators/lm_judge.py b/fms_dgt/blocks/validators/lm_judge.py index 538bd0ac..76f0e37a 100644 --- a/fms_dgt/blocks/validators/lm_judge.py +++ b/fms_dgt/blocks/validators/lm_judge.py @@ -1,11 +1,11 @@ # Standard from typing import Any, Dict, List -# First Party +# Local from fms_dgt.base.block import BaseValidatorBlock from fms_dgt.base.instance import Instance from fms_dgt.base.registry import get_block, register_block -from fms_dgt.blocks.generators.llm import LMGeneratorBlock +from fms_dgt.blocks.generators.llm import LMGenerator TYPE_KEY = "lm_type" @@ -16,9 +16,7 @@ class LMJudgeValidator(BaseValidatorBlock): def __init__(self, name: str, config: Dict, **kwargs: Any): super().__init__(name, config, **kwargs) - self._llm_generator: LMGeneratorBlock = get_block(config[TYPE_KEY])( - name, config - ) + self._llm_generator: LMGenerator = get_block(config[TYPE_KEY])(name, config) self._blocks.append(self._llm_generator) def validate_batch(self, inputs: List[Instance], **kwargs: Any) -> None: diff --git a/fms_dgt/databuilders/api/generate.py b/fms_dgt/databuilders/api/generate.py index 367a29eb..ae0566a0 100644 --- a/fms_dgt/databuilders/api/generate.py +++ b/fms_dgt/databuilders/api/generate.py @@ -3,11 +3,11 @@ import random import time -# First Party +# Local from fms_dgt.base.databuilder import DataBuilder from fms_dgt.base.registry import register_data_builder from fms_dgt.base.task import group_data_by_task -from fms_dgt.blocks.generators.llm import LMGeneratorBlock +from fms_dgt.blocks.generators.llm import LMGenerator from fms_dgt.blocks.validators.api import APIGenSpecValidator, ApiGenSpecYesNoValidation from fms_dgt.blocks.validators.rouge import RougeValidator from fms_dgt.databuilders.api.task import ApiSdgData, ApiSdgTask @@ -36,7 +36,7 @@ def __init__( ), "Number of prompt examples must be at least 1" # llm1 is the main generator that will produce the synthetic examples - llm1: LMGeneratorBlock + llm1: LMGenerator val1: APIGenSpecValidator val2: RougeValidator @@ -246,7 +246,7 @@ class ApiYesNoDataBuilder(ApiDataBuilder): """Class for API Sequence task""" # llm1 is the main generator that will produce the synthetic examples - llm1: LMGeneratorBlock + llm1: LMGenerator val1: ApiGenSpecYesNoValidation @@ -255,5 +255,5 @@ class ApiDetectionDataBuilder(ApiDataBuilder): """Class for API Sequence task""" # llm1 is the main generator that will produce the synthetic examples - llm1: LMGeneratorBlock + llm1: LMGenerator val1: APIGenSpecValidator diff --git a/fms_dgt/databuilders/nl2sql/generate.py b/fms_dgt/databuilders/nl2sql/generate.py index 2aee47db..5eff3ee6 100644 --- a/fms_dgt/databuilders/nl2sql/generate.py +++ b/fms_dgt/databuilders/nl2sql/generate.py @@ -2,11 +2,11 @@ from dataclasses import asdict from typing import Any, Iterable, List, Set, Tuple -# First Party +# Local from fms_dgt.base.databuilder import DataBuilder from fms_dgt.base.registry import register_data_builder from fms_dgt.base.task import SdgTask -from fms_dgt.blocks.generators.llm import LMGeneratorBlock +from fms_dgt.blocks.generators.llm import LMGenerator from fms_dgt.blocks.validators.nl2sql.sql_execution_validator import ( SQLExecutionValidator, ) @@ -38,7 +38,7 @@ def __init__( super().__init__(*args, **kwargs) # llm1 is a code generator for the synthetic examples - llm1: LMGeneratorBlock + llm1: LMGenerator # val1 is the validator which checks SQL syntax val1: SQLSyntaxValidator diff --git a/fms_dgt/databuilders/simple/generate.py b/fms_dgt/databuilders/simple/generate.py index 1d7b2658..3994eb77 100644 --- a/fms_dgt/databuilders/simple/generate.py +++ b/fms_dgt/databuilders/simple/generate.py @@ -7,11 +7,11 @@ # Third Party import pandas as pd -# First Party +# Local from fms_dgt.base.databuilder import DataBuilder from fms_dgt.base.registry import register_data_builder from fms_dgt.base.task import SdgTask, group_data_by_task -from fms_dgt.blocks.generators.llm import LMGeneratorBlock +from fms_dgt.blocks.generators.llm import LMGenerator from fms_dgt.blocks.validators.rouge import RougeValidator from fms_dgt.databuilders.simple.task import InstructLabSdgData, InstructLabSdgTask from fms_dgt.utils import sdg_logger @@ -25,7 +25,7 @@ class SimpleInstructDataBuilder(DataBuilder): TASK_TYPE: SdgTask = InstructLabSdgTask # llm1 is the main generator that will produce the synthetic examples - llm1: LMGeneratorBlock + llm1: LMGenerator # val1 is the validator which checks rouge score val1: RougeValidator diff --git a/templates/databuilder/generate.py b/templates/databuilder/generate.py index 4f69bcad..555e22b9 100644 --- a/templates/databuilder/generate.py +++ b/templates/databuilder/generate.py @@ -9,7 +9,7 @@ from fms_dgt.base.databuilder import DataBuilder from fms_dgt.base.registry import register_data_builder from fms_dgt.base.task import SdgTask -from fms_dgt.blocks.generators.llm import LMGeneratorBlock +from fms_dgt.blocks.generators.llm import LMGenerator from fms_dgt.blocks.validators.rouge import RougeValidator @@ -20,7 +20,7 @@ class TemplateDataBuilder(DataBuilder): TASK_TYPE: SdgTask = TemplateSdgTask # llm1 is the main generator that will produce the synthetic examples - llm1: LMGeneratorBlock + llm1: LMGenerator # val1 is the validator which checks rouge score val1: RougeValidator diff --git a/tests/blocks/generators/test_llm.py b/tests/blocks/generators/test_llm.py index 1e20b244..463716cb 100644 --- a/tests/blocks/generators/test_llm.py +++ b/tests/blocks/generators/test_llm.py @@ -7,9 +7,9 @@ # Third Party import pytest -# First Party +# Local from fms_dgt.base.registry import get_block -from fms_dgt.blocks.generators.llm import CachingLM, LMGeneratorBlock +from fms_dgt.blocks.generators.llm import CachingLM, LMGenerator # hf cache @@ -50,7 +50,7 @@ class TestLlmGenerators: "model_cfg", [GREEDY_GENAI_CFG, GREEDY_OPENAI_CFG] ) # GREEDY_VLLM_CFG] def test_generate_batch(self, model_cfg): - lm: LMGeneratorBlock = get_block(model_cfg["type"])( + lm: LMGenerator = get_block(model_cfg["type"])( name=f"test_{model_cfg['type']}", config=model_cfg ) @@ -71,7 +71,7 @@ def test_generate_batch(self, model_cfg): @pytest.mark.parametrize("model_cfg", [GREEDY_GENAI_CFG]) # , GREEDY_VLLM_CFG]) def test_loglikelihood_batch(self, model_cfg): - lm: LMGeneratorBlock = get_block(model_cfg["type"])( + lm: LMGenerator = get_block(model_cfg["type"])( name=f"test_{model_cfg['type']}", config=model_cfg ) @@ -130,7 +130,7 @@ def test_lm_caching(self): if os.path.exists(cache_path): os.remove(cache_path) - lm: LMGeneratorBlock = get_block(GREEDY_GENAI_CFG["type"])( + lm: LMGenerator = get_block(GREEDY_GENAI_CFG["type"])( name=f"test_{GREEDY_GENAI_CFG['type']}", config=GREEDY_GENAI_CFG ) From 1d654ddfcc1274ab07b29ed36bf3fcb8d92f2338 Mon Sep 17 00:00:00 2001 From: Max Crouse Date: Mon, 8 Jul 2024 10:34:29 -0500 Subject: [PATCH 13/41] remove config argument --- fms_dgt/base/block.py | 31 +++++------ fms_dgt/base/databuilder.py | 7 ++- fms_dgt/blocks/generators/genai.py | 6 +-- fms_dgt/blocks/generators/llm.py | 27 ++++++---- fms_dgt/blocks/generators/openai.py | 31 ++++------- fms_dgt/blocks/generators/vllm.py | 67 +++++++++++------------- fms_dgt/blocks/validators/lm_judge.py | 6 +-- fms_dgt/blocks/validators/rouge.py | 6 +-- tests/blocks/generators/test_llm.py | 18 +++---- tests/blocks/validators/test_api.py | 32 +++++------ tests/blocks/validators/test_lm_judge.py | 4 +- tests/blocks/validators/test_nl2sql.py | 6 +-- tests/blocks/validators/test_rouge.py | 4 +- 13 files changed, 123 insertions(+), 122 deletions(-) diff --git a/fms_dgt/base/block.py b/fms_dgt/base/block.py index fd1315ab..2353cb19 100644 --- a/fms_dgt/base/block.py +++ b/fms_dgt/base/block.py @@ -11,26 +11,27 @@ class BaseBlock(ABC): """Base Class for all Blocks""" - def __init__(self, name: str, config: Dict, **kwargs: Any) -> None: + def __init__( + self, + name: str = None, + arg_fields: List[str] = None, + kwarg_fields: List[str] = None, + result_field: str = None, + ) -> None: + + assert name is not None, f"'name' field cannot be empty in block definition" + self._name = name - self._config: Dict = config self._blocks: List[BaseBlock] = [] - # overwrite config fields with kwargs (usually these will be command line args) - self._config.update(kwargs) - - self._arg_fields = self._config.get("arg_fields", None) - self._kwarg_fields = self._config.get("kwarg_fields", None) - self._result_field = self._config.get("result_field", None) + self._arg_fields = arg_fields + self._kwarg_fields = kwarg_fields + self._result_field = result_field @property def name(self): return self._name - @property - def config(self): - return self._config - @property def blocks(self) -> List: """Returns the constituent blocks associated with this class.""" @@ -96,9 +97,9 @@ class BaseGeneratorBlock(BaseBlock): class BaseValidatorBlock(BaseBlock): - def __init__(self, name: str, config: Dict, **kwargs: Any) -> None: - super().__init__(name, config, **kwargs) - self._filter_invalids = config.get("filter", False) + def __init__(self, filter: bool = False, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._filter_invalids = filter def __call__( self, diff --git a/fms_dgt/base/databuilder.py b/fms_dgt/base/databuilder.py index 88078f45..c34749de 100644 --- a/fms_dgt/base/databuilder.py +++ b/fms_dgt/base/databuilder.py @@ -94,10 +94,15 @@ def _init_blocks(self, lm_cache: str = None): # TODO: need to handle nested blocks for obj_name, obj_config in self.config.blocks.items(): + obj_kwargs = {**obj_config, "name": obj_name} sdg_logger.debug( "Initializing object %s with config %s", obj_name, obj_config ) - obj = get_block(obj_config[TYPE_KEY])(obj_name, obj_config) + + assert ( + TYPE_KEY in obj_kwargs + ), f"'type' field missing from {obj_name} in data builder config" + obj = get_block(obj_kwargs.pop(TYPE_KEY))(**obj_kwargs) if lm_cache is not None and isinstance(obj, LMGenerator): sdg_logger.info( diff --git a/fms_dgt/blocks/generators/genai.py b/fms_dgt/blocks/generators/genai.py index 88482bd7..00d7e3bb 100644 --- a/fms_dgt/blocks/generators/genai.py +++ b/fms_dgt/blocks/generators/genai.py @@ -34,8 +34,8 @@ class GenAIGenerator(LMGenerator): """GenAI Generator""" - def __init__(self, name: str, config: Dict, **kwargs: Any): - super().__init__(name, config, **kwargs) + def __init__(self, **kwargs: Any): + super().__init__(**kwargs) try: # Third Party @@ -77,7 +77,7 @@ def generate_batch( until = None if isinstance(kwargs := copy.deepcopy(gen_kwargs), dict): - # start with default params in self.config then overwrite with kwargs + # start with default params then overwrite with kwargs kwargs = {**self._base_kwargs, **kwargs} until = kwargs.get("stop_sequences", None) model_id = kwargs.pop("model_id", self.model_id_or_path) diff --git a/fms_dgt/blocks/generators/llm.py b/fms_dgt/blocks/generators/llm.py index 3f9eee28..1b31cf41 100644 --- a/fms_dgt/blocks/generators/llm.py +++ b/fms_dgt/blocks/generators/llm.py @@ -36,22 +36,29 @@ class LMGenerator(BaseGeneratorBlock): """Class for LLM Generators""" - def __init__(self, name: str, config: Dict, **kwargs: Any): - super().__init__(name, config, **kwargs) + def __init__( + self, + model_id_or_path: str = None, + **kwargs: Any, + ): + # TODO: define exact kwargs that are supported + default_kwargs = {"decoding_method": "sample"} + cfg_kwargs = { + k: kwargs.pop(k) + for k in copy.copy(kwargs) + if k in TextGenerationParameters.model_fields + } + + super().__init__(**kwargs) + self._rank = 0 self.cache_hook = CacheHook(None) - self.model_id_or_path: str = config.get(MODEL_ID_OR_PATH, None) + self.model_id_or_path: str = model_id_or_path assert ( self.model_id_or_path is not None - ), f"Must specify model for Generator {name}" + ), f"Must specify model for Generator {self.name}" - default_kwargs = {"decoding_method": "sample"} - cfg_kwargs = { - k: v - for k, v in copy.deepcopy(self.config).items() - if k in TextGenerationParameters.model_fields - } self._base_kwargs = {**default_kwargs, **cfg_kwargs} @property diff --git a/fms_dgt/blocks/generators/openai.py b/fms_dgt/blocks/generators/openai.py index f9f274a6..a75fa802 100644 --- a/fms_dgt/blocks/generators/openai.py +++ b/fms_dgt/blocks/generators/openai.py @@ -12,7 +12,7 @@ # Standard from importlib.util import find_spec -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union import copy # Third Party @@ -71,19 +71,13 @@ def completion(): @register_block("openai-chat", "local-chat-completions") class OpenaiChatCompletionsLM(LMGenerator): - def __init__(self, name: str, config: Dict, **kwargs: Any) -> None: - """ - - :param model: str - Implements an OpenAI-style chat completion API for - accessing both OpenAI OR locally-hosted models using - HuggingFace Tokenizer - OpenAI API model (e.g. gpt-3.5-turbo) - using the **gen_kwargs passed on init - :param truncate: bool - Truncate input if too long (if False and input is too long, throw error) - """ - super().__init__(name, config, **kwargs) + def __init__( + self, + base_url: str = None, + truncate: bool = False, + **kwargs: Any, + ): + super().__init__(**kwargs) try: # Third Party import openai # noqa: E401 @@ -93,11 +87,8 @@ def __init__(self, name: str, config: Dict, **kwargs: Any) -> None: please install these via `pip install .[openai]`", ) - self.model_id_or_path: str = config.get( - "model_id", "gpt-3.5-turbo" - ) # GPT model or Local model using HuggingFace model paths - self.base_url: str = config.get("base_url", None) - self.truncate: bool = config.get("truncate", False) + self.base_url: str = base_url + self.truncate: bool = truncate # Read from environment variable OPENAI_API_KEY # Set to EMPTY for local @@ -139,7 +130,7 @@ def generate_batch( until = None if isinstance(kwargs := copy.deepcopy(gen_kwargs), dict): - # start with default params in self.config then overwrite with kwargs + # start with default params then overwrite with kwargs kwargs = {**self._base_kwargs, **kwargs} model_id = kwargs.pop("model_id_or_path", self.model_id_or_path) kwargs["stop"] = until diff --git a/fms_dgt/blocks/generators/vllm.py b/fms_dgt/blocks/generators/vllm.py index b06934ea..b5e28a08 100644 --- a/fms_dgt/blocks/generators/vllm.py +++ b/fms_dgt/blocks/generators/vllm.py @@ -50,8 +50,32 @@ class vLLMGenerator(LMGenerator): _DEFAULT_MAX_LENGTH = 2048 - def __init__(self, name: str, config: Dict, **kwargs: Any): - super().__init__(name, config, **kwargs) + def __init__( + self, + dtype: Literal["float16", "bfloat16", "float32", "auto"] = "auto", + revision: Optional[str] = None, + trust_remote_code: Optional[bool] = False, + tokenizer: Optional[str] = None, + tokenizer_mode: Literal["auto", "slow"] = "auto", + tokenizer_revision: Optional[str] = None, + add_bos_token: Optional[bool] = False, + prefix_token_id: Optional[int] = None, + tensor_parallel_size: int = 1, + quantization: Optional[str] = None, + max_gen_toks: int = 256, + swap_space: int = 4, + batch_size: Union[str, int] = 1, + max_batch_size=None, + max_length: int = None, + max_model_len: int = None, + seed: int = 1234, + gpu_memory_utilization: float = 0.9, + device: str = "cuda", + data_parallel_size: int = 1, + lora_local_path: str = None, + **kwargs: Any, + ): + super().__init__(**kwargs) if not find_spec("vllm"): raise Exception( @@ -59,35 +83,6 @@ def __init__(self, name: str, config: Dict, **kwargs: Any): "Please install vllm via `pip install fms_dgt[vllm]`" ) - pretrained = self._config.get( - "model_id_or_path", - None, - ) - dtype: Literal["float16", "bfloat16", "float32", "auto"] = self._config.get( - "dtype", "auto" - ) - revision: Optional[str] = self._config.get("revision", None) - trust_remote_code: Optional[bool] = self._config.get("trust_remote_code", False) - tokenizer: Optional[str] = self._config.get("tokenizer", None) - tokenizer_mode: Literal["auto", "slow"] = self._config.get( - "tokenizer_mode", "auto" - ) - tokenizer_revision: Optional[str] = self._config.get("tokenizer_revision", None) - add_bos_token: Optional[bool] = self._config.get("add_bos_token", False) - prefix_token_id: Optional[int] = self._config.get("prefix_token_id", None) - tensor_parallel_size: int = self._config.get("tensor_parallel_size", 1) - quantization: Optional[str] = self._config.get("quantization", None) - max_gen_toks: int = self._config.get("max_gen_toks", 256) - swap_space: int = self._config.get("swap_space", 4) - batch_size: Union[str, int] = self._config.get("batch_size", "auto") - max_batch_size = self._config.get("max_batch_size", None) - max_length: int = self._config.get("max_length", None) - max_model_len: int = self._config.get("max_model_len", None) - seed: int = self._config.get("seed", 1234) - gpu_memory_utilization: float = self._config.get("gpu_memory_utilization", 0.9) - device: str = self._config.get("device", "cuda") - data_parallel_size: int = self._config.get("data_parallel_size", 1) - assert "cuda" in device or device is None, "vLLM only supports CUDA" assert ( max_length is None or max_model_len is None @@ -97,7 +92,7 @@ def __init__(self, name: str, config: Dict, **kwargs: Any): self.tensor_parallel_size = int(tensor_parallel_size) self.data_parallel_size = int(data_parallel_size) self.model_args = { - "model": pretrained, + "model": self.model_id_or_path, "gpu_memory_utilization": float(gpu_memory_utilization), "revision": revision, "dtype": dtype, @@ -136,10 +131,12 @@ def __init__(self, name: str, config: Dict, **kwargs: Any): from transformers import AutoConfig self._config = AutoConfig.from_pretrained( - pretrained, trust_remote_code=trust_remote_code, revision=revision + self.model_id_or_path, + trust_remote_code=trust_remote_code, + revision=revision, ) self.tokenizer = get_tokenizer( - tokenizer if tokenizer else pretrained, + tokenizer if tokenizer else self.model_id_or_path, tokenizer_mode=tokenizer_mode, trust_remote_code=trust_remote_code, tokenizer_revision=tokenizer_revision, @@ -260,7 +257,7 @@ def generate_batch( # unpack our keyword arguments. until = None if isinstance(kwargs := copy.deepcopy(gen_kwargs), dict): - # start with default params in self.config then overwrite with kwargs + # start with default params then overwrite with kwargs kwargs = {**self._base_kwargs, **kwargs} if "stop_sequences" in kwargs: until = kwargs.pop("stop_sequences") diff --git a/fms_dgt/blocks/validators/lm_judge.py b/fms_dgt/blocks/validators/lm_judge.py index 76f0e37a..fb725ee1 100644 --- a/fms_dgt/blocks/validators/lm_judge.py +++ b/fms_dgt/blocks/validators/lm_judge.py @@ -14,9 +14,9 @@ class LMJudgeValidator(BaseValidatorBlock): """LLM-based Validator""" - def __init__(self, name: str, config: Dict, **kwargs: Any): - super().__init__(name, config, **kwargs) - self._llm_generator: LMGenerator = get_block(config[TYPE_KEY])(name, config) + def __init__(self, lm_type: str = None, **kwargs: Any): + super().__init__(**kwargs) + self._llm_generator: LMGenerator = get_block(lm_type)(self.name, **kwargs) self._blocks.append(self._llm_generator) def validate_batch(self, inputs: List[Instance], **kwargs: Any) -> None: diff --git a/fms_dgt/blocks/validators/rouge.py b/fms_dgt/blocks/validators/rouge.py index acb5465e..4aebad2e 100644 --- a/fms_dgt/blocks/validators/rouge.py +++ b/fms_dgt/blocks/validators/rouge.py @@ -21,9 +21,9 @@ class RougeValidator(BaseValidatorBlock): """Base Class for all Validators""" - def __init__(self, name: str, config: Dict) -> None: - super().__init__(name, config) - self._threshold = config.get("threshold", None) + def __init__(self, threshold: float = -1, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._threshold = threshold if self._threshold <= 0: self._threshold = None diff --git a/tests/blocks/generators/test_llm.py b/tests/blocks/generators/test_llm.py index 463716cb..ba7bbc31 100644 --- a/tests/blocks/generators/test_llm.py +++ b/tests/blocks/generators/test_llm.py @@ -50,9 +50,9 @@ class TestLlmGenerators: "model_cfg", [GREEDY_GENAI_CFG, GREEDY_OPENAI_CFG] ) # GREEDY_VLLM_CFG] def test_generate_batch(self, model_cfg): - lm: LMGenerator = get_block(model_cfg["type"])( - name=f"test_{model_cfg['type']}", config=model_cfg - ) + model_cfg = dict(model_cfg) + model_type = model_cfg.pop("type") + lm: LMGenerator = get_block(model_type)(name=f"test_{model_type}", **model_cfg) inputs: List[Dict] = [] for prompt in PROMPTS: @@ -71,9 +71,9 @@ def test_generate_batch(self, model_cfg): @pytest.mark.parametrize("model_cfg", [GREEDY_GENAI_CFG]) # , GREEDY_VLLM_CFG]) def test_loglikelihood_batch(self, model_cfg): - lm: LMGenerator = get_block(model_cfg["type"])( - name=f"test_{model_cfg['type']}", config=model_cfg - ) + model_cfg = dict(model_cfg) + model_type = model_cfg.pop("type") + lm: LMGenerator = get_block(model_type)(name=f"test_{model_type}", **model_cfg) inputs: List[Dict] = [] for prompt in PROMPTS: @@ -130,9 +130,9 @@ def test_lm_caching(self): if os.path.exists(cache_path): os.remove(cache_path) - lm: LMGenerator = get_block(GREEDY_GENAI_CFG["type"])( - name=f"test_{GREEDY_GENAI_CFG['type']}", config=GREEDY_GENAI_CFG - ) + model_cfg = dict(GREEDY_GENAI_CFG) + model_type = model_cfg.pop("type") + lm: LMGenerator = get_block(model_type)(name=f"test_{model_type}", **model_cfg) non_cache_inputs: List[Dict] = [] for prompt in PROMPTS: diff --git a/tests/blocks/validators/test_api.py b/tests/blocks/validators/test_api.py index 8385173a..600b2047 100644 --- a/tests/blocks/validators/test_api.py +++ b/tests/blocks/validators/test_api.py @@ -5,7 +5,7 @@ # Third Party import pytest -# First Party +# Local from fms_dgt.base.instance import Instance from fms_dgt.blocks.validators.api import APIGenSpecValidator, ApiGenSpecYesNoValidation @@ -45,7 +45,7 @@ def get_args(func_calls): class TestApiValidator: def test_single_intent(self): - validator = APIGenSpecValidator("test_single_intent", dict()) + validator = APIGenSpecValidator(name="test_single_intent") # single intent func_calls = [{"name": "add"}] @@ -54,11 +54,11 @@ def test_single_intent(self): args = [api_info, question, json.dumps(func_calls)] test_instance = [Instance(args, single_intent_kwargs)] - validator.validate_batch(test_instance) + validator(test_instance) assert test_instance[0].result def test_multi_intent(self): - validator = APIGenSpecValidator("test_multi_intent", dict()) + validator = APIGenSpecValidator(name="test_multi_intent") # multiple intent func_calls = [ {"name": "add"}, @@ -69,11 +69,11 @@ def test_multi_intent(self): args = [api_info, question, json.dumps(func_calls)] test_instance = [Instance(args, multi_intent_kwargs)] - validator.validate_batch(test_instance) + validator(test_instance) assert test_instance[0].result def test_parallel_single(self): - validator = APIGenSpecValidator("test_parallel_single", dict()) + validator = APIGenSpecValidator(name="test_parallel_single") # parallel single func_calls = [ @@ -85,7 +85,7 @@ def test_parallel_single(self): args = [api_info, question, json.dumps(func_calls)] test_instance = [Instance(args, parallel_kwargs)] - validator.validate_batch(test_instance) + validator(test_instance) assert test_instance[0].result func_calls = [ @@ -97,13 +97,13 @@ def test_parallel_single(self): args = [api_info, question, json.dumps(func_calls)] test_instance = [Instance(args, parallel_kwargs)] - validator.validate_batch(test_instance) + validator(test_instance) assert not test_instance[ 0 ].result, "Validator should have failed due to required args!" def test_parallel_multiple(self): - validator = APIGenSpecValidator("test_parallel_multiple", dict()) + validator = APIGenSpecValidator(name="test_parallel_multiple") # parallel multiple func_calls = [ @@ -115,7 +115,7 @@ def test_parallel_multiple(self): args = [api_info, question, json.dumps(func_calls)] test_instance = [Instance(args, parallel_kwargs)] - validator.validate_batch(test_instance) + validator(test_instance) assert test_instance[0].result func_calls = [ @@ -127,7 +127,7 @@ def test_parallel_multiple(self): args = [api_info, question, json.dumps(func_calls)] test_instance = [Instance(args, parallel_kwargs)] - validator.validate_batch(test_instance) + validator(test_instance) assert not test_instance[ 0 ].result, ( @@ -135,7 +135,7 @@ def test_parallel_multiple(self): ) def test_parallel_nested(self): - validator = APIGenSpecValidator("test_parallel_nested", dict()) + validator = APIGenSpecValidator(name="test_parallel_nested") # parallel multiple func_calls = [ @@ -147,7 +147,7 @@ def test_parallel_nested(self): args = [api_info, question, json.dumps(func_calls)] test_instance = [Instance(args, parallel_nested_kwargs)] - validator.validate_batch(test_instance) + validator(test_instance) assert test_instance[0].result func_calls = [ @@ -159,7 +159,7 @@ def test_parallel_nested(self): args = [api_info, question, json.dumps(func_calls)] test_instance = [Instance(args, parallel_nested_kwargs)] - validator.validate_batch(test_instance) + validator(test_instance) assert not test_instance[ 0 ].result, ( @@ -167,12 +167,12 @@ def test_parallel_nested(self): ) def test_yes_no(self): - validator = ApiGenSpecYesNoValidation("test_yes_no", dict()) + validator = ApiGenSpecYesNoValidation(name="test_yes_no") for arg_inp in ["YES", "NO", "MAYBE"]: args = [TEST_APIS, "this is a test question", arg_inp] test_instance = [Instance(args)] - validator.validate_batch(test_instance) + validator(test_instance) assert test_instance[0].result == (arg_inp in ["YES", "NO"]) diff --git a/tests/blocks/validators/test_lm_judge.py b/tests/blocks/validators/test_lm_judge.py index 6dc95027..4e6af55d 100644 --- a/tests/blocks/validators/test_lm_judge.py +++ b/tests/blocks/validators/test_lm_judge.py @@ -7,7 +7,7 @@ # Third Party import pytest -# First Party +# Local from fms_dgt.base.instance import Instance from fms_dgt.blocks.validators.lm_judge import LMJudgeValidator @@ -24,7 +24,7 @@ class TestLlmJudgeValidator: @pytest.mark.parametrize("model_backend", ["genai"]) def test_generate_batch(self, model_backend): - lm_judge = LMJudgeValidator(name=f"test_{model_backend}", config=GREEDY_CFG) + lm_judge = LMJudgeValidator(name=f"test_{model_backend}", **GREEDY_CFG) inputs = [ Instance( diff --git a/tests/blocks/validators/test_nl2sql.py b/tests/blocks/validators/test_nl2sql.py index 436b9e95..ce3cb53f 100644 --- a/tests/blocks/validators/test_nl2sql.py +++ b/tests/blocks/validators/test_nl2sql.py @@ -1,4 +1,4 @@ -# First Party +# Local from fms_dgt.blocks.validators.nl2sql.sql_execution_validator import ( SQLExecutionValidator, ) @@ -6,7 +6,7 @@ def test_sql_syntax_validator(): - validator = SQLSyntaxValidator(name="sql_syntax_validator", config={}) + validator = SQLSyntaxValidator(name="sql_syntax_validator") assert validator._validate( record=dict( sql_schema="CREATE TABLE users (\n user_id INTEGER,\n first_name VARCHAR(155),\n last_name VARCHAR(155),\n email VARCHAR(155),\n city VARCHAR(155),\n country VARCHAR(155)\n);\nCREATE TABLE orders (\n order_id INTEGER,\n user_id INTEGER,\n product_name VARCHAR(155),\n price FLOAT,\n quantity INT,\n order_date DATE,\n order_state VARCHAR(155),\n CONSTRAINT orders_user_id_fkey FOREIGN KEY (user_id) REFERENCES users (user_id)\n);", @@ -24,7 +24,7 @@ def test_sql_syntax_validator(): def test_sql_execution_validator(): - validator = SQLExecutionValidator(name="sql_execution_validator", config={}) + validator = SQLExecutionValidator(name="sql_execution_validator") assert validator._validate( record=dict( sql_schema="CREATE TABLE users (\n user_id INTEGER,\n first_name VARCHAR(155),\n last_name VARCHAR(155),\n email VARCHAR(155),\n city VARCHAR(155),\n country VARCHAR(155)\n);\nCREATE TABLE orders (\n order_id INTEGER,\n user_id INTEGER,\n product_name VARCHAR(155),\n price FLOAT,\n quantity INT,\n order_date DATE,\n order_state VARCHAR(155),\n CONSTRAINT orders_user_id_fkey FOREIGN KEY (user_id) REFERENCES users (user_id)\n);", diff --git a/tests/blocks/validators/test_rouge.py b/tests/blocks/validators/test_rouge.py index a3d74caf..b46e11fc 100644 --- a/tests/blocks/validators/test_rouge.py +++ b/tests/blocks/validators/test_rouge.py @@ -5,14 +5,14 @@ # Third Party import pytest -# First Party +# Local from fms_dgt.base.instance import Instance from fms_dgt.blocks.validators.rouge import RougeValidator class TestRougeValidator: def test_matches(self): - validator = RougeValidator("test_rouge_validator", {"threshold": 0.0}) + validator = RougeValidator(name="test_rouge_validator", threshold=0.0) all_data = [ "I went to the store", From 6d51079305b29b1309f7b9454960ee59f976999c Mon Sep 17 00:00:00 2001 From: Max Crouse Date: Mon, 8 Jul 2024 11:40:18 -0500 Subject: [PATCH 14/41] demonstrate default vals --- fms_dgt/base/block.py | 11 +++++++++++ fms_dgt/databuilders/simple/generate.py | 14 ++------------ fms_dgt/databuilders/simple/simple.yaml | 9 +++++++++ 3 files changed, 22 insertions(+), 12 deletions(-) diff --git a/fms_dgt/base/block.py b/fms_dgt/base/block.py index 2353cb19..42c11dc8 100644 --- a/fms_dgt/base/block.py +++ b/fms_dgt/base/block.py @@ -24,6 +24,17 @@ def __init__( self._name = name self._blocks: List[BaseBlock] = [] + # minor type checking + if type(arg_fields) == str: + arg_fields = [arg_fields] + if type(kwarg_fields) == str: + kwarg_fields = [kwarg_fields] + if type(result_field) == list: + assert ( + len(result_field) == 1 + ), "Cannot have multiple 'result' fields for {name}" + result_field = result_field[0] + self._arg_fields = arg_fields self._kwarg_fields = kwarg_fields self._result_field = result_field diff --git a/fms_dgt/databuilders/simple/generate.py b/fms_dgt/databuilders/simple/generate.py index 3994eb77..004bd2d8 100644 --- a/fms_dgt/databuilders/simple/generate.py +++ b/fms_dgt/databuilders/simple/generate.py @@ -75,12 +75,7 @@ def __call__( request_start = time.time() - llm_outputs = self.llm1( - inputs, - arg_fields=["prompt"], - kwarg_fields=["stop_sequences"], - result_field="output", - ) + llm_outputs = self.llm1(inputs) request_duration = time.time() - request_start post_process_start = time.time() @@ -129,12 +124,7 @@ def __call__( val1_inputs.append(inp) # filter rouge failed data - outputs = [ - output["data"] - for output in self.val1( - val1_inputs, arg_fields=["new_toks", "all_toks"], result_field="output" - ) - ] + outputs = [output["data"] for output in self.val1(val1_inputs)] discarded += len(val1_inputs) - len(outputs) diff --git a/fms_dgt/databuilders/simple/simple.yaml b/fms_dgt/databuilders/simple/simple.yaml index 7f5cc743..48618c55 100644 --- a/fms_dgt/databuilders/simple/simple.yaml +++ b/fms_dgt/databuilders/simple/simple.yaml @@ -2,12 +2,21 @@ name: simple blocks: llm1: type: genai + arg_fields: + - prompt + kwarg_fields: + - stop_sequences + result_field: output temperature: 0.0 max_new_tokens: 512 min_new_tokens: 1 model_id_or_path: mistralai/mixtral-8x7b-instruct-v01 val1: type: rouge_scorer + arg_fields: + - new_toks + - all_toks + result_field: output filter: true threshold: 1.0 metadata: From 0d17bb69d401e8828cf52ef554153ec294976379 Mon Sep 17 00:00:00 2001 From: Max Crouse Date: Mon, 8 Jul 2024 11:58:20 -0500 Subject: [PATCH 15/41] demonstrate default vals --- fms_dgt/base/block.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/fms_dgt/base/block.py b/fms_dgt/base/block.py index 42c11dc8..d2042ead 100644 --- a/fms_dgt/base/block.py +++ b/fms_dgt/base/block.py @@ -19,8 +19,6 @@ def __init__( result_field: str = None, ) -> None: - assert name is not None, f"'name' field cannot be empty in block definition" - self._name = name self._blocks: List[BaseBlock] = [] From fba82b61f41d8a2e10374ed6250336de2d48e719 Mon Sep 17 00:00:00 2001 From: Max Crouse Date: Mon, 8 Jul 2024 12:53:12 -0500 Subject: [PATCH 16/41] remove abstract method to simplify --- fms_dgt/base/block.py | 7 +------ fms_dgt/blocks/generators/genai.py | 20 ------------------- fms_dgt/blocks/generators/openai.py | 20 ------------------- fms_dgt/blocks/generators/vllm.py | 20 ------------------- fms_dgt/blocks/utilities/flatten_field.py | 4 +--- fms_dgt/blocks/validators/api.py | 18 ----------------- .../nl2sql/sql_execution_validator.py | 18 ----------------- .../validators/nl2sql/sql_syntax_validator.py | 18 ----------------- fms_dgt/blocks/validators/rouge.py | 18 ----------------- 9 files changed, 2 insertions(+), 141 deletions(-) diff --git a/fms_dgt/base/block.py b/fms_dgt/base/block.py index d2042ead..ff446f54 100644 --- a/fms_dgt/base/block.py +++ b/fms_dgt/base/block.py @@ -84,17 +84,14 @@ def write_result( else: raise ValueError(f"Unexpected input type: {type(inp)}") - @abc.abstractmethod def __call__( self, inputs: Union[List[Dict], Type[pd.DataFrame], Type[Dataset]], - *args: Any, arg_fields: Optional[List[str]] = None, kwarg_fields: Optional[List[str]] = None, result_field: Optional[str] = None, - **kwargs: Any, ): - pass + raise NotImplementedError class BaseUtilityBlock(BaseBlock): @@ -113,11 +110,9 @@ def __init__(self, filter: bool = False, **kwargs: Any) -> None: def __call__( self, inputs: Union[List[Dict], Type[pd.DataFrame], Type[Dataset]], - *args: Any, arg_fields: Optional[List[str]] = None, kwarg_fields: Optional[List[str]] = None, result_field: Optional[List[str]] = None, - **kwargs: Any, ): outputs = [] for x in inputs: diff --git a/fms_dgt/blocks/generators/genai.py b/fms_dgt/blocks/generators/genai.py index 00d7e3bb..1035076c 100644 --- a/fms_dgt/blocks/generators/genai.py +++ b/fms_dgt/blocks/generators/genai.py @@ -207,23 +207,3 @@ def loglikelihood_batch( pbar.update(1) pbar.close() - - def __call__( - self, - inputs: Union[List[Dict], DataFrame, Dataset], - *args: Any, - arg_fields: Union[List[str], None] = None, - kwarg_fields: Union[List[str], None] = None, - result_field: Union[str, None] = None, - method: str = "generate", - **kwargs: Any, - ): - return super().__call__( - inputs, - *args, - arg_fields=arg_fields, - kwarg_fields=kwarg_fields, - result_field=result_field, - method=method, - **kwargs, - ) diff --git a/fms_dgt/blocks/generators/openai.py b/fms_dgt/blocks/generators/openai.py index a75fa802..fa462254 100644 --- a/fms_dgt/blocks/generators/openai.py +++ b/fms_dgt/blocks/generators/openai.py @@ -172,23 +172,3 @@ def generate_batch( def loglikelihood_batch(self, requests, disable_tqdm: bool = False): raise NotImplementedError("No support for logits.") - - def __call__( - self, - inputs: Union[List[Dict], DataFrame, Dataset], - *args: Any, - arg_fields: Union[List[str], None] = None, - kwarg_fields: Union[List[str], None] = None, - result_field: Union[str, None] = None, - method: str = "generate", - **kwargs: Any, - ): - return super().__call__( - inputs, - *args, - arg_fields=arg_fields, - kwarg_fields=kwarg_fields, - result_field=result_field, - method=method, - **kwargs, - ) diff --git a/fms_dgt/blocks/generators/vllm.py b/fms_dgt/blocks/generators/vllm.py index b5e28a08..d8a2cd53 100644 --- a/fms_dgt/blocks/generators/vllm.py +++ b/fms_dgt/blocks/generators/vllm.py @@ -460,23 +460,3 @@ def modify_gen_kwargs(kwargs: dict) -> dict: "spaces_between_special_tokens", False ) return kwargs - - def __call__( - self, - inputs: Union[List[Dict], DataFrame, Dataset], - *args: Any, - arg_fields: Union[List[str], None] = None, - kwarg_fields: Union[List[str], None] = None, - result_field: Union[str, None] = None, - method: str = "generate", - **kwargs: Any, - ): - return super().__call__( - inputs, - *args, - arg_fields=arg_fields, - kwarg_fields=kwarg_fields, - result_field=result_field, - method=method, - **kwargs, - ) diff --git a/fms_dgt/blocks/utilities/flatten_field.py b/fms_dgt/blocks/utilities/flatten_field.py index cfeaa007..cadb9be0 100644 --- a/fms_dgt/blocks/utilities/flatten_field.py +++ b/fms_dgt/blocks/utilities/flatten_field.py @@ -18,11 +18,9 @@ class FlattenField(BaseUtilityBlock): def __call__( self, inputs: Union[List[Dict], DataFrame, Dataset], - *args: Any, arg_fields: Optional[List[str]] = None, kwarg_fields: Optional[List[str]] = None, - result_field: Optional[List[str]] = None, - **kwargs: Any, + result_field: Optional[str] = None, ): arg_fields = arg_fields if arg_fields is not None else self._arg_fields if arg_fields is None: diff --git a/fms_dgt/blocks/validators/api.py b/fms_dgt/blocks/validators/api.py index 717f7eb1..06b7739a 100644 --- a/fms_dgt/blocks/validators/api.py +++ b/fms_dgt/blocks/validators/api.py @@ -26,24 +26,6 @@ class APIGenSpecValidator(BaseValidatorBlock): """Class for API Sequence Prediction Validator""" - def __call__( - self, - inputs: Union[List[Dict], DataFrame, Dataset], - *args: Any, - arg_fields: Optional[List[str]] = None, - kwarg_fields: Optional[List[str]] = None, - result_field: Optional[List[str]] = None, - **kwargs: Any, - ): - return super().__call__( - inputs, - *args, - arg_fields=arg_fields, - kwarg_fields=kwarg_fields, - result_field=result_field, - **kwargs, - ) - def _validate( self, api_info: dict, diff --git a/fms_dgt/blocks/validators/nl2sql/sql_execution_validator.py b/fms_dgt/blocks/validators/nl2sql/sql_execution_validator.py index 329b795c..6c272961 100644 --- a/fms_dgt/blocks/validators/nl2sql/sql_execution_validator.py +++ b/fms_dgt/blocks/validators/nl2sql/sql_execution_validator.py @@ -20,24 +20,6 @@ class SQLExecutionValidator(BaseValidatorBlock): """SQL execution validator.""" - def __call__( - self, - inputs: Union[List[Dict], DataFrame, Dataset], - *args: Any, - arg_fields: Optional[List[str]] = None, - kwarg_fields: Optional[List[str]] = None, - result_field: Optional[List[str]] = None, - **kwargs: Any, - ): - return super().__call__( - inputs, - *args, - arg_fields=arg_fields, - kwarg_fields=kwarg_fields, - result_field=result_field, - **kwargs, - ) - def _validate(self, record: Dict[str, str], **kwargs: Any) -> bool: """Validate a record containing information on schema, query and utterance. diff --git a/fms_dgt/blocks/validators/nl2sql/sql_syntax_validator.py b/fms_dgt/blocks/validators/nl2sql/sql_syntax_validator.py index 9195ad09..167e6b53 100644 --- a/fms_dgt/blocks/validators/nl2sql/sql_syntax_validator.py +++ b/fms_dgt/blocks/validators/nl2sql/sql_syntax_validator.py @@ -19,24 +19,6 @@ class SQLSyntaxValidator(BaseValidatorBlock): """SQL syntax validator.""" - def __call__( - self, - inputs: Union[List[Dict], DataFrame, Dataset], - *args: Any, - arg_fields: Optional[List[str]] = None, - kwarg_fields: Optional[List[str]] = None, - result_field: Optional[List[str]] = None, - **kwargs: Any, - ): - return super().__call__( - inputs, - *args, - arg_fields=arg_fields, - kwarg_fields=kwarg_fields, - result_field=result_field, - **kwargs, - ) - def _validate( self, record: Dict[str, str], sql_dialect: str = "postgres", **kwargs: Any ) -> bool: diff --git a/fms_dgt/blocks/validators/rouge.py b/fms_dgt/blocks/validators/rouge.py index 4aebad2e..4c9cb49d 100644 --- a/fms_dgt/blocks/validators/rouge.py +++ b/fms_dgt/blocks/validators/rouge.py @@ -44,24 +44,6 @@ def tokenize(self, inp: Union[List, str]): self._cache[inp] = self.scorer._tokenizer.tokenize(inp) return self._cache[inp] - def __call__( - self, - inputs: Union[List[Dict], DataFrame, Dataset], - *args: Any, - arg_fields: Optional[List[str]] = None, - kwarg_fields: Optional[List[str]] = None, - result_field: Optional[List[str]] = None, - **kwargs: Any, - ): - return super().__call__( - inputs, - *args, - arg_fields=arg_fields, - kwarg_fields=kwarg_fields, - result_field=result_field, - **kwargs, - ) - def _validate(self, new_tokens: List[int], check_tokens: List[List[int]]) -> bool: """Runs through all the validators if data list is None. Otherwise just runs through the validators specified for data in the List""" From e7db4dd1bc982955f4d723f38c6c2281744083b8 Mon Sep 17 00:00:00 2001 From: Max Crouse Date: Mon, 8 Jul 2024 12:54:19 -0500 Subject: [PATCH 17/41] remove abstract method to simplify --- fms_dgt/blocks/generators/llm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/fms_dgt/blocks/generators/llm.py b/fms_dgt/blocks/generators/llm.py index 1b31cf41..8ee94693 100644 --- a/fms_dgt/blocks/generators/llm.py +++ b/fms_dgt/blocks/generators/llm.py @@ -100,7 +100,6 @@ def set_cache_hook(self, cache_hook) -> None: def __call__( self, inputs: Union[List[Dict], pd.DataFrame, Dataset], - *args: Any, arg_fields: Optional[List[str]] = None, kwarg_fields: Optional[List[str]] = None, result_field: Optional[str] = None, From af01c4ed11c6f25aed214ef4c08814b32dc25749 Mon Sep 17 00:00:00 2001 From: Max Crouse Date: Mon, 8 Jul 2024 16:18:00 -0500 Subject: [PATCH 18/41] misc minor changes --- fms_dgt/base/block.py | 15 ++++++--------- fms_dgt/blocks/generators/llm.py | 7 +++---- fms_dgt/blocks/utilities/flatten_field.py | 4 ++-- fms_dgt/blocks/validators/lm_judge.py | 2 +- templates/generator/template.py | 4 ++-- templates/validator/template.py | 18 ------------------ tests/compatibility_tests/blocks.py | 6 +++--- 7 files changed, 17 insertions(+), 39 deletions(-) diff --git a/fms_dgt/base/block.py b/fms_dgt/base/block.py index ff446f54..c4a36350 100644 --- a/fms_dgt/base/block.py +++ b/fms_dgt/base/block.py @@ -1,12 +1,15 @@ # Standard from abc import ABC -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, Dict, Iterable, List, Optional, Type, Union import abc # Third Party from datasets import Dataset import pandas as pd +BLOCK_ROW_TYPE = Union[Dict, Type[pd.Series]] +BLOCK_INPUT_TYPE = Union[Iterable[BLOCK_ROW_TYPE], pd.DataFrame, Dataset] + class BaseBlock(ABC): """Base Class for all Blocks""" @@ -20,7 +23,6 @@ def __init__( ) -> None: self._name = name - self._blocks: List[BaseBlock] = [] # minor type checking if type(arg_fields) == str: @@ -41,11 +43,6 @@ def __init__( def name(self): return self._name - @property - def blocks(self) -> List: - """Returns the constituent blocks associated with this class.""" - return self._blocks - def get_args_kwargs( self, inp: Union[Dict, pd.DataFrame, Dataset], @@ -86,7 +83,7 @@ def write_result( def __call__( self, - inputs: Union[List[Dict], Type[pd.DataFrame], Type[Dataset]], + inputs: BLOCK_INPUT_TYPE, arg_fields: Optional[List[str]] = None, kwarg_fields: Optional[List[str]] = None, result_field: Optional[str] = None, @@ -109,7 +106,7 @@ def __init__(self, filter: bool = False, **kwargs: Any) -> None: def __call__( self, - inputs: Union[List[Dict], Type[pd.DataFrame], Type[Dataset]], + inputs: BLOCK_INPUT_TYPE, arg_fields: Optional[List[str]] = None, kwarg_fields: Optional[List[str]] = None, result_field: Optional[List[str]] = None, diff --git a/fms_dgt/blocks/generators/llm.py b/fms_dgt/blocks/generators/llm.py index 8ee94693..32979362 100644 --- a/fms_dgt/blocks/generators/llm.py +++ b/fms_dgt/blocks/generators/llm.py @@ -26,7 +26,7 @@ import pandas as pd # Local -from fms_dgt.base.block import BaseGeneratorBlock +from fms_dgt.base.block import BLOCK_INPUT_TYPE, BaseGeneratorBlock from fms_dgt.base.instance import Instance from fms_dgt.utils import sdg_logger @@ -99,7 +99,7 @@ def set_cache_hook(self, cache_hook) -> None: def __call__( self, - inputs: Union[List[Dict], pd.DataFrame, Dataset], + inputs: BLOCK_INPUT_TYPE, arg_fields: Optional[List[str]] = None, kwarg_fields: Optional[List[str]] = None, result_field: Optional[str] = None, @@ -250,8 +250,7 @@ def fn(requests: List[Instance]): def __call__( self, - inputs: Union[List[Dict], pd.DataFrame, Dataset], - *args: Any, + inputs: BLOCK_INPUT_TYPE, arg_fields: Optional[List[str]] = None, kwarg_fields: Optional[List[str]] = None, result_field: Optional[str] = None, diff --git a/fms_dgt/blocks/utilities/flatten_field.py b/fms_dgt/blocks/utilities/flatten_field.py index cadb9be0..bc7f4f4c 100644 --- a/fms_dgt/blocks/utilities/flatten_field.py +++ b/fms_dgt/blocks/utilities/flatten_field.py @@ -7,7 +7,7 @@ from pandas import DataFrame # Local -from fms_dgt.base.block import BaseUtilityBlock +from fms_dgt.base.block import BLOCK_INPUT_TYPE, BaseUtilityBlock from fms_dgt.base.registry import register_block @@ -17,7 +17,7 @@ class FlattenField(BaseUtilityBlock): def __call__( self, - inputs: Union[List[Dict], DataFrame, Dataset], + inputs: BLOCK_INPUT_TYPE, arg_fields: Optional[List[str]] = None, kwarg_fields: Optional[List[str]] = None, result_field: Optional[str] = None, diff --git a/fms_dgt/blocks/validators/lm_judge.py b/fms_dgt/blocks/validators/lm_judge.py index fb725ee1..b0ff68ab 100644 --- a/fms_dgt/blocks/validators/lm_judge.py +++ b/fms_dgt/blocks/validators/lm_judge.py @@ -17,7 +17,7 @@ class LMJudgeValidator(BaseValidatorBlock): def __init__(self, lm_type: str = None, **kwargs: Any): super().__init__(**kwargs) self._llm_generator: LMGenerator = get_block(lm_type)(self.name, **kwargs) - self._blocks.append(self._llm_generator) + self.blocks = [self._llm_generator] def validate_batch(self, inputs: List[Instance], **kwargs: Any) -> None: generator_inputs = [Instance([x.args[0]], x.kwargs) for x in inputs] diff --git a/templates/generator/template.py b/templates/generator/template.py index ceff7d5a..34cffe8e 100644 --- a/templates/generator/template.py +++ b/templates/generator/template.py @@ -6,7 +6,7 @@ from pandas import DataFrame # Local -from fms_dgt.base.block import BaseGeneratorBlock +from fms_dgt.base.block import BLOCK_INPUT_TYPE, BaseGeneratorBlock from fms_dgt.base.registry import register_block @@ -19,7 +19,7 @@ def __init__(self, name: str, config: Dict, **kwargs: Any) -> None: def __call__( self, - inputs: Union[List[Dict], DataFrame, Dataset], + inputs: BLOCK_INPUT_TYPE, *args: Any, arg_fields: Optional[List[str]] = None, kwarg_fields: Optional[List[str]] = None, diff --git a/templates/validator/template.py b/templates/validator/template.py index fedb93de..92da5e6a 100644 --- a/templates/validator/template.py +++ b/templates/validator/template.py @@ -17,24 +17,6 @@ class TemplateValidator(BaseValidatorBlock): def __init__(self, name: str, config: Dict) -> None: super().__init__(name, config) - def __call__( - self, - inputs: Union[List[Dict], DataFrame, Dataset], - *args: Any, - arg_fields: Optional[List[str]] = None, - kwarg_fields: Optional[List[str]] = None, - result_field: Optional[List[str]] = None, - **kwargs: Any, - ) -> None: - return super().__call__( - inputs, - *args, - arg_fields=arg_fields, - kwarg_fields=kwarg_fields, - result_field=result_field, - **kwargs, - ) - def _validate(self, *args, **kwargs) -> bool: """Return True if valid and False otherwise""" return True diff --git a/tests/compatibility_tests/blocks.py b/tests/compatibility_tests/blocks.py index f284f9b7..68efe1b6 100644 --- a/tests/compatibility_tests/blocks.py +++ b/tests/compatibility_tests/blocks.py @@ -8,8 +8,8 @@ from sdg.src.instructlab.sdg.filterblock import FilterByValueBlock import pandas as pd -# First Party -from fms_dgt.base.block import BaseBlock +# Local +from fms_dgt.base.block import BLOCK_INPUT_TYPE, BaseBlock class TestFilterBlock(BaseBlock): @@ -30,7 +30,7 @@ def __init__( def __call__( self, - inputs: Union[List[Dict], pd.DataFrame, Dataset], + inputs: BLOCK_INPUT_TYPE, **kwargs: Any, ) -> Any: return self._filter_block.generate(inputs) From a94c7ebf3fbca82eb51d08a87eff57ff72d30669 Mon Sep 17 00:00:00 2001 From: Max Crouse Date: Tue, 9 Jul 2024 13:24:22 -0500 Subject: [PATCH 19/41] call to generate --- fms_dgt/base/block.py | 4 +- fms_dgt/blocks/generators/llm.py | 4 +- fms_dgt/blocks/utilities/flatten_field.py | 2 +- fms_dgt/databuilders/api/generate.py | 6 +- fms_dgt/databuilders/nl2sql/generate.py | 7 +- fms_dgt/databuilders/simple/generate.py | 4 +- templates/generator/template.py | 2 +- tests/blocks/generators/test_llm.py | 14 +- tests/blocks/validators/test_api.py | 186 ++++++++++++++++------ tests/blocks/validators/test_rouge.py | 14 +- tests/compatibility_tests/blocks.py | 4 +- 11 files changed, 172 insertions(+), 75 deletions(-) diff --git a/fms_dgt/base/block.py b/fms_dgt/base/block.py index c4a36350..066f6b10 100644 --- a/fms_dgt/base/block.py +++ b/fms_dgt/base/block.py @@ -81,7 +81,7 @@ def write_result( else: raise ValueError(f"Unexpected input type: {type(inp)}") - def __call__( + def generate( self, inputs: BLOCK_INPUT_TYPE, arg_fields: Optional[List[str]] = None, @@ -104,7 +104,7 @@ def __init__(self, filter: bool = False, **kwargs: Any) -> None: super().__init__(**kwargs) self._filter_invalids = filter - def __call__( + def generate( self, inputs: BLOCK_INPUT_TYPE, arg_fields: Optional[List[str]] = None, diff --git a/fms_dgt/blocks/generators/llm.py b/fms_dgt/blocks/generators/llm.py index 32979362..845107e2 100644 --- a/fms_dgt/blocks/generators/llm.py +++ b/fms_dgt/blocks/generators/llm.py @@ -97,7 +97,7 @@ def loglikelihood_batch( def set_cache_hook(self, cache_hook) -> None: self.cache_hook = cache_hook - def __call__( + def generate( self, inputs: BLOCK_INPUT_TYPE, arg_fields: Optional[List[str]] = None, @@ -248,7 +248,7 @@ def fn(requests: List[Instance]): return fn - def __call__( + def generate( self, inputs: BLOCK_INPUT_TYPE, arg_fields: Optional[List[str]] = None, diff --git a/fms_dgt/blocks/utilities/flatten_field.py b/fms_dgt/blocks/utilities/flatten_field.py index bc7f4f4c..a9e81d09 100644 --- a/fms_dgt/blocks/utilities/flatten_field.py +++ b/fms_dgt/blocks/utilities/flatten_field.py @@ -15,7 +15,7 @@ class FlattenField(BaseUtilityBlock): """Flatten specified args""" - def __call__( + def generate( self, inputs: BLOCK_INPUT_TYPE, arg_fields: Optional[List[str]] = None, diff --git a/fms_dgt/databuilders/api/generate.py b/fms_dgt/databuilders/api/generate.py index ae0566a0..d7bc74cd 100644 --- a/fms_dgt/databuilders/api/generate.py +++ b/fms_dgt/databuilders/api/generate.py @@ -57,7 +57,7 @@ def __call__( gen_inputs.append(inp) request_start = time.time() - llm_outputs = self.llm1( + llm_outputs = self.llm1.generate( gen_inputs, arg_fields=["prompt"], kwarg_fields=["stop_sequences"], @@ -130,7 +130,7 @@ def _wf_filter_data(self, data_to_filter: List[Dict]): # filter invalid data outputs = [ output["data"] - for output in self.val1( + for output in self.val1.generate( val1_inputs, arg_fields=["new_apis", "question", "answer"], kwarg_fields=[ @@ -168,7 +168,7 @@ def _rouge_filter_data(self, data_to_filter: List[ApiSdgData]): # filter rouge failed data outputs = [ output["data"] - for output in self.val2( + for output in self.val2.generate( val2_inputs, arg_fields=["new_instruction_tokens", "all_instruction_tokens"], result_field="output", diff --git a/fms_dgt/databuilders/nl2sql/generate.py b/fms_dgt/databuilders/nl2sql/generate.py index 5eff3ee6..25282668 100644 --- a/fms_dgt/databuilders/nl2sql/generate.py +++ b/fms_dgt/databuilders/nl2sql/generate.py @@ -42,6 +42,7 @@ def __init__( # val1 is the validator which checks SQL syntax val1: SQLSyntaxValidator + # val2 is the validator which checks SQL execution val2: SQLExecutionValidator @@ -72,7 +73,7 @@ def __call__( instances = prompting_pipeline.run( data_generation_schema=data_generation_schema ) - llm_outputs = self.llm1( + llm_outputs = self.llm1.generate( instances, arg_fields=["prompt"], result_field="output" ) @@ -110,12 +111,12 @@ def __call__( } for sql_schema, utterance, sql_query in processed_outputs ] - filtered_output = self.val1( + filtered_output = self.val1.generate( instances_for_validation, kwarg_fields=["record", "sql_dialect"], result_field="output", ) - filtered_output = self.val2( + filtered_output = self.val2.generate( filtered_output, kwarg_fields=["record", "sql_dialect"], result_field="output", diff --git a/fms_dgt/databuilders/simple/generate.py b/fms_dgt/databuilders/simple/generate.py index 004bd2d8..8a9921c4 100644 --- a/fms_dgt/databuilders/simple/generate.py +++ b/fms_dgt/databuilders/simple/generate.py @@ -75,7 +75,7 @@ def __call__( request_start = time.time() - llm_outputs = self.llm1(inputs) + llm_outputs = self.llm1.generate(inputs) request_duration = time.time() - request_start post_process_start = time.time() @@ -124,7 +124,7 @@ def __call__( val1_inputs.append(inp) # filter rouge failed data - outputs = [output["data"] for output in self.val1(val1_inputs)] + outputs = [output["data"] for output in self.val1.generate(val1_inputs)] discarded += len(val1_inputs) - len(outputs) diff --git a/templates/generator/template.py b/templates/generator/template.py index 34cffe8e..f83e3948 100644 --- a/templates/generator/template.py +++ b/templates/generator/template.py @@ -17,7 +17,7 @@ class TemplateGenerator(BaseGeneratorBlock): def __init__(self, name: str, config: Dict, **kwargs: Any) -> None: super().__init__(name, config, **kwargs) - def __call__( + def generate( self, inputs: BLOCK_INPUT_TYPE, *args: Any, diff --git a/tests/blocks/generators/test_llm.py b/tests/blocks/generators/test_llm.py index ba7bbc31..46751fad 100644 --- a/tests/blocks/generators/test_llm.py +++ b/tests/blocks/generators/test_llm.py @@ -61,7 +61,7 @@ def test_generate_batch(self, model_cfg): inputs_copy = copy.deepcopy(inputs) - lm(inputs, arg_fields=["prompt"], result_field="output") + lm.generate(inputs, arg_fields=["prompt"], result_field="output") for i, inp in enumerate(inputs): assert ( @@ -82,7 +82,7 @@ def test_loglikelihood_batch(self, model_cfg): inputs_copy = copy.deepcopy(inputs) - lm( + lm.generate( inputs, arg_fields=["prompt1", "prompt2"], result_field="output", @@ -143,7 +143,7 @@ def test_lm_caching(self): post_cache_inputs = copy.deepcopy(non_cache_inputs) non_cache_time = time.time() - lm(non_cache_inputs, arg_fields=["prompt"], result_field="output") + lm.generate(non_cache_inputs, arg_fields=["prompt"], result_field="output") non_cache_time = time.time() - non_cache_time cache_lm = CachingLM( @@ -152,11 +152,15 @@ def test_lm_caching(self): ) pre_cache_time = time.time() - cache_lm(pre_cache_inputs, arg_fields=["prompt"], result_field="output") + cache_lm.generate( + pre_cache_inputs, arg_fields=["prompt"], result_field="output" + ) pre_cache_time = time.time() - pre_cache_time post_cache_time = time.time() - cache_lm(post_cache_inputs, arg_fields=["prompt"], result_field="output") + cache_lm.generate( + post_cache_inputs, arg_fields=["prompt"], result_field="output" + ) post_cache_time = time.time() - post_cache_time os.remove(cache_path) diff --git a/tests/blocks/validators/test_api.py b/tests/blocks/validators/test_api.py index 600b2047..3fc75945 100644 --- a/tests/blocks/validators/test_api.py +++ b/tests/blocks/validators/test_api.py @@ -51,11 +51,22 @@ def test_single_intent(self): func_calls = [{"name": "add"}] question = "add 3 with 4" api_info = get_args(func_calls) - args = [api_info, question, json.dumps(func_calls)] - test_instance = [Instance(args, single_intent_kwargs)] - validator(test_instance) - assert test_instance[0].result + test_instance = [ + { + "a": api_info, + "b": question, + "c": json.dumps(func_calls), + **single_intent_kwargs, + } + ] + validator.generate( + test_instance, + arg_fields=["a", "b", "c"], + kwarg_fields=list(single_intent_kwargs.keys()), + result_field="result", + ) + assert test_instance[0]["result"] def test_multi_intent(self): validator = APIGenSpecValidator(name="test_multi_intent") @@ -66,11 +77,22 @@ def test_multi_intent(self): ] question = "add 3 with 4" api_info = get_args(func_calls) - args = [api_info, question, json.dumps(func_calls)] - test_instance = [Instance(args, multi_intent_kwargs)] - validator(test_instance) - assert test_instance[0].result + test_instance = [ + { + "a": api_info, + "b": question, + "c": json.dumps(func_calls), + **multi_intent_kwargs, + } + ] + validator.generate( + test_instance, + arg_fields=["a", "b", "c"], + kwarg_fields=list(multi_intent_kwargs.keys()), + result_field="result", + ) + assert test_instance[0]["result"] def test_parallel_single(self): validator = APIGenSpecValidator(name="test_parallel_single") @@ -82,11 +104,22 @@ def test_parallel_single(self): ] question = "add 3 with 4 then add 4 with 5" api_info = get_args(func_calls) - args = [api_info, question, json.dumps(func_calls)] - test_instance = [Instance(args, parallel_kwargs)] - validator(test_instance) - assert test_instance[0].result + test_instance = [ + { + "a": api_info, + "b": question, + "c": json.dumps(func_calls), + **parallel_kwargs, + } + ] + validator.generate( + test_instance, + arg_fields=["a", "b", "c"], + kwarg_fields=list(parallel_kwargs.keys()), + result_field="result", + ) + assert test_instance[0]["result"] func_calls = [ {"name": "add", "arguments": {"n1": 3}}, @@ -94,13 +127,24 @@ def test_parallel_single(self): ] question = "add 3 with 4" api_info = get_args(func_calls) - args = [api_info, question, json.dumps(func_calls)] - test_instance = [Instance(args, parallel_kwargs)] - validator(test_instance) - assert not test_instance[ - 0 - ].result, "Validator should have failed due to required args!" + test_instance = [ + { + "a": api_info, + "b": question, + "c": json.dumps(func_calls), + **parallel_kwargs, + } + ] + validator.generate( + test_instance, + arg_fields=["a", "b", "c"], + kwarg_fields=list(parallel_kwargs.keys()), + result_field="result", + ) + assert not test_instance[0][ + "result" + ], "Validator should have failed due to required args!" def test_parallel_multiple(self): validator = APIGenSpecValidator(name="test_parallel_multiple") @@ -112,11 +156,22 @@ def test_parallel_multiple(self): ] question = "add 3 with 4 and add an event store to my calendar" api_info = get_args(func_calls) - args = [api_info, question, json.dumps(func_calls)] - test_instance = [Instance(args, parallel_kwargs)] - validator(test_instance) - assert test_instance[0].result + test_instance = [ + { + "a": api_info, + "b": question, + "c": json.dumps(func_calls), + **parallel_kwargs, + } + ] + validator.generate( + test_instance, + arg_fields=["a", "b", "c"], + kwarg_fields=list(parallel_kwargs.keys()), + result_field="result", + ) + assert test_instance[0]["result"] func_calls = [ {"name": "add", "arguments": {"n1": 3, "n2": 4}}, @@ -124,15 +179,24 @@ def test_parallel_multiple(self): ] question = "add 3 with 4 and add an event store to my calendar" api_info = get_args(func_calls) - args = [api_info, question, json.dumps(func_calls)] - - test_instance = [Instance(args, parallel_kwargs)] - validator(test_instance) - assert not test_instance[ - 0 - ].result, ( - "Validator should have failed due to arg content not being in question!" + + test_instance = [ + { + "a": api_info, + "b": question, + "c": json.dumps(func_calls), + **parallel_kwargs, + } + ] + validator.generate( + test_instance, + arg_fields=["a", "b", "c"], + kwarg_fields=list(parallel_kwargs.keys()), + result_field="result", ) + assert not test_instance[0][ + "result" + ], "Validator should have failed due to arg content not being in question!" def test_parallel_nested(self): validator = APIGenSpecValidator(name="test_parallel_nested") @@ -144,11 +208,22 @@ def test_parallel_nested(self): ] question = "add 3 with 4 and add an event with the result of the earlier addition to my calendar" api_info = get_args(func_calls) - args = [api_info, question, json.dumps(func_calls)] - test_instance = [Instance(args, parallel_nested_kwargs)] - validator(test_instance) - assert test_instance[0].result + test_instance = [ + { + "a": api_info, + "b": question, + "c": json.dumps(func_calls), + **parallel_nested_kwargs, + } + ] + validator.generate( + test_instance, + arg_fields=["a", "b", "c"], + kwarg_fields=list(parallel_nested_kwargs.keys()), + result_field="result", + ) + assert test_instance[0]["result"] func_calls = [ {"name": "add", "arguments": {"n1": 3, "n2": 4}}, @@ -156,24 +231,43 @@ def test_parallel_nested(self): ] question = "add 3 with 4 and add an event store to my calendar" api_info = get_args(func_calls) - args = [api_info, question, json.dumps(func_calls)] - - test_instance = [Instance(args, parallel_nested_kwargs)] - validator(test_instance) - assert not test_instance[ - 0 - ].result, ( - "Validator should have failed due to arg content not being in question!" + + test_instance = [ + { + "a": api_info, + "b": question, + "c": json.dumps(func_calls), + **parallel_nested_kwargs, + } + ] + validator.generate( + test_instance, + arg_fields=["a", "b", "c"], + kwarg_fields=list(parallel_nested_kwargs.keys()), + result_field="result", ) + assert not test_instance[0][ + "result" + ], "Validator should have failed due to arg content not being in question!" def test_yes_no(self): validator = ApiGenSpecYesNoValidation(name="test_yes_no") for arg_inp in ["YES", "NO", "MAYBE"]: - args = [TEST_APIS, "this is a test question", arg_inp] - test_instance = [Instance(args)] - validator(test_instance) - assert test_instance[0].result == (arg_inp in ["YES", "NO"]) + + test_instance = [ + { + "a": TEST_APIS, + "b": "this is a test question", + "c": arg_inp, + } + ] + validator.generate( + test_instance, + arg_fields=["a", "b", "c"], + result_field="result", + ) + assert test_instance[0]["result"] == (arg_inp in ["YES", "NO"]) TEST_APIS = { diff --git a/tests/blocks/validators/test_rouge.py b/tests/blocks/validators/test_rouge.py index b46e11fc..fede1a0c 100644 --- a/tests/blocks/validators/test_rouge.py +++ b/tests/blocks/validators/test_rouge.py @@ -23,16 +23,14 @@ def test_matches(self): data_entry = "I went to the store" new_tokens = validator.tokenize(data_entry) - args = [new_tokens, all_tokens] - inputs = [Instance(args)] + inputs = [{"a": new_tokens, "b": all_tokens}] validator._threshold = 0.91 - validator.validate_batch(inputs) - assert inputs[0].result + validator.generate(inputs, arg_fields=["a", "b"], result_field="result") + assert inputs[0]["result"] data_entry = "one two three" new_tokens = validator.tokenize(data_entry) - args = [new_tokens, all_tokens] - inputs = [Instance(args)] + inputs = [{"a": new_tokens, "b": all_tokens}] validator._threshold = 0.0 - validator.validate_batch(inputs) - assert not inputs[0].result + validator.generate(inputs, arg_fields=["a", "b"], result_field="result") + assert not inputs[0]["result"] diff --git a/tests/compatibility_tests/blocks.py b/tests/compatibility_tests/blocks.py index 68efe1b6..c86ae062 100644 --- a/tests/compatibility_tests/blocks.py +++ b/tests/compatibility_tests/blocks.py @@ -28,7 +28,7 @@ def __init__( operation=self.config.get("operation"), ) - def __call__( + def generate( self, inputs: BLOCK_INPUT_TYPE, **kwargs: Any, @@ -48,7 +48,7 @@ def main(): "operation": operator.ne, }, ) - ret_dataset: Dataset = test_block(dataset) + ret_dataset: Dataset = test_block.generate(dataset) print(json.dumps(dataset.to_dict(), indent=4)) print("\n=====\n") From d06a26dc394a493bff393e038d433836e70c7bc1 Mon Sep 17 00:00:00 2001 From: Max Crouse Date: Tue, 9 Jul 2024 13:25:18 -0500 Subject: [PATCH 20/41] call to generate --- tests/blocks/validators/test_rouge.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/blocks/validators/test_rouge.py b/tests/blocks/validators/test_rouge.py index fede1a0c..2a13d8ac 100644 --- a/tests/blocks/validators/test_rouge.py +++ b/tests/blocks/validators/test_rouge.py @@ -6,7 +6,6 @@ import pytest # Local -from fms_dgt.base.instance import Instance from fms_dgt.blocks.validators.rouge import RougeValidator From b7fcb59622d9ef905708d552da003f75376852ae Mon Sep 17 00:00:00 2001 From: Max Crouse Date: Tue, 9 Jul 2024 14:08:54 -0500 Subject: [PATCH 21/41] non base functions --- fms_dgt/base/block.py | 94 ++++++++++++++++++-------------- fms_dgt/blocks/generators/llm.py | 21 +++++-- 2 files changed, 68 insertions(+), 47 deletions(-) diff --git a/fms_dgt/base/block.py b/fms_dgt/base/block.py index 066f6b10..cbd8a4fa 100644 --- a/fms_dgt/base/block.py +++ b/fms_dgt/base/block.py @@ -1,13 +1,12 @@ # Standard from abc import ABC -from typing import Any, Dict, Iterable, List, Optional, Type, Union -import abc +from typing import Any, Dict, Iterable, List, Optional, Union # Third Party from datasets import Dataset import pandas as pd -BLOCK_ROW_TYPE = Union[Dict, Type[pd.Series]] +BLOCK_ROW_TYPE = Union[Dict, pd.Series] BLOCK_INPUT_TYPE = Union[Iterable[BLOCK_ROW_TYPE], pd.DataFrame, Dataset] @@ -43,43 +42,17 @@ def __init__( def name(self): return self._name - def get_args_kwargs( - self, - inp: Union[Dict, pd.DataFrame, Dataset], - arg_fields: Optional[List[str]] = None, - kwarg_fields: Optional[List[str]] = None, - ): - arg_fields = arg_fields if arg_fields is not None else self._arg_fields - kwarg_fields = kwarg_fields if kwarg_fields is not None else self._kwarg_fields - - if arg_fields is None: - arg_fields = [] - if kwarg_fields is None: - kwarg_fields = [] - - if type(inp) == dict: - args = [inp.get(arg) for arg in arg_fields] - kwargs = {kwarg: inp.get(kwarg) for kwarg in kwarg_fields} - elif type(inp) in [pd.DataFrame, Dataset]: - args = [inp.get(arg) for arg in arg_fields] - kwargs = {kwarg: inp.get(kwarg) for kwarg in kwarg_fields} - else: - raise ValueError(f"Unexpected input type: {type(inp)}") - - return args, kwargs - - def write_result( - self, inp: Union[Dict, pd.DataFrame, Dataset], res: Any, result_field: str - ): - result_field = result_field if result_field is not None else self._result_field - assert result_field is not None, "Result field cannot be None!" + @property + def arg_fields(self): + return self._arg_fields - if type(inp) == dict: - inp[result_field] = res - elif type(inp) in [pd.DataFrame, Dataset]: - inp[result_field] = res - else: - raise ValueError(f"Unexpected input type: {type(inp)}") + @property + def kwarg_fields(self): + return self._kwarg_fields + + @property + def result_field(self): + return self._result_field def generate( self, @@ -113,12 +86,51 @@ def generate( ): outputs = [] for x in inputs: - inp_args, inp_kwargs = self.get_args_kwargs(x, arg_fields, kwarg_fields) + inp_args, inp_kwargs = get_args_kwargs( + x, arg_fields or self.arg_fields, kwarg_fields or self.kwarg_fields + ) res = self._validate(*inp_args, **inp_kwargs) if res or not self._filter_invalids: - self.write_result(x, res, result_field) + write_result(x, res, result_field or self.result_field) outputs.append(x) return outputs def _validate(self, *args: Any, **kwargs: Any) -> bool: raise NotImplementedError + + +def get_args_kwargs( + inp: BLOCK_ROW_TYPE, + arg_fields: Optional[List[str]] = None, + kwarg_fields: Optional[List[str]] = None, +): + if arg_fields is None: + arg_fields = [] + if kwarg_fields is None: + kwarg_fields = [] + + if type(inp) == dict: + args = [inp.get(arg) for arg in arg_fields] + kwargs = {kwarg: inp.get(kwarg) for kwarg in kwarg_fields} + elif type(inp) in [pd.DataFrame, Dataset]: + args = [inp.get(arg) for arg in arg_fields] + kwargs = {kwarg: inp.get(kwarg) for kwarg in kwarg_fields} + else: + raise ValueError(f"Unexpected input type: {type(inp)}") + + return args, kwargs + + +def write_result( + inp: BLOCK_ROW_TYPE, + res: Any, + result_field: str, +): + assert result_field is not None, "Result field cannot be None!" + + if type(inp) == dict: + inp[result_field] = res + elif type(inp) in [pd.DataFrame, Dataset]: + inp[result_field] = res + else: + raise ValueError(f"Unexpected input type: {type(inp)}") diff --git a/fms_dgt/blocks/generators/llm.py b/fms_dgt/blocks/generators/llm.py index 845107e2..7c7bc8f9 100644 --- a/fms_dgt/blocks/generators/llm.py +++ b/fms_dgt/blocks/generators/llm.py @@ -26,7 +26,12 @@ import pandas as pd # Local -from fms_dgt.base.block import BLOCK_INPUT_TYPE, BaseGeneratorBlock +from fms_dgt.base.block import ( + BLOCK_INPUT_TYPE, + BaseGeneratorBlock, + get_args_kwargs, + write_result, +) from fms_dgt.base.instance import Instance from fms_dgt.utils import sdg_logger @@ -110,7 +115,9 @@ def generate( # simplify generation here instances: List[Instance] = [] for inp in inputs: - inp_args, inp_kwargs = self.get_args_kwargs(inp, arg_fields, kwarg_fields) + inp_args, inp_kwargs = get_args_kwargs( + inp, arg_fields or self.arg_fields, kwarg_fields or self.kwarg_fields + ) instances.append(Instance(args=inp_args, kwargs=inp_kwargs, data=inp)) if method == "generate": @@ -133,7 +140,7 @@ def generate( outputs = [] for inst in instances: - self.write_result(inst.data, inst.result, result_field) + write_result(inst.data, inst.result, result_field or self.result_field) outputs.append(inst.data) return outputs @@ -261,8 +268,10 @@ def generate( # simplify generation here instances: List[Instance] = [] for inp in inputs: - inp_args, inp_kwargs = self.lm.get_args_kwargs( - inp, arg_fields, kwarg_fields + inp_args, inp_kwargs = get_args_kwargs( + inp, + arg_fields or self.lm.arg_fields, + kwarg_fields or self.lm.kwarg_fields, ) instances.append(Instance(args=inp_args, kwargs=inp_kwargs, data=inp)) @@ -286,7 +295,7 @@ def generate( outputs = [] for inst in instances: - self.lm.write_result(inst.data, inst.result, result_field) + write_result(inst.data, inst.result, result_field or self.lm.result_field) outputs.append(inst.data) return outputs From cb29984d2a9f3346650fb02bd05002402be63daf Mon Sep 17 00:00:00 2001 From: Max Crouse Date: Wed, 10 Jul 2024 08:28:28 -0500 Subject: [PATCH 22/41] registry change --- fms_dgt/base/registry.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/fms_dgt/base/registry.py b/fms_dgt/base/registry.py index faa7abfb..5ff3a377 100644 --- a/fms_dgt/base/registry.py +++ b/fms_dgt/base/registry.py @@ -1,17 +1,11 @@ # Standard from typing import Any -import logging # Local from fms_dgt.base.block import BaseBlock from fms_dgt.base.dataloader import BaseDataloader from fms_dgt.base.resource import BaseResource -eval_logger = logging.getLogger("fms_dgt") - -# TODO: generator registry, validator registry, task registry - - BLOCK_REGISTRY = {} From b49ffa1a98f81f72cf6fde12e2194856a0782949 Mon Sep 17 00:00:00 2001 From: Max Crouse Date: Wed, 10 Jul 2024 08:50:57 -0500 Subject: [PATCH 23/41] genai req bug fix --- fms_dgt/blocks/generators/llm.py | 32 +++++++++++++++++++------------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/fms_dgt/blocks/generators/llm.py b/fms_dgt/blocks/generators/llm.py index 7c7bc8f9..f6c481e2 100644 --- a/fms_dgt/blocks/generators/llm.py +++ b/fms_dgt/blocks/generators/llm.py @@ -13,17 +13,13 @@ # Standard from typing import Any, Dict, List, Optional, Union import abc -import copy import hashlib import json import os # Third Party -from datasets import Dataset -from genai.schema import TextGenerationParameters from sqlitedict import SqliteDict from tqdm import tqdm -import pandas as pd # Local from fms_dgt.base.block import ( @@ -44,16 +40,14 @@ class LMGenerator(BaseGeneratorBlock): def __init__( self, model_id_or_path: str = None, + decoding_method: str = "sample", + max_new_tokens: int = None, + min_new_tokens: int = None, + random_seed: int = None, + stop_sequences: List[str] = None, + temperature: float = None, **kwargs: Any, ): - # TODO: define exact kwargs that are supported - default_kwargs = {"decoding_method": "sample"} - cfg_kwargs = { - k: kwargs.pop(k) - for k in copy.copy(kwargs) - if k in TextGenerationParameters.model_fields - } - super().__init__(**kwargs) self._rank = 0 @@ -64,7 +58,19 @@ def __init__( self.model_id_or_path is not None ), f"Must specify model for Generator {self.name}" - self._base_kwargs = {**default_kwargs, **cfg_kwargs} + cfg_kwargs = dict() + for k, v in { + "decoding_method": decoding_method, + "max_new_tokens": max_new_tokens, + "min_new_tokens": min_new_tokens, + "random_seed": random_seed, + "stop_sequences": stop_sequences, + "temperature": temperature, + }.items(): + if v is not None: + cfg_kwargs[k] = v + + self._base_kwargs = cfg_kwargs @property def rank(self): From c7a49a6bd7afd29ebfce088bdfd3aff82d2abe89 Mon Sep 17 00:00:00 2001 From: Max Crouse Date: Wed, 10 Jul 2024 08:53:23 -0500 Subject: [PATCH 24/41] Update fms_dgt/base/block.py Co-authored-by: Gabe Goodhart --- fms_dgt/base/block.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/fms_dgt/base/block.py b/fms_dgt/base/block.py index cbd8a4fa..b0bb8b69 100644 --- a/fms_dgt/base/block.py +++ b/fms_dgt/base/block.py @@ -54,14 +54,32 @@ def kwarg_fields(self): def result_field(self): return self._result_field + @abstractmethod def generate( self, inputs: BLOCK_INPUT_TYPE, + *, arg_fields: Optional[List[str]] = None, kwarg_fields: Optional[List[str]] = None, result_field: Optional[str] = None, + **kwargs, ): - raise NotImplementedError + """The generate function is the primary interface to a Block + + args: + inputs (BLOCK_INPUT_TYPE): A block operates over a logical iterable + of rows with named columns (see BLOCK_INPUT_TYPE) + + kwargs: + arg_fields (Optional[List[str]]): Names of fields within the rows of + the inputs that should be extracted and passed as positional + args to the underlying implementation methods. + kwarg_fields (Optional[List[str]]): Names of fields within the rows + of the inputs that should be extracted and passed as keyword + args to the underlying implementation methods. + **kwargs: Additional keyword args that may be passed to the derived + block's generate function + """ class BaseUtilityBlock(BaseBlock): From 80b426f83fe53a21877698725948a328622c795b Mon Sep 17 00:00:00 2001 From: Max Crouse Date: Wed, 10 Jul 2024 08:53:49 -0500 Subject: [PATCH 25/41] Update fms_dgt/base/block.py Co-authored-by: Gabe Goodhart --- fms_dgt/base/block.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/fms_dgt/base/block.py b/fms_dgt/base/block.py index b0bb8b69..483e8ddc 100644 --- a/fms_dgt/base/block.py +++ b/fms_dgt/base/block.py @@ -122,10 +122,8 @@ def get_args_kwargs( arg_fields: Optional[List[str]] = None, kwarg_fields: Optional[List[str]] = None, ): - if arg_fields is None: - arg_fields = [] - if kwarg_fields is None: - kwarg_fields = [] + arg_fields = arg_fields or [] + kwarg_fields = or kwarg_fields or [] if type(inp) == dict: args = [inp.get(arg) for arg in arg_fields] From e2bc1dc907f223f5ee0f174293168a50a3b93b5e Mon Sep 17 00:00:00 2001 From: Max Crouse Date: Wed, 10 Jul 2024 08:54:05 -0500 Subject: [PATCH 26/41] Update fms_dgt/base/block.py Co-authored-by: Gabe Goodhart --- fms_dgt/base/block.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fms_dgt/base/block.py b/fms_dgt/base/block.py index 483e8ddc..4adefca6 100644 --- a/fms_dgt/base/block.py +++ b/fms_dgt/base/block.py @@ -1,5 +1,5 @@ # Standard -from abc import ABC +from abc import ABC, abstractmethod from typing import Any, Dict, Iterable, List, Optional, Union # Third Party From 79278fd1912c12a80a941d35a35ff8ebcad9ce6e Mon Sep 17 00:00:00 2001 From: Max Crouse Date: Wed, 10 Jul 2024 08:59:08 -0500 Subject: [PATCH 27/41] Update fms_dgt/base/block.py Co-authored-by: Gabe Goodhart --- fms_dgt/base/block.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/fms_dgt/base/block.py b/fms_dgt/base/block.py index 4adefca6..e4d7fcfd 100644 --- a/fms_dgt/base/block.py +++ b/fms_dgt/base/block.py @@ -144,9 +144,7 @@ def write_result( ): assert result_field is not None, "Result field cannot be None!" - if type(inp) == dict: - inp[result_field] = res - elif type(inp) in [pd.DataFrame, Dataset]: + if isinstance(inp, (dict, pd.DataFrame, Dataset): inp[result_field] = res else: raise ValueError(f"Unexpected input type: {type(inp)}") From fed9702420940c29cb46ca90030076d9f13e4b5b Mon Sep 17 00:00:00 2001 From: Max Crouse Date: Wed, 10 Jul 2024 09:11:18 -0500 Subject: [PATCH 28/41] throw type error --- fms_dgt/base/block.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/fms_dgt/base/block.py b/fms_dgt/base/block.py index cbd8a4fa..fc5b26c0 100644 --- a/fms_dgt/base/block.py +++ b/fms_dgt/base/block.py @@ -21,18 +21,14 @@ def __init__( result_field: str = None, ) -> None: - self._name = name + if not (isinstance(arg_fields, list) or arg_fields is None): + raise TypeError(f"arg_fields must be of type 'list'") + if not (isinstance(kwarg_fields, list) or kwarg_fields is None): + raise TypeError(f"kwarg_fields must be of type 'list'") + if not (isinstance(result_field, str) or result_field is None): + raise TypeError(f"result_field must be of type 'str'") - # minor type checking - if type(arg_fields) == str: - arg_fields = [arg_fields] - if type(kwarg_fields) == str: - kwarg_fields = [kwarg_fields] - if type(result_field) == list: - assert ( - len(result_field) == 1 - ), "Cannot have multiple 'result' fields for {name}" - result_field = result_field[0] + self._name = name self._arg_fields = arg_fields self._kwarg_fields = kwarg_fields From ed9e13b423b327c7783e8f0afeab23ce36f8e75c Mon Sep 17 00:00:00 2001 From: Max Crouse Date: Wed, 10 Jul 2024 09:46:31 -0500 Subject: [PATCH 29/41] instance methods --- fms_dgt/base/block.py | 74 ++++++++++++++++---------------- fms_dgt/blocks/generators/llm.py | 21 +++------ 2 files changed, 43 insertions(+), 52 deletions(-) diff --git a/fms_dgt/base/block.py b/fms_dgt/base/block.py index 5dc055a6..e7f5d680 100644 --- a/fms_dgt/base/block.py +++ b/fms_dgt/base/block.py @@ -50,6 +50,39 @@ def kwarg_fields(self): def result_field(self): return self._result_field + def get_args_kwargs( + self, + inp: BLOCK_ROW_TYPE, + arg_fields: Optional[List[str]] = None, + kwarg_fields: Optional[List[str]] = None, + ): + + arg_fields = arg_fields or self.arg_fields or [] + kwarg_fields = kwarg_fields or self.kwarg_fields or [] + + if isinstance(inp, (dict, pd.DataFrame, Dataset)): + args = [inp.get(arg) for arg in arg_fields] + kwargs = {kwarg: inp.get(kwarg) for kwarg in kwarg_fields} + else: + raise TypeError(f"Unexpected input type: {type(inp)}") + + return args, kwargs + + def write_result( + self, + inp: BLOCK_ROW_TYPE, + res: Any, + result_field: str = None, + ): + result_field = result_field or self.result_field + + assert result_field is not None, "Result field cannot be None!" + + if isinstance(inp, (dict, pd.DataFrame, Dataset)): + inp[result_field] = res + else: + raise TypeError(f"Unexpected input type: {type(inp)}") + @abstractmethod def generate( self, @@ -61,7 +94,7 @@ def generate( **kwargs, ): """The generate function is the primary interface to a Block - + args: inputs (BLOCK_INPUT_TYPE): A block operates over a logical iterable of rows with named columns (see BLOCK_INPUT_TYPE) @@ -100,47 +133,12 @@ def generate( ): outputs = [] for x in inputs: - inp_args, inp_kwargs = get_args_kwargs( - x, arg_fields or self.arg_fields, kwarg_fields or self.kwarg_fields - ) + inp_args, inp_kwargs = self.get_args_kwargs(x, arg_fields, kwarg_fields) res = self._validate(*inp_args, **inp_kwargs) if res or not self._filter_invalids: - write_result(x, res, result_field or self.result_field) + self.write_result(x, res, result_field) outputs.append(x) return outputs def _validate(self, *args: Any, **kwargs: Any) -> bool: raise NotImplementedError - - -def get_args_kwargs( - inp: BLOCK_ROW_TYPE, - arg_fields: Optional[List[str]] = None, - kwarg_fields: Optional[List[str]] = None, -): - arg_fields = arg_fields or [] - kwarg_fields = or kwarg_fields or [] - - if type(inp) == dict: - args = [inp.get(arg) for arg in arg_fields] - kwargs = {kwarg: inp.get(kwarg) for kwarg in kwarg_fields} - elif type(inp) in [pd.DataFrame, Dataset]: - args = [inp.get(arg) for arg in arg_fields] - kwargs = {kwarg: inp.get(kwarg) for kwarg in kwarg_fields} - else: - raise ValueError(f"Unexpected input type: {type(inp)}") - - return args, kwargs - - -def write_result( - inp: BLOCK_ROW_TYPE, - res: Any, - result_field: str, -): - assert result_field is not None, "Result field cannot be None!" - - if isinstance(inp, (dict, pd.DataFrame, Dataset): - inp[result_field] = res - else: - raise ValueError(f"Unexpected input type: {type(inp)}") diff --git a/fms_dgt/blocks/generators/llm.py b/fms_dgt/blocks/generators/llm.py index f6c481e2..6e7ad8c1 100644 --- a/fms_dgt/blocks/generators/llm.py +++ b/fms_dgt/blocks/generators/llm.py @@ -22,12 +22,7 @@ from tqdm import tqdm # Local -from fms_dgt.base.block import ( - BLOCK_INPUT_TYPE, - BaseGeneratorBlock, - get_args_kwargs, - write_result, -) +from fms_dgt.base.block import BLOCK_INPUT_TYPE, BaseGeneratorBlock from fms_dgt.base.instance import Instance from fms_dgt.utils import sdg_logger @@ -121,9 +116,7 @@ def generate( # simplify generation here instances: List[Instance] = [] for inp in inputs: - inp_args, inp_kwargs = get_args_kwargs( - inp, arg_fields or self.arg_fields, kwarg_fields or self.kwarg_fields - ) + inp_args, inp_kwargs = self.get_args_kwargs(inp, arg_fields, kwarg_fields) instances.append(Instance(args=inp_args, kwargs=inp_kwargs, data=inp)) if method == "generate": @@ -146,7 +139,7 @@ def generate( outputs = [] for inst in instances: - write_result(inst.data, inst.result, result_field or self.result_field) + self.write_result(inst.data, inst.result, result_field) outputs.append(inst.data) return outputs @@ -274,10 +267,10 @@ def generate( # simplify generation here instances: List[Instance] = [] for inp in inputs: - inp_args, inp_kwargs = get_args_kwargs( + inp_args, inp_kwargs = self.lm.get_args_kwargs( inp, - arg_fields or self.lm.arg_fields, - kwarg_fields or self.lm.kwarg_fields, + arg_fields, + kwarg_fields, ) instances.append(Instance(args=inp_args, kwargs=inp_kwargs, data=inp)) @@ -301,7 +294,7 @@ def generate( outputs = [] for inst in instances: - write_result(inst.data, inst.result, result_field or self.lm.result_field) + self.lm.write_result(inst.data, inst.result, result_field) outputs.append(inst.data) return outputs From f78b81dd7374cd508636ea2844f34d57598bdf3d Mon Sep 17 00:00:00 2001 From: Max Crouse Date: Wed, 10 Jul 2024 10:34:09 -0500 Subject: [PATCH 30/41] dataset type --- fms_dgt/base/block.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fms_dgt/base/block.py b/fms_dgt/base/block.py index e7f5d680..118434ef 100644 --- a/fms_dgt/base/block.py +++ b/fms_dgt/base/block.py @@ -6,8 +6,8 @@ from datasets import Dataset import pandas as pd -BLOCK_ROW_TYPE = Union[Dict, pd.Series] -BLOCK_INPUT_TYPE = Union[Iterable[BLOCK_ROW_TYPE], pd.DataFrame, Dataset] +DATASET_ROW_TYPE = Union[Dict, pd.Series] +DATASET_TYPE = Union[Iterable[DATASET_ROW_TYPE], pd.DataFrame, Dataset] class BaseBlock(ABC): From 13701af651f1c684e7008bbde94ab24c2a5e6d34 Mon Sep 17 00:00:00 2001 From: Max Crouse Date: Wed, 10 Jul 2024 11:11:30 -0500 Subject: [PATCH 31/41] dataset type --- fms_dgt/base/block.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/fms_dgt/base/block.py b/fms_dgt/base/block.py index 118434ef..72ecaf09 100644 --- a/fms_dgt/base/block.py +++ b/fms_dgt/base/block.py @@ -52,7 +52,7 @@ def result_field(self): def get_args_kwargs( self, - inp: BLOCK_ROW_TYPE, + inp: DATASET_ROW_TYPE, arg_fields: Optional[List[str]] = None, kwarg_fields: Optional[List[str]] = None, ): @@ -70,7 +70,7 @@ def get_args_kwargs( def write_result( self, - inp: BLOCK_ROW_TYPE, + inp: DATASET_ROW_TYPE, res: Any, result_field: str = None, ): @@ -86,7 +86,7 @@ def write_result( @abstractmethod def generate( self, - inputs: BLOCK_INPUT_TYPE, + inputs: DATASET_TYPE, *, arg_fields: Optional[List[str]] = None, kwarg_fields: Optional[List[str]] = None, @@ -126,7 +126,7 @@ def __init__(self, filter: bool = False, **kwargs: Any) -> None: def generate( self, - inputs: BLOCK_INPUT_TYPE, + inputs: DATASET_TYPE, arg_fields: Optional[List[str]] = None, kwarg_fields: Optional[List[str]] = None, result_field: Optional[List[str]] = None, From 0953d9be9ccda60f365f403a189b096037bb4901 Mon Sep 17 00:00:00 2001 From: Max Crouse Date: Wed, 10 Jul 2024 11:16:22 -0500 Subject: [PATCH 32/41] dataset type --- fms_dgt/blocks/generators/llm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fms_dgt/blocks/generators/llm.py b/fms_dgt/blocks/generators/llm.py index 6e7ad8c1..2b7238f1 100644 --- a/fms_dgt/blocks/generators/llm.py +++ b/fms_dgt/blocks/generators/llm.py @@ -22,7 +22,7 @@ from tqdm import tqdm # Local -from fms_dgt.base.block import BLOCK_INPUT_TYPE, BaseGeneratorBlock +from fms_dgt.base.block import DATASET_TYPE, BaseGeneratorBlock from fms_dgt.base.instance import Instance from fms_dgt.utils import sdg_logger @@ -105,7 +105,7 @@ def set_cache_hook(self, cache_hook) -> None: def generate( self, - inputs: BLOCK_INPUT_TYPE, + inputs: DATASET_TYPE, arg_fields: Optional[List[str]] = None, kwarg_fields: Optional[List[str]] = None, result_field: Optional[str] = None, @@ -256,7 +256,7 @@ def fn(requests: List[Instance]): def generate( self, - inputs: BLOCK_INPUT_TYPE, + inputs: DATASET_TYPE, arg_fields: Optional[List[str]] = None, kwarg_fields: Optional[List[str]] = None, result_field: Optional[str] = None, From e1a5c23331ae0a21143c940906a8a484205961e4 Mon Sep 17 00:00:00 2001 From: Max Crouse Date: Wed, 10 Jul 2024 11:20:06 -0500 Subject: [PATCH 33/41] Update fms_dgt/base/block.py Co-authored-by: Gabe Goodhart --- fms_dgt/base/block.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fms_dgt/base/block.py b/fms_dgt/base/block.py index 72ecaf09..d6162fd6 100644 --- a/fms_dgt/base/block.py +++ b/fms_dgt/base/block.py @@ -72,7 +72,7 @@ def write_result( self, inp: DATASET_ROW_TYPE, res: Any, - result_field: str = None, + result_field: Optional[str] = None, ): result_field = result_field or self.result_field From fb48faffd162b142268a1c862970e669c6678b39 Mon Sep 17 00:00:00 2001 From: Max Crouse Date: Wed, 10 Jul 2024 11:24:46 -0500 Subject: [PATCH 34/41] Update fms_dgt/base/block.py Co-authored-by: Gabe Goodhart --- fms_dgt/base/block.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/fms_dgt/base/block.py b/fms_dgt/base/block.py index d6162fd6..5270a157 100644 --- a/fms_dgt/base/block.py +++ b/fms_dgt/base/block.py @@ -61,12 +61,12 @@ def get_args_kwargs( kwarg_fields = kwarg_fields or self.kwarg_fields or [] if isinstance(inp, (dict, pd.DataFrame, Dataset)): - args = [inp.get(arg) for arg in arg_fields] - kwargs = {kwarg: inp.get(kwarg) for kwarg in kwarg_fields} - else: - raise TypeError(f"Unexpected input type: {type(inp)}") + return ( + [inp.get(arg) for arg in arg_fields], + {kwarg: inp.get(kwarg) for kwarg in kwarg_fields} + ) + raise TypeError(f"Unexpected input type: {type(inp)}") - return args, kwargs def write_result( self, From b4088ce7c1f31f81104d07c89de891095c02ef26 Mon Sep 17 00:00:00 2001 From: Max Crouse Date: Wed, 10 Jul 2024 11:24:57 -0500 Subject: [PATCH 35/41] Update fms_dgt/base/block.py Co-authored-by: Gabe Goodhart --- fms_dgt/base/block.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fms_dgt/base/block.py b/fms_dgt/base/block.py index 5270a157..a43265da 100644 --- a/fms_dgt/base/block.py +++ b/fms_dgt/base/block.py @@ -140,5 +140,6 @@ def generate( outputs.append(x) return outputs + @abstractmethod def _validate(self, *args: Any, **kwargs: Any) -> bool: - raise NotImplementedError + """Derived validators must implement _validate with their core logic""" From 64f971a96883849c022776f1503d08011e761299 Mon Sep 17 00:00:00 2001 From: Max Crouse Date: Wed, 10 Jul 2024 11:25:58 -0500 Subject: [PATCH 36/41] Update fms_dgt/base/block.py Co-authored-by: Gabe Goodhart --- fms_dgt/base/block.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/fms_dgt/base/block.py b/fms_dgt/base/block.py index a43265da..efff9ac6 100644 --- a/fms_dgt/base/block.py +++ b/fms_dgt/base/block.py @@ -80,8 +80,7 @@ def write_result( if isinstance(inp, (dict, pd.DataFrame, Dataset)): inp[result_field] = res - else: - raise TypeError(f"Unexpected input type: {type(inp)}") + raise TypeError(f"Unexpected input type: {type(inp)}") @abstractmethod def generate( From 41c91445c9a25a98b8d075846c77e0358d603d1a Mon Sep 17 00:00:00 2001 From: Max Crouse Date: Wed, 10 Jul 2024 11:35:08 -0500 Subject: [PATCH 37/41] fixing base block class --- fms_dgt/base/block.py | 14 ++++++++------ fms_dgt/blocks/generators/llm.py | 1 + fms_dgt/blocks/utilities/flatten_field.py | 5 +++-- tests/compatibility_tests/blocks.py | 4 ++-- 4 files changed, 14 insertions(+), 10 deletions(-) diff --git a/fms_dgt/base/block.py b/fms_dgt/base/block.py index efff9ac6..552e3d15 100644 --- a/fms_dgt/base/block.py +++ b/fms_dgt/base/block.py @@ -6,7 +6,7 @@ from datasets import Dataset import pandas as pd -DATASET_ROW_TYPE = Union[Dict, pd.Series] +DATASET_ROW_TYPE = Union[Dict[str, Any], pd.Series] DATASET_TYPE = Union[Iterable[DATASET_ROW_TYPE], pd.DataFrame, Dataset] @@ -22,11 +22,11 @@ def __init__( ) -> None: if not (isinstance(arg_fields, list) or arg_fields is None): - raise TypeError(f"arg_fields must be of type 'list'") + raise TypeError("arg_fields must be of type 'list'") if not (isinstance(kwarg_fields, list) or kwarg_fields is None): - raise TypeError(f"kwarg_fields must be of type 'list'") + raise TypeError("kwarg_fields must be of type 'list'") if not (isinstance(result_field, str) or result_field is None): - raise TypeError(f"result_field must be of type 'str'") + raise TypeError("result_field must be of type 'str'") self._name = name @@ -63,11 +63,10 @@ def get_args_kwargs( if isinstance(inp, (dict, pd.DataFrame, Dataset)): return ( [inp.get(arg) for arg in arg_fields], - {kwarg: inp.get(kwarg) for kwarg in kwarg_fields} + {kwarg: inp.get(kwarg) for kwarg in kwarg_fields}, ) raise TypeError(f"Unexpected input type: {type(inp)}") - def write_result( self, inp: DATASET_ROW_TYPE, @@ -80,6 +79,8 @@ def write_result( if isinstance(inp, (dict, pd.DataFrame, Dataset)): inp[result_field] = res + return + raise TypeError(f"Unexpected input type: {type(inp)}") @abstractmethod @@ -126,6 +127,7 @@ def __init__(self, filter: bool = False, **kwargs: Any) -> None: def generate( self, inputs: DATASET_TYPE, + *, arg_fields: Optional[List[str]] = None, kwarg_fields: Optional[List[str]] = None, result_field: Optional[List[str]] = None, diff --git a/fms_dgt/blocks/generators/llm.py b/fms_dgt/blocks/generators/llm.py index 2b7238f1..94ac9cd8 100644 --- a/fms_dgt/blocks/generators/llm.py +++ b/fms_dgt/blocks/generators/llm.py @@ -106,6 +106,7 @@ def set_cache_hook(self, cache_hook) -> None: def generate( self, inputs: DATASET_TYPE, + *, arg_fields: Optional[List[str]] = None, kwarg_fields: Optional[List[str]] = None, result_field: Optional[str] = None, diff --git a/fms_dgt/blocks/utilities/flatten_field.py b/fms_dgt/blocks/utilities/flatten_field.py index a9e81d09..d1242c12 100644 --- a/fms_dgt/blocks/utilities/flatten_field.py +++ b/fms_dgt/blocks/utilities/flatten_field.py @@ -7,7 +7,7 @@ from pandas import DataFrame # Local -from fms_dgt.base.block import BLOCK_INPUT_TYPE, BaseUtilityBlock +from fms_dgt.base.block import DATASET_TYPE, BaseUtilityBlock from fms_dgt.base.registry import register_block @@ -17,7 +17,8 @@ class FlattenField(BaseUtilityBlock): def generate( self, - inputs: BLOCK_INPUT_TYPE, + inputs: DATASET_TYPE, + *, arg_fields: Optional[List[str]] = None, kwarg_fields: Optional[List[str]] = None, result_field: Optional[str] = None, diff --git a/tests/compatibility_tests/blocks.py b/tests/compatibility_tests/blocks.py index c86ae062..0c8dd71b 100644 --- a/tests/compatibility_tests/blocks.py +++ b/tests/compatibility_tests/blocks.py @@ -9,7 +9,7 @@ import pandas as pd # Local -from fms_dgt.base.block import BLOCK_INPUT_TYPE, BaseBlock +from fms_dgt.base.block import DATASET_TYPE, BaseBlock class TestFilterBlock(BaseBlock): @@ -30,7 +30,7 @@ def __init__( def generate( self, - inputs: BLOCK_INPUT_TYPE, + inputs: DATASET_TYPE, **kwargs: Any, ) -> Any: return self._filter_block.generate(inputs) From 6f5631ad361106cdcf64ec5f41d842361d67a0d6 Mon Sep 17 00:00:00 2001 From: Max Crouse Date: Wed, 10 Jul 2024 14:30:33 -0500 Subject: [PATCH 38/41] consistency --- fms_dgt/blocks/utilities/flatten_field.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/fms_dgt/blocks/utilities/flatten_field.py b/fms_dgt/blocks/utilities/flatten_field.py index d1242c12..351bb530 100644 --- a/fms_dgt/blocks/utilities/flatten_field.py +++ b/fms_dgt/blocks/utilities/flatten_field.py @@ -23,9 +23,7 @@ def generate( kwarg_fields: Optional[List[str]] = None, result_field: Optional[str] = None, ): - arg_fields = arg_fields if arg_fields is not None else self._arg_fields - if arg_fields is None: - arg_fields = [] + arg_fields = arg_fields or self._arg_fields or [] assert ( len(arg_fields) == 1 From 41009e9bf73812e002ff3d1b4aef0ac5846616c8 Mon Sep 17 00:00:00 2001 From: Max Crouse Date: Wed, 10 Jul 2024 16:00:59 -0500 Subject: [PATCH 39/41] removing empty classes --- fms_dgt/base/block.py | 14 +++----------- fms_dgt/blocks/generators/llm.py | 4 ++-- fms_dgt/blocks/utilities/flatten_field.py | 4 ++-- 3 files changed, 7 insertions(+), 15 deletions(-) diff --git a/fms_dgt/base/block.py b/fms_dgt/base/block.py index 552e3d15..ea17cd37 100644 --- a/fms_dgt/base/block.py +++ b/fms_dgt/base/block.py @@ -16,9 +16,9 @@ class BaseBlock(ABC): def __init__( self, name: str = None, - arg_fields: List[str] = None, - kwarg_fields: List[str] = None, - result_field: str = None, + arg_fields: Optional[List[str]] = None, + kwarg_fields: Optional[List[str]] = None, + result_field: Optional[str] = None, ) -> None: if not (isinstance(arg_fields, list) or arg_fields is None): @@ -111,14 +111,6 @@ def generate( """ -class BaseUtilityBlock(BaseBlock): - pass - - -class BaseGeneratorBlock(BaseBlock): - pass - - class BaseValidatorBlock(BaseBlock): def __init__(self, filter: bool = False, **kwargs: Any) -> None: super().__init__(**kwargs) diff --git a/fms_dgt/blocks/generators/llm.py b/fms_dgt/blocks/generators/llm.py index 94ac9cd8..330af2b1 100644 --- a/fms_dgt/blocks/generators/llm.py +++ b/fms_dgt/blocks/generators/llm.py @@ -22,14 +22,14 @@ from tqdm import tqdm # Local -from fms_dgt.base.block import DATASET_TYPE, BaseGeneratorBlock +from fms_dgt.base.block import DATASET_TYPE, BaseBlock from fms_dgt.base.instance import Instance from fms_dgt.utils import sdg_logger MODEL_ID_OR_PATH = "model_id_or_path" -class LMGenerator(BaseGeneratorBlock): +class LMGenerator(BaseBlock): """Class for LLM Generators""" def __init__( diff --git a/fms_dgt/blocks/utilities/flatten_field.py b/fms_dgt/blocks/utilities/flatten_field.py index 351bb530..9e58eeac 100644 --- a/fms_dgt/blocks/utilities/flatten_field.py +++ b/fms_dgt/blocks/utilities/flatten_field.py @@ -7,12 +7,12 @@ from pandas import DataFrame # Local -from fms_dgt.base.block import DATASET_TYPE, BaseUtilityBlock +from fms_dgt.base.block import DATASET_TYPE, BaseBlock from fms_dgt.base.registry import register_block @register_block("flatten_field") -class FlattenField(BaseUtilityBlock): +class FlattenField(BaseBlock): """Flatten specified args""" def generate( From 7fb7f67ae522fce69af83965892f058589c60a50 Mon Sep 17 00:00:00 2001 From: Max Crouse Date: Thu, 11 Jul 2024 08:52:28 -0500 Subject: [PATCH 40/41] make blocks a list, easier for duplicate checking --- fms_dgt/base/databuilder.py | 26 ++++++++++++++----- .../api/api_function_checking.yaml | 6 ++--- .../api/api_yes_no_detection.yaml | 6 ++--- fms_dgt/databuilders/nl2sql/nl2sql.yaml | 6 ++--- fms_dgt/databuilders/simple/simple.yaml | 4 +-- 5 files changed, 31 insertions(+), 17 deletions(-) diff --git a/fms_dgt/base/databuilder.py b/fms_dgt/base/databuilder.py index 764c0911..8fb5c935 100644 --- a/fms_dgt/base/databuilder.py +++ b/fms_dgt/base/databuilder.py @@ -3,6 +3,7 @@ from dataclasses import dataclass from datetime import datetime from typing import Any, Iterable, List, Mapping, Optional, Union +import json import os import time @@ -31,6 +32,7 @@ def __post_init__(self) -> None: pass +NAME_KEY = "name" TYPE_KEY = "type" @@ -79,6 +81,10 @@ def __init__( output_dir, f"discarded_{self.config.name}_{date_suffix}.log" ) + @property + def name(self) -> str: + return self._name + @property def config(self) -> DataBuilderConfig: """Returns the DataBuilderConfig associated with this class.""" @@ -93,15 +99,23 @@ def _init_blocks(self, lm_cache: str = None): self._blocks: List[BaseBlock] = [] # TODO: need to handle nested blocks - for obj_name, obj_config in self.config.blocks.items(): - obj_kwargs = {**obj_config, "name": obj_name} + for obj_kwargs in self.config.blocks: + + for req_key in (NAME_KEY, TYPE_KEY): + assert ( + req_key in obj_kwargs + ), f"'{req_key}' field missing in data builder config from block with args:\n{json.dumps(obj_kwargs, indent=4)} " + + obj_name = obj_kwargs["name"] + + assert not any( + block.name == obj_name for block in self._blocks + ), f"Duplicate '{obj_name}' block in '{self.name}' data builder" + sdg_logger.debug( - "Initializing object %s with config %s", obj_name, obj_config + "Initializing object %s with config %s", obj_name, obj_kwargs ) - assert ( - TYPE_KEY in obj_kwargs - ), f"'type' field missing from {obj_name} in data builder config" obj = get_block(obj_kwargs.pop(TYPE_KEY))(**obj_kwargs) if lm_cache is not None and isinstance(obj, LMGenerator): diff --git a/fms_dgt/databuilders/api/api_function_checking.yaml b/fms_dgt/databuilders/api/api_function_checking.yaml index bcab7b8e..ed8e46d7 100644 --- a/fms_dgt/databuilders/api/api_function_checking.yaml +++ b/fms_dgt/databuilders/api/api_function_checking.yaml @@ -1,16 +1,16 @@ name: api_function_checking blocks: - llm1: + - name: llm1 type: genai decoding_method: sample temperature: 0.5 max_new_tokens: 1024 min_new_tokens: 1 model_id_or_path: mistralai/mixtral-8x7b-instruct-v01 - val1: + - name: val1 type: api_function_checking filter: true - val2: + - name: val2 type: rouge_scorer filter: true threshold: 0.35 diff --git a/fms_dgt/databuilders/api/api_yes_no_detection.yaml b/fms_dgt/databuilders/api/api_yes_no_detection.yaml index cbbace1b..181dd26c 100644 --- a/fms_dgt/databuilders/api/api_yes_no_detection.yaml +++ b/fms_dgt/databuilders/api/api_yes_no_detection.yaml @@ -1,16 +1,16 @@ name: api_yes_no_detection blocks: - llm1: + - name: llm1 type: genai decoding_method: sample temperature: 0.5 max_new_tokens: 1024 min_new_tokens: 1 model_id_or_path: mistralai/mixtral-8x7b-instruct-v01 - val1: + - name: val1 type: api_yes_no filter: true - val2: + - name: val2 type: rouge_scorer filter: true threshold: 0.35 diff --git a/fms_dgt/databuilders/nl2sql/nl2sql.yaml b/fms_dgt/databuilders/nl2sql/nl2sql.yaml index e29664e9..753c7c6a 100644 --- a/fms_dgt/databuilders/nl2sql/nl2sql.yaml +++ b/fms_dgt/databuilders/nl2sql/nl2sql.yaml @@ -1,15 +1,15 @@ name: nl2sql blocks: - llm1: + - name: llm1 type: genai temperature: 0.0 max_new_tokens: 512 min_new_tokens: 1 model_id_or_path: ibm/granite-8b-code-instruct - val1: + - name: val1 type: sql_syntax_validator filter: true - val2: + - name: val2 type: sql_execution_validator filter: true metadata: diff --git a/fms_dgt/databuilders/simple/simple.yaml b/fms_dgt/databuilders/simple/simple.yaml index 48618c55..ec45063d 100644 --- a/fms_dgt/databuilders/simple/simple.yaml +++ b/fms_dgt/databuilders/simple/simple.yaml @@ -1,6 +1,6 @@ name: simple blocks: - llm1: + - name: llm1 type: genai arg_fields: - prompt @@ -11,7 +11,7 @@ blocks: max_new_tokens: 512 min_new_tokens: 1 model_id_or_path: mistralai/mixtral-8x7b-instruct-v01 - val1: + - name: val1 type: rouge_scorer arg_fields: - new_toks From 63d4ea68895ac475a4b6158f51479b75f26390ed Mon Sep 17 00:00:00 2001 From: Max Crouse Date: Thu, 11 Jul 2024 11:43:16 -0500 Subject: [PATCH 41/41] simpler type check --- fms_dgt/base/block.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fms_dgt/base/block.py b/fms_dgt/base/block.py index ea17cd37..c36901f9 100644 --- a/fms_dgt/base/block.py +++ b/fms_dgt/base/block.py @@ -21,11 +21,11 @@ def __init__( result_field: Optional[str] = None, ) -> None: - if not (isinstance(arg_fields, list) or arg_fields is None): + if not isinstance(arg_fields, (list, type(None))): raise TypeError("arg_fields must be of type 'list'") - if not (isinstance(kwarg_fields, list) or kwarg_fields is None): + if not isinstance(kwarg_fields, (list, type(None))): raise TypeError("kwarg_fields must be of type 'list'") - if not (isinstance(result_field, str) or result_field is None): + if not isinstance(result_field, (str, type(None))): raise TypeError("result_field must be of type 'str'") self._name = name