Skip to content

Commit b56db98

Browse files
committed
Added separation with beamformers that use true source locations. Also did some checks on rank of STFT tensor with no luck - no low-rank tensors found
1 parent 3b35cb9 commit b56db98

10 files changed

Lines changed: 553 additions & 73 deletions

File tree

MatlabUtils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def find_engine():
2424

2525
engine.addpath('amsbss/ILRMA')
2626
engine.addpath('amsbss/AuxIVA')
27+
engine.addpath('amsbss/AuxIVA/STFT')
2728
return engine
2829

2930

Model.py

Lines changed: 8 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from RecorderClass import Recorder
77
import threading
88
from Player import play
9+
import matplotlib.pyplot as plt
910

1011

1112
def mix(s_input: np.ndarray, sim: dict, data_set: dict):
@@ -54,19 +55,8 @@ def mix_linear(S: np.ndarray, sim: dict) -> Tuple[np.ndarray, np.ndarray, dict]:
5455
return filtered, mixed, {'mixing_matrix': A}
5556

5657

57-
def hexagonal_points(d: float) -> np.ndarray:
58-
return d * np.array([[-1, 0, 0],
59-
[-1 / 2, 3 ** 0.5 / 2, 0],
60-
[-1 / 2, -3 ** 0.5 / 2, 0],
61-
[0, 0, 0],
62-
[1 / 2, 3 ** 0.5 / 2, 0],
63-
[1 / 2, -3 ** 0.5 / 2, 0],
64-
[1, 0, 0]]).T
65-
66-
6758
def mix_convolutive(S: np.array, sim: dict, data_set: dict) -> Tuple[np.ndarray, np.ndarray, dict]:
6859
# Get parameters
69-
import matplotlib.pyplot as plt
7060
opts = sim['env_options']
7161
N = S.shape[0] # number of sources
7262
M = sim['microphones'] if 'microphones' in sim else N # number of microphones
@@ -84,26 +74,19 @@ def mix_convolutive(S: np.array, sim: dict, data_set: dict) -> Tuple[np.ndarray,
8474
max_order=max_order,
8575
sigma2_awgn=opts['sigma2_awgn'])
8676
# Microphone locations for hexagonal array
87-
array_loc = np.array([[3], [2], [0.5]])
88-
micro_locs = hexagonal_points(sim['microphones_distance'])
89-
micro_locs += array_loc
77+
micro_locs = opts['micro_locations']
9078

9179
# Check that required number of microphones has it's locations
92-
if micro_locs.shape[0] < M:
80+
if micro_locs.shape[1] < M:
9381
raise ValueError('{} microphones required, but only {} microphone locations specified'
9482
.format(M, micro_locs.shape[0]))
9583

9684
# Select as much microphones as needed
9785
R = micro_locs[:, :M]
9886

9987
room.add_microphone_array(pra.MicrophoneArray(R, room.fs))
100-
# room.add_microphone_array(pra.Beamformer(R, room.fs))
10188
# Place the sources inside the room
102-
source_locs = np.array([
103-
[3., 3, 0.85], # source 1
104-
[3., 1, 0.85], # source 2
105-
[5., 2, 0.85], # source 3
106-
])
89+
source_locs = opts['source_locations']
10790

10891
# Check that required number of microphones has it's locations
10992
if source_locs.shape[0] < N:
@@ -116,10 +99,10 @@ def mix_convolutive(S: np.array, sim: dict, data_set: dict) -> Tuple[np.ndarray,
11699
room.add_source(loc, signal=np.zeros_like(sig))
117100
# Make separate recordings
118101

119-
#room.plot_rir()
120-
#fig = plt.gcf()
121-
#fig.set_size_inches(9, 6)
122-
#plt.show()
102+
# room.plot_rir()
103+
# fig = plt.gcf()
104+
# fig.set_size_inches(9, 6)
105+
# plt.show()
123106

124107
filtered = []
125108
for source, s in zip(room.sources, S):

Normalizer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55

66
def form_source_matrix(S_input: list) -> np.ndarray:
77
S = copy.deepcopy(S_input)
8-
l_max = max(len(s) for s in S)
8+
l_min = min(len(s) for s in S)
99

1010
for s in S:
11-
s.resize((1, l_max), refcheck=False)
11+
s.resize((1, l_min), refcheck=False)
1212

1313
return np.vstack(S)
1414

@@ -25,7 +25,8 @@ def normalize(s: np.ndarray) -> np.ndarray:
2525
# span = np.max(s) - np.min(s)
2626
# span = 1 if span == 0 else span # safety check to avoid division by zero
2727
# s = ((s - np.min(s)) * 2) / span - 1
28-
return s / np.max(np.abs(s))
28+
max_v = np.max(np.abs(s))
29+
return s / np.max(np.abs(s)) if max_v > 0 else s
2930

3031

3132
def normalize_old(S: np.ndarray) -> np.ndarray:

0 commit comments

Comments
 (0)