diff --git a/config/config.yaml b/config/config.yaml index c99f46b..e1e6295 100755 --- a/config/config.yaml +++ b/config/config.yaml @@ -9,6 +9,7 @@ path: event_length: 1024 mel_length: 256 num_rows_per_batch: 8 +split_frame_length: 2000 optim: lr: 6e-5 @@ -46,11 +47,13 @@ trainer: devices: ${devices} eval: - is_sanity_check: False - eval_dataset: - exp_tag_name: - audio_dir: - midi_dir: + eval_dataset: "Slakh" + eval_first_n_examples: 3 + eval_after_num_epoch: 400 + eval_per_epoch: 1 + exp_tag_name: "test_midis" + audio_dir: "/data/slakh2100_flac_redux/validation/*/mix_16k.wav" + midi_dir: "/data/slakh2100_flac_redux/validation/" defaults: - model: MT3Net diff --git a/config/config_slakh_f1=0.65.yaml b/config/config_slakh_f1_0.65.yaml similarity index 65% rename from config/config_slakh_f1=0.65.yaml rename to config/config_slakh_f1_0.65.yaml index d633aa2..b1fa6b3 100755 --- a/config/config_slakh_f1=0.65.yaml +++ b/config/config_slakh_f1_0.65.yaml @@ -9,11 +9,13 @@ path: event_length: 1024 mel_length: 256 num_rows_per_batch: 12 +split_frame_length: 2000 optim: lr: 2e-4 warmup_steps: 64500 num_epochs: ${num_epochs} + num_steps_per_epoch: 1289 # TODO: this is not good practice. Ideally we can get this from dataloader. min_lr: 1e-4 grad_accum: 1 @@ -21,10 +23,10 @@ grad_accum: 1 dataloader: train: batch_size: 1 - num_workers: 12 + num_workers: 4 val: batch_size: 1 - num_workers: 12 + num_workers: 0 modelcheckpoint: monitor: 'val_loss' @@ -44,7 +46,15 @@ trainer: strategy: "ddp_find_unused_parameters_false" devices: ${devices} +eval: + eval_dataset: "Slakh" + eval_first_n_examples: 1 + eval_after_num_epoch: 400 + eval_per_epoch: 1 + exp_tag_name: "test_midis" + audio_dir: "/data/slakh2100_flac_redux/validation/*/mix_16k.wav" + midi_dir: "/data/slakh2100_flac_redux/validation/" + defaults: - model: MT3Net - - dataset: Slakh - # TODO: we need to specify num_samples_per_batch here from 8 to 12 \ No newline at end of file + - dataset: Slakh \ No newline at end of file diff --git a/config/dataset/Slakh.yaml b/config/dataset/Slakh.yaml index 819ec20..bf9b802 100644 --- a/config/dataset/Slakh.yaml +++ b/config/dataset/Slakh.yaml @@ -7,6 +7,7 @@ train: inst_filename: inst_names.json audio_filename: mix_16k.wav num_rows_per_batch: ${num_rows_per_batch} + split_frame_length: ${split_frame_length} val: _target_: dataset.dataset_2_random.SlakhDataset # choosing which data class to use root_dir: "/data2/kinwai/slakh2100_flac_redux/validation/" @@ -16,6 +17,7 @@ val: inst_filename: inst_names.json audio_filename: mix_16k.wav num_rows_per_batch: ${num_rows_per_batch} + split_frame_length: ${split_frame_length} test: root_dir: "/data/slakh2100_flac_redux/test" collate_fn: dataset.dataset_2_random.collate_fn diff --git a/contrib/spectrograms.py b/contrib/spectrograms.py index 567912a..8c80614 100755 --- a/contrib/spectrograms.py +++ b/contrib/spectrograms.py @@ -18,6 +18,8 @@ from ddsp import spectral_ops import tensorflow as tf +tf.config.set_visible_devices([], 'GPU') + # defaults for spectrogram config DEFAULT_SAMPLE_RATE = 16000 diff --git a/contrib/vocabularies.py b/contrib/vocabularies.py index 17ceff5..fc516fd 100755 --- a/contrib/vocabularies.py +++ b/contrib/vocabularies.py @@ -24,6 +24,7 @@ import seqio import t5.data import tensorflow as tf +tf.config.set_visible_devices([], 'GPU') DECODED_EOS_ID = -1 diff --git a/dataset/dataset_2_random.py b/dataset/dataset_2_random.py index 1a7fdf5..91759d1 100755 --- a/dataset/dataset_2_random.py +++ b/dataset/dataset_2_random.py @@ -11,7 +11,7 @@ import librosa import note_seq from glob import glob -from contrib import event_codec, note_sequences, spectrograms, vocabularies, run_length_encoding, metrics_utils +from contrib import spectrograms, vocabularies, note_sequences, run_length_encoding, metrics_utils, event_codec from contrib.preprocessor import slakh_class_to_program_and_is_drum, add_track_to_notesequence, PitchBendError import soundfile as sf @@ -33,7 +33,8 @@ def __init__( midi_folder='MIDI', inst_filename='inst_names.json', shuffle=True, - num_rows_per_batch=8 + num_rows_per_batch=8, + split_frame_length=2000 ) -> None: super().__init__() self.spectrogram_config = spectrograms.SpectrogramConfig() @@ -52,6 +53,7 @@ def __init__( self.onsets_only = onsets_only self.tie_token = self.codec.encode_event(event_codec.Event('tie', 0)) if self.include_ties else None self.num_rows_per_batch = num_rows_per_batch + self.split_frame_length = split_frame_length def _build_dataset(self, root_dir, shuffle=True): df = [] @@ -79,7 +81,6 @@ def _audio_to_frames( [0, frame_size - len(samples) % frame_size], mode='constant') - frames = spectrograms.split_audio(samples, self.spectrogram_config) num_frames = len(samples) // frame_size @@ -365,17 +366,20 @@ def __getitem__(self, idx): if sr != self.spectrogram_config.sample_rate: audio = librosa.resample(audio, orig_sr=sr, target_sr=self.spectrogram_config.sample_rate) row = self._tokenize(ns, audio, inst_names) - - # NOTE: by default, this is self._split_frame(row, length=2000) - # this does not guarantee the chunks in `rows` to be contiguous. - # if we need to ensure that the chunks in `rows` to be contiguous, use: - # rows = self._split_frame(row, length=self.mel_length) - rows = self._split_frame(row) + + # NOTE: if split_frame_length != mel_length, then it does not guarantee the chunks in `rows` to be contiguous. + rows = self._split_frame(row, length=self.split_frame_length) inputs, targets, frame_times, num_insts = [], [], [], [] - if len(rows) > self.num_rows_per_batch: - start_idx = random.randint(0, len(rows) - self.num_rows_per_batch) - rows = rows[start_idx : start_idx + self.num_rows_per_batch] + + # randomly select N rows for training, as the GPU memory is limited + if self.num_rows_per_batch == -1: + # choose all rows in this case + pass + else: + if len(rows) > self.num_rows_per_batch: + start_idx = random.randint(0, len(rows) - self.num_rows_per_batch) + rows = rows[start_idx : start_idx + self.num_rows_per_batch] predictions = [] # wavs = [] @@ -520,16 +524,6 @@ def token_to_idx(self, token_str): def collate_fn(lst): inputs = torch.cat([k[0] for k in lst]) targets = torch.cat([k[1] for k in lst]) - # num_insts = torch.cat([k[2] for k in lst]) - - # add random shuffling here - # indices = np.arange(inputs.shape[0]) - # np.random.shuffle(indices) - # indices = torch.from_numpy(indices) - # inputs = inputs[indices] - # targets = targets[indices] - # num_insts = num_insts[indices] - return inputs, targets if __name__ == '__main__': diff --git a/evaluate.py b/evaluate.py index 4900f5d..b1e508f 100755 --- a/evaluate.py +++ b/evaluate.py @@ -328,8 +328,6 @@ def func(item): mean_scores = {k: np.mean(v) for k, v in scores.items() if k != "F1 by program"} - for key in sorted(list(mean_scores)): - print("{}: {:.4}".format(key, mean_scores[key])) if enable_instrument_eval: print("====") @@ -362,6 +360,8 @@ def func(item): print("{}: {:.4}".format(d[key], program_f1_dict[key])) elif key * 8 in program_f1_dict: print("{}: {:.4}".format(d[key], program_f1_dict[key * 8])) + + return mean_scores if __name__ == "__main__": diff --git a/inference.py b/inference.py index 670ca61..89a20fc 100755 --- a/inference.py +++ b/inference.py @@ -3,13 +3,14 @@ import os from typing import List +import torch.nn as nn +import torch + import numpy as np from tqdm import tqdm from models.t5 import T5ForConditionalGeneration, T5Config from models.t5_xl import T5WithXLDecoder, T5Config # from models.t5_xl_instrument import T5WithXLDecoder, T5Config -import torch.nn as nn -import torch from contrib import spectrograms, vocabularies, note_sequences, metrics_utils import note_seq import traceback @@ -58,7 +59,6 @@ def __init__( num_velocity_bins=1)) self.vocab = vocabularies.vocabulary_from_codec(self.codec) self.device = device - self.model.to(self.device) self.mel_norm = mel_norm # if "pretrained/mt3.pth" in weight_path: @@ -72,7 +72,6 @@ def _audio_to_frames(self, audio): frame_size = spectrogram_config.hop_width padding = [0, frame_size - len(audio) % frame_size] audio = np.pad(audio, padding, mode='constant') - print('audio', audio.shape, 'frame_size', frame_size) frames = spectrograms.split_audio(audio, spectrogram_config) num_frames = len(audio) // frame_size times = np.arange(num_frames) / \ @@ -167,22 +166,18 @@ def inference( invalid_programs = self._get_program_ids(valid_programs) else: invalid_programs = None - # print('preprocessing', audio_path) inputs, frame_times = self._preprocess(audio) inputs_tensor = torch.from_numpy(inputs) results = [] inputs_tensor, frame_times = self._batching(inputs_tensor, frame_times, batch_size=batch_size) - print('inferencing', audio_path) if self.contiguous_inference: inputs_tensor = torch.cat(inputs_tensor, dim=0) frame_times = [torch.tensor(k) for k in frame_times] frame_times = torch.cat(frame_times, dim=0) - print('inputs_tensor', inputs_tensor.shape, frame_times.shape) inputs_tensor = [inputs_tensor] frame_times = [frame_times] - self.model.cuda() for idx, batch in enumerate(inputs_tensor): batch = batch.to(self.device) @@ -200,7 +195,6 @@ def inference( filename = audio_path.split('/')[-1].split('.')[0] outpath = f'./out/{filename}.mid' os.makedirs('/'.join(outpath.split('/')[:-1]), exist_ok=True) - print("saving", outpath) note_seq.sequence_proto_to_midi_file(event, outpath) except Exception as e: diff --git a/midi_script.py b/midi_script.py index 022861e..ae54124 100755 --- a/midi_script.py +++ b/midi_script.py @@ -7,7 +7,9 @@ Use this for evaluation!!! """ -midis = sorted(glob.glob("/workspace/data/dataset/slakh2100_flac_redux/test/*/MIDI/")) +# NOTE: change the following path accordingly +midis = sorted(glob.glob("/data/slakh2100_flac_redux/validation/*/MIDI/")) +# midis = sorted(glob.glob("/workspace/data/dataset/slakh2100_flac_redux/test/*/MIDI/")) for midi in tqdm(midis): stems = sorted(glob.glob(midi + "*.mid")) insts = [] diff --git a/tasks/mt3_net.py b/tasks/mt3_net.py index a1388f6..77ebdf0 100644 --- a/tasks/mt3_net.py +++ b/tasks/mt3_net.py @@ -1,18 +1,37 @@ from torch.optim import AdamW from omegaconf import OmegaConf import pytorch_lightning as pl +from pytorch_lightning.utilities.rank_zero import rank_zero_only + import torch +# NOTE: On Linux, Python's multiprocessing module default start method is "fork". +# we need to set the start method to "spawn" +# in order to properly use TF methods in PyTorch dataloader. +try: + torch.multiprocessing.set_start_method('spawn') +except RuntimeError: + pass + import torch.nn as nn from transformers import T5Config + +import tensorflow as tf +tf.config.set_visible_devices([], 'GPU') + + from models.t5 import T5ForConditionalGeneration from utils import get_cosine_schedule_with_warmup +from test import get_scores +import glob + class MT3Net(pl.LightningModule): - def __init__(self, config, optim_cfg): + def __init__(self, config, optim_cfg, eval_cfg=None): super().__init__() self.config = config self.optim_cfg = optim_cfg + self.eval_cfg = eval_cfg T5config = T5Config.from_dict(OmegaConf.to_container(self.config)) self.model: nn.Module = T5ForConditionalGeneration(T5config) @@ -43,9 +62,28 @@ def validation_step(self, batch, batch_idx): ) self.log('val_loss', loss, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True) - # no need to use it in this stage - # return loss - + @rank_zero_only + def validation_epoch_end(self, outs): + if self.current_epoch >= self.eval_cfg.eval_after_num_epoch: + if self.current_epoch % self.eval_cfg.eval_per_epoch == 0: + eval_audio_dir = sorted(glob.glob(self.eval_cfg.audio_dir)) + if self.eval_cfg.eval_first_n_examples: + eval_audio_dir = eval_audio_dir[:self.eval_cfg.eval_first_n_examples] + + self.model.eval() + scores = get_scores( + model=self.model, + eval_audio_dir=eval_audio_dir, + eval_dataset="Slakh", + ground_truth_midi_dir=self.eval_cfg.midi_dir, + verbose=False + ) + print(scores) + + self.log('val_f1_flat', scores['Onset F1'], on_step=False, on_epoch=True) + self.log('val_f1_midi_class', scores['Onset + program F1 (midi_class)'], on_step=False, on_epoch=True) + self.log('val_f1_full', scores['Onset + program F1 (full)'], on_step=False, on_epoch=True) + def configure_optimizers(self): optimizer = AdamW(self.model.parameters(), self.optim_cfg.lr) warmup_step = int(self.optim_cfg.warmup_steps) @@ -69,7 +107,7 @@ def configure_optimizers(self): class MT3NetWeightedLoss(pl.LightningModule): - def __init__(self, config, optim_cfg): + def __init__(self, config, optim_cfg, eval_cfg=None): super().__init__() self.config = config self.optim_cfg = optim_cfg diff --git a/test.py b/test.py index f711fa5..7fc1a82 100755 --- a/test.py +++ b/test.py @@ -1,8 +1,8 @@ import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' -from inference import InferenceHandler import torch +from inference import InferenceHandler import glob import os from tqdm import tqdm @@ -12,28 +12,15 @@ from evaluate import evaluate_main -@hydra.main(config_path="config", config_name="config") -def main(cfg): - # convert .ckpt to .pth - assert cfg.path - assert cfg.path.endswith(".pt") or cfg.path.endswith("pth"), "Only .pt / .pth files are supported." - assert cfg.eval.exp_tag_name - assert cfg.eval.audio_dir - - pl = hydra.utils.instantiate(cfg.model, optim_cfg=cfg.optim) - model = pl.model - print(f"Loading weights from: {cfg.path}") - model.load_state_dict(torch.load(cfg.path)) - model.eval() - - dir = sorted(glob.glob(cfg.eval.audio_dir)) - if cfg.eval.eval_dataset == "NSynth": - # NOTE: skip vocals and mallets. Both Slakh and ComMU dataset does not have vocals, and ComMU does not have mallets. - dir = [d for d in dir if "vocal" not in d and "mallet" not in d] - if cfg.eval.is_sanity_check: - dir = dir[:100] - - mel_norm = False if "pretrained/mt3.pth" in cfg.path else True +def get_scores( + model, + eval_audio_dir=None, + mel_norm=True, + eval_dataset="Slakh", + exp_tag_name="test_midis", + ground_truth_midi_dir=None, + verbose=True +): handler = InferenceHandler( model=model, device=torch.device('cuda'), @@ -43,41 +30,94 @@ def main(cfg): def func(fname): audio, _ = librosa.load(fname, sr=16000) - if cfg.eval.eval_dataset == "NSynth": + if eval_dataset == "NSynth": audio = np.pad(audio, (int(0.05 * 16000), 0), "constant", constant_values=0) return audio - print("Total songs:", len(dir)) - - exp_tag_name = cfg.eval.exp_tag_name + if verbose: + print("Total songs:", len(eval_audio_dir)) - for fname in tqdm(dir): + for fname in tqdm(eval_audio_dir): audio = func(fname) - if cfg.eval.eval_dataset == "Slakh": + if eval_dataset == "Slakh": name = fname.split("/")[-2] outpath = os.path.join(exp_tag_name, name, "mix.mid") - elif cfg.eval.eval_dataset == "ComMU" or cfg.eval.eval_dataset == "NSynth": + elif eval_dataset == "ComMU" or eval_dataset == "NSynth": name = fname.split("/")[-1] outpath = os.path.join(exp_tag_name, name.replace(".wav", ".mid")) else: raise ValueError("Invalid dataset name.") handler.inference( - audio, - fname, + audio=audio, + audio_path=fname, outpath=outpath, batch_size=8, # changing this might affect results of sequential-inference models (e.g. XL). max_length=256 ) - print("Evaluating...") + if verbose: + print("Evaluating...") current_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir - ground_truth_midi_dir = cfg.eval.midi_dir if cfg.eval.midi_dir else cfg.dataset.test.root_dir - evaluate_main( - dataset_name=cfg.eval.eval_dataset, - test_midi_dir=os.path.join(current_dir, cfg.eval.exp_tag_name), + scores = evaluate_main( + dataset_name=eval_dataset, + test_midi_dir=os.path.join(current_dir, exp_tag_name), + ground_truth_midi_dir=ground_truth_midi_dir, + ) + + if verbose: + for key in sorted(list(scores)): + print("{}: {:.4}".format(key, scores[key])) + + return scores + + +@hydra.main(config_path="config", config_name="config", version_base="1.1") +def main(cfg): + assert cfg.path + assert cfg.path.endswith(".pt") or \ + cfg.path.endswith("pth") or \ + cfg.path.endswith("ckpt"), "Only .pt, .pth, .ckpt files are supported." + assert cfg.eval.exp_tag_name + assert cfg.eval.audio_dir + + pl = hydra.utils.instantiate(cfg.model, optim_cfg=cfg.optim) + print(f"Loading weights from: {cfg.path}") + if cfg.path.endswith(".ckpt"): + # load lightning module from checkpoint + pl = pl.load_from_checkpoint( + cfg.path, + config=cfg.model.config, + optim_cfg=cfg.optim, + ) + model = pl.model + else: + # load weights for nn.Module + model = pl.model + model.load_state_dict(torch.load(cfg.path)) + + model.eval() + if torch.cuda.is_available(): + model.cuda() + + dir = sorted(glob.glob(cfg.eval.audio_dir)) + if cfg.eval.eval_dataset == "NSynth": + # NOTE: skip vocals and mallets. Both Slakh and ComMU dataset does not have vocals, and ComMU does not have mallets. + dir = [d for d in dir if "vocal" not in d and "mallet" not in d] + if cfg.eval.eval_first_n_examples: + dir = dir[:cfg.eval.eval_first_n_examples] + + mel_norm = False if "pretrained/mt3.pth" in cfg.path else True + ground_truth_midi_dir = cfg.eval.midi_dir if cfg.eval.midi_dir else cfg.dataset.test.root_dir + + get_scores( + model, + eval_audio_dir=dir, + mel_norm=mel_norm, + eval_dataset=cfg.eval.eval_dataset, + exp_tag_name=cfg.eval.exp_tag_name, ground_truth_midi_dir=ground_truth_midi_dir, ) diff --git a/test.sh b/test.sh index b198f9e..01439fa 100755 --- a/test.sh +++ b/test.sh @@ -1,22 +1,23 @@ # ComMU -python3 test.py \ - --config-path "config" \ - --config-name "config_commu" \ - path="../../../pretrained/commu_mt3.pt" \ - eval.eval_dataset="ComMU" \ - eval.exp_tag_name="commu_mt3" \ - eval.audio_dir="/data/datasets/ComMU/dataset_processed/commu_audio_v2/test/*.wav" \ - hydra/job_logging=disabled \ +# python3 test.py \ +# --config-path "config" \ +# --config-name "config_commu" \ +# path="../../../pretrained/commu_mt3.pt" \ +# eval.eval_dataset="ComMU" \ +# eval.exp_tag_name="commu_mt3" \ +# eval.audio_dir="/data/datasets/ComMU/dataset_processed/commu_audio_v2/test/*.wav" \ +# hydra/job_logging=disabled \ # Slakh MT3 baseline -# python3 test.py \ -# path="../../../pretrained/mt3.pth" \ -# eval.eval_dataset="Slakh" \ -# eval.exp_tag_name="slakh_mt3_official" \ -# eval.audio_dir="/data/slakh2100_flac_redux/test/*/mix_16k.wav" \ -# hydra/job_logging=disabled \ -# eval.is_sanity_check=True \ +python3 test.py \ + path="../../../pretrained/mt3.pth" \ + eval.eval_dataset="Slakh" \ + eval.exp_tag_name="slakh_mt3_official" \ + eval.audio_dir="/data/slakh2100_flac_redux/test/*/mix_16k.wav" \ + eval.midi_dir="/data/slakh2100_flac_redux/test/" \ + eval.eval_first_n_examples=3 \ + hydra/job_logging=disabled \ # NSynth @@ -37,4 +38,4 @@ python3 test.py \ # eval.exp_tag_name="commu_mt3_on_nsynth" \ # eval.audio_dir="/data/nsynth-valid/audio/*.wav" \ # eval.midi_dir="/data/nsynth-valid/midi/" \ -# hydra/job_logging=disabled \ \ No newline at end of file +# hydra/job_logging=disabled \ diff --git a/train.py b/train.py index be35d97..c830ddd 100755 --- a/train.py +++ b/train.py @@ -12,19 +12,22 @@ import torch import pytorch_lightning as pl -import os import hydra from tasks.mt3_net import MT3Net -@hydra.main(config_path="config", config_name="config") +@hydra.main(config_path="config", config_name="config", version_base="1.1") # def main(config, model_config, result_dir, mode, path): def main(cfg): # set seed to ensure reproducibility pl.seed_everything(cfg.seed) - model = hydra.utils.instantiate(cfg.model, optim_cfg=cfg.optim) + model = hydra.utils.instantiate( + cfg.model, + optim_cfg=cfg.optim, + eval_cfg=cfg.eval, + ) logger = TensorBoardLogger(save_dir='.', name=f"{cfg.model_type}_{cfg.dataset_type}")