Skip to content

Commit

Permalink
Add correlation clustering inference-only functionality (#32)
Browse files Browse the repository at this point in the history
* --pairwise_eval_clustering that accepts 'cc' or 'hac' to run clustering inference with pairwise model output
* Modified logging statements
* Fixed clustering batch size to 1 (to force operating on only one block at a time)
* Added missing model.eval() in --eval_only_split flow 
* Modularized train script
  • Loading branch information
dhdhagar authored Jan 19, 2023
1 parent 436f0c2 commit ba568fc
Show file tree
Hide file tree
Showing 7 changed files with 406 additions and 209 deletions.
45 changes: 45 additions & 0 deletions e2e_pipeline/cc_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import torch

from e2e_pipeline.mlp_layer import MLPLayer
from e2e_pipeline.sdp_layer import SDPLayer
from e2e_pipeline.hac_cut_layer import HACCutLayer
from e2e_pipeline.trellis_cut_layer import TrellisCutLayer
from e2e_pipeline.uncompress_layer import UncompressTransformLayer
import logging
from IPython import embed

logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO)
logger = logging.getLogger(__name__)

class CCInference(torch.nn.Module):
"""
Correlation clustering inference-only model. Expects edge weights and the number of nodes as input.
"""

def __init__(self, sdp_max_iters, sdp_eps):
super().__init__()
self.uncompress_layer = UncompressTransformLayer()
self.sdp_layer = SDPLayer(max_iters=sdp_max_iters, eps=sdp_eps)
self.hac_cut_layer = HACCutLayer()

def forward(self, edge_weights, N, verbose=False):
edge_weights = torch.squeeze(edge_weights)
edge_weights_uncompressed = self.uncompress_layer(edge_weights, N)
output_probs = self.sdp_layer(edge_weights_uncompressed, N)
pred_clustering = self.hac_cut_layer(output_probs, edge_weights_uncompressed)

if verbose:
logger.info(f"Size of W = {edge_weights.size()}")
logger.info(f"\n{edge_weights}")

logger.info(f"Size of W_matrix = {edge_weights_uncompressed.size()}")
logger.info(f"\n{edge_weights_uncompressed}")

logger.info(f"Size of X = {output_probs.size()}")
logger.info(f"\n{output_probs}")

logger.info(f"Size of X_r = {pred_clustering.size()}")
logger.info(f"\n{pred_clustering}")

return self.hac_cut_layer.cluster_labels
4 changes: 2 additions & 2 deletions e2e_pipeline/hac_cut_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,5 +123,5 @@ def get_rounded_solution(self, X, weights, _MAX_DIST=10, use_similarities=True,
self.objective_value = energy[max_node]
return self.round_matrix

def forward(self, X, W):
return X + (self.get_rounded_solution(X, W) - X).detach()
def forward(self, X, W, use_similarities=True):
return X + (self.get_rounded_solution(X, W, use_similarities=use_similarities) - X).detach()
27 changes: 27 additions & 0 deletions e2e_pipeline/hac_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import torch

from e2e_pipeline.mlp_layer import MLPLayer
from e2e_pipeline.sdp_layer import SDPLayer
from e2e_pipeline.hac_cut_layer import HACCutLayer
from e2e_pipeline.trellis_cut_layer import TrellisCutLayer
from e2e_pipeline.uncompress_layer import UncompressTransformLayer
import logging
from IPython import embed

logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO)
logger = logging.getLogger(__name__)

class HACInference:
"""
HAC inference-only.
"""

def __init__(self):
super().__init__()

def fit(self):
pass

def cluster(self):
pass
132 changes: 132 additions & 0 deletions e2e_scripts/evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
"""
Functions to evaluate end-to-end clustering and pairwise training
"""

from tqdm import tqdm
from sklearn.metrics.cluster import v_measure_score
from sklearn.metrics import roc_curve, auc
from sklearn.metrics import precision_recall_fscore_support
import numpy as np
import torch

from e2e_scripts.train_utils import compute_b3_f1

from IPython import embed


def evaluate(model, dataloader, overfit_batch_idx=-1, clustering_fn=None, tqdm_label='', device=None):
"""
clustering_fn: unused when pairwise_mode is False (only added to keep fn signature identical)
"""
device = device if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_features = dataloader.dataset[0][0].shape[1]

vmeasure, b3_f1, sigs_per_block = [], [], []
for (idx, batch) in enumerate(tqdm(dataloader, desc=f'Evaluating {tqdm_label}')):
if overfit_batch_idx > -1:
if idx < overfit_batch_idx:
continue
if idx > overfit_batch_idx:
break
data, target, cluster_ids = batch
data = data.reshape(-1, n_features).float()
if data.shape[0] == 0:
# Only one signature in block -> predict correctly
vmeasure.append(1.)
b3_f1.append(1.)
sigs_per_block.append(1)
else:
block_size = len(cluster_ids) # get_matrix_size_from_triu(data)
cluster_ids = np.reshape(cluster_ids, (block_size,))
target = target.flatten().float()
sigs_per_block.append(block_size)

# Forward pass through the e2e model
data, target = data.to(device), target.to(device)
_ = model(data, block_size)
predicted_cluster_ids = model.hac_cut_layer.cluster_labels # .detach()

# Compute clustering metrics
vmeasure.append(v_measure_score(predicted_cluster_ids, cluster_ids))
b3_f1_metrics = compute_b3_f1(cluster_ids, predicted_cluster_ids)
b3_f1.append(b3_f1_metrics[2])

vmeasure = np.array(vmeasure)
b3_f1 = np.array(b3_f1)
sigs_per_block = np.array(sigs_per_block)

return np.sum(b3_f1 * sigs_per_block) / np.sum(sigs_per_block), \
np.sum(vmeasure * sigs_per_block) / np.sum(sigs_per_block)


def evaluate_pairwise(model, dataloader, overfit_batch_idx=-1, mode="macro", return_pred_only=False,
thresh_for_f1=0.5, clustering_fn=None, clustering_threshold=None, val_dataloader=None,
tqdm_label='', device=None):
device = device if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_features = dataloader.dataset[0][0].shape[1]

if clustering_fn is not None:
# Then dataloader passed is blockwise
vmeasure, b3_f1, sigs_per_block = [], [], []
for (idx, batch) in enumerate(tqdm(dataloader, desc=f'Evaluating {tqdm_label}')):
if overfit_batch_idx > -1:
if idx < overfit_batch_idx:
continue
if idx > overfit_batch_idx:
break
data, target, cluster_ids = batch
data = data.reshape(-1, n_features).float()
if data.shape[0] == 0:
# Only one signature in block -> predict correctly
vmeasure.append(1.)
b3_f1.append(1.)
sigs_per_block.append(1)
else:
block_size = len(cluster_ids) # get_matrix_size_from_triu(data)
cluster_ids = np.reshape(cluster_ids, (block_size,))
target = target.flatten().float()
sigs_per_block.append(block_size)

# Forward pass through the e2e model
data, target = data.to(device), target.to(device)
predicted_cluster_ids = clustering_fn(model(data), block_size) # .detach()

# Compute clustering metrics
vmeasure.append(v_measure_score(predicted_cluster_ids, cluster_ids))
b3_f1_metrics = compute_b3_f1(cluster_ids, predicted_cluster_ids)
b3_f1.append(b3_f1_metrics[2])

vmeasure = np.array(vmeasure)
b3_f1 = np.array(b3_f1)
sigs_per_block = np.array(sigs_per_block)

return np.sum(vmeasure * sigs_per_block) / np.sum(sigs_per_block), \
np.sum(b3_f1 * sigs_per_block) / np.sum(sigs_per_block)

y_pred, targets = [], []
for (idx, batch) in enumerate(tqdm(dataloader, desc=f'Evaluating {tqdm_label}')):
if overfit_batch_idx > -1:
if idx < overfit_batch_idx:
continue
if idx > overfit_batch_idx:
break
data, target = batch
data = data.reshape(-1, n_features).float()
assert data.shape[0] != 0
target = target.flatten().float()
# Forward pass through the pairwise model
data = data.to(device)
y_pred.append(torch.sigmoid(model(data)).cpu().numpy())
targets.append(target)
y_pred = np.hstack(y_pred)
targets = np.hstack(targets)

if return_pred_only:
return y_pred

fpr, tpr, _ = roc_curve(targets, y_pred)
roc_auc = auc(fpr, tpr)
pr, rc, f1, _ = precision_recall_fscore_support(targets, y_pred >= thresh_for_f1, beta=1.0, average=mode,
zero_division=0)

return roc_auc, np.round(f1, 3)
Loading

0 comments on commit ba568fc

Please sign in to comment.