From c1e13ae21024d441b7a949d9e08baccbb58346c8 Mon Sep 17 00:00:00 2001 From: Dhruv Agarwal Date: Sat, 18 Mar 2023 04:23:50 -0400 Subject: [PATCH] Inference solver, parallel eval iterations, sweep config changes (#41) --- e2e_debug/solve.py | 41 +-- e2e_pipeline/sdp_layer.py | 41 ++- e2e_pipeline/uncompress_layer.py | 2 +- e2e_scripts/evaluate.py | 140 ++++++++-- e2e_scripts/preprocess_s2and_data.py | 2 +- e2e_scripts/train.py | 326 ++++++++-------------- e2e_scripts/train_utils.py | 183 ++++++++++-- s2and/data.py | 7 +- utils/parser.py | 5 +- wandb_configs/sweeps/e2e-nosdp-warm.json | 6 +- wandb_configs/sweeps/e2e-nosdp.json | 6 +- wandb_configs/sweeps/e2e-warm.json | 6 +- wandb_configs/sweeps/e2e.json | 6 +- wandb_configs/sweeps/frac-nosdp-warm.json | 6 +- wandb_configs/sweeps/frac-nosdp.json | 6 +- wandb_configs/sweeps/frac-warm.json | 6 +- wandb_configs/sweeps/frac.json | 6 +- wandb_configs/sweeps/mlp.json | 1 + 18 files changed, 487 insertions(+), 309 deletions(-) diff --git a/e2e_debug/solve.py b/e2e_debug/solve.py index 6e47ef2..ccefa35 100644 --- a/e2e_debug/solve.py +++ b/e2e_debug/solve.py @@ -50,6 +50,9 @@ def __init__(self): self.add_argument( "--scs_log_csv_filename", type=str, ) + self.add_argument( + "--max_scaling", action="store_true", + ) self.add_argument( "--interactive", action="store_true", ) @@ -63,15 +66,18 @@ def __init__(self): # Read error file logger.info("Reading input data") - with open(args.data_fpath, 'r') as fh: - data = json.load(fh) - assert len(data['errors']) > 0 - # Pick specific error instance to process - error_data = data['errors'][args.data_idx] + if args.data_fpath.endswith('.pt'): + _W_val = torch.load(args.data_fpath, map_location='cpu').numpy() + else: + with open(args.data_fpath, 'r') as fh: + data = json.load(fh) + assert len(data['errors']) > 0 + # Pick specific error instance to process + error_data = data['errors'][args.data_idx] - # Extract input data from the error instance - _raw = np.array(error_data['model_call_args']['data']) - _W_val = np.array(error_data['cvxpy_layer_args']['W_val']) + # Extract input data from the error instance + _raw = np.array(error_data['model_call_args']['data']) + _W_val = np.array(error_data['cvxpy_layer_args']['W_val']) # Construct cvxpy problem logger.info('Constructing optimization problem') @@ -84,7 +90,7 @@ def __init__(self): constraints = [ cp.diag(X) == np.ones((n,)), X[:n, :] >= 0, - X[:n, :] <= 1 + # X[:n, :] <= 1 ] # Setup HAC Cut @@ -94,12 +100,14 @@ def __init__(self): sdp_obj_value = float('inf') result_idxs, results_X, results_clustering = [], [], [] no_solution_scaling_factors = [] - for i in range(1, 10): # n + for i in range(0, 10): # n # Skipping 1; no scaling leads to non-convergence (infinite objective value) - if i == 1: - scaling_factor = np.max(W) + if i == 0: + scaling_factor = np.max(np.abs(W)) else: scaling_factor = i + if args.max_scaling: + continue logger.info(f'Scaling factor={scaling_factor}') # Create problem W_scaled = W / scaling_factor @@ -114,8 +122,7 @@ def __init__(self): alpha=args.scs_alpha, scale=args.scs_scale, use_indirect=args.scs_use_indirect, - use_quad_obj=not args.scs_dont_use_quad_obj, - log_csv_filename=args.scs_log_csv_filename + # use_quad_obj=not args.scs_dont_use_quad_obj ) logger.info(f"@scaling={scaling_factor}, objective value = {sdp_obj_value}, norm={np.linalg.norm(W_scaled)}") if sdp_obj_value != float('inf'): @@ -129,9 +136,9 @@ def __init__(self): logger.info(f"Solution not found = {len(no_solution_scaling_factors)}") logger.info(f"Solution found = {len(results_X)}") - logger.info("Same clustering:") - for i in range(len(results_clustering)-1): - logger.info(np.array_equal(results_clustering[i], results_clustering[i + 1])) + # logger.info("Same clustering:") + # for i in range(len(results_clustering)-1): + # logger.info(np.array_equal(results_clustering[i], results_clustering[i + 1])) # logger.info(f"Solution found with scaling factor = {scaling_factor}") # if args.interactive and sdp_obj_value == float('inf'): # embed() diff --git a/e2e_pipeline/sdp_layer.py b/e2e_pipeline/sdp_layer.py index bdc10c0..be914b8 100644 --- a/e2e_pipeline/sdp_layer.py +++ b/e2e_pipeline/sdp_layer.py @@ -50,22 +50,33 @@ def build_and_solve_sdp(self, W_val, N, verbose=False): X[:N, :] >= 0, ] - # create problem - prob = cp.Problem(cp.Maximize(cp.trace(W @ X)), constraints) - # Note: maximizing the trace is equivalent to maximizing the sum_E (w_uv * X_uv) objective - # because W is upper-triangular and X is symmetric - - # Build the SDP cvxpylayer - cvxpy_layer = CvxpyLayer(prob, parameters=[W], variables=[X]) - - # Forward pass through the SDP cvxpylayer try: - pw_prob_matrix = cvxpy_layer(W_val, solver_args={ - "solve_method": "SCS", - "verbose": verbose, - "max_iters": self.max_iters, - "eps": self.eps - })[0] + if self.training: + # create problem + prob = cp.Problem(cp.Maximize(cp.trace(W @ X)), constraints) + # Note: maximizing the trace is equivalent to maximizing the sum_E (w_uv * X_uv) objective + # because W is upper-triangular and X is symmetric + # Build the SDP cvxpylayer + cvxpy_layer = CvxpyLayer(prob, parameters=[W], variables=[X]) + # Forward pass through the SDP cvxpylayer + pw_prob_matrix = cvxpy_layer(W_val, solver_args={ + "solve_method": "SCS", + "verbose": verbose, + "max_iters": self.max_iters, + "eps": self.eps + })[0] + else: + # create problem + prob = cp.Problem(cp.Maximize(cp.trace(W_val.cpu().numpy() @ X)), constraints) + _solve_val = prob.solve( + solver=cp.SCS, + verbose=verbose, + max_iters=self.max_iters, + eps=self.eps + ) + if _solve_val == float('inf'): + raise ValueError() + pw_prob_matrix = torch.tensor(X.value, device=W_val.device) # Fix to prevent invalid solution values close to 0 and 1 but outside the range pw_prob_matrix = torch.clamp(pw_prob_matrix, min=0, max=1) except: diff --git a/e2e_pipeline/uncompress_layer.py b/e2e_pipeline/uncompress_layer.py index 93d99dd..7198b07 100644 --- a/e2e_pipeline/uncompress_layer.py +++ b/e2e_pipeline/uncompress_layer.py @@ -6,7 +6,7 @@ def __init__(self): super().__init__() def forward(self, compressed_matrix, N, make_symmetric=False, ones_diagonal=False): - device = compressed_matrix.get_device() + device = compressed_matrix.device triu_indices = torch.triu_indices(N, N, offset=1, device=device) if make_symmetric: sym_indices = torch.stack((torch.cat((triu_indices[0], triu_indices[1])), diff --git a/e2e_scripts/evaluate.py b/e2e_scripts/evaluate.py index b7ba424..71ec3d6 100644 --- a/e2e_scripts/evaluate.py +++ b/e2e_scripts/evaluate.py @@ -8,13 +8,14 @@ from sklearn.metrics.cluster import v_measure_score from sklearn.metrics import roc_curve, auc from sklearn.metrics import precision_recall_fscore_support +from torch.multiprocessing import Process, Manager import numpy as np import torch from e2e_pipeline.cc_inference import CCInference from e2e_pipeline.hac_inference import HACInference from e2e_pipeline.sdp_layer import CvxpyException -from e2e_scripts.train_utils import compute_b3_f1, save_to_wandb_run +from e2e_scripts.train_utils import compute_b3_f1, save_to_wandb_run, copy_and_load_model from IPython import embed @@ -24,13 +25,50 @@ logger = logging.getLogger(__name__) +def _run_iter(model_class, state_dict_path, _fork_id, _shared_list, eval_fn, **kwargs): + model = model_class(*kwargs['model_args']) + model.load_state_dict(torch.load(state_dict_path)) + model.to('cpu') + model.eval() + with torch.no_grad(): + res = eval_fn(model=model, **kwargs) + _shared_list.append(res) + del model + + +def _fork_iter(batch_idx, _fork_id, _shared_list, eval_fn, **kwargs): + kwargs['model_class'] = kwargs['model'].__class__ + kwargs['state_dict_path'] = copy_and_load_model(kwargs['model'], kwargs['run_dir'], 'cpu', store_only=True) + del kwargs['model'] + kwargs['overfit_batch_idx'] = batch_idx + kwargs['tqdm_label'] = f'{kwargs["tqdm_label"]} (fork{_fork_id})' + kwargs['_fork_id'] = _fork_id + kwargs['tqdm_position'] = (0 if kwargs['tqdm_position'] is None else kwargs['tqdm_position']) + _fork_id + 1 + kwargs['return_iter'] = True + kwargs['fork_size'] = -1 + kwargs['_shared_list'] = _shared_list + kwargs['disable_tqdm'] = True + kwargs['device'] = 'cpu' + kwargs['eval_fn'] = eval_fn + _proc = Process(target=_run_iter, kwargs=kwargs) + _proc.start() + return _proc + + def evaluate(model, dataloader, overfit_batch_idx=-1, clustering_fn=None, clustering_threshold=None, val_dataloader=None, tqdm_label='', device=None, verbose=False, debug=False, _errors=None, - run_dir='./', tqdm_position=None): + run_dir='./', tqdm_position=None, model_args=None, return_iter=False, fork_size=500, + disable_tqdm=False): """ clustering_fn, clustering_threshold, val_dataloader: unused when pairwise_mode is False (only added to keep fn signature identical) """ + fn_args = locals() + fork_enabled = fork_size > -1 and model_args is not None + if fork_enabled: + _fork_id = 1 + _shared_list = Manager().list() + _procs = [] device = device if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu") n_features = dataloader.dataset[0][0].shape[1] @@ -43,7 +81,8 @@ def evaluate(model, dataloader, overfit_batch_idx=-1, clustering_fn=None, cluste } max_pred_id = -1 n_exceptions = 0 - for (idx, batch) in enumerate(tqdm(dataloader, desc=f'Evaluating {tqdm_label}', position=tqdm_position)): + pbar = tqdm(dataloader, desc=f'Eval {tqdm_label}', position=tqdm_position, disable=disable_tqdm) + for (idx, batch) in enumerate(pbar): if overfit_batch_idx > -1: if idx < overfit_batch_idx: continue @@ -51,11 +90,16 @@ def evaluate(model, dataloader, overfit_batch_idx=-1, clustering_fn=None, cluste break data, _, cluster_ids = batch block_size = len(cluster_ids) - all_gold += list(np.reshape(cluster_ids, (block_size,))) + pbar.set_description(f'Eval {tqdm_label} (sz={block_size})') data = data.reshape(-1, n_features).float() if data.shape[0] == 0: # Only one signature in block; manually assign a unique cluster pred_cluster_ids = [max_pred_id + 1] + elif fork_enabled and block_size >= fork_size: + _proc = _fork_iter(idx, _fork_id, _shared_list, evaluate, **fn_args) + _fork_id += 1 + _procs.append((_proc, block_size)) + continue else: # Forward pass through the e2e model data = data.to(device) @@ -79,8 +123,6 @@ def evaluate(model, dataloader, overfit_batch_idx=-1, clustering_fn=None, cluste save_to_wandb_run({'errors': _errors}, 'errors.json', run_dir, logger) if not debug: # if tqdm_label is not 'dev' and not debug: raise CvxpyException(data=_error_obj) - # If split is dev, skip batch and continue - all_gold = all_gold[:-len(cluster_ids)] n_exceptions += 1 logger.info(f'Caught CvxpyException {n_exceptions}: skipping batch') continue @@ -89,8 +131,34 @@ def evaluate(model, dataloader, overfit_batch_idx=-1, clustering_fn=None, cluste cc_obj_vals['sdp'].append(model.sdp_layer.objective_value) cc_obj_vals['block_idxs'].append(idx) cc_obj_vals['block_sizes'].append(block_size) + all_gold += list(np.reshape(cluster_ids, (block_size,))) max_pred_id = max(pred_cluster_ids) all_pred += list(pred_cluster_ids) + if overfit_batch_idx > -1 and return_iter: + return { + 'cluster_labels': model.hac_cut_layer.cluster_labels, + 'round_objective_value': model.hac_cut_layer.objective_value, + 'sdp_objective_value': model.sdp_layer.objective_value, + 'block_idx': idx, + 'block_size': block_size, + 'cluster_ids': cluster_ids + } + + if fork_enabled and len(_procs) > 0: + _procs.sort(key=lambda x: x[1]) # To visualize progress + for _proc in tqdm(_procs, desc=f'Eval {tqdm_label} (waiting for forks to join)', position=tqdm_position): + _proc[0].join() + assert len(_procs) == len(_shared_list), "All forked eval iterations did not return results" + for _data in _shared_list: + pred_cluster_ids = (_data['cluster_labels'] + (max_pred_id + 1)).tolist() + cc_obj_vals['round'].append(_data['round_objective_value']) + cc_obj_vals['sdp'].append(_data['sdp_objective_value']) + cc_obj_vals['block_idxs'].append(_data['block_idx']) + cc_obj_vals['block_sizes'].append(_data['block_size']) + all_gold += list(np.reshape(_data['cluster_ids'], (_data['block_size'],))) + max_pred_id = max(pred_cluster_ids) + all_pred += list(pred_cluster_ids) + vmeasure = v_measure_score(all_gold, all_pred) b3_f1 = compute_b3_f1(all_gold, all_pred)[2] return b3_f1, vmeasure, cc_obj_vals @@ -99,7 +167,13 @@ def evaluate(model, dataloader, overfit_batch_idx=-1, clustering_fn=None, cluste def evaluate_pairwise(model, dataloader, overfit_batch_idx=-1, mode="macro", return_pred_only=False, thresh_for_f1=0.5, clustering_fn=None, clustering_threshold=None, val_dataloader=None, tqdm_label='', device=None, verbose=False, debug=False, _errors=None, run_dir='./', - tqdm_position=None): + tqdm_position=None, model_args=None, return_iter=False, fork_size=500, disable_tqdm=False): + fn_args = locals() + fork_enabled = fork_size > -1 and model_args is not None + if fork_enabled: + _fork_id = 1 + _shared_list = Manager().list() + _procs = [] device = device if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu") n_features = dataloader.dataset[0][0].shape[1] @@ -117,7 +191,8 @@ def evaluate_pairwise(model, dataloader, overfit_batch_idx=-1, mode="macro", ret } max_pred_id = -1 # In each iteration, add to all blockwise predicted IDs to distinguish from previous blocks n_exceptions = 0 - for (idx, batch) in enumerate(tqdm(dataloader, desc=f'Evaluating {tqdm_label}', position=tqdm_position)): + pbar = tqdm(dataloader, desc=f'Eval {tqdm_label}', position=tqdm_position, disable=disable_tqdm) + for (idx, batch) in enumerate(pbar): if overfit_batch_idx > -1: if idx < overfit_batch_idx: continue @@ -125,11 +200,16 @@ def evaluate_pairwise(model, dataloader, overfit_batch_idx=-1, mode="macro", ret break data, _, cluster_ids = batch block_size = len(cluster_ids) - all_gold += list(np.reshape(cluster_ids, (block_size,))) + pbar.set_description(f'Eval {tqdm_label} (sz={block_size})') data = data.reshape(-1, n_features).float() if data.shape[0] == 0: # Only one signature in block; manually assign a unique cluster pred_cluster_ids = [max_pred_id + 1] + elif fork_enabled and block_size >= fork_size and clustering_fn.__class__ is CCInference: + _proc = _fork_iter(idx, _fork_id, _shared_list, evaluate_pairwise, **fn_args) + _fork_id += 1 + _procs.append((_proc, block_size)) + continue else: # Forward pass through the e2e model data = data.to(device) @@ -155,24 +235,49 @@ def evaluate_pairwise(model, dataloader, overfit_batch_idx=-1, mode="macro", ret save_to_wandb_run({'errors': _errors}, 'errors.json', run_dir, logger) if not debug: # if tqdm_label is not 'dev' and not debug: raise CvxpyException(data=_error_obj) - # If split is dev, skip batch and continue - all_gold = all_gold[:-len(cluster_ids)] n_exceptions += 1 logger.info(f'Caught CvxpyException {n_exceptions}: skipping batch') continue + if clustering_fn.__class__ is CCInference: + cc_obj_vals['round'].append(clustering_fn.hac_cut_layer.objective_value) + cc_obj_vals['sdp'].append(clustering_fn.sdp_layer.objective_value) + cc_obj_vals['block_idxs'].append(idx) + cc_obj_vals['block_sizes'].append(block_size) + all_gold += list(np.reshape(cluster_ids, (block_size,))) max_pred_id = max(pred_cluster_ids) all_pred += list(pred_cluster_ids) - if clustering_fn.__class__ is CCInference: - cc_obj_vals['round'].append(clustering_fn.hac_cut_layer.objective_value) - cc_obj_vals['sdp'].append(clustering_fn.sdp_layer.objective_value) - cc_obj_vals['block_idxs'].append(idx) - cc_obj_vals['block_sizes'].append(block_size) + if overfit_batch_idx > -1 and return_iter: + return { + 'cluster_labels': list(np.array(pred_cluster_ids) - (max_pred_id + 1)), + 'round_objective_value': clustering_fn.hac_cut_layer.objective_value, + 'sdp_objective_value': clustering_fn.sdp_layer.objective_value, + 'block_idx': idx, + 'block_size': block_size, + 'cluster_ids': cluster_ids + } + + if fork_enabled and len(_procs) > 0: + _procs.sort(key=lambda x: x[1]) # To visualize progress + for _proc in tqdm(_procs, desc=f'Eval {tqdm_label} (waiting for forks to join)', position=tqdm_position): + _proc[0].join() + assert len(_procs) == len(_shared_list), "All forked eval iterations did not return results" + for _data in _shared_list: + pred_cluster_ids = (_data['cluster_labels'] + (max_pred_id + 1)).tolist() + cc_obj_vals['round'].append(_data['round_objective_value']) + cc_obj_vals['sdp'].append(_data['sdp_objective_value']) + cc_obj_vals['block_idxs'].append(_data['block_idx']) + cc_obj_vals['block_sizes'].append(_data['block_size']) + all_gold += list(np.reshape(_data['cluster_ids'], (_data['block_size'],))) + max_pred_id = max(pred_cluster_ids) + all_pred += list(pred_cluster_ids) + vmeasure = v_measure_score(all_gold, all_pred) b3_f1 = compute_b3_f1(all_gold, all_pred)[2] return (b3_f1, vmeasure, cc_obj_vals) if clustering_fn.__class__ is CCInference else (b3_f1, vmeasure) y_pred, targets = [], [] - for (idx, batch) in enumerate(tqdm(dataloader, desc=f'Evaluating {tqdm_label}', position=tqdm_position)): + pbar = tqdm(dataloader, desc=f'Eval {tqdm_label}', position=tqdm_position, disable=disable_tqdm) + for (idx, batch) in enumerate(pbar): if overfit_batch_idx > -1: if idx < overfit_batch_idx: continue @@ -180,6 +285,7 @@ def evaluate_pairwise(model, dataloader, overfit_batch_idx=-1, mode="macro", ret break data, target = batch data = data.reshape(-1, n_features).float() + pbar.set_description(f'Eval {tqdm_label} (sz={len(data)})') assert data.shape[0] != 0 target = target.flatten().float() # Forward pass through the pairwise model diff --git a/e2e_scripts/preprocess_s2and_data.py b/e2e_scripts/preprocess_s2and_data.py index cb46521..d97283b 100644 --- a/e2e_scripts/preprocess_s2and_data.py +++ b/e2e_scripts/preprocess_s2and_data.py @@ -118,7 +118,7 @@ def find_total_num_train_pairs(blockwise_data): DATA_HOME_DIR = params["data_home_dir"] dataset = params["dataset_name"] - random_seeds = {1, 2, 3, 4, 5} + random_seeds = [1, 2, 3, 4, 5] if params["dataset_seed"] is None else [params["dataset_seed"]] for seed in random_seeds: print("Preprocessing started for seed value", seed) save_blockwise_featurized_data(dataset, seed) diff --git a/e2e_scripts/train.py b/e2e_scripts/train.py index 23b36d8..a645257 100644 --- a/e2e_scripts/train.py +++ b/e2e_scripts/train.py @@ -1,7 +1,6 @@ import glob import json import os -import sys import time import logging import random @@ -10,8 +9,8 @@ import wandb import torch import numpy as np - from tqdm import tqdm +from torch.multiprocessing import set_start_method, Manager from e2e_pipeline.cc_inference import CCInference from e2e_pipeline.hac_inference import HACInference @@ -21,119 +20,21 @@ from e2e_scripts.evaluate import evaluate, evaluate_pairwise from e2e_scripts.train_utils import DEFAULT_HYPERPARAMS, get_dataloaders, get_matrix_size_from_triu, \ uncompress_target_tensor, count_parameters, log_cc_objective_values, save_to_wandb_run, FrobeniusLoss, \ - copy_and_load_model + get_feature_count, _check_process, fork_eval, init_eval, dev_eval from utils.parser import Parser -from torch.multiprocessing import Process, set_start_method, Manager +from IPython import embed try: set_start_method('spawn', force=True) except RuntimeError: pass -from IPython import embed - - logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) logger = logging.getLogger(__name__) -def _check_process(_proc, _return_dict, logger, run, overfit_batch_idx, use_lr_scheduler, hyp, - scheduler, eval_metric_to_idx, dev_opt_metric, i, best_epoch, best_dev_score, - best_dev_scores, best_dev_state_dict, sync=False): - if _proc is not None: - if _return_dict['_state'] == 'done' or (sync and _return_dict['_state'] == 'start'): - _proc.join() - _return_dict['_state'] = 'finish' - if _return_dict['_method'] == 'init_eval': - logger.info(_return_dict['local']) - run.log(_return_dict['wandb']) - elif _return_dict['_method'] == 'dev_eval': - logger.info(_return_dict['local']) - run.log(_return_dict['wandb']) - if overfit_batch_idx > -1: - if use_lr_scheduler: - if hyp['lr_scheduler'] == 'plateau': - scheduler.step(_return_dict['train_scores'][eval_metric_to_idx[dev_opt_metric]]) - elif hyp['lr_scheduler'] == 'step': - scheduler.step() - else: - dev_scores = _return_dict['dev_scores'] - dev_opt_score = dev_scores[eval_metric_to_idx[dev_opt_metric]] - if dev_opt_score > best_dev_score: - logger.info(f"New best dev {dev_opt_metric} score @ epoch{i+1}: {dev_opt_score}") - best_epoch = i - best_dev_score = dev_opt_score - best_dev_scores = dev_scores - best_dev_state_dict = torch.load(_return_dict['state_dict_path'], device) - if use_lr_scheduler: - if hyp['lr_scheduler'] == 'plateau': - scheduler.step(dev_scores[eval_metric_to_idx[dev_opt_metric]]) - elif hyp['lr_scheduler'] == 'step': - scheduler.step() - return best_epoch, best_dev_score, best_dev_scores, best_dev_state_dict - - -def init_eval(model_class, model_args, state_dict_path, overfit_batch_idx, eval_fn, train_dataloader, device, verbose, - debug, _errors, eval_metric_to_idx, val_dataloader, return_dict): - return_dict['_state'] = 'start' - return_dict['_method'] = 'init_eval' - model = model_class(*model_args) - model.load_state_dict(torch.load(state_dict_path)) - model.to(device) - with torch.no_grad(): - model.eval() - if overfit_batch_idx > -1: - train_scores = eval_fn(model, train_dataloader, overfit_batch_idx=overfit_batch_idx, - tqdm_label='train', device=device, verbose=verbose, debug=debug, - _errors=_errors, tqdm_position=0) - return_dict['local'] = f"Initial: train_{list(eval_metric_to_idx)[0]}={train_scores[0]}, " + \ - f"train_{list(eval_metric_to_idx)[1]}={train_scores[1]}" - return_dict['wandb'] = {'epoch': 0, f'train_{list(eval_metric_to_idx)[0]}': train_scores[0], - f'train_{list(eval_metric_to_idx)[1]}': train_scores[1]} - else: - dev_scores = eval_fn(model, val_dataloader, tqdm_label='dev 0', device=device, verbose=verbose, - debug=debug, _errors=_errors, tqdm_position=0) - return_dict['local'] = f"Initial: dev_{list(eval_metric_to_idx)[0]}={dev_scores[0]}, " + \ - f"dev_{list(eval_metric_to_idx)[1]}={dev_scores[1]}" - return_dict['wandb'] = {'epoch': 0, f'dev_{list(eval_metric_to_idx)[0]}': dev_scores[0], - f'dev_{list(eval_metric_to_idx)[1]}': dev_scores[1]} - del model - return_dict['_state'] = 'done' - - -def dev_eval(model_class, model_args, state_dict_path, overfit_batch_idx, eval_fn, train_dataloader, device, verbose, - debug, _errors, eval_metric_to_idx, val_dataloader, return_dict, i, run_dir): - return_dict['_state'] = 'start' - return_dict['_method'] = 'dev_eval' - return_dict['state_dict_path'] = state_dict_path - model = model_class(*model_args) - model.load_state_dict(torch.load(state_dict_path)) - model.to(device) - with torch.no_grad(): - model.eval() - if overfit_batch_idx > -1: - train_scores = eval_fn(model, train_dataloader, overfit_batch_idx=overfit_batch_idx, - tqdm_label='train', device=device, verbose=verbose, debug=debug, - _errors=_errors) - return_dict['local'] = f"Epoch {i + 1}: train_{list(eval_metric_to_idx)[0]}={train_scores[0]}, " + \ - f"train_{list(eval_metric_to_idx)[1]}={train_scores[1]}" - return_dict['wandb'] = {f'train_{list(eval_metric_to_idx)[0]}': train_scores[0], - f'train_{list(eval_metric_to_idx)[1]}': train_scores[1]} - return_dict['train_scores'] = train_scores - else: - dev_scores = eval_fn(model, val_dataloader, tqdm_label=f'dev {i+1}', device=device, verbose=verbose, - debug=debug, _errors=_errors) - return_dict['local'] = f"Epoch {i + 1}: dev_{list(eval_metric_to_idx)[0]}={dev_scores[0]}, " + \ - f"dev_{list(eval_metric_to_idx)[1]}={dev_scores[1]}" - return_dict['wandb'] = {f'dev_{list(eval_metric_to_idx)[0]}': dev_scores[0], - f'dev_{list(eval_metric_to_idx)[1]}': dev_scores[1]} - return_dict['dev_scores'] = dev_scores - del model - return_dict['_state'] = 'done' - - def train(hyperparams={}, verbose=False, project=None, entity=None, tags=None, group=None, save_model=False, load_model_from_wandb_run=None, load_model_from_fpath=None, eval_only_split=None, eval_all=False, skip_initial_eval=False, pairwise_eval_clustering=None, @@ -206,14 +107,20 @@ def train(hyperparams={}, verbose=False, project=None, entity=None, tags=None, g eval_metric_to_idx = clustering_metrics if not pairwise_mode else pairwise_metrics dev_opt_metric = hyp['dev_opt_metric'] if hyp['dev_opt_metric'] in eval_metric_to_idx \ else list(eval_metric_to_idx)[0] + training_mode = not eval_all and eval_only_split is None # Get data loaders (optionally with imputation, normalization) - train_dataloader, val_dataloader, test_dataloader = get_dataloaders(hyp["dataset"], hyp["dataset_random_seed"], - hyp["convert_nan"], hyp["nan_value"], - hyp["normalize_data"], hyp["subsample_sz_train"], - hyp["subsample_sz_dev"], pairwise_mode, - batch_size) - n_features = train_dataloader.dataset[0][0].shape[1] + if training_mode: + train_dataloader, val_dataloader, test_dataloader = get_dataloaders(hyp["dataset"], + hyp["dataset_random_seed"], + hyp["convert_nan"], hyp["nan_value"], + hyp["normalize_data"], + hyp["subsample_sz_train"], + hyp["subsample_sz_dev"], pairwise_mode, + batch_size) + n_features = train_dataloader.dataset[0][0].shape[1] + else: + n_features = get_feature_count(hyp["dataset"], hyp["dataset_random_seed"]) # Create model with hyperparams if not pairwise_mode: @@ -222,53 +129,44 @@ def train(hyperparams={}, verbose=False, project=None, entity=None, tags=None, g negative_slope, hidden_config, sdp_max_iters, sdp_eps, sdp_scale, use_rounded_loss, (e2e_loss == "bce"), use_sdp) model = EntResModel(*model_args) - # Define loss - if e2e_loss not in ["frob", "bce"]: - raise ValueError("Invalid value for e2e_loss") - loss_fn_e2e = FrobeniusLoss() if e2e_loss == 'frob' else torch.nn.BCELoss() - - pos_weight = None - if weighted_loss: - if overfit_batch_idx > -1: - n_pos = train_dataloader.dataset[overfit_batch_idx][1].sum() - pos_weight = (len(train_dataloader.dataset[overfit_batch_idx][1]) - n_pos) / n_pos - else: - _n_pos, _n_total = 0., 0. - for _i in range(len(train_dataloader.dataset)): - _n_pos += train_dataloader.dataset[_i][1].sum() - _n_total += len(train_dataloader.dataset[_i][1]) - pos_weight = (_n_total - _n_pos) / _n_pos # Define eval eval_fn = evaluate pairwise_clustering_fns = [None] # Unused when pairwise_mode is False - if n_warmstart_epochs > 0: - train_dataloader_pairwise, _, _ = get_dataloaders(hyp["dataset"], - hyp["dataset_random_seed"], - hyp["convert_nan"], - hyp["nan_value"], - hyp["normalize_data"], - hyp["subsample_sz_train"], - hyp["subsample_sz_dev"], - True, hyp['batch_size']) + + if training_mode: # => model will be used for training # Define loss - loss_fn_pairwise = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight)) + if e2e_loss not in ["frob", "bce"]: + raise ValueError("Invalid value for e2e_loss") + loss_fn_e2e = FrobeniusLoss() if e2e_loss == 'frob' else torch.nn.BCELoss() + + pos_weight = None + if weighted_loss: + if overfit_batch_idx > -1: + n_pos = train_dataloader.dataset[overfit_batch_idx][1].sum() + pos_weight = (len(train_dataloader.dataset[overfit_batch_idx][1]) - n_pos) / n_pos + else: + _n_pos, _n_total = 0., 0. + for _i in range(len(train_dataloader.dataset)): + _n_pos += train_dataloader.dataset[_i][1].sum() + _n_total += len(train_dataloader.dataset[_i][1]) + pos_weight = (_n_total - _n_pos) / _n_pos if _n_pos > 0 else 1. + if n_warmstart_epochs > 0: + train_dataloader_pairwise = get_dataloaders(hyp["dataset"], + hyp["dataset_random_seed"], + hyp["convert_nan"], + hyp["nan_value"], + hyp["normalize_data"], + hyp["subsample_sz_train"], + hyp["subsample_sz_dev"], + pairwise_mode=True, batch_size=hyp['batch_size'], + split='train') + # Define loss + loss_fn_pairwise = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight)) else: model_args = (n_features, neumiss_depth, dropout_p, dropout_only_once, add_neumiss, - neumiss_deq, hidden_dim, n_hidden_layers, add_batchnorm, activation, - negative_slope, hidden_config) + neumiss_deq, hidden_dim, n_hidden_layers, add_batchnorm, activation, + negative_slope, hidden_config) model = PairwiseModel(*model_args) - # Define loss - pos_weight = None - if weighted_loss: - if overfit_batch_idx > -1: - n_pos = \ - train_dataloader.dataset[overfit_batch_idx * batch_size:(overfit_batch_idx + 1) * batch_size][ - 1].sum() - pos_weight = torch.tensor((batch_size - n_pos) / n_pos) - else: - n_pos = train_dataloader.dataset[:][1].sum() - pos_weight = torch.tensor((len(train_dataloader.dataset) - n_pos) / n_pos) - loss_fn_pairwise = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight) # Define eval eval_fn = evaluate_pairwise pairwise_clustering_fns = [None] @@ -287,14 +185,28 @@ def train(hyperparams={}, verbose=False, project=None, entity=None, tags=None, g pairwise_clustering_fn_labels = ['cc', 'hac', 'cc-fixed'] else: raise ValueError('Invalid argument passed to --pairwise_eval_clustering') - _, val_dataloader_e2e, test_dataloader_e2e = get_dataloaders(hyp["dataset"], - hyp["dataset_random_seed"], - hyp["convert_nan"], - hyp["nan_value"], - hyp["normalize_data"], - hyp["subsample_sz_train"], - hyp["subsample_sz_dev"], - pairwise_mode=False, batch_size=1) + val_dataloader_e2e, test_dataloader_e2e = get_dataloaders(hyp["dataset"], + hyp["dataset_random_seed"], + hyp["convert_nan"], + hyp["nan_value"], + hyp["normalize_data"], + hyp["subsample_sz_train"], + hyp["subsample_sz_dev"], + pairwise_mode=False, batch_size=1, + split=['dev', 'test']) + if training_mode: # => model will be used for training + # Define loss + pos_weight = None + if weighted_loss: + if overfit_batch_idx > -1: + n_pos = \ + train_dataloader.dataset[overfit_batch_idx * batch_size:(overfit_batch_idx + 1) * batch_size][ + 1].sum() + pos_weight = torch.tensor((batch_size - n_pos) / n_pos if n_pos > 0 else 1.) + else: + n_pos = train_dataloader.dataset[:][1].sum() + pos_weight = torch.tensor((len(train_dataloader.dataset) - n_pos) / n_pos if n_pos > 0 else 1.) + loss_fn_pairwise = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight) logger.info(f"Model loaded: {model}", ) # Load stored model, if available @@ -323,14 +235,15 @@ def train(hyperparams={}, verbose=False, project=None, entity=None, tags=None, g 'cc-nosdp', 'cc-nosdp-fixed'] cc_inference_sdp.eval() cc_inference_nosdp.eval() - _, val_dataloader_e2e, test_dataloader_e2e = get_dataloaders(hyp["dataset"], - hyp["dataset_random_seed"], - hyp["convert_nan"], - hyp["nan_value"], - hyp["normalize_data"], - hyp["subsample_sz_train"], - hyp["subsample_sz_dev"], - pairwise_mode=False, batch_size=1) + val_dataloader_e2e, test_dataloader_e2e = get_dataloaders(hyp["dataset"], + hyp["dataset_random_seed"], + hyp["convert_nan"], + hyp["nan_value"], + hyp["normalize_data"], + hyp["subsample_sz_train"], + hyp["subsample_sz_dev"], + pairwise_mode=False, batch_size=1, + split=['dev', 'test']) start_time = time.time() with torch.no_grad(): model.eval() @@ -342,7 +255,7 @@ def train(hyperparams={}, verbose=False, project=None, entity=None, tags=None, g clustering_threshold=clustering_threshold if i % 2 == 0 else None, val_dataloader=val_dataloader_e2e, tqdm_label='test clustering', device=device, verbose=verbose, - debug=debug, _errors=_errors) + debug=debug, _errors=_errors, model_args=model_args) if inference_fn.__class__ is HACInference: clustering_threshold = inference_fn.cut_threshold logger.info( @@ -360,18 +273,16 @@ def train(hyperparams={}, verbose=False, project=None, entity=None, tags=None, g end_time = time.time() elif eval_only_split is not None: # Run inference on the specified split and exit - dataloaders = { - 'train': train_dataloader, - 'dev': val_dataloader, - 'test': test_dataloader - } start_time = time.time() with torch.no_grad(): model.eval() - - eval_dataloader = dataloaders[eval_only_split] + eval_dataloader = get_dataloaders(hyp["dataset"], hyp["dataset_random_seed"], + hyp["convert_nan"], hyp["nan_value"], + hyp["normalize_data"], hyp["subsample_sz_train"], + hyp["subsample_sz_dev"], pairwise_mode, + batch_size, split=eval_only_split) eval_scores = eval_fn(model, eval_dataloader, tqdm_label=eval_only_split, device=device, verbose=verbose, - debug=debug, _errors=_errors) + debug=debug, _errors=_errors, model_args=model_args) logger.info(f"Eval: {eval_only_split}_{list(eval_metric_to_idx)[0]}={eval_scores[0]}, " + f"{eval_only_split}_{list(eval_metric_to_idx)[1]}={eval_scores[1]}") # Log eval metrics @@ -380,7 +291,6 @@ def train(hyperparams={}, verbose=False, project=None, entity=None, tags=None, g if len(eval_scores) == 3: log_cc_objective_values(scores=eval_scores, split_name=eval_only_split, log_prefix='Eval', verbose=verbose, logger=logger) - # For pairwise-mode: if pairwise_clustering_fns[0] is not None: clustering_threshold = None @@ -390,7 +300,7 @@ def train(hyperparams={}, verbose=False, project=None, entity=None, tags=None, g clustering_threshold=clustering_threshold, val_dataloader=val_dataloader_e2e, tqdm_label='test clustering', device=device, verbose=verbose, - debug=debug, _errors=_errors) + debug=debug, _errors=_errors, model_args=model_args) if pairwise_clustering_fn.__class__ is HACInference: clustering_threshold = pairwise_clustering_fn.cut_threshold logger.info( @@ -430,15 +340,14 @@ def train(hyperparams={}, verbose=False, project=None, entity=None, tags=None, g if not skip_initial_eval: # Get initial model performance on dev (or 'train' for overfitting runs) - _state_dict_path = copy_and_load_model(model, run.dir, device, store_only=True) - _proc = Process(target=init_eval, - kwargs=dict(model_class=model.__class__, model_args=model_args, - state_dict_path=_state_dict_path, - overfit_batch_idx=overfit_batch_idx, eval_fn=eval_fn, - train_dataloader=train_dataloader, device=device, verbose=verbose, - debug=debug, _errors=_errors, eval_metric_to_idx=eval_metric_to_idx, - val_dataloader=val_dataloader, return_dict=_return_dict)) - _proc.start() + _proc = fork_eval(target=init_eval, args=dict(model_args=model_args, + overfit_batch_idx=overfit_batch_idx, eval_fn=eval_fn, + train_dataloader=train_dataloader, device=device, + verbose=verbose, + debug=debug, _errors=_errors, + eval_metric_to_idx=eval_metric_to_idx, + val_dataloader=val_dataloader, return_dict=_return_dict), + model=model, run_dir=run.dir, device=device, logger=logger) if not pairwise_mode and grad_acc > 1: grad_acc_steps = [] _seen_pw = 0 @@ -473,9 +382,9 @@ def train(hyperparams={}, verbose=False, project=None, entity=None, tags=None, g grad_acc_idx = 0 optimizer.zero_grad() - for (idx, batch) in enumerate(tqdm(_train_dataloader, - desc=f"{'Warm-starting' if warmstart_mode else 'Training'} {i + 1}", - position=1)): + pbar = tqdm(_train_dataloader, desc=f"{'Warm-starting' if warmstart_mode else 'Training'} {i + 1}", + position=1) + for (idx, batch) in enumerate(pbar): best_epoch, best_dev_score, best_dev_scores, best_dev_state_dict = _check_process(_proc, _return_dict, logger, run, @@ -506,6 +415,8 @@ def train(hyperparams={}, verbose=False, project=None, entity=None, tags=None, g # Block contains only one signature pair; batchnorm throws error continue block_size = get_matrix_size_from_triu(data) + pbar.set_description(f"{'Warm-starting' if warmstart_mode else 'Training'} {i + 1} " + \ + f"(sz={len(data) if (pairwise_mode or warmstart_mode) else block_size})") target = target.flatten().float() if verbose: logger.info(f"Batch shape: {data.shape}") @@ -611,29 +522,29 @@ def train(hyperparams={}, verbose=False, project=None, entity=None, tags=None, g wandb.log({f'train_epoch_loss': np.mean(running_loss)}) # Get model performance on dev (or 'train' for overfitting runs) - _state_dict_path = copy_and_load_model(model, run.dir, device, store_only=True) - _proc = Process(target=dev_eval, - kwargs=dict(model_class=model.__class__, model_args=model_args, - state_dict_path=_state_dict_path, overfit_batch_idx=overfit_batch_idx, - eval_fn=eval_fn, train_dataloader=train_dataloader, device=device, + _proc = fork_eval(target=dev_eval, + args=dict(model_args=model_args, + overfit_batch_idx=overfit_batch_idx, eval_fn=eval_fn, + train_dataloader=train_dataloader, device=device, verbose=verbose, debug=debug, _errors=_errors, eval_metric_to_idx=eval_metric_to_idx, val_dataloader=val_dataloader, - return_dict=_return_dict, i=i, run_dir=run.dir)) - _proc.start() + return_dict=_return_dict, i=i), + model=model, run_dir=run.dir, device=device, logger=logger, + sync=(idx == len(_train_dataloader.dataset) - 1)) end_time = time.time() best_epoch, best_dev_score, best_dev_scores, best_dev_state_dict = _check_process(_proc, _return_dict, - logger, run, - overfit_batch_idx, - use_lr_scheduler, - hyp, scheduler, - eval_metric_to_idx, - dev_opt_metric, i, - best_epoch, - best_dev_score, - best_dev_scores, - best_dev_state_dict, - sync=True) + logger, run, + overfit_batch_idx, + use_lr_scheduler, + hyp, scheduler, + eval_metric_to_idx, + dev_opt_metric, i, + best_epoch, + best_dev_score, + best_dev_scores, + best_dev_state_dict, + sync=True) # Save model if save_model: torch.save(best_dev_state_dict, os.path.join(run.dir, 'model_state_dict_best.pt')) @@ -646,7 +557,7 @@ def train(hyperparams={}, verbose=False, project=None, entity=None, tags=None, g with torch.no_grad(): model.eval() test_scores = eval_fn(model, test_dataloader, tqdm_label='test', device=device, verbose=verbose, - debug=debug, _errors=_errors, tqdm_position=2) + debug=debug, _errors=_errors, tqdm_position=2, model_args=model_args) logger.info(f"Final: test_{list(eval_metric_to_idx)[0]}={test_scores[0]}, " + f"test_{list(eval_metric_to_idx)[1]}={test_scores[1]}") # Log final metrics @@ -667,7 +578,8 @@ def train(hyperparams={}, verbose=False, project=None, entity=None, tags=None, g clustering_threshold=clustering_threshold, val_dataloader=val_dataloader_e2e, tqdm_label='test clustering', device=device, verbose=verbose, - debug=debug, _errors=_errors, tqdm_position=2) + debug=debug, _errors=_errors, tqdm_position=2, + model_args=model_args) if pairwise_clustering_fn.__class__ is HACInference: clustering_threshold = pairwise_clustering_fn.cut_threshold logger.info(f"Final: test_{list(clustering_metrics)[0]}_{pairwise_clustering_fn_labels[i]}={clustering_scores[0]}, " + diff --git a/e2e_scripts/train_utils.py b/e2e_scripts/train_utils.py index c86c498..0fe40c9 100644 --- a/e2e_scripts/train_utils.py +++ b/e2e_scripts/train_utils.py @@ -9,26 +9,27 @@ from typing import Tuple, Optional import math import pickle +import torch +import numpy as np +import wandb from time import time +from sklearn.preprocessing import StandardScaler from torch.utils.data import DataLoader from s2and.consts import PREPROCESSED_DATA_DIR from s2and.data import S2BlocksDataset from s2and.eval import b3_precision_recall_fscore from torch import Tensor -import torch -import numpy as np -import wandb +from torch.multiprocessing import Process from IPython import embed - # Default hyperparameters DEFAULT_HYPERPARAMS = { # Dataset "dataset": "pubmed", "dataset_random_seed": 1, - "subsample_sz_train": 80, - "subsample_sz_dev": 100, + "subsample_sz_train": 60, + "subsample_sz_dev": -1, # Run config "run_random_seed": 17, "pairwise_mode": False, @@ -56,11 +57,11 @@ "sdp_eps": 1e-3, "sdp_scale": True, # Training config - "batch_size": 10000, # pairwise only; used by e2e if gradient_accumulation is true - "lr": 4e-3, + "batch_size": 8000, # pairwise only; used by e2e if gradient_accumulation is true + "lr": 1e-3, "n_epochs": 5, "n_warmstart_epochs": 0, - "weighted_loss": False, + "weighted_loss": True, "use_lr_scheduler": True, "lr_scheduler": "plateau", # "plateau", "step" "lr_factor": 0.4, @@ -69,7 +70,7 @@ "lr_step_size": 2, "lr_gamma": 0.4, "weight_decay": 0.01, - "gradient_accumulation": False, # e2e only; accumulate over pairwise examples + "gradient_accumulation": True, # e2e only; accumulate over pairwise examples "dev_opt_metric": 'b3_f1', # e2e: {'b3_f1', 'vmeasure'}; pairwise: {'auroc', 'f1'} "overfit_batch_idx": -1 } @@ -83,25 +84,42 @@ def read_blockwise_features(pkl): def get_dataloaders(dataset, dataset_seed, convert_nan, nan_value, normalize, subsample_sz_train, subsample_sz_dev, - pairwise_mode, batch_size): - train_pkl = f"{PREPROCESSED_DATA_DIR}/{dataset}/seed{dataset_seed}/train_features.pkl" - val_pkl = f"{PREPROCESSED_DATA_DIR}/{dataset}/seed{dataset_seed}/val_features.pkl" - test_pkl = f"{PREPROCESSED_DATA_DIR}/{dataset}/seed{dataset_seed}/test_features.pkl" + pairwise_mode, batch_size, shuffle=False, split=None): + pickle_path = { + 'train': f"{PREPROCESSED_DATA_DIR}/{dataset}/seed{dataset_seed}/train_features.pkl", + 'dev': f"{PREPROCESSED_DATA_DIR}/{dataset}/seed{dataset_seed}/val_features.pkl", + 'test': f"{PREPROCESSED_DATA_DIR}/{dataset}/seed{dataset_seed}/test_features.pkl" + } + subsample_sz = { + 'train': subsample_sz_train, + 'dev': subsample_sz_dev, + 'test': -1 + } + train_scaler = StandardScaler() + train_X = np.concatenate(list(map(lambda x: x[0], read_blockwise_features(pickle_path['train']).values()))) + train_scaler.fit(train_X) - train_dataset = S2BlocksDataset(read_blockwise_features(train_pkl), convert_nan=convert_nan, nan_value=nan_value, - scale=normalize, subsample_sz=subsample_sz_train, pairwise_mode=pairwise_mode) - train_dataloader = DataLoader(train_dataset, shuffle=False, batch_size=batch_size) + def _get_dataloader(_split): + dataset = S2BlocksDataset(read_blockwise_features(pickle_path[_split]), convert_nan=convert_nan, + nan_value=nan_value, scale=normalize, scaler=train_scaler, + subsample_sz=subsample_sz[_split], + pairwise_mode=pairwise_mode, sort_desc=(_split in ['dev', 'test'])) + dataloader = DataLoader(dataset, shuffle=shuffle, batch_size=batch_size) + return dataloader - val_dataset = S2BlocksDataset(read_blockwise_features(val_pkl), convert_nan=convert_nan, nan_value=nan_value, - scale=normalize, scaler=train_dataset.scaler, subsample_sz=subsample_sz_dev, - pairwise_mode=pairwise_mode) - val_dataloader = DataLoader(val_dataset, shuffle=False, batch_size=batch_size) + if split is None: + return _get_dataloader('train'), _get_dataloader('dev'), _get_dataloader('test') + if type(split) is str: + return _get_dataloader(split) + if type(split) is list: + return tuple([_get_dataloader(_split) for _split in split]) + raise ValueError('Invalid argument to split') - test_dataset = S2BlocksDataset(read_blockwise_features(test_pkl), convert_nan=convert_nan, nan_value=nan_value, - scale=normalize, scaler=train_dataset.scaler, pairwise_mode=pairwise_mode) - test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size) - return train_dataloader, val_dataloader, test_dataloader +def get_feature_count(dataset, dataset_seed): + data_fpath = f"{PREPROCESSED_DATA_DIR}/{dataset}/seed{dataset_seed}/test_features.pkl" + block_dict = read_blockwise_features(data_fpath) + return next(iter(block_dict.values()))[0].shape[1] def uncompress_target_tensor(compressed_targets, make_symmetric=True, device=None): @@ -205,3 +223,118 @@ def copy_and_load_model(model, run_dir, device, store_only=False): _model.load_state_dict(_STATE_DICT) os.remove(_PATH) return _model + + +def _check_process(_proc, _return_dict, logger, run, overfit_batch_idx, use_lr_scheduler, hyp, + scheduler, eval_metric_to_idx, dev_opt_metric, i, best_epoch, best_dev_score, + best_dev_scores, best_dev_state_dict, sync=False): + if _proc is not None: + if _return_dict['_state'] == 'done' or (sync and _return_dict['_state'] != 'finish'): + _proc.join() + _return_dict['_state'] = 'finish' + if _return_dict['_method'] == 'init_eval': + logger.info(_return_dict['local']) + run.log(_return_dict['wandb']) + if overfit_batch_idx == -1: + best_dev_scores = _return_dict['dev_scores'] + best_dev_score = best_dev_scores[eval_metric_to_idx[dev_opt_metric]] + elif _return_dict['_method'] == 'dev_eval': + logger.info(_return_dict['local']) + run.log(_return_dict['wandb']) + if overfit_batch_idx > -1: + if use_lr_scheduler: + if hyp['lr_scheduler'] == 'plateau': + scheduler.step(_return_dict['train_scores'][eval_metric_to_idx[dev_opt_metric]]) + elif hyp['lr_scheduler'] == 'step': + scheduler.step() + else: + dev_scores = _return_dict['dev_scores'] + dev_opt_score = dev_scores[eval_metric_to_idx[dev_opt_metric]] + if dev_opt_score > best_dev_score: + logger.info(f"New best dev {dev_opt_metric} score @ epoch{i + 1}: {dev_opt_score}") + best_epoch = i + best_dev_score = dev_opt_score + best_dev_scores = dev_scores + best_dev_state_dict = torch.load(_return_dict['state_dict_path']) + if use_lr_scheduler: + if hyp['lr_scheduler'] == 'plateau': + scheduler.step(dev_scores[eval_metric_to_idx[dev_opt_metric]]) + elif hyp['lr_scheduler'] == 'step': + scheduler.step() + return best_epoch, best_dev_score, best_dev_scores, best_dev_state_dict + + +def init_eval(model_class, model_args, state_dict_path, overfit_batch_idx, eval_fn, train_dataloader, device, verbose, + debug, _errors, eval_metric_to_idx, val_dataloader, return_dict): + return_dict['_state'] = 'start' + return_dict['_method'] = 'init_eval' + model = model_class(*model_args) + model.load_state_dict(torch.load(state_dict_path)) + model.to(device) + with torch.no_grad(): + model.eval() + if overfit_batch_idx > -1: + train_scores = eval_fn(model, train_dataloader, overfit_batch_idx=overfit_batch_idx, + tqdm_label='train', device=device, verbose=verbose, debug=debug, + _errors=_errors, tqdm_position=0, model_args=model_args) + return_dict['local'] = f"Initial: train_{list(eval_metric_to_idx)[0]}={train_scores[0]}, " + \ + f"train_{list(eval_metric_to_idx)[1]}={train_scores[1]}" + return_dict['wandb'] = {'epoch': 0, f'train_{list(eval_metric_to_idx)[0]}': train_scores[0], + f'train_{list(eval_metric_to_idx)[1]}': train_scores[1]} + else: + dev_scores = eval_fn(model, val_dataloader, tqdm_label='dev 0', device=device, verbose=verbose, + debug=debug, _errors=_errors, tqdm_position=0, model_args=model_args) + return_dict['local'] = f"Initial: dev_{list(eval_metric_to_idx)[0]}={dev_scores[0]}, " + \ + f"dev_{list(eval_metric_to_idx)[1]}={dev_scores[1]}" + return_dict['wandb'] = {'epoch': 0, f'dev_{list(eval_metric_to_idx)[0]}': dev_scores[0], + f'dev_{list(eval_metric_to_idx)[1]}': dev_scores[1]} + return_dict['dev_scores'] = dev_scores + del model + return_dict['_state'] = 'done' + return return_dict + + +def dev_eval(model_class, model_args, state_dict_path, overfit_batch_idx, eval_fn, train_dataloader, device, verbose, + debug, _errors, eval_metric_to_idx, val_dataloader, return_dict, i): + return_dict['_state'] = 'start' + return_dict['_method'] = 'dev_eval' + return_dict['state_dict_path'] = state_dict_path + model = model_class(*model_args) + model.load_state_dict(torch.load(state_dict_path)) + model.to(device) + with torch.no_grad(): + model.eval() + if overfit_batch_idx > -1: + train_scores = eval_fn(model, train_dataloader, overfit_batch_idx=overfit_batch_idx, + tqdm_label='train', device=device, verbose=verbose, debug=debug, + _errors=_errors, model_args=model_args) + return_dict['local'] = f"Epoch {i + 1}: train_{list(eval_metric_to_idx)[0]}={train_scores[0]}, " + \ + f"train_{list(eval_metric_to_idx)[1]}={train_scores[1]}" + return_dict['wandb'] = {f'train_{list(eval_metric_to_idx)[0]}': train_scores[0], + f'train_{list(eval_metric_to_idx)[1]}': train_scores[1]} + return_dict['train_scores'] = train_scores + else: + dev_scores = eval_fn(model, val_dataloader, tqdm_label=f'dev {i + 1}', device=device, verbose=verbose, + debug=debug, _errors=_errors, model_args=model_args) + return_dict['local'] = f"Epoch {i + 1}: dev_{list(eval_metric_to_idx)[0]}={dev_scores[0]}, " + \ + f"dev_{list(eval_metric_to_idx)[1]}={dev_scores[1]}" + return_dict['wandb'] = {f'dev_{list(eval_metric_to_idx)[0]}': dev_scores[0], + f'dev_{list(eval_metric_to_idx)[1]}': dev_scores[1]} + return_dict['dev_scores'] = dev_scores + del model + return_dict['_state'] = 'done' + return return_dict + + +def fork_eval(target, args, model, run_dir, device, logger, sync=False): + state_dict_path = copy_and_load_model(model, run_dir, device, store_only=True) + args['model_class'] = model.__class__ + args['state_dict_path'] = state_dict_path + if sync: + target(**args) + proc = Process() + else: + proc = Process(target=target, kwargs=args) + logger.info('Forking eval') + proc.start() + return proc diff --git a/s2and/data.py b/s2and/data.py index 9d75eb1..745251b 100644 --- a/s2and/data.py +++ b/s2and/data.py @@ -128,7 +128,7 @@ class S2BlocksDataset(Dataset): """ def __init__(self, block_dict: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]], convert_nan=True, nan_value=-1, scale=False, scaler=None, subsample_sz=-1, - pairwise_mode=False): + pairwise_mode=False, sort_desc=False): self.pairwise_mode = pairwise_mode self.block_dict = block_dict self.convert_nan = convert_nan @@ -171,6 +171,11 @@ def __init__(self, block_dict: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarra else: self.blockwise_data.append((X, y, cluster_ids)) self.blockwise_keys.append(dict_key) + if sort_desc: + self.blockwise_keys = list(map(lambda x: x[1], sorted(enumerate(self.blockwise_keys), + key=lambda x: len(self.blockwise_data[x[0]][2]), + reverse=True))) + self.blockwise_data.sort(key=lambda x: -len(x[2])) if self.pairwise_mode: self.pairwise_data = {'X': [], 'y': []} self.cluster_ids = [] diff --git a/utils/parser.py b/utils/parser.py index f16f61f..de21db8 100644 --- a/utils/parser.py +++ b/utils/parser.py @@ -38,6 +38,9 @@ def add_preprocessing_args(self): parser.add_argument( "--dataset_name", type=str, help="name of AND dataset that you want to preprocess" ) + parser.add_argument( + "--dataset_seed", type=int + ) def add_training_args(self): """ @@ -94,7 +97,7 @@ def add_training_args(self): help="Whether to prevent wandb sweep early terminate or not", ) parser.add_argument( - "--wandb_max_runs", type=int, default=600, + "--wandb_max_runs", type=int, default=120, help="Maximum number of runs to try in the sweep", ) parser.add_argument( diff --git a/wandb_configs/sweeps/e2e-nosdp-warm.json b/wandb_configs/sweeps/e2e-nosdp-warm.json index 3846f0d..294b543 100644 --- a/wandb_configs/sweeps/e2e-nosdp-warm.json +++ b/wandb_configs/sweeps/e2e-nosdp-warm.json @@ -8,9 +8,9 @@ "n_hidden_layers": {"values": [1, 2]}, "dropout_p": {"values": [0, 0.1, 0.2, 0.3, 0.4, 0.5]}, "lr_scheduler": {"value": "plateau"}, - "subsample_sz_train": {"value": 80}, - "subsample_sz_dev": {"value": 100}, "activation": {"values": ["leaky_relu", "relu"]}, "use_sdp": {"value": false}, - "n_warmstart_epochs": {"value": 2} + "n_warmstart_epochs": {"value": 2}, + "gradient_accumulation": {"values": [true, false]}, + "weighted_loss": {"values": [true, false]} } diff --git a/wandb_configs/sweeps/e2e-nosdp.json b/wandb_configs/sweeps/e2e-nosdp.json index 4e02afe..5b47c39 100644 --- a/wandb_configs/sweeps/e2e-nosdp.json +++ b/wandb_configs/sweeps/e2e-nosdp.json @@ -8,8 +8,8 @@ "n_hidden_layers": {"values": [1, 2]}, "dropout_p": {"values": [0, 0.1, 0.2, 0.3, 0.4, 0.5]}, "lr_scheduler": {"value": "plateau"}, - "subsample_sz_train": {"value": 80}, - "subsample_sz_dev": {"value": 100}, "activation": {"values": ["leaky_relu", "relu"]}, - "use_sdp": {"value": false} + "use_sdp": {"value": false}, + "gradient_accumulation": {"values": [true, false]}, + "weighted_loss": {"values": [true, false]} } diff --git a/wandb_configs/sweeps/e2e-warm.json b/wandb_configs/sweeps/e2e-warm.json index 77de43c..19e511b 100644 --- a/wandb_configs/sweeps/e2e-warm.json +++ b/wandb_configs/sweeps/e2e-warm.json @@ -8,8 +8,8 @@ "n_hidden_layers": {"values": [1, 2]}, "dropout_p": {"values": [0, 0.1, 0.2, 0.3, 0.4, 0.5]}, "lr_scheduler": {"value": "plateau"}, - "subsample_sz_train": {"value": 80}, - "subsample_sz_dev": {"value": 100}, "activation": {"values": ["leaky_relu", "relu"]}, - "n_warmstart_epochs": {"value": 2} + "n_warmstart_epochs": {"value": 2}, + "gradient_accumulation": {"values": [true, false]}, + "weighted_loss": {"values": [true, false]} } diff --git a/wandb_configs/sweeps/e2e.json b/wandb_configs/sweeps/e2e.json index 20991ba..e084f00 100644 --- a/wandb_configs/sweeps/e2e.json +++ b/wandb_configs/sweeps/e2e.json @@ -8,7 +8,7 @@ "n_hidden_layers": {"values": [1, 2]}, "dropout_p": {"values": [0, 0.1, 0.2, 0.3, 0.4, 0.5]}, "lr_scheduler": {"value": "plateau"}, - "subsample_sz_train": {"value": 80}, - "subsample_sz_dev": {"value": 100}, - "activation": {"values": ["leaky_relu", "relu"]} + "activation": {"values": ["leaky_relu", "relu"]}, + "gradient_accumulation": {"values": [true, false]}, + "weighted_loss": {"values": [true, false]} } diff --git a/wandb_configs/sweeps/frac-nosdp-warm.json b/wandb_configs/sweeps/frac-nosdp-warm.json index 75503ce..491e04c 100644 --- a/wandb_configs/sweeps/frac-nosdp-warm.json +++ b/wandb_configs/sweeps/frac-nosdp-warm.json @@ -8,10 +8,10 @@ "n_hidden_layers": {"values": [1, 2]}, "dropout_p": {"values": [0, 0.1, 0.2, 0.3, 0.4, 0.5]}, "lr_scheduler": {"value": "plateau"}, - "subsample_sz_train": {"value": 80}, - "subsample_sz_dev": {"value": 100}, "activation": {"values": ["leaky_relu", "relu"]}, "use_rounded_loss": {"value": false}, "use_sdp": {"value": false}, - "n_warmstart_epochs": {"value": 2} + "n_warmstart_epochs": {"value": 2}, + "gradient_accumulation": {"values": [true, false]}, + "weighted_loss": {"values": [true, false]} } diff --git a/wandb_configs/sweeps/frac-nosdp.json b/wandb_configs/sweeps/frac-nosdp.json index f27ee08..d9a1e41 100644 --- a/wandb_configs/sweeps/frac-nosdp.json +++ b/wandb_configs/sweeps/frac-nosdp.json @@ -8,9 +8,9 @@ "n_hidden_layers": {"values": [1, 2]}, "dropout_p": {"values": [0, 0.1, 0.2, 0.3, 0.4, 0.5]}, "lr_scheduler": {"value": "plateau"}, - "subsample_sz_train": {"value": 80}, - "subsample_sz_dev": {"value": 100}, "activation": {"values": ["leaky_relu", "relu"]}, "use_rounded_loss": {"value": false}, - "use_sdp": {"value": false} + "use_sdp": {"value": false}, + "gradient_accumulation": {"values": [true, false]}, + "weighted_loss": {"values": [true, false]} } diff --git a/wandb_configs/sweeps/frac-warm.json b/wandb_configs/sweeps/frac-warm.json index fa4b935..b13efc5 100644 --- a/wandb_configs/sweeps/frac-warm.json +++ b/wandb_configs/sweeps/frac-warm.json @@ -8,9 +8,9 @@ "n_hidden_layers": {"values": [1, 2]}, "dropout_p": {"values": [0, 0.1, 0.2, 0.3, 0.4, 0.5]}, "lr_scheduler": {"value": "plateau"}, - "subsample_sz_train": {"value": 80}, - "subsample_sz_dev": {"value": 100}, "activation": {"values": ["leaky_relu", "relu"]}, "use_rounded_loss": {"value": false}, - "n_warmstart_epochs": {"value": 2} + "n_warmstart_epochs": {"value": 2}, + "gradient_accumulation": {"values": [true, false]}, + "weighted_loss": {"values": [true, false]} } diff --git a/wandb_configs/sweeps/frac.json b/wandb_configs/sweeps/frac.json index 7eb6812..a572b76 100644 --- a/wandb_configs/sweeps/frac.json +++ b/wandb_configs/sweeps/frac.json @@ -8,8 +8,8 @@ "n_hidden_layers": {"values": [1, 2]}, "dropout_p": {"values": [0, 0.1, 0.2, 0.3, 0.4, 0.5]}, "lr_scheduler": {"value": "plateau"}, - "subsample_sz_train": {"value": 80}, - "subsample_sz_dev": {"value": 100}, "activation": {"values": ["leaky_relu", "relu"]}, - "use_rounded_loss": {"value": false} + "use_rounded_loss": {"value": false}, + "gradient_accumulation": {"values": [true, false]}, + "weighted_loss": {"values": [true, false]} } diff --git a/wandb_configs/sweeps/mlp.json b/wandb_configs/sweeps/mlp.json index a5f49fc..24274c7 100644 --- a/wandb_configs/sweeps/mlp.json +++ b/wandb_configs/sweeps/mlp.json @@ -10,5 +10,6 @@ "dropout_p": {"values": [0, 0.1, 0.2, 0.3, 0.4, 0.5]}, "lr_scheduler": {"value": "plateau"}, "activation": {"values": ["leaky_relu", "relu"]}, + "gradient_accumulation": {"value": false}, "weighted_loss": {"value": true} }