diff --git a/config/config_slakh_f1=0.65.yaml b/config/config_slakh_f1_0.65.yaml similarity index 60% rename from config/config_slakh_f1=0.65.yaml rename to config/config_slakh_f1_0.65.yaml index d633aa2..27dd291 100755 --- a/config/config_slakh_f1=0.65.yaml +++ b/config/config_slakh_f1_0.65.yaml @@ -9,11 +9,16 @@ path: event_length: 1024 mel_length: 256 num_rows_per_batch: 12 +split_frame_length: 256 +dataset_is_deterministic: False +dataset_is_randomize_tokens: True +dataset_use_tf_spectral_ops: False optim: lr: 2e-4 warmup_steps: 64500 num_epochs: ${num_epochs} + num_steps_per_epoch: 1289 # TODO: this is not good practice. Ideally we can get this from dataloader. min_lr: 1e-4 grad_accum: 1 @@ -21,10 +26,10 @@ grad_accum: 1 dataloader: train: batch_size: 1 - num_workers: 12 + num_workers: 2 val: batch_size: 1 - num_workers: 12 + num_workers: 0 modelcheckpoint: monitor: 'val_loss' @@ -32,6 +37,7 @@ modelcheckpoint: save_last: True save_top_k: 5 save_weights_only: False + every_n_epochs: 50 filename: '{epoch}-{step}-{val_loss:.4f}' trainer: @@ -43,6 +49,20 @@ trainer: log_every_n_steps: 100 strategy: "ddp_find_unused_parameters_false" devices: ${devices} + check_val_every_n_epoch: 10 + +eval: + is_sanity_check: False + eval_first_n_examples: + eval_after_num_epoch: 400 + eval_per_epoch: 1 + eval_dataset: + exp_tag_name: + audio_dir: + midi_dir: + contiguous_inference: + batch_size: 8 + use_tf_spectral_ops: False # change this to True if using pretrained/mt3.pth defaults: - model: MT3Net diff --git a/config/dataset/Slakh.yaml b/config/dataset/Slakh.yaml index 819ec20..9203fa7 100644 --- a/config/dataset/Slakh.yaml +++ b/config/dataset/Slakh.yaml @@ -7,6 +7,10 @@ train: inst_filename: inst_names.json audio_filename: mix_16k.wav num_rows_per_batch: ${num_rows_per_batch} + split_frame_length: ${split_frame_length} + is_deterministic: ${dataset_is_deterministic} + is_randomize_tokens: ${dataset_is_randomize_tokens} + use_tf_spectral_ops: ${dataset_use_tf_spectral_ops} val: _target_: dataset.dataset_2_random.SlakhDataset # choosing which data class to use root_dir: "/data2/kinwai/slakh2100_flac_redux/validation/" @@ -16,6 +20,10 @@ val: inst_filename: inst_names.json audio_filename: mix_16k.wav num_rows_per_batch: ${num_rows_per_batch} + split_frame_length: ${split_frame_length} + is_deterministic: ${dataset_is_deterministic} + is_randomize_tokens: ${dataset_is_randomize_tokens} + use_tf_spectral_ops: ${dataset_use_tf_spectral_ops} test: root_dir: "/data/slakh2100_flac_redux/test" collate_fn: dataset.dataset_2_random.collate_fn diff --git a/contrib/metrics_utils.py b/contrib/metrics_utils.py index 3d0d5e6..66f1307 100755 --- a/contrib/metrics_utils.py +++ b/contrib/metrics_utils.py @@ -21,9 +21,7 @@ from contrib import event_codec, note_sequences, run_length_encoding -import note_seq import numpy as np -import pretty_midi S = TypeVar('S') T = TypeVar('T') @@ -143,57 +141,4 @@ def event_predictions_to_ns( 'est_ns': ns, 'est_invalid_events': total_invalid_events, 'est_dropped_events': total_dropped_events, - } - - -def get_prettymidi_pianoroll(ns: note_seq.NoteSequence, fps: float, - is_drum: bool): - """Convert NoteSequence to pianoroll through pretty_midi.""" - for note in ns.notes: - if is_drum or note.end_time - note.start_time < 0.05: - # Give all drum notes a fixed length, and all others a min length - note.end_time = note.start_time + 0.05 - - pm = note_seq.note_sequence_to_pretty_midi(ns) - end_time = pm.get_end_time() - cc = [ - # all sound off - pretty_midi.ControlChange(number=120, value=0, time=end_time), - # all notes off - pretty_midi.ControlChange(number=123, value=0, time=end_time) - ] - pm.instruments[0].control_changes = cc - if is_drum: - # If inst.is_drum is set, pretty_midi will return an all zero pianoroll. - for inst in pm.instruments: - inst.is_drum = False - pianoroll = pm.get_piano_roll(fs=fps) - return pianoroll - - -def frame_metrics(ref_pianoroll: np.ndarray, - est_pianoroll: np.ndarray, - velocity_threshold: int) -> Tuple[float, float, float]: - """Frame Precision, Recall, and F1.""" - import sklearn - # Pad to same length - if ref_pianoroll.shape[1] > est_pianoroll.shape[1]: - diff = ref_pianoroll.shape[1] - est_pianoroll.shape[1] - est_pianoroll = np.pad( - est_pianoroll, [(0, 0), (0, diff)], mode='constant') - elif est_pianoroll.shape[1] > ref_pianoroll.shape[1]: - diff = est_pianoroll.shape[1] - ref_pianoroll.shape[1] - ref_pianoroll = np.pad( - ref_pianoroll, [(0, 0), (0, diff)], mode='constant') - - # For ref, remove any notes that are too quiet (consistent with Cerberus.) - ref_frames_bool = ref_pianoroll > velocity_threshold - # For est, keep all predicted notes. - est_frames_bool = est_pianoroll > 0 - - precision, recall, f1, _ = sklearn.metrics.precision_recall_fscore_support( - ref_frames_bool.flatten(), - est_frames_bool.flatten(), - labels=[True, False]) - - return precision[0], recall[0], f1[0] + } \ No newline at end of file diff --git a/contrib/spectrograms.py b/contrib/spectrograms.py index 567912a..d351926 100755 --- a/contrib/spectrograms.py +++ b/contrib/spectrograms.py @@ -15,9 +15,14 @@ """Audio spectrogram functions.""" import dataclasses +import torch +from torchaudio.transforms import MelSpectrogram +import librosa +import numpy as np -from ddsp import spectral_ops -import tensorflow as tf +# this is to suppress a warning from torch melspectrogram +import warnings +warnings.filterwarnings("ignore") # defaults for spectrogram config DEFAULT_SAMPLE_RATE = 16000 @@ -35,6 +40,7 @@ class SpectrogramConfig: sample_rate: int = DEFAULT_SAMPLE_RATE hop_width: int = DEFAULT_HOP_WIDTH num_mel_bins: int = DEFAULT_NUM_MEL_BINS + use_tf_spectral_ops: bool = False @property def abbrev_str(self): @@ -53,29 +59,63 @@ def frames_per_second(self): def split_audio(samples, spectrogram_config): - """Split audio into frames.""" - return tf.signal.frame( + """Split audio into frames using librosa.""" + if samples.shape[0] % spectrogram_config.hop_width != 0: + samples = np.pad( + samples, + (0, spectrogram_config.hop_width - samples.shape[0] % spectrogram_config.hop_width), + 'constant', + constant_values=0 + ) + return librosa.util.frame( samples, frame_length=spectrogram_config.hop_width, - frame_step=spectrogram_config.hop_width, - pad_end=True) - - -def compute_spectrogram(samples, spectrogram_config): - """Compute a mel spectrogram.""" - overlap = 1 - (spectrogram_config.hop_width / FFT_SIZE) - return spectral_ops.compute_logmel( - samples, - bins=spectrogram_config.num_mel_bins, - lo_hz=MEL_LO_HZ, - overlap=overlap, - fft_size=FFT_SIZE, - sample_rate=spectrogram_config.sample_rate) + hop_length=spectrogram_config.hop_width, + axis=-1).T + + +def compute_spectrogram( + samples, + spectrogram_config, +): + """ + Compute a mel spectrogram. + Due to multiprocessing issues running TF and PyTorch together, we use librosa + and only keep `spectral_ops.compute_logmel` for evaluation purposes. + """ + if spectrogram_config.use_tf_spectral_ops: + # NOTE: we only keep this for evaluating existing models + # This is because I find even with an equivalent PyTorch / librosa implementation + # that gives close-enough results (melspec MAE ~ 2e-3), the model output is still affected badly. + # lazy load + from ddsp import spectral_ops + overlap = 1 - (spectrogram_config.hop_width / FFT_SIZE) + return spectral_ops.compute_logmel( + samples, + bins=spectrogram_config.num_mel_bins, + lo_hz=MEL_LO_HZ, + overlap=overlap, + fft_size=FFT_SIZE, + sample_rate=spectrogram_config.sample_rate) + else: + transform = MelSpectrogram( + sample_rate=spectrogram_config.sample_rate, + n_fft=FFT_SIZE, + hop_length=spectrogram_config.hop_width, + n_mels=spectrogram_config.num_mel_bins, + f_min=MEL_LO_HZ, + power=1.0, + ) + samples = torch.from_numpy(samples).float() + S = transform(samples) + S[S<0] = 0 + S = torch.log(S + 1e-6) + return S.numpy().T def flatten_frames(frames): """Convert frames back into a flat array of samples.""" - return tf.reshape(frames, [-1]) + return np.reshape(frames, (-1,)) def input_depth(spectrogram_config): diff --git a/contrib/vocabularies.py b/contrib/vocabularies.py index 17ceff5..2a2413a 100755 --- a/contrib/vocabularies.py +++ b/contrib/vocabularies.py @@ -23,7 +23,6 @@ import note_seq import seqio import t5.data -import tensorflow as tf DECODED_EOS_ID = -1 @@ -220,7 +219,7 @@ def _decode_id(encoded_id): ids = [_decode_id(int(i)) for i in ids] return ids - def _encode_tf(self, token_ids: tf.Tensor) -> tf.Tensor: + def _encode_tf(self, token_ids): """Encode a list of tokens to a tf.Tensor. Args: @@ -229,46 +228,48 @@ def _encode_tf(self, token_ids: tf.Tensor) -> tf.Tensor: Returns: a 1d tf.Tensor with dtype tf.int32 """ - with tf.control_dependencies( - [tf.debugging.assert_less( - token_ids, tf.cast(self._num_regular_tokens, token_ids.dtype)), - tf.debugging.assert_greater_equal( - token_ids, tf.cast(0, token_ids.dtype)) - ]): - tf_ids = token_ids + self._num_special_tokens - return tf_ids - - def _decode_tf(self, ids: tf.Tensor) -> tf.Tensor: - """Decode in TensorFlow. - - The special tokens of PAD and UNK as well as extra_ids will be - replaced with DECODED_INVALID_ID in the output. If EOS is present, it and - all following tokens in the decoded output and will be represented by - DECODED_EOS_ID. - - Args: - ids: a 1d tf.Tensor with dtype tf.int32 - - Returns: - a 1d tf.Tensor with dtype tf.int32 - """ - # Create a mask that is true from the first EOS position onward. - # First, create an array that is True whenever there is an EOS, then cumsum - # that array so that every position after and including the first True is - # >1, then cast back to bool for the final mask. - eos_and_after = tf.cumsum( - tf.cast(tf.equal(ids, self.eos_id), tf.int32), exclusive=False, axis=-1) - eos_and_after = tf.cast(eos_and_after, tf.bool) - - return tf.where( - eos_and_after, - DECODED_EOS_ID, - tf.where( - tf.logical_and( - tf.greater_equal(ids, self._num_special_tokens), - tf.less(ids, self._base_vocab_size)), - ids - self._num_special_tokens, - DECODED_INVALID_ID)) + return None + # with tf.control_dependencies( + # [tf.debugging.assert_less( + # token_ids, tf.cast(self._num_regular_tokens, token_ids.dtype)), + # tf.debugging.assert_greater_equal( + # token_ids, tf.cast(0, token_ids.dtype)) + # ]): + # tf_ids = token_ids + self._num_special_tokens + # return tf_ids + + def _decode_tf(self, ids): + return None + # """Decode in TensorFlow. + + # The special tokens of PAD and UNK as well as extra_ids will be + # replaced with DECODED_INVALID_ID in the output. If EOS is present, it and + # all following tokens in the decoded output and will be represented by + # DECODED_EOS_ID. + + # Args: + # ids: a 1d tf.Tensor with dtype tf.int32 + + # Returns: + # a 1d tf.Tensor with dtype tf.int32 + # """ + # # Create a mask that is true from the first EOS position onward. + # # First, create an array that is True whenever there is an EOS, then cumsum + # # that array so that every position after and including the first True is + # # >1, then cast back to bool for the final mask. + # eos_and_after = tf.cumsum( + # tf.cast(tf.equal(ids, self.eos_id), tf.int32), exclusive=False, axis=-1) + # eos_and_after = tf.cast(eos_and_after, tf.bool) + + # return tf.where( + # eos_and_after, + # DECODED_EOS_ID, + # tf.where( + # tf.logical_and( + # tf.greater_equal(ids, self._num_special_tokens), + # tf.less(ids, self._base_vocab_size)), + # ids - self._num_special_tokens, + # DECODED_INVALID_ID)) def num_special_tokens(self): return self._num_special_tokens diff --git a/dataset/dataset_2_random.py b/dataset/dataset_2_random.py index 1a7fdf5..35e129a 100755 --- a/dataset/dataset_2_random.py +++ b/dataset/dataset_2_random.py @@ -1,9 +1,6 @@ import torch from torch.utils.data import Dataset, DataLoader -import tensorflow as tf -tf.config.set_visible_devices([], 'GPU') - import json import random from typing import Dict, List, Optional, Sequence, Tuple @@ -33,10 +30,15 @@ def __init__( midi_folder='MIDI', inst_filename='inst_names.json', shuffle=True, - num_rows_per_batch=8 + num_rows_per_batch=8, + split_frame_length=2000, + is_randomize_tokens=True, + is_deterministic=False, + use_tf_spectral_ops=True ) -> None: super().__init__() self.spectrogram_config = spectrograms.SpectrogramConfig() + self.spectrogram_config.use_tf_spectral_ops = use_tf_spectral_ops self.codec = vocabularies.build_codec(vocab_config=vocabularies.VocabularyConfig( num_velocity_bins=1)) self.vocab = vocabularies.vocabulary_from_codec(self.codec) @@ -52,6 +54,9 @@ def __init__( self.onsets_only = onsets_only self.tie_token = self.codec.encode_event(event_codec.Event('tie', 0)) if self.include_ties else None self.num_rows_per_batch = num_rows_per_batch + self.split_frame_length = split_frame_length + self.is_deterministic = is_deterministic + self.is_randomize_tokens = is_randomize_tokens def _build_dataset(self, root_dir, shuffle=True): df = [] @@ -213,14 +218,15 @@ def _run_length_encode_shifts( # NOTE: this needs to be uncommented if not using random-order augmentation # because random-order augmentation use `_remove_redundant_tokens` to replace this part - # is_redundant = False - # for i, (min_index, max_index) in enumerate(state_change_event_ranges): - # if (min_index <= event) and (event <= max_index): - # if current_state[i] == event: - # is_redundant = True - # current_state[i] = event - # if is_redundant: - # continue + if not self.is_randomize_tokens: + is_redundant = False + for i, (min_index, max_index) in enumerate(state_change_event_ranges): + if (min_index <= event) and (event <= max_index): + if current_state[i] == event: + is_redundant = True + current_state[i] = event + if is_redundant: + continue # Once we've reached a non-shift event, RLE all previous shift events # before outputting the non-shift event. @@ -319,7 +325,10 @@ def _random_chunk(self, row): random_length = input_length - self.mel_length if random_length < 1: return row - start_length = random.randint(0, random_length) + if self.is_deterministic: + start_length = 16 + else: + start_length = random.randint(0, random_length) for k in row.keys(): if k in ['inputs', 'input_times', 'input_event_start_indices', 'input_event_end_indices', 'input_state_event_indices']: new_row[k] = row[k][start_length:start_length+self.mel_length] @@ -358,23 +367,33 @@ def _postprocess_batch(self, result): result = result.cpu().numpy() return result - def __getitem__(self, idx): - row = self.df[idx] + def _preprocess_inputs(self, row): ns, inst_names = self._parse_midi(row['midi_path'], row['inst_names']) audio, sr = librosa.load(row['audio_path'], sr=None) if sr != self.spectrogram_config.sample_rate: audio = librosa.resample(audio, orig_sr=sr, target_sr=self.spectrogram_config.sample_rate) + + return ns, audio, inst_names + + def __getitem__(self, idx): + ns, audio, inst_names = self._preprocess_inputs(self.df[idx]) + row = self._tokenize(ns, audio, inst_names) # NOTE: by default, this is self._split_frame(row, length=2000) # this does not guarantee the chunks in `rows` to be contiguous. # if we need to ensure that the chunks in `rows` to be contiguous, use: - # rows = self._split_frame(row, length=self.mel_length) - rows = self._split_frame(row) + rows = self._split_frame(row, length=self.split_frame_length) + # rows = self._split_frame(row) + + # print('self.is_deterministic', self.is_deterministic) inputs, targets, frame_times, num_insts = [], [], [], [] if len(rows) > self.num_rows_per_batch: - start_idx = random.randint(0, len(rows) - self.num_rows_per_batch) + if self.is_deterministic: + start_idx = 2 + else: + start_idx = random.randint(0, len(rows) - self.num_rows_per_batch) rows = rows[start_idx : start_idx + self.num_rows_per_batch] predictions = [] @@ -391,13 +410,13 @@ def __getitem__(self, idx): row = self._compute_spectrogram(row) # -- random order augmentation -- - # If turned on, comment out `is_redundant` code in `run_length_encoding` - # print("=======") + if self.is_randomize_tokens: + t = self.randomize_tokens([self.get_token_name(t) for t in row["targets"]]) + t = np.array([self.token_to_idx(k) for k in t]) + t = self._remove_redundant_tokens(t) + row["targets"] = t + # print(j, [self.get_token_name(t) for t in row["targets"]]) - t = self.randomize_tokens([self.get_token_name(t) for t in row["targets"]]) - t = np.array([self.token_to_idx(k) for k in t]) - t = self._remove_redundant_tokens(t) - row["targets"] = t row = self._pad_length(row) inputs.append(row["inputs"]) @@ -439,7 +458,6 @@ def __getitem__(self, idx): # note_seq.sequence_proto_to_midi_file(result['est_ns'], "test_out.mid") # sf.write(f"test_out.wav", np.concatenate(wavs), 16000, "PCM_24") # ========== for reconstructing the MIDI from events =========== # - # num_insts = np.stack(num_insts) return torch.stack(inputs), torch.stack(targets) @@ -538,7 +556,10 @@ def collate_fn(lst): shuffle=False, is_train=False, include_ties=True, - mel_length=256 + mel_length=256, + split_frame_length=256, + is_deterministic=False, + is_randomize_tokens=False ) print("pitch", dataset.codec.event_type_range("pitch")) print("velocity", dataset.codec.event_type_range("velocity")) @@ -548,6 +569,6 @@ def collate_fn(lst): dl = DataLoader(dataset, batch_size=1, num_workers=0, collate_fn=collate_fn, shuffle=False) for idx, item in enumerate(dl): inputs, targets = item - print(idx, inputs.shape, targets[0]) + print(idx, inputs.shape, targets[0][:100]) break \ No newline at end of file diff --git a/evaluate.py b/evaluate.py index 4900f5d..9d4ad39 100755 --- a/evaluate.py +++ b/evaluate.py @@ -328,9 +328,7 @@ def func(item): mean_scores = {k: np.mean(v) for k, v in scores.items() if k != "F1 by program"} - for key in sorted(list(mean_scores)): - print("{}: {:.4}".format(key, mean_scores[key])) - + if enable_instrument_eval: print("====") program_f1_dict = {} @@ -362,6 +360,8 @@ def func(item): print("{}: {:.4}".format(d[key], program_f1_dict[key])) elif key * 8 in program_f1_dict: print("{}: {:.4}".format(d[key], program_f1_dict[key * 8])) + + return mean_scores if __name__ == "__main__": diff --git a/inference.py b/inference.py index 670ca61..399fd47 100755 --- a/inference.py +++ b/inference.py @@ -27,7 +27,8 @@ def __init__( weight_path=None, device=torch.device('cuda'), mel_norm=True, - contiguous_inference=False + contiguous_inference=False, + use_tf_spectral_ops=False, ) -> None: if model is None: # config_path = f'{root_path}/config.json' @@ -54,6 +55,7 @@ def __init__( self.SAMPLE_RATE = 16000 self.spectrogram_config = spectrograms.SpectrogramConfig() + self.spectrogram_config.use_tf_spectral_ops = use_tf_spectral_ops self.codec = vocabularies.build_codec(vocab_config=vocabularies.VocabularyConfig( num_velocity_bins=1)) self.vocab = vocabularies.vocabulary_from_codec(self.codec) @@ -158,6 +160,7 @@ def inference( num_beams=1, batch_size=5, max_length=1024, + verbose=False, ): """ `contiguous_inference` is True only for XL models as context from previous chunk is needed. @@ -178,7 +181,6 @@ def inference( inputs_tensor = torch.cat(inputs_tensor, dim=0) frame_times = [torch.tensor(k) for k in frame_times] frame_times = torch.cat(frame_times, dim=0) - print('inputs_tensor', inputs_tensor.shape, frame_times.shape) inputs_tensor = [inputs_tensor] frame_times = [frame_times] diff --git a/test.py b/test.py index f711fa5..d1fdec5 100755 --- a/test.py +++ b/test.py @@ -1,8 +1,8 @@ import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' -from inference import InferenceHandler import torch +from inference import InferenceHandler import glob import os from tqdm import tqdm @@ -12,73 +12,123 @@ from evaluate import evaluate_main -@hydra.main(config_path="config", config_name="config") -def main(cfg): - # convert .ckpt to .pth - assert cfg.path - assert cfg.path.endswith(".pt") or cfg.path.endswith("pth"), "Only .pt / .pth files are supported." - assert cfg.eval.exp_tag_name - assert cfg.eval.audio_dir - - pl = hydra.utils.instantiate(cfg.model, optim_cfg=cfg.optim) - model = pl.model - print(f"Loading weights from: {cfg.path}") - model.load_state_dict(torch.load(cfg.path)) - model.eval() - - dir = sorted(glob.glob(cfg.eval.audio_dir)) - if cfg.eval.eval_dataset == "NSynth": - # NOTE: skip vocals and mallets. Both Slakh and ComMU dataset does not have vocals, and ComMU does not have mallets. - dir = [d for d in dir if "vocal" not in d and "mallet" not in d] - if cfg.eval.is_sanity_check: - dir = dir[:100] - - mel_norm = False if "pretrained/mt3.pth" in cfg.path else True +def get_scores( + model, + eval_audio_dir=None, + mel_norm=True, + eval_dataset="Slakh", + exp_tag_name="test_midis", + ground_truth_midi_dir=None, + verbose=True, + contiguous_inference=False, + use_tf_spectral_ops=False, + batch_size=8, + max_length=1024 +): handler = InferenceHandler( model=model, device=torch.device('cuda'), mel_norm=mel_norm, - contiguous_inference=False + contiguous_inference=contiguous_inference, + use_tf_spectral_ops=use_tf_spectral_ops ) def func(fname): audio, _ = librosa.load(fname, sr=16000) - if cfg.eval.eval_dataset == "NSynth": + if eval_dataset == "NSynth": audio = np.pad(audio, (int(0.05 * 16000), 0), "constant", constant_values=0) return audio - print("Total songs:", len(dir)) - - exp_tag_name = cfg.eval.exp_tag_name + if verbose: + print("Total songs:", len(eval_audio_dir)) - for fname in tqdm(dir): + for fname in tqdm(eval_audio_dir): audio = func(fname) - if cfg.eval.eval_dataset == "Slakh": + if eval_dataset == "Slakh": name = fname.split("/")[-2] outpath = os.path.join(exp_tag_name, name, "mix.mid") - elif cfg.eval.eval_dataset == "ComMU" or cfg.eval.eval_dataset == "NSynth": + elif eval_dataset == "ComMU" or eval_dataset == "NSynth": name = fname.split("/")[-1] outpath = os.path.join(exp_tag_name, name.replace(".wav", ".mid")) else: raise ValueError("Invalid dataset name.") handler.inference( - audio, - fname, + audio=audio, + audio_path=fname, outpath=outpath, - batch_size=8, # changing this might affect results of sequential-inference models (e.g. XL). - max_length=256 + batch_size=batch_size, + max_length=max_length, + verbose=verbose ) - print("Evaluating...") + if verbose: + print("Evaluating...") current_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir - ground_truth_midi_dir = cfg.eval.midi_dir if cfg.eval.midi_dir else cfg.dataset.test.root_dir - evaluate_main( - dataset_name=cfg.eval.eval_dataset, - test_midi_dir=os.path.join(current_dir, cfg.eval.exp_tag_name), + scores = evaluate_main( + dataset_name=eval_dataset, + test_midi_dir=os.path.join(current_dir, exp_tag_name), + ground_truth_midi_dir=ground_truth_midi_dir, + ) + + if verbose: + for key in sorted(list(scores)): + print("{}: {:.4}".format(key, scores[key])) + + return scores + + +@hydra.main(config_path="config", config_name="config", version_base="1.1") +def main(cfg): + assert cfg.path + assert cfg.path.endswith(".pt") or \ + cfg.path.endswith("pth") or \ + cfg.path.endswith("ckpt"), "Only .pt, .pth, .ckpt files are supported." + assert cfg.eval.exp_tag_name + assert cfg.eval.audio_dir + + pl = hydra.utils.instantiate(cfg.model, optim_cfg=cfg.optim) + print(f"Loading weights from: {cfg.path}") + if cfg.path.endswith(".ckpt"): + # load lightning module from checkpoint + pl = pl.load_from_checkpoint( + cfg.path, + config=cfg.model.config, + optim_cfg=cfg.optim, + ) + model = pl.model + else: + # load weights for nn.Module + model = pl.model + # print(model) + model.load_state_dict(torch.load(cfg.path)) + + model.eval() + # if torch.cuda.is_available(): + # model.cuda() + + dir = sorted(glob.glob(cfg.eval.audio_dir)) + if cfg.eval.eval_dataset == "NSynth": + # NOTE: skip vocals and mallets. Both Slakh and ComMU dataset does not have vocals, and ComMU does not have mallets. + dir = [d for d in dir if "vocal" not in d and "mallet" not in d] + if cfg.eval.eval_first_n_examples: + dir = dir[:cfg.eval.eval_first_n_examples] + + mel_norm = False if "pretrained/mt3.pth" in cfg.path else True + ground_truth_midi_dir = cfg.eval.midi_dir if cfg.eval.midi_dir else cfg.dataset.test.root_dir + + get_scores( + model, + eval_audio_dir=dir, + mel_norm=mel_norm, + eval_dataset=cfg.eval.eval_dataset, + exp_tag_name=cfg.eval.exp_tag_name, ground_truth_midi_dir=ground_truth_midi_dir, + contiguous_inference=cfg.eval.contiguous_inference, + use_tf_spectral_ops=cfg.eval.use_tf_spectral_ops, + batch_size=cfg.eval.batch_size, )