Taejin commited on
Commit
0b09b5e
·
1 Parent(s): a7109eb

Updating the README.md file

Browse files

Signed-off-by: taejinp <tango4j@gmail.com>

Files changed (1) hide show
  1. README.md +15 -14
README.md CHANGED
@@ -239,8 +239,6 @@ Each model instance:
239
 
240
  This architecture enables the model to handle severe speech overlap by having each instance focus exclusively on one speaker, eliminating the permutation problem that affects other multitalker ASR approaches.
241
 
242
-
243
-
244
  ## NVIDIA NeMo
245
 
246
  To train, fine-tune or perform multitalker ASR with this model, you will need to install [NVIDIA NeMo](https://github.com/NVIDIA/NeMo)[7]. We recommend you install it after you've installed Cython and latest PyTorch version.
@@ -259,31 +257,36 @@ The model is available for use in the NeMo Framework[7], and can be used as a pr
259
 
260
  ### Method 1. Code snippet
261
 
 
 
 
262
  ```python
263
  from nemo.collections.asr.models import SortformerEncLabelModel
264
  import torch
265
 
266
- # Step 1: Load streaming diarization model (provides speaker activity predictions)
267
  diar_model = SortformerEncLabelModel.from_pretrained("nvidia/diar_streaming_sortformer_4spk-v2.1")
268
  diar_model.eval().to(torch.device("cuda"))
269
 
270
- # Step 2: Load streaming multitalker ASR model (transcribes each speaker separately)
271
  asr_model = ASRModel.from_pretrained("nvidia/multitalker-parakeet-streaming-0.6b-v1.nemo")
272
  asr_model.eval().to(torch.device("cuda"))
273
 
 
 
 
 
274
  from multitalker_transcript_config import MultitalkerTranscriptionConfig
275
  from omegaconf import OmegaConf
276
- # Step 3: Configure models with streaming parameters (latency, chunk sizes, etc.)
277
  cfg = OmegaConf.structured(MultitalkerTranscriptionConfig())
278
  cfg.audio_file = "/path/to/your/audio.wav"
279
  cfg.output_path = "/path/to/output_transcription.json"
280
 
281
- # Initialize diarization model with streaming config (sets chunk_len, context, etc.)
282
  diar_model = MultitalkerTranscriptionConfig.init_diar_model(cfg, diar_model)
283
 
 
 
 
284
  from nemo.collections.asr.parts.utils.streaming_utils import CacheAwareStreamingAudioBuffer
285
 
286
- # Step 4: Setup streaming buffer (simulates real-time audio stream)
287
  samples = [{'audio_filepath': cfg.audio_file}]
288
  streaming_buffer = CacheAwareStreamingAudioBuffer(
289
  model=asr_model,
@@ -292,13 +295,12 @@ streaming_buffer = CacheAwareStreamingAudioBuffer(
292
  )
293
  streaming_buffer.append_audio_file(audio_filepath=cfg.audio_file, stream_id=-1)
294
  streaming_buffer_iter = iter(streaming_buffer)
295
-
 
 
296
  from nemo.collections.asr.parts.utils.multispk_transcribe_utils import SpeakerTaggedASR
297
-
298
- # Step 5: Initialize multi-instance ASR streamer (manages per-speaker ASR instances)
299
  multispk_asr_streamer = SpeakerTaggedASR(cfg, asr_model, diar_model)
300
 
301
- # Step 6: Process audio chunks iteratively (streaming inference loop)
302
  for step_num, (chunk_audio, chunk_lengths) in enumerate(streaming_buffer_iter):
303
  drop_extra_pre_encoded = (
304
  0
@@ -315,10 +317,9 @@ for step_num, (chunk_audio, chunk_lengths) in enumerate(streaming_buffer_iter):
315
  is_buffer_empty=streaming_buffer.is_buffer_empty(),
316
  drop_extra_pre_encoded=drop_extra_pre_encoded,
317
  )
318
- # Step 7: Generate final transcriptions in SegLST format (speaker-tagged with timestamps)
319
- seglst_dict_list = multispk_asr_streamer.generate_seglst_dicts_from_parallel_streaming(samples=samples)
320
 
321
- # Display speaker-tagged transcriptions with timestamps
 
322
  print(seglst_dict_list)
323
  ```
324
 
 
239
 
240
  This architecture enables the model to handle severe speech overlap by having each instance focus exclusively on one speaker, eliminating the permutation problem that affects other multitalker ASR approaches.
241
 
 
 
242
  ## NVIDIA NeMo
243
 
244
  To train, fine-tune or perform multitalker ASR with this model, you will need to install [NVIDIA NeMo](https://github.com/NVIDIA/NeMo)[7]. We recommend you install it after you've installed Cython and latest PyTorch version.
 
257
 
258
  ### Method 1. Code snippet
259
 
260
+ Load a speaker diarization model [Streaming Sortformer Diarizer v2.1](https://huggingface.co/nvidia/diar_streaming_sortformer_4spk-v2.1) for generating speaker timestamps.
261
+ A speaker diarization model is needed for tracking the speech activity of each speaker.
262
+
263
  ```python
264
  from nemo.collections.asr.models import SortformerEncLabelModel
265
  import torch
266
 
 
267
  diar_model = SortformerEncLabelModel.from_pretrained("nvidia/diar_streaming_sortformer_4spk-v2.1")
268
  diar_model.eval().to(torch.device("cuda"))
269
 
 
270
  asr_model = ASRModel.from_pretrained("nvidia/multitalker-parakeet-streaming-0.6b-v1.nemo")
271
  asr_model.eval().to(torch.device("cuda"))
272
 
273
+ """
274
+ Use the pre-defined dataclass template `MultitalkerTranscriptionConfig` from `multitalker_transcript_config.py`.
275
+ Configure the diarization model using streaming parameters:
276
+ """
277
  from multitalker_transcript_config import MultitalkerTranscriptionConfig
278
  from omegaconf import OmegaConf
 
279
  cfg = OmegaConf.structured(MultitalkerTranscriptionConfig())
280
  cfg.audio_file = "/path/to/your/audio.wav"
281
  cfg.output_path = "/path/to/output_transcription.json"
282
 
 
283
  diar_model = MultitalkerTranscriptionConfig.init_diar_model(cfg, diar_model)
284
 
285
+ """
286
+ Load a streaming audio buffer to simulate a real-time audio session.
287
+ """
288
  from nemo.collections.asr.parts.utils.streaming_utils import CacheAwareStreamingAudioBuffer
289
 
 
290
  samples = [{'audio_filepath': cfg.audio_file}]
291
  streaming_buffer = CacheAwareStreamingAudioBuffer(
292
  model=asr_model,
 
295
  )
296
  streaming_buffer.append_audio_file(audio_filepath=cfg.audio_file, stream_id=-1)
297
  streaming_buffer_iter = iter(streaming_buffer)
298
+ """
299
+ Use a helper class `SpeakerTaggedASR` that handles all ASR and diarization cache data for streaming.
300
+ """
301
  from nemo.collections.asr.parts.utils.multispk_transcribe_utils import SpeakerTaggedASR
 
 
302
  multispk_asr_streamer = SpeakerTaggedASR(cfg, asr_model, diar_model)
303
 
 
304
  for step_num, (chunk_audio, chunk_lengths) in enumerate(streaming_buffer_iter):
305
  drop_extra_pre_encoded = (
306
  0
 
317
  is_buffer_empty=streaming_buffer.is_buffer_empty(),
318
  drop_extra_pre_encoded=drop_extra_pre_encoded,
319
  )
 
 
320
 
321
+ # Generate the speaker-tagged transcript and print it.
322
+ seglst_dict_list = multispk_asr_streamer.generate_seglst_dicts_from_parallel_streaming(samples=samples)
323
  print(seglst_dict_list)
324
  ```
325