| import os |
| import csv |
| import glob |
| from tqdm import tqdm |
| import torch |
| import torchaudio |
| from torchmetrics.audio import ScaleInvariantSignalDistortionRatio, SignalDistortionRatio |
|
|
|
|
| def calculate_sdr_and_sisdr(original_audio_path, separated_audio_paths): |
| """ |
| 计算叠加的音频与原始音频之间的 SDR 和 SI-SDR。 |
| |
| 参数: |
| - original_audio_path: str, 原始音频文件路径。 |
| - separated_audio_paths: List[str], 分割后的音频片段文件路径列表。 |
| |
| 返回: |
| - sdr: float, SDR 值。 |
| - sisdr: float, SI-SDR 值。 |
| """ |
| |
| original_waveform, sample_rate = torchaudio.load(original_audio_path) |
|
|
| |
| combined_waveform = None |
|
|
| |
| for path in separated_audio_paths: |
| separated_waveform, _ = torchaudio.load(path) |
|
|
| |
| min_length = min(original_waveform.size(1), separated_waveform.size(1)) |
| separated_waveform = separated_waveform[:, :min_length] |
|
|
| |
| if combined_waveform is None: |
| combined_waveform = separated_waveform |
| else: |
| combined_waveform = combined_waveform[:, :min_length] + separated_waveform |
|
|
| |
| min_length = min(original_waveform.size(1), combined_waveform.size(1)) |
| original_waveform = original_waveform[:, :min_length] |
| combined_waveform = combined_waveform[:, :min_length] |
|
|
| |
| sisdr_metric = ScaleInvariantSignalDistortionRatio() |
| sisdr = sisdr_metric(combined_waveform, original_waveform).item() |
|
|
| |
| sdr_metric = SignalDistortionRatio() |
| sdr = sdr_metric(combined_waveform, original_waveform).item() |
|
|
| |
| |
|
|
| return sdr, sisdr |
|
|
|
|
| if __name__ == "__main__": |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| dset = 'balanced_train_segments' |
| |
| |
| src_data_root = r'/data/sound/audioset/audios_32k' |
| sep_data_root = r'data_engine_infer/audioset_separation_child_label' |
| |
| writer = csv.writer(open(os.path.join(sep_data_root, dset + '.csv'), 'w')) |
| writer.writerow(['video', 'sdr', 'sisdr']) |
| for video_path in tqdm(glob.glob(os.path.join(sep_data_root, dset, '*'))): |
| video = video_path.split('/')[-1] |
| original_audio_path = os.path.join(src_data_root, dset, video + '.wav') |
| separated_audio_paths = glob.glob(video_path + '/*') |
| sdr, sisdr = calculate_sdr_and_sisdr(original_audio_path, separated_audio_paths) |
| writer.writerow([video, f'{sdr:.3f}', f'{sisdr:.3f}']) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|