Skip to content

Commit

Permalink
[REVIEW] Support dict output for models (#98)
Browse files Browse the repository at this point in the history
* Add support for dict output from models

* Fix minor problems with a test

Signed-off-by: Vibhu Jawa <[email protected]>

* fix a input validation minor error

* Remove print output

* Add support of meta correctly for output dict

Signed-off-by: Vibhu Jawa <[email protected]>

* Add type hints for get_model_output

Signed-off-by: Vibhu Jawa <[email protected]>

* fix docstring slightly

Signed-off-by: Vibhu Jawa <[email protected]>

* fix based on reviews

Signed-off-by: Vibhu Jawa <[email protected]>

* fix based on reviews

Signed-off-by: Vibhu Jawa <[email protected]>

---------

Signed-off-by: Vibhu Jawa <[email protected]>
  • Loading branch information
VibhuJawa authored Oct 29, 2024
1 parent b5038ae commit 82f232f
Show file tree
Hide file tree
Showing 6 changed files with 355 additions and 54 deletions.
174 changes: 141 additions & 33 deletions crossfit/backend/torch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from enum import Enum
from typing import Any, Dict, List, Union

import cudf
import cupy as cp
import torch

from crossfit.backend.cudf.series import (
create_list_series_from_1d_or_2d_ar,
Expand All @@ -23,16 +26,30 @@
from crossfit.utils.torch_utils import cleanup_torch_cache, concat_and_pad_tensors


class ModelOutputType(Enum):
NUMERIC = "numeric"
STRING = "string"


class Model:
def __init__(self, path_or_name: str, max_mem_gb: int = 16, model_output_type: str = "numeric"):
def __init__(
self,
path_or_name: str,
max_mem_gb: int = 16,
model_output_type: Union[str, Dict[str, str]] = "numeric",
):
"""Initialize a Crossfit Pytorch Model Instance.
Args:
path_or_name (str): Path to the model file or the model name to load.
max_mem_gb (int): Maximum memory in gigabytes to allocate for the model.Defaults to 16.
model_output_type (str, dict): Specifies the type of model output. Can be either
"numeric" or "string". If a dictionary is provided, it maps prediction names to
their respective types. Defaults to "numeric".
"""
self.path_or_name = path_or_name
self.max_mem_gb = max_mem_gb
if model_output_type in ["numeric", "string"]:
self.model_output_type = model_output_type
else:
raise ValueError(
"Invalid model output type provided. Allowed values are : 'string' or 'numeric'."
)
self.model_output_type = _validate_model_output_type(model_output_type)

def load_model(self, device="cuda"):
raise NotImplementedError()
Expand All @@ -50,6 +67,8 @@ def call_on_worker(self, worker, *args, **kwargs):
return worker.torch_model(*args, **kwargs)

def get_model(self, worker):
# TODO: We should not hard code the attribute name
# to torch_model. We should use the path_or_name_model
if not hasattr(worker, "torch_model"):
self.load_on_worker(worker)
return worker.torch_model
Expand All @@ -60,36 +79,125 @@ def estimate_memory(self, max_num_tokens: int, batch_size: int) -> int:
def max_seq_length(self) -> int:
raise NotImplementedError()

def get_model_output(self, all_outputs_ls, index, loader, pred_output_col) -> cudf.DataFrame:
def get_model_output(
self,
all_outputs_ls: List[Union[dict, torch.Tensor]],
index: Union[cudf.Index],
loader: Any,
pred_output_col: str,
) -> cudf.DataFrame:
# importing here to avoid cyclic import error
from crossfit.backend.torch.loader import SortedSeqLoader

out = cudf.DataFrame(index=index)
out_df = cudf.DataFrame(index=index)
_index = loader.sort_column(index.values) if type(loader) is SortedSeqLoader else index

if self.model_output_type == "string":
all_outputs = [o for output in all_outputs_ls for o in output]
out[pred_output_col] = cudf.Series(data=all_outputs, index=_index)
del all_outputs_ls
del loader
else:
outputs = cp.asarray(
concat_and_pad_tensors(
all_outputs_ls,
pad_token_id=loader.pad_token_id,
padding_side=loader.padding_side,
if isinstance(all_outputs_ls[0], dict):
if not isinstance(self.model_output_type, dict):
raise ValueError(
"model_output_type must be a dictionary when the model output is a dictionary"
)
for pred_name in all_outputs_ls[0].keys():
if pred_name not in self.model_output_type:
raise ValueError(
f"Invalid prediction name '{pred_name}'.\n"
f"Allowed prediction names: {list(self.model_output_type.keys())}\n"
"Please provide a valid prediction name ands its datatype "
"in the model_output_type dictionary."
)
model_output_type = self.model_output_type.get(pred_name, self.model_output_type)
_add_column_to_df(
out_df,
[o[pred_name] for o in all_outputs_ls],
_index,
loader,
pred_name,
model_output_type,
)
else:
_add_column_to_df(
out_df, all_outputs_ls, _index, loader, pred_output_col, self.model_output_type
)
del all_outputs_ls
del loader
cleanup_torch_cache()
if len(outputs.shape) <= 2:
out[pred_output_col] = create_list_series_from_1d_or_2d_ar(outputs, _index)
elif len(outputs.shape) == 3:
out[pred_output_col] = create_nested_list_series_from_3d_ar(outputs, _index)
else:
raise RuntimeError(f"Unexpected output shape: {outputs.shape}")
del outputs
del _index
cleanup_torch_cache()
return out
return out_df


def _add_column_to_df(
df: cudf.DataFrame,
all_outputs_ls: List[Any],
_index: Any,
loader: Any,
pred_output_col: str,
model_output_type: ModelOutputType,
) -> None:
if model_output_type is ModelOutputType.STRING:
_add_string_column(df, pred_output_col, all_outputs_ls)
elif model_output_type is ModelOutputType.NUMERIC:
_add_numeric_column(df, all_outputs_ls, _index, loader, pred_output_col)
else:
raise ValueError(f"Invalid model_output_type: {model_output_type}")


def _add_string_column(
df: cudf.DataFrame, pred_output_col: str, all_outputs_ls: List[List[str]]
) -> None:
df[pred_output_col] = [o for output in all_outputs_ls for o in output]


def _add_numeric_column(
df: cudf.DataFrame, all_outputs_ls: List[Any], _index: Any, loader: Any, pred_output_col: str
) -> None:
outputs = cp.asarray(
concat_and_pad_tensors(
all_outputs_ls,
pad_token_id=getattr(loader, "pad_token_id", None),
padding_side=getattr(loader, "padding_side", None),
)
)
del all_outputs_ls
del loader
cleanup_torch_cache()
if len(outputs.shape) == 1:
df[pred_output_col] = cudf.Series(outputs, index=_index)
elif len(outputs.shape) == 2:
df[pred_output_col] = create_list_series_from_1d_or_2d_ar(outputs, _index)
elif len(outputs.shape) == 3:
df[pred_output_col] = create_nested_list_series_from_3d_ar(outputs, _index)
else:
raise RuntimeError(f"Unexpected output shape: {outputs.shape}")


def _validate_model_output_type(
model_output_type: Union[str, ModelOutputType, dict[str, Union[str, ModelOutputType]]]
) -> Union[ModelOutputType, dict[str, ModelOutputType]]:
"""Validate and convert model output type to proper enum format.
Args:
model_output_type: Either a string/enum value, or a dict of string/enum values
Returns:
ModelOutputType or dict: Validated and converted output type(s)
Raises:
ValueError: If invalid output type is provided
"""

def _convert_single_type(value):
if isinstance(value, str):
try:
return ModelOutputType(value)
except ValueError:
raise ValueError(
f"Invalid model_output_type: {value}. "
f"Allowed values are: {[e.value for e in ModelOutputType]}"
)
elif isinstance(value, ModelOutputType):
return value
else:
raise ValueError(
f"model_output_type must be string or ModelOutputType, got {type(value)}"
)

if isinstance(model_output_type, dict):
return {key: _convert_single_type(value) for key, value in model_output_type.items()}
else:
return _convert_single_type(model_output_type)
78 changes: 66 additions & 12 deletions crossfit/backend/torch/op/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from typing import Optional

import torch
Expand All @@ -32,8 +33,9 @@ def __init__(
batch_size: int = DEFAULT_BATCH_SIZE,
max_mem: str = "16GB",
sorted_data_loader: bool = True,
model_output_col: Optional[str] = None,
pred_output_col: str = "preds",
model_output_col: Optional[str] = None, # Deprecated
model_output_cols: Optional[list[str]] = None,
pred_output_col: Optional[str] = None,
):
super().__init__(pre=pre, cols=cols, keep_cols=keep_cols)
self.model = model
Expand All @@ -42,8 +44,24 @@ def __init__(
self.max_mem = max_mem
self.max_mem_gb = int(self.max_mem.split("GB")[0]) / 2.5
self.sorted_data_loader = sorted_data_loader
self.model_output_col = model_output_col
self.pred_output_col = pred_output_col

if model_output_col and model_output_cols:
raise ValueError("Specify either model_output_col or model_output_cols, not both.")
elif model_output_col:
self.model_output_cols = [model_output_col]
elif model_output_cols:
self.model_output_cols = model_output_cols
else:
self.model_output_cols = None

if model_output_col:
warnings.warn("model_output_col is deprecated. Please use model_output_cols instead.")

if pred_output_col and self.model_output_cols and len(self.model_output_cols) > 1:
raise ValueError(
"pred_output_col can only be specified when model_output_cols has a single column."
)
self.pred_output_col = pred_output_col or "preds"

@torch.no_grad()
def call(self, data, partition_info=None):
Expand All @@ -67,19 +85,55 @@ def call(self, data, partition_info=None):
all_outputs_ls = []
for output in loader.map(self.model.get_model(self.get_worker())):
if isinstance(output, dict):
if self.model_output_col not in output:
raise ValueError(f"Column '{self.model_output_col}' not found in model output.")
output = output[self.model_output_col]

if self.model_output_cols:
output = {col: output[col] for col in self.model_output_cols if col in output}
if len(output) == 0:
raise ValueError(
"None of the specified model_output_cols were found in",
"the output dict. ",
f"Available output keys: {list(output.keys())}. ",
f"Requested columns: {self.model_output_cols}",
)
if len(output) == 1:
output = list(output.values())[0]
elif len(output) > 1 and self.model_output_cols is None:
raise ValueError(
"Model returned more than one output column, but model_output_cols ",
"was not specified. Please specify model_output_cols",
"to get all model outputs.",
)
if self.post is not None:
output = self.post(output)

all_outputs_ls.append(output)
out = self.model.get_model_output(all_outputs_ls, index, loader, self.pred_output_col)
return out

def meta(self):
if self.model.model_output_type == "string":
return {self.pred_output_col: "object"}
# Case 1: Multiple output columns
if self.model_output_cols and len(self.model_output_cols) > 1:
if not isinstance(self.model.model_output_type, dict):
raise ValueError(
"model_output_type must be a dictionary when "
"multiple model_output_cols are specified"
)
return {
col: "object" if self.model.model_output_type.get(col) == "string" else "float32"
for col in self.model_output_cols
}

# Case 2: Single output column or default behavior
if self.model_output_cols:
first_col = self.model_output_cols[0]
if isinstance(self.model.model_output_type, dict):
output_type = self.model.model_output_type.get(first_col)
else:
output_type = self.model.model_output_type
else:
return {self.pred_output_col: "float32"}
# If model_output_cols is empty, fallback to default output type
output_type = (
list(self.model.model_output_type.values())[0]
if isinstance(self.model.model_output_type, dict)
else self.model.model_output_type
)

return {self.pred_output_col: "object" if output_type == "string" else "float32"}
3 changes: 1 addition & 2 deletions crossfit/op/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, pre=None, cols=False, keep_cols=None):

@property
def worker_name(self):
return getattr(self.get_worker(), "name", 0)
return getattr(self.get_worker(), "worker_address")

def setup(self):
pass
Expand Down Expand Up @@ -59,7 +59,6 @@ def call_dask(self, data: dd.DataFrame):
def create_progress_bar(self, total, partition_info=None, **kwargs):
return tqdm(
total=total,
position=int(self.worker_name),
desc=f"GPU: {self.worker_name}, Part: {partition_info['number']}",
**kwargs,
)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from setuptools import find_packages, setup

VERSION = "0.0.6"
VERSION = "0.0.7"


def get_long_description():
Expand Down
Loading

0 comments on commit 82f232f

Please sign in to comment.