Taejin commited on
Commit
719134e
·
1 Parent(s): 8b4a198

Adding exampl.py

Browse files

Signed-off-by: taejinp <[email protected]>

Files changed (1) hide show
  1. example.py +72 -0
example.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the NVIDIA Open Model License (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license/
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Load one of the NeMo speaker diarization models:
17
+ [Streaming Sortformer Diarizer v2](https://huggingface.co/nvidia/diar_streaming_sortformer_4spk-v2.1),
18
+ [Streaming Sortformer Diarizer v2.1](https://huggingface.co/nvidia/diar_streaming_sortformer_4spk-v2.1)
19
+ """
20
+ ```python
21
+ from nemo.collections.asr.models import SortformerEncLabelModel, ASRModel
22
+ import torch
23
+ # A speaker diarization model is needed for tracking the speech activity of each speaker.
24
+ diar_model = SortformerEncLabelModel.from_pretrained("nvidia/diar_streaming_sortformer_4spk-v2.1").eval().to(torch.device("cuda"))
25
+ asr_model = ASRModel.from_pretrained("nvidia/multitalker-parakeet-streaming-0.6b-v1.nemo").eval().to(torch.device("cuda"))
26
+
27
+ # Use the pre-defined dataclass template `MultitalkerTranscriptionConfig` from `multitalker_transcript_config.py`.
28
+ # Configure the diarization model using streaming parameters:
29
+ from multitalker_transcript_config import MultitalkerTranscriptionConfig
30
+ from omegaconf import OmegaConf
31
+ cfg = OmegaConf.structured(MultitalkerTranscriptionConfig())
32
+ cfg.audio_file = "/path/to/your/audio.wav"
33
+ cfg.output_path = "/path/to/output_transcription.json"
34
+
35
+ diar_model = MultitalkerTranscriptionConfig.init_diar_model(cfg, diar_model)
36
+
37
+ # Load your audio file into a streaming audio buffer to simulate a real-time audio session.
38
+ from nemo.collections.asr.parts.utils.streaming_utils import CacheAwareStreamingAudioBuffer
39
+
40
+ samples = [{'audio_filepath': cfg.audio_file}]
41
+ streaming_buffer = CacheAwareStreamingAudioBuffer(
42
+ model=asr_model,
43
+ online_normalization=cfg.online_normalization,
44
+ pad_and_drop_preencoded=cfg.pad_and_drop_preencoded,
45
+ )
46
+ streaming_buffer.append_audio_file(audio_filepath=cfg.audio_file, stream_id=-1)
47
+ streaming_buffer_iter = iter(streaming_buffer)
48
+
49
+ # Use the helper class `SpeakerTaggedASR`, which handles all ASR and diarization cache data for streaming.
50
+ from nemo.collections.asr.parts.utils.multispk_transcribe_utils import SpeakerTaggedASR
51
+ multispk_asr_streamer = SpeakerTaggedASR(cfg, asr_model, diar_model)
52
+
53
+ for step_num, (chunk_audio, chunk_lengths) in enumerate(streaming_buffer_iter):
54
+ drop_extra_pre_encoded = (
55
+ 0
56
+ if step_num == 0 and not cfg.pad_and_drop_preencoded
57
+ else asr_model.encoder.streaming_cfg.drop_extra_pre_encoded
58
+ )
59
+ with torch.inference_mode():
60
+ with torch.amp.autocast(diar_model.device.type, enabled=True):
61
+ with torch.no_grad():
62
+ multispk_asr_streamer.perform_parallel_streaming_stt_spk(
63
+ step_num=step_num,
64
+ chunk_audio=chunk_audio,
65
+ chunk_lengths=chunk_lengths,
66
+ is_buffer_empty=streaming_buffer.is_buffer_empty(),
67
+ drop_extra_pre_encoded=drop_extra_pre_encoded,
68
+ )
69
+
70
+ # Generate the speaker-tagged transcript and print it.
71
+ multispk_asr_streamer.generate_seglst_dicts_from_parallel_streaming(samples=samples)
72
+ print(multispk_asr_streamer.instance_manager.seglst_dict_list)