@@ -45,6 +45,15 @@ <h1 class="title">Module <code>spectralcluster.refinement</code></h1>
45
45
RowWiseNormalize = 6
46
46
47
47
48
+ class ThresholdType(enum.Enum):
49
+ """Different types of thresholding."""
50
+ # We clear values that are smaller than row_max*p_percentile
51
+ RowMax = 1
52
+
53
+ # We clear (p_percentile*100)% smallest values of the entire row
54
+ Percentile = 2
55
+
56
+
48
57
class SymmetrizeType(enum.Enum):
49
58
"""Different types of symmetrization operation."""
50
59
# We use max(A, A^T)
@@ -61,7 +70,7 @@ <h1 class="title">Module <code>spectralcluster.refinement</code></h1>
61
70
gaussian_blur_sigma=1,
62
71
p_percentile=0.95,
63
72
thresholding_soft_multiplier=0.01,
64
- thresholding_with_row_max=True ,
73
+ thresholding_type=ThresholdType.RowMax ,
65
74
thresholding_with_binarization=False,
66
75
thresholding_preserve_diagonal=False,
67
76
symmetrize_type=SymmetrizeType.Max,
@@ -73,8 +82,7 @@ <h1 class="title">Module <code>spectralcluster.refinement</code></h1>
73
82
p_percentile: the p-percentile for the row wise thresholding
74
83
thresholding_soft_multiplier: the multiplier for soft threhsold, if this
75
84
value is 0, then it's a hard thresholding
76
- thresholding_with_row_max: if true, we use row_max * p_percentile as row
77
- wise threshold, instead of doing a percentile-based thresholding
85
+ thresholding_type: the type of thresholding operation
78
86
thresholding_with_binarization: if true, we set values larger than the
79
87
threshold to 1
80
88
thresholding_preserve_diagonal: if true, in the row wise thresholding
@@ -88,7 +96,7 @@ <h1 class="title">Module <code>spectralcluster.refinement</code></h1>
88
96
self.gaussian_blur_sigma = gaussian_blur_sigma
89
97
self.p_percentile = p_percentile
90
98
self.thresholding_soft_multiplier = thresholding_soft_multiplier
91
- self.thresholding_with_row_max = thresholding_with_row_max
99
+ self.thresholding_type = thresholding_type
92
100
self.thresholding_with_binarization = thresholding_with_binarization
93
101
self.thresholding_preserve_diagonal = thresholding_preserve_diagonal
94
102
self.symmetrize_type = symmetrize_type
@@ -121,7 +129,7 @@ <h1 class="title">Module <code>spectralcluster.refinement</code></h1>
121
129
elif name == RefinementName.RowWiseThreshold:
122
130
return RowWiseThreshold(self.p_percentile,
123
131
self.thresholding_soft_multiplier,
124
- self.thresholding_with_row_max ,
132
+ self.thresholding_type ,
125
133
self.thresholding_with_binarization,
126
134
self.thresholding_preserve_diagonal)
127
135
elif name == RefinementName.Symmetrize:
@@ -203,12 +211,14 @@ <h1 class="title">Module <code>spectralcluster.refinement</code></h1>
203
211
def __init__(self,
204
212
p_percentile=0.95,
205
213
thresholding_soft_multiplier=0.01,
206
- thresholding_with_row_max=False ,
214
+ thresholding_type=ThresholdType.RowMax ,
207
215
thresholding_with_binarization=False,
208
216
thresholding_preserve_diagonal=False):
209
217
self.p_percentile = p_percentile
210
218
self.multiplier = thresholding_soft_multiplier
211
- self.thresholding_with_row_max = thresholding_with_row_max
219
+ if not isinstance(thresholding_type, ThresholdType):
220
+ raise TypeError("thresholding_type must be a ThresholdType")
221
+ self.thresholding_type = thresholding_type
212
222
self.thresholding_with_binarization = thresholding_with_binarization
213
223
self.thresholding_preserve_diagonal = thresholding_preserve_diagonal
214
224
@@ -217,17 +227,19 @@ <h1 class="title">Module <code>spectralcluster.refinement</code></h1>
217
227
refined_affinity = np.copy(affinity)
218
228
if self.thresholding_preserve_diagonal:
219
229
np.fill_diagonal(refined_affinity, 0.0)
220
- if self.thresholding_with_row_max :
230
+ if self.thresholding_type == ThresholdType.RowMax :
221
231
# Row_max based thresholding
222
232
row_max = refined_affinity.max(axis=1)
223
233
row_max = np.expand_dims(row_max, axis=1)
224
234
is_smaller = refined_affinity < (row_max * self.p_percentile)
225
- else :
235
+ elif self.thresholding_type == ThresholdType.Percentile :
226
236
# Percentile based thresholding
227
237
row_percentile = np.percentile(
228
238
refined_affinity, self.p_percentile * 100, axis=1)
229
239
row_percentile = np.expand_dims(row_percentile, axis=1)
230
240
is_smaller = refined_affinity < row_percentile
241
+ else:
242
+ raise ValueError("Unsupported thresholding_type")
231
243
if self.thresholding_with_binarization:
232
244
# For values larger than the threshold, we binarize them to 1
233
245
refined_affinity = (np.ones_like(
@@ -245,13 +257,13 @@ <h1 class="title">Module <code>spectralcluster.refinement</code></h1>
245
257
"""The Symmetrization operation."""
246
258
247
259
def __init__(self, symmetrize_type=SymmetrizeType.Max):
260
+ if not isinstance(symmetrize_type, SymmetrizeType):
261
+ raise TypeError("symmetrize_type must be a SymmetrizeType")
248
262
self.symmetrize_type = symmetrize_type
249
263
250
264
def refine(self, affinity):
251
265
self.check_input(affinity)
252
- if not isinstance(self.symmetrize_type, SymmetrizeType):
253
- raise TypeError("symmetrize_type must be a SymmetrizeType")
254
- elif self.symmetrize_type == SymmetrizeType.Max:
266
+ if self.symmetrize_type == SymmetrizeType.Max:
255
267
return np.maximum(affinity, np.transpose(affinity))
256
268
elif self.symmetrize_type == SymmetrizeType.Average:
257
269
return 0.5 * (affinity + np.transpose(affinity))
@@ -572,7 +584,7 @@ <h3>Class variables</h3>
572
584
</ dd >
573
585
< dt id ="spectralcluster.refinement.RefinementOptions "> < code class ="flex name class ">
574
586
< span > class < span class ="ident "> RefinementOptions</ span > </ span >
575
- < span > (</ span > < span > gaussian_blur_sigma=1, p_percentile=0.95, thresholding_soft_multiplier=0.01, thresholding_with_row_max=True , thresholding_with_binarization=False, thresholding_preserve_diagonal=False, symmetrize_type=SymmetrizeType.Max, refinement_sequence=None)</ span >
587
+ < span > (</ span > < span > gaussian_blur_sigma=1, p_percentile=0.95, thresholding_soft_multiplier=0.01, thresholding_type=ThresholdType.RowMax , thresholding_with_binarization=False, thresholding_preserve_diagonal=False, symmetrize_type=SymmetrizeType.Max, refinement_sequence=None)</ span >
576
588
</ code > </ dt >
577
589
< dd >
578
590
< div class ="desc "> < p > Refinement options for the affinity matrix.</ p >
@@ -586,9 +598,8 @@ <h2 id="args">Args</h2>
586
598
< dt > < strong > < code > thresholding_soft_multiplier</ code > </ strong > </ dt >
587
599
< dd > the multiplier for soft threhsold, if this
588
600
value is 0, then it's a hard thresholding</ dd >
589
- < dt > < strong > < code > thresholding_with_row_max</ code > </ strong > </ dt >
590
- < dd > if true, we use row_max * p_percentile as row
591
- wise threshold, instead of doing a percentile-based thresholding</ dd >
601
+ < dt > < strong > < code > thresholding_type</ code > </ strong > </ dt >
602
+ < dd > the type of thresholding operation</ dd >
592
603
< dt > < strong > < code > thresholding_with_binarization</ code > </ strong > </ dt >
593
604
< dd > if true, we set values larger than the
594
605
threshold to 1</ dd >
@@ -614,7 +625,7 @@ <h2 id="args">Args</h2>
614
625
gaussian_blur_sigma=1,
615
626
p_percentile=0.95,
616
627
thresholding_soft_multiplier=0.01,
617
- thresholding_with_row_max=True ,
628
+ thresholding_type=ThresholdType.RowMax ,
618
629
thresholding_with_binarization=False,
619
630
thresholding_preserve_diagonal=False,
620
631
symmetrize_type=SymmetrizeType.Max,
@@ -626,8 +637,7 @@ <h2 id="args">Args</h2>
626
637
p_percentile: the p-percentile for the row wise thresholding
627
638
thresholding_soft_multiplier: the multiplier for soft threhsold, if this
628
639
value is 0, then it's a hard thresholding
629
- thresholding_with_row_max: if true, we use row_max * p_percentile as row
630
- wise threshold, instead of doing a percentile-based thresholding
640
+ thresholding_type: the type of thresholding operation
631
641
thresholding_with_binarization: if true, we set values larger than the
632
642
threshold to 1
633
643
thresholding_preserve_diagonal: if true, in the row wise thresholding
@@ -641,7 +651,7 @@ <h2 id="args">Args</h2>
641
651
self.gaussian_blur_sigma = gaussian_blur_sigma
642
652
self.p_percentile = p_percentile
643
653
self.thresholding_soft_multiplier = thresholding_soft_multiplier
644
- self.thresholding_with_row_max = thresholding_with_row_max
654
+ self.thresholding_type = thresholding_type
645
655
self.thresholding_with_binarization = thresholding_with_binarization
646
656
self.thresholding_preserve_diagonal = thresholding_preserve_diagonal
647
657
self.symmetrize_type = symmetrize_type
@@ -674,7 +684,7 @@ <h2 id="args">Args</h2>
674
684
elif name == RefinementName.RowWiseThreshold:
675
685
return RowWiseThreshold(self.p_percentile,
676
686
self.thresholding_soft_multiplier,
677
- self.thresholding_with_row_max ,
687
+ self.thresholding_type ,
678
688
self.thresholding_with_binarization,
679
689
self.thresholding_preserve_diagonal)
680
690
elif name == RefinementName.Symmetrize:
@@ -733,7 +743,7 @@ <h2 id="raises">Raises</h2>
733
743
elif name == RefinementName.RowWiseThreshold:
734
744
return RowWiseThreshold(self.p_percentile,
735
745
self.thresholding_soft_multiplier,
736
- self.thresholding_with_row_max ,
746
+ self.thresholding_type ,
737
747
self.thresholding_with_binarization,
738
748
self.thresholding_preserve_diagonal)
739
749
elif name == RefinementName.Symmetrize:
@@ -783,7 +793,7 @@ <h3>Inherited members</h3>
783
793
</ dd >
784
794
< dt id ="spectralcluster.refinement.RowWiseThreshold "> < code class ="flex name class ">
785
795
< span > class < span class ="ident "> RowWiseThreshold</ span > </ span >
786
- < span > (</ span > < span > p_percentile=0.95, thresholding_soft_multiplier=0.01, thresholding_with_row_max=False , thresholding_with_binarization=False, thresholding_preserve_diagonal=False)</ span >
796
+ < span > (</ span > < span > p_percentile=0.95, thresholding_soft_multiplier=0.01, thresholding_type=ThresholdType.RowMax , thresholding_with_binarization=False, thresholding_preserve_diagonal=False)</ span >
787
797
</ code > </ dt >
788
798
< dd >
789
799
< div class ="desc "> < p > Apply row wise thresholding.</ p > </ div >
@@ -797,12 +807,14 @@ <h3>Inherited members</h3>
797
807
def __init__(self,
798
808
p_percentile=0.95,
799
809
thresholding_soft_multiplier=0.01,
800
- thresholding_with_row_max=False ,
810
+ thresholding_type=ThresholdType.RowMax ,
801
811
thresholding_with_binarization=False,
802
812
thresholding_preserve_diagonal=False):
803
813
self.p_percentile = p_percentile
804
814
self.multiplier = thresholding_soft_multiplier
805
- self.thresholding_with_row_max = thresholding_with_row_max
815
+ if not isinstance(thresholding_type, ThresholdType):
816
+ raise TypeError("thresholding_type must be a ThresholdType")
817
+ self.thresholding_type = thresholding_type
806
818
self.thresholding_with_binarization = thresholding_with_binarization
807
819
self.thresholding_preserve_diagonal = thresholding_preserve_diagonal
808
820
@@ -811,17 +823,19 @@ <h3>Inherited members</h3>
811
823
refined_affinity = np.copy(affinity)
812
824
if self.thresholding_preserve_diagonal:
813
825
np.fill_diagonal(refined_affinity, 0.0)
814
- if self.thresholding_with_row_max :
826
+ if self.thresholding_type == ThresholdType.RowMax :
815
827
# Row_max based thresholding
816
828
row_max = refined_affinity.max(axis=1)
817
829
row_max = np.expand_dims(row_max, axis=1)
818
830
is_smaller = refined_affinity < (row_max * self.p_percentile)
819
- else :
831
+ elif self.thresholding_type == ThresholdType.Percentile :
820
832
# Percentile based thresholding
821
833
row_percentile = np.percentile(
822
834
refined_affinity, self.p_percentile * 100, axis=1)
823
835
row_percentile = np.expand_dims(row_percentile, axis=1)
824
836
is_smaller = refined_affinity < row_percentile
837
+ else:
838
+ raise ValueError("Unsupported thresholding_type")
825
839
if self.thresholding_with_binarization:
826
840
# For values larger than the threshold, we binarize them to 1
827
841
refined_affinity = (np.ones_like(
@@ -862,13 +876,13 @@ <h3>Inherited members</h3>
862
876
"""The Symmetrization operation."""
863
877
864
878
def __init__(self, symmetrize_type=SymmetrizeType.Max):
879
+ if not isinstance(symmetrize_type, SymmetrizeType):
880
+ raise TypeError("symmetrize_type must be a SymmetrizeType")
865
881
self.symmetrize_type = symmetrize_type
866
882
867
883
def refine(self, affinity):
868
884
self.check_input(affinity)
869
- if not isinstance(self.symmetrize_type, SymmetrizeType):
870
- raise TypeError("symmetrize_type must be a SymmetrizeType")
871
- elif self.symmetrize_type == SymmetrizeType.Max:
885
+ if self.symmetrize_type == SymmetrizeType.Max:
872
886
return np.maximum(affinity, np.transpose(affinity))
873
887
elif self.symmetrize_type == SymmetrizeType.Average:
874
888
return 0.5 * (affinity + np.transpose(affinity))
@@ -923,6 +937,40 @@ <h3>Class variables</h3>
923
937
</ dd >
924
938
</ dl >
925
939
</ dd >
940
+ < dt id ="spectralcluster.refinement.ThresholdType "> < code class ="flex name class ">
941
+ < span > class < span class ="ident "> ThresholdType</ span > </ span >
942
+ < span > (</ span > < span > value, names=None, *, module=None, qualname=None, type=None, start=1)</ span >
943
+ </ code > </ dt >
944
+ < dd >
945
+ < div class ="desc "> < p > Different types of thresholding.</ p > </ div >
946
+ < details class ="source ">
947
+ < summary >
948
+ < span > Expand source code</ span >
949
+ </ summary >
950
+ < pre > < code class ="python "> class ThresholdType(enum.Enum):
951
+ """Different types of thresholding."""
952
+ # We clear values that are smaller than row_max*p_percentile
953
+ RowMax = 1
954
+
955
+ # We clear (p_percentile*100)% smallest values of the entire row
956
+ Percentile = 2</ code > </ pre >
957
+ </ details >
958
+ < h3 > Ancestors</ h3 >
959
+ < ul class ="hlist ">
960
+ < li > enum.Enum</ li >
961
+ </ ul >
962
+ < h3 > Class variables</ h3 >
963
+ < dl >
964
+ < dt id ="spectralcluster.refinement.ThresholdType.Percentile "> < code class ="name "> var < span class ="ident "> Percentile</ span > </ code > </ dt >
965
+ < dd >
966
+ < div class ="desc "> </ div >
967
+ </ dd >
968
+ < dt id ="spectralcluster.refinement.ThresholdType.RowMax "> < code class ="name "> var < span class ="ident "> RowMax</ span > </ code > </ dt >
969
+ < dd >
970
+ < div class ="desc "> </ div >
971
+ </ dd >
972
+ </ dl >
973
+ </ dd >
926
974
</ dl >
927
975
</ section >
928
976
</ article >
@@ -988,6 +1036,13 @@ <h4><code><a title="spectralcluster.refinement.SymmetrizeType" href="#spectralcl
988
1036
< li > < code > < a title ="spectralcluster.refinement.SymmetrizeType.Max " href ="#spectralcluster.refinement.SymmetrizeType.Max "> Max</ a > </ code > </ li >
989
1037
</ ul >
990
1038
</ li >
1039
+ < li >
1040
+ < h4 > < code > < a title ="spectralcluster.refinement.ThresholdType " href ="#spectralcluster.refinement.ThresholdType "> ThresholdType</ a > </ code > </ h4 >
1041
+ < ul class ="">
1042
+ < li > < code > < a title ="spectralcluster.refinement.ThresholdType.Percentile " href ="#spectralcluster.refinement.ThresholdType.Percentile "> Percentile</ a > </ code > </ li >
1043
+ < li > < code > < a title ="spectralcluster.refinement.ThresholdType.RowMax " href ="#spectralcluster.refinement.ThresholdType.RowMax "> RowMax</ a > </ code > </ li >
1044
+ </ ul >
1045
+ </ li >
991
1046
</ ul >
992
1047
</ li >
993
1048
</ ul >
0 commit comments