Skip to content

Commit

Permalink
Update notebooks
Browse files Browse the repository at this point in the history
  • Loading branch information
thequilo committed May 16, 2022
1 parent 0bdec0b commit 5568e2a
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 29 deletions.
18 changes: 11 additions & 7 deletions mms_msg/visualization/plot.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import contextlib
from collections import defaultdict

import paderbox as pb


def plot_meeting(ex):
with pb.visualization.axes_context(columns=1, figure_size=(10, 3)) as ac:
speech_activity = defaultdict(pb.array.interval.zeros)
num_samples = pb.utils.nested.get_by_path(ex, 'num_samples.original_source', allow_early_stopping=True)
for o, l, s, in zip(ex['offset']['original_source'], num_samples, ex['speaker_id']):
speech_activity[s][o:o + l] = True
def plot_mixture(ex, ax=None):
if ax is None:
from matplotlib import pyplot as plt
_, ax = plt.subplots(1, 1)
speech_activity = defaultdict(pb.array.interval.zeros)
num_samples = pb.utils.nested.get_by_path(ex, 'num_samples.original_source', allow_early_stopping=True)
for o, l, s, in zip(ex['offset']['original_source'], num_samples, ex['speaker_id']):
speech_activity[s][o:o + l] = True

pb.visualization.plot.activity(speech_activity, ax=ac.new)
pb.visualization.plot.activity(speech_activity, ax=ax)
ax.axvline(ex['num_samples']['observation'])
4 changes: 3 additions & 1 deletion notebooks/extending_mms_msg.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@
},
"outputs": [],
"source": [
"mms_msg.visualization.plot.plot_meeting(ds.map(SequentialOffsetSampler())[0])"
"import paderbox as pb\n",
"with pb.visualization.figure_context():\n",
" mms_msg.visualization.plot.plot_mixture(ds.map(SequentialOffsetSampler())[0])"
]
},
{
Expand Down
50 changes: 29 additions & 21 deletions notebooks/mixture_generator.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,20 @@
"source": [
"from collections import defaultdict\n",
"import itertools\n",
"from mms_msg.visualization.plot import plot_meeting\n",
"from mms_msg.visualization.plot import plot_mixture\n",
"from mms_msg import keys\n",
"from mms_msg.simulation.utils import load_audio\n",
" \n",
"def plot_meetings(generator_dataset, number=6, columns=3, figure_width=10):\n",
"def plot_mixtures(generator_dataset, number=6, columns=3, figure_width=10):\n",
" with pb.visualization.axes_context(columns=columns, figure_size=(figure_width, 3)) as ac:\n",
" for ex in itertools.islice(generator_dataset, number):\n",
" activity = defaultdict(pb.array.interval.zeros)\n",
" num_samples = pb.utils.nested.get_by_path(ex, 'num_samples.original_source', allow_early_stopping=True)\n",
" for o, l, s in zip(ex['offset']['original_source'], num_samples, ex['speaker_id']):\n",
" activity[s][o:o+l] = True\n",
"\n",
" pb.visualization.plot.activity(activity, ax=ac.new)"
" plot_mixture(ex, ac.new)\n",
" # activity = defaultdict(pb.array.interval.zeros)\n",
" # num_samples = pb.utils.nested.get_by_path(ex, 'num_samples.original_source', allow_early_stopping=True)\n",
" # for o, l, s in zip(ex['offset']['original_source'], num_samples, ex['speaker_id']):\n",
" # activity[s][o:o+l] = True\n",
" #\n",
" # pb.visualization.plot.activity(activity, ax=ac.new)"
]
},
{
Expand Down Expand Up @@ -136,6 +137,9 @@
"# If required: Add log_weights to simulate volume differences\n",
"ds = ds.map(sampling.environment.scaling.UniformScalingSampler(max_weight=5))\n",
"\n",
"# If required: Truncate to the shorter utterance\n",
"ds = ds.map(mms_msg.simulation.truncation.truncate_min)\n",
"\n",
"len(ds), ds[0]"
]
},
Expand All @@ -146,7 +150,7 @@
"metadata": {},
"outputs": [],
"source": [
"plot_meetings(ds)"
"plot_mixtures(ds)"
]
},
{
Expand Down Expand Up @@ -180,7 +184,7 @@
"metadata": {},
"outputs": [],
"source": [
"plot_meetings(ds)"
"plot_mixtures(ds)"
]
},
{
Expand All @@ -191,7 +195,6 @@
"outputs": [],
"source": [
"# Load an example\n",
"import functools\n",
"from mms_msg import keys\n",
"ds = ds\\\n",
" .map(lambda example: load_audio(example, keys.ORIGINAL_SOURCE, keys.RIR))\\\n",
Expand Down Expand Up @@ -236,10 +239,7 @@
"source": [
"# Check that iterating two times gives different examples\n",
"for _ in range(2):\n",
" for e in ds:\n",
" print(e)\n",
" print()\n",
" break"
" plot_mixtures(ds, number=3)"
]
},
{
Expand Down Expand Up @@ -279,7 +279,7 @@
"metadata": {},
"outputs": [],
"source": [
"plot_meetings(ds, columns=2, figure_width=20, number=6)"
"plot_mixtures(ds, columns=2, figure_width=20, number=6)"
]
},
{
Expand Down Expand Up @@ -342,7 +342,7 @@
"metadata": {},
"outputs": [],
"source": [
"db.get_dataset('test_clean')[0]"
"plot_mixtures(db.get_dataset('test_clean'), number=3)"
]
},
{
Expand All @@ -356,22 +356,30 @@
},
"outputs": [],
"source": [
"# Dynamic mixing can be enabled by appending \"_rng\" (for a random seed) or \"_rng<seed>\" (for a fixed seed) to the dataset name\n",
"next(iter(db.get_dataset('train_clean_100_rng')))"
"# Dynamic mixing can be enabled by appending \"_rng\" (for a random seed) or \"_rng<seed>\" (for a fixed seed) to the dataset name.\n",
"# The top two potted rows are different because the seed is random by default\n",
"# The bottom two plotted rows are equal because the seed is fixed to 42\n",
"plot_mixtures(db.get_dataset('train_clean_100_rng'), number=3)\n",
"plot_mixtures(db.get_dataset('train_clean_100_rng'), number=3)\n",
"plot_mixtures(db.get_dataset('train_clean_100_rng42'), number=3)\n",
"plot_mixtures(db.get_dataset('train_clean_100_rng42'), number=3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "63fa20cf",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"next(iter(db.get_dataset('train_clean_100_rng42')))"
"# Audio can be loaded with the `load_example` method of the database object\n",
"ex = db.load_example(db.get_dataset('test_clean')[0])\n",
"pb.io.play(ex['audio_data']['observation'], name='observation')\n",
"pb.io.play(ex['audio_data']['speech_image'][0], name='speech_image 1')\n",
"pb.io.play(ex['audio_data']['speech_image'][1], name='speech_image 2')"
]
},
{
Expand Down

0 comments on commit 5568e2a

Please sign in to comment.