Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ path:
event_length: 1024
mel_length: 256
num_rows_per_batch: 8
split_frame_length: 2000

optim:
lr: 6e-5
Expand Down Expand Up @@ -46,11 +47,13 @@ trainer:
devices: ${devices}

eval:
is_sanity_check: False
eval_dataset:
exp_tag_name:
audio_dir:
midi_dir:
eval_dataset: "Slakh"
eval_first_n_examples: 3
eval_after_num_epoch: 400
eval_per_epoch: 1
exp_tag_name: "test_midis"
audio_dir: "/data/slakh2100_flac_redux/validation/*/mix_16k.wav"
midi_dir: "/data/slakh2100_flac_redux/validation/"

defaults:
- model: MT3Net
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,24 @@ path:
event_length: 1024
mel_length: 256
num_rows_per_batch: 12
split_frame_length: 2000

optim:
lr: 2e-4
warmup_steps: 64500
num_epochs: ${num_epochs}
num_steps_per_epoch: 1289 # TODO: this is not good practice. Ideally we can get this from dataloader.
min_lr: 1e-4

grad_accum: 1

dataloader:
train:
batch_size: 1
num_workers: 12
num_workers: 4
val:
batch_size: 1
num_workers: 12
num_workers: 0

modelcheckpoint:
monitor: 'val_loss'
Expand All @@ -44,7 +46,15 @@ trainer:
strategy: "ddp_find_unused_parameters_false"
devices: ${devices}

eval:
eval_dataset: "Slakh"
eval_first_n_examples: 1
eval_after_num_epoch: 400
eval_per_epoch: 1
exp_tag_name: "test_midis"
audio_dir: "/data/slakh2100_flac_redux/validation/*/mix_16k.wav"
midi_dir: "/data/slakh2100_flac_redux/validation/"

defaults:
- model: MT3Net
- dataset: Slakh
# TODO: we need to specify num_samples_per_batch here from 8 to 12
- dataset: Slakh
2 changes: 2 additions & 0 deletions config/dataset/Slakh.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ train:
inst_filename: inst_names.json
audio_filename: mix_16k.wav
num_rows_per_batch: ${num_rows_per_batch}
split_frame_length: ${split_frame_length}
val:
_target_: dataset.dataset_2_random.SlakhDataset # choosing which data class to use
root_dir: "/data2/kinwai/slakh2100_flac_redux/validation/"
Expand All @@ -16,6 +17,7 @@ val:
inst_filename: inst_names.json
audio_filename: mix_16k.wav
num_rows_per_batch: ${num_rows_per_batch}
split_frame_length: ${split_frame_length}
test:
root_dir: "/data/slakh2100_flac_redux/test"
collate_fn: dataset.dataset_2_random.collate_fn
2 changes: 2 additions & 0 deletions contrib/spectrograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

from ddsp import spectral_ops
import tensorflow as tf
tf.config.set_visible_devices([], 'GPU')


# defaults for spectrogram config
DEFAULT_SAMPLE_RATE = 16000
Expand Down
1 change: 1 addition & 0 deletions contrib/vocabularies.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import seqio
import t5.data
import tensorflow as tf
tf.config.set_visible_devices([], 'GPU')


DECODED_EOS_ID = -1
Expand Down
38 changes: 16 additions & 22 deletions dataset/dataset_2_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import librosa
import note_seq
from glob import glob
from contrib import event_codec, note_sequences, spectrograms, vocabularies, run_length_encoding, metrics_utils
from contrib import spectrograms, vocabularies, note_sequences, run_length_encoding, metrics_utils, event_codec
from contrib.preprocessor import slakh_class_to_program_and_is_drum, add_track_to_notesequence, PitchBendError
import soundfile as sf

Expand All @@ -33,7 +33,8 @@ def __init__(
midi_folder='MIDI',
inst_filename='inst_names.json',
shuffle=True,
num_rows_per_batch=8
num_rows_per_batch=8,
split_frame_length=2000
) -> None:
super().__init__()
self.spectrogram_config = spectrograms.SpectrogramConfig()
Expand All @@ -52,6 +53,7 @@ def __init__(
self.onsets_only = onsets_only
self.tie_token = self.codec.encode_event(event_codec.Event('tie', 0)) if self.include_ties else None
self.num_rows_per_batch = num_rows_per_batch
self.split_frame_length = split_frame_length

def _build_dataset(self, root_dir, shuffle=True):
df = []
Expand Down Expand Up @@ -79,7 +81,6 @@ def _audio_to_frames(
[0, frame_size - len(samples) % frame_size],
mode='constant')


frames = spectrograms.split_audio(samples, self.spectrogram_config)

num_frames = len(samples) // frame_size
Expand Down Expand Up @@ -365,17 +366,20 @@ def __getitem__(self, idx):
if sr != self.spectrogram_config.sample_rate:
audio = librosa.resample(audio, orig_sr=sr, target_sr=self.spectrogram_config.sample_rate)
row = self._tokenize(ns, audio, inst_names)

# NOTE: by default, this is self._split_frame(row, length=2000)
# this does not guarantee the chunks in `rows` to be contiguous.
# if we need to ensure that the chunks in `rows` to be contiguous, use:
# rows = self._split_frame(row, length=self.mel_length)
rows = self._split_frame(row)

# NOTE: if split_frame_length != mel_length, then it does not guarantee the chunks in `rows` to be contiguous.
rows = self._split_frame(row, length=self.split_frame_length)

inputs, targets, frame_times, num_insts = [], [], [], []
if len(rows) > self.num_rows_per_batch:
start_idx = random.randint(0, len(rows) - self.num_rows_per_batch)
rows = rows[start_idx : start_idx + self.num_rows_per_batch]

# randomly select N rows for training, as the GPU memory is limited
if self.num_rows_per_batch == -1:
# choose all rows in this case
pass
else:
if len(rows) > self.num_rows_per_batch:
start_idx = random.randint(0, len(rows) - self.num_rows_per_batch)
rows = rows[start_idx : start_idx + self.num_rows_per_batch]

predictions = []
# wavs = []
Expand Down Expand Up @@ -520,16 +524,6 @@ def token_to_idx(self, token_str):
def collate_fn(lst):
inputs = torch.cat([k[0] for k in lst])
targets = torch.cat([k[1] for k in lst])
# num_insts = torch.cat([k[2] for k in lst])

# add random shuffling here
# indices = np.arange(inputs.shape[0])
# np.random.shuffle(indices)
# indices = torch.from_numpy(indices)
# inputs = inputs[indices]
# targets = targets[indices]
# num_insts = num_insts[indices]

return inputs, targets

if __name__ == '__main__':
Expand Down
4 changes: 2 additions & 2 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,6 @@ def func(item):


mean_scores = {k: np.mean(v) for k, v in scores.items() if k != "F1 by program"}
for key in sorted(list(mean_scores)):
print("{}: {:.4}".format(key, mean_scores[key]))

if enable_instrument_eval:
print("====")
Expand Down Expand Up @@ -362,6 +360,8 @@ def func(item):
print("{}: {:.4}".format(d[key], program_f1_dict[key]))
elif key * 8 in program_f1_dict:
print("{}: {:.4}".format(d[key], program_f1_dict[key * 8]))

return mean_scores


if __name__ == "__main__":
Expand Down
12 changes: 3 additions & 9 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
import os
from typing import List

import torch.nn as nn
import torch

import numpy as np
from tqdm import tqdm
from models.t5 import T5ForConditionalGeneration, T5Config
from models.t5_xl import T5WithXLDecoder, T5Config
# from models.t5_xl_instrument import T5WithXLDecoder, T5Config
import torch.nn as nn
import torch
from contrib import spectrograms, vocabularies, note_sequences, metrics_utils
import note_seq
import traceback
Expand Down Expand Up @@ -58,7 +59,6 @@ def __init__(
num_velocity_bins=1))
self.vocab = vocabularies.vocabulary_from_codec(self.codec)
self.device = device
self.model.to(self.device)

self.mel_norm = mel_norm
# if "pretrained/mt3.pth" in weight_path:
Expand All @@ -72,7 +72,6 @@ def _audio_to_frames(self, audio):
frame_size = spectrogram_config.hop_width
padding = [0, frame_size - len(audio) % frame_size]
audio = np.pad(audio, padding, mode='constant')
print('audio', audio.shape, 'frame_size', frame_size)
frames = spectrograms.split_audio(audio, spectrogram_config)
num_frames = len(audio) // frame_size
times = np.arange(num_frames) / \
Expand Down Expand Up @@ -167,22 +166,18 @@ def inference(
invalid_programs = self._get_program_ids(valid_programs)
else:
invalid_programs = None
# print('preprocessing', audio_path)
inputs, frame_times = self._preprocess(audio)
inputs_tensor = torch.from_numpy(inputs)
results = []
inputs_tensor, frame_times = self._batching(inputs_tensor, frame_times, batch_size=batch_size)
print('inferencing', audio_path)

if self.contiguous_inference:
inputs_tensor = torch.cat(inputs_tensor, dim=0)
frame_times = [torch.tensor(k) for k in frame_times]
frame_times = torch.cat(frame_times, dim=0)
print('inputs_tensor', inputs_tensor.shape, frame_times.shape)
inputs_tensor = [inputs_tensor]
frame_times = [frame_times]

self.model.cuda()
for idx, batch in enumerate(inputs_tensor):
batch = batch.to(self.device)

Expand All @@ -200,7 +195,6 @@ def inference(
filename = audio_path.split('/')[-1].split('.')[0]
outpath = f'./out/{filename}.mid'
os.makedirs('/'.join(outpath.split('/')[:-1]), exist_ok=True)
print("saving", outpath)
note_seq.sequence_proto_to_midi_file(event, outpath)

except Exception as e:
Expand Down
4 changes: 3 additions & 1 deletion midi_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
Use this for evaluation!!!
"""

midis = sorted(glob.glob("/workspace/data/dataset/slakh2100_flac_redux/test/*/MIDI/"))
# NOTE: change the following path accordingly
midis = sorted(glob.glob("/data/slakh2100_flac_redux/validation/*/MIDI/"))
# midis = sorted(glob.glob("/workspace/data/dataset/slakh2100_flac_redux/test/*/MIDI/"))
for midi in tqdm(midis):
stems = sorted(glob.glob(midi + "*.mid"))
insts = []
Expand Down
48 changes: 43 additions & 5 deletions tasks/mt3_net.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,37 @@
from torch.optim import AdamW
from omegaconf import OmegaConf
import pytorch_lightning as pl
from pytorch_lightning.utilities.rank_zero import rank_zero_only

import torch
# NOTE: On Linux, Python's multiprocessing module default start method is "fork".
# we need to set the start method to "spawn"
# in order to properly use TF methods in PyTorch dataloader.
try:
torch.multiprocessing.set_start_method('spawn')
except RuntimeError:
pass

import torch.nn as nn
from transformers import T5Config

import tensorflow as tf
tf.config.set_visible_devices([], 'GPU')


from models.t5 import T5ForConditionalGeneration
from utils import get_cosine_schedule_with_warmup
from test import get_scores
import glob


class MT3Net(pl.LightningModule):

def __init__(self, config, optim_cfg):
def __init__(self, config, optim_cfg, eval_cfg=None):
super().__init__()
self.config = config
self.optim_cfg = optim_cfg
self.eval_cfg = eval_cfg
T5config = T5Config.from_dict(OmegaConf.to_container(self.config))
self.model: nn.Module = T5ForConditionalGeneration(T5config)

Expand Down Expand Up @@ -43,9 +62,28 @@ def validation_step(self, batch, batch_idx):
)
self.log('val_loss', loss, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)

# no need to use it in this stage
# return loss

@rank_zero_only
def validation_epoch_end(self, outs):
if self.current_epoch >= self.eval_cfg.eval_after_num_epoch:
if self.current_epoch % self.eval_cfg.eval_per_epoch == 0:
eval_audio_dir = sorted(glob.glob(self.eval_cfg.audio_dir))
if self.eval_cfg.eval_first_n_examples:
eval_audio_dir = eval_audio_dir[:self.eval_cfg.eval_first_n_examples]

self.model.eval()
scores = get_scores(
model=self.model,
eval_audio_dir=eval_audio_dir,
eval_dataset="Slakh",
ground_truth_midi_dir=self.eval_cfg.midi_dir,
verbose=False
)
print(scores)

self.log('val_f1_flat', scores['Onset F1'], on_step=False, on_epoch=True)
self.log('val_f1_midi_class', scores['Onset + program F1 (midi_class)'], on_step=False, on_epoch=True)
self.log('val_f1_full', scores['Onset + program F1 (full)'], on_step=False, on_epoch=True)

def configure_optimizers(self):
optimizer = AdamW(self.model.parameters(), self.optim_cfg.lr)
warmup_step = int(self.optim_cfg.warmup_steps)
Expand All @@ -69,7 +107,7 @@ def configure_optimizers(self):

class MT3NetWeightedLoss(pl.LightningModule):

def __init__(self, config, optim_cfg):
def __init__(self, config, optim_cfg, eval_cfg=None):
super().__init__()
self.config = config
self.optim_cfg = optim_cfg
Expand Down
Loading