lataon's picture
update: interface
dba24db
# === Helper ===
import difflib
import numpy as np
from functools import lru_cache
import torchaudio
import torch
import io
import soundfile as sf
def load_audio(src):
"""Load audio from file path or datasets Audio dict, return 1D float32 at 16kHz."""
# Handle datasets Audio dict: may contain 'path' and/or 'bytes'
if isinstance(src, dict):
path = src.get("path")
audio_bytes = src.get("bytes")
if audio_bytes is not None:
data, sr = sf.read(io.BytesIO(audio_bytes), dtype='float32', always_2d=False)
arr = np.asarray(data, dtype=np.float32)
if arr.ndim > 1:
arr = arr.mean(axis=1)
if sr != 16000:
tensor = torch.from_numpy(arr).unsqueeze(0)
tensor = torchaudio.functional.resample(tensor, sr, 16000)
arr = tensor.squeeze(0).cpu().numpy().astype(np.float32)
return arr
elif path is not None:
src = path
else:
raise ValueError("Audio source missing both 'bytes' and 'path'")
# Load from file path
waveform, sr = torchaudio.load(src)
if sr != 16000:
waveform = torchaudio.functional.resample(waveform, sr, 16000)
wav = waveform.squeeze()
if wav.ndim > 1:
wav = wav.mean(axis=0) # stereo → mono
return wav.cpu().numpy().astype(np.float32)
def calc_per(pred, ref):
pred_list = pred.strip().split()
ref_list = ref.strip().split()
sm = difflib.SequenceMatcher(None, ref_list, pred_list)
dist = sum(tr[-1] for tr in sm.get_opcodes() if tr[0] != 'equal')
if len(ref_list) == 0:
return 0.0
return round(100 * dist / len(ref_list), 2)
def phonetic_distance(ipa1: str, ipa2: str) -> float:
"""
Calculates the phonetic (feature-based) distance between two IPA phonemes.
Args:
ipa1 (str): First IPA symbol (e.g., 'p')
ipa2 (str): Second IPA symbol (e.g., 'b')
Returns:
float: Feature edit distance between the two phonemes
"""
if ipa1 == ipa2:
return 1.0
return 0.0
# dst = panphon.distance.Distance()
# return max(0.0, 1.0 - dst.feature_edit_distance(ipa1, ipa2)*3)
# @lru_cache(maxsize=None)
def phonetic_distance_cached(p1, p2):
return phonetic_distance(p1, p2)
def align_sequences(seq1, seq2):
n, m = len(seq1), len(seq2)
dp = np.zeros((n + 1, m + 1), dtype=np.float32)
backtrack = np.empty((n + 1, m + 1), dtype='U1')
dp[:, 0] = np.arange(n + 1)
dp[0, :] = np.arange(m + 1)
backtrack[:, 0] = 'D'
backtrack[0, :] = 'I'
backtrack[0, 0] = ''
for i in range(1, n + 1):
for j in range(1, m + 1):
try:
cost = 1 - phonetic_distance_cached(seq1[i - 1], seq2[j - 1])
except Exception as e:
print(f"Error computing distance between '{seq1[i - 1]}' and '{seq2[j - 1]}': {e}")
cost = 1.0
options = [
(dp[i - 1][j] + 1, 'D'),
(dp[i][j - 1] + 1, 'I'),
(dp[i - 1][j - 1] + cost, 'M')
]
dp[i][j], backtrack[i][j] = min(options, key=lambda x: x[0])
# Backtracking
i, j = n, m
aligned_seq1, aligned_seq2 = [], []
while i > 0 or j > 0:
move = backtrack[i][j]
if move == 'M':
aligned_seq1.append(seq1[i - 1]); aligned_seq2.append(seq2[j - 1])
i, j = i - 1, j - 1
elif move == 'D':
aligned_seq1.append(seq1[i - 1]); aligned_seq2.append('-')
i -= 1
elif move == 'I':
aligned_seq1.append('-'); aligned_seq2.append(seq2[j - 1])
j -= 1
else:
break
aligned_seq1.reverse()
aligned_seq2.reverse()
return aligned_seq1, aligned_seq2
def score_alignment(aligned1, aligned2):
total = 0.0
scores = []
for p1, p2 in zip(aligned1, aligned2):
if p1 == '-' or p2 == '-':
scores.append(0.0)
else:
score = phonetic_distance_cached(p1, p2)
scores.append(score)
total += score
return round(total / len(scores), 3), scores
def calculate_error_rate(ref_seq, hyp_seq, unit="phoneme"):
"""
Calculate PER (phoneme error rate) or WER (word error rate).
Args:
ref_seq (list[str]): reference sequence (phonemes or words)
hyp_seq (list[str]): hypothesis sequence
unit (str): "phoneme" or "word"
Returns:
float: error rate
dict: counts of S, D, I
"""
ref_seq = ref_seq.replace(" ", "")
hyp_seq = hyp_seq.replace(" ", "")
aligned_ref, aligned_hyp = align_sequences(ref_seq, hyp_seq)
S = D = I = 0
for r, h in zip(aligned_ref, aligned_hyp):
if r == h:
continue
if r == "-": # insertion in hyp
I += 1
elif h == "-": # deletion in hyp
D += 1
else: # substitution
S += 1
N = len(ref_seq) # reference length
error_rate = (S + D + I) / N if N > 0 else 0.0
return error_rate*100, {"S": S, "D": D, "I": I, "N": N}