@@ -74,10 +74,10 @@ def get_available_model_names(class_name):
74
74
75
75
class ClusteringDiarizer (torch .nn .Module , Model , DiarizationMixin ):
76
76
"""
77
- Inference model Class for offline speaker diarization.
78
- This class handles required functionality for diarization : Speech Activity Detection, Segmentation,
79
- Extract Embeddings, Clustering, Resegmentation and Scoring.
80
- All the parameters are passed through config file
77
+ Inference model Class for offline speaker diarization.
78
+ This class handles required functionality for diarization : Speech Activity Detection, Segmentation,
79
+ Extract Embeddings, Clustering, Resegmentation and Scoring.
80
+ All the parameters are passed through config file
81
81
"""
82
82
83
83
def __init__ (self , cfg : Union [DictConfig , Any ], speaker_model = None ):
@@ -137,7 +137,10 @@ def _init_speaker_model(self, speaker_model=None):
137
137
Initialize speaker embedding model with model name or path passed through config
138
138
"""
139
139
if speaker_model is not None :
140
- self ._speaker_model = speaker_model
140
+ if self ._cfg .device is None and torch .cuda .is_available ():
141
+ self ._speaker_model = speaker_model .to (torch .device ('cuda' ))
142
+ else :
143
+ self ._speaker_model = speaker_model
141
144
else :
142
145
model_path = self ._cfg .diarizer .speaker_embeddings .model_path
143
146
if model_path is not None and model_path .endswith ('.nemo' ):
@@ -158,7 +161,6 @@ def _init_speaker_model(self, speaker_model=None):
158
161
self ._speaker_model = EncDecSpeakerLabelModel .from_pretrained (
159
162
model_name = model_path , map_location = self ._cfg .device
160
163
)
161
-
162
164
self .multiscale_args_dict = parse_scale_configs (
163
165
self ._diarizer_params .speaker_embeddings .parameters .window_length_in_sec ,
164
166
self ._diarizer_params .speaker_embeddings .parameters .shift_length_in_sec ,
@@ -171,7 +173,9 @@ def _setup_vad_test_data(self, manifest_vad_input):
171
173
'sample_rate' : self ._cfg .sample_rate ,
172
174
'batch_size' : self ._cfg .get ('batch_size' ),
173
175
'vad_stream' : True ,
174
- 'labels' : ['infer' ,],
176
+ 'labels' : [
177
+ 'infer' ,
178
+ ],
175
179
'window_length_in_sec' : self ._vad_window_length_in_sec ,
176
180
'shift_length_in_sec' : self ._vad_shift_length_in_sec ,
177
181
'trim_silence' : False ,
@@ -192,8 +196,8 @@ def _setup_spkr_test_data(self, manifest_file):
192
196
193
197
def _run_vad (self , manifest_file ):
194
198
"""
195
- Run voice activity detection.
196
- Get log probability of voice activity detection and smoothes using the post processing parameters.
199
+ Run voice activity detection.
200
+ Get log probability of voice activity detection and smoothes using the post processing parameters.
197
201
Using generated frame level predictions generated manifest file for later speaker embedding extraction.
198
202
input:
199
203
manifest_file (str) : Manifest file containing path to audio file and label as infer
@@ -338,7 +342,7 @@ def _perform_speech_activity_detection(self):
338
342
def _extract_embeddings (self , manifest_file : str , scale_idx : int , num_scales : int ):
339
343
"""
340
344
This method extracts speaker embeddings from segments passed through manifest_file
341
- Optionally you may save the intermediate speaker embeddings for debugging or any use.
345
+ Optionally you may save the intermediate speaker embeddings for debugging or any use.
342
346
"""
343
347
logging .info ("Extracting embeddings for Diarization" )
344
348
self ._setup_spkr_test_data (manifest_file )
0 commit comments