Skip to content

Commit

Permalink
FIX: make type checker happy
Browse files Browse the repository at this point in the history
  • Loading branch information
mathysgrapotte committed Jan 22, 2025
1 parent e365fc8 commit d40715f
Show file tree
Hide file tree
Showing 22 changed files with 286 additions and 204 deletions.
1 change: 1 addition & 0 deletions config/mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ ignore_missing_imports = true
exclude = tests/fixtures/
warn_unused_ignores = true
show_error_codes = true
explicit_package_bases = True
17 changes: 10 additions & 7 deletions src/stimulus/analysis/analysis_default.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""Default analysis module for stimulus package."""

import math
from typing import Any
from typing import Any, Union

import matplotlib as mpl
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from matplotlib.ticker import StrMethodFormatter
from torch.utils.data import DataLoader

from stimulus.data.handlertorch import TorchDataset
Expand Down Expand Up @@ -66,8 +66,11 @@ def heatmap(
im = ax.imshow(data, **kwargs)

# Create colorbar
cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw)
cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom")
if ax.figure is not None and hasattr(ax.figure, "colorbar"):
cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw)
cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom")
else:
cbar = None

# Show all ticks and label them with the respective list entries.
ax.set_xticks(np.arange(data.shape[1]), labels=col_labels)
Expand All @@ -93,7 +96,7 @@ def heatmap(
def annotate_heatmap(
im: Any,
data: np.ndarray | None = None,
valfmt: str = "{x:.2f}",
valfmt: Union[str, StrMethodFormatter] = "{x:.2f}",
textcolors: tuple[str, str] = ("black", "white"),
threshold: float | None = None,
**textkw: Any,
Expand Down Expand Up @@ -134,15 +137,15 @@ def annotate_heatmap(

# Get the formatter in case a string is supplied
if isinstance(valfmt, str):
valfmt = mpl.ticker.StrMethodFormatter(valfmt)
valfmt = StrMethodFormatter(valfmt)

# Loop over the data and create a `Text` for each "pixel".
# Change the text's color depending on the data.
texts = []
for i in range(data.shape[0]):
for j in range(data.shape[1]):
kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)])
text = im.axes.text(j, i, valfmt(data[i, j], None), **kw)
text = im.axes.text(j, i, valfmt(data[i, j]), **kw)
texts.append(text)

return texts
Expand Down
7 changes: 4 additions & 3 deletions src/stimulus/cli/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ def main(
data_path: str,
output: str,
*,
return_labels: bool,
split: int | None,
return_labels: bool = False,
split: int | None = None,
) -> None:
"""Run model prediction pipeline.
Expand Down Expand Up @@ -171,7 +171,8 @@ def main(
shuffle=False,
)

out = PredictWrapper(model, dataloader).predict(return_labels=return_labels)
predictor = PredictWrapper(model, dataloader)
out = predictor.predict(return_labels=return_labels)
y_pred, y_true = out if return_labels else (out, {})

y_pred = {k: v.tolist() for k, v in y_pred.items()}
Expand Down
11 changes: 6 additions & 5 deletions src/stimulus/cli/split_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""

import argparse
from typing import Any

import yaml

Expand Down Expand Up @@ -44,7 +45,7 @@ def get_args() -> argparse.Namespace:
return parser.parse_args()


def main(config_yaml: str, out_dir_path: str) -> str:
def main(config_yaml: str, out_dir_path: str) -> None:
"""Reads a YAML config file and generates all possible data configurations.
This script reads a YAML with a defined structure and creates all the YAML files ready to be passed to
Expand All @@ -58,16 +59,16 @@ def main(config_yaml: str, out_dir_path: str) -> str:
and uses the default split behavior.
"""
# read the yaml experiment config and load it to dictionary
yaml_config = {}
yaml_config: dict[str, Any] = {}
with open(config_yaml) as conf_file:
yaml_config = yaml.safe_load(conf_file)

yaml_config_dict: YamlConfigDict = YamlConfigDict(**yaml_config)
# check if the yaml schema is correct
check_yaml_schema(yaml_config)
check_yaml_schema(yaml_config_dict)

# generate all the YAML configs
config_dict = YamlConfigDict(**yaml_config)
data_configs = generate_data_configs(config_dict)
data_configs = generate_data_configs(yaml_config_dict)

# dump all the YAML configs into files
dump_yaml_list_into_files(data_configs, out_dir_path, "test")
Expand Down
14 changes: 7 additions & 7 deletions src/stimulus/data/data_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def categorize_columns_by_type(self) -> dict:

return {"input": input_columns, "label": label_columns, "meta": meta_columns}

def _load_config(self, config_path: str) -> dict:
def _load_config(self, config_path: str) -> yaml_data.YamlConfigDict:
"""Loads and parses a YAML configuration file.
Args:
Expand All @@ -111,7 +111,7 @@ def _load_config(self, config_path: str) -> dict:
with open(config_path) as file:
return yaml_data.YamlSubConfigDict(**yaml.safe_load(file))

def get_split_columns(self) -> str:
def get_split_columns(self) -> list[str]:
"""Get the columns that are used for splitting."""
return self.config.split.split_input_columns

Expand Down Expand Up @@ -281,6 +281,7 @@ def __init__(
"""
self.dataset_manager = DatasetManager(config_path)
self.columns = self.read_csv_header(csv_path)
self.data = self.load_csv(csv_path)

def read_csv_header(self, csv_path: str) -> list:
"""Get the column names from the header of the CSV file.
Expand Down Expand Up @@ -383,7 +384,7 @@ def shuffle_labels(self, seed: Optional[float] = None) -> None:
# set the np seed
np.random.seed(seed)

label_keys = self.dataset_manager.get_label_columns()["label"]
label_keys = self.dataset_manager.column_categories["label"]
for key in label_keys:
self.data = self.data.with_columns(pl.Series(key, np.random.permutation(list(self.data[key]))))

Expand Down Expand Up @@ -432,9 +433,9 @@ def get_all_items(self) -> tuple[dict, dict, dict]:
meta_data = {key: self.data[key].to_list() for key in meta_columns}
return input_data, label_data, meta_data

def get_all_items_and_length(self) -> tuple[dict, dict, dict, int]:
def get_all_items_and_length(self) -> tuple[tuple[dict, dict, dict], int]:
"""Get the full dataset as three separate dictionaries for inputs, labels and metadata, and the length of the data."""
return self.get_all_items(), len(self)
return self.get_all_items(), len(self.data)

def load_csv_per_split(self, csv_path: str, split: int) -> pl.DataFrame:
"""Load the part of csv file that has the specified split value.
Expand All @@ -455,7 +456,7 @@ def __len__(self) -> int:
"""Return the length of the first list in input, assumes that all are the same length."""
return len(self.data)

def __getitem__(self, idx: Any) -> dict:
def __getitem__(self, idx: Any) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor], dict[str, list]]:
"""Get the data at a given index, and encodes the input and label, leaving meta as it is.
Args:
Expand All @@ -465,7 +466,6 @@ def __getitem__(self, idx: Any) -> dict:
if isinstance(idx, slice):
data_at_index = self.data.slice(idx.start or 0, idx.stop or len(self.data))
elif isinstance(idx, int):
# Convert single row to DataFrame to maintain consistent interface
data_at_index = self.data.slice(idx, idx + 1)
else:
data_at_index = self.data[idx]
Expand Down
Loading

0 comments on commit d40715f

Please sign in to comment.