Skip to content

Commit

Permalink
Inference solver, parallel eval iterations, sweep config changes (#41)
Browse files Browse the repository at this point in the history
  • Loading branch information
dhdhagar authored Mar 18, 2023
1 parent 548f151 commit c1e13ae
Show file tree
Hide file tree
Showing 18 changed files with 487 additions and 309 deletions.
41 changes: 24 additions & 17 deletions e2e_debug/solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ def __init__(self):
self.add_argument(
"--scs_log_csv_filename", type=str,
)
self.add_argument(
"--max_scaling", action="store_true",
)
self.add_argument(
"--interactive", action="store_true",
)
Expand All @@ -63,15 +66,18 @@ def __init__(self):

# Read error file
logger.info("Reading input data")
with open(args.data_fpath, 'r') as fh:
data = json.load(fh)
assert len(data['errors']) > 0
# Pick specific error instance to process
error_data = data['errors'][args.data_idx]
if args.data_fpath.endswith('.pt'):
_W_val = torch.load(args.data_fpath, map_location='cpu').numpy()
else:
with open(args.data_fpath, 'r') as fh:
data = json.load(fh)
assert len(data['errors']) > 0
# Pick specific error instance to process
error_data = data['errors'][args.data_idx]

# Extract input data from the error instance
_raw = np.array(error_data['model_call_args']['data'])
_W_val = np.array(error_data['cvxpy_layer_args']['W_val'])
# Extract input data from the error instance
_raw = np.array(error_data['model_call_args']['data'])
_W_val = np.array(error_data['cvxpy_layer_args']['W_val'])

# Construct cvxpy problem
logger.info('Constructing optimization problem')
Expand All @@ -84,7 +90,7 @@ def __init__(self):
constraints = [
cp.diag(X) == np.ones((n,)),
X[:n, :] >= 0,
X[:n, :] <= 1
# X[:n, :] <= 1
]

# Setup HAC Cut
Expand All @@ -94,12 +100,14 @@ def __init__(self):
sdp_obj_value = float('inf')
result_idxs, results_X, results_clustering = [], [], []
no_solution_scaling_factors = []
for i in range(1, 10): # n
for i in range(0, 10): # n
# Skipping 1; no scaling leads to non-convergence (infinite objective value)
if i == 1:
scaling_factor = np.max(W)
if i == 0:
scaling_factor = np.max(np.abs(W))
else:
scaling_factor = i
if args.max_scaling:
continue
logger.info(f'Scaling factor={scaling_factor}')
# Create problem
W_scaled = W / scaling_factor
Expand All @@ -114,8 +122,7 @@ def __init__(self):
alpha=args.scs_alpha,
scale=args.scs_scale,
use_indirect=args.scs_use_indirect,
use_quad_obj=not args.scs_dont_use_quad_obj,
log_csv_filename=args.scs_log_csv_filename
# use_quad_obj=not args.scs_dont_use_quad_obj
)
logger.info(f"@scaling={scaling_factor}, objective value = {sdp_obj_value}, norm={np.linalg.norm(W_scaled)}")
if sdp_obj_value != float('inf'):
Expand All @@ -129,9 +136,9 @@ def __init__(self):
logger.info(f"Solution not found = {len(no_solution_scaling_factors)}")
logger.info(f"Solution found = {len(results_X)}")

logger.info("Same clustering:")
for i in range(len(results_clustering)-1):
logger.info(np.array_equal(results_clustering[i], results_clustering[i + 1]))
# logger.info("Same clustering:")
# for i in range(len(results_clustering)-1):
# logger.info(np.array_equal(results_clustering[i], results_clustering[i + 1]))
# logger.info(f"Solution found with scaling factor = {scaling_factor}")
# if args.interactive and sdp_obj_value == float('inf'):
# embed()
Expand Down
41 changes: 26 additions & 15 deletions e2e_pipeline/sdp_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,22 +50,33 @@ def build_and_solve_sdp(self, W_val, N, verbose=False):
X[:N, :] >= 0,
]

# create problem
prob = cp.Problem(cp.Maximize(cp.trace(W @ X)), constraints)
# Note: maximizing the trace is equivalent to maximizing the sum_E (w_uv * X_uv) objective
# because W is upper-triangular and X is symmetric

# Build the SDP cvxpylayer
cvxpy_layer = CvxpyLayer(prob, parameters=[W], variables=[X])

# Forward pass through the SDP cvxpylayer
try:
pw_prob_matrix = cvxpy_layer(W_val, solver_args={
"solve_method": "SCS",
"verbose": verbose,
"max_iters": self.max_iters,
"eps": self.eps
})[0]
if self.training:
# create problem
prob = cp.Problem(cp.Maximize(cp.trace(W @ X)), constraints)
# Note: maximizing the trace is equivalent to maximizing the sum_E (w_uv * X_uv) objective
# because W is upper-triangular and X is symmetric
# Build the SDP cvxpylayer
cvxpy_layer = CvxpyLayer(prob, parameters=[W], variables=[X])
# Forward pass through the SDP cvxpylayer
pw_prob_matrix = cvxpy_layer(W_val, solver_args={
"solve_method": "SCS",
"verbose": verbose,
"max_iters": self.max_iters,
"eps": self.eps
})[0]
else:
# create problem
prob = cp.Problem(cp.Maximize(cp.trace(W_val.cpu().numpy() @ X)), constraints)
_solve_val = prob.solve(
solver=cp.SCS,
verbose=verbose,
max_iters=self.max_iters,
eps=self.eps
)
if _solve_val == float('inf'):
raise ValueError()
pw_prob_matrix = torch.tensor(X.value, device=W_val.device)
# Fix to prevent invalid solution values close to 0 and 1 but outside the range
pw_prob_matrix = torch.clamp(pw_prob_matrix, min=0, max=1)
except:
Expand Down
2 changes: 1 addition & 1 deletion e2e_pipeline/uncompress_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ def __init__(self):
super().__init__()

def forward(self, compressed_matrix, N, make_symmetric=False, ones_diagonal=False):
device = compressed_matrix.get_device()
device = compressed_matrix.device
triu_indices = torch.triu_indices(N, N, offset=1, device=device)
if make_symmetric:
sym_indices = torch.stack((torch.cat((triu_indices[0], triu_indices[1])),
Expand Down
140 changes: 123 additions & 17 deletions e2e_scripts/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
from sklearn.metrics.cluster import v_measure_score
from sklearn.metrics import roc_curve, auc
from sklearn.metrics import precision_recall_fscore_support
from torch.multiprocessing import Process, Manager
import numpy as np
import torch

from e2e_pipeline.cc_inference import CCInference
from e2e_pipeline.hac_inference import HACInference
from e2e_pipeline.sdp_layer import CvxpyException
from e2e_scripts.train_utils import compute_b3_f1, save_to_wandb_run
from e2e_scripts.train_utils import compute_b3_f1, save_to_wandb_run, copy_and_load_model

from IPython import embed

Expand All @@ -24,13 +25,50 @@
logger = logging.getLogger(__name__)


def _run_iter(model_class, state_dict_path, _fork_id, _shared_list, eval_fn, **kwargs):
model = model_class(*kwargs['model_args'])
model.load_state_dict(torch.load(state_dict_path))
model.to('cpu')
model.eval()
with torch.no_grad():
res = eval_fn(model=model, **kwargs)
_shared_list.append(res)
del model


def _fork_iter(batch_idx, _fork_id, _shared_list, eval_fn, **kwargs):
kwargs['model_class'] = kwargs['model'].__class__
kwargs['state_dict_path'] = copy_and_load_model(kwargs['model'], kwargs['run_dir'], 'cpu', store_only=True)
del kwargs['model']
kwargs['overfit_batch_idx'] = batch_idx
kwargs['tqdm_label'] = f'{kwargs["tqdm_label"]} (fork{_fork_id})'
kwargs['_fork_id'] = _fork_id
kwargs['tqdm_position'] = (0 if kwargs['tqdm_position'] is None else kwargs['tqdm_position']) + _fork_id + 1
kwargs['return_iter'] = True
kwargs['fork_size'] = -1
kwargs['_shared_list'] = _shared_list
kwargs['disable_tqdm'] = True
kwargs['device'] = 'cpu'
kwargs['eval_fn'] = eval_fn
_proc = Process(target=_run_iter, kwargs=kwargs)
_proc.start()
return _proc


def evaluate(model, dataloader, overfit_batch_idx=-1, clustering_fn=None, clustering_threshold=None,
val_dataloader=None, tqdm_label='', device=None, verbose=False, debug=False, _errors=None,
run_dir='./', tqdm_position=None):
run_dir='./', tqdm_position=None, model_args=None, return_iter=False, fork_size=500,
disable_tqdm=False):
"""
clustering_fn, clustering_threshold, val_dataloader: unused when pairwise_mode is False
(only added to keep fn signature identical)
"""
fn_args = locals()
fork_enabled = fork_size > -1 and model_args is not None
if fork_enabled:
_fork_id = 1
_shared_list = Manager().list()
_procs = []
device = device if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_features = dataloader.dataset[0][0].shape[1]

Expand All @@ -43,19 +81,25 @@ def evaluate(model, dataloader, overfit_batch_idx=-1, clustering_fn=None, cluste
}
max_pred_id = -1
n_exceptions = 0
for (idx, batch) in enumerate(tqdm(dataloader, desc=f'Evaluating {tqdm_label}', position=tqdm_position)):
pbar = tqdm(dataloader, desc=f'Eval {tqdm_label}', position=tqdm_position, disable=disable_tqdm)
for (idx, batch) in enumerate(pbar):
if overfit_batch_idx > -1:
if idx < overfit_batch_idx:
continue
if idx > overfit_batch_idx:
break
data, _, cluster_ids = batch
block_size = len(cluster_ids)
all_gold += list(np.reshape(cluster_ids, (block_size,)))
pbar.set_description(f'Eval {tqdm_label} (sz={block_size})')
data = data.reshape(-1, n_features).float()
if data.shape[0] == 0:
# Only one signature in block; manually assign a unique cluster
pred_cluster_ids = [max_pred_id + 1]
elif fork_enabled and block_size >= fork_size:
_proc = _fork_iter(idx, _fork_id, _shared_list, evaluate, **fn_args)
_fork_id += 1
_procs.append((_proc, block_size))
continue
else:
# Forward pass through the e2e model
data = data.to(device)
Expand All @@ -79,8 +123,6 @@ def evaluate(model, dataloader, overfit_batch_idx=-1, clustering_fn=None, cluste
save_to_wandb_run({'errors': _errors}, 'errors.json', run_dir, logger)
if not debug: # if tqdm_label is not 'dev' and not debug:
raise CvxpyException(data=_error_obj)
# If split is dev, skip batch and continue
all_gold = all_gold[:-len(cluster_ids)]
n_exceptions += 1
logger.info(f'Caught CvxpyException {n_exceptions}: skipping batch')
continue
Expand All @@ -89,8 +131,34 @@ def evaluate(model, dataloader, overfit_batch_idx=-1, clustering_fn=None, cluste
cc_obj_vals['sdp'].append(model.sdp_layer.objective_value)
cc_obj_vals['block_idxs'].append(idx)
cc_obj_vals['block_sizes'].append(block_size)
all_gold += list(np.reshape(cluster_ids, (block_size,)))
max_pred_id = max(pred_cluster_ids)
all_pred += list(pred_cluster_ids)
if overfit_batch_idx > -1 and return_iter:
return {
'cluster_labels': model.hac_cut_layer.cluster_labels,
'round_objective_value': model.hac_cut_layer.objective_value,
'sdp_objective_value': model.sdp_layer.objective_value,
'block_idx': idx,
'block_size': block_size,
'cluster_ids': cluster_ids
}

if fork_enabled and len(_procs) > 0:
_procs.sort(key=lambda x: x[1]) # To visualize progress
for _proc in tqdm(_procs, desc=f'Eval {tqdm_label} (waiting for forks to join)', position=tqdm_position):
_proc[0].join()
assert len(_procs) == len(_shared_list), "All forked eval iterations did not return results"
for _data in _shared_list:
pred_cluster_ids = (_data['cluster_labels'] + (max_pred_id + 1)).tolist()
cc_obj_vals['round'].append(_data['round_objective_value'])
cc_obj_vals['sdp'].append(_data['sdp_objective_value'])
cc_obj_vals['block_idxs'].append(_data['block_idx'])
cc_obj_vals['block_sizes'].append(_data['block_size'])
all_gold += list(np.reshape(_data['cluster_ids'], (_data['block_size'],)))
max_pred_id = max(pred_cluster_ids)
all_pred += list(pred_cluster_ids)

vmeasure = v_measure_score(all_gold, all_pred)
b3_f1 = compute_b3_f1(all_gold, all_pred)[2]
return b3_f1, vmeasure, cc_obj_vals
Expand All @@ -99,7 +167,13 @@ def evaluate(model, dataloader, overfit_batch_idx=-1, clustering_fn=None, cluste
def evaluate_pairwise(model, dataloader, overfit_batch_idx=-1, mode="macro", return_pred_only=False,
thresh_for_f1=0.5, clustering_fn=None, clustering_threshold=None, val_dataloader=None,
tqdm_label='', device=None, verbose=False, debug=False, _errors=None, run_dir='./',
tqdm_position=None):
tqdm_position=None, model_args=None, return_iter=False, fork_size=500, disable_tqdm=False):
fn_args = locals()
fork_enabled = fork_size > -1 and model_args is not None
if fork_enabled:
_fork_id = 1
_shared_list = Manager().list()
_procs = []
device = device if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_features = dataloader.dataset[0][0].shape[1]

Expand All @@ -117,19 +191,25 @@ def evaluate_pairwise(model, dataloader, overfit_batch_idx=-1, mode="macro", ret
}
max_pred_id = -1 # In each iteration, add to all blockwise predicted IDs to distinguish from previous blocks
n_exceptions = 0
for (idx, batch) in enumerate(tqdm(dataloader, desc=f'Evaluating {tqdm_label}', position=tqdm_position)):
pbar = tqdm(dataloader, desc=f'Eval {tqdm_label}', position=tqdm_position, disable=disable_tqdm)
for (idx, batch) in enumerate(pbar):
if overfit_batch_idx > -1:
if idx < overfit_batch_idx:
continue
if idx > overfit_batch_idx:
break
data, _, cluster_ids = batch
block_size = len(cluster_ids)
all_gold += list(np.reshape(cluster_ids, (block_size,)))
pbar.set_description(f'Eval {tqdm_label} (sz={block_size})')
data = data.reshape(-1, n_features).float()
if data.shape[0] == 0:
# Only one signature in block; manually assign a unique cluster
pred_cluster_ids = [max_pred_id + 1]
elif fork_enabled and block_size >= fork_size and clustering_fn.__class__ is CCInference:
_proc = _fork_iter(idx, _fork_id, _shared_list, evaluate_pairwise, **fn_args)
_fork_id += 1
_procs.append((_proc, block_size))
continue
else:
# Forward pass through the e2e model
data = data.to(device)
Expand All @@ -155,31 +235,57 @@ def evaluate_pairwise(model, dataloader, overfit_batch_idx=-1, mode="macro", ret
save_to_wandb_run({'errors': _errors}, 'errors.json', run_dir, logger)
if not debug: # if tqdm_label is not 'dev' and not debug:
raise CvxpyException(data=_error_obj)
# If split is dev, skip batch and continue
all_gold = all_gold[:-len(cluster_ids)]
n_exceptions += 1
logger.info(f'Caught CvxpyException {n_exceptions}: skipping batch')
continue
if clustering_fn.__class__ is CCInference:
cc_obj_vals['round'].append(clustering_fn.hac_cut_layer.objective_value)
cc_obj_vals['sdp'].append(clustering_fn.sdp_layer.objective_value)
cc_obj_vals['block_idxs'].append(idx)
cc_obj_vals['block_sizes'].append(block_size)
all_gold += list(np.reshape(cluster_ids, (block_size,)))
max_pred_id = max(pred_cluster_ids)
all_pred += list(pred_cluster_ids)
if clustering_fn.__class__ is CCInference:
cc_obj_vals['round'].append(clustering_fn.hac_cut_layer.objective_value)
cc_obj_vals['sdp'].append(clustering_fn.sdp_layer.objective_value)
cc_obj_vals['block_idxs'].append(idx)
cc_obj_vals['block_sizes'].append(block_size)
if overfit_batch_idx > -1 and return_iter:
return {
'cluster_labels': list(np.array(pred_cluster_ids) - (max_pred_id + 1)),
'round_objective_value': clustering_fn.hac_cut_layer.objective_value,
'sdp_objective_value': clustering_fn.sdp_layer.objective_value,
'block_idx': idx,
'block_size': block_size,
'cluster_ids': cluster_ids
}

if fork_enabled and len(_procs) > 0:
_procs.sort(key=lambda x: x[1]) # To visualize progress
for _proc in tqdm(_procs, desc=f'Eval {tqdm_label} (waiting for forks to join)', position=tqdm_position):
_proc[0].join()
assert len(_procs) == len(_shared_list), "All forked eval iterations did not return results"
for _data in _shared_list:
pred_cluster_ids = (_data['cluster_labels'] + (max_pred_id + 1)).tolist()
cc_obj_vals['round'].append(_data['round_objective_value'])
cc_obj_vals['sdp'].append(_data['sdp_objective_value'])
cc_obj_vals['block_idxs'].append(_data['block_idx'])
cc_obj_vals['block_sizes'].append(_data['block_size'])
all_gold += list(np.reshape(_data['cluster_ids'], (_data['block_size'],)))
max_pred_id = max(pred_cluster_ids)
all_pred += list(pred_cluster_ids)

vmeasure = v_measure_score(all_gold, all_pred)
b3_f1 = compute_b3_f1(all_gold, all_pred)[2]
return (b3_f1, vmeasure, cc_obj_vals) if clustering_fn.__class__ is CCInference else (b3_f1, vmeasure)

y_pred, targets = [], []
for (idx, batch) in enumerate(tqdm(dataloader, desc=f'Evaluating {tqdm_label}', position=tqdm_position)):
pbar = tqdm(dataloader, desc=f'Eval {tqdm_label}', position=tqdm_position, disable=disable_tqdm)
for (idx, batch) in enumerate(pbar):
if overfit_batch_idx > -1:
if idx < overfit_batch_idx:
continue
if idx > overfit_batch_idx:
break
data, target = batch
data = data.reshape(-1, n_features).float()
pbar.set_description(f'Eval {tqdm_label} (sz={len(data)})')
assert data.shape[0] != 0
target = target.flatten().float()
# Forward pass through the pairwise model
Expand Down
Loading

0 comments on commit c1e13ae

Please sign in to comment.