From 04e8b2870f5791443d3e3f898c839a38aeed2266 Mon Sep 17 00:00:00 2001 From: Mert Cobanov Date: Mon, 11 Mar 2024 13:25:01 +0300 Subject: [PATCH] Add error handling and input validation to calculate_kmeans function --- tasnif/calculations.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/tasnif/calculations.py b/tasnif/calculations.py index ec2a6ac..02f9eb2 100644 --- a/tasnif/calculations.py +++ b/tasnif/calculations.py @@ -41,7 +41,17 @@ def calculate_kmeans(pca_embeddings, num_classes): labels and centroids. """ print("KMeans processing...") - centroid, labels = kmeans2(data=pca_embeddings, k=num_classes, minit="points") - counts = np.bincount(labels) - print("Kmeans done!") - return centroid, labels, counts + if not isinstance(pca_embeddings, np.ndarray): + raise ValueError("pca_embeddings must be a numpy array") + + if num_classes > len(pca_embeddings): + raise ValueError( + "num_classes must be less than or equal to the number of samples in pca_embeddings" + ) + + try: + centroid, labels = kmeans2(data=pca_embeddings, k=num_classes, minit="points") + counts = np.bincount(labels) + return centroid, labels, counts + except Exception as e: + raise RuntimeError(f"An error occurred during KMeans processing: {e}")