From 5568e2adaaca770b3146587f298bd8325380105b Mon Sep 17 00:00:00 2001 From: Thilo von Neumann Date: Mon, 16 May 2022 11:30:28 +0200 Subject: [PATCH] Update notebooks --- mms_msg/visualization/plot.py | 18 ++++++----- notebooks/extending_mms_msg.ipynb | 4 ++- notebooks/mixture_generator.ipynb | 50 ++++++++++++++++++------------- 3 files changed, 43 insertions(+), 29 deletions(-) diff --git a/mms_msg/visualization/plot.py b/mms_msg/visualization/plot.py index 664c27a..007efde 100644 --- a/mms_msg/visualization/plot.py +++ b/mms_msg/visualization/plot.py @@ -1,13 +1,17 @@ +import contextlib from collections import defaultdict import paderbox as pb -def plot_meeting(ex): - with pb.visualization.axes_context(columns=1, figure_size=(10, 3)) as ac: - speech_activity = defaultdict(pb.array.interval.zeros) - num_samples = pb.utils.nested.get_by_path(ex, 'num_samples.original_source', allow_early_stopping=True) - for o, l, s, in zip(ex['offset']['original_source'], num_samples, ex['speaker_id']): - speech_activity[s][o:o + l] = True +def plot_mixture(ex, ax=None): + if ax is None: + from matplotlib import pyplot as plt + _, ax = plt.subplots(1, 1) + speech_activity = defaultdict(pb.array.interval.zeros) + num_samples = pb.utils.nested.get_by_path(ex, 'num_samples.original_source', allow_early_stopping=True) + for o, l, s, in zip(ex['offset']['original_source'], num_samples, ex['speaker_id']): + speech_activity[s][o:o + l] = True - pb.visualization.plot.activity(speech_activity, ax=ac.new) + pb.visualization.plot.activity(speech_activity, ax=ax) + ax.axvline(ex['num_samples']['observation']) diff --git a/notebooks/extending_mms_msg.ipynb b/notebooks/extending_mms_msg.ipynb index fccbd4b..4c5f490 100644 --- a/notebooks/extending_mms_msg.ipynb +++ b/notebooks/extending_mms_msg.ipynb @@ -146,7 +146,9 @@ }, "outputs": [], "source": [ - "mms_msg.visualization.plot.plot_meeting(ds.map(SequentialOffsetSampler())[0])" + "import paderbox as pb\n", + "with pb.visualization.figure_context():\n", + " mms_msg.visualization.plot.plot_mixture(ds.map(SequentialOffsetSampler())[0])" ] }, { diff --git a/notebooks/mixture_generator.ipynb b/notebooks/mixture_generator.ipynb index 7cbab90..c1ae192 100644 --- a/notebooks/mixture_generator.ipynb +++ b/notebooks/mixture_generator.ipynb @@ -40,19 +40,20 @@ "source": [ "from collections import defaultdict\n", "import itertools\n", - "from mms_msg.visualization.plot import plot_meeting\n", + "from mms_msg.visualization.plot import plot_mixture\n", "from mms_msg import keys\n", "from mms_msg.simulation.utils import load_audio\n", " \n", - "def plot_meetings(generator_dataset, number=6, columns=3, figure_width=10):\n", + "def plot_mixtures(generator_dataset, number=6, columns=3, figure_width=10):\n", " with pb.visualization.axes_context(columns=columns, figure_size=(figure_width, 3)) as ac:\n", " for ex in itertools.islice(generator_dataset, number):\n", - " activity = defaultdict(pb.array.interval.zeros)\n", - " num_samples = pb.utils.nested.get_by_path(ex, 'num_samples.original_source', allow_early_stopping=True)\n", - " for o, l, s in zip(ex['offset']['original_source'], num_samples, ex['speaker_id']):\n", - " activity[s][o:o+l] = True\n", - "\n", - " pb.visualization.plot.activity(activity, ax=ac.new)" + " plot_mixture(ex, ac.new)\n", + " # activity = defaultdict(pb.array.interval.zeros)\n", + " # num_samples = pb.utils.nested.get_by_path(ex, 'num_samples.original_source', allow_early_stopping=True)\n", + " # for o, l, s in zip(ex['offset']['original_source'], num_samples, ex['speaker_id']):\n", + " # activity[s][o:o+l] = True\n", + " #\n", + " # pb.visualization.plot.activity(activity, ax=ac.new)" ] }, { @@ -136,6 +137,9 @@ "# If required: Add log_weights to simulate volume differences\n", "ds = ds.map(sampling.environment.scaling.UniformScalingSampler(max_weight=5))\n", "\n", + "# If required: Truncate to the shorter utterance\n", + "ds = ds.map(mms_msg.simulation.truncation.truncate_min)\n", + "\n", "len(ds), ds[0]" ] }, @@ -146,7 +150,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot_meetings(ds)" + "plot_mixtures(ds)" ] }, { @@ -180,7 +184,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot_meetings(ds)" + "plot_mixtures(ds)" ] }, { @@ -191,7 +195,6 @@ "outputs": [], "source": [ "# Load an example\n", - "import functools\n", "from mms_msg import keys\n", "ds = ds\\\n", " .map(lambda example: load_audio(example, keys.ORIGINAL_SOURCE, keys.RIR))\\\n", @@ -236,10 +239,7 @@ "source": [ "# Check that iterating two times gives different examples\n", "for _ in range(2):\n", - " for e in ds:\n", - " print(e)\n", - " print()\n", - " break" + " plot_mixtures(ds, number=3)" ] }, { @@ -279,7 +279,7 @@ "metadata": {}, "outputs": [], "source": [ - "plot_meetings(ds, columns=2, figure_width=20, number=6)" + "plot_mixtures(ds, columns=2, figure_width=20, number=6)" ] }, { @@ -342,7 +342,7 @@ "metadata": {}, "outputs": [], "source": [ - "db.get_dataset('test_clean')[0]" + "plot_mixtures(db.get_dataset('test_clean'), number=3)" ] }, { @@ -356,14 +356,18 @@ }, "outputs": [], "source": [ - "# Dynamic mixing can be enabled by appending \"_rng\" (for a random seed) or \"_rng\" (for a fixed seed) to the dataset name\n", - "next(iter(db.get_dataset('train_clean_100_rng')))" + "# Dynamic mixing can be enabled by appending \"_rng\" (for a random seed) or \"_rng\" (for a fixed seed) to the dataset name.\n", + "# The top two potted rows are different because the seed is random by default\n", + "# The bottom two plotted rows are equal because the seed is fixed to 42\n", + "plot_mixtures(db.get_dataset('train_clean_100_rng'), number=3)\n", + "plot_mixtures(db.get_dataset('train_clean_100_rng'), number=3)\n", + "plot_mixtures(db.get_dataset('train_clean_100_rng42'), number=3)\n", + "plot_mixtures(db.get_dataset('train_clean_100_rng42'), number=3)" ] }, { "cell_type": "code", "execution_count": null, - "id": "63fa20cf", "metadata": { "pycharm": { "name": "#%%\n" @@ -371,7 +375,11 @@ }, "outputs": [], "source": [ - "next(iter(db.get_dataset('train_clean_100_rng42')))" + "# Audio can be loaded with the `load_example` method of the database object\n", + "ex = db.load_example(db.get_dataset('test_clean')[0])\n", + "pb.io.play(ex['audio_data']['observation'], name='observation')\n", + "pb.io.play(ex['audio_data']['speech_image'][0], name='speech_image 1')\n", + "pb.io.play(ex['audio_data']['speech_image'][1], name='speech_image 2')" ] }, {