diff --git a/src/kaczmarz/__init__.py b/src/kaczmarz/__init__.py index f42fbfb..3ea4938 100644 --- a/src/kaczmarz/__init__.py +++ b/src/kaczmarz/__init__.py @@ -13,14 +13,19 @@ from ._abc import Base from ._variants import ( # OrthogonalMaxDistance, + BiasedSubsampledMaxDistance, Cyclic, + GramianSelectableSet, MaxDistance, MaxDistanceLookahead, + Nonrepetitive, ParallelOrthoUpdate, Quantile, Random, - RandomOrthoGraph, + RelaxedGreedy, SampledQuantile, + SubsampledMaxDistance, + SubsampledPlusNeighborsMaxDistance, SVRandom, UniformRandom, WindowedQuantile, diff --git a/src/kaczmarz/_normalize.py b/src/kaczmarz/_normalize.py index b895ff9..2b36847 100644 --- a/src/kaczmarz/_normalize.py +++ b/src/kaczmarz/_normalize.py @@ -59,6 +59,7 @@ def normalize_system(A, b): A = np.array(A) row_norms = compute_row_norms(A) + row_norms[row_norms==0] = 1 A = normalize_matrix(A, row_norms=row_norms) b = np.array(b).ravel() / row_norms diff --git a/src/kaczmarz/_variants.py b/src/kaczmarz/_variants.py index 48270de..b82f40d 100644 --- a/src/kaczmarz/_variants.py +++ b/src/kaczmarz/_variants.py @@ -95,6 +95,8 @@ class Random(kaczmarz.Base): def __init__(self, *base_args, p=None, **base_kwargs): super().__init__(*base_args, **base_kwargs) + if p is None: + p = np.ones(self._n_rows) / self._n_rows self._p = p # p=None corresponds to uniform. def _select_row_index(self, xk): @@ -223,8 +225,56 @@ def _distance(self, xk, ik): def _threshold_distances(self, xk): return self._window +class Nonrepetitive(Random): + """Do not sample the most recently projected row.""" + def _select_row_index(self, xk): + i = super()._select_row_index(xk) + + # This loops infinitely if there is only one row. + while i == self._ik: + i = super()._select_row_index(xk) + + return i + +class RelaxedGreedy(kaczmarz.Base): + """Only sample equations that lead to a sufficiently large update. + + Parameters + ---------- + theta : float, optional + Parameter in the range [0,1] + + References + ---------- + 1. Zhong-Zhi Bai, Wen-Ting Wu, + On relaxed greedy randomized Kaczmarz methods for solving large sparse linear systems, + Applied Mathematics Letters, Volume 83, 2018, Pages 21-26, + """ + def __init__(self, *args, theta=0.5, **kwargs): + super().__init__(*args, **kwargs) + if theta < 0 or theta > 1: + raise Exception("Theta value outside parameter range [0, 1]") + self._theta = theta + self._row_norms_sq = self._row_norms **2 + self._fro_sq = np.sum(self._row_norms_sq) + + # Bai and Wu's algorithm + def _select_row_index(self, xk): + residual_sq = (self._b - self._A @ xk) ** 2 + residual_unnormalized_sq = self._row_norms_sq * residual_sq + res_norm_sq = residual_unnormalized_sq.sum() + epsilon = self._theta / res_norm_sq * residual_sq + (1 - self._theta) / self._fro_sq + + index_bool = (residual_unnormalized_sq >= epsilon * res_norm_sq * (self._row_norms_sq)) + if ~np.any(index_bool): + raise Exception("Index set empty") -class RandomOrthoGraph(kaczmarz.Base): + prob = residual_unnormalized_sq + prob[~index_bool] = 0 + prob /= prob.sum() + return np.random.choice(self._n_rows, p=prob) + +class GramianSelectableSet(kaczmarz.Base): """Try to only sample equations which are not already satisfied. Use the orthogonality graph defined in [1] to decide which rows should @@ -277,7 +327,7 @@ def selectable(self): return self._selectable.copy() -class ParallelOrthoUpdate(RandomOrthoGraph): +class ParallelOrthoUpdate(GramianSelectableSet): """Perform multiple updates in parallel, using only rows which are mutually orthogonal Parameters @@ -320,3 +370,77 @@ def _select_row_index(self, xk): self._update_selectable(i) return tauk + + +class SubsampledMaxDistance(Random): + """Choose the best row amongst a random subset at each iteration. + + Parameters + ---------- + n_samples : int + Numbers of rows to sample at each iteration. + """ + + def __init__(self, *base_args, n_samples=1, **base_kwargs): + super().__init__(*base_args, **base_kwargs) + self._n_samples = n_samples + self._row_idxs = [] + + def _get_samples(self): + return np.random.choice(self._n_rows, self._n_samples, p=self._p) + + def _select_row_index(self, xk): + row_idxs = self._get_samples() + self._row_idxs = row_idxs + + residual = self._b[row_idxs] - self._A[row_idxs] @ xk + return row_idxs[np.argmax(np.abs(residual))] + + +class SubsampledPlusNeighborsMaxDistance(SubsampledMaxDistance): + """Add neighbors of the most recent row to the subsample for BiasedSubsampledMaxDistance.""" + + def __init__(self, *base_args, **base_kwargs): + super().__init__(*base_args, **base_kwargs) + + self._gramian = self._A @ self._A.T + + # Map each row index i to indexes of rows that are NOT orthogonal to it. + self._i_to_neighbors = {} + for i in range(self._n_rows): + self._i_to_neighbors[i] = self._gramian[[i], :].nonzero()[1] + + def _get_samples(self): + samples = super()._get_samples() + if self.ik != -1: + neighbors = self._i_to_neighbors[self.ik] + # Add the neighbors to the sample. + samples = np.union1d(samples, neighbors) + + return samples + + +class BiasedSubsampledMaxDistance(SubsampledPlusNeighborsMaxDistance): + """Bias the subset for SubsampledMaxDistance toward neighbors of the most recent row. + + Parameters + ---------- + bias : float + Neighbors of the most recently used row will be `bias` times extra likely to be part of the subsample. + """ + + def __init__(self, *base_args, bias=1, **base_kwargs): + super().__init__(*base_args, **base_kwargs) + + self._bias = bias + + def _get_samples(self): + if self.ik == -1: + p = self._p + else: + neighbor_idxs = self._i_to_neighbors[self.ik] + p = self._p.copy() + p[neighbor_idxs] *= self._bias + p /= p.sum() + + return np.random.choice(self._n_rows, self._n_samples, p=p) diff --git a/tests/test_ortho_graph.py b/tests/test_ortho_graph.py index f9374b6..f03709c 100644 --- a/tests/test_ortho_graph.py +++ b/tests/test_ortho_graph.py @@ -7,7 +7,7 @@ def test_selectable_set(eye33, ones3): x0 = np.zeros(3) - solver = kaczmarz.RandomOrthoGraph(eye33, ones3, x0) + solver = kaczmarz.GramianSelectableSet(eye33, ones3, x0) # Length is 3 for two iterations because first iteration yields x0 without performing an update. assert 3 == sum(solver.selectable)