12
12
import time
13
13
from scipy .misc import logsumexp
14
14
from scipy .optimize import minimize
15
+ from sklearn .preprocessing import OneHotEncoder
15
16
16
17
from ..base import BaseEstimator , TransformerMixin
17
18
from ..preprocessing import LabelEncoder
22
23
from ..externals .six import integer_types
23
24
24
25
25
- class NeighborhoodComponentAnalysis (BaseEstimator , TransformerMixin ):
26
- """Neighborhood Component Analysis
26
+ class NeighborhoodComponentsAnalysis (BaseEstimator , TransformerMixin ):
27
+ """Neighborhood Components Analysis
27
28
28
29
Parameters
29
30
----------
@@ -98,16 +99,16 @@ class NeighborhoodComponentAnalysis(BaseEstimator, TransformerMixin):
98
99
99
100
Examples
100
101
--------
101
- >>> from sklearn.neighbors.nca import NeighborhoodComponentAnalysis
102
+ >>> from sklearn.neighbors.nca import NeighborhoodComponentsAnalysis
102
103
>>> from sklearn.neighbors import KNeighborsClassifier
103
104
>>> from sklearn.datasets import load_iris
104
105
>>> from sklearn.model_selection import train_test_split
105
106
>>> X, y = load_iris(return_X_y=True)
106
107
>>> X_train, X_test, y_train, y_test = train_test_split(X, y,
107
108
... stratify=y, test_size=0.7, random_state=42)
108
- >>> nca = NeighborhoodComponentAnalysis (None,random_state=42)
109
+ >>> nca = NeighborhoodComponentsAnalysis (None,random_state=42)
109
110
>>> nca.fit(X_train, y_train) # doctest: +ELLIPSIS
110
- NeighborhoodComponentAnalysis (...)
111
+ NeighborhoodComponentsAnalysis (...)
111
112
>>> knn = KNeighborsClassifier(n_neighbors=3)
112
113
>>> knn.fit(X_train, y_train) # doctest: +ELLIPSIS
113
114
KNeighborsClassifier(...)
@@ -123,23 +124,21 @@ class NeighborhoodComponentAnalysis(BaseEstimator, TransformerMixin):
123
124
Neighborhood Component Analysis (NCA) is a machine learning algorithm for
124
125
metric learning. It learns a linear transformation in a supervised fashion
125
126
to improve the classification accuracy of a stochastic nearest neighbors
126
- rule in the new space.
127
-
128
- .. warning::
129
-
130
- As NCA is optimizing a non-convex objective function, it will
131
- likely end up in a local optimum. Several runs with independent random
132
- init might be necessary to get a good convergence.
127
+ rule in the transformed space.
133
128
134
129
References
135
130
----------
136
131
.. [1] J. Goldberger, G. Hinton, S. Roweis, R. Salakhutdinov.
137
132
"Neighbourhood Components Analysis". Advances in Neural Information
138
133
Processing Systems. 17, 513-520, 2005.
139
134
http://www.cs.nyu.edu/~roweis/papers/ncanips.pdf
135
+
136
+ .. [2] Wikipedia entry on Neighborhood Components Analysis
137
+ https://en.wikipedia.org/wiki/Neighbourhood_components_analysis
138
+
140
139
"""
141
140
142
- def __init__ (self , n_features_out = None , init = 'identity ' , max_iter = 50 ,
141
+ def __init__ (self , n_features_out = None , init = 'pca ' , max_iter = 50 ,
143
142
tol = 1e-5 , callback = None , store_opt_result = False , verbose = 0 ,
144
143
random_state = None ):
145
144
@@ -167,7 +166,7 @@ def fit(self, X, y):
167
166
Returns
168
167
-------
169
168
self : object
170
- returns a trained NeighborhoodComponentAnalysis model.
169
+ returns a trained NeighborhoodComponentsAnalysis model.
171
170
"""
172
171
173
172
# Verify inputs X and y and NCA parameters, and transform a copy if
@@ -182,7 +181,8 @@ def fit(self, X, y):
182
181
183
182
# Compute arrays that stay fixed during optimization:
184
183
# mask for fast lookup of same-class samples
185
- masks = _make_masks (y_valid )
184
+ masks = OneHotEncoder (sparse = False ,
185
+ dtype = bool ).fit_transform (y_valid [:, np .newaxis ])
186
186
# pairwise differences
187
187
diffs = X_valid [:, np .newaxis ] - X_valid [np .newaxis ]
188
188
@@ -193,7 +193,7 @@ def fit(self, X, y):
193
193
disp = self .verbose - 2 if self .verbose > 1 else - 1
194
194
optimizer_params = {'method' : 'L-BFGS-B' ,
195
195
'fun' : self ._loss_grad_lbfgs ,
196
- 'args' : (X_valid , y_valid , diffs , masks ),
196
+ 'args' : (X_valid , y_valid , diffs , masks , - 1.0 ),
197
197
'jac' : True ,
198
198
'x0' : transformation ,
199
199
'tol' : self .tol ,
@@ -401,7 +401,7 @@ def _callback(self, transformation):
401
401
self .n_iter_ += 1
402
402
403
403
def _loss_grad_lbfgs (self , transformation , X , y , diffs ,
404
- masks ):
404
+ masks , sign = 1.0 ):
405
405
"""Compute the loss and the loss gradient w.r.t. ``transformation``.
406
406
407
407
Parameters
@@ -448,23 +448,29 @@ def _loss_grad_lbfgs(self, transformation, X, y, diffs,
448
448
gradient = np .zeros (transformation .shape )
449
449
X_embedded = transformation .dot (X .T ).T
450
450
451
- # for every sample, compute its contribution to loss and gradient
451
+ # for every sample x_i , compute its contribution to loss and gradient
452
452
for i in range (X .shape [0 ]):
453
+ # compute distances to x_i in embedded space
453
454
diff_embedded = X_embedded [i ] - X_embedded
454
- sum_of_squares = np .einsum ('ij,ij->i' , diff_embedded ,
455
- diff_embedded )
456
- sum_of_squares [i ] = np .inf
457
- soft = np .exp (- sum_of_squares - logsumexp (- sum_of_squares ))
458
- ci = masks [:, y [i ]]
459
- p_i_j = soft [ci ]
460
- not_ci = np .logical_not (ci )
455
+ dist_embedded = np .einsum ('ij,ij->i' , diff_embedded ,
456
+ diff_embedded )
457
+ dist_embedded [i ] = np .inf
458
+
459
+ # compute exponentiated distances (use the log-sum-exp trick to
460
+ # avoid numerical instabilities
461
+ exp_dist_embedded = np .exp (- dist_embedded -
462
+ logsumexp (- dist_embedded ))
463
+ ci = masks [:, y [i ]] # samples that are in the same class as x_i
464
+ p_i_j = exp_dist_embedded [ci ]
461
465
diff_ci = diffs [i , ci , :]
462
- diff_not_ci = diffs [i , not_ci , :]
466
+ diff_not_ci = diffs [i , ~ ci , :]
463
467
sum_ci = diff_ci .T .dot (
464
468
(p_i_j [:, np .newaxis ] * diff_embedded [ci , :]))
465
- sum_not_ci = diff_not_ci .T .dot ((soft [not_ci ][:, np .newaxis ] *
466
- diff_embedded [not_ci , :]))
467
- p_i = np .sum (p_i_j )
469
+ sum_not_ci = diff_not_ci .T .dot ((exp_dist_embedded [~ ci ][:,
470
+ np .newaxis ] *
471
+ diff_embedded [~ ci , :]))
472
+ p_i = np .sum (p_i_j ) # probability of x_i to be correctly
473
+ # classified
468
474
gradient += 2 * (p_i * (sum_ci .T + sum_not_ci .T ) - sum_ci .T )
469
475
loss += p_i
470
476
@@ -475,7 +481,7 @@ def _loss_grad_lbfgs(self, transformation, X, y, diffs,
475
481
loss , t_funcall ))
476
482
sys .stdout .flush ()
477
483
478
- return - loss , - gradient .ravel ()
484
+ return sign * loss , sign * gradient .ravel ()
479
485
480
486
481
487
##########################
@@ -538,8 +544,9 @@ def _make_masks(y):
538
544
masks: array, shape (n_samples, n_classes)
539
545
One-hot encoding of ``y``.
540
546
"""
541
-
542
- n = y .shape [0 ]
543
- masks = np .zeros ((n , y .max () + 1 ))
544
- masks [np .arange (n ), y ] = [1 ]
545
- return masks .astype (bool )
547
+ masks = OneHotEncoder (sparse = False , dtype = bool ).fit_transform (y [:,
548
+ np .newaxis ])
549
+ # n = y.shape[0]
550
+ # masks = np.zeros((n, y.max() + 1), dtype=bool)
551
+ # masks[np.arange(n), y] = [True]
552
+ return masks
0 commit comments