Skip to content
17 changes: 17 additions & 0 deletions armory/instrument/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,8 @@ def _write(self, name, batch, result):
f"neutral: {result['neutral']}/{total}, "
f"entailment: {result['entailment']}/{total}"
)
elif "confusion_matrix" in name:
f_result = f"{result}"
elif any(m in name for m in MEAN_AP_METRICS):
if "input_to" in name:
for m in MEAN_AP_METRICS:
Expand All @@ -216,6 +218,8 @@ def _write(self, name, batch, result):
elif any(m in name for m in QUANTITY_METRICS):
# Don't include % symbol
f_result = f"{np.mean(result):.2}"
elif isinstance(result, dict):
f_result = f"{result}"
else:
f_result = f"{np.mean(result):.2%}"
log.success(
Expand Down Expand Up @@ -253,6 +257,19 @@ def _task_metric(
elif name == "word_error_rate":
final = metrics.get("total_wer")
final_suffix = "total_word_error_rate"
elif name in [
"per_class_mean_accuracy",
"per_class_precision_and_recall",
"confusion_matrix",
]:
metric = metrics.get("identity_unzip")
func = metrics.get(name)

def final(x):
return func(*metrics.task.identity_zip(x))

final_suffix = name

elif use_mean:
final = np.mean
final_suffix = f"mean_{name}"
Expand Down
52 changes: 52 additions & 0 deletions armory/metrics/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,58 @@ def tpr_fpr(actual_conditions, predicted_conditions):
)


@populationwise
def per_class_precision_and_recall(y, y_pred):
"""
Produce a dictionary whose keys are class labels, and values are (precision, recall) for that class
"""
# Assumes that every class is represented in y

C = confusion_matrix(y, y_pred, normalize_rows=False)
# breakpoint()
N = C.shape[0]
D = {}
for class_ in range(N):
# precision: true positives / number of items identified as class_
tp = C[class_, class_]
total_selected = C[:, class_].sum()
precision = tp / total_selected

# recall: true positives / number of actual items in class_
total_class_ = C[class_, :].sum()
recall = tp / total_class_

D[class_] = (precision, recall)

return D


@populationwise
def confusion_matrix(y, y_pred, normalize_rows=True):
"""
Produce a matrix C such that C[i,j] describes how often class i is classified as class j.
If normalize_rows is False, C[i,j] is the actual number of i's classified as j.
If normalize_rows is True (default), the rows are normalized in L1, so that C[i,j] is the percentage of class i that was marked as j.
"""
# Assumes that every class is represented in y

y = np.array(y)
y_pred = np.array(y_pred)
if y_pred.ndim == 2: # if y_pred is logits
y_pred = np.argmax(y_pred, axis=1)
N = len(np.unique(y)) # number of classes
C = np.zeros((N, N))
for i in range(N):
for j in range(N):
# count items of class i that were classified as j
C[i, j] = np.sum(y_pred[y == i] == j)
if normalize_rows:
# divide rows by their sum so that each element is a percentage of class i, not a count
sums = np.sum(C, axis=1)
C = C / sums[:, np.newaxis]
return C


@batchwise
def per_class_accuracy(y, y_pred):
"""
Expand Down
5 changes: 4 additions & 1 deletion armory/utils/config_schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,10 @@
"carla_od_disappearance_rate",
"carla_od_hallucinations_per_image",
"carla_od_misclassification_rate",
"carla_od_true_positive_rate"
"carla_od_true_positive_rate",
"per_class_mean_accuracy",
"confusion_matrix",
"per_class_precision_and_recall"
]
},
"sysconfig": {
Expand Down
27 changes: 27 additions & 0 deletions tests/unit/test_task_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,33 @@
pytestmark = pytest.mark.unit


def test_confusion_matrix():
y = np.array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
y_pred = np.array([0, 0, 0, 1, 1, 1, 1, 1, 1, 0])
assert task.confusion_matrix(y, y) == pytest.approx(np.array([[1, 0], [0, 1]]))
assert task.confusion_matrix(y, y_pred) == pytest.approx(
np.array([[0.6, 0.4], [0.2, 0.8]])
)
assert task.confusion_matrix(y, y_pred, normalize_rows=False) == pytest.approx(
np.array([[3, 2], [1, 4]])
)


def test_per_class_precision_and_recall():
y = np.array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
y_pred = np.array([0, 0, 0, 1, 1, 1, 1, 1, 1, 0])
D = task.per_class_precision_and_recall(y, y_pred)
assert D[0] == pytest.approx((0.75, 0.6))
assert D[1] == pytest.approx((0.66666667, 0.8))

y = np.array([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2])
y_pred = np.array([0, 0, 0, 0, 1, 1, 2, 1, 2, 2, 0, 1])
D = task.per_class_precision_and_recall(y, y_pred)
assert D[0] == pytest.approx((0.8, 1))
assert D[1] == pytest.approx((0.75, 0.75))
assert D[2] == pytest.approx((0.666666667, 0.5))


@pytest.mark.docker_required
@pytest.mark.pytorch_deepspeech
@pytest.mark.slow
Expand Down