Spaces:
Paused
Paused
| import base64 | |
| import io | |
| import os | |
| import tempfile | |
| import wave | |
| import torch | |
| import numpy as np | |
| from typing import List | |
| from pydantic import BaseModel | |
| import spaces | |
| from TTS.tts.configs.xtts_config import XttsConfig | |
| from TTS.tts.models.xtts import Xtts | |
| from TTS.utils.generic_utils import get_user_data_dir | |
| from TTS.utils.manage import ModelManager | |
| os.environ["COQUI_TOS_AGREED"] = "1" | |
| torch.set_num_threads(int(os.environ.get("NUM_THREADS", os.cpu_count()))) | |
| device = torch.device("cuda" if os.environ.get("USE_CPU", "0") == "0" else "cpu") | |
| if not torch.cuda.is_available() and device == "cuda": | |
| raise RuntimeError("CUDA device unavailable, please use Dockerfile.cpu instead.") | |
| custom_model_path = os.environ.get("CUSTOM_MODEL_PATH", "/app/tts_models") | |
| if os.path.exists(custom_model_path) and os.path.isfile(custom_model_path + "/config.json"): | |
| model_path = custom_model_path | |
| print("Loading custom model from", model_path, flush=True) | |
| else: | |
| print("Loading default model", flush=True) | |
| model_name = "tts_models/multilingual/multi-dataset/xtts_v2" | |
| print("Downloading XTTS Model:", model_name, flush=True) | |
| ModelManager().download_model(model_name) | |
| model_path = os.path.join(get_user_data_dir("tts"), model_name.replace("/", "--")) | |
| print("XTTS Model downloaded", flush=True) | |
| print("Loading XTTS", flush=True) | |
| config = XttsConfig() | |
| config.load_json(os.path.join(model_path, "config.json")) | |
| model = Xtts.init_from_config(config) | |
| model.load_checkpoint(config, checkpoint_dir=model_path, eval=True, use_deepspeed=True if device == "cuda" else False) | |
| model.to(device) | |
| print("XTTS Loaded.", flush=True) | |
| print("Running XTTS Server ...", flush=True) | |
| # @app.post("/clone_speaker") | |
| def predict_speaker(wav_file): | |
| """Compute conditioning inputs from reference audio file.""" | |
| temp_audio_name = next(tempfile._get_candidate_names()) | |
| with open(temp_audio_name, "wb") as temp, torch.inference_mode(): | |
| temp.write(io.BytesIO(wav_file.read()).getbuffer()) | |
| gpt_cond_latent, speaker_embedding = model.get_conditioning_latents( | |
| temp_audio_name | |
| ) | |
| return { | |
| "gpt_cond_latent": gpt_cond_latent.cpu().squeeze().half().tolist(), | |
| "speaker_embedding": speaker_embedding.cpu().squeeze().half().tolist(), | |
| } | |
| def postprocess(wav): | |
| """Post process the output waveform""" | |
| if isinstance(wav, list): | |
| wav = torch.cat(wav, dim=0) | |
| wav = wav.clone().detach().cpu().numpy() | |
| wav = wav[None, : int(wav.shape[0])] | |
| wav = np.clip(wav, -1, 1) | |
| wav = (wav * 32767).astype(np.int16) | |
| return wav | |
| def encode_audio_common( | |
| frame_input, encode_base64=True, sample_rate=24000, sample_width=2, channels=1 | |
| ): | |
| """Return base64 encoded audio""" | |
| wav_buf = io.BytesIO() | |
| with wave.open(wav_buf, "wb") as vfout: | |
| vfout.setnchannels(channels) | |
| vfout.setsampwidth(sample_width) | |
| vfout.setframerate(sample_rate) | |
| vfout.writeframes(frame_input) | |
| wav_buf.seek(0) | |
| if encode_base64: | |
| b64_encoded = base64.b64encode(wav_buf.getbuffer()).decode("utf-8") | |
| return b64_encoded | |
| else: | |
| return wav_buf.read() | |
| class StreamingInputs(BaseModel): | |
| speaker_embedding: List[float] | |
| gpt_cond_latent: List[List[float]] | |
| text: str | |
| language: str | |
| add_wav_header: bool = True | |
| stream_chunk_size: str = "20" | |
| # | |
| #def predict_streaming_generator(parsed_input: dict = Body(...)): | |
| # speaker_embedding = torch.tensor(parsed_input.speaker_embedding).unsqueeze(0).unsqueeze(-1) | |
| # gpt_cond_latent = torch.tensor(parsed_input.gpt_cond_latent).reshape((-1, 1024)).unsqueeze(0) | |
| # text = parsed_input.text | |
| # language = parsed_input.language | |
| # | |
| # stream_chunk_size = int(parsed_input.stream_chunk_size) | |
| # add_wav_header = parsed_input.add_wav_header | |
| # | |
| # | |
| # chunks = model.inference_stream( | |
| # text, | |
| # language, | |
| # gpt_cond_latent, | |
| # speaker_embedding, | |
| # stream_chunk_size=stream_chunk_size, | |
| # enable_text_splitting=True | |
| # ) | |
| # | |
| # for i, chunk in enumerate(chunks): | |
| # chunk = postprocess(chunk) | |
| # if i == 0 and add_wav_header: | |
| # yield encode_audio_common(b"", encode_base64=False) | |
| # yield chunk.tobytes() | |
| # else: | |
| # yield chunk.tobytes() | |
| # | |
| # | |
| ## @app.post("/tts_stream") | |
| #def predict_streaming_endpoint(parsed_input: StreamingInputs): | |
| # return StreamingResponse( | |
| # predict_streaming_generator(parsed_input), | |
| # media_type="audio/wav", | |
| # ) | |
| class TTSInputs(BaseModel): | |
| speaker_embedding: List[float] | |
| gpt_cond_latent: List[List[float]] | |
| text: str | |
| language: str | |
| temperature: float | |
| speed: float | |
| top_k: int | |
| top_p: float | |
| # @app.post("/tts") | |
| def predict_speech(parsed_input: TTSInputs): | |
| speaker_embedding = torch.tensor(parsed_input.speaker_embedding).unsqueeze(0).unsqueeze(-1) | |
| gpt_cond_latent = torch.tensor(parsed_input.gpt_cond_latent).reshape((-1, 1024)).unsqueeze(0) | |
| print("speaker embedding") | |
| print(speaker_embedding) | |
| print("latent") | |
| print(gpt_cond_latent) | |
| text = parsed_input.text | |
| language = parsed_input.language | |
| temperature = parsed_input.temperature | |
| speed = parsed_input.speed | |
| top_k = parsed_input.top_k | |
| top_p = parsed_input.top_p | |
| length_penalty = 1.0 | |
| repetition_penalty= 2.0 | |
| out = model.inference( | |
| text, | |
| language, | |
| gpt_cond_latent, | |
| speaker_embedding, | |
| temperature, | |
| length_penalty, | |
| repetition_penalty, | |
| top_k, | |
| top_p, | |
| speed, | |
| ) | |
| wav = postprocess(torch.tensor(out["wav"])) | |
| return encode_audio_common(wav.tobytes()) | |
| # @app.get("/studio_speakers") | |
| def get_speakers(): | |
| if hasattr(model, "speaker_manager") and hasattr(model.speaker_manager, "speakers"): | |
| return { | |
| speaker: { | |
| "speaker_embedding": model.speaker_manager.speakers[speaker]["speaker_embedding"].cpu().squeeze().half().tolist(), | |
| "gpt_cond_latent": model.speaker_manager.speakers[speaker]["gpt_cond_latent"].cpu().squeeze().half().tolist(), | |
| } | |
| for speaker in model.speaker_manager.speakers.keys() | |
| } | |
| else: | |
| return {} | |
| # @app.get("/languages") | |
| def get_languages(): | |
| return config.languages |