lataon's picture
add: note
aa67214
import pandas as pd
from utils.load_model import run_hubert_base, run_whisper, run_model, run_timit, run_wavlm_large_phoneme, run_gruut
from utils.audio_process import calculate_error_rate, load_audio
from utils.cmu_process import clean_cmu, cmu_to_ipa, text_to_phoneme
from constants import DATASETS, FINAL_SIZE
from datasets import load_dataset, Audio
import argparse
# Map model names to their runner functions
MODEL_RUNNERS = {
"HuBERT-Base": run_hubert_base,
"Whisper": run_whisper,
"HuBERT fine-tuned": run_model,
"Timit": run_timit,
"WavLM": run_wavlm_large_phoneme,
"LJSpeech Gruut": run_gruut,
}
def set_output(model, pre_pho, ref_pho, duration, per, score):
return {
"model": model,
"phonemes": pre_pho,
"ref_phonemes": ref_pho,
"duration": duration,
"PER": per,
"score": score
}
def get_output(model, wav, reference_phoneme):
"""
Run the given model, compute error rate, and return formatted output.
"""
if model not in MODEL_RUNNERS:
raise ValueError(f"Unknown model: {model}")
run_func = MODEL_RUNNERS[model]
phonemes, dur = run_func(wav)
per, score = calculate_error_rate(reference_phoneme, phonemes)
return set_output(model, phonemes, reference_phoneme, dur, per, score)
def benchmark_all(example):
"""
Run all models on a single dataset example in parallel.
"""
# Load waveform manually to avoid datasets' torchcodec dependency
wav = load_audio(example["audio"])
reference_phoneme = example["phonetic"]
reference_phoneme = cmu_to_ipa(clean_cmu(reference_phoneme))
# Run all models in parallel using ThreadPoolExecutor
from concurrent.futures import ThreadPoolExecutor
models = [
"HuBERT-Base",
"Whisper",
"HuBERT fine-tuned",
"Timit",
"WavLM",
"LJSpeech Gruut"
]
with ThreadPoolExecutor(max_workers=len(models)) as executor:
futures = [
executor.submit(get_output, model, wav, reference_phoneme)
for model in models
]
results = [future.result() for future in futures]
return pd.DataFrame(results)
def benchmark_dataset(dataset):
"""
Run benchmark_all on each sample and compute average PER and duration per model.
"""
all_results = []
for example in dataset:
df = benchmark_all(example)
all_results.append(df)
full_df = pd.concat(all_results, ignore_index=True)
# Compute average PER and duration per model
avg_stats = (
full_df.groupby("model")[["PER", "duration"]]
.mean()
.reset_index()
.rename(columns={"PER": "Average PER", "duration": "Average Duration (s)"})
)
return full_df, avg_stats
def load_dataset_with_limits(dataset_config, max_samples=None, use_streaming=False):
"""
Load a dataset with optional size limits and streaming.
Args:
dataset_config: Dictionary containing dataset configuration
max_samples: Maximum number of samples to load (None for no limit)
use_streaming: Whether to use streaming for large datasets
Returns:
Dataset object
"""
try:
# Prepare load_dataset arguments
load_args = {
"path": dataset_config["name"],
"split": dataset_config["split"]
}
# Add config if specified
if "config" in dataset_config:
load_args["name"] = dataset_config["config"]
# Add streaming if requested
if use_streaming:
load_args["streaming"] = True
print(f"Loading {dataset_config['name']} with streaming...")
else:
print(f"Loading {dataset_config['name']}...")
dataset = load_dataset(**load_args)
# Apply size limits
if max_samples is not None:
print(f"Limiting dataset to {max_samples} samples...")
if use_streaming:
dataset = dataset.take(max_samples)
else:
dataset = dataset.select(range(min(max_samples, len(dataset))))
return dataset
except Exception as e:
print(f"[warn] skip dataset {dataset_config['name']}: {e}")
return None
def parse_cli_args():
"""
Parse and return CLI arguments for the evaluation script.
"""
parser = argparse.ArgumentParser(description='Phoneme Detection Evaluation')
parser.add_argument('--max-samples', type=int, default=None,
help='Override max_samples for all datasets')
parser.add_argument('--dataset', type=str, default=None,
help='Process only specific dataset (by name)')
return parser.parse_args()
def cast_audio_column_safely(dataset):
"""
Ensure the dataset's 'audio' column is set to non-decoding Audio.
"""
try:
dataset = dataset.cast_column("audio", Audio(decode=False))
except Exception:
pass
return dataset
def prepare_dataset_for_evaluation(dataset, dataset_config, max_samples):
"""
Normalize, deduplicate, and filter dataset examples for evaluation.
Handles both streaming and non-streaming datasets.
Returns a finalized small dataset suitable for benchmarking.
"""
field = dataset_config["field"]
use_streaming = dataset_config.get("use_streaming", False)
if use_streaming:
print("Processing streaming dataset...")
valid_samples = []
streaming_limit = min(max_samples, FINAL_SIZE)
for example in dataset:
if field == "text":
phonetic_text = text_to_phoneme(example[field])
example = {**example, "phonetic": phonetic_text}
current_field = "phonetic"
else:
current_field = field
if current_field in example:
phoneme_tokens = example[current_field].split()
if len(phoneme_tokens) >= 10:
valid_samples.append(example)
if len(valid_samples) >= streaming_limit:
break
print(f"Found {len(valid_samples)} valid samples")
if len(valid_samples) == 0:
print("No valid samples found, skipping dataset")
return None
from datasets import Dataset
dataset_final = Dataset.from_list(valid_samples)
return dataset_final
else:
if field == "text":
dataset = dataset.map(lambda x: {"phonetic": text_to_phoneme(x[field])})
field = "phonetic"
unique_texts = dataset.unique(field)
print("Unique phonetic strings (", dataset_config["name"], "):", len(unique_texts))
dataset_unique = dataset.filter(lambda x: x[field] in unique_texts)
def is_valid(example):
phoneme_tokens = example[field].split()
return len(phoneme_tokens) >= 10
dataset_filtered = dataset_unique.filter(is_valid)
final_size = min(FINAL_SIZE, len(dataset_filtered))
dataset_final = dataset_filtered.shuffle(seed=42).select(range(final_size))
return dataset_final
def evaluate_dataset(dataset_final):
"""
Run benchmarking on a capped subset of the dataset and return both
the full per-example results and the aggregated stats per model.
"""
benchmark_size = min(FINAL_SIZE, len(dataset_final))
return benchmark_dataset(dataset_final.select(range(benchmark_size)))
def update_aggregates(per_model_results, avg_stats, dataset_name):
"""
Update the aggregate dictionary per model with results from one dataset.
"""
dataset_key = dataset_name.split("/")[-1]
for _, row in avg_stats.iterrows():
model_name = str(row["model"]).replace(" ", "-")
per = float(row["Average PER"]) if row["Average PER"] is not None else None
avg_dur = float(row["Average Duration (s)"]) if row["Average Duration (s)"] is not None else None
if model_name not in per_model_results:
per_model_results[model_name] = {}
per_model_results[model_name][dataset_key] = {"per": per, "avg_duration": avg_dur}
def save_leaderboard_results(per_model_results, results_dir="eval-results"):
"""
Persist one JSON file per model for the leaderboard app to consume.
"""
import json, os, time
os.makedirs(results_dir, exist_ok=True)
timestamp = int(time.time())
for model_name, task_results in per_model_results.items():
org_model = f"{model_name}"
payload = {
"config": {
"model_name": org_model,
"model_dtype": "float32",
"model_sha": ""
},
"results": task_results
}
out_path = os.path.join(results_dir, f"results_{timestamp}_{model_name}.json")
with open(out_path, "w", encoding="utf-8") as f:
json.dump(payload, f, ensure_ascii=False, indent=2)
print(f"Saved leaderboard result: {out_path}")
def process_single_dataset(dataset_config, args, per_model_results):
"""
Load, normalize, evaluate a single dataset and update aggregates.
"""
if args.dataset and args.dataset not in dataset_config["name"]:
return
max_samples = args.max_samples if args.max_samples is not None else dataset_config.get("max_samples")
use_streaming = dataset_config.get("use_streaming", False)
dataset = load_dataset_with_limits(
dataset_config,
max_samples=max_samples,
use_streaming=use_streaming
)
if dataset is None:
return
dataset = cast_audio_column_safely(dataset)
dataset_final = prepare_dataset_for_evaluation(dataset, dataset_config, max_samples)
if dataset_final is None:
return
print(dataset_final)
print("Final size:", len(dataset_final))
full_results, avg_stats = evaluate_dataset(dataset_final)
print("Average Statistic per model (", dataset_config["name"], "):")
print(avg_stats)
update_aggregates(per_model_results, avg_stats, dataset_config["name"])
def main():
args = parse_cli_args()
per_model_results = {}
for dataset_config in DATASETS:
process_single_dataset(dataset_config, args, per_model_results)
save_leaderboard_results(per_model_results)
if __name__ == "__main__":
main()