Skip to content

Commit

Permalink
Merge pull request #54 from mathysgrapotte/linting
Browse files Browse the repository at this point in the history
Linting passing
  • Loading branch information
mathysgrapotte authored Jan 23, 2025
2 parents 5b81c79 + d40715f commit dabfebe
Show file tree
Hide file tree
Showing 48 changed files with 1,127 additions and 1,213 deletions.
1 change: 1 addition & 0 deletions config/mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ ignore_missing_imports = true
exclude = tests/fixtures/
warn_unused_ignores = true
show_error_codes = true
explicit_package_bases = True
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ dependencies = [
"numpy>=1.26.0,<2.0.0",
"pandas>=2.2.0",
"polars-lts-cpu>=0.20.30,<1.12.0",
"pydantic>=2.0.0",
"ray[default,train,tune]>=2.23.0; python_version < '3.12'",
"ray[default,train,tune]>=2.38.0; python_version >= '3.12'",
"safetensors>=0.4.5",
Expand Down
17 changes: 10 additions & 7 deletions src/stimulus/analysis/analysis_default.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""Default analysis module for stimulus package."""

import math
from typing import Any
from typing import Any, Union

import matplotlib as mpl
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from matplotlib.ticker import StrMethodFormatter
from torch.utils.data import DataLoader

from stimulus.data.handlertorch import TorchDataset
Expand Down Expand Up @@ -66,8 +66,11 @@ def heatmap(
im = ax.imshow(data, **kwargs)

# Create colorbar
cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw)
cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom")
if ax.figure is not None and hasattr(ax.figure, "colorbar"):
cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw)
cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom")
else:
cbar = None

# Show all ticks and label them with the respective list entries.
ax.set_xticks(np.arange(data.shape[1]), labels=col_labels)
Expand All @@ -93,7 +96,7 @@ def heatmap(
def annotate_heatmap(
im: Any,
data: np.ndarray | None = None,
valfmt: str = "{x:.2f}",
valfmt: Union[str, StrMethodFormatter] = "{x:.2f}",
textcolors: tuple[str, str] = ("black", "white"),
threshold: float | None = None,
**textkw: Any,
Expand Down Expand Up @@ -134,15 +137,15 @@ def annotate_heatmap(

# Get the formatter in case a string is supplied
if isinstance(valfmt, str):
valfmt = mpl.ticker.StrMethodFormatter(valfmt)
valfmt = StrMethodFormatter(valfmt)

# Loop over the data and create a `Text` for each "pixel".
# Change the text's color depending on the data.
texts = []
for i in range(data.shape[0]):
for j in range(data.shape[1]):
kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)])
text = im.axes.text(j, i, valfmt(data[i, j], None), **kw)
text = im.axes.text(j, i, valfmt(data[i, j]), **kw)
texts.append(text)

return texts
Expand Down
11 changes: 6 additions & 5 deletions src/stimulus/cli/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,11 @@ def parse_y_keys(y: dict[str, Any], data: pl.DataFrame, y_type: str = "pred") ->
return y

parsed_y = {}
for k1 in y:
for k1, v1 in y.items():
for k2 in data.columns:
if k1 == k2.split(":")[0]:
new_key = f"{k1}:{y_type}:{k2.split(':')[2]}"
parsed_y[new_key] = y[k1]
parsed_y[new_key] = v1

return parsed_y

Expand Down Expand Up @@ -140,8 +140,8 @@ def main(
data_path: str,
output: str,
*,
return_labels: bool,
split: int | None,
return_labels: bool = False,
split: int | None = None,
) -> None:
"""Run model prediction pipeline.
Expand Down Expand Up @@ -171,7 +171,8 @@ def main(
shuffle=False,
)

out = PredictWrapper(model, dataloader).predict(return_labels=return_labels)
predictor = PredictWrapper(model, dataloader)
out = predictor.predict(return_labels=return_labels)
y_pred, y_true = out if return_labels else (out, {})

y_pred = {k: v.tolist() for k, v in y_pred.items()}
Expand Down
2 changes: 1 addition & 1 deletion src/stimulus/cli/shuffle_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import json
import os

from stimulus.data.csv import CsvProcessing
from stimulus.data.data_handlers import CsvProcessing
from stimulus.utils.launch_utils import get_experiment


Expand Down
2 changes: 1 addition & 1 deletion src/stimulus/cli/split_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import json
import logging

from stimulus.data.csv import CsvProcessing
from stimulus.data.data_handlers import CsvProcessing
from stimulus.utils.launch_utils import get_experiment


Expand Down
11 changes: 6 additions & 5 deletions src/stimulus/cli/split_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""

import argparse
from typing import Any

import yaml

Expand Down Expand Up @@ -44,7 +45,7 @@ def get_args() -> argparse.Namespace:
return parser.parse_args()


def main(config_yaml: str, out_dir_path: str) -> str:
def main(config_yaml: str, out_dir_path: str) -> None:
"""Reads a YAML config file and generates all possible data configurations.
This script reads a YAML with a defined structure and creates all the YAML files ready to be passed to
Expand All @@ -58,16 +59,16 @@ def main(config_yaml: str, out_dir_path: str) -> str:
and uses the default split behavior.
"""
# read the yaml experiment config and load it to dictionary
yaml_config = {}
yaml_config: dict[str, Any] = {}
with open(config_yaml) as conf_file:
yaml_config = yaml.safe_load(conf_file)

yaml_config_dict: YamlConfigDict = YamlConfigDict(**yaml_config)
# check if the yaml schema is correct
check_yaml_schema(yaml_config)
check_yaml_schema(yaml_config_dict)

# generate all the YAML configs
config_dict = YamlConfigDict(**yaml_config)
data_configs = generate_data_configs(config_dict)
data_configs = generate_data_configs(yaml_config_dict)

# dump all the YAML configs into files
dump_yaml_list_into_files(data_configs, out_dir_path, "test")
Expand Down
2 changes: 1 addition & 1 deletion src/stimulus/cli/transform_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import argparse
import json

from stimulus.data.csv import CsvProcessing
from stimulus.data.data_handlers import CsvProcessing
from stimulus.utils.launch_utils import get_experiment


Expand Down
26 changes: 10 additions & 16 deletions src/stimulus/data/csv.py → src/stimulus/data/data_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def categorize_columns_by_type(self) -> dict:

return {"input": input_columns, "label": label_columns, "meta": meta_columns}

def _load_config(self, config_path: str) -> dict:
def _load_config(self, config_path: str) -> yaml_data.YamlConfigDict:
"""Loads and parses a YAML configuration file.
Args:
Expand All @@ -111,7 +111,7 @@ def _load_config(self, config_path: str) -> dict:
with open(config_path) as file:
return yaml_data.YamlSubConfigDict(**yaml.safe_load(file))

def get_split_columns(self) -> str:
def get_split_columns(self) -> list[str]:
"""Get the columns that are used for splitting."""
return self.config.split.split_input_columns

Expand Down Expand Up @@ -273,18 +273,15 @@ def __init__(
config_path: str,
csv_path: str,
) -> None:
"""Initialize the DatasetHandler with required loaders and config.
"""Initialize the DatasetHandler with required config.
Args:
encoder_loader (experiments.EncoderLoader): Loader for getting column encoders.
transform_loader (experiments.TransformLoader): Loader for getting data transformations.
split_loader (experiments.SplitLoader): Loader for getting dataset split configurations.
config_path (str): Path to the dataset configuration file.
csv_path (str): Path to the CSV data file.
split (int): The split to load, 0 is train, 1 is validation, 2 is test.
"""
self.dataset_manager = DatasetManager(config_path)
self.columns = self.read_csv_header(csv_path)
self.data = self.load_csv(csv_path)

def read_csv_header(self, csv_path: str) -> list:
"""Get the column names from the header of the CSV file.
Expand Down Expand Up @@ -344,10 +341,8 @@ def add_split(self, split_manager: SplitManager, *, force: bool = False) -> None
An error exception is raised if the split column is already present in the csv file. This behaviour can be overriden by setting force=True.
Args:
config (dict) : the dictionary containing the following keys:
"name" (str) : the split_function name, as defined in the splitters class and experiment.
"parameters" (dict) : the split_function specific optional parameters, passed here as a dict with keys named as in the split function definition.
force (bool) : If True, the split column present in the csv file will be overwritten.
split_manager (SplitManager): Manager for handling dataset splitting
force (bool): If True, the split column present in the csv file will be overwritten.
"""
if ("split" in self.columns) and (not force):
raise ValueError(
Expand Down Expand Up @@ -389,7 +384,7 @@ def shuffle_labels(self, seed: Optional[float] = None) -> None:
# set the np seed
np.random.seed(seed)

label_keys = self.dataset_manager.get_label_columns()["label"]
label_keys = self.dataset_manager.column_categories["label"]
for key in label_keys:
self.data = self.data.with_columns(pl.Series(key, np.random.permutation(list(self.data[key]))))

Expand Down Expand Up @@ -438,9 +433,9 @@ def get_all_items(self) -> tuple[dict, dict, dict]:
meta_data = {key: self.data[key].to_list() for key in meta_columns}
return input_data, label_data, meta_data

def get_all_items_and_length(self) -> tuple[dict, dict, dict, int]:
def get_all_items_and_length(self) -> tuple[tuple[dict, dict, dict], int]:
"""Get the full dataset as three separate dictionaries for inputs, labels and metadata, and the length of the data."""
return self.get_all_items(), len(self)
return self.get_all_items(), len(self.data)

def load_csv_per_split(self, csv_path: str, split: int) -> pl.DataFrame:
"""Load the part of csv file that has the specified split value.
Expand All @@ -461,7 +456,7 @@ def __len__(self) -> int:
"""Return the length of the first list in input, assumes that all are the same length."""
return len(self.data)

def __getitem__(self, idx: Any) -> dict:
def __getitem__(self, idx: Any) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor], dict[str, list]]:
"""Get the data at a given index, and encodes the input and label, leaving meta as it is.
Args:
Expand All @@ -471,7 +466,6 @@ def __getitem__(self, idx: Any) -> dict:
if isinstance(idx, slice):
data_at_index = self.data.slice(idx.start or 0, idx.stop or len(self.data))
elif isinstance(idx, int):
# Convert single row to DataFrame to maintain consistent interface
data_at_index = self.data.slice(idx, idx + 1)
else:
data_at_index = self.data[idx]
Expand Down
1 change: 1 addition & 0 deletions src/stimulus/data/encoding/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Encoding package for data transformation."""
Loading

0 comments on commit dabfebe

Please sign in to comment.