Skip to content

Commit f2dffaa

Browse files
authored
Force diarizer to use CUDA if cuda is available and if device=None. (#9380)
* Fixed clustering diarizer to load MSDD to GPU by default if cuda on Signed-off-by: Taejin Park <[email protected]> * Fixed clustering diarizer to load MSDD to GPU by default if cuda on Signed-off-by: Taejin Park <[email protected]> * Apply isort and black reformatting Signed-off-by: tango4j <[email protected]> --------- Signed-off-by: Taejin Park <[email protected]> Signed-off-by: tango4j <[email protected]> Co-authored-by: tango4j <[email protected]>
1 parent e999061 commit f2dffaa

File tree

1 file changed

+14
-10
lines changed

1 file changed

+14
-10
lines changed

nemo/collections/asr/models/clustering_diarizer.py

+14-10
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,10 @@ def get_available_model_names(class_name):
7474

7575
class ClusteringDiarizer(torch.nn.Module, Model, DiarizationMixin):
7676
"""
77-
Inference model Class for offline speaker diarization.
78-
This class handles required functionality for diarization : Speech Activity Detection, Segmentation,
79-
Extract Embeddings, Clustering, Resegmentation and Scoring.
80-
All the parameters are passed through config file
77+
Inference model Class for offline speaker diarization.
78+
This class handles required functionality for diarization : Speech Activity Detection, Segmentation,
79+
Extract Embeddings, Clustering, Resegmentation and Scoring.
80+
All the parameters are passed through config file
8181
"""
8282

8383
def __init__(self, cfg: Union[DictConfig, Any], speaker_model=None):
@@ -137,7 +137,10 @@ def _init_speaker_model(self, speaker_model=None):
137137
Initialize speaker embedding model with model name or path passed through config
138138
"""
139139
if speaker_model is not None:
140-
self._speaker_model = speaker_model
140+
if self._cfg.device is None and torch.cuda.is_available():
141+
self._speaker_model = speaker_model.to(torch.device('cuda'))
142+
else:
143+
self._speaker_model = speaker_model
141144
else:
142145
model_path = self._cfg.diarizer.speaker_embeddings.model_path
143146
if model_path is not None and model_path.endswith('.nemo'):
@@ -158,7 +161,6 @@ def _init_speaker_model(self, speaker_model=None):
158161
self._speaker_model = EncDecSpeakerLabelModel.from_pretrained(
159162
model_name=model_path, map_location=self._cfg.device
160163
)
161-
162164
self.multiscale_args_dict = parse_scale_configs(
163165
self._diarizer_params.speaker_embeddings.parameters.window_length_in_sec,
164166
self._diarizer_params.speaker_embeddings.parameters.shift_length_in_sec,
@@ -171,7 +173,9 @@ def _setup_vad_test_data(self, manifest_vad_input):
171173
'sample_rate': self._cfg.sample_rate,
172174
'batch_size': self._cfg.get('batch_size'),
173175
'vad_stream': True,
174-
'labels': ['infer',],
176+
'labels': [
177+
'infer',
178+
],
175179
'window_length_in_sec': self._vad_window_length_in_sec,
176180
'shift_length_in_sec': self._vad_shift_length_in_sec,
177181
'trim_silence': False,
@@ -192,8 +196,8 @@ def _setup_spkr_test_data(self, manifest_file):
192196

193197
def _run_vad(self, manifest_file):
194198
"""
195-
Run voice activity detection.
196-
Get log probability of voice activity detection and smoothes using the post processing parameters.
199+
Run voice activity detection.
200+
Get log probability of voice activity detection and smoothes using the post processing parameters.
197201
Using generated frame level predictions generated manifest file for later speaker embedding extraction.
198202
input:
199203
manifest_file (str) : Manifest file containing path to audio file and label as infer
@@ -338,7 +342,7 @@ def _perform_speech_activity_detection(self):
338342
def _extract_embeddings(self, manifest_file: str, scale_idx: int, num_scales: int):
339343
"""
340344
This method extracts speaker embeddings from segments passed through manifest_file
341-
Optionally you may save the intermediate speaker embeddings for debugging or any use.
345+
Optionally you may save the intermediate speaker embeddings for debugging or any use.
342346
"""
343347
logging.info("Extracting embeddings for Diarization")
344348
self._setup_spkr_test_data(manifest_file)

0 commit comments

Comments
 (0)