# === Helper === import difflib import numpy as np from functools import lru_cache import torchaudio import torch import io import soundfile as sf def load_audio(src): """Load audio from file path or datasets Audio dict, return 1D float32 at 16kHz.""" # Handle datasets Audio dict: may contain 'path' and/or 'bytes' if isinstance(src, dict): path = src.get("path") audio_bytes = src.get("bytes") if audio_bytes is not None: data, sr = sf.read(io.BytesIO(audio_bytes), dtype='float32', always_2d=False) arr = np.asarray(data, dtype=np.float32) if arr.ndim > 1: arr = arr.mean(axis=1) if sr != 16000: tensor = torch.from_numpy(arr).unsqueeze(0) tensor = torchaudio.functional.resample(tensor, sr, 16000) arr = tensor.squeeze(0).cpu().numpy().astype(np.float32) return arr elif path is not None: src = path else: raise ValueError("Audio source missing both 'bytes' and 'path'") # Load from file path waveform, sr = torchaudio.load(src) if sr != 16000: waveform = torchaudio.functional.resample(waveform, sr, 16000) wav = waveform.squeeze() if wav.ndim > 1: wav = wav.mean(axis=0) # stereo → mono return wav.cpu().numpy().astype(np.float32) def calc_per(pred, ref): pred_list = pred.strip().split() ref_list = ref.strip().split() sm = difflib.SequenceMatcher(None, ref_list, pred_list) dist = sum(tr[-1] for tr in sm.get_opcodes() if tr[0] != 'equal') if len(ref_list) == 0: return 0.0 return round(100 * dist / len(ref_list), 2) def phonetic_distance(ipa1: str, ipa2: str) -> float: """ Calculates the phonetic (feature-based) distance between two IPA phonemes. Args: ipa1 (str): First IPA symbol (e.g., 'p') ipa2 (str): Second IPA symbol (e.g., 'b') Returns: float: Feature edit distance between the two phonemes """ if ipa1 == ipa2: return 1.0 return 0.0 # dst = panphon.distance.Distance() # return max(0.0, 1.0 - dst.feature_edit_distance(ipa1, ipa2)*3) # @lru_cache(maxsize=None) def phonetic_distance_cached(p1, p2): return phonetic_distance(p1, p2) def align_sequences(seq1, seq2): n, m = len(seq1), len(seq2) dp = np.zeros((n + 1, m + 1), dtype=np.float32) backtrack = np.empty((n + 1, m + 1), dtype='U1') dp[:, 0] = np.arange(n + 1) dp[0, :] = np.arange(m + 1) backtrack[:, 0] = 'D' backtrack[0, :] = 'I' backtrack[0, 0] = '' for i in range(1, n + 1): for j in range(1, m + 1): try: cost = 1 - phonetic_distance_cached(seq1[i - 1], seq2[j - 1]) except Exception as e: print(f"Error computing distance between '{seq1[i - 1]}' and '{seq2[j - 1]}': {e}") cost = 1.0 options = [ (dp[i - 1][j] + 1, 'D'), (dp[i][j - 1] + 1, 'I'), (dp[i - 1][j - 1] + cost, 'M') ] dp[i][j], backtrack[i][j] = min(options, key=lambda x: x[0]) # Backtracking i, j = n, m aligned_seq1, aligned_seq2 = [], [] while i > 0 or j > 0: move = backtrack[i][j] if move == 'M': aligned_seq1.append(seq1[i - 1]); aligned_seq2.append(seq2[j - 1]) i, j = i - 1, j - 1 elif move == 'D': aligned_seq1.append(seq1[i - 1]); aligned_seq2.append('-') i -= 1 elif move == 'I': aligned_seq1.append('-'); aligned_seq2.append(seq2[j - 1]) j -= 1 else: break aligned_seq1.reverse() aligned_seq2.reverse() return aligned_seq1, aligned_seq2 def score_alignment(aligned1, aligned2): total = 0.0 scores = [] for p1, p2 in zip(aligned1, aligned2): if p1 == '-' or p2 == '-': scores.append(0.0) else: score = phonetic_distance_cached(p1, p2) scores.append(score) total += score return round(total / len(scores), 3), scores def calculate_error_rate(ref_seq, hyp_seq, unit="phoneme"): """ Calculate PER (phoneme error rate) or WER (word error rate). Args: ref_seq (list[str]): reference sequence (phonemes or words) hyp_seq (list[str]): hypothesis sequence unit (str): "phoneme" or "word" Returns: float: error rate dict: counts of S, D, I """ ref_seq = ref_seq.replace(" ", "") hyp_seq = hyp_seq.replace(" ", "") aligned_ref, aligned_hyp = align_sequences(ref_seq, hyp_seq) S = D = I = 0 for r, h in zip(aligned_ref, aligned_hyp): if r == h: continue if r == "-": # insertion in hyp I += 1 elif h == "-": # deletion in hyp D += 1 else: # substitution S += 1 N = len(ref_seq) # reference length error_rate = (S + D + I) / N if N > 0 else 0.0 return error_rate*100, {"S": S, "D": D, "I": I, "N": N}