diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c6e722a --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +checkpoints_detect/ +forecast_checkpoints/ +detect_result/ +forecast_results/ +*.pyc +datasets/ +pretrain-library/ \ No newline at end of file diff --git a/data_process/finetune_dataset.py b/data_process/finetune_dataset.py new file mode 100644 index 0000000..104a1c6 --- /dev/null +++ b/data_process/finetune_dataset.py @@ -0,0 +1,257 @@ +import os +import numpy as np +import pandas as pd + +from torch.utils.data import Dataset +from sklearn.preprocessing import StandardScaler + + +#Forecasting dataset for data saved in csv format +class Forecast_Dataset_csv(Dataset): + def __init__(self, root_path, data_path='ETTm1.csv', flag='train', size=None, + data_split = [0.7, 0.1, 0.2], scale=True, scale_statistic=None): + + # size [past_len, pred_len] + self.in_len = size[0] + self.out_len = size[1] + + # init + assert flag in ['train', 'test', 'val'] + type_map = {'train':0, 'val':1, 'test':2} + self.set_type = type_map[flag] + + self.scale = scale + + self.root_path = root_path + self.data_path = data_path + self.data_split = data_split + self.scale_statistic = scale_statistic + self.__read_data__() + + def __read_data__(self): + df_raw = pd.read_csv(os.path.join(self.root_path, self.data_path)) + + #int split, e.g. [34560,11520,11520] for ETTm1 + if (self.data_split[0] > 1): + train_num = self.data_split[0]; + val_num = self.data_split[1]; + test_num = self.data_split[2]; + #ratio split, e.g. [0.7, 0.1, 0.2] for Weather + else: + train_num = int(len(df_raw)*self.data_split[0]); + test_num = int(len(df_raw)*self.data_split[2]) + val_num = len(df_raw) - train_num - test_num; + + border1s = [0, train_num - self.in_len, train_num + val_num - self.in_len] + border2s = [train_num, train_num+val_num, train_num + val_num + test_num] + + border1 = border1s[self.set_type] + border2 = border2s[self.set_type] + + cols_data = df_raw.columns[1:] + df_data = df_raw[cols_data] + + if self.scale: + if self.scale_statistic is None: + self.scaler = StandardScaler() + train_data = df_data[border1s[0]:border2s[0]] + self.scaler.fit(train_data.values) + else: + self.scaler = StandardScaler(mean = self.scale_statistic['mean'], std = self.scale_statistic['std']) + data = self.scaler.transform(df_data.values) + else: + data = df_data.values + + self.data_x = data[border1:border2] + self.data_y = data[border1:border2] + + def __getitem__(self, index): + past_begin = index + past_end = past_begin + self.in_len + pred_begin = past_end + pred_end = pred_begin + self.out_len + + seq_x = self.data_x[past_begin:past_end].transpose() # [ts_d, ts_len] + seq_y = self.data_y[pred_begin:pred_end].transpose() + + return seq_x, seq_y + + def __len__(self): + return len(self.data_x) - self.in_len - self.out_len + 1 + + def inverse_transform(self, data): + return self.scaler.inverse_transform(data) + + +#Forecasting dataset for data saved in npy format +class Forecast_Dataset_npy(Dataset): + ''' + For dataset stored in .npy files (originally for anomaly detection), which is usually saved in '*_train.npy', '*_test.npy' and '*_test_label.npy' + We split the training set into training and validation set + We now use this dataset for forecasting task + ''' + + def __init__(self, root_path, data_name='SMD', flag="train", size=None, step=1, + valid_prop=0.2, scale=True, scale_statistic=None, lead_in=1000): + + self.in_len = size[0] + self.out_len = size[1] + + self.root_path = root_path + self.data_name = data_name + self.flag = flag + self.step = step #like stride in convolution, we may skip multiple steps when using sliding window on large dataset (e.g. SMD has 566,724 timestamps) + self.valid_prop = valid_prop + self.scale = scale + self.scale_statistic = scale_statistic + + ''' + the front part of the test series will be preserved for model input; + to keep consistant with csv format where some steps in val set are input to the model to predict test set; + With lead_in, if pred_len remain unchanged, varying input length will not change y in test set + ''' + self.lead_in = lead_in + + self.__read_data__() + + def __read_data__(self): + train_val_data = np.load(os.path.join(self.root_path, '{}_train.npy'.format(self.data_name))) + test_data = np.load(os.path.join(self.root_path, '{}_test.npy'.format(self.data_name))) + self.data_dim = train_val_data.shape[1] + # we do not need anomaly label for forecasting + + train_num = int(len(train_val_data) * (1 - self.valid_prop)) + + if self.scale: + # use the mean and std of training set + if self.scale_statistic is None: + self.scaler = StandardScaler() + train_data = train_val_data[:train_num] + self.scaler.fit(train_data) + else: + self.scaler = StandardScaler(mean=self.scale_statistic['mean'], std=self.scale_statistic['std']) + + scale_train_val_data = self.scaler.transform(train_val_data) + scale_test_data = self.scaler.transform(test_data) + + if self.flag == 'train': + self.data_x = scale_train_val_data[:train_num] + self.data_y = scale_train_val_data[:train_num] + elif self.flag == 'val': + self.data_x = scale_train_val_data[train_num - self.in_len:] + self.data_y = scale_train_val_data[train_num - self.in_len:] + elif self.flag == 'test': + ''' + |------------------|------|----------------------------------------| + ^ ^ + | | + lead_in-in_len lead_in + ''' + self.data_x = scale_test_data[self.lead_in - self.in_len:] + self.data_y = scale_test_data[self.lead_in - self.in_len:] + + else: + if self.flag == 'train': + self.data_x = train_val_data[:train_num] + self.data_y = train_val_data[:train_num] + elif self.flag == 'val': + self.data_x = train_val_data[train_num - self.in_len:] + self.data_y = train_val_data[train_num - self.in_len:] + elif self.flag == 'test': + self.data_x = test_data[self.lead_in - self.in_len:] + self.data_y = test_data[self.lead_in - self.in_len:] + + def __len__(self): + return (len(self.data_x) - self.in_len - self.out_len) // self.step + 1 + + def __getitem__(self, index): + index = index * self.step + + past_begin = index + past_end = past_begin + self.in_len + pred_begin = past_end + pred_end = pred_begin + self.out_len + + seq_x = self.data_x[past_begin:past_end].transpose() # [ts_d, ts_len] + seq_y = self.data_y[pred_begin:pred_end].transpose() + + return seq_x, seq_y + +#Anomaly Detection dataset (all saved in npy format) +class Detection_Dataset_npy(Dataset): + ''' + For dataset stored in .npy files, which is usually saved in '*_train.npy', '*_test.npy' and '*_test_label.npy' + We split the original training set into training and validation set + ''' + + def __init__(self, root_path, data_name='SMD', flag="train", seg_len=100, step=None, + valid_prop=0.2, scale=True, scale_statistic=None): + self.root_path = root_path + self.data_name = data_name + self.flag = flag + self.seg_len = seg_len # length of time-series segment, usually 100 for all anomaly detection experiments + self.step = step if step is not None else seg_len #use step to skip some steps when the set is too large + self.valid_prop = valid_prop + self.scale = scale + self.scale_statistic = scale_statistic + + self.__read_data__() + + def __read_data__(self): + train_val_ts = np.load(os.path.join(self.root_path, '{}_train.npy'.format(self.data_name))) + test_ts = np.load(os.path.join(self.root_path, '{}_test.npy'.format(self.data_name))) + test_label = np.load(os.path.join(self.root_path, '{}_test_label.npy'.format(self.data_name))) + self.data_dim = train_val_ts.shape[1] + + data_len = len(train_val_ts) + train_ts = train_val_ts[0:int(data_len * (1 - self.valid_prop))] + val_ts = train_val_ts[int(data_len * (1 - self.valid_prop)):] + + if self.scale: + if self.scale_statistic is None: + # use the mean and std of training set + self.scaler = StandardScaler() + self.scaler.fit(train_ts) + else: + self.scaler = StandardScaler(mean=self.scale_statistic['mean'], std=self.scale_statistic['std']) + self.train_ts = self.scaler.transform(train_ts) + self.val_ts = self.scaler.transform(val_ts) + self.test_ts = self.scaler.transform(test_ts) + + else: + self.train_ts = train_ts + self.val_ts = val_ts + self.test_ts = test_ts + self.threshold_ts = np.concatenate([self.train_ts, self.val_ts], axis=0) #use both training and validation set to set threshold + self.test_label = test_label + + def __len__(self): + # number of non-overlapping time-series segments + if self.flag == "train": + return (self.train_ts.shape[0] - self.seg_len) // self.step + 1 + elif (self.flag == 'val'): + return (self.val_ts.shape[0] - self.seg_len) // self.step + 1 + elif (self.flag == 'test'): + return (self.test_ts.shape[0] - self.seg_len) // self.step + 1 + elif (self.flag == 'threshold'): + return (self.threshold_ts.shape[0] - self.seg_len) // self.step + 1 + + def __getitem__(self, index): + # select data by flag + if self.flag == "train": + ts = self.train_ts + elif (self.flag == 'val'): + ts = self.val_ts + elif (self.flag == 'test'): + ts = self.test_ts + elif (self.flag == 'threshold'): + ts = self.threshold_ts + + index = index * self.step + + ts_seg = ts[index:index + self.seg_len, :].transpose() # [ts_dim, seg_len] + ts_label = np.zeros(self.seg_len) + if self.flag == 'test': + ts_label = self.test_label[index:index + self.seg_len] + + return ts_seg, ts_label \ No newline at end of file diff --git a/exp/__init__.py b/exp/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/exp/exp_detect.py b/exp/exp_detect.py new file mode 100644 index 0000000..eef2fc8 --- /dev/null +++ b/exp/exp_detect.py @@ -0,0 +1,265 @@ +import os +import torch +import numpy as np + +from models.finetune_model.UP2ME_detector import UP2ME_Detector +from data_process.finetune_dataset import Detection_Dataset_npy +import time +from sklearn.metrics import precision_recall_fscore_support, accuracy_score + +import torch.nn as nn +from torch import optim +from torch.utils.data import DataLoader +from torch.nn import DataParallel + +import json +import pickle + +from utils.tools import EarlyStopping, adjust_learning_rate +from utils.metrics import segment_adjust, adjusted_precision_recall_curve + +class UP2ME_exp_detect(object): + def __init__(self, args): + self.args = args + self.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() else 'cpu') + self.model = self._build_model().to(self.device) + + def _build_model(self): + model = UP2ME_Detector( + pretrained_model_path = self.args.pretrained_model_path, + pretrain_args=self.args.pretrain_args, + finetune_flag=(not self.args.IR_mode), + finetune_layers=self.args.finetune_layers, + dropout=self.args.dropout + ).float() + + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + print('Trainable parameters: {}'.format(trainable_params)) + + if self.args.use_multi_gpu and self.args.use_gpu: + model = nn.DataParallel(model, device_ids=self.args.device_ids) + return model + + def _get_data(self, flag): + args = self.args + + if flag == 'test' or flag == 'threshold': + shuffle_flag = False; drop_last = False; batch_size = args.batch_size; step=args.seg_len; + else: + shuffle_flag = True; drop_last = False; batch_size = args.batch_size; step=args.slide_step; + + data_set = Detection_Dataset_npy( + root_path=args.root_path, + data_name=args.data_name, + flag=flag, + seg_len=args.seg_len, + step=step, + valid_prop=args.valid_prop, + scale=True, + scale_statistic=None + ) + + print(flag, len(data_set)) + + data_loader = DataLoader( + data_set, + batch_size=batch_size, + shuffle=shuffle_flag, + num_workers=args.num_workers, + drop_last=drop_last) + + return data_set, data_loader + + def _select_optimizer(self): + model_optim = optim.Adam(self.model.parameters(), lr=self.args.learning_rate) + return model_optim + + def _select_criterion(self): + criterion = nn.MSELoss() + return criterion + + def _process_one_batch(self, dataset_object, batch_x): + batch_x = batch_x.float().to(self.device) + + if self.args.IR_mode: + outputs = self.model.immediate_detect(batch_x) + else: + outputs = self.model(batch_x, self.args.neighbor_num) + + return outputs, batch_x + + def vali(self, vali_data, vali_loader, criterion): + self.model.eval_mode() + total_loss = [] + with torch.no_grad(): + for i, (batch_x,_) in enumerate(vali_loader): + pred, true = self._process_one_batch( + vali_data, batch_x) + loss = criterion(pred.detach().cpu(), true.detach().cpu()) + total_loss.append(loss.detach().item()) + total_loss = np.average(total_loss) + + return total_loss + + def train(self, setting): + train_data, train_loader = self._get_data(flag = 'train') + vali_data, vali_loader = self._get_data(flag = 'val') + + path = os.path.join(self.args.checkpoints, setting) + if not os.path.exists(path): + os.makedirs(path) + with open(os.path.join(path, "args.json"), 'w') as f: + json.dump(vars(self.args), f, indent=True) + + train_steps = len(train_loader) + early_stopping = EarlyStopping(patience=self.args.tolerance, verbose=True) + + model_optim = self._select_optimizer() + criterion = self._select_criterion() + + for epoch in range(self.args.train_epochs): + time_now = time.time() + iter_count = 0 + train_loss = [] + + if isinstance(self.model, DataParallel): + self.model.module.train_mode() + else: + self.model.train_mode() + + epoch_time = time.time() + for i, (batch_x,batch_y) in enumerate(train_loader): + iter_count += 1 + + model_optim.zero_grad() + pred, true = self._process_one_batch(train_data, batch_x) + loss = criterion(pred, true) + train_loss.append(loss.item()) + + if (i+1) % 100==0: + print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item())) + speed = (time.time()-time_now)/iter_count + left_time = speed*((self.args.train_epochs - epoch)*train_steps - i) + print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time)) + iter_count = 0 + time_now = time.time() + + loss.backward() + model_optim.step() + + print("Epoch: {} cost time: {}".format(epoch+1, time.time()-epoch_time)) + train_loss = np.average(train_loss) + vali_loss = self.vali(vali_data, vali_loader, criterion) + + print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f}".format( + epoch + 1, train_steps, train_loss, vali_loss)) + early_stopping(vali_loss, self.model, path) + adjust_learning_rate(model_optim, epoch + 1, self.args) + if early_stopping.early_stop: + print("Early stopping") + break + + best_model_path = path+'/'+'checkpoint.pth' + self.model.load_state_dict(torch.load(best_model_path)) + state_dict = self.model.module.state_dict() if isinstance(self.model, DataParallel) else self.model.state_dict() + torch.save(state_dict, path+'/'+'checkpoint.pth') + + return self.model + + def test(self, setting, save_pred = False): + test_data, test_loader = self._get_data(flag='test') + threshold_data, threshold_loader = self._get_data(flag='threshold') + + self.model.eval_mode() + + # (1) use the threshold loader (train + val) to select the threshold for anomaly annotation + anomaly_score_threshold = [] + threshold_preds = [] + threshold_trues = [] + self.anomaly_criterion = nn.MSELoss(reduction='none') + + self.model.eval_mode() + with torch.no_grad(): + for i, (batch_x, _) in enumerate(threshold_loader): + # reconstruction + pred, true = self._process_one_batch(threshold_data, batch_x) + threshold_preds.append(pred.detach().cpu().numpy()) + threshold_trues.append(true.detach().cpu().numpy()) + # criterion + score = torch.mean(self.anomaly_criterion(pred, true), dim=-2) # pred and true in shape [batch_size, ts_len] + score = score.detach().cpu().numpy() + anomaly_score_threshold.append(score) + + anomaly_score_threshold = np.concatenate(anomaly_score_threshold, axis=0).reshape(-1) + threshold = np.percentile(anomaly_score_threshold, 100 - self.args.anomaly_ratio) #error of both train and val + print("Threshold :", threshold) + + # (2) calculate the anomaly score on the test set + test_anomaly_score = [] + test_labels = [] + test_preds = [] + test_trues = [] + with torch.no_grad(): + for i, (batch_x, batch_y) in enumerate(test_loader): + batch_x = batch_x.float().to(self.device) + # reconstruction + pred, true = self._process_one_batch(test_data, batch_x) + test_preds.append(pred.detach().cpu().numpy()) + test_trues.append(true.detach().cpu().numpy()) + # criterion + score = torch.mean(self.anomaly_criterion(pred, true), dim=-2) + score = score.detach().cpu().numpy() + test_anomaly_score.append(score) + test_labels.append(batch_y) + + test_anomaly_score = np.concatenate(test_anomaly_score, axis=0).reshape(-1) + + # (3) assign a binary label according to the threshold + anomaly_pred = (test_anomaly_score > threshold).astype(int) + test_labels = np.concatenate(test_labels, axis=0).reshape(-1) + anomaly_gt = test_labels.astype(int) + + print("pred: ", anomaly_pred.shape) + print("gt: ", anomaly_gt.shape) + + + # (4) perfrom segment adjustment to measure precision, recall and F1-score + adjusted_anomaly_pred = segment_adjust(anomaly_gt, anomaly_pred) + accuracy = accuracy_score(anomaly_gt, adjusted_anomaly_pred) + precision, recall, f_score, _ = precision_recall_fscore_support(anomaly_gt, adjusted_anomaly_pred, average='binary') + + # (5) evaluate precision-recall curve that is agnostic to the threshold + precision_list, recall_list, average_precision = adjusted_precision_recall_curve(anomaly_gt, test_anomaly_score) + + print("Accuracy : {:0.4f}, Precision : {:0.4f}, Recall : {:0.4f}, F-score : {:0.4f}, AP : {:0.4f} ".format( + accuracy, precision, recall, f_score, average_precision)) + + folder_path = self.args.save_folder + setting + '/' + if not os.path.exists(folder_path): + os.makedirs(folder_path) + + f = open(folder_path + "result.txt", "w") + f.write(setting + " \n") + f.write("Accuracy : {:0.4f}, Precision : {:0.4f}, Recall : {:0.4f}, F-score : {:0.4f}, AP : {:0.4f} \n".format( + accuracy, precision, + recall, f_score, average_precision)) + f.write('\n') + f.write('\n') + f.close() + + if save_pred: + threshold_preds = np.concatenate(threshold_preds, axis=0) + threshold_trues = np.concatenate(threshold_trues, axis=0) + test_preds = np.concatenate(test_preds, axis=0) + test_trues = np.concatenate(test_trues, axis=0) + data_save = { + 'threshold_reconstruct': threshold_preds, + 'threshold_trues': threshold_trues, + 'test_reconstruct': test_preds, + 'test_trues': test_trues, + 'test_labels': anomaly_gt + } + with open(folder_path + "reconstruction.pkl", 'wb') as f: + pickle.dump(data_save, f) + + return \ No newline at end of file diff --git a/exp/exp_forecast.py b/exp/exp_forecast.py new file mode 100644 index 0000000..ae9f2ab --- /dev/null +++ b/exp/exp_forecast.py @@ -0,0 +1,253 @@ +import os +import torch +import numpy as np +import random + +from models.finetune_model.UP2ME_forecaster import UP2ME_forecaster +from data_process.finetune_dataset import Forecast_Dataset_csv, Forecast_Dataset_npy +import time +from utils.metrics import metric + +import torch.nn as nn +from torch import optim +from torch.utils.data import DataLoader +from torch.nn import DataParallel + +import json +import pickle + +from utils.tools import EarlyStopping, adjust_learning_rate + +class UP2ME_exp_forecast(object): + def __init__(self, args): + self.args = args + self.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() else 'cpu') + self.model = self._build_model().to(self.device) + + def _build_model(self): + model = UP2ME_forecaster( + pretrained_model_path = self.args.pretrained_model_path, + pretrain_args=self.args.pretrain_args, + finetune_flag=(not self.args.IR_mode), + finetune_layers=self.args.finetune_layers, + dropout=self.args.dropout + ).float() + + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + print('Trainable parameters: {}'.format(trainable_params)) + + if self.args.use_multi_gpu and self.args.use_gpu: + model = nn.DataParallel(model, device_ids=self.args.device_ids) + return model + + def _get_data(self, flag): + args = self.args + + if flag == 'test': + shuffle_flag = False; drop_last = False; batch_size = args.batch_size; step = 1; + elif flag == 'val': + shuffle_flag = False; drop_last = False; batch_size = args.batch_size; step = args.slide_step; + else: + shuffle_flag = True; drop_last = False; batch_size = args.batch_size; step = args.slide_step; + + if args.data_format == 'csv': + data_set = Forecast_Dataset_csv( + root_path=args.root_path, + data_path=args.data_path, + flag=flag, + size=[args.in_len, args.out_len], + data_split = args.data_split, + scale=True, + scale_statistic=None + ) + elif args.data_format == 'npy': + data_set = Forecast_Dataset_npy( + root_path=args.root_path, + data_name=args.data_name, + flag=flag, + size=[args.in_len, args.out_len], + step=step, + valid_prop=args.valid_prop, + scale=True, + scale_statistic=None, + ) + + data_loader = DataLoader( + data_set, + batch_size=batch_size, + shuffle=shuffle_flag, + num_workers=args.num_workers, + drop_last=drop_last) + + return data_set, data_loader + + def _select_optimizer(self): + model_optim = optim.Adam(self.model.parameters(), lr=self.args.learning_rate) + return model_optim + + def _select_criterion(self): + criterion = nn.MSELoss() + return criterion + + def _process_one_batch(self, dataset_object, batch_x, batch_y): + batch_x = batch_x.float().to(self.device) + batch_y = batch_y.float().to(self.device) + + if self.args.IR_mode: + outputs = self.model.immediate_forecast(batch_x, self.args.out_len) + else: + outputs = self.model(batch_x, self.args.out_len//self.args.pretrain_args['patch_size'], self.args.neighbor_num) + + return outputs, batch_y + + def vali(self, vali_data, vali_loader, criterion): + if isinstance(self.model, DataParallel): + self.model.module.eval_mode() + else: + self.model.eval_mode() + total_loss = [] + with torch.no_grad(): + for i, (batch_x,batch_y) in enumerate(vali_loader): + pred, true = self._process_one_batch( + vali_data, batch_x, batch_y) + loss = criterion(pred.detach().cpu(), true.detach().cpu()) + total_loss.append(loss.detach().item()) + total_loss = np.average(total_loss) + + return total_loss + + def train(self, setting): + train_data, train_loader = self._get_data(flag = 'train') + vali_data, vali_loader = self._get_data(flag = 'val') + + path = os.path.join(self.args.checkpoints, setting) + if not os.path.exists(path): + os.makedirs(path) + with open(os.path.join(path, "args.json"), 'w') as f: + json.dump(vars(self.args), f, indent=True) + + train_steps = len(train_loader) + early_stopping = EarlyStopping(patience=self.args.tolerance, verbose=True) + + model_optim = self._select_optimizer() + criterion = self._select_criterion() + + for epoch in range(self.args.train_epochs): + time_now = time.time() + iter_count = 0 + train_loss = [] + + if isinstance(self.model, DataParallel): + self.model.module.train_mode() + else: + self.model.train_mode() + epoch_time = time.time() + for i, (batch_x,batch_y) in enumerate(train_loader): + iter_count += 1 + + model_optim.zero_grad() + pred, true = self._process_one_batch( + train_data, batch_x, batch_y) + loss = criterion(pred, true) + train_loss.append(loss.item()) + + if (i+1) % 100==0: + print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item())) + speed = (time.time()-time_now)/iter_count + left_time = speed*((self.args.train_epochs - epoch)*train_steps - i) + print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time)) + iter_count = 0 + time_now = time.time() + + loss.backward() + model_optim.step() + + print("Epoch: {} cost time: {}".format(epoch+1, time.time()-epoch_time)) + train_loss = np.average(train_loss) + vali_loss = self.vali(vali_data, vali_loader, criterion) + + print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f}".format( + epoch + 1, train_steps, train_loss, vali_loss)) + early_stopping(vali_loss, self.model, path) + adjust_learning_rate(model_optim, epoch + 1, self.args) + if early_stopping.early_stop: + print("Early stopping") + break + + best_model_path = path+'/'+'checkpoint.pth' + self.model.load_state_dict(torch.load(best_model_path)) + state_dict = self.model.module.state_dict() if isinstance(self.model, DataParallel) else self.model.state_dict() + torch.save(state_dict, path+'/'+'checkpoint.pth') + + return self.model + + def test(self, setting, save_pred = False): + test_data, test_loader = self._get_data(flag='test') + + if self.args.is_training == 0 and not self.args.IR_mode: #load previous finetuned model + print('loading model ......') + path = os.path.join(self.args.checkpoints, setting) + best_model_path = path+'/'+'checkpoint.pth' + self.model.load_state_dict(torch.load(best_model_path, map_location=torch.device('cpu'))) + + if isinstance(self.model, DataParallel): + self.model.module.eval_mode() + else: + self.model.eval_mode() + + pasts = [] + preds = [] + trues = [] + metrics_all = [] + instance_num = 0 + + with torch.no_grad(): + for i, (batch_x,batch_y) in enumerate(test_loader): + pred, true = self._process_one_batch( + test_data, batch_x, batch_y) + batch_size = pred.shape[0] + instance_num += batch_size + batch_metric = np.array(metric(pred.detach().cpu().numpy(), true.detach().cpu().numpy())) * batch_size + metrics_all.append(batch_metric) + if (save_pred): + pasts.append(batch_x.detach().cpu().numpy()) + preds.append(pred.detach().cpu().numpy()) + trues.append(true.detach().cpu().numpy()) + + print('test instance num: {}'.format(instance_num)) + metrics_all = np.stack(metrics_all, axis = 0) + metrics_mean = metrics_all.sum(axis = 0) / instance_num + + # result save + folder_path = './forecast_results/' + setting +'/' + if not os.path.exists(folder_path): + os.makedirs(folder_path) + + mae, mse, rmse, mape, mspe = metrics_mean + print('mse:{}, mae:{}'.format(mse, mae)) + + np.save(folder_path+'metrics.npy', np.array([mae, mse, rmse, mape, mspe])) + if (save_pred): + print('saving predicted results in {}'.format(folder_path)) + preds = np.concatenate(preds, axis = 0) + trues = np.concatenate(trues, axis = 0) + pasts = np.concatenate(pasts, axis = 0) + print(preds.shape) + np.save(folder_path +'pred.npy', preds) + np.save(folder_path +'true.npy', trues) + np.save(folder_path +'past.npy', pasts) + + # write to txt + if self.args.IR_mode: + save_file = folder_path+'IR-mode_metrics.txt' + else: + save_file = folder_path+'finetune_metrics.txt' + with open(save_file, 'w') as f: + f.write(setting+'\n') + f.write('mse:{}\n'.format(mse)) + f.write('mae:{}\n'.format(mae)) + f.write('rmse:{}\n'.format(rmse)) + f.write('mape:{}\n'.format(mape)) + f.write('mspe:{}\n'.format(mspe)) + + return \ No newline at end of file diff --git a/models/finetune_model/UP2ME_detector.py b/models/finetune_model/UP2ME_detector.py new file mode 100644 index 0000000..0160bc3 --- /dev/null +++ b/models/finetune_model/UP2ME_detector.py @@ -0,0 +1,136 @@ +import torch +import torch.nn as nn +from ..pretrain_model.UP2ME_model import UP2ME_model +from .temporal_channel_layer import Temporal_Channel_Layer +from ..pretrain_model.embed import learnable_position_embedding +from .graph_structure import graph_construct +from einops import rearrange + +class UP2ME_Detector(nn.Module): + def __init__(self, pretrained_model_path, pretrain_args, finetune_flag=False, finetune_layers=3, dropout=0.0): + super(UP2ME_Detector, self).__init__() + self.pretrained_model_path = pretrained_model_path + self.pretrain_args = pretrain_args + self.data_dim = pretrain_args['data_dim'] + self.patch_size = pretrain_args['patch_size'] + + self.finetune_flag = finetune_flag + self.finetune_layers = finetune_layers + self.dropout = dropout + + # load pre-trained model + self.pretrained_model = UP2ME_model( + data_dim=pretrain_args['data_dim'], patch_size=pretrain_args['patch_size'],\ + d_model=pretrain_args['d_model'], d_ff = pretrain_args['d_ff'], n_heads=pretrain_args['n_heads'], \ + e_layers=pretrain_args['e_layers'], d_layers = pretrain_args['d_layers'], dropout=pretrain_args['dropout']) + self.load_pre_trained_model() + + # if fine-tune, add new layers + if self.finetune_flag: + self.learnable_patch = nn.Parameter(torch.randn(pretrain_args['d_model'])) + self.position_embedding = learnable_position_embedding(pretrain_args['d_model']) + self.channel_embedding = nn.Embedding(pretrain_args['data_dim'], pretrain_args['d_model']) + self.init_enc2dec_param() + + self.temporal_channel_layers = nn.ModuleList() + for _ in range(finetune_layers): + self.temporal_channel_layers.append(Temporal_Channel_Layer( + d_model=pretrain_args['d_model'], n_heads=pretrain_args['n_heads'], d_ff=pretrain_args['d_ff'], dropout=dropout)) + + def load_pre_trained_model(self): + #load the pre-trained model + pretrained_model = torch.load(self.pretrained_model_path, map_location='cpu') + self.pretrained_model.load_state_dict(pretrained_model) + + #freeze the encoder and decoder of pre-trained model + for param in self.pretrained_model.parameters(): + param.requires_grad = False + + def init_enc2dec_param(self): + self.position_embedding.load_state_dict(self.pretrained_model.position_embedding.state_dict()) + self.channel_embedding.load_state_dict(self.pretrained_model.channel_embedding.state_dict()) + self.learnable_patch.data.copy_(self.pretrained_model.learnable_patch.data) + + + def train_mode(self): + self.train() + self.pretrained_model.eval() + return + + def eval_mode(self): + self.eval() + return + + def immediate_detect(self, multi_ts): + ''' + Immediate anomaly detection, the key idea is to iteratively mask one patch and use the remaining to reconstruct it + It is more difficult to reconstruct abnormal patches than normal ones + Args: + multi_ts: [batch_size, ts_d, ts_len] + ''' + + batch_size, ts_d, ts_len = multi_ts.shape + + reconstruct_patches = [] + patch_num = ts_len // self.patch_size + patch_idx = torch.arange(patch_num)[None, None, :].expand(batch_size, ts_d, -1).to(multi_ts.device) + + for i in range(patch_num): + masked_patch_index = i * torch.ones([batch_size, ts_d, 1], dtype=torch.long).to(multi_ts.device) # [batch_size, ts_d, 1] + unmasked_patch_index = torch.cat([patch_idx[:, :, :i], patch_idx[:, :, i + 1:]], dim=-1) # [batch_size, ts_d, patch_num-1] + encoded_patch_unmasked = self.pretrained_model.encode_multi_to_patch(multi_ts, masked_patch_index, unmasked_patch_index) # [batch_size, ts_d, patch_num-1, d_model] + + recons_full_ts = self.pretrained_model.decode_patch_to_multi(encoded_patch_unmasked, masked_patch_index, unmasked_patch_index) # [batch_size, ts_d, ts_len] + recons_patch = recons_full_ts[:, :, i * self.patch_size: (i + 1) * self.patch_size] #pick out the [batch_size, ts_d, patch_size] + reconstruct_patches.append(recons_patch) + + reconstruct_ts = torch.cat(reconstruct_patches, dim=-1) # [batch_size, ts_d, ts_len] + + return reconstruct_ts + + def forward(self, input_ts, neighbor_num = 10): + batch_size, ts_d, ts_len = input_ts.shape + + encoded_patch = self.pretrained_model.encode_multi_to_patch(input_ts) + _, _, patch_num, d_model = encoded_patch.shape + + #compute the graph structure + graph_adj = graph_construct(encoded_patch, k = neighbor_num) + + #<----------------------------------------------------prepare unmasked patches-------------------------------------------------------> + channel_idx = torch.arange(ts_d).to(input_ts.device) + channel_embed = self.channel_embedding(channel_idx) #[ts_d, d_model] + + masked_patch_embed = self.learnable_patch[None, None, None, :].expand(batch_size, ts_d, -1, -1) + + full_patch_idx = torch.arange(patch_num)[None, :].expand(batch_size, -1).to(input_ts.device) + position_embed_masked = self.position_embedding([batch_size, patch_num, d_model], full_patch_idx) #[batch_size, patch_num, d_model] + + masked_patches_embed = channel_embed[None, :, None, :] + masked_patch_embed + position_embed_masked[:, None, :, :] #[batch_size, ts_d, patch_num, d_model] + #<------------------------------------------------------------------------------------------------------------------------------------------> + + #iteratively mask each patch along time axes and use others to reconstruct it + reconstruct_segments = [] + patch_idx = torch.arange(patch_num)[None, None, :].expand(batch_size, ts_d, -1).to(input_ts.device) + + for i in range(patch_num): + masked_patch_idx = i * torch.ones([batch_size, ts_d, 1], dtype=torch.long).to(input_ts.device) + unmasked_patch_idx = torch.cat([patch_idx[:, :, :i], patch_idx[:, :, i + 1:]], dim=-1) + encoded_patch_unmasked = self.pretrained_model.encode_multi_to_patch(input_ts, masked_patch_idx, unmasked_patch_idx) + patches_full = torch.cat([encoded_patch_unmasked[:, :, :i, :], masked_patches_embed[:, :, i:i+1, :], encoded_patch_unmasked[:, :, i:, :]], dim = -2) #[batch_size, ts_d, patch_num, d_model] + + #passing TC layers + for layer in self.temporal_channel_layers: + patches_full = layer(patches_full, graph_adj) + + #passing pretrained decoder + flatten_patches_full = rearrange(patches_full, 'batch_size ts_d seq_len d_model -> (batch_size ts_d) seq_len d_model') + flatten_channel_index = channel_idx[None, :].expand(batch_size, -1).reshape(-1) + flatten_full_ts = self.pretrained_model.pretrain_decode(flatten_patches_full, flatten_channel_index) # (batch_size*ts_d, past_len+pred_len) + full_ts = rearrange(flatten_full_ts, '(batch_size ts_d) ts_len -> batch_size ts_d ts_len', batch_size = batch_size) + recons_segment = full_ts[:, :, i*self.patch_size: (i+1)*self.patch_size] #[batch_size, ts_d, patch_size] + reconstruct_segments.append(recons_segment) + + reconstruct_ts = torch.cat(reconstruct_segments, dim=-1) #[batch_size, ts_d, ts_len] + + return reconstruct_ts \ No newline at end of file diff --git a/models/finetune_model/UP2ME_forecaster.py b/models/finetune_model/UP2ME_forecaster.py new file mode 100644 index 0000000..a29fa96 --- /dev/null +++ b/models/finetune_model/UP2ME_forecaster.py @@ -0,0 +1,131 @@ +import torch +import torch.nn as nn +from ..pretrain_model.UP2ME_model import UP2ME_model +from .temporal_channel_layer import Temporal_Channel_Layer +from ..pretrain_model.embed import learnable_position_embedding +from .graph_structure import graph_construct +from einops import rearrange + +class UP2ME_forecaster(nn.Module): + def __init__(self, pretrained_model_path, pretrain_args, finetune_flag=False, finetune_layers=3, dropout=0.0): + super(UP2ME_forecaster, self).__init__() + self.pretrained_model_path = pretrained_model_path + self.pretrain_args = pretrain_args + self.data_dim = pretrain_args['data_dim'] + self.patch_size = pretrain_args['patch_size'] + + self.finetune_flag = finetune_flag + self.finetune_layers = finetune_layers + self.dropout = dropout + + # load pre-trained model + self.pretrained_model = UP2ME_model( + data_dim=pretrain_args['data_dim'], patch_size=pretrain_args['patch_size'],\ + d_model=pretrain_args['d_model'], d_ff = pretrain_args['d_ff'], n_heads=pretrain_args['n_heads'], \ + e_layers=pretrain_args['e_layers'], d_layers = pretrain_args['d_layers'], dropout=pretrain_args['dropout']) + self.load_pre_trained_model() + + # if fine-tune, add new layers + if self.finetune_flag: + self.enc_2_dec = nn.Linear(pretrain_args['d_model'], pretrain_args['d_model']) + self.learnable_patch = nn.Parameter(torch.randn(pretrain_args['d_model'])) + self.position_embedding = learnable_position_embedding(pretrain_args['d_model']) + self.channel_embedding = nn.Embedding(pretrain_args['data_dim'], pretrain_args['d_model']) + self.init_enc2dec_param() + + self.temporal_channel_layers = nn.ModuleList() + for _ in range(finetune_layers): + self.temporal_channel_layers.append(Temporal_Channel_Layer( + d_model=pretrain_args['d_model'], n_heads=pretrain_args['n_heads'], d_ff=pretrain_args['d_ff'], dropout=dropout)) + + def load_pre_trained_model(self): + #load the pre-trained model + pretrained_model = torch.load(self.pretrained_model_path, map_location='cpu') + self.pretrained_model.load_state_dict(pretrained_model) + + #freeze the encoder and decoder of pre-trained model + for param in self.pretrained_model.parameters(): + param.requires_grad = False + + def init_enc2dec_param(self): + self.position_embedding.load_state_dict(self.pretrained_model.position_embedding.state_dict()) + self.channel_embedding.load_state_dict(self.pretrained_model.channel_embedding.state_dict()) + self.learnable_patch.data.copy_(self.pretrained_model.learnable_patch.data) + self.enc_2_dec.load_state_dict(self.pretrained_model.enc_2_dec.state_dict()) + + + def train_mode(self): + self.train() + self.pretrained_model.eval() + return + + def eval_mode(self): + self.eval() + return + + def immediate_forecast(self, multi_ts, pred_len): + ''' + Immediate reaction mode, directly use the pretrained model to perform multi-variate forecasting without any parameter modification + Args: + multi_ts: [batch_size, ts_d, past_len] + pred_len: [batch_size] + ''' + + batch_size, ts_d, past_len = multi_ts.shape + + # encode past patches + encoded_past_patch = self.pretrained_model.encode_multi_to_patch(multi_ts) + + #prepare masked and unmasked indices + full_len = past_len + pred_len + full_patch_num = full_len // self.patch_size + past_patch_num = past_len // self.patch_size + full_patch_idx = torch.arange(full_patch_num)[None, None, :].expand(batch_size, ts_d, -1).to(multi_ts.device) + past_patch_idx = full_patch_idx[:, :, :past_patch_num] + pred_patch_idx = full_patch_idx[:, :, past_patch_num:] + + reconstructed_full_ts = self.pretrained_model.decode_patch_to_multi(encoded_past_patch, pred_patch_idx, past_patch_idx) + + pred_ts = reconstructed_full_ts[:, :, past_len:] + + return pred_ts + + def forward(self, past_ts, pred_patch_num, neighbor_num = 10): + batch_size, ts_d, _ = past_ts.shape + + encoded_patch_past = self.pretrained_model.encode_multi_to_patch(past_ts) + + #compute the graph structure + graph_adj = graph_construct(encoded_patch_past, k = neighbor_num) + + encoded_patch_past_transformed = self.enc_2_dec(encoded_patch_past) #[batch_size, ts_d, patch_num, d_model] + + #<----------------------------------concatenate past and future patches--------------------------------------------> + _, _, past_patch_num, d_model = encoded_patch_past_transformed.shape + channel_idx = torch.arange(ts_d).to(past_ts.device) + channel_embed = self.channel_embedding(channel_idx) #[ts_d, d_model] + channel_embed_future = channel_embed[None, :, None, :].expand(batch_size, -1, pred_patch_num, -1) #[batch_size, ts_d, pred_patch_num, d_model] + + patch_embed_future = self.learnable_patch[None, None, None, :].expand(batch_size, ts_d, pred_patch_num, -1) + + future_patch_idx = torch.arange(past_patch_num, past_patch_num + pred_patch_num)[None, :].expand(batch_size, -1).to(past_ts.device) + position_embed_future = self.position_embedding([batch_size, pred_patch_num, d_model], future_patch_idx) #[batch_size, pred_patch_num, d_model] + position_embed_future = position_embed_future[:, None, :, :].expand(-1, ts_d, -1, -1) #[batch_size, ts_d, pred_patch_num, d_model] + + patches_future = patch_embed_future + position_embed_future + channel_embed_future #[batch_size, ts_d, pred_patch_num, d_model] + patches_past = encoded_patch_past_transformed + patches_full = torch.cat((patches_past, patches_future), dim = -2) #[batch_size, ts_d, past_patch_num+pred_patch_num, d_model] + #<-----------------------------------------------------------------------------------------------------------------> + + #passing TC layers + for layer in self.temporal_channel_layers: + patches_full = layer(patches_full, graph_adj) + + #passing pretrained decoder + flatten_patches_full = rearrange(patches_full, 'batch_size ts_d seq_len d_model -> (batch_size ts_d) seq_len d_model') + flatten_channel_index = channel_idx[None, :].expand(batch_size, -1).reshape(-1) + flatten_full_ts = self.pretrained_model.pretrain_decode(flatten_patches_full, flatten_channel_index) # (batch_size*ts_d, past_len+pred_len) + full_ts = rearrange(flatten_full_ts, '(batch_size ts_d) ts_len -> batch_size ts_d ts_len', batch_size = batch_size) + pred_ts = full_ts[:, :, -pred_patch_num*self.pretrain_args['patch_size']:] + + return pred_ts \ No newline at end of file diff --git a/models/finetune_model/graph_structure.py b/models/finetune_model/graph_structure.py new file mode 100644 index 0000000..1a0db28 --- /dev/null +++ b/models/finetune_model/graph_structure.py @@ -0,0 +1,55 @@ +import torch +import math +import numpy as np + + +def batch_cosine_similarity(x, y, eps=1e-8): + ''' + compute the cosine similarity matrix among variables and get D * D matrix + x, y: [batch_size, ts_d, d_model] + ''' + + inner_dot = torch.einsum('bqd,bkd->bqk', x, y) + x_norm = torch.norm(x, dim=-1, keepdim=True) # [batch_size, ts_d, 1] + y_norm = torch.norm(y, dim=-1, keepdim=True) # [batch_size, ts_d, 1] + norm_dot = torch.einsum('bqd,bkd->bqk', x_norm, y_norm) # [batch_size, ts_d, ts_d] + cos_corr = inner_dot / (norm_dot + eps) + + return cos_corr + +def k_nearest_neighbor(corr_matrix, k=10): + ''' + return an adjacency matrix of the k nearest neighbors + corr_matrix: [batch_size, ts_d, ts_d] + ''' + batch_size, ts_d, _ = corr_matrix.shape + edges_knn = torch.topk(corr_matrix, k, dim=-1)[1] # [batch_size * ts_d, k] + + knn_adj = torch.zeros(batch_size, ts_d, ts_d).to(corr_matrix.device) + knn_adj.scatter_(-1, edges_knn, 1) + knn_adj = knn_adj.permute(0, 2, 1) # source to target + + return knn_adj + + +def graph_construct(encoded_patch, patch_mask=None, k=10): + # encoded_patch: [batch_size, ts_d, patch_num, d_model] + # patch_mask: [batch_size, ts_d, patch_num], 1 for masked, 0 for unmasked + batch_size, ts_d, patch_num, d_model = encoded_patch.shape + + if patch_mask is not None: + encoded_patch = encoded_patch.masked_fill(patch_mask[:, :, :, None] == 1, -np.inf) + + channel_encode = encoded_patch.max(dim=-2)[0] # max pooling over the patch dimension, [batch_size, ts_d, d_model] + corr_matrix = batch_cosine_similarity(channel_encode, channel_encode) # [batch_size, ts_d, ts_d] + + # KNN graph + knn_adj = k_nearest_neighbor(corr_matrix, k) # [batch_size, ts_d, ts_d] i-->j + + # top k*N graph + top_k_threshold = torch.topk(corr_matrix.reshape(batch_size, -1), k * ts_d, dim=-1)[0][:, -1] # [batch_size] + top_k_adj = (corr_matrix >= top_k_threshold[:, None, None]).long() # [batch_size, ts_d, ts_d] i-->j + + graph_adj = knn_adj * top_k_adj # consider both knn and top-k threshold, [batch_size, ts_d, ts_d] i-->j + + return graph_adj diff --git a/models/finetune_model/temporal_channel_layer.py b/models/finetune_model/temporal_channel_layer.py new file mode 100644 index 0000000..b3c79fa --- /dev/null +++ b/models/finetune_model/temporal_channel_layer.py @@ -0,0 +1,131 @@ +import torch +import numpy as np +import math +import torch.nn as nn +from torch.nn import TransformerEncoderLayer +from torch_geometric.nn.conv import TransformerConv +from torch_geometric.utils import to_dense_batch +from einops import rearrange + +def batch_to_sparse(batch_node, adj_matrix): + ''' + convert the batched node features and adjacency matrix to sparse tensor for pytorch geometric + batch_node: [batch_size, node_num, d_model] + adj_matrix: [batch_size, node_num, node_num], binary adjacency matrix + ''' + batch_size, node_num, _ = batch_node.shape + offset, row, col = torch.nonzero(adj_matrix > 0).t() # [edge_num] + + row = row + offset * node_num + col = col + offset * node_num + edge_index = torch.stack([row, col], dim=0).long() + list_node = batch_node.reshape(batch_size * node_num, -1) + batch_idx= torch.arange(0, batch_size, device=batch_node.device)[:, None].expand(-1, node_num).reshape(-1) + + return list_node, edge_index, batch_idx + +class Graph_MultiHeadAttention(nn.Module): + def __init__(self, d_model, n_heads, dropout): + super(Graph_MultiHeadAttention, self).__init__() + self.d_model = d_model + self.n_heads = n_heads + self.head_dim = d_model // n_heads + self.dropout = dropout + + self.MHA_layer = TransformerConv(in_channels=d_model, out_channels=self.head_dim, heads=n_heads, concat=True, dropout=dropout) + self.out_projection = nn.Linear(self.n_heads * self.head_dim, d_model) + + def forward(self, node, edge_index): + # node: [|V|, d_model] + # edge_index: [2, |E|] + # batch_idx: [|V|] + + node = self.MHA_layer(node, edge_index) # [|V|, n_heads * head_dim] + + output = self.out_projection(node) # [|V|, d_model] + + return output + +class Graph_TransformerEncoderLayer(nn.Module): + def __init__(self, d_model, n_heads, d_ff, dropout): + super(Graph_TransformerEncoderLayer, self).__init__() + self.d_model = d_model + self.n_heads = n_heads + self.d_ff = d_ff + self.dropout = dropout + + self.attention = Graph_MultiHeadAttention(d_model, n_heads, dropout) + self.feed_forward = nn.Sequential( + nn.Linear(d_model, d_ff), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(d_ff, d_model) + ) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + def forward(self, node, edge_index, norm_first=False): + + if norm_first: + node = node + self._sa_block(self.norm1(node), edge_index) + node = node + self._ff_block(self.norm2(node)) + else: + node = self.norm1(node + self._sa_block(node, edge_index)) + node = self.norm2(node + self._ff_block(node)) + + return node + + def _sa_block(self, node, edge_index): + node = self.attention(node, edge_index) + return self.dropout1(node) + + # feed forward block + def _ff_block(self, node): + node = self.feed_forward(node) + return self.dropout2(node) + +class Temporal_Channel_Layer(nn.Module): + def __init__(self, d_model, n_heads, d_ff, dropout): + super(Temporal_Channel_Layer, self).__init__() + self.d_model = d_model + self.n_heads = n_heads + self.d_ff = d_ff + self.dropout = dropout + + self.temporal_layer_blocks = 1 + self.temporal_layers = nn.ModuleList() + for i in range(self.temporal_layer_blocks): + self.temporal_layers.append(TransformerEncoderLayer(d_model, n_heads, d_ff, dropout, batch_first = True)) + self.channel_layer_blocks = 1 + self.channel_layers = nn.ModuleList() + for i in range(self.channel_layer_blocks): + self.channel_layers.append(Graph_TransformerEncoderLayer(d_model, n_heads, d_ff, dropout)) + + def forward(self, x, graph_adj): + # x: [batch_size, ts_d, seq_len, d_model] + # graph_adj: [batch_size, ts_d, ts_d] + + batch_size, ts_d, seq_len, d_model = x.shape + temporal_input = x.reshape(batch_size * ts_d, seq_len, d_model) + temporal_output = temporal_input + for temporal_layer in self.temporal_layers: + temporal_output = temporal_layer(temporal_output) # [batch_size * ts_d, seq_len, d_model] + + channel_input = rearrange(temporal_output, '(batch_size ts_d) seq_len d_model -> (batch_size seq_len) ts_d d_model', batch_size = batch_size) + graph_adj_expand = graph_adj[:, None, :, :].expand(-1, seq_len, -1, -1) + graph_adj_expand = rearrange(graph_adj_expand, 'batch_size seq_len ts_d1 ts_d2 -> (batch_size seq_len) ts_d1 ts_d2') + + channel_input_sparse, edge_index, batch_idx = batch_to_sparse(channel_input, graph_adj_expand) #for torch geometric, [batch_size * seq_len * ts_d, d_model], [2, edge_num], [batch_size * seq_len * ts_d] + channel_output_sparse = channel_input_sparse + + for channel_layer in self.channel_layers: + channel_output_sparse = channel_layer(channel_output_sparse, edge_index) + channel_output_batch = to_dense_batch(channel_output_sparse, batch=batch_idx)[0] # [batch_size * seq_len, ts_d, d_model] + + output = rearrange(channel_output_batch, '(batch_size seq_len) ts_d d_model -> batch_size ts_d seq_len d_model', batch_size = batch_size) + + return output diff --git a/models/pretrain_model/UP2ME_model.py b/models/pretrain_model/UP2ME_model.py new file mode 100644 index 0000000..069ee95 --- /dev/null +++ b/models/pretrain_model/UP2ME_model.py @@ -0,0 +1,161 @@ +import torch +from torch import nn +from .embed import patch_embedding, learnable_position_embedding, RevIN +from .encoder_decoder import encoder, decoder +from einops import rearrange, repeat +from torch.nn.utils.rnn import pad_sequence +from loguru import logger + +class UP2ME_model(nn.Module): + def __init__(self, data_dim, patch_size, + d_model=256, d_ff=512, n_heads=4, e_layers=3, d_layers=1, dropout=0.0, + mask_ratio=0.75, device=torch.device('cuda:0')): + super(UP2ME_model, self).__init__() + + self.data_dim = data_dim + self.patch_size = patch_size + self.d_model = d_model + self.d_ff = d_ff + self.n_heads = n_heads + self.e_layers = e_layers + self.d_layers = d_layers + self.dropout = dropout + self.device = device + + self.RevIN = RevIN(data_dim, affine=True) + self.patch_embedding = patch_embedding(patch_size, d_model) + self.position_embedding = learnable_position_embedding(d_model) + self.channel_embedding = nn.Embedding(data_dim, d_model) + + #encoder + self.encoder = encoder(e_layers, d_model, n_heads, d_ff, dropout) + + #encoder-space to decoder-space + self.enc_2_dec = nn.Linear(d_model, d_model) + self.learnable_patch = nn.Parameter(torch.randn(d_model)) + + #decoder + self.decoder = decoder(patch_size, d_layers, d_model, n_heads, d_ff, dropout) + + def encode_uni_to_patch(self, ts, channel_idx, masked_patch_index=None, unmasked_patch_index=None, imputation_point_mask=None): + ''' + Encode the unmaksed patches of unvariate time series to latent patches. + + Args: + ts: time series with shape [batch_size, ts_length] + channel_idx: channel index of time series with shape [batch_size] + masked_patch_index: masked patch index with shape [batch_size, masked_patch_num] + unmasked_patch_index: unmasked patch index with shape [batch_size, unmasked_patch_num] + imputation_point_mask: point mask with shape [batch_size, ts_length], for imputation task. + ''' + + ts = self.RevIN(ts, channel_idx, mode='norm', mask=imputation_point_mask) + + patch_embed = self.patch_embedding.forward_uni(ts) # [batch_size, patch_size, d_model] + position_embed = self.position_embedding(patch_embed.shape) + channel_embed = self.channel_embedding(channel_idx) # [batch_size, d_model] + patches = patch_embed + position_embed + channel_embed[:, None, :] # [batch_size, patch_size, d_model] + + imputation_patch_mask = None + if masked_patch_index is not None and unmasked_patch_index is not None: # only encode the unmasked patches + encoder_input = patches.gather(1, unmasked_patch_index[:, :, None].expand(-1, -1, self.d_model)) + elif imputation_point_mask is not None: + imputation_patch_mask = rearrange(imputation_point_mask, 'batch_size (patch_num patch_size) -> batch_size patch_num patch_size', patch_size=self.patch_size) + imputation_patch_mask = imputation_patch_mask.sum(dim=-1) > 0 # [batch_size, patch_num] + encoder_input = patches + else: + encoder_input = patches + encoded_patch_unmasked = self.encoder.forward_uni(encoder_input, imputation_patch_mask) # [batch_size, unmasked_patch_num, d_model] + + return encoded_patch_unmasked + + def patch_concatenate(self, encoded_patch_unmasked, channel_idx, masked_patch_index, unmasked_patch_index): + ''' + concatenate encoded unmasked patches and tokens indicating masked patches, i.e. First line in Equation (4) except enc-to-dec + + Args: + encoded_patch_unmasked: encoded patches without masking with shape [batch_size, unmasked_patch_num, d_model] + channel_idx: channel index of time series with shape [batch_size] + masked_patch_index: masked patch index with shape [batch_size, masked_patch_num] + unmasked_patch_index: unmasked patch index with shape [batch_size, unmasked_patch_num] + ''' + batch_size, unmasked_patch_num, _ = encoded_patch_unmasked.shape + masked_patch_num = masked_patch_index.shape[1] + + patch_embed_masked = self.learnable_patch[None, None, :].expand(batch_size, masked_patch_num, -1) + position_embed_masked = self.position_embedding(patch_embed_masked.shape, masked_patch_index) + channel_embed_masked = self.channel_embedding(channel_idx) + patches_masked = patch_embed_masked + position_embed_masked + channel_embed_masked[:, None, :] + + patches_full = torch.cat([patches_masked, encoded_patch_unmasked], dim=1) #concate masked&unmasked patches + patch_index_full = torch.cat([masked_patch_index, unmasked_patch_index], dim=1) + origin_patch_index = torch.argsort(patch_index_full, dim=1) + origin_patch_index = origin_patch_index.to(encoded_patch_unmasked.device) + patches_full_sorted = patches_full.gather(1, origin_patch_index[:, :, None].expand(-1, -1, self.d_model)) #rearrange to original order + + return patches_full_sorted + + def pretrain_decode(self, full_patches, channel_idx): + ''' + Decoding process, passing decoder and perform final projection + + Args: + concated_patches: masked & unmasked patches [batch_size, total_patch_num, d_model] + channel_idx: channel index of time series with shape [batch_size] + ''' + + reconstructed_ts = self.decoder.forward_uni(full_patches) + reconstructed_ts = self.RevIN(reconstructed_ts, channel_idx, mode='denorm') + + return reconstructed_ts + + #some functions for downstream tasks + def encode_multi_to_patch(self, multi_ts, masked_patch_index=None, unmasked_patch_index=None, imputation_point_mask=None): + ''' + Encode the unmaksed patches of multivariate time series to latent patches. + + Args: + multi_ts: time series with shape [batch_size, ts_d, ts_length] + masked_patch_index: masked patch index with shape [batch_size, ts_d, masked_patch_num] + unmasked_patch_index: unmasked patch index with shape [batch_size, ts_d, unmasked_patch_num] + point_mask: point mask with shape [batch_size, ts_d, ts_length], for imputation task. + ''' + + batch_size, ts_d, ts_length = multi_ts.shape + + ts_flatten = rearrange(multi_ts, 'batch_size ts_d ts_length -> (batch_size ts_d) ts_length') + channel_idx = torch.arange(self.data_dim)[None, :].expand(batch_size, -1).to(multi_ts.device) + channel_flatten = rearrange(channel_idx, 'batch_size ts_d -> (batch_size ts_d)') + + if masked_patch_index is not None and unmasked_patch_index is not None: # only encode the unmasked patches + masked_patch_flatten = rearrange(masked_patch_index, 'batch_size ts_d masked_patch_num -> (batch_size ts_d) masked_patch_num') + unmasked_patch_flatten = rearrange(unmasked_patch_index, 'batch_size ts_d unmasked_patch_num -> (batch_size ts_d) unmasked_patch_num') + else: + masked_patch_flatten, unmasked_patch_flatten = None, None + + if imputation_point_mask is not None: + imputation_point_mask_flatten = rearrange(imputation_point_mask, 'batch_size ts_d ts_length -> (batch_size ts_d) ts_length') + else: + imputation_point_mask_flatten = None + + encoded_patch_flatten = self.encode_uni_to_patch(ts_flatten, channel_flatten, masked_patch_flatten, unmasked_patch_flatten, imputation_point_mask_flatten) + encoded_patch = rearrange(encoded_patch_flatten, '(batch_size ts_d) patch_num d_model -> batch_size ts_d patch_num d_model', batch_size=batch_size) + + return encoded_patch + + def decode_patch_to_multi(self, encoded_patch_unmasked, masked_patch_index, unmasked_patch_index): + batch_size, ts_d, unmasked_patch_num, _ = encoded_patch_unmasked.shape + + flatten_encoded_patch_unmasked = rearrange(encoded_patch_unmasked, 'batch_size ts_d unmasked_patch_num d_model -> (batch_size ts_d) unmasked_patch_num d_model') + flatten_masked_patch_index = rearrange(masked_patch_index, 'batch_size ts_d masked_patch_num -> (batch_size ts_d) masked_patch_num') + flatten_unmasked_patch_index = rearrange(unmasked_patch_index, 'batch_size ts_d unmasked_patch_num -> (batch_size ts_d) unmasked_patch_num') + flatten_channel_idx = torch.arange(self.data_dim)[None, :].expand(batch_size, -1).reshape(batch_size * ts_d).to(encoded_patch_unmasked.device) + + flatten_encoded_patch_unmasked = self.enc_2_dec(flatten_encoded_patch_unmasked) + flatten_full_patch = self.patch_concatenate(flatten_encoded_patch_unmasked, flatten_channel_idx, flatten_masked_patch_index, flatten_unmasked_patch_index) + flatten_reconstructed_ts = self.pretrain_decode(flatten_full_patch, flatten_channel_idx) + + reconstructed_ts = rearrange(flatten_reconstructed_ts, '(batch_size ts_d) ts_len -> batch_size ts_d ts_len', batch_size=batch_size) + + return reconstructed_ts + diff --git a/models/pretrain_model/embed.py b/models/pretrain_model/embed.py new file mode 100644 index 0000000..4de76c8 --- /dev/null +++ b/models/pretrain_model/embed.py @@ -0,0 +1,144 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat + +class patch_embedding(nn.Module): + def __init__(self, patch_size, d_model): + super(patch_embedding, self).__init__() + self.patch_size = patch_size + self.linear = nn.Linear(patch_size, d_model, bias = False) + + def forward_uni(self, x): + #the univariate forward mode + + batch_size, ts_len = x.shape + + x_patch = rearrange(x, 'b (patch_num patch_size) -> b patch_num patch_size', patch_size = self.patch_size) + x_embed = self.linear(x_patch) #[batch_size, patch_num, d_model] + + return x_embed + + def forward_multi(self, x): + #the multivariate forward mode + + batch_size, ts_len, ts_dim = x.shape + + x_patch = rearrange(x, 'b (patch_num patch_size) d -> b d patch_num patch_size', patch_size = self.patch_size) + x_embed = self.linear(x_patch) #[batch_size, ts_dim, patch_num, d_model] + + return x_embed + +class learnable_position_embedding(nn.Module): + def __init__(self, d_model, max_len: int = 1000): + super().__init__() + self.d_model = d_model + self.position_embedding = nn.Parameter(torch.randn(max_len, d_model), requires_grad=True) + + def forward(self, patch_shape, index=None): + """Positional encoding + + Args: + patch_shape: shape of time series, should be [batch_size, patch_num, d_model] or [batch_size, ts_d, patch_num, d_model] + index (list or None): add positional embedding by index, [batch_size, patch_num] or [batch_size, ts_d, patch_num] + + Returns: + torch.tensor: output sequence + """ + if (len(patch_shape) == 3): + #for univariate time series + batch_size, patch_num, _ = patch_shape + position_embedding_expand = self.position_embedding[None, :, :].expand(batch_size, -1, -1) + if index is None: + pe = position_embedding_expand[:, :patch_num, :] #not assigned, 0 ~ patch_num - 1 + else: + index_expand = index[:, :, None].expand(-1, -1, self.d_model) + pe = position_embedding_expand.gather(1, index_expand) + return pe #[batch_size, patch_num, d_model] + + elif (len(patch_shape) == 4): + #for multivariate time series + batch_size, ts_d, patch_num, _ = patch_shape + position_embedding_expand = self.position_embedding[None, None, :, :].expand(batch_size, ts_d, -1, -1) + if index is None: + pe = position_embedding_expand[:, :, :patch_num, :] + else: + index_expand = index[:, :, :, None].expand(-1, -1, -1, self.d_model) + pe = position_embedding_expand.gather(2, index_expand) + return pe #[batch_size, ts_d, patch_num, d_model] + + +class RevIN(nn.Module): + def __init__(self, num_features: int, eps=1e-7, affine=True): + """ + :param num_features: the number of features or channels + :param eps: a value added for numerical stability + :param affine: if True, RevIN has learnable affine parameters + """ + super(RevIN, self).__init__() + self.num_features = num_features + self.eps = eps + self.affine = affine + if self.affine: + self._init_params() + + def forward(self, x, channel_idx, mode: str, mask=None): + if mode == 'norm': + self._get_statistics(x, mask) + x = self._normalize(x, channel_idx) + elif mode == 'denorm': + x = self._denormalize(x, channel_idx) + else: + raise NotImplementedError + return x + + def _init_params(self): + # initialize RevIN params: (C,) + self.affine_bias = nn.Parameter(torch.zeros(self.num_features)) + self.affine_weight = nn.Parameter(torch.ones(self.num_features)) + + def _get_statistics(self, x, mask): + ''' + x: time series with shape [batch_size, ts_length] + channel_idx: channel index of time series with shape [batch_size] + mask: binary mask with the same shape as x, 1 for non-sense values to be masked + ''' + if mask is not None: + masked_x = x.masked_fill(mask == 1, 0) + self.mean = (torch.sum(masked_x, dim=-1, keepdim=True) / torch.sum(mask == 0, dim=-1, keepdim=True)).detach() + diff = masked_x - self.mean + masked_diff = diff.masked_fill(mask == 1, 0) + self.stdev = (torch.sqrt(torch.sum(masked_diff * masked_diff, dim=-1, keepdim=True) / torch.sum(mask == 0, dim=-1, keepdim=True) + self.eps)).detach() + else: + self.mean = torch.mean(x, dim=-1, keepdim=True).detach() # [batch_size, 1] + self.stdev = torch.sqrt(torch.var(x, dim=-1, keepdim=True, unbiased=False) + self.eps).detach() # [batch_size, 1] + + def _normalize(self, x, channel_idx): + ''' + channel_idx: [batch_size] + x: [batch_size, ts_length] + ''' + x = x - self.mean + x = x / self.stdev + + if self.affine: + affine_weight = self.affine_weight.gather(0, channel_idx) # [batch_size] + affine_bias = self.affine_bias.gather(0, channel_idx) # [batch_size] + + x = x * affine_weight[:, None] + x = x + affine_bias[:, None] + + return x + + def _denormalize(self, x, channel_idx): + if self.affine: + affine_weight = self.affine_weight.gather(0, channel_idx) # [batch_size] + affine_bias = self.affine_bias.gather(0, channel_idx) # [batch_size] + + x = x - affine_bias[:, None] + x = x / (affine_weight[:, None] + self.eps * self.eps) + + x = x * self.stdev + x = x + self.mean + + return x \ No newline at end of file diff --git a/models/pretrain_model/encoder_decoder.py b/models/pretrain_model/encoder_decoder.py new file mode 100644 index 0000000..ea627ed --- /dev/null +++ b/models/pretrain_model/encoder_decoder.py @@ -0,0 +1,48 @@ +from torch import nn +from torch.nn import TransformerEncoder, TransformerEncoderLayer +from einops import rearrange, repeat + +class encoder(nn.Module): + def __init__(self, n_layers=3, d_model=256, n_heads=4, d_ff=512, dropout=0.): + super(encoder, self).__init__() + #the encoder does not handle the patch staff, just encoded the given patches + self.n_layers = n_layers + self.d_model = d_model + self.n_heads = n_heads + self.d_ff = d_ff + self.dropout = dropout + + encoder_layer = TransformerEncoderLayer(d_model, n_heads, d_ff, dropout, batch_first = True) + self.encoder_layers = TransformerEncoder(encoder_layer, n_layers) + + def forward_uni(self, patch, mask = None): + batch_size, patch_num, _ = patch.shape + + encoded_patch = self.encoder_layers(patch, src_key_padding_mask = mask) + + return encoded_patch + +class decoder(nn.Module): + def __init__(self, patch_size, n_layers=1, d_model=256, n_heads=4, d_ff=512, dropout=0.): + super(decoder, self).__init__() + #the decoder takes encoded tokens + indicating tokens as input, projects tokens to original space + self.patch_size = patch_size + self.n_layers = n_layers + self.d_model = d_model + self.n_heads = n_heads + self.d_ff = d_ff + self.dropout = dropout + + decoder_layer = TransformerEncoderLayer(d_model, n_heads, d_ff, dropout, batch_first = True) + self.decoder_layers = TransformerEncoder(decoder_layer, n_layers) + + self.output_layer = nn.Linear(d_model, patch_size) + + def forward_uni(self, patch): + batch_size, patch_num, _ = patch.shape + + decoded_patch = self.decoder_layers(patch) + decoded_ts = self.output_layer(decoded_patch) + decoded_ts = rearrange(decoded_ts, 'b patch_num patch_size -> b (patch_num patch_size)') + + return decoded_ts \ No newline at end of file diff --git a/run_detect.py b/run_detect.py new file mode 100644 index 0000000..d2cfa4b --- /dev/null +++ b/run_detect.py @@ -0,0 +1,74 @@ +import argparse +import os +import torch +import json +from exp.exp_detect import UP2ME_exp_detect + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='UP2ME for anomoaly detection') + + parser.add_argument('--is_training', type=int, default=1, help='status') + parser.add_argument('--IR_mode', action='store_true', help='whether to use immediate reaction mode', default=False) + + parser.add_argument('--root_path', type=str, default='./datasets/NIPS_Water/', help='root path of the data file') + parser.add_argument('--data_name', type=str, default='NIPS_Water', help='data name') + parser.add_argument('--valid_prop', type=float, default=0.2,help='valid proportion split from train set') + parser.add_argument('--checkpoints', type=str, default='./checkpoints_detect/', help='location to store model checkpoints') + + parser.add_argument('--pretrained_model_path', type=str, default='pretrain-library/U2MNIPS_Water-Base_dataNIPS_Water_dim9_patch10_minPatch5_maxPatch100\ + _mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/bestValid-step-270000.pth', help='location of the pretrained model') + parser.add_argument('--pretrain_args_path', type=str, default='pretrain-library/U2MNIPS_Water-Base_dataNIPS_Water_dim9_patch10_minPatch5_maxPatch100\ + _mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/args.json', help='location of the pretrained model parameters') + + parser.add_argument('--seg_len', type=int, default=100, help='the non-overlapping segment length') + parser.add_argument('--finetune_layers', type=int, default=1, help='forecast layers to finetune') + parser.add_argument('--dropout', type=float, default=0.2, help='dropout ratio for finetune layers') + parser.add_argument('--neighbor_num', type=int, default=10, help='number of neighbors for graph (for high dimensional data)') + + parser.add_argument('--anomaly_ratio', type=float, default=1, help='anomaly ratio in the dataset (in %)') + + parser.add_argument('--num_workers', type=int, default=0, help='data loader num workers') + parser.add_argument('--batch_size', type=int, default=128, help='batch size of train input data') + parser.add_argument('--train_epochs', type=int, default=20, help='train epochs') + parser.add_argument('--learning_rate', type=float, default=1e-4, help='optimizer initial learning rate') + parser.add_argument('--lradj', type=str, default='none', help='adjust learning rate') + parser.add_argument('--itr', type=int, default=1, help='experiments times') + parser.add_argument('--slide_step', type=int, default=10, help='sliding steps for the sliding window of train and valid') + parser.add_argument('--tolerance', type=int, default=3, help='tolerance for early stopping') + + parser.add_argument('--save_folder', type=str, default='./detect_result/', help='folder path to save the detection results') + parser.add_argument('--save_pred', action='store_true', help='whether to save the reconstructed MTS', default=False) + + parser.add_argument('--use_gpu', type=bool, default=True, help='use gpu') + parser.add_argument('--gpu', type=int, default=0, help='gpu') + parser.add_argument('--use_multi_gpu', action='store_true', help='use multiple gpus', default=False) + parser.add_argument('--devices', type=str, default='0,1,2,3',help='device ids of multile gpus') + + parser.add_argument('--label', type=str, default='ft',help='labels to attach to setting') + + args = parser.parse_args() + + args.use_gpu = True if torch.cuda.is_available() and args.use_gpu else False + + if args.use_gpu and args.use_multi_gpu: + args.devices = args.devices.replace(' ','') + device_ids = args.devices.split(',') + args.device_ids = [int(id_) for id_ in device_ids] + args.gpu = args.device_ids[0] + + args.pretrain_args = json.load(open(args.pretrain_args_path)) + + print('Args in experiment:') + print(args) + + for i in range(args.itr): + setting = 'UP2ME_detect_{}_data{}_seglen{}_IRmode{}_ftlayers{}_neighbor{}_itr{}'.format(args.label, args.data_name, args.seg_len, \ + args.IR_mode, args.finetune_layers, args.neighbor_num, i) + exp = UP2ME_exp_detect(args) + + if args.is_training: + print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>'.format(setting)) + exp.train(setting) + + print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting)) + exp.test(setting, args.save_pred) \ No newline at end of file diff --git a/run_forecast.py b/run_forecast.py new file mode 100644 index 0000000..725233b --- /dev/null +++ b/run_forecast.py @@ -0,0 +1,88 @@ +import argparse +import os +import torch +import json +import random +import numpy as np + +from exp.exp_forecast import UP2ME_exp_forecast +from utils.tools import string_split + +parser = argparse.ArgumentParser(description='UP2ME for forecasting') + +parser.add_argument('--is_training', type=int, default=1, help='status') +parser.add_argument('--IR_mode', action='store_true', help='whether to use immediate reaction mode', default=False) + +parser.add_argument('--data_format', type=str, default='csv', help='data format') +parser.add_argument('--data_name', type=str, default='SMD', help='data name') +parser.add_argument('--root_path', type=str, default='./datasets/', help='root path of the data file') +parser.add_argument('--data_path', type=str, default='ETTh1.csv', help='data file') +parser.add_argument('--valid_prop', type=float, default=0.2, help='proportion of validation set, for numpy data only') +parser.add_argument('--data_split', type=str, default='0.7,0.1,0.2', help='train/val/test split, can be ratio or number') +parser.add_argument('--checkpoints', type=str, default='./checkpoints_forecast/', help='location to store model checkpoints') + +parser.add_argument('--pretrained_model_path', type=str, default='./checkpoints/U2M_ETTm1.csv_dim7_patch12_minPatch20_maxPatch200\ + _mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/epoch-80000.pth', help='location of the pretrained model') +parser.add_argument('--pretrain_args_path', type=str, default='./checkpoints/U2M_ETTm1.csv_dim7_patch12_minPatch20_maxPatch200\ + _mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/args.json', help='location of the pretrained model parameters') + +parser.add_argument('--in_len', type=int, default=720, help='input MTS length (T)') +parser.add_argument('--out_len', type=int, default=96, help='output MTS length (\tau)') + +parser.add_argument('--finetune_layers', type=int, default=1, help='forecast layers to finetune') +parser.add_argument('--dropout', type=float, default=0.0, help='dropout ratio for finetune layers') +parser.add_argument('--neighbor_num', type=int, default=10, help='number of neighbors for graph (for high dimensional data)') + +parser.add_argument('--num_workers', type=int, default=0, help='data loader num workers') +parser.add_argument('--batch_size', type=int, default=128, help='batch size of train input data') +parser.add_argument('--train_epochs', type=int, default=20, help='train epochs') +parser.add_argument('--learning_rate', type=float, default=1e-4, help='optimizer initial learning rate') +parser.add_argument('--lradj', type=str, default='none', help='adjust learning rate') +parser.add_argument('--tolerance', type=int, default=3, help='early stopping tolerance') +parser.add_argument('--itr', type=int, default=1, help='experiments times') +parser.add_argument('--slide_step', type=int, default=10, help='sliding steps for the sliding window of train and valid') + +parser.add_argument('--save_pred', action='store_true', help='whether to save the predicted future MTS', default=False) + +parser.add_argument('--use_gpu', type=bool, default=True, help='use gpu') +parser.add_argument('--gpu', type=int, default=0, help='gpu') +parser.add_argument('--use_multi_gpu', action='store_true', help='use multiple gpus', default=False) +parser.add_argument('--devices', type=str, default='0,1,2,3', help='device ids of multile gpus') + +parser.add_argument('--label', type=str, default='ft', help='labels to attach to setting') + +args = parser.parse_args() + +args.data_split = string_split(args.data_split) +args.use_gpu = True if torch.cuda.is_available() and args.use_gpu else False + +if args.use_gpu and args.use_multi_gpu: + args.devices = args.devices.replace(' ', '') + device_ids = args.devices.split(',') + args.device_ids = [int(id_) for id_ in device_ids] + args.gpu = args.device_ids[0] + +args.pretrain_args = json.load(open(args.pretrain_args_path)) + +# fix random seed +torch.manual_seed(2023) +random.seed(2023) +np.random.seed(2023) +torch.cuda.manual_seed_all(2023) + +print('Args in experiment:') +print(args) + +for i in range(args.itr): + setting = 'U2M_forecast_data{}_dim{}_patch{}_dm{}_dff{}_heads{}_eLayer{}_dLayer{}_IRmode{}_ftLayer{}_neighbor{}_inlen{}_outlen{}_itr{}'.format(args.data_name, + args.pretrain_args['data_dim'], args.pretrain_args['patch_size'], args.pretrain_args['d_model'], args.pretrain_args['d_ff'], + args.pretrain_args['n_heads'], args.pretrain_args['e_layers'], args.pretrain_args['d_layers'], args.IR_mode, args.finetune_layers, args.neighbor_num, + args.in_len, args.out_len, i) + exp = UP2ME_exp_forecast(args) + + if args.is_training: + print('>>>>>>>start training : {}>>>>>>>>>>>>>>>>>>>>>>>>>>'.format(setting)) + exp.train(setting) + + print('>>>>>>>testing : {}<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<'.format(setting)) + exp.test(setting, args.save_pred) \ No newline at end of file diff --git a/scripts/detect_scripts/NIPS_Water.sh b/scripts/detect_scripts/NIPS_Water.sh new file mode 100644 index 0000000..ab87c60 --- /dev/null +++ b/scripts/detect_scripts/NIPS_Water.sh @@ -0,0 +1,13 @@ +python run_detect.py --root_path ./datasets/NIPS_Water --data_name NIPS_Water \ +--pretrained_model_path ./pretrain-library/U2MNIPS_Water-Base_dataNIPS_Water_dim9_patch10_minPatch5_maxPatch100_mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/bestValid-step-500000.pth \ +--pretrain_args_path ./pretrain-library/U2MNIPS_Water-Base_dataNIPS_Water_dim9_patch10_minPatch5_maxPatch100_mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/args.json \ +--seg_len 100 --anomaly_ratio 2 \ +--batch_size 64 --gpu 0 \ +--is_training 0 --IR_mode + +python run_detect.py --root_path ./datasets/NIPS_Water --data_name NIPS_Water \ +--pretrained_model_path ./pretrain-library/U2MNIPS_Water-Base_dataNIPS_Water_dim9_patch10_minPatch5_maxPatch100_mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/bestValid-step-500000.pth \ +--pretrain_args_path ./pretrain-library/U2MNIPS_Water-Base_dataNIPS_Water_dim9_patch10_minPatch5_maxPatch100_mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/args.json \ +--seg_len 100 --anomaly_ratio 2 \ +--finetune_layers 1 --dropout 0.2 --neighbor_num 5 --slide_step 1 --learning_rate 1e-5 --train_epochs 10 --tolerance 3 --batch_size 64 --gpu 0 \ +--is_training 1 \ No newline at end of file diff --git a/scripts/detect_scripts/PSM.sh b/scripts/detect_scripts/PSM.sh new file mode 100644 index 0000000..6de7a12 --- /dev/null +++ b/scripts/detect_scripts/PSM.sh @@ -0,0 +1,13 @@ +python run_detect.py --root_path ./datasets/PSM/ --data_name PSM \ +--pretrained_model_path ./pretrain-library/U2MPSM-Base_dataPSM_dim25_patch10_minPatch5_maxPatch100_mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/bestValid-step-490000.pth \ +--pretrain_args_path ./pretrain-library/U2MPSM-Base_dataPSM_dim25_patch10_minPatch5_maxPatch100_mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/args.json \ +--seg_len 100 --anomaly_ratio 1 \ +--batch_size 32 --gpu 0 \ +--is_training 0 --IR_mode + +python run_detect.py --root_path ./datasets/PSM/ --data_name PSM \ +--pretrained_model_path ./pretrain-library/U2MPSM-Base_dataPSM_dim25_patch10_minPatch5_maxPatch100_mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/bestValid-step-490000.pth \ +--pretrain_args_path ./pretrain-library/U2MPSM-Base_dataPSM_dim25_patch10_minPatch5_maxPatch100_mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/args.json \ +--seg_len 100 --anomaly_ratio 1 \ +--finetune_layers 1 --dropout 0.2 --neighbor_num 10 --slide_step 10 --learning_rate 1e-5 --train_epochs 10 --tolerance 3 --batch_size 32 --gpu 0 \ +--is_training 1 \ No newline at end of file diff --git a/scripts/detect_scripts/SMD.sh b/scripts/detect_scripts/SMD.sh new file mode 100644 index 0000000..ee3c630 --- /dev/null +++ b/scripts/detect_scripts/SMD.sh @@ -0,0 +1,13 @@ +python run_detect.py --root_path ./datasets/SMD/ --data_name SMD \ +--pretrained_model_path ./pretrain-library/U2MSMD-Base_dataSMD_dim38_patch10_minPatch5_maxPatch100_mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/bestValid-step-140000.pth \ +--pretrain_args_path ./pretrain-library/U2MSMD-Base_dataSMD_dim38_patch10_minPatch5_maxPatch100_mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/args.json \ +--seg_len 100 --anomaly_ratio 0.5 \ +--batch_size 16 --gpu 0 \ +--is_training 0 --IR_mode + +python run_detect.py --root_path ./datasets/SMD/ --data_name SMD \ +--pretrained_model_path ./pretrain-library/U2MSMD-Base_dataSMD_dim38_patch10_minPatch5_maxPatch100_mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/bestValid-step-140000.pth \ +--pretrain_args_path ./pretrain-library/U2MSMD-Base_dataSMD_dim38_patch10_minPatch5_maxPatch100_mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/args.json \ +--seg_len 100 --anomaly_ratio 0.5 \ +--finetune_layers 1 --dropout 0.2 --neighbor_num 10 --slide_step 10 --learning_rate 1e-5 --train_epochs 10 --tolerance 3 --batch_size 64 --gpu 0 \ +--is_training 1 \ No newline at end of file diff --git a/scripts/detect_scripts/SWaT.sh b/scripts/detect_scripts/SWaT.sh new file mode 100644 index 0000000..463d1d2 --- /dev/null +++ b/scripts/detect_scripts/SWaT.sh @@ -0,0 +1,13 @@ +python run_detect.py --root_path ./datasets/SWaT/ --data_name SWaT \ +--pretrained_model_path ./pretrain-library/U2MSWaT-Base_dataSWaT_dim51_patch10_minPatch5_maxPatch100_mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/bestValid-step-500000.pth \ +--pretrain_args_path ./pretrain-library/U2MSWaT-Base_dataSWaT_dim51_patch10_minPatch5_maxPatch100_mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/args.json \ +--seg_len 100 --anomaly_ratio 1 \ +--batch_size 32 --gpu 0 \ +--is_training 0 --IR_mode + +python run_detect.py --root_path ./datasets/SWaT/ --data_name SWaT \ +--pretrained_model_path ./pretrain-library/U2MSWaT-Base_dataSWaT_dim51_patch10_minPatch5_maxPatch100_mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/bestValid-step-500000.pth \ +--pretrain_args_path ./pretrain-library/U2MSWaT-Base_dataSWaT_dim51_patch10_minPatch5_maxPatch100_mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/args.json \ +--seg_len 100 --anomaly_ratio 1 \ +--finetune_layers 1 --dropout 0.2 --neighbor_num 10 --slide_step 10 --learning_rate 1e-5 --train_epochs 10 --tolerance 3 --batch_size 32 --gpu 0 \ +--is_training 1 \ No newline at end of file diff --git a/scripts/forecast_scripts/ECL.sh b/scripts/forecast_scripts/ECL.sh new file mode 100644 index 0000000..6907b07 --- /dev/null +++ b/scripts/forecast_scripts/ECL.sh @@ -0,0 +1,19 @@ +for out_len in 96 192 336 720 +do + #immediate reaction forecast + python run_forecast.py --data_format csv --data_name ECL --root_path ./datasets/ECL/ --data_path ECL.csv \ + --data_split 0.7,0.1,0.2 --checkpoints ./forecast_checkpoints/ \ + --pretrained_model_path pretrain-library/U2MECL-Base_dataElectricity_dim321_patch12_minPatch20_maxPatch200_mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/bestValid-step-475000.pth \ + --pretrain_args_path pretrain-library/U2MECL-Base_dataElectricity_dim321_patch12_minPatch20_maxPatch200_mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/args.json \ + --in_len 336 --out_len $out_len \ + --batch_size 16 --gpu 0 --is_training 0 --IR_mode + + #finetune + python run_forecast.py --data_format csv --data_name ECL --root_path ./datasets/ECL/ --data_path ECL.csv \ + --data_split 0.7,0.1,0.2 --checkpoints ./forecast_checkpoints/ \ + --pretrained_model_path pretrain-library/U2MECL-Base_dataElectricity_dim321_patch12_minPatch20_maxPatch200_mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/bestValid-step-475000.pth \ + --pretrain_args_path pretrain-library/U2MECL-Base_dataElectricity_dim321_patch12_minPatch20_maxPatch200_mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/args.json \ + --in_len 336 --out_len $out_len \ + --finetune_layers 1 --neighbor_num 10 --dropout 0.2 --learning_rate 1e-5 --tolerance 3 \ + --batch_size 16 --use_multi_gpu --devices 0,1 --is_training 1 +done \ No newline at end of file diff --git a/scripts/forecast_scripts/ETTm1.sh b/scripts/forecast_scripts/ETTm1.sh new file mode 100644 index 0000000..f7f8540 --- /dev/null +++ b/scripts/forecast_scripts/ETTm1.sh @@ -0,0 +1,19 @@ +for out_len in 96 192 336 720 +do + #immediate reaction forecast + python run_forecast.py --data_format csv --data_name ETTm1 --root_path ./datasets/ETT/ --data_path ETTm1.csv \ + --data_split 34560,11520,11520 --checkpoints ./forecast_checkpoints/ \ + --pretrained_model_path pretrain-library/U2METTm1-Base_dataETTm1_dim7_patch12_minPatch20_maxPatch200_mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/bestValid-step-35000.pth \ + --pretrain_args_path pretrain-library/U2METTm1-Base_dataETTm1_dim7_patch12_minPatch20_maxPatch200_mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/args.json \ + --in_len 336 --out_len $out_len \ + --gpu 0 --is_training 0 --IR_mode + + #finetune + python run_forecast.py --data_format csv --data_name ETTm1 --root_path ./datasets/ETT/ --data_path ETTm1.csv \ + --data_split 34560,11520,11520 --checkpoints ./forecast_checkpoints/ \ + --pretrained_model_path pretrain-library/U2METTm1-Base_dataETTm1_dim7_patch12_minPatch20_maxPatch200_mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/bestValid-step-35000.pth \ + --pretrain_args_path pretrain-library/U2METTm1-Base_dataETTm1_dim7_patch12_minPatch20_maxPatch200_mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/args.json \ + --in_len 336 --out_len $out_len \ + --finetune_layers 1 --neighbor_num 4 --dropout 0.2 --learning_rate 1e-5 --tolerance 3 \ + --gpu 0 --is_training 1 +done \ No newline at end of file diff --git a/scripts/forecast_scripts/NIPS_Water.sh b/scripts/forecast_scripts/NIPS_Water.sh new file mode 100644 index 0000000..0be0489 --- /dev/null +++ b/scripts/forecast_scripts/NIPS_Water.sh @@ -0,0 +1,17 @@ +for out_len in 50 100 150 200 +do + #immediate forecast + python run_forecast.py --data_format npy --data_name NIPS_Water --root_path ./datasets/NIPS_Water/ --valid_prop 0.2 \ + --pretrained_model_path pretrain-library/U2MNIPS_Water-Base_dataNIPS_Water_dim9_patch10_minPatch5_maxPatch100_mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/bestValid-step-500000.pth \ + --pretrain_args_path pretrain-library/U2MNIPS_Water-Base_dataNIPS_Water_dim9_patch10_minPatch5_maxPatch100_mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/args.json \ + --in_len 400 --out_len $out_len \ + --gpu 0 --is_training 0 --IR_mode + + #finetune + python run_forecast.py --data_format npy --data_name NIPS_Water --root_path ./datasets/NIPS_Water/ --valid_prop 0.2 \ + --pretrained_model_path pretrain-library/U2MNIPS_Water-Base_dataNIPS_Water_dim9_patch10_minPatch5_maxPatch100_mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/bestValid-step-500000.pth \ + --pretrain_args_path pretrain-library/U2MNIPS_Water-Base_dataNIPS_Water_dim9_patch10_minPatch5_maxPatch100_mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/args.json \ + --in_len 400 --out_len $out_len \ + --finetune_layers 1 --neighbor_num 5 --dropout 0.2 --learning_rate 1e-5 --tolerance 3 --slide_step 1 \ + --gpu 0 --is_training 1 +done \ No newline at end of file diff --git a/scripts/forecast_scripts/PSM.sh b/scripts/forecast_scripts/PSM.sh new file mode 100644 index 0000000..457faa3 --- /dev/null +++ b/scripts/forecast_scripts/PSM.sh @@ -0,0 +1,17 @@ +for out_len in 50 100 150 200 +do + #immediate forecast + python run_forecast.py --data_format npy --data_name PSM --root_path ./datasets/PSM/ --valid_prop 0.2 \ + --pretrained_model_path pretrain-library/U2MPSM-Base_dataPSM_dim25_patch10_minPatch5_maxPatch100_mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/bestValid-step-490000.pth \ + --pretrain_args_path pretrain-library/U2MPSM-Base_dataPSM_dim25_patch10_minPatch5_maxPatch100_mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/args.json \ + --in_len 400 --out_len $out_len \ + --batch_size 64 --gpu 0 --is_training 0 --IR_mode + + #finetune + python run_forecast.py --data_format npy --data_name PSM --root_path ./datasets/PSM/ --valid_prop 0.2 \ + --pretrained_model_path pretrain-library/U2MPSM-Base_dataPSM_dim25_patch10_minPatch5_maxPatch100_mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/bestValid-step-490000.pth \ + --pretrain_args_path pretrain-library/U2MPSM-Base_dataPSM_dim25_patch10_minPatch5_maxPatch100_mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/args.json \ + --in_len 400 --out_len $out_len \ + --finetune_layers 1 --neighbor_num 10 --dropout 0.2 --learning_rate 1e-5 --tolerance 3 --slide_step 10 \ + --batch_size 64 --gpu 0 --is_training 1 +done \ No newline at end of file diff --git a/scripts/forecast_scripts/SMD.sh b/scripts/forecast_scripts/SMD.sh new file mode 100644 index 0000000..884e2fb --- /dev/null +++ b/scripts/forecast_scripts/SMD.sh @@ -0,0 +1,17 @@ +for out_len in 50 100 150 200 +do + #immediate forecast + python run_forecast.py --data_format npy --data_name SMD --root_path ./datasets/SMD/ --valid_prop 0.2 \ + --pretrained_model_path pretrain-library/U2MSMD-Base_dataSMD_dim38_patch10_minPatch5_maxPatch100_mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/bestValid-step-140000.pth \ + --pretrain_args_path pretrain-library/U2MSMD-Base_dataSMD_dim38_patch10_minPatch5_maxPatch100_mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/args.json \ + --in_len 400 --out_len $out_len \ + --gpu 0 --is_training 0 --IR_mode + + #finetune + python run_forecast.py --data_format npy --data_name SMD --root_path ./datasets/SMD/ --valid_prop 0.2 \ + --pretrained_model_path pretrain-library/U2MSMD-Base_dataSMD_dim38_patch10_minPatch5_maxPatch100_mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/bestValid-step-140000.pth \ + --pretrain_args_path pretrain-library/U2MSMD-Base_dataSMD_dim38_patch10_minPatch5_maxPatch100_mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/args.json \ + --in_len 400 --out_len $out_len \ + --finetune_layers 1 --neighbor_num 10 --dropout 0.2 --learning_rate 1e-5 --tolerance 3 --slide_step 10 \ + --gpu 0 --is_training 1 +done \ No newline at end of file diff --git a/scripts/forecast_scripts/SWaT.sh b/scripts/forecast_scripts/SWaT.sh new file mode 100644 index 0000000..7548664 --- /dev/null +++ b/scripts/forecast_scripts/SWaT.sh @@ -0,0 +1,17 @@ +for out_len in 50 100 150 200 +do + #immediate forecast + python run_forecast.py --data_format npy --data_name SWaT --root_path ./datasets/SWaT/ --valid_prop 0.2 \ + --pretrained_model_path pretrain-library/U2MSWaT-Base_dataSWaT_dim51_patch10_minPatch5_maxPatch100_mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/bestValid-step-500000.pth \ + --pretrain_args_path pretrain-library/U2MSWaT-Base_dataSWaT_dim51_patch10_minPatch5_maxPatch100_mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/args.json \ + --in_len 400 --out_len $out_len \ + --batch_size 64 --gpu 0 --is_training 0 --IR_mode + + #finetune + python run_forecast.py --data_format npy --data_name SWaT --root_path ./datasets/SWaT/ --valid_prop 0.2 \ + --pretrained_model_path pretrain-library/U2MSWaT-Base_dataSWaT_dim51_patch10_minPatch5_maxPatch100_mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/bestValid-step-500000.pth \ + --pretrain_args_path pretrain-library/U2MSWaT-Base_dataSWaT_dim51_patch10_minPatch5_maxPatch100_mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/args.json \ + --in_len 400 --out_len $out_len \ + --finetune_layers 1 --neighbor_num 10 --dropout 0.2 --learning_rate 1e-5 --tolerance 3 --slide_step 10 \ + --batch_size 64 --gpu 0 --is_training 1 +done \ No newline at end of file diff --git a/scripts/forecast_scripts/traffic.sh b/scripts/forecast_scripts/traffic.sh new file mode 100644 index 0000000..0224fa1 --- /dev/null +++ b/scripts/forecast_scripts/traffic.sh @@ -0,0 +1,19 @@ +for out_len in 96 192 336 720 +do + #immediate reaction forecast + python run_forecast.py --data_format csv --data_name traffic --root_path ./datasets/traffic/ --data_path traffic.csv \ + --data_split 0.7,0.1,0.2 --checkpoints ./forecast_checkpoints/ \ + --pretrained_model_path pretrain-library/U2MTraffic-Base_dataTraffic_dim862_patch12_minPatch20_maxPatch200_mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/bestValid-step-205000.pth \ + --pretrain_args_path pretrain-library/U2MTraffic-Base_dataTraffic_dim862_patch12_minPatch20_maxPatch200_mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/args.json \ + --in_len 336 --out_len $out_len \ + --batch_size 6 --gpu 0 --is_training 0 --IR_mode + + #finetune + python run_forecast.py --data_format csv --data_name traffic --root_path ./datasets/traffic/ --data_path traffic.csv \ + --data_split 0.7,0.1,0.2 --checkpoints ./forecast_checkpoints/ \ + --pretrained_model_path pretrain-library/U2MTraffic-Base_dataTraffic_dim862_patch12_minPatch20_maxPatch200_mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/bestValid-step-205000.pth \ + --pretrain_args_path pretrain-library/U2MTraffic-Base_dataTraffic_dim862_patch12_minPatch20_maxPatch200_mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/args.json \ + --in_len 336 --out_len $out_len \ + --finetune_layers 1 --neighbor_num 10 --dropout 0.2 --learning_rate 1e-5 --tolerance 3 \ + --batch_size 6 --use_multi_gpu --devices 0,1 --is_training 1 +done \ No newline at end of file diff --git a/scripts/forecast_scripts/weather.sh b/scripts/forecast_scripts/weather.sh new file mode 100644 index 0000000..893832f --- /dev/null +++ b/scripts/forecast_scripts/weather.sh @@ -0,0 +1,19 @@ +for out_len in 96 192 336 720 +do + #immediate reaction forecast + python run_forecast.py --data_format csv --data_name weather --root_path ./datasets/weather/ --data_path weather.csv \ + --data_split 0.7,0.1,0.2 --checkpoints ./forecast_checkpoints/ \ + --pretrained_model_path pretrain-library/U2MWeather-Base_dataweather_dim21_patch12_minPatch20_maxPatch200_mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/bestValid-step-60000.pth \ + --pretrain_args_path pretrain-library/U2MWeather-Base_dataweather_dim21_patch12_minPatch20_maxPatch200_mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/args.json \ + --in_len 336 --out_len $out_len \ + --gpu 0 --is_training 0 --IR_mode + + #finetune + python run_forecast.py --data_format csv --data_name weather --root_path ./datasets/weather/ --data_path weather.csv \ + --data_split 0.7,0.1,0.2 --checkpoints ./forecast_checkpoints/ \ + --pretrained_model_path pretrain-library/U2MWeather-Base_dataweather_dim21_patch12_minPatch20_maxPatch200_mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/bestValid-step-60000.pth \ + --pretrain_args_path pretrain-library/U2MWeather-Base_dataweather_dim21_patch12_minPatch20_maxPatch200_mask0.5_dm256_dff512_heads4_eLayer4_dLayer1_dropout0.0/args.json \ + --in_len 336 --out_len $out_len \ + --finetune_layers 1 --neighbor_num 10 --dropout 0.2 --learning_rate 1e-5 --tolerance 3 \ + --gpu 0 --is_training 1 +done \ No newline at end of file diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/metrics.py b/utils/metrics.py new file mode 100644 index 0000000..5b80c13 --- /dev/null +++ b/utils/metrics.py @@ -0,0 +1,135 @@ +import numpy as np + +def RSE(pred, true): + return np.sqrt(np.sum((true-pred)**2)) / np.sqrt(np.sum((true-true.mean())**2)) + +def CORR(pred, true): + u = ((true-true.mean(0))*(pred-pred.mean(0))).sum(0) + d = np.sqrt(((true-true.mean(0))**2*(pred-pred.mean(0))**2).sum(0)) + return (u/d).mean(-1) + +def MAE(pred, true): + return np.mean(np.abs(pred-true)) + +def MSE(pred, true): + return np.mean((pred-true)**2) + +def RMSE(pred, true): + return np.sqrt(MSE(pred, true)) + +def MAPE(pred, true): + return np.mean(np.abs((pred - true) / (true + 1e-8))) + +def MSPE(pred, true): + return np.mean(np.square((pred - true) / (true + 1e-8))) + +def metric(pred, true): + mae = MAE(pred, true) + mse = MSE(pred, true) + rmse = RMSE(pred, true) + mape = MAPE(pred, true) + mspe = MSPE(pred, true) + + return mae,mse,rmse,mape,mspe + +def segment_adjust(gt, pred): + ''' + as long as one point in a segment is labeled as anomaly, the whole segment is labeled as anomaly + delaited can be found in "Unsupervised Anomaly Detection via Variational Auto-Encoder for Seasonal KPIs in Web Applications (WWW18)" + gt: long continuous ground truth labels, [ts_len] + pred: long continuous predicted labels, [ts_len] + ''' + adjusted_flag = np.zeros_like(pred) + + for i in range(len(gt)): + if adjusted_flag[i]: + continue # this point has been adjusted + else: + if gt[i] == 1 and pred[i] == 1: + # detect an anomaly point, adjust the whole segment + for j in range(i, len(gt)): # adjust the right side + if gt[j] == 0: + break + else: + pred[j] = 1 + adjusted_flag[j] = 1 + for j in range(i, 0, -1): # adjust the left side + if gt[j] == 0: + break + else: + pred[j] = 1 + adjusted_flag[j] = 1 + else: + continue # gt=1, pred=0; gt=0, pred=0; gt=0, pred=1, do nothing + + return pred + + +def segment_adjust_flip(gt, old_pred, flip_idx): + ''' + flip one prediction from 0 to 1 then re-adjust the segment + used for compute the precision-recall curve + ''' + + new_pred = old_pred + delta_true_positive = 0 + delta_false_positive = 0 + if new_pred[flip_idx] == 1: #has already been adjusted by previous flip + return gt, new_pred, delta_true_positive, delta_false_positive + else: + new_pred[flip_idx] = 1 + if (gt[flip_idx] == 1): delta_true_positive += 1 + else: delta_false_positive += 1 + + if gt[flip_idx] == 1 and new_pred[flip_idx] == 1: + #detect an anomaly point, adjust the whole segment + for j in range(flip_idx + 1, len(gt)): #adjust the right side + if gt[j] == 0: + break + else: + new_pred[j] = 1 + delta_true_positive += 1 + + for j in range(flip_idx - 1, -1, -1): #adjust the left side + if gt[j] == 0: + break + else: + new_pred[j] = 1 + delta_true_positive += 1 + + return gt, new_pred, delta_true_positive, delta_false_positive + + +def adjusted_precision_recall_curve(gt, anomaly_score): + precisions = []; recalls = []; precisions.append(0.); recalls.append(0.) + + bound_idx = np.argsort(anomaly_score)[::-1] + + pred = np.zeros_like(anomaly_score) #start from all zero + + true_positive_num = 0 + false_positive_num = 0 + positive_num = np.sum(gt == 1) + flip_idx = bound_idx[0] + gt, pred, delta_tp, delta_fp = segment_adjust_flip(gt, pred, flip_idx) + true_positive_num += delta_tp + false_positive_num += delta_fp + + precision = 1.0 * true_positive_num / (true_positive_num + false_positive_num); precisions.append(precision) + recall = 1.0 * true_positive_num / positive_num; recalls.append(recall) + + for flip_idx in bound_idx[1:]: + gt, pred, delta_tp, delta_fp = segment_adjust_flip(gt, pred, flip_idx) + + true_positive_num += delta_tp + false_positive_num += delta_fp + precision = 1.0 * true_positive_num / (true_positive_num + false_positive_num) + recall = 1.0 * true_positive_num / positive_num + + precisions.append(precision); recalls.append(recall) + precisions = np.array(precisions) + recalls = np.array(recalls) + + AP = ((recalls[1:] - recalls[:-1]) * precisions[1:]).sum() + + return precisions, recalls, AP \ No newline at end of file diff --git a/utils/tools.py b/utils/tools.py new file mode 100644 index 0000000..1d8643f --- /dev/null +++ b/utils/tools.py @@ -0,0 +1,78 @@ +import numpy as np +import torch +import json +import argparse + +def adjust_learning_rate(optimizer, epoch, args): + if args.lradj == 'type1': + lr_adjust = {2: args.learning_rate * 0.5 ** 1, 4: args.learning_rate * 0.5 ** 2, + 6: args.learning_rate * 0.5 ** 3, 8: args.learning_rate * 0.5 ** 4, + 10: args.learning_rate * 0.5 ** 5} + elif args.lradj == 'type2': + lr_adjust = {5: args.learning_rate * 0.5 ** 1, 10: args.learning_rate * 0.5 ** 2, + 15: args.learning_rate * 0.5 ** 3, 20: args.learning_rate * 0.5 ** 4, + 25: args.learning_rate * 0.5 ** 5} + else: + lr_adjust = {} + if epoch in lr_adjust.keys(): + lr = lr_adjust[epoch] + for param_group in optimizer.param_groups: + param_group['lr'] = lr + print('Updating learning rate to {}'.format(lr)) + + +class EarlyStopping: + def __init__(self, patience=7, verbose=False, delta=0): + self.patience = patience + self.verbose = verbose + self.counter = 0 + self.best_score = None + self.early_stop = False + self.val_loss_min = np.Inf + self.delta = delta + self.best_model = None + + def __call__(self, val_loss, model, path): + score = -val_loss + if self.best_score is None: + self.best_score = score + self.save_checkpoint(val_loss, model, path) + elif score < self.best_score + self.delta: + self.counter += 1 + print(f'EarlyStopping counter: {self.counter} out of {self.patience}') + if self.counter >= self.patience: + self.early_stop = True + else: + self.best_score = score + self.save_checkpoint(val_loss, model, path) + self.counter = 0 + + def save_checkpoint(self, val_loss, model, path): + if self.verbose: + print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') + torch.save(model.state_dict(), path + '/' + 'checkpoint.pth') + self.best_model = model + self.val_loss_min = val_loss + +def load_args(filename): + with open(filename, 'r') as f: + args = json.load(f) + return args + + +def string_split(str_for_split): + str_no_space = str_for_split.replace(' ', '') + str_split = str_no_space.split(',') + value_list = [eval(x) for x in str_split] + + return value_list + +def str2bool(value): + if isinstance(value, bool): + return value + if value.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif value.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Invalid boolen type')