1
1
from __future__ import annotations
2
2
3
- from typing import Optional , Union , List
3
+ from typing import Optional , Union , Tuple , Iterable
4
4
from typing_extensions import Self
5
5
6
6
import numpy as np
@@ -16,7 +16,7 @@ class SplitConformalClassifier:
16
16
def __init__ (
17
17
self ,
18
18
estimator : ClassifierMixin = LogisticRegression (),
19
- confidence_level : Union [float , List [float ]] = 0.9 ,
19
+ confidence_level : Union [float , Iterable [float ]] = 0.9 ,
20
20
conformity_score : Union [str , BaseClassificationScore ] = "lac" ,
21
21
prefit : bool = True ,
22
22
n_jobs : Optional [int ] = None ,
@@ -42,44 +42,31 @@ def conformalize(
42
42
return self
43
43
44
44
def predict (self , X : ArrayLike ) -> NDArray :
45
- """
46
- Return
47
- -----
48
- Return ponctual prediction similar to predict method of
49
- scikit-learn classifiers
50
- Shape (n_samples,)
51
- """
52
45
return np .ndarray (0 )
53
46
54
- def predict_sets (
47
+ def predict_set (
55
48
self ,
56
49
X : ArrayLike ,
57
50
conformity_score_params : Optional [dict ] = None ,
58
51
# Parameters specific to conformal method,
59
52
# For example: include_last_label
60
- ) -> NDArray :
53
+ ) -> Tuple [ NDArray , NDArray ] :
61
54
"""
62
- Return
63
- -----
64
- An array containing the prediction sets
65
- Shape (n_samples, n_classes) if confidence_level is float,
66
- Shape (n_samples, n_classes, confidence_level) if confidence_level
67
- is a list of floats
55
+ Shape: (n, ), (n, n_class, n_confidence_level)
68
56
"""
69
- return np .ndarray (0 )
57
+ return np .ndarray (0 ), np . ndarray ( 0 )
70
58
71
59
72
60
class CrossConformalClassifier :
73
61
def __init__ (
74
62
self ,
75
63
estimator : ClassifierMixin = LogisticRegression (),
76
- confidence_level : Union [float , List [float ]] = 0.9 ,
64
+ confidence_level : Union [float , Iterable [float ]] = 0.9 ,
77
65
conformity_score : Union [str , BaseClassificationScore ] = "lac" ,
78
- cross_val : Union [BaseCrossValidator , str ] = 5 ,
66
+ cv : Union [int , BaseCrossValidator ] = 5 ,
79
67
n_jobs : Optional [int ] = None ,
80
68
verbose : int = 0 ,
81
69
random_state : Optional [Union [int , np .random .RandomState ]] = None ,
82
-
83
70
) -> None :
84
71
pass
85
72
@@ -95,34 +82,22 @@ def conformalize(
95
82
self ,
96
83
X_conformalize : ArrayLike ,
97
84
y_conformalize : ArrayLike ,
85
+ groups : Optional [ArrayLike ] = None ,
98
86
predict_params : Optional [dict ] = None
99
87
) -> Self :
100
88
return self
101
89
102
- def predict (self ,
103
- X : ArrayLike ) -> NDArray :
104
- """
105
- Return
106
- -----
107
- Return ponctual prediction similar to predict method of
108
- scikit-learn classifiers
109
- Shape (n_samples,)
110
- """
90
+ def predict (self , X : ArrayLike ) -> NDArray :
111
91
return np .ndarray (0 )
112
92
113
- def predict_sets (
93
+ def predict_set (
114
94
self ,
115
95
X : ArrayLike ,
116
96
aggregation_method : Optional [str ] = "mean" ,
117
97
# How to aggregate the scores by the estimators on test data
118
98
conformity_score_params : Optional [dict ] = None
119
- ) -> NDArray :
99
+ ) -> Tuple [ NDArray , NDArray ] :
120
100
"""
121
- Return
122
- -----
123
- An array containing the prediction sets
124
- Shape (n_samples, n_classes) if confidence_level is float,
125
- Shape (n_samples, n_classes, confidence_level) if confidence_level
126
- is a list of floats
101
+ Shape: (n, ), (n, n_class, n_confidence_level)
127
102
"""
128
- return np .ndarray (0 )
103
+ return np .ndarray (0 ), np . ndarray ( 0 )
0 commit comments