From 17312f55a01890882e2c0d7026b8050aa674f19f Mon Sep 17 00:00:00 2001 From: VHemanth45 Date: Wed, 11 Dec 2024 17:30:41 +0530 Subject: [PATCH 1/2] fixed the documentation of split.py --- openml/tasks/split.py | 202 ++++++++++++++++++++++++++++-------------- 1 file changed, 135 insertions(+), 67 deletions(-) diff --git a/openml/tasks/split.py b/openml/tasks/split.py index ac538496e..569041fee 100644 --- a/openml/tasks/split.py +++ b/openml/tasks/split.py @@ -1,64 +1,92 @@ # License: BSD 3-Clause -from __future__ import annotations +from collections import namedtuple, OrderedDict +import os import pickle -from collections import OrderedDict -from pathlib import Path -from typing import Any -from typing_extensions import NamedTuple -import arff # type: ignore import numpy as np +import arff -class Split(NamedTuple): - """A single split of a dataset.""" - - train: np.ndarray - test: np.ndarray +# Named tuple to represent a train-test split +Split = namedtuple("Split", ["train", "test"]) class OpenMLSplit: - """OpenML Split object. + """ + Represents a split object for OpenML datasets. + + This class manages train-test splits for a dataset across multiple + repetitions, folds, and samples. Parameters ---------- name : int or str + The name or ID of the split. description : str + A textual description of the split. split : dict + A dictionary containing the splits organized by repetition, fold, + and sample. + + Attributes + ---------- + name : int or str + The name or ID of the split. + description : str + Description of the split. + split : dict + Nested dictionary holding the train-test indices for each repetition, + fold, and sample. + repeats : int + Number of repetitions in the split. + folds : int + Number of folds in each repetition. + samples : int + Number of samples in each fold. + + Raises + ------ + ValueError + If the number of folds is inconsistent across repetitions. """ - def __init__( - self, - name: int | str, - description: str, - split: dict[int, dict[int, dict[int, tuple[np.ndarray, np.ndarray]]]], - ): + def __init__(self, name, description, split): self.description = description self.name = name - self.split: dict[int, dict[int, dict[int, tuple[np.ndarray, np.ndarray]]]] = {} + self.split = dict() - # Add splits according to repetition + # Populate splits according to repetitions for repetition in split: - _rep = int(repetition) - self.split[_rep] = OrderedDict() - for fold in split[_rep]: - self.split[_rep][fold] = OrderedDict() - for sample in split[_rep][fold]: - self.split[_rep][fold][sample] = split[_rep][fold][sample] + repetition = int(repetition) + self.split[repetition] = OrderedDict() + for fold in split[repetition]: + self.split[repetition][fold] = OrderedDict() + for sample in split[repetition][fold]: + self.split[repetition][fold][sample] = split[repetition][fold][sample] self.repeats = len(self.split) - - # TODO(eddiebergman): Better error message - if any(len(self.split[0]) != len(self.split[i]) for i in range(self.repeats)): - raise ValueError("") - + if any([len(self.split[0]) != len(self.split[i]) for i in range(self.repeats)]): + raise ValueError("Number of folds is inconsistent across repetitions.") self.folds = len(self.split[0]) self.samples = len(self.split[0][0]) - def __eq__(self, other: Any) -> bool: + def __eq__(self, other): + """ + Check if two OpenMLSplit objects are equal. + + Parameters + ---------- + other : OpenMLSplit + Another OpenMLSplit object to compare against. + + Returns + ------- + bool + True if the objects are equal, False otherwise. + """ if ( - (not isinstance(self, type(other))) + type(self) != type(other) or self.name != other.name or self.description != other.description or self.split.keys() != other.split.keys() @@ -84,32 +112,49 @@ def __eq__(self, other: Any) -> bool: return True @classmethod - def _from_arff_file(cls, filename: Path) -> OpenMLSplit: # noqa: C901, PLR0912 - repetitions = None - name = None + def _from_arff_file(cls, filename: str) -> "OpenMLSplit": + """ + Create an OpenMLSplit object from an ARFF file. + + Parameters + ---------- + filename : str + Path to the ARFF file. - pkl_filename = filename.with_suffix(".pkl.py3") + Returns + ------- + OpenMLSplit + The constructed OpenMLSplit object. - if pkl_filename.exists(): - with pkl_filename.open("rb") as fh: - # TODO(eddiebergman): Would be good to figure out what _split is and assert it is - _split = pickle.load(fh) # noqa: S301 - repetitions = _split["repetitions"] - name = _split["name"] + Raises + ------ + FileNotFoundError + If the ARFF file does not exist. + ValueError + If an unknown split type is encountered. + """ + repetitions = None + pkl_filename = filename.replace(".arff", ".pkl.py3") - # Cache miss - if repetitions is None: - # Faster than liac-arff and sufficient in this situation! - if not filename.exists(): - raise FileNotFoundError(f"Split arff {filename} does not exist!") + # Try loading from a cached pickle file + if os.path.exists(pkl_filename): + with open(pkl_filename, "rb") as fh: + _ = pickle.load(fh) + repetitions = _["repetitions"] + name = _["name"] - file_data = arff.load(filename.open("r"), return_type=arff.DENSE_GEN) + # Cache miss: load from ARFF file + if repetitions is None: + if not os.path.exists(filename): + raise FileNotFoundError(f"Split ARFF file {filename} does not exist.") + file_data = arff.load(open(filename), return_type=arff.DENSE_GEN) splits = file_data["data"] name = file_data["relation"] attrnames = [attr[0] for attr in file_data["attributes"]] repetitions = OrderedDict() + # Identify attribute indices type_idx = attrnames.index("type") rowid_idx = attrnames.index("rowid") repeat_idx = attrnames.index("repeat") @@ -117,7 +162,6 @@ def _from_arff_file(cls, filename: Path) -> OpenMLSplit: # noqa: C901, PLR0912 sample_idx = attrnames.index("sample") if "sample" in attrnames else None for line in splits: - # A line looks like type, rowid, repeat, fold repetition = int(line[repeat_idx]) fold = int(line[fold_idx]) sample = 0 @@ -138,8 +182,9 @@ def _from_arff_file(cls, filename: Path) -> OpenMLSplit: # noqa: C901, PLR0912 elif type_ == "TEST": split[1].append(line[rowid_idx]) else: - raise ValueError(type_) + raise ValueError(f"Unknown split type: {type_}") + # Convert lists to numpy arrays for repetition in repetitions: for fold in repetitions[repetition]: for sample in repetitions[repetition][fold]: @@ -148,38 +193,61 @@ def _from_arff_file(cls, filename: Path) -> OpenMLSplit: # noqa: C901, PLR0912 np.array(repetitions[repetition][fold][sample][1], dtype=np.int32), ) - with pkl_filename.open("wb") as fh: + # Cache the parsed splits + with open(pkl_filename, "wb") as fh: pickle.dump({"name": name, "repetitions": repetitions}, fh, protocol=2) - assert name is not None return cls(name, "", repetitions) - def get(self, repeat: int = 0, fold: int = 0, sample: int = 0) -> tuple[np.ndarray, np.ndarray]: - """Returns the specified data split from the CrossValidationSplit object. + def from_dataset(self, X, Y, folds, repeats): + """ + Construct splits from a dataset. + + Parameters + ---------- + X : array-like + Feature matrix. + Y : array-like + Target array. + folds : int + Number of folds. + repeats : int + Number of repetitions. + + Raises + ------ + NotImplementedError + This method is not yet implemented. + """ + raise NotImplementedError("from_dataset method is not implemented.") + + def get(self, repeat=0, fold=0, sample=0): + """ + Retrieve a specific split. Parameters ---------- - repeat : int - Index of the repeat to retrieve. - fold : int - Index of the fold to retrieve. - sample : int - Index of the sample to retrieve. + repeat : int, optional + Repetition index (default is 0). + fold : int, optional + Fold index (default is 0). + sample : int, optional + Sample index (default is 0). Returns ------- - numpy.ndarray - The data split for the specified repeat, fold, and sample. + Split + A named tuple containing train and test indices. Raises ------ ValueError - If the specified repeat, fold, or sample is not known. + If the specified repeat, fold, or sample does not exist. """ if repeat not in self.split: - raise ValueError(f"Repeat {repeat!s} not known") + raise ValueError(f"Repeat {repeat} not known.") if fold not in self.split[repeat]: - raise ValueError(f"Fold {fold!s} not known") + raise ValueError(f"Fold {fold} not known.") if sample not in self.split[repeat][fold]: - raise ValueError(f"Sample {sample!s} not known") + raise ValueError(f"Sample {sample} not known.") return self.split[repeat][fold][sample] From cb1262e846a1b92a31e7158acb5d62d5d78f7278 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 11 Dec 2024 12:08:11 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- openml/tasks/split.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/openml/tasks/split.py b/openml/tasks/split.py index 569041fee..f93a887fb 100644 --- a/openml/tasks/split.py +++ b/openml/tasks/split.py @@ -1,12 +1,12 @@ # License: BSD 3-Clause +from __future__ import annotations -from collections import namedtuple, OrderedDict import os import pickle +from collections import OrderedDict, namedtuple -import numpy as np import arff - +import numpy as np # Named tuple to represent a train-test split Split = namedtuple("Split", ["train", "test"]) @@ -16,7 +16,7 @@ class OpenMLSplit: """ Represents a split object for OpenML datasets. - This class manages train-test splits for a dataset across multiple + This class manages train-test splits for a dataset across multiple repetitions, folds, and samples. Parameters @@ -26,7 +26,7 @@ class OpenMLSplit: description : str A textual description of the split. split : dict - A dictionary containing the splits organized by repetition, fold, + A dictionary containing the splits organized by repetition, fold, and sample. Attributes @@ -36,7 +36,7 @@ class OpenMLSplit: description : str Description of the split. split : dict - Nested dictionary holding the train-test indices for each repetition, + Nested dictionary holding the train-test indices for each repetition, fold, and sample. repeats : int Number of repetitions in the split. @@ -54,7 +54,7 @@ class OpenMLSplit: def __init__(self, name, description, split): self.description = description self.name = name - self.split = dict() + self.split = {} # Populate splits according to repetitions for repetition in split: @@ -66,7 +66,7 @@ def __init__(self, name, description, split): self.split[repetition][fold][sample] = split[repetition][fold][sample] self.repeats = len(self.split) - if any([len(self.split[0]) != len(self.split[i]) for i in range(self.repeats)]): + if any(len(self.split[0]) != len(self.split[i]) for i in range(self.repeats)): raise ValueError("Number of folds is inconsistent across repetitions.") self.folds = len(self.split[0]) self.samples = len(self.split[0][0]) @@ -112,7 +112,7 @@ def __eq__(self, other): return True @classmethod - def _from_arff_file(cls, filename: str) -> "OpenMLSplit": + def _from_arff_file(cls, filename: str) -> OpenMLSplit: """ Create an OpenMLSplit object from an ARFF file.