@@ -50,13 +50,13 @@ class MatthewsCorrelationCoefficient(tf.keras.metrics.Metric):
5050
5151 Usage:
5252
53- >>> y_true = np.array([[1.0], [1.0], [1.0], [0.0]], dtype=np.float32)
54- >>> y_pred = np.array([[1.0], [0 .0], [1.0], [1.0]], dtype=np.float32)
55- >>> metric = tfa.metrics.MatthewsCorrelationCoefficient(num_classes=1 )
53+ >>> y_true = np.array([[0.0, 1.0], [0.0, 1.0], [0.0, 1.0], [1.0, 0.0]], dtype=np.float32)
54+ >>> y_pred = np.array([[0.0, 1.0], [1.0, 0 .0], [0.0, 1.0], [0.0, 1.0]], dtype=np.float32)
55+ >>> metric = tfa.metrics.MatthewsCorrelationCoefficient(num_classes=2 )
5656 >>> metric.update_state(y_true, y_pred)
5757 >>> result = metric.result()
5858 >>> result.numpy()
59- array([ -0.33333334], dtype=float32)
59+ -0.33333334
6060 """
6161
6262 @typechecked
@@ -70,28 +70,10 @@ def __init__(
7070 """Creates a Matthews Correlation Coefficient instance."""
7171 super ().__init__ (name = name , dtype = dtype )
7272 self .num_classes = num_classes
73- self .true_positives = self .add_weight (
74- "true_positives" ,
75- shape = [self .num_classes ],
76- initializer = "zeros" ,
77- dtype = self .dtype ,
78- )
79- self .false_positives = self .add_weight (
80- "false_positives" ,
81- shape = [self .num_classes ],
82- initializer = "zeros" ,
83- dtype = self .dtype ,
84- )
85- self .false_negatives = self .add_weight (
86- "false_negatives" ,
87- shape = [self .num_classes ],
88- initializer = "zeros" ,
89- dtype = self .dtype ,
90- )
91- self .true_negatives = self .add_weight (
92- "true_negatives" ,
93- shape = [self .num_classes ],
94- initializer = "zeros" ,
73+ self .conf_mtx = self .add_weight (
74+ "conf_mtx" ,
75+ shape = (self .num_classes , self .num_classes ),
76+ initializer = tf .keras .initializers .zeros ,
9577 dtype = self .dtype ,
9678 )
9779
@@ -100,43 +82,35 @@ def update_state(self, y_true, y_pred, sample_weight=None):
10082 y_true = tf .cast (y_true , dtype = self .dtype )
10183 y_pred = tf .cast (y_pred , dtype = self .dtype )
10284
103- true_positive = tf .math .count_nonzero ( y_true * y_pred , 0 )
104- # true_negative
105- y_true_negative = tf .math . not_equal ( y_true , 1.0 )
106- y_pred_negative = tf . math . not_equal ( y_pred , 1.0 )
107- true_negative = tf . math . count_nonzero (
108- tf . math . logical_and ( y_true_negative , y_pred_negative ), axis = 0
85+ new_conf_mtx = tf .math .confusion_matrix (
86+ labels = tf . argmax ( y_true , 1 ),
87+ predictions = tf .argmax ( y_pred , 1 ),
88+ num_classes = self . num_classes ,
89+ weights = sample_weight ,
90+ dtype = self . dtype ,
10991 )
110- # predicted sum
111- pred_sum = tf .math .count_nonzero (y_pred , 0 )
112- # Ground truth label sum
113- true_sum = tf .math .count_nonzero (y_true , 0 )
114- false_positive = pred_sum - true_positive
115- false_negative = true_sum - true_positive
116-
117- # true positive state_update
118- self .true_positives .assign_add (tf .cast (true_positive , self .dtype ))
119- # false positive state_update
120- self .false_positives .assign_add (tf .cast (false_positive , self .dtype ))
121- # false negative state_update
122- self .false_negatives .assign_add (tf .cast (false_negative , self .dtype ))
123- # true negative state_update
124- self .true_negatives .assign_add (tf .cast (true_negative , self .dtype ))
92+
93+ self .conf_mtx .assign_add (new_conf_mtx )
12594
12695 def result (self ):
127- # numerator
128- numerator1 = self .true_positives * self .true_negatives
129- numerator2 = self .false_positives * self .false_negatives
130- numerator = numerator1 - numerator2
131- # denominator
132- denominator1 = self .true_positives + self .false_positives
133- denominator2 = self .true_positives + self .false_negatives
134- denominator3 = self .true_negatives + self .false_positives
135- denominator4 = self .true_negatives + self .false_negatives
136- denominator = tf .math .sqrt (
137- denominator1 * denominator2 * denominator3 * denominator4
138- )
139- mcc = tf .math .divide_no_nan (numerator , denominator )
96+
97+ true_sum = tf .reduce_sum (self .conf_mtx , axis = 1 )
98+ pred_sum = tf .reduce_sum (self .conf_mtx , axis = 0 )
99+ num_correct = tf .linalg .trace (self .conf_mtx )
100+ num_samples = tf .reduce_sum (pred_sum )
101+
102+ # covariance true-pred
103+ cov_ytyp = num_correct * num_samples - tf .tensordot (true_sum , pred_sum , axes = 1 )
104+ # covariance pred-pred
105+ cov_ypyp = num_samples ** 2 - tf .tensordot (pred_sum , pred_sum , axes = 1 )
106+ # covariance true-true
107+ cov_ytyt = num_samples ** 2 - tf .tensordot (true_sum , true_sum , axes = 1 )
108+
109+ mcc = cov_ytyp / tf .math .sqrt (cov_ytyt * cov_ypyp )
110+
111+ if tf .math .is_nan (mcc ):
112+ mcc = tf .constant (0 , dtype = self .dtype )
113+
140114 return mcc
141115
142116 def get_config (self ):
@@ -150,5 +124,9 @@ def get_config(self):
150124
151125 def reset_states (self ):
152126 """Resets all of the metric state variables."""
153- reset_value = np .zeros (self .num_classes , dtype = self .dtype )
154- K .batch_set_value ([(v , reset_value ) for v in self .variables ])
127+
128+ for v in self .variables :
129+ K .set_value (
130+ v ,
131+ np .zeros ((self .num_classes , self .num_classes ), v .dtype .as_numpy_dtype ),
132+ )
0 commit comments