Skip to content
This repository was archived by the owner on Nov 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
72a2bd9
updating blocks
Jul 3, 2024
6d6f4d9
merge with main
Jul 3, 2024
7b2d811
update all pub databuilders
Jul 3, 2024
4ee38bb
caching llm
Jul 3, 2024
aea2444
updating with main
Jul 3, 2024
d859057
adding compatibility_tests
Jul 3, 2024
b84e7de
merge main
Jul 3, 2024
3e1fbc0
template update
Jul 3, 2024
aec74ff
template update
Jul 3, 2024
8a215de
rm import
Jul 3, 2024
52d26f6
remove old return type
Jul 4, 2024
94ffa99
add parquet saving / loading
Jul 5, 2024
a19116f
adding utility block
Jul 5, 2024
e48ec42
adding utility block
Jul 5, 2024
0ebe2a0
rm block suffix
Jul 8, 2024
1d654dd
remove config argument
Jul 8, 2024
6d51079
demonstrate default vals
Jul 8, 2024
0d17bb6
demonstrate default vals
Jul 8, 2024
fba82b6
remove abstract method to simplify
Jul 8, 2024
e7db4dd
remove abstract method to simplify
Jul 8, 2024
af01c4e
misc minor changes
Jul 8, 2024
a94c7eb
call to generate
Jul 9, 2024
d06a26d
call to generate
Jul 9, 2024
b7fcb59
non base functions
Jul 9, 2024
8b03d98
merge with main
Jul 10, 2024
cb29984
registry change
Jul 10, 2024
b49ffa1
genai req bug fix
Jul 10, 2024
c7a49a6
Update fms_dgt/base/block.py
mvcrouse Jul 10, 2024
80b426f
Update fms_dgt/base/block.py
mvcrouse Jul 10, 2024
e2bc1dc
Update fms_dgt/base/block.py
mvcrouse Jul 10, 2024
79278fd
Update fms_dgt/base/block.py
mvcrouse Jul 10, 2024
fed9702
throw type error
Jul 10, 2024
030caad
Merge branch 'block_design' of github.com:mvcrouse/fms-sdg into block…
Jul 10, 2024
ed9e13b
instance methods
Jul 10, 2024
f78b81d
dataset type
Jul 10, 2024
13701af
dataset type
Jul 10, 2024
0953d9b
dataset type
Jul 10, 2024
e1a5c23
Update fms_dgt/base/block.py
mvcrouse Jul 10, 2024
fb48faf
Update fms_dgt/base/block.py
mvcrouse Jul 10, 2024
b4088ce
Update fms_dgt/base/block.py
mvcrouse Jul 10, 2024
64f971a
Update fms_dgt/base/block.py
mvcrouse Jul 10, 2024
41c9144
fixing base block class
Jul 10, 2024
6f5631a
consistency
Jul 10, 2024
41009e9
removing empty classes
Jul 10, 2024
7fb7f67
make blocks a list, easier for duplicate checking
Jul 11, 2024
63d4ea6
simpler type check
Jul 11, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions fms_dgt/base/block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Standard
from abc import ABC, abstractmethod
from typing import Any, Dict, Iterable, List, Optional, Union

# Third Party
from datasets import Dataset
import pandas as pd

DATASET_ROW_TYPE = Union[Dict[str, Any], pd.Series]
DATASET_TYPE = Union[Iterable[DATASET_ROW_TYPE], pd.DataFrame, Dataset]


class BaseBlock(ABC):
"""Base Class for all Blocks"""

def __init__(
self,
name: str = None,
arg_fields: Optional[List[str]] = None,
kwarg_fields: Optional[List[str]] = None,
result_field: Optional[str] = None,
) -> None:

if not isinstance(arg_fields, (list, type(None))):
raise TypeError("arg_fields must be of type 'list'")
if not isinstance(kwarg_fields, (list, type(None))):
raise TypeError("kwarg_fields must be of type 'list'")
if not isinstance(result_field, (str, type(None))):
raise TypeError("result_field must be of type 'str'")

self._name = name

self._arg_fields = arg_fields
self._kwarg_fields = kwarg_fields
self._result_field = result_field

@property
def name(self):
return self._name

@property
def arg_fields(self):
return self._arg_fields

@property
def kwarg_fields(self):
return self._kwarg_fields

@property
def result_field(self):
return self._result_field

def get_args_kwargs(
self,
inp: DATASET_ROW_TYPE,
arg_fields: Optional[List[str]] = None,
kwarg_fields: Optional[List[str]] = None,
):

arg_fields = arg_fields or self.arg_fields or []
kwarg_fields = kwarg_fields or self.kwarg_fields or []

if isinstance(inp, (dict, pd.DataFrame, Dataset)):
return (
[inp.get(arg) for arg in arg_fields],
{kwarg: inp.get(kwarg) for kwarg in kwarg_fields},
)
raise TypeError(f"Unexpected input type: {type(inp)}")

def write_result(
self,
inp: DATASET_ROW_TYPE,
res: Any,
result_field: Optional[str] = None,
):
result_field = result_field or self.result_field

assert result_field is not None, "Result field cannot be None!"

if isinstance(inp, (dict, pd.DataFrame, Dataset)):
inp[result_field] = res
return

raise TypeError(f"Unexpected input type: {type(inp)}")

@abstractmethod
def generate(
self,
inputs: DATASET_TYPE,
*,
arg_fields: Optional[List[str]] = None,
kwarg_fields: Optional[List[str]] = None,
result_field: Optional[str] = None,
**kwargs,
):
"""The generate function is the primary interface to a Block

args:
inputs (BLOCK_INPUT_TYPE): A block operates over a logical iterable
of rows with named columns (see BLOCK_INPUT_TYPE)

kwargs:
arg_fields (Optional[List[str]]): Names of fields within the rows of
the inputs that should be extracted and passed as positional
args to the underlying implementation methods.
kwarg_fields (Optional[List[str]]): Names of fields within the rows
of the inputs that should be extracted and passed as keyword
args to the underlying implementation methods.
**kwargs: Additional keyword args that may be passed to the derived
block's generate function
"""


class BaseValidatorBlock(BaseBlock):
def __init__(self, filter: bool = False, **kwargs: Any) -> None:
super().__init__(**kwargs)
self._filter_invalids = filter

def generate(
self,
inputs: DATASET_TYPE,
*,
arg_fields: Optional[List[str]] = None,
kwarg_fields: Optional[List[str]] = None,
result_field: Optional[List[str]] = None,
):
outputs = []
for x in inputs:
inp_args, inp_kwargs = self.get_args_kwargs(x, arg_fields, kwarg_fields)
res = self._validate(*inp_args, **inp_kwargs)
if res or not self._filter_invalids:
self.write_result(x, res, result_field)
outputs.append(x)
return outputs

@abstractmethod
def _validate(self, *args: Any, **kwargs: Any) -> bool:
"""Derived validators must implement _validate with their core logic"""
138 changes: 65 additions & 73 deletions fms_dgt/base/databuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,40 +3,36 @@
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Iterable, List, Mapping, Optional, Union
import json
import os
import time

# Third Party
from tqdm import tqdm

# Local
from fms_dgt.base.generator import BaseGenerator
from fms_dgt.base.registry import get_generator, get_validator
from fms_dgt.base.block import BaseBlock
from fms_dgt.base.registry import get_block
from fms_dgt.base.task import SdgData, SdgTask
from fms_dgt.base.validator import BaseValidator
from fms_dgt.generators.llm import CachingLM, LMGenerator
from fms_dgt.blocks.generators.llm import CachingLM, LMGenerator
from fms_dgt.utils import all_annotations, sdg_logger


@dataclass
class DataBuilderConfig(dict):
# data builder naming/registry
name: Optional[str] = None
generators: Optional[Union[str, list]] = None
validators: Optional[Union[str, list]] = None
blocks: Optional[dict] = None
generation_kwargs: Optional[dict] = None
metadata: Optional[
dict
] = None # by default, not used in the code. allows for users to pass arbitrary info to data builders

def __post_init__(self) -> None:
if self.generation_kwargs is not None:
if "temperature" in self.generation_kwargs:
self.generation_kwargs["temperature"] = float(
self.generation_kwargs["temperature"]
)
pass


NAME_KEY = "name"
TYPE_KEY = "type"


Expand Down Expand Up @@ -66,7 +62,7 @@ def __init__(
self._restart_generation = restart_generation

# initializing generators / validators
self._init_gv(lm_cache=lm_cache)
self._init_blocks(lm_cache=lm_cache)

# TODO: Data loader goes here
self._tasks: List[SdgTask] = [
Expand All @@ -85,74 +81,70 @@ def __init__(
output_dir, f"discarded_{self.config.name}_{date_suffix}.log"
)

@property
def name(self) -> str:
return self._name

@property
def config(self) -> DataBuilderConfig:
"""Returns the DataBuilderConfig associated with this class."""
return self._config

@property
def generators(self) -> List[BaseGenerator]:
"""Returns the generators associated with this class."""
return self._generators
def blocks(self) -> List[BaseBlock]:
"""Returns the blocks associated with this class."""
return self._blocks

@property
def validators(self) -> List[BaseValidator]:
"""Returns the validators associated with this class."""
return self._validators

def _init_gv(self, lm_cache: str = None):
_generators = (
[self.config.generators]
if type(self.config.generators) == str
else self.config.generators
)
_validators = (
[self.config.validators]
if type(self.config.validators) == str
else self.config.validators
)
self._generators: List[BaseGenerator] = []
self._validators: List[BaseValidator] = []

# TODO: need to handle nested generators / validators
for i, info_src in enumerate([_generators, _validators]):
# user may not define a generator / validator
if info_src is not None:
for obj_name, obj_config in info_src.items():
sdg_logger.debug(
"Initializing object %s with config %s", obj_name, obj_config
)
obj = (get_generator if i == 0 else get_validator)(
obj_config[TYPE_KEY]
)(obj_name, obj_config)

if lm_cache is not None and isinstance(obj, LMGenerator):
sdg_logger.info(
"Using cache at %s",
lm_cache + "_rank" + str(obj.rank) + ".db",
)
obj = CachingLM(
obj,
lm_cache
# each rank receives a different cache db.
# necessary to avoid multiple writes to cache at once
+ f"_model{os.path.split(obj.model_id_or_path)[-1]}_rank{obj.rank}.db",
)

type_annotations = all_annotations(type(self))
assert (
obj_name in type_annotations
), f"Object {obj_name} is missing from definition of DataBuilder {self.__class__}"

obj_type = type_annotations[obj_name]

# double check types
assert isinstance(obj, obj_type) or (
isinstance(obj, CachingLM) and isinstance(obj.lm, obj_type)
), f"Type of retrieved object {obj.__class__} for {obj_name} does not match type {obj_type} specified in DataBuilder {self.__class__}"

setattr(self, obj_name, obj)
(self._generators if i == 0 else self._validators).append(obj)
def _init_blocks(self, lm_cache: str = None):
self._blocks: List[BaseBlock] = []

# TODO: need to handle nested blocks
for obj_kwargs in self.config.blocks:

for req_key in (NAME_KEY, TYPE_KEY):
assert (
req_key in obj_kwargs
), f"'{req_key}' field missing in data builder config from block with args:\n{json.dumps(obj_kwargs, indent=4)} "

obj_name = obj_kwargs["name"]

assert not any(
block.name == obj_name for block in self._blocks
), f"Duplicate '{obj_name}' block in '{self.name}' data builder"

sdg_logger.debug(
"Initializing object %s with config %s", obj_name, obj_kwargs
)

obj = get_block(obj_kwargs.pop(TYPE_KEY))(**obj_kwargs)

if lm_cache is not None and isinstance(obj, LMGenerator):
sdg_logger.info(
"Using cache at %s",
lm_cache + "_rank" + str(obj.rank) + ".db",
)
obj = CachingLM(
obj,
lm_cache
# each rank receives a different cache db.
# necessary to avoid multiple writes to cache at once
+ f"_model{os.path.split(obj.model_id_or_path)[-1]}_rank{obj.rank}.db",
)

type_annotations = all_annotations(type(self))
assert (
obj_name in type_annotations
), f"Object {obj_name} is missing from definition of DataBuilder {self.__class__}"

obj_type = type_annotations[obj_name]

# double check types
assert isinstance(obj, obj_type) or (
isinstance(obj, CachingLM) and isinstance(obj.lm, obj_type)
), f"Type of retrieved object {obj.__class__} for {obj_name} does not match type {obj_type} specified in DataBuilder {self.__class__}"

setattr(self, obj_name, obj)
self._blocks.append(obj)

def execute_tasks(self):
# main entry point to task execution
Expand Down
35 changes: 0 additions & 35 deletions fms_dgt/base/generator.py

This file was deleted.

4 changes: 2 additions & 2 deletions fms_dgt/base/instance.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Standard
from dataclasses import dataclass, field
from typing import Any, List, Optional
from dataclasses import dataclass
from typing import Any, Optional


@dataclass
Expand Down
Loading