Skip to content

Nca temp #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: nca
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/modules/classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1113,7 +1113,7 @@ Model validation
neighbors.RadiusNeighborsRegressor
neighbors.NearestCentroid
neighbors.NearestNeighbors
neighbors.NeighborhoodComponentAnalysis
neighbors.NeighborhoodComponentsAnalysis

.. autosummary::
:toctree: generated/
Expand Down
4 changes: 2 additions & 2 deletions sklearn/neighbors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .kde import KernelDensity
from .approximate import LSHForest
from .lof import LocalOutlierFactor
from .nca import NeighborhoodComponentAnalysis
from .nca import NeighborhoodComponentsAnalysis

__all__ = ['BallTree',
'DistanceMetric',
Expand All @@ -30,4 +30,4 @@
'KernelDensity',
'LSHForest',
'LocalOutlierFactor',
'NeighborhoodComponentAnalysis']
'NeighborhoodComponentsAnalysis']
119 changes: 63 additions & 56 deletions sklearn/neighbors/nca.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import time
from scipy.misc import logsumexp
from scipy.optimize import minimize
from sklearn.preprocessing import OneHotEncoder

from ..base import BaseEstimator, TransformerMixin
from ..preprocessing import LabelEncoder
Expand All @@ -22,12 +23,12 @@
from ..externals.six import integer_types


class NeighborhoodComponentAnalysis(BaseEstimator, TransformerMixin):
"""Neighborhood Component Analysis
class NeighborhoodComponentsAnalysis(BaseEstimator, TransformerMixin):
"""Neighborhood Components Analysis

Parameters
----------
n_features_out: int, optional (default=None)
n_features_out : int, optional (default=None)
Preferred dimensionality of the embedding.

init : string or numpy array, optional (default='pca')
Expand Down Expand Up @@ -87,27 +88,27 @@ class NeighborhoodComponentAnalysis(BaseEstimator, TransformerMixin):
Attributes
----------
transformation_ : array, shape (n_features_out, n_features)
The linear transformation learned during fitting.
The linear transformation learned during fitting.

n_iter_ : int
Counts the number of iterations performed by the optimizer.
Counts the number of iterations performed by the optimizer.

opt_result_ : scipy.optimize.OptimizeResult (optional)
A dictionary of information representing the optimization result.
This is stored only if ``store_opt_result`` was True.

Examples
--------
>>> from sklearn.neighbors.nca import NeighborhoodComponentAnalysis
>>> from sklearn.neighbors.nca import NeighborhoodComponentsAnalysis
>>> from sklearn.neighbors import KNeighborsClassifier
>>> from sklearn.datasets import load_iris
>>> from sklearn.model_selection import train_test_split
>>> X, y = load_iris(return_X_y=True)
>>> X_train, X_test, y_train, y_test = train_test_split(X, y,
... stratify=y, test_size=0.7, random_state=42)
>>> nca = NeighborhoodComponentAnalysis(None,random_state=42)
>>> nca = NeighborhoodComponentsAnalysis(random_state=42)
>>> nca.fit(X_train, y_train) # doctest: +ELLIPSIS
NeighborhoodComponentAnalysis(...)
NeighborhoodComponentsAnalysis(...)
>>> knn = KNeighborsClassifier(n_neighbors=3)
>>> knn.fit(X_train, y_train) # doctest: +ELLIPSIS
KNeighborsClassifier(...)
Expand All @@ -121,25 +122,23 @@ class NeighborhoodComponentAnalysis(BaseEstimator, TransformerMixin):
Notes
-----
Neighborhood Component Analysis (NCA) is a machine learning algorithm for
metric learning. It learns a linear transformation of the space in a
supervised fashion to improve the classification accuracy of a
stochastic nearest neighbors rule in this new space.

.. warning::

As NCA is optimizing a non-convex objective function, it will
likely end up in a local optimum. Several runs with independent random
init might be necessary to get a good convergence.
metric learning. It learns a linear transformation in a supervised fashion
to improve the classification accuracy of a stochastic nearest neighbors
rule in the transformed space.

References
----------
.. [1] J. Goldberger, G. Hinton, S. Roweis, R. Salakhutdinov.
"Neighbourhood Components Analysis". Advances in Neural Information
Processing Systems. 17, 513-520, 2005.
http://www.cs.nyu.edu/~roweis/papers/ncanips.pdf

.. [2] Wikipedia entry on Neighborhood Components Analysis
https://en.wikipedia.org/wiki/Neighbourhood_components_analysis

"""

def __init__(self, n_features_out=None, init='identity', max_iter=50,
def __init__(self, n_features_out=None, init='pca', max_iter=50,
tol=1e-5, callback=None, store_opt_result=False, verbose=0,
random_state=None):

Expand Down Expand Up @@ -167,7 +166,7 @@ def fit(self, X, y):
Returns
-------
self : object
returns a trained NeighborhoodComponentAnalysis model.
returns a trained NeighborhoodComponentsAnalysis model.
"""

# Verify inputs X and y and NCA parameters, and transform a copy if
Expand All @@ -182,7 +181,8 @@ def fit(self, X, y):

# Compute arrays that stay fixed during optimization:
# mask for fast lookup of same-class samples
masks = _make_masks(y_valid)
masks = OneHotEncoder(sparse=False,
dtype=bool).fit_transform(y_valid[:, np.newaxis])
# pairwise differences
diffs = X_valid[:, np.newaxis] - X_valid[np.newaxis]

Expand All @@ -193,7 +193,7 @@ def fit(self, X, y):
disp = self.verbose - 2 if self.verbose > 1 else -1
optimizer_params = {'method': 'L-BFGS-B',
'fun': self._loss_grad_lbfgs,
'args': (X_valid, y_valid, diffs, masks),
'args': (X_valid, y_valid, diffs, masks, -1.0),
'jac': True,
'x0': transformation,
'tol': self.tol,
Expand Down Expand Up @@ -401,7 +401,7 @@ def _callback(self, transformation):
self.n_iter_ += 1

def _loss_grad_lbfgs(self, transformation, X, y, diffs,
masks):
masks, sign=1.0):
"""Compute the loss and the loss gradient w.r.t. ``transformation``.

Parameters
Expand Down Expand Up @@ -430,31 +430,58 @@ def _loss_grad_lbfgs(self, transformation, X, y, diffs,
The new (flattened) gradient of the loss.
"""

if self.n_iter_ == 0:
self.n_iter_ += 1
if self.verbose:
header_fields = ['Iteration', 'Objective Value', 'Time(s)']
header_fmt = '{:>10} {:>20} {:>10}'
header = header_fmt.format(*header_fields)
cls_name = self.__class__.__name__
print('[{}]'.format(cls_name))
print('[{}] {}\n[{}] {}'.format(cls_name, header,
cls_name, '-' * len(header)))

t_funcall = time.time()

transformation = transformation.reshape(-1, X.shape[1])
loss = 0
gradient = np.zeros(transformation.shape)
X_embedded = transformation.dot(X.T).T

# for every sample, compute its contribution to loss and gradient
# for every sample x_i, compute its contribution to loss and gradient
for i in range(X.shape[0]):
# compute squared distances to x_i in embedded space
diff_embedded = X_embedded[i] - X_embedded
sum_of_squares = np.einsum('ij,ij->i', diff_embedded,
diff_embedded)
sum_of_squares[i] = np.inf
soft = np.exp(-sum_of_squares - logsumexp(-sum_of_squares))
ci = masks[:, y[i]]
p_i_j = soft[ci]
not_ci = np.logical_not(ci)
diff_ci = diffs[i, ci, :] # n_samples * n_features
diff_not_ci = diffs[i, not_ci, :]
dist_embedded = np.einsum('ij,ij->i', diff_embedded,
diff_embedded)
dist_embedded[i] = np.inf

# compute exponentiated distances (use the log-sum-exp trick to
# avoid numerical instabilities
exp_dist_embedded = np.exp(-dist_embedded -
logsumexp(-dist_embedded))
ci = masks[:, y[i]] # samples that are in the same class as x_i
p_i_j = exp_dist_embedded[ci]
diff_ci = diffs[i, ci, :]
diff_not_ci = diffs[i, ~ci, :]
sum_ci = diff_ci.T.dot(
(p_i_j[:, np.newaxis] * diff_embedded[ci, :]))
sum_not_ci = diff_not_ci.T.dot((soft[not_ci][:, np.newaxis] *
diff_embedded[not_ci, :]))
p_i = np.sum(p_i_j)
sum_not_ci = diff_not_ci.T.dot((exp_dist_embedded[~ci][:,
np.newaxis] *
diff_embedded[~ci, :]))
p_i = np.sum(p_i_j) # probability of x_i to be correctly
# classified
gradient += 2 * (p_i * (sum_ci.T + sum_not_ci.T) - sum_ci.T)
loss += p_i
return - loss, - gradient.ravel()

if self.verbose:
t_funcall = time.time() - t_funcall
values_fmt = '[{}] {:>10} {:>20.6e} {:>10.2f}'
print(values_fmt.format(self.__class__.__name__, self.n_iter_,
loss, t_funcall))
sys.stdout.flush()

return sign * loss, sign * gradient.ravel()


##########################
Expand Down Expand Up @@ -502,23 +529,3 @@ def _check_scalar(x, name, target_type, min_val=None, max_val=None):

if max_val is not None and x > max_val:
raise ValueError('`{}`= {}, must be <= {}.'.format(name, x, max_val))


def _make_masks(y):
"""Create one-hot encoding of vector ``y``.

Parameters
----------
y : array, shape (n_samples,)
Data samples labels.

Returns
-------
masks: array, shape (n_samples, n_classes)
One-hot encoding of ``y``.
"""

n = y.shape[0]
masks = np.zeros((n, y.max() + 1))
masks[np.arange(n), y] = [1]
return masks.astype(bool)
Loading