Skip to content

Commit d7054de

Browse files
committed
Make ThresholdType an enum
1 parent 2f7b524 commit d7054de

File tree

6 files changed

+95
-35
lines changed

6 files changed

+95
-35
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ build/*
44
dist/*
55
spectralcluster.egg-info/*
66
.coverage
7+
.DS_Store

README.md

+4-3
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,14 @@ You can specify your refinment operations like this:
9191

9292
```
9393
from spectralcluster import RefinementOptions
94+
from spectralcluster import ThresholdType
9495
from spectralcluster import ICASSP2018_REFINEMENT_SEQUENCE
9596
9697
refinement_options = RefinementOptions(
9798
gaussian_blur_sigma=1,
9899
p_percentile=0.95,
99100
thresholding_soft_multiplier=0.01,
100-
thresholding_with_row_max=True,
101+
thresholding_type=ThresholdType.RowMax,
101102
refinement_sequence=ICASSP2018_REFINEMENT_SEQUENCE)
102103
```
103104

@@ -116,8 +117,8 @@ In the new version of this library, we support different types of Laplacian matr
116117

117118
* None Laplacian (affinity matrix): `W`
118119
* Unnormalized Laplacian: `L = D - W`
119-
* Graph cut Laplacian: `L' = D^{-1/2} L D^{-1/2}`
120-
* Random walk Laplacian: `L' = D^{-1} L`
120+
* Graph cut Laplacian: `L' = D^{-1/2} * L * D^{-1/2}`
121+
* Random walk Laplacian: `L' = D^{-1} * L`
121122

122123
You can specify the Laplacian matrix type with the `laplacian_type` argument of the `SpectralClusterer` class.
123124

docs/configs.html

+3-1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ <h1 class="title">Module <code>spectralcluster.configs</code></h1>
3434

3535
RefinementName = refinement.RefinementName
3636
RefinementOptions = refinement.RefinementOptions
37+
ThresholdType = refinement.ThresholdType
38+
SymmetrizeType = refinement.SymmetrizeType
3739
SpectralClusterer = spectral_clusterer.SpectralClusterer
3840

3941

@@ -52,7 +54,7 @@ <h1 class="title">Module <code>spectralcluster.configs</code></h1>
5254
gaussian_blur_sigma=1,
5355
p_percentile=0.95,
5456
thresholding_soft_multiplier=0.01,
55-
thresholding_with_row_max=True,
57+
thresholding_type=ThresholdType.RowMax,
5658
refinement_sequence=ICASSP2018_REFINEMENT_SEQUENCE)
5759

5860
icassp2018_clusterer = SpectralClusterer(

docs/index.html

+1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ <h1 class="title">Package <code>spectralcluster</code></h1>
4747

4848
RefinementName = refinement.RefinementName
4949
RefinementOptions = refinement.RefinementOptions
50+
ThresholdType = refinement.ThresholdType
5051
SymmetrizeType = refinement.SymmetrizeType
5152

5253
SpectralClusterer = spectral_clusterer.SpectralClusterer

docs/refinement.html

+85-30
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,15 @@ <h1 class="title">Module <code>spectralcluster.refinement</code></h1>
4545
RowWiseNormalize = 6
4646

4747

48+
class ThresholdType(enum.Enum):
49+
&#34;&#34;&#34;Different types of thresholding.&#34;&#34;&#34;
50+
# We clear values that are smaller than row_max*p_percentile
51+
RowMax = 1
52+
53+
# We clear (p_percentile*100)% smallest values of the entire row
54+
Percentile = 2
55+
56+
4857
class SymmetrizeType(enum.Enum):
4958
&#34;&#34;&#34;Different types of symmetrization operation.&#34;&#34;&#34;
5059
# We use max(A, A^T)
@@ -61,7 +70,7 @@ <h1 class="title">Module <code>spectralcluster.refinement</code></h1>
6170
gaussian_blur_sigma=1,
6271
p_percentile=0.95,
6372
thresholding_soft_multiplier=0.01,
64-
thresholding_with_row_max=True,
73+
thresholding_type=ThresholdType.RowMax,
6574
thresholding_with_binarization=False,
6675
thresholding_preserve_diagonal=False,
6776
symmetrize_type=SymmetrizeType.Max,
@@ -73,8 +82,7 @@ <h1 class="title">Module <code>spectralcluster.refinement</code></h1>
7382
p_percentile: the p-percentile for the row wise thresholding
7483
thresholding_soft_multiplier: the multiplier for soft threhsold, if this
7584
value is 0, then it&#39;s a hard thresholding
76-
thresholding_with_row_max: if true, we use row_max * p_percentile as row
77-
wise threshold, instead of doing a percentile-based thresholding
85+
thresholding_type: the type of thresholding operation
7886
thresholding_with_binarization: if true, we set values larger than the
7987
threshold to 1
8088
thresholding_preserve_diagonal: if true, in the row wise thresholding
@@ -88,7 +96,7 @@ <h1 class="title">Module <code>spectralcluster.refinement</code></h1>
8896
self.gaussian_blur_sigma = gaussian_blur_sigma
8997
self.p_percentile = p_percentile
9098
self.thresholding_soft_multiplier = thresholding_soft_multiplier
91-
self.thresholding_with_row_max = thresholding_with_row_max
99+
self.thresholding_type = thresholding_type
92100
self.thresholding_with_binarization = thresholding_with_binarization
93101
self.thresholding_preserve_diagonal = thresholding_preserve_diagonal
94102
self.symmetrize_type = symmetrize_type
@@ -121,7 +129,7 @@ <h1 class="title">Module <code>spectralcluster.refinement</code></h1>
121129
elif name == RefinementName.RowWiseThreshold:
122130
return RowWiseThreshold(self.p_percentile,
123131
self.thresholding_soft_multiplier,
124-
self.thresholding_with_row_max,
132+
self.thresholding_type,
125133
self.thresholding_with_binarization,
126134
self.thresholding_preserve_diagonal)
127135
elif name == RefinementName.Symmetrize:
@@ -203,12 +211,14 @@ <h1 class="title">Module <code>spectralcluster.refinement</code></h1>
203211
def __init__(self,
204212
p_percentile=0.95,
205213
thresholding_soft_multiplier=0.01,
206-
thresholding_with_row_max=False,
214+
thresholding_type=ThresholdType.RowMax,
207215
thresholding_with_binarization=False,
208216
thresholding_preserve_diagonal=False):
209217
self.p_percentile = p_percentile
210218
self.multiplier = thresholding_soft_multiplier
211-
self.thresholding_with_row_max = thresholding_with_row_max
219+
if not isinstance(thresholding_type, ThresholdType):
220+
raise TypeError(&#34;thresholding_type must be a ThresholdType&#34;)
221+
self.thresholding_type = thresholding_type
212222
self.thresholding_with_binarization = thresholding_with_binarization
213223
self.thresholding_preserve_diagonal = thresholding_preserve_diagonal
214224

@@ -217,17 +227,19 @@ <h1 class="title">Module <code>spectralcluster.refinement</code></h1>
217227
refined_affinity = np.copy(affinity)
218228
if self.thresholding_preserve_diagonal:
219229
np.fill_diagonal(refined_affinity, 0.0)
220-
if self.thresholding_with_row_max:
230+
if self.thresholding_type == ThresholdType.RowMax:
221231
# Row_max based thresholding
222232
row_max = refined_affinity.max(axis=1)
223233
row_max = np.expand_dims(row_max, axis=1)
224234
is_smaller = refined_affinity &lt; (row_max * self.p_percentile)
225-
else:
235+
elif self.thresholding_type == ThresholdType.Percentile:
226236
# Percentile based thresholding
227237
row_percentile = np.percentile(
228238
refined_affinity, self.p_percentile * 100, axis=1)
229239
row_percentile = np.expand_dims(row_percentile, axis=1)
230240
is_smaller = refined_affinity &lt; row_percentile
241+
else:
242+
raise ValueError(&#34;Unsupported thresholding_type&#34;)
231243
if self.thresholding_with_binarization:
232244
# For values larger than the threshold, we binarize them to 1
233245
refined_affinity = (np.ones_like(
@@ -245,13 +257,13 @@ <h1 class="title">Module <code>spectralcluster.refinement</code></h1>
245257
&#34;&#34;&#34;The Symmetrization operation.&#34;&#34;&#34;
246258

247259
def __init__(self, symmetrize_type=SymmetrizeType.Max):
260+
if not isinstance(symmetrize_type, SymmetrizeType):
261+
raise TypeError(&#34;symmetrize_type must be a SymmetrizeType&#34;)
248262
self.symmetrize_type = symmetrize_type
249263

250264
def refine(self, affinity):
251265
self.check_input(affinity)
252-
if not isinstance(self.symmetrize_type, SymmetrizeType):
253-
raise TypeError(&#34;symmetrize_type must be a SymmetrizeType&#34;)
254-
elif self.symmetrize_type == SymmetrizeType.Max:
266+
if self.symmetrize_type == SymmetrizeType.Max:
255267
return np.maximum(affinity, np.transpose(affinity))
256268
elif self.symmetrize_type == SymmetrizeType.Average:
257269
return 0.5 * (affinity + np.transpose(affinity))
@@ -572,7 +584,7 @@ <h3>Class variables</h3>
572584
</dd>
573585
<dt id="spectralcluster.refinement.RefinementOptions"><code class="flex name class">
574586
<span>class <span class="ident">RefinementOptions</span></span>
575-
<span>(</span><span>gaussian_blur_sigma=1, p_percentile=0.95, thresholding_soft_multiplier=0.01, thresholding_with_row_max=True, thresholding_with_binarization=False, thresholding_preserve_diagonal=False, symmetrize_type=SymmetrizeType.Max, refinement_sequence=None)</span>
587+
<span>(</span><span>gaussian_blur_sigma=1, p_percentile=0.95, thresholding_soft_multiplier=0.01, thresholding_type=ThresholdType.RowMax, thresholding_with_binarization=False, thresholding_preserve_diagonal=False, symmetrize_type=SymmetrizeType.Max, refinement_sequence=None)</span>
576588
</code></dt>
577589
<dd>
578590
<div class="desc"><p>Refinement options for the affinity matrix.</p>
@@ -586,9 +598,8 @@ <h2 id="args">Args</h2>
586598
<dt><strong><code>thresholding_soft_multiplier</code></strong></dt>
587599
<dd>the multiplier for soft threhsold, if this
588600
value is 0, then it's a hard thresholding</dd>
589-
<dt><strong><code>thresholding_with_row_max</code></strong></dt>
590-
<dd>if true, we use row_max * p_percentile as row
591-
wise threshold, instead of doing a percentile-based thresholding</dd>
601+
<dt><strong><code>thresholding_type</code></strong></dt>
602+
<dd>the type of thresholding operation</dd>
592603
<dt><strong><code>thresholding_with_binarization</code></strong></dt>
593604
<dd>if true, we set values larger than the
594605
threshold to 1</dd>
@@ -614,7 +625,7 @@ <h2 id="args">Args</h2>
614625
gaussian_blur_sigma=1,
615626
p_percentile=0.95,
616627
thresholding_soft_multiplier=0.01,
617-
thresholding_with_row_max=True,
628+
thresholding_type=ThresholdType.RowMax,
618629
thresholding_with_binarization=False,
619630
thresholding_preserve_diagonal=False,
620631
symmetrize_type=SymmetrizeType.Max,
@@ -626,8 +637,7 @@ <h2 id="args">Args</h2>
626637
p_percentile: the p-percentile for the row wise thresholding
627638
thresholding_soft_multiplier: the multiplier for soft threhsold, if this
628639
value is 0, then it&#39;s a hard thresholding
629-
thresholding_with_row_max: if true, we use row_max * p_percentile as row
630-
wise threshold, instead of doing a percentile-based thresholding
640+
thresholding_type: the type of thresholding operation
631641
thresholding_with_binarization: if true, we set values larger than the
632642
threshold to 1
633643
thresholding_preserve_diagonal: if true, in the row wise thresholding
@@ -641,7 +651,7 @@ <h2 id="args">Args</h2>
641651
self.gaussian_blur_sigma = gaussian_blur_sigma
642652
self.p_percentile = p_percentile
643653
self.thresholding_soft_multiplier = thresholding_soft_multiplier
644-
self.thresholding_with_row_max = thresholding_with_row_max
654+
self.thresholding_type = thresholding_type
645655
self.thresholding_with_binarization = thresholding_with_binarization
646656
self.thresholding_preserve_diagonal = thresholding_preserve_diagonal
647657
self.symmetrize_type = symmetrize_type
@@ -674,7 +684,7 @@ <h2 id="args">Args</h2>
674684
elif name == RefinementName.RowWiseThreshold:
675685
return RowWiseThreshold(self.p_percentile,
676686
self.thresholding_soft_multiplier,
677-
self.thresholding_with_row_max,
687+
self.thresholding_type,
678688
self.thresholding_with_binarization,
679689
self.thresholding_preserve_diagonal)
680690
elif name == RefinementName.Symmetrize:
@@ -733,7 +743,7 @@ <h2 id="raises">Raises</h2>
733743
elif name == RefinementName.RowWiseThreshold:
734744
return RowWiseThreshold(self.p_percentile,
735745
self.thresholding_soft_multiplier,
736-
self.thresholding_with_row_max,
746+
self.thresholding_type,
737747
self.thresholding_with_binarization,
738748
self.thresholding_preserve_diagonal)
739749
elif name == RefinementName.Symmetrize:
@@ -783,7 +793,7 @@ <h3>Inherited members</h3>
783793
</dd>
784794
<dt id="spectralcluster.refinement.RowWiseThreshold"><code class="flex name class">
785795
<span>class <span class="ident">RowWiseThreshold</span></span>
786-
<span>(</span><span>p_percentile=0.95, thresholding_soft_multiplier=0.01, thresholding_with_row_max=False, thresholding_with_binarization=False, thresholding_preserve_diagonal=False)</span>
796+
<span>(</span><span>p_percentile=0.95, thresholding_soft_multiplier=0.01, thresholding_type=ThresholdType.RowMax, thresholding_with_binarization=False, thresholding_preserve_diagonal=False)</span>
787797
</code></dt>
788798
<dd>
789799
<div class="desc"><p>Apply row wise thresholding.</p></div>
@@ -797,12 +807,14 @@ <h3>Inherited members</h3>
797807
def __init__(self,
798808
p_percentile=0.95,
799809
thresholding_soft_multiplier=0.01,
800-
thresholding_with_row_max=False,
810+
thresholding_type=ThresholdType.RowMax,
801811
thresholding_with_binarization=False,
802812
thresholding_preserve_diagonal=False):
803813
self.p_percentile = p_percentile
804814
self.multiplier = thresholding_soft_multiplier
805-
self.thresholding_with_row_max = thresholding_with_row_max
815+
if not isinstance(thresholding_type, ThresholdType):
816+
raise TypeError(&#34;thresholding_type must be a ThresholdType&#34;)
817+
self.thresholding_type = thresholding_type
806818
self.thresholding_with_binarization = thresholding_with_binarization
807819
self.thresholding_preserve_diagonal = thresholding_preserve_diagonal
808820

@@ -811,17 +823,19 @@ <h3>Inherited members</h3>
811823
refined_affinity = np.copy(affinity)
812824
if self.thresholding_preserve_diagonal:
813825
np.fill_diagonal(refined_affinity, 0.0)
814-
if self.thresholding_with_row_max:
826+
if self.thresholding_type == ThresholdType.RowMax:
815827
# Row_max based thresholding
816828
row_max = refined_affinity.max(axis=1)
817829
row_max = np.expand_dims(row_max, axis=1)
818830
is_smaller = refined_affinity &lt; (row_max * self.p_percentile)
819-
else:
831+
elif self.thresholding_type == ThresholdType.Percentile:
820832
# Percentile based thresholding
821833
row_percentile = np.percentile(
822834
refined_affinity, self.p_percentile * 100, axis=1)
823835
row_percentile = np.expand_dims(row_percentile, axis=1)
824836
is_smaller = refined_affinity &lt; row_percentile
837+
else:
838+
raise ValueError(&#34;Unsupported thresholding_type&#34;)
825839
if self.thresholding_with_binarization:
826840
# For values larger than the threshold, we binarize them to 1
827841
refined_affinity = (np.ones_like(
@@ -862,13 +876,13 @@ <h3>Inherited members</h3>
862876
&#34;&#34;&#34;The Symmetrization operation.&#34;&#34;&#34;
863877

864878
def __init__(self, symmetrize_type=SymmetrizeType.Max):
879+
if not isinstance(symmetrize_type, SymmetrizeType):
880+
raise TypeError(&#34;symmetrize_type must be a SymmetrizeType&#34;)
865881
self.symmetrize_type = symmetrize_type
866882

867883
def refine(self, affinity):
868884
self.check_input(affinity)
869-
if not isinstance(self.symmetrize_type, SymmetrizeType):
870-
raise TypeError(&#34;symmetrize_type must be a SymmetrizeType&#34;)
871-
elif self.symmetrize_type == SymmetrizeType.Max:
885+
if self.symmetrize_type == SymmetrizeType.Max:
872886
return np.maximum(affinity, np.transpose(affinity))
873887
elif self.symmetrize_type == SymmetrizeType.Average:
874888
return 0.5 * (affinity + np.transpose(affinity))
@@ -923,6 +937,40 @@ <h3>Class variables</h3>
923937
</dd>
924938
</dl>
925939
</dd>
940+
<dt id="spectralcluster.refinement.ThresholdType"><code class="flex name class">
941+
<span>class <span class="ident">ThresholdType</span></span>
942+
<span>(</span><span>value, names=None, *, module=None, qualname=None, type=None, start=1)</span>
943+
</code></dt>
944+
<dd>
945+
<div class="desc"><p>Different types of thresholding.</p></div>
946+
<details class="source">
947+
<summary>
948+
<span>Expand source code</span>
949+
</summary>
950+
<pre><code class="python">class ThresholdType(enum.Enum):
951+
&#34;&#34;&#34;Different types of thresholding.&#34;&#34;&#34;
952+
# We clear values that are smaller than row_max*p_percentile
953+
RowMax = 1
954+
955+
# We clear (p_percentile*100)% smallest values of the entire row
956+
Percentile = 2</code></pre>
957+
</details>
958+
<h3>Ancestors</h3>
959+
<ul class="hlist">
960+
<li>enum.Enum</li>
961+
</ul>
962+
<h3>Class variables</h3>
963+
<dl>
964+
<dt id="spectralcluster.refinement.ThresholdType.Percentile"><code class="name">var <span class="ident">Percentile</span></code></dt>
965+
<dd>
966+
<div class="desc"></div>
967+
</dd>
968+
<dt id="spectralcluster.refinement.ThresholdType.RowMax"><code class="name">var <span class="ident">RowMax</span></code></dt>
969+
<dd>
970+
<div class="desc"></div>
971+
</dd>
972+
</dl>
973+
</dd>
926974
</dl>
927975
</section>
928976
</article>
@@ -988,6 +1036,13 @@ <h4><code><a title="spectralcluster.refinement.SymmetrizeType" href="#spectralcl
9881036
<li><code><a title="spectralcluster.refinement.SymmetrizeType.Max" href="#spectralcluster.refinement.SymmetrizeType.Max">Max</a></code></li>
9891037
</ul>
9901038
</li>
1039+
<li>
1040+
<h4><code><a title="spectralcluster.refinement.ThresholdType" href="#spectralcluster.refinement.ThresholdType">ThresholdType</a></code></h4>
1041+
<ul class="">
1042+
<li><code><a title="spectralcluster.refinement.ThresholdType.Percentile" href="#spectralcluster.refinement.ThresholdType.Percentile">Percentile</a></code></li>
1043+
<li><code><a title="spectralcluster.refinement.ThresholdType.RowMax" href="#spectralcluster.refinement.ThresholdType.RowMax">RowMax</a></code></li>
1044+
</ul>
1045+
</li>
9911046
</ul>
9921047
</li>
9931048
</ul>

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import setuptools
44

5-
VERSION = "0.2.0"
5+
VERSION = "0.2.1"
66

77
with open("README.md", "r") as file_object:
88
LONG_DESCRIPTION = file_object.read()

0 commit comments

Comments
 (0)