QingyuLiu1 commited on
Commit
ad3a05c
·
1 Parent(s): fbabfd7
Files changed (5) hide show
  1. app.py +16 -7
  2. module_clf5.py +223 -0
  3. requirements.txt +2 -1
  4. utils_clf5_space.py +302 -0
  5. utils_clf5space.py +0 -0
app.py CHANGED
@@ -4,18 +4,23 @@ import spaces
4
  import torch
5
  from cached_path import cached_path
6
  from f5_tts.infer.utils_infer import (
7
- infer_process,
8
  load_model,
9
  load_vocoder,
10
  preprocess_ref_audio_text,
11
  )
12
  from f5_tts.model import DiT
 
 
 
 
13
 
14
 
15
  vocoder = load_vocoder()
16
 
17
  # Cross-Lingual F5-TTS configuration
18
  model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
 
 
19
  vocab_path = str(cached_path("hf://QingyuLiu1/Cross-Lingual_F5-TTS/vocab.txt")) # Using the same vocab as base model
20
 
21
  # Load Cross-Lingual F5-TTS model
@@ -26,6 +31,11 @@ cross_lingual_model = load_model(
26
  vocab_file=vocab_path,
27
  )
28
 
 
 
 
 
 
29
 
30
  @spaces.GPU
31
  def infer(
@@ -35,7 +45,7 @@ def infer(
35
  show_info=gr.Info,
36
  ):
37
  # Fixed reference text
38
- ref_text = "Hello World! I'm Qingyu Liu."
39
 
40
  if not ref_audio_orig or not gen_text.strip():
41
  gr.Warning("Please ensure [Reference Audio] and [Text to Generate] are both provided.")
@@ -47,11 +57,11 @@ def infer(
47
  torch.manual_seed(seed)
48
  used_seed = seed
49
 
50
- ref_audio, ref_text_processed = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
51
-
52
- final_wave, final_sample_rate, _ = infer_process(
 
53
  ref_audio,
54
- ref_text_processed,
55
  gen_text,
56
  cross_lingual_model,
57
  vocoder,
@@ -66,7 +76,6 @@ with gr.Blocks() as app_basic_tts:
66
  with gr.Row():
67
  with gr.Column():
68
  ref_wav_input = gr.Audio(label="Reference Audio", type="filepath")
69
- # Removed ref_txt_input - using fixed text instead
70
  gen_txt_input = gr.Textbox(label="Text to Generate")
71
  generate_btn = gr.Button("Synthesize", variant="primary")
72
  with gr.Row():
 
4
  import torch
5
  from cached_path import cached_path
6
  from f5_tts.infer.utils_infer import (
 
7
  load_model,
8
  load_vocoder,
9
  preprocess_ref_audio_text,
10
  )
11
  from f5_tts.model import DiT
12
+ from utils_clf5_space import (
13
+ load_model_sp,
14
+ infer_process_clf5,
15
+ )
16
 
17
 
18
  vocoder = load_vocoder()
19
 
20
  # Cross-Lingual F5-TTS configuration
21
  model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
22
+ model_cfg_sp = dict(dim=512, depth=6, heads=8, ff_mult=4)
23
+ mel_spec_kwargs = dict(target_sample_rate=24000, n_mel_channels=100, hop_length=256, win_length=1024, n_fft=1024, mel_spec_type="vocos")
24
  vocab_path = str(cached_path("hf://QingyuLiu1/Cross-Lingual_F5-TTS/vocab.txt")) # Using the same vocab as base model
25
 
26
  # Load Cross-Lingual F5-TTS model
 
31
  vocab_file=vocab_path,
32
  )
33
 
34
+ speakingrate_model = load_model_sp(
35
+ model_cfg_sp,
36
+ str(cached_path("hf://QingyuLiu1/Cross-Lingual_F5-TTS/syllables_gce_20000.safetensors")),
37
+ mel_spec_kwargs,
38
+ )
39
 
40
  @spaces.GPU
41
  def infer(
 
45
  show_info=gr.Info,
46
  ):
47
  # Fixed reference text
48
+ ref_text = "Useless here."
49
 
50
  if not ref_audio_orig or not gen_text.strip():
51
  gr.Warning("Please ensure [Reference Audio] and [Text to Generate] are both provided.")
 
57
  torch.manual_seed(seed)
58
  used_seed = seed
59
 
60
+ ref_audio, _ = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
61
+
62
+ final_wave, final_sample_rate, _ = infer_process_clf5(
63
+ speakingrate_model,
64
  ref_audio,
 
65
  gen_text,
66
  cross_lingual_model,
67
  vocoder,
 
76
  with gr.Row():
77
  with gr.Column():
78
  ref_wav_input = gr.Audio(label="Reference Audio", type="filepath")
 
79
  gen_txt_input = gr.Textbox(label="Text to Generate")
80
  generate_btn = gr.Button("Synthesize", variant="primary")
81
  with gr.Row():
module_clf5.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+ from typing import Literal
7
+
8
+ from f5_tts.model.modules import MelSpec
9
+ from f5_tts.model.utils import (
10
+ default,
11
+ exists,
12
+ lens_to_mask,
13
+ )
14
+
15
+ from x_transformers.x_transformers import RotaryEmbedding
16
+ from f5_tts.model.modules import (
17
+ ConvPositionEmbedding,
18
+ Attention,
19
+ AttnProcessor,
20
+ FeedForward
21
+ )
22
+
23
+ class SpeedPredictorLayer(nn.Module):
24
+ def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, qk_norm=None, pe_attn_head=None):
25
+ super().__init__()
26
+
27
+ self.attn = Attention(
28
+ processor=AttnProcessor(pe_attn_head=pe_attn_head),
29
+ dim=dim,
30
+ heads=heads,
31
+ dim_head=dim_head,
32
+ dropout=dropout,
33
+ qk_norm=qk_norm,
34
+ )
35
+
36
+ self.ln1 = nn.LayerNorm(dim, elementwise_affine=True, eps=1e-6)
37
+ self.ln2 = nn.LayerNorm(dim, elementwise_affine=True, eps=1e-6)
38
+ self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
39
+
40
+ def forward(self, x, mask=None, rope=None): # x: noised input, t: time embedding
41
+ # mha sublayer (Pre norm)
42
+ x_norm_atte = self.ln1(x)
43
+ attn_output = self.attn(x=x_norm_atte, mask=mask, rope=rope)
44
+ x = x + attn_output
45
+
46
+ # ffn sublayer (Pre norm)
47
+ x_norm_ffn = self.ln2(x)
48
+ ffn_output = self.ff(x=x_norm_ffn)
49
+ output = x + ffn_output
50
+ return output
51
+
52
+ class GaussianCrossEntropyLoss(nn.Module):
53
+ def __init__(self, num_classes, sigma_factor=2.0):
54
+ super().__init__()
55
+ self.num_classes = num_classes
56
+ self.sigma_factor = sigma_factor
57
+
58
+ def forward(self, y_pred, y_true, device): # y_pred.shape: [b, num_classes] y_true.shape: [b]
59
+ # gt
60
+ centers = y_true.unsqueeze(-1) # shape: [b, 1]
61
+
62
+ # 位置索引
63
+ positions = torch.arange(self.num_classes, device=device).float() # shape: [num_classes]
64
+ positions = positions.expand(y_true.shape[0], -1) # shape: [b, num_classes]
65
+
66
+ # sigma
67
+ sigma = self.sigma_factor * torch.ones_like(y_true, device=device).float()
68
+
69
+ # 高斯分布
70
+ diff = positions - centers # (c-gt).shape: [b, num_classes]
71
+ y_true_soft = torch.exp(-(diff.pow(2) / (2 * sigma.pow(2).unsqueeze(-1)))) # shape: [b, num_classes]
72
+
73
+ loss = -(y_true_soft * F.log_softmax(y_pred, dim=-1)).sum(dim=-1).mean()
74
+
75
+ return loss
76
+
77
+ class SpeedTransformer(nn.Module):
78
+ def __init__(
79
+ self,
80
+ dim,
81
+ depth=6,
82
+ heads=8,
83
+ dropout=0.1,
84
+ ff_mult=4,
85
+ qk_norm=None,
86
+ pe_attn_head=None,
87
+ mel_dim=100,
88
+ num_classes=32,
89
+ ):
90
+ super().__init__()
91
+ self.dim_head = dim // heads
92
+ self.num_classes = num_classes
93
+ self.mel_proj = nn.Linear(mel_dim, dim)
94
+ self.conv_layer = ConvPositionEmbedding(dim=dim)
95
+ self.rotary_embed = RotaryEmbedding(self.dim_head)
96
+ self.transformer_blocks = nn.ModuleList([
97
+ SpeedPredictorLayer(
98
+ dim=dim,
99
+ heads=heads,
100
+ dim_head = self.dim_head,
101
+ ff_mult=ff_mult,
102
+ dropout=dropout,
103
+ qk_norm=qk_norm,
104
+ pe_attn_head=pe_attn_head
105
+ ) for _ in range(depth)
106
+ ])
107
+ self.pool = nn.Sequential(
108
+ nn.Linear(dim, dim),
109
+ nn.Tanh(),
110
+ nn.Linear(dim, 1)
111
+ )
112
+ self.classifier = nn.Sequential(
113
+ nn.LayerNorm(dim),
114
+ nn.Linear(dim, dim),
115
+ nn.GELU(), # nn.ReLU()
116
+ nn.Linear(dim, num_classes)
117
+ )
118
+ # self.initialize_weights()
119
+
120
+ # def initialize_weights(self):
121
+
122
+ def forward(self, x, lens): # x.shape = [b, seq_len, d_mel]
123
+ seq_len = x.shape[1]
124
+ mask = lens_to_mask(lens, length=seq_len) # shape = [b, seq_len]
125
+
126
+ x = self.mel_proj(x) # shape = [b, seq_len, h]
127
+ x = self.conv_layer(x, mask) # shape = [b, seq_len, h]
128
+
129
+ rope = self.rotary_embed.forward_from_seq_len(seq_len)
130
+ for block in self.transformer_blocks:
131
+ x = block(x, mask=mask, rope=rope) # shape = [b, seq_len, h]
132
+
133
+ # sequence pooling
134
+ weights = self.pool(x) # shape = [b, seq_len, 1]
135
+ # 将 padding 位置的 weights 设为 -inf
136
+ weights.masked_fill_(~mask.unsqueeze(-1), -torch.finfo(weights.dtype).max)
137
+ weights = F.softmax(weights, dim=1) # shape = [b, seq_len, 1]
138
+ x = (x * weights).sum(dim=1) # shape = [b, h]
139
+
140
+ output = self.classifier(x) # shape: [b, num_classes]
141
+ return output
142
+
143
+ class SpeedMapper:
144
+ def __init__(
145
+ self,
146
+ num_classes: Literal[32, 72],
147
+ delta: float = 0.25
148
+ ):
149
+ self.num_classes = num_classes
150
+ self.delta = delta
151
+
152
+ self.max_speed = float(num_classes) * delta
153
+
154
+ self.speed_values = torch.arange(0.25, self.max_speed + self.delta, self.delta)
155
+ assert len(self.speed_values) == num_classes, f"Generated {len(self.speed_values)} classes, expected {num_classes}"
156
+
157
+ def label_to_speed(self, label: torch.Tensor) -> torch.Tensor:
158
+ return self.speed_values.to(label.device)[label] # label * 0.25 + 0.25
159
+
160
+ class SpeedPredictor(nn.Module):
161
+ def __init__(
162
+ self,
163
+ speed_type: Literal["phonemes", "syllables", "words"] = "phonemes",
164
+ mel_spec_kwargs: dict = dict(),
165
+ arch_kwargs: dict | None = None,
166
+ sigma_factor: int = 2,
167
+ mel_spec_module: nn.Module | None = None,
168
+ num_channels: int = 100,
169
+ ):
170
+ super().__init__()
171
+
172
+ num_classes_map = {
173
+ "phonemes": 72,
174
+ "syllables": 32,
175
+ "words": 32
176
+ }
177
+ self.num_classes = num_classes_map[speed_type]
178
+
179
+ # mel spec
180
+ self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs))
181
+ num_channels = default(num_channels, self.mel_spec.n_mel_channels)
182
+ self.num_channels = num_channels
183
+ self.speed_transformer = SpeedTransformer(**arch_kwargs, num_classes=self.num_classes)
184
+ self.gce = GaussianCrossEntropyLoss(num_classes=self.num_classes, sigma_factor=sigma_factor)
185
+ self.speed_mapper = SpeedMapper(self.num_classes)
186
+
187
+ @property
188
+ def device(self):
189
+ return next(self.parameters()).device
190
+
191
+ @torch.no_grad()
192
+ def predict_speed(self, audio: torch.Tensor, lens: torch.Tensor | None = None):
193
+ # raw wave
194
+ if audio.ndim == 2:
195
+ audio = self.mel_spec(audio).permute(0, 2, 1)
196
+
197
+ batch, seq_len, device = *audio.shape[:2], audio.device
198
+
199
+ if not exists(lens):
200
+ lens = torch.full((batch,), seq_len, device=device, dtype=torch.long)
201
+
202
+ logits = self.speed_transformer(audio, lens)
203
+ probs = F.softmax(logits, dim=-1)
204
+
205
+ pred_class = torch.argmax(probs, dim=-1)
206
+ pred_speed = self.speed_mapper.label_to_speed(pred_class)
207
+ return pred_speed
208
+
209
+ def forward(
210
+ self,
211
+ inp: float["b n d"] | float["b nw"], # mel or raw wave # noqa: F722
212
+ speed: float["b"], # speed groundtruth
213
+ lens: int["b"] | None = None, # noqa: F821
214
+ ):
215
+ if inp.ndim == 2:
216
+ inp = self.mel_spec(inp)
217
+ inp = inp.permute(0, 2, 1)
218
+ assert inp.shape[-1] == self.num_channels
219
+ device = self.device
220
+ pred = self.speed_transformer(inp, lens)
221
+ loss = self.gce(pred, speed, device)
222
+
223
+ return loss
requirements.txt CHANGED
@@ -1 +1,2 @@
1
- f5-tts
 
 
1
+ f5-tts
2
+ pyphen
utils_clf5_space.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ from f5_tts.infer.utils_infer import (
4
+ load_checkpoint,
5
+ chunk_text,
6
+ convert_char_to_pinyin,
7
+ )
8
+ from module_clf5 import SpeedPredictor
9
+ import tqdm
10
+ from concurrent.futures import ThreadPoolExecutor
11
+ import numpy as np
12
+
13
+ import pyphen
14
+ import re
15
+
16
+ def count(text, speed_type="syllables"):
17
+ def count_syllables(text):
18
+ # 初始化 pyphen 字典
19
+ dic = pyphen.Pyphen(lang='en_US')
20
+
21
+ total_syllables = 0
22
+
23
+ # 1. 定义正则表达式
24
+ pattern = re.compile(r"[a-zA-Z']+|[\u4e00-\u9fff]")
25
+
26
+ # 2. 找出所有匹配的令牌(英文单词和中文字符)
27
+ tokens = pattern.findall(text)
28
+
29
+ # 3. 遍历令牌并计算音节
30
+ for token in tokens:
31
+ # 检查是否为中文字符
32
+ if '\u4e00' <= token <= '\u9fff':
33
+ # 中文单字计为1个音节
34
+ total_syllables += 1
35
+ else:
36
+ # 英文单词处理逻辑
37
+ try:
38
+ # 使用 pyphen 划分音节
39
+ syllables = dic.inserted(token.lower()).split("-")
40
+ total_syllables += len(syllables)
41
+ except Exception:
42
+ # 如果出现任何错误,估算为1个音节
43
+ total_syllables += 1
44
+
45
+ return total_syllables
46
+
47
+ count_functions = {
48
+ "syllables": count_syllables,
49
+ }
50
+
51
+ if speed_type not in count_functions:
52
+ raise ValueError(f"Unknown speed_type: {speed_type}")
53
+
54
+ return count_functions[speed_type](text)
55
+
56
+ device = (
57
+ "cuda"
58
+ if torch.cuda.is_available()
59
+ else "xpu"
60
+ if torch.xpu.is_available()
61
+ else "mps"
62
+ if torch.backends.mps.is_available()
63
+ else "cpu"
64
+ )
65
+
66
+ # -----------------------------------------
67
+
68
+ target_sample_rate = 24000
69
+ n_mel_channels = 100
70
+ hop_length = 256
71
+ win_length = 1024
72
+ n_fft = 1024
73
+ mel_spec_type = "vocos"
74
+ target_rms = 0.1
75
+ cross_fade_duration = 0.15
76
+ ode_method = "euler"
77
+ nfe_step = 32 # 16, 32
78
+ cfg_strength = 2.0
79
+ sway_sampling_coef = -1.0
80
+ speed = 1.0
81
+ fix_duration = None
82
+
83
+ # -----------------------------------------
84
+
85
+ def infer_process_clf5(
86
+ speakingrate_model,
87
+ ref_audio,
88
+ gen_text,
89
+ model_obj,
90
+ vocoder,
91
+ mel_spec_type=mel_spec_type,
92
+ show_info=print,
93
+ progress=tqdm,
94
+ target_rms=target_rms,
95
+ cross_fade_duration=cross_fade_duration,
96
+ nfe_step=nfe_step,
97
+ cfg_strength=cfg_strength,
98
+ sway_sampling_coef=sway_sampling_coef,
99
+ speed=speed,
100
+ fix_duration=fix_duration,
101
+ device=device,
102
+ ):
103
+ # Split the input text into batches
104
+ audio, sr = torchaudio.load(ref_audio)
105
+ # max_chars = int(len(ref_text.encode("utf-8")) / (audio.shape[-1] / sr) * (22 - audio.shape[-1] / sr))
106
+ gen_text_batches = chunk_text(gen_text)
107
+ for i, gen_text in enumerate(gen_text_batches):
108
+ print(f"gen_text {i}", gen_text)
109
+ print("\n")
110
+
111
+ show_info(f"Generating audio in {len(gen_text_batches)} batches...")
112
+ return next(
113
+ infer_batch_process_clf5(
114
+ (audio, sr),
115
+ speakingrate_model,
116
+ gen_text_batches,
117
+ model_obj,
118
+ vocoder,
119
+ mel_spec_type=mel_spec_type,
120
+ progress=progress,
121
+ target_rms=target_rms,
122
+ cross_fade_duration=cross_fade_duration,
123
+ nfe_step=nfe_step,
124
+ cfg_strength=cfg_strength,
125
+ sway_sampling_coef=sway_sampling_coef,
126
+ speed=speed,
127
+ fix_duration=fix_duration,
128
+ device=device,
129
+ )
130
+ )
131
+
132
+
133
+ def infer_batch_process_clf5(
134
+ ref_audio,
135
+ speakingrate_model,
136
+ gen_text_batches,
137
+ model_obj,
138
+ vocoder,
139
+ mel_spec_type="vocos",
140
+ progress=tqdm,
141
+ target_rms=0.1,
142
+ cross_fade_duration=0.15,
143
+ nfe_step=32,
144
+ cfg_strength=2.0,
145
+ sway_sampling_coef=-1,
146
+ speed=1,
147
+ fix_duration=None,
148
+ device=None,
149
+ streaming=False,
150
+ chunk_size=2048,
151
+ ):
152
+ audio, sr = ref_audio
153
+ if audio.shape[0] > 1:
154
+ audio = torch.mean(audio, dim=0, keepdim=True)
155
+
156
+ rms = torch.sqrt(torch.mean(torch.square(audio)))
157
+ if rms < target_rms:
158
+ audio = audio * target_rms / rms
159
+ if sr != target_sample_rate:
160
+ resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
161
+ audio = resampler(audio)
162
+ audio = audio.to(device)
163
+ pred_speed = speakingrate_model.predict_speed(
164
+ audio=audio
165
+ )
166
+
167
+ generated_waves = []
168
+ spectrograms = []
169
+
170
+ def process_batch(gen_text):
171
+ local_speed = speed
172
+ if len(gen_text.encode("utf-8")) < 10:
173
+ local_speed = 0.3
174
+
175
+ # Prepare the text
176
+ text_list = [gen_text]
177
+ final_text_list = convert_char_to_pinyin(text_list)
178
+
179
+ ref_audio_len = audio.shape[-1] // hop_length
180
+ if fix_duration is not None:
181
+ duration = int(fix_duration * target_sample_rate / hop_length)
182
+ else:
183
+ # Calculate duration
184
+ # ref_text_len = len(ref_text.encode("utf-8"))
185
+ # gen_text_len = len(gen_text.encode("utf-8"))
186
+ # duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / local_speed)
187
+
188
+ gt_num_unit = count(gen_text)
189
+ pred_duration = max(gt_num_unit / pred_speed.item(), 1)
190
+ gen_audio_len = int((pred_duration * target_sample_rate) / hop_length)
191
+ duration = ref_audio_len + gen_audio_len
192
+
193
+ # inference
194
+ with torch.inference_mode():
195
+ generated, _ = model_obj.sample(
196
+ cond=audio,
197
+ text=final_text_list,
198
+ duration=duration,
199
+ steps=nfe_step,
200
+ cfg_strength=cfg_strength,
201
+ sway_sampling_coef=sway_sampling_coef,
202
+ )
203
+
204
+ generated = generated.to(torch.float32)
205
+ generated = generated[:, ref_audio_len:, :]
206
+ generated_mel_spec = generated.permute(0, 2, 1)
207
+ if mel_spec_type == "vocos":
208
+ generated_wave = vocoder.decode(generated_mel_spec)
209
+ elif mel_spec_type == "bigvgan":
210
+ generated_wave = vocoder(generated_mel_spec)
211
+ if rms < target_rms:
212
+ generated_wave = generated_wave * rms / target_rms
213
+
214
+ # wav -> numpy
215
+ generated_wave = generated_wave.squeeze().cpu().numpy()
216
+
217
+ if streaming:
218
+ for j in range(0, len(generated_wave), chunk_size):
219
+ yield generated_wave[j : j + chunk_size], target_sample_rate
220
+ else:
221
+ yield generated_wave, generated_mel_spec[0].cpu().numpy()
222
+
223
+ if streaming:
224
+ for gen_text in progress.tqdm(gen_text_batches) if progress is not None else gen_text_batches:
225
+ for chunk in process_batch(gen_text):
226
+ yield chunk
227
+ else:
228
+ with ThreadPoolExecutor() as executor:
229
+ futures = [executor.submit(process_batch, gen_text) for gen_text in gen_text_batches]
230
+ for future in progress.tqdm(futures) if progress is not None else futures:
231
+ result = future.result()
232
+ if result:
233
+ generated_wave, generated_mel_spec = next(result)
234
+ generated_waves.append(generated_wave)
235
+ spectrograms.append(generated_mel_spec)
236
+
237
+ if generated_waves:
238
+ if cross_fade_duration <= 0:
239
+ # Simply concatenate
240
+ final_wave = np.concatenate(generated_waves)
241
+ else:
242
+ # Combine all generated waves with cross-fading
243
+ final_wave = generated_waves[0]
244
+ for i in range(1, len(generated_waves)):
245
+ prev_wave = final_wave
246
+ next_wave = generated_waves[i]
247
+
248
+ # Calculate cross-fade samples, ensuring it does not exceed wave lengths
249
+ cross_fade_samples = int(cross_fade_duration * target_sample_rate)
250
+ cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))
251
+
252
+ if cross_fade_samples <= 0:
253
+ # No overlap possible, concatenate
254
+ final_wave = np.concatenate([prev_wave, next_wave])
255
+ continue
256
+
257
+ # Overlapping parts
258
+ prev_overlap = prev_wave[-cross_fade_samples:]
259
+ next_overlap = next_wave[:cross_fade_samples]
260
+
261
+ # Fade out and fade in
262
+ fade_out = np.linspace(1, 0, cross_fade_samples)
263
+ fade_in = np.linspace(0, 1, cross_fade_samples)
264
+
265
+ # Cross-faded overlap
266
+ cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in
267
+
268
+ # Combine
269
+ new_wave = np.concatenate(
270
+ [prev_wave[:-cross_fade_samples], cross_faded_overlap, next_wave[cross_fade_samples:]]
271
+ )
272
+
273
+ final_wave = new_wave
274
+
275
+ # Create a combined spectrogram
276
+ combined_spectrogram = np.concatenate(spectrograms, axis=1)
277
+
278
+ yield final_wave, target_sample_rate, combined_spectrogram
279
+
280
+ else:
281
+ yield None, target_sample_rate, None
282
+
283
+ def load_model_sp(
284
+ model_cfg,
285
+ ckpt_path,
286
+ mel_spec_kwargs,
287
+ speed_type="syllables",
288
+ use_ema=True,
289
+ device=device,
290
+ ):
291
+ print("model : ", ckpt_path, "\n")
292
+
293
+ model_sp = SpeedPredictor(
294
+ speed_type=speed_type,
295
+ mel_spec_kwargs=mel_spec_kwargs,
296
+ arch_kwargs = model_cfg
297
+ ).to(device)
298
+
299
+ dtype = torch.float32
300
+ model = load_checkpoint(model_sp, ckpt_path, device, dtype=dtype, use_ema=use_ema)
301
+
302
+ return model
utils_clf5space.py DELETED
File without changes