| import torch |
| import numpy as np |
| from torchaudio import functional as F |
| from transformers.pipelines.audio_utils import ffmpeg_read |
| from starlette.exceptions import HTTPException |
| import sys |
|
|
| |
| |
|
|
| import logging |
| logger = logging.getLogger(__name__) |
|
|
| def preprocess_inputs(inputs, sampling_rate): |
| inputs = ffmpeg_read(inputs, sampling_rate) |
|
|
| if sampling_rate != 16000: |
| inputs = F.resample( |
| torch.from_numpy(inputs), sampling_rate, 16000 |
| ).numpy() |
|
|
| if len(inputs.shape) != 1: |
| logger.error(f"Diarization pipeline expecs single channel audio, received {inputs.shape}") |
| raise HTTPException( |
| status_code=400, |
| detail=f"Diarization pipeline expecs single channel audio, received {inputs.shape}" |
| ) |
|
|
| |
| diarizer_inputs = torch.from_numpy(inputs).float() |
| diarizer_inputs = diarizer_inputs.unsqueeze(0) |
|
|
| return inputs, diarizer_inputs |
|
|
|
|
| def diarize_audio(diarizer_inputs, diarization_pipeline, parameters): |
| diarization = diarization_pipeline( |
| {"waveform": diarizer_inputs, "sample_rate": parameters.sampling_rate}, |
| num_speakers=parameters.num_speakers, |
| min_speakers=parameters.min_speakers, |
| max_speakers=parameters.max_speakers, |
| ) |
|
|
| segments = [] |
| for segment, track, label in diarization.itertracks(yield_label=True): |
| segments.append( |
| { |
| "segment": {"start": segment.start, "end": segment.end}, |
| "track": track, |
| "label": label, |
| } |
| ) |
|
|
| |
| |
| new_segments = [] |
| prev_segment = cur_segment = segments[0] |
|
|
| for i in range(1, len(segments)): |
| cur_segment = segments[i] |
|
|
| |
| if cur_segment["label"] != prev_segment["label"] and i < len(segments): |
| |
| new_segments.append( |
| { |
| "segment": { |
| "start": prev_segment["segment"]["start"], |
| "end": cur_segment["segment"]["start"], |
| }, |
| "speaker": prev_segment["label"], |
| } |
| ) |
| prev_segment = segments[i] |
|
|
| |
| new_segments.append( |
| { |
| "segment": { |
| "start": prev_segment["segment"]["start"], |
| "end": cur_segment["segment"]["end"], |
| }, |
| "speaker": prev_segment["label"], |
| } |
| ) |
|
|
| return new_segments |
|
|
|
|
| def post_process_segments_and_transcripts(new_segments, transcript, group_by_speaker) -> list: |
| |
| end_timestamps = np.array( |
| [chunk["timestamp"][-1] if chunk["timestamp"][-1] is not None else sys.float_info.max for chunk in transcript]) |
| segmented_preds = [] |
|
|
| |
| for segment in new_segments: |
| |
| end_time = segment["segment"]["end"] |
| |
| upto_idx = np.argmin(np.abs(end_timestamps - end_time)) |
|
|
| if group_by_speaker: |
| segmented_preds.append( |
| { |
| "speaker": segment["speaker"], |
| "text": "".join( |
| [chunk["text"] for chunk in transcript[: upto_idx + 1]] |
| ), |
| "timestamp": ( |
| transcript[0]["timestamp"][0], |
| transcript[upto_idx]["timestamp"][1], |
| ), |
| } |
| ) |
| else: |
| for i in range(upto_idx + 1): |
| segmented_preds.append({"speaker": segment["speaker"], **transcript[i]}) |
|
|
| |
| transcript = transcript[upto_idx + 1:] |
| end_timestamps = end_timestamps[upto_idx + 1:] |
|
|
| if len(end_timestamps) == 0: |
| break |
|
|
| return segmented_preds |
|
|
|
|
| def diarize(diarization_pipeline, file, parameters, asr_outputs): |
| _, diarizer_inputs = preprocess_inputs(file, parameters.sampling_rate) |
|
|
| segments = diarize_audio( |
| diarizer_inputs, |
| diarization_pipeline, |
| parameters |
| ) |
|
|
| return post_process_segments_and_transcripts( |
| segments, asr_outputs["chunks"], group_by_speaker=False |
| ) |