Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fixed the documentation of split.py #1385

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 134 additions & 66 deletions openml/tasks/split.py
Original file line number Diff line number Diff line change
@@ -1,64 +1,92 @@
# License: BSD 3-Clause
from __future__ import annotations

import os
import pickle
from collections import OrderedDict
from pathlib import Path
from typing import Any
from typing_extensions import NamedTuple
from collections import OrderedDict, namedtuple

import arff # type: ignore
import arff
import numpy as np


class Split(NamedTuple):
"""A single split of a dataset."""

train: np.ndarray
test: np.ndarray
# Named tuple to represent a train-test split
Split = namedtuple("Split", ["train", "test"])


class OpenMLSplit:
"""OpenML Split object.
"""
Represents a split object for OpenML datasets.

This class manages train-test splits for a dataset across multiple
repetitions, folds, and samples.

Parameters
----------
name : int or str
The name or ID of the split.
description : str
A textual description of the split.
split : dict
A dictionary containing the splits organized by repetition, fold,
and sample.

Attributes
----------
name : int or str
The name or ID of the split.
description : str
Description of the split.
split : dict
Nested dictionary holding the train-test indices for each repetition,
fold, and sample.
repeats : int
Number of repetitions in the split.
folds : int
Number of folds in each repetition.
samples : int
Number of samples in each fold.

Raises
------
ValueError
If the number of folds is inconsistent across repetitions.
"""

def __init__(
self,
name: int | str,
description: str,
split: dict[int, dict[int, dict[int, tuple[np.ndarray, np.ndarray]]]],
):
def __init__(self, name, description, split):
self.description = description
self.name = name
self.split: dict[int, dict[int, dict[int, tuple[np.ndarray, np.ndarray]]]] = {}
self.split = {}

# Add splits according to repetition
# Populate splits according to repetitions
for repetition in split:
_rep = int(repetition)
self.split[_rep] = OrderedDict()
for fold in split[_rep]:
self.split[_rep][fold] = OrderedDict()
for sample in split[_rep][fold]:
self.split[_rep][fold][sample] = split[_rep][fold][sample]
repetition = int(repetition)
self.split[repetition] = OrderedDict()
for fold in split[repetition]:
self.split[repetition][fold] = OrderedDict()
for sample in split[repetition][fold]:
self.split[repetition][fold][sample] = split[repetition][fold][sample]

self.repeats = len(self.split)

# TODO(eddiebergman): Better error message
if any(len(self.split[0]) != len(self.split[i]) for i in range(self.repeats)):
raise ValueError("")

raise ValueError("Number of folds is inconsistent across repetitions.")
self.folds = len(self.split[0])
self.samples = len(self.split[0][0])

def __eq__(self, other: Any) -> bool:
def __eq__(self, other):
"""
Check if two OpenMLSplit objects are equal.

Parameters
----------
other : OpenMLSplit
Another OpenMLSplit object to compare against.

Returns
-------
bool
True if the objects are equal, False otherwise.
"""
if (
(not isinstance(self, type(other)))
type(self) != type(other)
or self.name != other.name
or self.description != other.description
or self.split.keys() != other.split.keys()
Expand All @@ -84,40 +112,56 @@ def __eq__(self, other: Any) -> bool:
return True

@classmethod
def _from_arff_file(cls, filename: Path) -> OpenMLSplit: # noqa: C901, PLR0912
repetitions = None
name = None
def _from_arff_file(cls, filename: str) -> OpenMLSplit:
"""
Create an OpenMLSplit object from an ARFF file.

pkl_filename = filename.with_suffix(".pkl.py3")
Parameters
----------
filename : str
Path to the ARFF file.

if pkl_filename.exists():
with pkl_filename.open("rb") as fh:
# TODO(eddiebergman): Would be good to figure out what _split is and assert it is
_split = pickle.load(fh) # noqa: S301
repetitions = _split["repetitions"]
name = _split["name"]
Returns
-------
OpenMLSplit
The constructed OpenMLSplit object.

# Cache miss
if repetitions is None:
# Faster than liac-arff and sufficient in this situation!
if not filename.exists():
raise FileNotFoundError(f"Split arff {filename} does not exist!")
Raises
------
FileNotFoundError
If the ARFF file does not exist.
ValueError
If an unknown split type is encountered.
"""
repetitions = None
pkl_filename = filename.replace(".arff", ".pkl.py3")

file_data = arff.load(filename.open("r"), return_type=arff.DENSE_GEN)
# Try loading from a cached pickle file
if os.path.exists(pkl_filename):
with open(pkl_filename, "rb") as fh:
_ = pickle.load(fh)
repetitions = _["repetitions"]
name = _["name"]

# Cache miss: load from ARFF file
if repetitions is None:
if not os.path.exists(filename):
raise FileNotFoundError(f"Split ARFF file {filename} does not exist.")
file_data = arff.load(open(filename), return_type=arff.DENSE_GEN)
splits = file_data["data"]
name = file_data["relation"]
attrnames = [attr[0] for attr in file_data["attributes"]]

repetitions = OrderedDict()

# Identify attribute indices
type_idx = attrnames.index("type")
rowid_idx = attrnames.index("rowid")
repeat_idx = attrnames.index("repeat")
fold_idx = attrnames.index("fold")
sample_idx = attrnames.index("sample") if "sample" in attrnames else None

for line in splits:
# A line looks like type, rowid, repeat, fold
repetition = int(line[repeat_idx])
fold = int(line[fold_idx])
sample = 0
Expand All @@ -138,8 +182,9 @@ def _from_arff_file(cls, filename: Path) -> OpenMLSplit: # noqa: C901, PLR0912
elif type_ == "TEST":
split[1].append(line[rowid_idx])
else:
raise ValueError(type_)
raise ValueError(f"Unknown split type: {type_}")

# Convert lists to numpy arrays
for repetition in repetitions:
for fold in repetitions[repetition]:
for sample in repetitions[repetition][fold]:
Expand All @@ -148,38 +193,61 @@ def _from_arff_file(cls, filename: Path) -> OpenMLSplit: # noqa: C901, PLR0912
np.array(repetitions[repetition][fold][sample][1], dtype=np.int32),
)

with pkl_filename.open("wb") as fh:
# Cache the parsed splits
with open(pkl_filename, "wb") as fh:
pickle.dump({"name": name, "repetitions": repetitions}, fh, protocol=2)

assert name is not None
return cls(name, "", repetitions)

def get(self, repeat: int = 0, fold: int = 0, sample: int = 0) -> tuple[np.ndarray, np.ndarray]:
"""Returns the specified data split from the CrossValidationSplit object.
def from_dataset(self, X, Y, folds, repeats):
"""
Construct splits from a dataset.

Parameters
----------
X : array-like
Feature matrix.
Y : array-like
Target array.
folds : int
Number of folds.
repeats : int
Number of repetitions.

Raises
------
NotImplementedError
This method is not yet implemented.
"""
raise NotImplementedError("from_dataset method is not implemented.")

def get(self, repeat=0, fold=0, sample=0):
"""
Retrieve a specific split.

Parameters
----------
repeat : int
Index of the repeat to retrieve.
fold : int
Index of the fold to retrieve.
sample : int
Index of the sample to retrieve.
repeat : int, optional
Repetition index (default is 0).
fold : int, optional
Fold index (default is 0).
sample : int, optional
Sample index (default is 0).

Returns
-------
numpy.ndarray
The data split for the specified repeat, fold, and sample.
Split
A named tuple containing train and test indices.

Raises
------
ValueError
If the specified repeat, fold, or sample is not known.
If the specified repeat, fold, or sample does not exist.
"""
if repeat not in self.split:
raise ValueError(f"Repeat {repeat!s} not known")
raise ValueError(f"Repeat {repeat} not known.")
if fold not in self.split[repeat]:
raise ValueError(f"Fold {fold!s} not known")
raise ValueError(f"Fold {fold} not known.")
if sample not in self.split[repeat][fold]:
raise ValueError(f"Sample {sample!s} not known")
raise ValueError(f"Sample {sample} not known.")
return self.split[repeat][fold][sample]