Skip to content

Commit 6b2f8e3

Browse files
ENH: refine classification API
1 parent d13c71b commit 6b2f8e3

File tree

1 file changed

+14
-39
lines changed

1 file changed

+14
-39
lines changed

mapie_v1/classification.py

+14-39
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import Optional, Union, List
3+
from typing import Optional, Union, Tuple, Iterable
44
from typing_extensions import Self
55

66
import numpy as np
@@ -16,7 +16,7 @@ class SplitConformalClassifier:
1616
def __init__(
1717
self,
1818
estimator: ClassifierMixin = LogisticRegression(),
19-
confidence_level: Union[float, List[float]] = 0.9,
19+
confidence_level: Union[float, Iterable[float]] = 0.9,
2020
conformity_score: Union[str, BaseClassificationScore] = "lac",
2121
prefit: bool = True,
2222
n_jobs: Optional[int] = None,
@@ -42,44 +42,31 @@ def conformalize(
4242
return self
4343

4444
def predict(self, X: ArrayLike) -> NDArray:
45-
"""
46-
Return
47-
-----
48-
Return ponctual prediction similar to predict method of
49-
scikit-learn classifiers
50-
Shape (n_samples,)
51-
"""
5245
return np.ndarray(0)
5346

54-
def predict_sets(
47+
def predict_set(
5548
self,
5649
X: ArrayLike,
5750
conformity_score_params: Optional[dict] = None,
5851
# Parameters specific to conformal method,
5952
# For example: include_last_label
60-
) -> NDArray:
53+
) -> Tuple[NDArray, NDArray]:
6154
"""
62-
Return
63-
-----
64-
An array containing the prediction sets
65-
Shape (n_samples, n_classes) if confidence_level is float,
66-
Shape (n_samples, n_classes, confidence_level) if confidence_level
67-
is a list of floats
55+
Shape: (n, ), (n, n_class, n_confidence_level)
6856
"""
69-
return np.ndarray(0)
57+
return np.ndarray(0), np.ndarray(0)
7058

7159

7260
class CrossConformalClassifier:
7361
def __init__(
7462
self,
7563
estimator: ClassifierMixin = LogisticRegression(),
76-
confidence_level: Union[float, List[float]] = 0.9,
64+
confidence_level: Union[float, Iterable[float]] = 0.9,
7765
conformity_score: Union[str, BaseClassificationScore] = "lac",
78-
cross_val: Union[BaseCrossValidator, str] = 5,
66+
cv: Union[int, BaseCrossValidator] = 5,
7967
n_jobs: Optional[int] = None,
8068
verbose: int = 0,
8169
random_state: Optional[Union[int, np.random.RandomState]] = None,
82-
8370
) -> None:
8471
pass
8572

@@ -95,34 +82,22 @@ def conformalize(
9582
self,
9683
X_conformalize: ArrayLike,
9784
y_conformalize: ArrayLike,
85+
groups: Optional[ArrayLike] = None,
9886
predict_params: Optional[dict] = None
9987
) -> Self:
10088
return self
10189

102-
def predict(self,
103-
X: ArrayLike) -> NDArray:
104-
"""
105-
Return
106-
-----
107-
Return ponctual prediction similar to predict method of
108-
scikit-learn classifiers
109-
Shape (n_samples,)
110-
"""
90+
def predict(self, X: ArrayLike) -> NDArray:
11191
return np.ndarray(0)
11292

113-
def predict_sets(
93+
def predict_set(
11494
self,
11595
X: ArrayLike,
11696
aggregation_method: Optional[str] = "mean",
11797
# How to aggregate the scores by the estimators on test data
11898
conformity_score_params: Optional[dict] = None
119-
) -> NDArray:
99+
) -> Tuple[NDArray, NDArray]:
120100
"""
121-
Return
122-
-----
123-
An array containing the prediction sets
124-
Shape (n_samples, n_classes) if confidence_level is float,
125-
Shape (n_samples, n_classes, confidence_level) if confidence_level
126-
is a list of floats
101+
Shape: (n, ), (n, n_class, n_confidence_level)
127102
"""
128-
return np.ndarray(0)
103+
return np.ndarray(0), np.ndarray(0)

0 commit comments

Comments
 (0)