| | from typing import Dict, Any |
| | import tempfile |
| | import torchaudio |
| | import soundfile as sf |
| | import re |
| | from num2words import num2words |
| | from f5_tts.model import DiT |
| | from f5_tts.infer.utils_infer import ( |
| | load_vocoder, |
| | load_model, |
| | preprocess_ref_audio_text, |
| | infer_process, |
| | remove_silence_for_generated_wav, |
| | ) |
| | import base64 |
| | import io |
| | import numpy as np |
| | from huggingface_hub import hf_hub_download |
| | import traceback |
| |
|
| | class EndpointHandler: |
| | def __init__(self, path=""): |
| | self.vocoder = load_vocoder() |
| | model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) |
| | |
| | model_path = hf_hub_download( |
| | repo_id="jpgallegoar/F5-Spanish", |
| | filename="model_1200000.safetensors" |
| | ) |
| | |
| | self.ema_model = load_model(DiT, model_cfg, model_path) |
| |
|
| | def traducir_numero_a_texto(self, texto): |
| | texto_separado = re.sub(r'([A-Za-z])(\d)', r'\1 \2', texto) |
| | texto_separado = re.sub(r'(\d)([A-Za-z])', r'\1 \2', texto_separado) |
| | def reemplazar_numero(match): |
| | numero = match.group() |
| | return num2words(int(numero), lang='es') |
| | return re.sub(r'\b\d+\b', reemplazar_numero, texto_separado) |
| |
|
| | def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
| | try: |
| | ref_audio_base64 = data.get("ref_audio") |
| | if not ref_audio_base64: |
| | return { |
| | "success": False, |
| | "error": "Missing required field: 'ref_audio'" |
| | } |
| |
|
| | |
| | try: |
| | audio_bytes = base64.b64decode(ref_audio_base64) |
| | with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio_file: |
| | temp_audio_file.write(audio_bytes) |
| | temp_audio_path = temp_audio_file.name |
| | except Exception as e: |
| | return { |
| | "success": False, |
| | "error": f"Invalid audio data: {type(e).__name__}: {str(e)}" |
| | } |
| |
|
| | ref_text = data.get("ref_text", "") |
| | gen_text = data.get("gen_text", "") |
| | if not gen_text: |
| | return { |
| | "success": False, |
| | "error": "Missing required field: 'gen_text'" |
| | } |
| |
|
| | remove_silence = data.get("remove_silence", True) |
| | cross_fade_duration = data.get("cross_fade_duration", 0.15) |
| | speed = data.get("speed", 1.0) |
| |
|
| | ref_audio, ref_text = preprocess_ref_audio_text(temp_audio_path, ref_text, show_info=print) |
| |
|
| | if not gen_text.startswith(" "): |
| | gen_text = " " + gen_text |
| | if not gen_text.endswith(". "): |
| | gen_text += ". " |
| | gen_text = self.traducir_numero_a_texto(gen_text.lower()) |
| |
|
| | final_wave, final_sample_rate, _ = infer_process( |
| | ref_audio, |
| | ref_text, |
| | gen_text, |
| | self.ema_model, |
| | self.vocoder, |
| | cross_fade_duration=cross_fade_duration, |
| | speed=speed, |
| | show_info=print, |
| | ) |
| |
|
| | if remove_silence: |
| | with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f: |
| | sf.write(f.name, final_wave, final_sample_rate) |
| | remove_silence_for_generated_wav(f.name) |
| | final_wave, _ = torchaudio.load(f.name) |
| | final_wave = final_wave.squeeze().cpu().numpy() |
| |
|
| | with io.BytesIO() as buffer: |
| | sf.write(buffer, final_wave, final_sample_rate, format="WAV") |
| | buffer.seek(0) |
| | encoded_audio = base64.b64encode(buffer.read()).decode("utf-8") |
| |
|
| | return { |
| | "success": True, |
| | "audio_base64": encoded_audio |
| | } |
| |
|
| | except Exception as e: |
| | print("==== Exception Traceback ====") |
| | traceback.print_exc() |
| | print("==== End Traceback ====") |
| | return { |
| | "success": False, |
| | "error": f"{type(e).__name__}: {str(e)}" |
| | } |
| |
|