From ba568fc49c8afb8218dcdc748bb7ef86c64c0c66 Mon Sep 17 00:00:00 2001 From: Dhruv Agarwal Date: Thu, 19 Jan 2023 05:08:06 -0500 Subject: [PATCH] Add correlation clustering inference-only functionality (#32) * --pairwise_eval_clustering that accepts 'cc' or 'hac' to run clustering inference with pairwise model output * Modified logging statements * Fixed clustering batch size to 1 (to force operating on only one block at a time) * Added missing model.eval() in --eval_only_split flow * Modularized train script --- e2e_pipeline/cc_inference.py | 45 ++++++ e2e_pipeline/hac_cut_layer.py | 4 +- e2e_pipeline/hac_inference.py | 27 ++++ e2e_scripts/evaluate.py | 132 +++++++++++++++++ e2e_scripts/train.py | 264 ++++++++-------------------------- e2e_scripts/train_utils.py | 130 +++++++++++++++++ utils/parser.py | 13 +- 7 files changed, 406 insertions(+), 209 deletions(-) create mode 100644 e2e_pipeline/cc_inference.py create mode 100644 e2e_pipeline/hac_inference.py create mode 100644 e2e_scripts/evaluate.py create mode 100644 e2e_scripts/train_utils.py diff --git a/e2e_pipeline/cc_inference.py b/e2e_pipeline/cc_inference.py new file mode 100644 index 0000000..4722f77 --- /dev/null +++ b/e2e_pipeline/cc_inference.py @@ -0,0 +1,45 @@ +import torch + +from e2e_pipeline.mlp_layer import MLPLayer +from e2e_pipeline.sdp_layer import SDPLayer +from e2e_pipeline.hac_cut_layer import HACCutLayer +from e2e_pipeline.trellis_cut_layer import TrellisCutLayer +from e2e_pipeline.uncompress_layer import UncompressTransformLayer +import logging +from IPython import embed + +logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', + level=logging.INFO) +logger = logging.getLogger(__name__) + +class CCInference(torch.nn.Module): + """ + Correlation clustering inference-only model. Expects edge weights and the number of nodes as input. + """ + + def __init__(self, sdp_max_iters, sdp_eps): + super().__init__() + self.uncompress_layer = UncompressTransformLayer() + self.sdp_layer = SDPLayer(max_iters=sdp_max_iters, eps=sdp_eps) + self.hac_cut_layer = HACCutLayer() + + def forward(self, edge_weights, N, verbose=False): + edge_weights = torch.squeeze(edge_weights) + edge_weights_uncompressed = self.uncompress_layer(edge_weights, N) + output_probs = self.sdp_layer(edge_weights_uncompressed, N) + pred_clustering = self.hac_cut_layer(output_probs, edge_weights_uncompressed) + + if verbose: + logger.info(f"Size of W = {edge_weights.size()}") + logger.info(f"\n{edge_weights}") + + logger.info(f"Size of W_matrix = {edge_weights_uncompressed.size()}") + logger.info(f"\n{edge_weights_uncompressed}") + + logger.info(f"Size of X = {output_probs.size()}") + logger.info(f"\n{output_probs}") + + logger.info(f"Size of X_r = {pred_clustering.size()}") + logger.info(f"\n{pred_clustering}") + + return self.hac_cut_layer.cluster_labels diff --git a/e2e_pipeline/hac_cut_layer.py b/e2e_pipeline/hac_cut_layer.py index b0eb581..22efe82 100644 --- a/e2e_pipeline/hac_cut_layer.py +++ b/e2e_pipeline/hac_cut_layer.py @@ -123,5 +123,5 @@ def get_rounded_solution(self, X, weights, _MAX_DIST=10, use_similarities=True, self.objective_value = energy[max_node] return self.round_matrix - def forward(self, X, W): - return X + (self.get_rounded_solution(X, W) - X).detach() + def forward(self, X, W, use_similarities=True): + return X + (self.get_rounded_solution(X, W, use_similarities=use_similarities) - X).detach() diff --git a/e2e_pipeline/hac_inference.py b/e2e_pipeline/hac_inference.py new file mode 100644 index 0000000..ae30bd7 --- /dev/null +++ b/e2e_pipeline/hac_inference.py @@ -0,0 +1,27 @@ +import torch + +from e2e_pipeline.mlp_layer import MLPLayer +from e2e_pipeline.sdp_layer import SDPLayer +from e2e_pipeline.hac_cut_layer import HACCutLayer +from e2e_pipeline.trellis_cut_layer import TrellisCutLayer +from e2e_pipeline.uncompress_layer import UncompressTransformLayer +import logging +from IPython import embed + +logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', + level=logging.INFO) +logger = logging.getLogger(__name__) + +class HACInference: + """ + HAC inference-only. + """ + + def __init__(self): + super().__init__() + + def fit(self): + pass + + def cluster(self): + pass diff --git a/e2e_scripts/evaluate.py b/e2e_scripts/evaluate.py new file mode 100644 index 0000000..976988e --- /dev/null +++ b/e2e_scripts/evaluate.py @@ -0,0 +1,132 @@ +""" + Functions to evaluate end-to-end clustering and pairwise training +""" + +from tqdm import tqdm +from sklearn.metrics.cluster import v_measure_score +from sklearn.metrics import roc_curve, auc +from sklearn.metrics import precision_recall_fscore_support +import numpy as np +import torch + +from e2e_scripts.train_utils import compute_b3_f1 + +from IPython import embed + + +def evaluate(model, dataloader, overfit_batch_idx=-1, clustering_fn=None, tqdm_label='', device=None): + """ + clustering_fn: unused when pairwise_mode is False (only added to keep fn signature identical) + """ + device = device if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu") + n_features = dataloader.dataset[0][0].shape[1] + + vmeasure, b3_f1, sigs_per_block = [], [], [] + for (idx, batch) in enumerate(tqdm(dataloader, desc=f'Evaluating {tqdm_label}')): + if overfit_batch_idx > -1: + if idx < overfit_batch_idx: + continue + if idx > overfit_batch_idx: + break + data, target, cluster_ids = batch + data = data.reshape(-1, n_features).float() + if data.shape[0] == 0: + # Only one signature in block -> predict correctly + vmeasure.append(1.) + b3_f1.append(1.) + sigs_per_block.append(1) + else: + block_size = len(cluster_ids) # get_matrix_size_from_triu(data) + cluster_ids = np.reshape(cluster_ids, (block_size,)) + target = target.flatten().float() + sigs_per_block.append(block_size) + + # Forward pass through the e2e model + data, target = data.to(device), target.to(device) + _ = model(data, block_size) + predicted_cluster_ids = model.hac_cut_layer.cluster_labels # .detach() + + # Compute clustering metrics + vmeasure.append(v_measure_score(predicted_cluster_ids, cluster_ids)) + b3_f1_metrics = compute_b3_f1(cluster_ids, predicted_cluster_ids) + b3_f1.append(b3_f1_metrics[2]) + + vmeasure = np.array(vmeasure) + b3_f1 = np.array(b3_f1) + sigs_per_block = np.array(sigs_per_block) + + return np.sum(b3_f1 * sigs_per_block) / np.sum(sigs_per_block), \ + np.sum(vmeasure * sigs_per_block) / np.sum(sigs_per_block) + + +def evaluate_pairwise(model, dataloader, overfit_batch_idx=-1, mode="macro", return_pred_only=False, + thresh_for_f1=0.5, clustering_fn=None, clustering_threshold=None, val_dataloader=None, + tqdm_label='', device=None): + device = device if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu") + n_features = dataloader.dataset[0][0].shape[1] + + if clustering_fn is not None: + # Then dataloader passed is blockwise + vmeasure, b3_f1, sigs_per_block = [], [], [] + for (idx, batch) in enumerate(tqdm(dataloader, desc=f'Evaluating {tqdm_label}')): + if overfit_batch_idx > -1: + if idx < overfit_batch_idx: + continue + if idx > overfit_batch_idx: + break + data, target, cluster_ids = batch + data = data.reshape(-1, n_features).float() + if data.shape[0] == 0: + # Only one signature in block -> predict correctly + vmeasure.append(1.) + b3_f1.append(1.) + sigs_per_block.append(1) + else: + block_size = len(cluster_ids) # get_matrix_size_from_triu(data) + cluster_ids = np.reshape(cluster_ids, (block_size,)) + target = target.flatten().float() + sigs_per_block.append(block_size) + + # Forward pass through the e2e model + data, target = data.to(device), target.to(device) + predicted_cluster_ids = clustering_fn(model(data), block_size) # .detach() + + # Compute clustering metrics + vmeasure.append(v_measure_score(predicted_cluster_ids, cluster_ids)) + b3_f1_metrics = compute_b3_f1(cluster_ids, predicted_cluster_ids) + b3_f1.append(b3_f1_metrics[2]) + + vmeasure = np.array(vmeasure) + b3_f1 = np.array(b3_f1) + sigs_per_block = np.array(sigs_per_block) + + return np.sum(vmeasure * sigs_per_block) / np.sum(sigs_per_block), \ + np.sum(b3_f1 * sigs_per_block) / np.sum(sigs_per_block) + + y_pred, targets = [], [] + for (idx, batch) in enumerate(tqdm(dataloader, desc=f'Evaluating {tqdm_label}')): + if overfit_batch_idx > -1: + if idx < overfit_batch_idx: + continue + if idx > overfit_batch_idx: + break + data, target = batch + data = data.reshape(-1, n_features).float() + assert data.shape[0] != 0 + target = target.flatten().float() + # Forward pass through the pairwise model + data = data.to(device) + y_pred.append(torch.sigmoid(model(data)).cpu().numpy()) + targets.append(target) + y_pred = np.hstack(y_pred) + targets = np.hstack(targets) + + if return_pred_only: + return y_pred + + fpr, tpr, _ = roc_curve(targets, y_pred) + roc_auc = auc(fpr, tpr) + pr, rc, f1, _ = precision_recall_fscore_support(targets, y_pred >= thresh_for_f1, beta=1.0, average=mode, + zero_division=0) + + return roc_auc, np.round(f1, 3) diff --git a/e2e_scripts/train.py b/e2e_scripts/train.py index fd4be1b..19dd8ff 100644 --- a/e2e_scripts/train.py +++ b/e2e_scripts/train.py @@ -1,29 +1,23 @@ import json import os import time -from collections import defaultdict -from typing import Dict -from typing import Tuple -import math import logging import random import copy -import pickle -import torch import wandb -from torch.utils.data import DataLoader +import torch import numpy as np -from sklearn.metrics.cluster import v_measure_score -from sklearn.metrics import roc_curve, auc -from sklearn.metrics import precision_recall_fscore_support + from tqdm import tqdm +from e2e_pipeline.cc_inference import CCInference +from e2e_pipeline.hac_inference import HACInference from e2e_pipeline.model import EntResModel from e2e_pipeline.pairwise_model import PairwiseModel -from s2and.consts import PREPROCESSED_DATA_DIR -from s2and.data import S2BlocksDataset -from s2and.eval import b3_precision_recall_fscore +from e2e_scripts.evaluate import evaluate, evaluate_pairwise +from e2e_scripts.train_utils import DEFAULT_HYPERPARAMS, get_dataloaders, get_matrix_size_from_triu, \ + uncompress_target_tensor, count_parameters from utils.parser import Parser from IPython import embed @@ -33,190 +27,11 @@ level=logging.INFO) logger = logging.getLogger(__name__) -# Default hyperparameters -DEFAULT_HYPERPARAMS = { - # Dataset - "dataset": "pubmed", - "dataset_random_seed": 1, - "subsample_sz": -1, - "subsample_dev": True, - # Run config - "run_random_seed": 17, - # Data config - "convert_nan": False, - "nan_value": -1, - "drop_feat_nan_pct": -1, - "normalize_data": True, - # Model config - "neumiss_deq": False, - "neumiss_depth": 20, - "hidden_dim": 512, - "n_hidden_layers": 2, - "dropout_p": 0.1, - "dropout_only_once": True, - "batchnorm": True, - "hidden_config": None, - "activation": "leaky_relu", - "negative_slope": 0.01, - # Solver config - "sdp_max_iters": 50000, - "sdp_eps": 1e-3, - # Training config - "batch_size": 1, - "lr": 1e-4, - "n_epochs": 5, - "weighted_loss": True, # Only applies to pairwise model currently; TODO: Implement for e2e - "use_lr_scheduler": True, - "lr_scheduler": "plateau", # "step" - "lr_factor": 0.7, - "lr_min": 1e-6, - "lr_scheduler_patience": 10, - "lr_step_size": 200, - "lr_gamma": 0.1, - "weight_decay": 0.01, - "dev_opt_metric": 'b3_f1', # e2e: {'vmeasure', 'b3_f1'}; pairwise: {'auroc', 'f1'} - "overfit_batch_idx": -1 -} - -def read_blockwise_features(pkl): - blockwise_data: Dict[str, Tuple[np.ndarray, np.ndarray]] - with open(pkl,"rb") as _pkl_file: - blockwise_data = pickle.load(_pkl_file) - return blockwise_data - - -def get_dataloaders(dataset, dataset_seed, convert_nan, nan_value, normalize, subsample_sz, subsample_dev, - pairwise_mode, batch_size): - train_pkl = f"{PREPROCESSED_DATA_DIR}/{dataset}/seed{dataset_seed}/train_features.pkl" - val_pkl = f"{PREPROCESSED_DATA_DIR}/{dataset}/seed{dataset_seed}/val_features.pkl" - test_pkl = f"{PREPROCESSED_DATA_DIR}/{dataset}/seed{dataset_seed}/test_features.pkl" - - train_dataset = S2BlocksDataset(read_blockwise_features(train_pkl), convert_nan=convert_nan, nan_value=nan_value, - scale=normalize, subsample_sz=subsample_sz, pairwise_mode=pairwise_mode) - train_dataloader = DataLoader(train_dataset, shuffle=False, batch_size=batch_size) - - val_dataset = S2BlocksDataset(read_blockwise_features(val_pkl), convert_nan=convert_nan, nan_value=nan_value, - scale=normalize, scaler=train_dataset.scaler, - subsample_sz=subsample_sz if subsample_dev else -1, pairwise_mode=pairwise_mode) - val_dataloader = DataLoader(val_dataset, shuffle=False, batch_size=batch_size) - - test_dataset = S2BlocksDataset(read_blockwise_features(test_pkl), convert_nan=convert_nan, nan_value=nan_value, - scale=normalize, scaler=train_dataset.scaler, pairwise_mode=pairwise_mode) - test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size) - - return train_dataloader, val_dataloader, test_dataloader - -def uncompress_target_tensor(compressed_targets, make_symmetric=True): - n = round(math.sqrt(2 * compressed_targets.size(dim=0))) + 1 - # Convert the 1D pairwise-similarities list to nxn upper triangular matrix - ind0, ind1 = torch.triu_indices(n, n, offset=1) - target = torch.eye(n, device=device) - target[ind0, ind1] = compressed_targets - if make_symmetric: - target[ind1, ind0] = compressed_targets - return target - -# Count parameters in the model -def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) - - -def get_matrix_size_from_triu(triu): - return round(math.sqrt(2 * len(triu))) + 1 - - -def compute_b3_f1(true_cluster_ids, pred_cluster_ids): - """ - Compute the B^3 variant of precision, recall and F-score. - Returns: - Precision - Recall - F1 - Per signature metrics - Overmerging ratios - Undermerging ratios - """ - true_cluster_dict, pred_cluster_dict = defaultdict(list), defaultdict(list) - for i in range(len(true_cluster_ids)): - true_cluster_dict[true_cluster_ids[i]].append(i) - pred_cluster_dict[pred_cluster_ids[i].item()].append(i) - return b3_precision_recall_fscore(true_cluster_dict, pred_cluster_dict) - - -def evaluate(model, dataloader, overfit_batch_idx=-1): - n_features = dataloader.dataset[0][0].shape[1] - vmeasure, b3_f1, sigs_per_block = [], [], [] - for (idx, batch) in enumerate(tqdm(dataloader, desc='Evaluating')): - if overfit_batch_idx > -1: - if idx < overfit_batch_idx: - continue - if idx > overfit_batch_idx: - break - data, target, cluster_ids = batch - data = data.reshape(-1, n_features).float() - if data.shape[0] == 0: - # Only one signature in block -> predict correctly - vmeasure.append(1.) - b3_f1.append(1.) - sigs_per_block.append(1) - else: - block_size = get_matrix_size_from_triu(data) - cluster_ids = np.reshape(cluster_ids, (block_size, )) - target = target.flatten().float() - sigs_per_block.append(block_size) - - # Forward pass through the e2e model - data, target = data.to(device), target.to(device) - _ = model(data, block_size) - predicted_cluster_ids = model.hac_cut_layer.cluster_labels.detach() - - # Compute clustering metrics - vmeasure.append(v_measure_score(predicted_cluster_ids, cluster_ids)) - b3_f1_metrics = compute_b3_f1(cluster_ids, predicted_cluster_ids) - b3_f1.append(b3_f1_metrics[2]) - - vmeasure = np.array(vmeasure) - b3_f1 = np.array(b3_f1) - sigs_per_block = np.array(sigs_per_block) - - return np.sum(vmeasure * sigs_per_block) / np.sum(sigs_per_block), \ - np.sum(b3_f1 * sigs_per_block) / np.sum(sigs_per_block) - - -def evaluate_pairwise(model, dataloader, overfit_batch_idx=-1, mode="macro", return_pred_only=False, - thresh_for_f1=0.5): - n_features = dataloader.dataset[0][0].shape[1] - y_pred, targets = [], [] - for (idx, batch) in enumerate(tqdm(dataloader, desc='Evaluating')): - if overfit_batch_idx > -1: - if idx < overfit_batch_idx: - continue - if idx > overfit_batch_idx: - break - data, target = batch - data = data.reshape(-1, n_features).float() - assert data.shape[0] != 0 - target = target.flatten().float() - # Forward pass through the pairwise model - data = data.to(device) - y_pred.append(torch.sigmoid(model(data)).cpu().numpy()) - targets.append(target) - y_pred = np.hstack(y_pred) - targets = np.hstack(targets) - - if return_pred_only: - return y_pred - - fpr, tpr, _ = roc_curve(targets, y_pred) - roc_auc = auc(fpr, tpr) - pr, rc, f1, _ = precision_recall_fscore_support(targets, y_pred >= thresh_for_f1, beta=1.0, average=mode, - zero_division=0) - - return roc_auc, np.round(f1, 3) - def train(hyperparams={}, verbose=False, project=None, entity=None, tags=None, group=None, save_model=False, load_model_from_wandb_run=None, load_model_from_fpath=None, - eval_only_split=None, skip_initial_eval=False, pairwise_mode=False): + eval_only_split=None, skip_initial_eval=False, pairwise_mode=False, + pairwise_eval_clustering=None): init_args = { 'config': DEFAULT_HYPERPARAMS } @@ -245,7 +60,7 @@ def train(hyperparams={}, verbose=False, project=None, entity=None, tags=None, g np.random.seed(hyp['run_random_seed']) weighted_loss = hyp['weighted_loss'] - batch_size = hyp['batch_size'] + batch_size = hyp['batch_size'] if pairwise_mode else 1 # Force clustering runs to operate on 1 block only n_epochs = hyp['n_epochs'] use_lr_scheduler = hyp['use_lr_scheduler'] hidden_dim = hyp["hidden_dim"] @@ -262,7 +77,9 @@ def train(hyperparams={}, verbose=False, project=None, entity=None, tags=None, g sdp_max_iters = hyp["sdp_max_iters"] sdp_eps = hyp["sdp_eps"] overfit_batch_idx = hyp['overfit_batch_idx'] - eval_metric_to_idx = {'vmeasure': 0, 'b3_f1': 1} if not pairwise_mode else {'auroc': 0, 'f1': 1} + clustering_metrics = {'b3_f1': 0, 'vmeasure': 1} + pairwise_metrics = {'auroc': 0, 'f1': 1} + eval_metric_to_idx = clustering_metrics if not pairwise_mode else pairwise_metrics dev_opt_metric = hyp['dev_opt_metric'] if hyp['dev_opt_metric'] in eval_metric_to_idx \ else list(eval_metric_to_idx)[0] @@ -283,6 +100,7 @@ def train(hyperparams={}, verbose=False, project=None, entity=None, tags=None, g loss_fn = lambda pred, gold: torch.norm(gold - pred) # Define eval eval_fn = evaluate + pairwise_clustering_fn = None # Unused when pairwise_mode is False else: model = PairwiseModel(n_features, neumiss_depth, dropout_p, dropout_only_once, add_neumiss, neumiss_deq, hidden_dim, n_hidden_layers, add_batchnorm, activation, @@ -301,6 +119,19 @@ def train(hyperparams={}, verbose=False, project=None, entity=None, tags=None, g loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight) # Define eval eval_fn = evaluate_pairwise + pairwise_clustering_fn = None + if pairwise_eval_clustering is not None: + if pairwise_eval_clustering == 'cc': + pairwise_clustering_fn = CCInference(sdp_max_iters, sdp_eps) + pairwise_clustering_fn.eval() + elif pairwise_eval_clustering == 'hac': + pairwise_clustering_fn = HACInference() # TODO: Implement + else: + raise ValueError('Invalid argument passed to --pairwise_eval_clustering') + _, _, clustering_test_dataloader = get_dataloaders(hyp["dataset"], hyp["dataset_random_seed"], + hyp["convert_nan"], hyp["nan_value"], + hyp["normalize_data"], hyp["subsample_sz"], + hyp["subsample_dev"], False, 1) logger.info(f"Model loaded: {model}", ) # Load stored model, if available @@ -324,8 +155,16 @@ def train(hyperparams={}, verbose=False, project=None, entity=None, tags=None, g 'test': test_dataloader } with torch.no_grad(): + model.eval() + if pairwise_clustering_fn is not None: + assert eval_only_split == 'test' # Clustering in --eval_only_split implemented only for test set + eval_metric_to_idx = clustering_metrics + eval_dataloader = clustering_test_dataloader + else: + eval_dataloader = dataloaders[eval_only_split] start_time = time.time() - eval_scores = eval_fn(model, dataloaders[eval_only_split]) + eval_scores = eval_fn(model, eval_dataloader, clustering_fn=pairwise_clustering_fn, + tqdm_label=eval_only_split, device=device) end_time = time.time() if verbose: logger.info( @@ -360,14 +199,15 @@ def train(hyperparams={}, verbose=False, project=None, entity=None, tags=None, g with torch.no_grad(): model.eval() if overfit_batch_idx > -1: - train_scores = eval_fn(model, train_dataloader, overfit_batch_idx) + train_scores = eval_fn(model, train_dataloader, overfit_batch_idx=overfit_batch_idx, + tqdm_label='train', device=device) if verbose: logger.info(f"Initial: train_{list(eval_metric_to_idx)[0]}={train_scores[0]}, " + f"train_{list(eval_metric_to_idx)[1]}={train_scores[1]}") wandb.log({'epoch': 0, f'train_{list(eval_metric_to_idx)[0]}': train_scores[0], f'train_{list(eval_metric_to_idx)[1]}': train_scores[1]}) else: - dev_scores = eval_fn(model, val_dataloader) + dev_scores = eval_fn(model, val_dataloader, tqdm_label='dev', device=device) if verbose: logger.info(f"Initial: dev_{list(eval_metric_to_idx)[0]}={dev_scores[0]}, " + f"dev_{list(eval_metric_to_idx)[1]}={dev_scores[1]}") @@ -379,7 +219,7 @@ def train(hyperparams={}, verbose=False, project=None, entity=None, tags=None, g for i in range(n_epochs): wandb.log({'epoch': i + 1}) running_loss = [] - for (idx, batch) in enumerate(tqdm(train_dataloader, desc="Training")): + for (idx, batch) in enumerate(tqdm(train_dataloader, desc=f"Training {i + 1}")): if overfit_batch_idx > -1: if idx < overfit_batch_idx: continue @@ -409,7 +249,7 @@ def train(hyperparams={}, verbose=False, project=None, entity=None, tags=None, g # Calculate the loss if not pairwise_mode: - gold_output = uncompress_target_tensor(target) + gold_output = uncompress_target_tensor(target, device=device) if verbose: logger.info(f"Gold:\n{gold_output}") loss = loss_fn(output.view_as(gold_output), gold_output) / (2 * block_size) @@ -434,7 +274,8 @@ def train(hyperparams={}, verbose=False, project=None, entity=None, tags=None, g with torch.no_grad(): model.eval() if overfit_batch_idx > -1: - train_scores = eval_fn(model, train_dataloader, overfit_batch_idx) + train_scores = eval_fn(model, train_dataloader, overfit_batch_idx=overfit_batch_idx, + tqdm_label='train', device=device) if verbose: logger.info(f"Epoch {i + 1}: train_{list(eval_metric_to_idx)[0]}={train_scores[0]}, " + f"train_{list(eval_metric_to_idx)[1]}={train_scores[1]}") @@ -446,7 +287,7 @@ def train(hyperparams={}, verbose=False, project=None, entity=None, tags=None, g elif hyp['lr_scheduler'] == 'step': scheduler.step() else: - dev_scores = eval_fn(model, val_dataloader) + dev_scores = eval_fn(model, val_dataloader, tqdm_label='dev', device=device) if verbose: logger.info(f"epoch {i + 1}: dev_{list(eval_metric_to_idx)[0]}={dev_scores[0]}, " + f"dev_{list(eval_metric_to_idx)[1]}={dev_scores[1]}") @@ -474,7 +315,7 @@ def train(hyperparams={}, verbose=False, project=None, entity=None, tags=None, g model.load_state_dict(best_dev_state_dict) with torch.no_grad(): model.eval() - test_scores = eval_fn(model, test_dataloader) + test_scores = eval_fn(model, test_dataloader, tqdm_label='test', device=device) if verbose: logger.info(f"Final: test_{list(eval_metric_to_idx)[0]}={test_scores[0]}, " + f"test_{list(eval_metric_to_idx)[1]}={test_scores[1]}") @@ -484,6 +325,17 @@ def train(hyperparams={}, verbose=False, project=None, entity=None, tags=None, g f'best_dev_{list(eval_metric_to_idx)[1]}': best_dev_scores[1], f'best_test_{list(eval_metric_to_idx)[0]}': test_scores[0], f'best_test_{list(eval_metric_to_idx)[1]}': test_scores[1]}) + if pairwise_clustering_fn is not None: + clustering_scores = eval_fn(model, clustering_test_dataloader, + clustering_fn=pairwise_clustering_fn, tqdm_label='test clustering', + device=device) + if verbose: + logger.info(f"Final: test_{list(clustering_metrics)[0]}={clustering_scores[0]}, " + + f"test_{list(clustering_metrics)[1]}={clustering_scores[1]}") + # Log final metrics + wandb.log({f'best_test_{list(clustering_metrics)[0]}': clustering_scores[0], + f'best_test_{list(clustering_metrics)[1]}': clustering_scores[1]}) + run.summary["z_model_parameters"] = count_parameters(model) run.summary["z_run_time"] = round(end_time - start_time) @@ -493,6 +345,7 @@ def train(hyperparams={}, verbose=False, project=None, entity=None, tags=None, g if save_model: torch.save(best_dev_state_dict, os.path.join(run.dir, 'model_state_dict_best.pt')) wandb.save('model_state_dict_best.pt') + logger.info(f"Saved best model on dev to {os.path.join(run.dir, 'model_state_dict_best.pt')}") logger.info(f"Run directory: {run.dir}") logger.info("End of train() call") @@ -593,5 +446,6 @@ def train(hyperparams={}, verbose=False, project=None, entity=None, tags=None, g load_model_from_fpath=args['load_model_from_fpath'], eval_only_split=args['eval_only_split'], skip_initial_eval=args['skip_initial_eval'], - pairwise_mode=args['pairwise_mode']) + pairwise_mode=args['pairwise_mode'], + pairwise_eval_clustering=args['pairwise_eval_clustering']) logger.info("End of run") diff --git a/e2e_scripts/train_utils.py b/e2e_scripts/train_utils.py new file mode 100644 index 0000000..0b4b24c --- /dev/null +++ b/e2e_scripts/train_utils.py @@ -0,0 +1,130 @@ +""" + Helper functions and constants for e2e_scripts/train.py +""" + +from collections import defaultdict +from typing import Dict +from typing import Tuple +import math +import pickle +from torch.utils.data import DataLoader +from s2and.consts import PREPROCESSED_DATA_DIR +from s2and.data import S2BlocksDataset +from s2and.eval import b3_precision_recall_fscore +import torch +import numpy as np + +from IPython import embed + + +# Default hyperparameters +DEFAULT_HYPERPARAMS = { + # Dataset + "dataset": "pubmed", + "dataset_random_seed": 1, + "subsample_sz": -1, + "subsample_dev": True, + # Run config + "run_random_seed": 17, + # Data config + "convert_nan": False, + "nan_value": -1, + "drop_feat_nan_pct": -1, + "normalize_data": True, + # Model config + "neumiss_deq": False, + "neumiss_depth": 20, + "hidden_dim": 512, + "n_hidden_layers": 2, + "dropout_p": 0.1, + "dropout_only_once": True, + "batchnorm": True, + "hidden_config": None, + "activation": "leaky_relu", + "negative_slope": 0.01, + # Solver config + "sdp_max_iters": 50000, + "sdp_eps": 1e-3, + # Training config + "batch_size": 10000, # For pairwise_mode only + "lr": 1e-4, + "n_epochs": 5, + "weighted_loss": True, # For pairwise_mode only; TODO: Implement for e2e + "use_lr_scheduler": True, + "lr_scheduler": "plateau", # "step" + "lr_factor": 0.7, + "lr_min": 1e-6, + "lr_scheduler_patience": 10, + "lr_step_size": 200, + "lr_gamma": 0.1, + "weight_decay": 0.01, + "dev_opt_metric": 'b3_f1', # e2e: {'vmeasure', 'b3_f1'}; pairwise: {'auroc', 'f1'} + "overfit_batch_idx": -1 +} + + +def read_blockwise_features(pkl): + blockwise_data: Dict[str, Tuple[np.ndarray, np.ndarray]] + with open(pkl, "rb") as _pkl_file: + blockwise_data = pickle.load(_pkl_file) + return blockwise_data + + +def get_dataloaders(dataset, dataset_seed, convert_nan, nan_value, normalize, subsample_sz, subsample_dev, + pairwise_mode, batch_size): + train_pkl = f"{PREPROCESSED_DATA_DIR}/{dataset}/seed{dataset_seed}/train_features.pkl" + val_pkl = f"{PREPROCESSED_DATA_DIR}/{dataset}/seed{dataset_seed}/val_features.pkl" + test_pkl = f"{PREPROCESSED_DATA_DIR}/{dataset}/seed{dataset_seed}/test_features.pkl" + + train_dataset = S2BlocksDataset(read_blockwise_features(train_pkl), convert_nan=convert_nan, nan_value=nan_value, + scale=normalize, subsample_sz=subsample_sz, pairwise_mode=pairwise_mode) + train_dataloader = DataLoader(train_dataset, shuffle=False, batch_size=batch_size) + + val_dataset = S2BlocksDataset(read_blockwise_features(val_pkl), convert_nan=convert_nan, nan_value=nan_value, + scale=normalize, scaler=train_dataset.scaler, + subsample_sz=subsample_sz if subsample_dev else -1, pairwise_mode=pairwise_mode) + val_dataloader = DataLoader(val_dataset, shuffle=False, batch_size=batch_size) + + test_dataset = S2BlocksDataset(read_blockwise_features(test_pkl), convert_nan=convert_nan, nan_value=nan_value, + scale=normalize, scaler=train_dataset.scaler, pairwise_mode=pairwise_mode) + test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size) + + return train_dataloader, val_dataloader, test_dataloader + + +def uncompress_target_tensor(compressed_targets, make_symmetric=True, device=None): + device = device if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu") + n = round(math.sqrt(2 * compressed_targets.size(dim=0))) + 1 + # Convert the 1D pairwise-similarities list to nxn upper triangular matrix + ind0, ind1 = torch.triu_indices(n, n, offset=1) + target = torch.eye(n, device=device) + target[ind0, ind1] = compressed_targets + if make_symmetric: + target[ind1, ind0] = compressed_targets + return target + + +# Count parameters in the model +def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +def get_matrix_size_from_triu(triu): + return round(math.sqrt(2 * len(triu))) + 1 + + +def compute_b3_f1(true_cluster_ids, pred_cluster_ids): + """ + Compute the B^3 variant of precision, recall and F-score. + Returns: + Precision + Recall + F1 + Per signature metrics + Overmerging ratios + Undermerging ratios + """ + true_cluster_dict, pred_cluster_dict = defaultdict(list), defaultdict(list) + for i in range(len(true_cluster_ids)): + true_cluster_dict[true_cluster_ids[i]].append(i) + pred_cluster_dict[pred_cluster_ids[i].item()].append(i) + return b3_precision_recall_fscore(true_cluster_dict, pred_cluster_dict) diff --git a/utils/parser.py b/utils/parser.py index 7771f1b..a5f10f2 100644 --- a/utils/parser.py +++ b/utils/parser.py @@ -70,7 +70,7 @@ def add_training_args(self): ) parser.add_argument( "--wandb_sweep_id", type=str, - help="Wandb sweep id (optional -- if run is already started)", + help="Attach wandb agents to an existing wandb sweep (expects 'entity/project/runid' as input)", ) parser.add_argument( "--wandb_sweep_method", type=str, default="bayes", @@ -118,7 +118,7 @@ def add_training_args(self): ) parser.add_argument( "--load_model_from_wandb_run", type=str, - help="Load model state_dict from a previous wandb run", + help="Load model state_dict from a previous wandb run (expects 'entity/project/runid' as input)", ) parser.add_argument( "--load_model_from_fpath", type=str, @@ -140,3 +140,12 @@ def add_training_args(self): "--pairwise_mode", action='store_true', help="Whether to use the pairwise MLP-only model or the e2e clustering model", ) + parser.add_argument( + "--pairwise_eval_clustering", type=str, + help="(only in --pairwise_mode) Whether to run clustering during --eval_only_split and final test eval. " + + "Accepts 'cc' for correlation clustering and 'hac' for agglomerative clustering.", + ) + parser.add_argument( + "--load_gbdt_model", type=str, + help="(only in --pairwise_mode and --eval_only_split) Load a gbdt model to compute the pairwise weights." + )