playmak3r commited on
Commit
b1c3b85
Β·
1 Parent(s): 67eb0b9

refactor: detach VibeVoiceDemo class into a separated module

Browse files
Files changed (2) hide show
  1. app.py +3 -633
  2. model.py +634 -0
app.py CHANGED
@@ -2,632 +2,19 @@
2
  VibeVoice Gradio Demo - High-Quality Dialogue Generation Interface with Streaming Support
3
  """
4
 
5
- from typing import Iterator, Optional, List, Dict, Any
6
- import argparse, os, time, traceback, json, sys, tempfile
7
- from pathlib import Path
8
- from datetime import datetime
9
- import threading
10
- import numpy as np
11
- import gradio as gr
12
- import librosa
13
- import soundfile as sf
14
  import torch
 
15
 
16
- from vibevoice.modular.configuration_vibevoice import VibeVoiceConfig
17
- from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
18
- from vibevoice.modular.lora_loading import load_lora_assets
19
- from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor
20
- from vibevoice.modular.streamer import AudioStreamer
21
  from transformers.utils import logging
22
  from transformers import set_seed
 
23
 
24
  logging.set_verbosity_info()
25
  logger = logging.get_logger(__name__)
26
 
27
 
28
- class VibeVoiceDemo:
29
- def __init__(self, model_path: str, device: str = "cuda", inference_steps: int = 5, adapter_path: Optional[str] = None):
30
- """Initialize the VibeVoice demo with model loading."""
31
- self.model_path = model_path
32
- self.device = device
33
- self.inference_steps = inference_steps
34
- self.adapter_path = adapter_path
35
- self.loaded_adapter_root: Optional[str] = None
36
- self.is_generating = False # Track generation state
37
- self.stop_generation = False # Flag to stop generation
38
- self.current_streamer = None # Track current audio streamer
39
- self.load_model()
40
- self.setup_voice_presets()
41
- self.load_example_scripts() # Load example scripts
42
-
43
- def load_model(self):
44
- """Load the VibeVoice model and processor."""
45
- print(f"Loading processor & model from {self.model_path}")
46
- self.loaded_adapter_root = None
47
- # Normalize potential 'mpx'
48
- if self.device.lower() == "mpx":
49
- print("Note: device 'mpx' detected, treating it as 'mps'.")
50
- self.device = "mps"
51
- if self.device == "mps" and not torch.backends.mps.is_available():
52
- print("Warning: MPS not available. Falling back to CPU.")
53
- self.device = "cpu"
54
- print(f"Using device: {self.device}")
55
- # Load processor
56
- self.processor = VibeVoiceProcessor.from_pretrained(self.model_path)
57
- # Decide dtype & attention
58
- if self.device == "mps":
59
- load_dtype = torch.float32
60
- attn_impl_primary = "sdpa"
61
- elif self.device == "cuda":
62
- load_dtype = torch.bfloat16
63
- attn_impl_primary = "flash_attention_2"
64
- else:
65
- load_dtype = torch.float32
66
- attn_impl_primary = "sdpa"
67
- print(f"Using device: {self.device}, torch_dtype: {load_dtype}, attn_implementation: {attn_impl_primary}")
68
- # Load model
69
- try:
70
- if self.device == "mps":
71
- self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
72
- self.model_path,
73
- torch_dtype=load_dtype,
74
- attn_implementation=attn_impl_primary,
75
- device_map=None,
76
- )
77
- self.model.to("mps")
78
- elif self.device == "cuda":
79
- self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
80
- self.model_path,
81
- torch_dtype=load_dtype,
82
- device_map="cuda",
83
- attn_implementation=attn_impl_primary,
84
- )
85
- else:
86
- self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
87
- self.model_path,
88
- torch_dtype=load_dtype,
89
- device_map="cpu",
90
- attn_implementation=attn_impl_primary,
91
- )
92
- except Exception as e:
93
- if attn_impl_primary == 'flash_attention_2':
94
- print(f"[ERROR] : {type(e).__name__}: {e}")
95
- print(traceback.format_exc())
96
- fallback_attn = "sdpa"
97
- print(f"Falling back to attention implementation: {fallback_attn}")
98
- self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
99
- self.model_path,
100
- torch_dtype=load_dtype,
101
- device_map=(self.device if self.device in ("cuda", "cpu") else None),
102
- attn_implementation=fallback_attn,
103
- )
104
- if self.device == "mps":
105
- self.model.to("mps")
106
- else:
107
- raise e
108
- if self.adapter_path:
109
- print(f"Loading fine-tuned assets from {self.adapter_path}")
110
- report = load_lora_assets(self.model, self.adapter_path)
111
- loaded_components = [
112
- name for name, loaded in (
113
- ("language LoRA", report.language_model),
114
- ("diffusion head LoRA", report.diffusion_head_lora),
115
- ("diffusion head weights", report.diffusion_head_full),
116
- ("acoustic connector", report.acoustic_connector),
117
- ("semantic connector", report.semantic_connector),
118
- )
119
- if loaded
120
- ]
121
- if loaded_components:
122
- print(f"Loaded components: {', '.join(loaded_components)}")
123
- else:
124
- print("Warning: no adapter components were loaded; check the checkpoint path.")
125
- if report.adapter_root is not None:
126
- self.loaded_adapter_root = str(report.adapter_root)
127
- print(f"Adapter assets resolved to: {self.loaded_adapter_root}")
128
- else:
129
- self.loaded_adapter_root = self.adapter_path
130
 
131
- self.model.eval()
132
-
133
- # Use SDE solver by default
134
- self.model.model.noise_scheduler = self.model.model.noise_scheduler.from_config(
135
- self.model.model.noise_scheduler.config,
136
- algorithm_type='sde-dpmsolver++',
137
- beta_schedule='squaredcos_cap_v2'
138
- )
139
- self.model.set_ddpm_inference_steps(num_steps=self.inference_steps)
140
-
141
- if hasattr(self.model.model, 'language_model'):
142
- print(f"Language model attention: {self.model.model.language_model.config._attn_implementation}")
143
-
144
- def setup_voice_presets(self):
145
- """Setup voice presets by scanning the voices directory."""
146
- voices_dir = os.path.join(os.path.dirname(__file__), "voices")
147
-
148
- # Check if voices directory exists
149
- if not os.path.exists(voices_dir):
150
- print(f"Warning: Voices directory not found at {voices_dir}")
151
- self.voice_presets = {}
152
- self.available_voices = {}
153
- return
154
-
155
- # Scan for all WAV files in the voices directory
156
- self.voice_presets = {}
157
-
158
- # Get all .wav files in the voices directory
159
- wav_files = [f for f in os.listdir(voices_dir)
160
- if f.lower().endswith(('.wav', '.mp3', '.flac', '.ogg', '.m4a', '.aac')) and os.path.isfile(os.path.join(voices_dir, f))]
161
-
162
- # Create dictionary with filename (without extension) as key
163
- for wav_file in wav_files:
164
- # Remove .wav extension to get the name
165
- name = os.path.splitext(wav_file)[0]
166
- # Create full path
167
- full_path = os.path.join(voices_dir, wav_file)
168
- self.voice_presets[name] = full_path
169
-
170
- # Sort the voice presets alphabetically by name for better UI
171
- self.voice_presets = dict(sorted(self.voice_presets.items()))
172
-
173
- # Filter out voices that don't exist (this is now redundant but kept for safety)
174
- self.available_voices = {
175
- name: path for name, path in self.voice_presets.items()
176
- if os.path.exists(path)
177
- }
178
-
179
- if not self.available_voices:
180
- raise gr.Error("No voice presets found. Please add .wav files to the demo/voices directory.")
181
-
182
- print(f"Found {len(self.available_voices)} voice files in {voices_dir}")
183
- print(f"Available voices: {', '.join(self.available_voices.keys())}")
184
-
185
- def read_audio(self, audio_path: str, target_sr: int = 24000) -> np.ndarray:
186
- """Read and preprocess audio file."""
187
- try:
188
- wav, sr = sf.read(audio_path)
189
- if len(wav.shape) > 1:
190
- wav = np.mean(wav, axis=1)
191
- if sr != target_sr:
192
- wav = librosa.resample(wav, orig_sr=sr, target_sr=target_sr)
193
- return wav
194
- except Exception as e:
195
- print(f"Error reading audio {audio_path}: {e}")
196
- return np.array([])
197
-
198
- def generate_podcast_streaming(self,
199
- num_speakers: int,
200
- script: str,
201
- speaker_1: str = None,
202
- speaker_2: str = None,
203
- speaker_3: str = None,
204
- speaker_4: str = None,
205
- cfg_scale: float = 1.3,
206
- disable_voice_cloning: bool = False) -> Iterator[tuple]:
207
- try:
208
-
209
- # Reset stop flag and set generating state
210
- self.stop_generation = False
211
- self.is_generating = True
212
-
213
- # Validate inputs
214
- if not script.strip():
215
- self.is_generating = False
216
- raise gr.Error("Error: Please provide a script.")
217
-
218
- # Defend against common mistake
219
- script = script.replace("’", "'")
220
-
221
- if num_speakers < 1 or num_speakers > 4:
222
- self.is_generating = False
223
- raise gr.Error("Error: Number of speakers must be between 1 and 4.")
224
-
225
- # Collect selected speakers
226
- selected_speakers = [speaker_1, speaker_2, speaker_3, speaker_4][:num_speakers]
227
-
228
- # Validate speaker selections
229
- for i, speaker in enumerate(selected_speakers):
230
- if not speaker or speaker not in self.available_voices:
231
- self.is_generating = False
232
- raise gr.Error(f"Error: Please select a valid speaker for Speaker {i+1}.")
233
-
234
- voice_cloning_enabled = not disable_voice_cloning
235
-
236
- # Build initial log
237
- log = f"πŸŽ™οΈ Generating podcast with {num_speakers} speakers\n"
238
- log += f"πŸ“Š Parameters: CFG Scale={cfg_scale}, Inference Steps={self.inference_steps}\n"
239
- log += f"🎭 Speakers: {', '.join(selected_speakers)}\n"
240
- log += f"πŸ”Š Voice cloning: {'Enabled' if voice_cloning_enabled else 'Disabled'}\n"
241
- if self.loaded_adapter_root:
242
- log += f"🧩 LoRA: {self.loaded_adapter_root}\n"
243
-
244
- # Check for stop signal
245
- if self.stop_generation:
246
- self.is_generating = False
247
- yield None, "πŸ›‘ Generation stopped by user", gr.update(visible=False)
248
- return
249
-
250
- # Load voice samples when voice cloning is enabled
251
- voice_samples = None
252
- if voice_cloning_enabled:
253
- voice_samples = []
254
- for speaker_name in selected_speakers:
255
- audio_path = self.available_voices[speaker_name]
256
- audio_data = self.read_audio(audio_path)
257
- if len(audio_data) == 0:
258
- self.is_generating = False
259
- raise gr.Error(f"Error: Failed to load audio for {speaker_name}")
260
- voice_samples.append(audio_data)
261
-
262
- # log += f"βœ… Loaded {len(voice_samples)} voice samples\n"
263
-
264
- # Check for stop signal
265
- if self.stop_generation:
266
- self.is_generating = False
267
- yield None, "πŸ›‘ Generation stopped by user", gr.update(visible=False)
268
- return
269
-
270
- # Parse script to assign speaker ID's
271
- lines = script.strip().split('\n')
272
- formatted_script_lines = []
273
-
274
- for line in lines:
275
- line = line.strip()
276
- if not line:
277
- continue
278
-
279
- # Check if line already has speaker format
280
- if line.startswith('Speaker ') and ':' in line:
281
- formatted_script_lines.append(line)
282
- else:
283
- # Auto-assign to speakers in rotation
284
- speaker_id = len(formatted_script_lines) % num_speakers
285
- formatted_script_lines.append(f"Speaker {speaker_id}: {line}")
286
-
287
- formatted_script = '\n'.join(formatted_script_lines)
288
- log += f"πŸ“ Formatted script with {len(formatted_script_lines)} turns\n\n"
289
- log += "πŸ”„ Processing with VibeVoice (streaming mode)...\n"
290
-
291
- # Check for stop signal before processing
292
- if self.stop_generation:
293
- self.is_generating = False
294
- yield None, "πŸ›‘ Generation stopped by user", gr.update(visible=False)
295
- return
296
-
297
- start_time = time.time()
298
-
299
- processor_kwargs = {
300
- "text": [formatted_script],
301
- "padding": True,
302
- "return_tensors": "pt",
303
- "return_attention_mask": True,
304
- }
305
- processor_kwargs["voice_samples"] = [voice_samples] if voice_samples is not None else None
306
-
307
- inputs = self.processor(**processor_kwargs)
308
- # Move tensors to device
309
- target_device = self.device if self.device in ("cuda", "mps") else "cpu"
310
- for k, v in inputs.items():
311
- if torch.is_tensor(v):
312
- inputs[k] = v.to(target_device)
313
-
314
- # Create audio streamer
315
- audio_streamer = AudioStreamer(
316
- batch_size=1,
317
- stop_signal=None,
318
- timeout=None
319
- )
320
-
321
- # Store current streamer for potential stopping
322
- self.current_streamer = audio_streamer
323
-
324
- # Start generation in a separate thread
325
- generation_thread = threading.Thread(
326
- target=self._generate_with_streamer,
327
- args=(inputs, cfg_scale, audio_streamer, voice_cloning_enabled)
328
- )
329
- generation_thread.start()
330
-
331
- # Wait for generation to actually start producing audio
332
- time.sleep(1) # Reduced from 3 to 1 second
333
-
334
- # Check for stop signal after thread start
335
- if self.stop_generation:
336
- audio_streamer.end()
337
- generation_thread.join(timeout=5.0) # Wait up to 5 seconds for thread to finish
338
- self.is_generating = False
339
- yield None, "πŸ›‘ Generation stopped by user", gr.update(visible=False)
340
- return
341
-
342
- # Collect audio chunks as they arrive
343
- sample_rate = 24000
344
- all_audio_chunks = [] # For final statistics
345
- pending_chunks = [] # Buffer for accumulating small chunks
346
- chunk_count = 0
347
- last_yield_time = time.time()
348
- min_yield_interval = 15 # Yield every 15 seconds
349
- min_chunk_size = sample_rate * 30 # At least 2 seconds of audio
350
-
351
- # Get the stream for the first (and only) sample
352
- audio_stream = audio_streamer.get_stream(0)
353
-
354
- has_yielded_audio = False
355
- has_received_chunks = False # Track if we received any chunks at all
356
-
357
- for audio_chunk in audio_stream:
358
- # Check for stop signal in the streaming loop
359
- if self.stop_generation:
360
- audio_streamer.end()
361
- break
362
-
363
- chunk_count += 1
364
- has_received_chunks = True # Mark that we received at least one chunk
365
-
366
- # Convert tensor to numpy
367
- if torch.is_tensor(audio_chunk):
368
- # Convert bfloat16 to float32 first, then to numpy
369
- if audio_chunk.dtype == torch.bfloat16:
370
- audio_chunk = audio_chunk.float()
371
- audio_np = audio_chunk.cpu().numpy().astype(np.float32)
372
- else:
373
- audio_np = np.array(audio_chunk, dtype=np.float32)
374
-
375
- # Ensure audio is 1D and properly normalized
376
- if len(audio_np.shape) > 1:
377
- audio_np = audio_np.squeeze()
378
-
379
- # Convert to 16-bit for Gradio
380
- audio_16bit = convert_to_16_bit_wav(audio_np)
381
-
382
- # Store for final statistics
383
- all_audio_chunks.append(audio_16bit)
384
-
385
- # Add to pending chunks buffer
386
- pending_chunks.append(audio_16bit)
387
-
388
- # Calculate pending audio size
389
- pending_audio_size = sum(len(chunk) for chunk in pending_chunks)
390
- current_time = time.time()
391
- time_since_last_yield = current_time - last_yield_time
392
-
393
- # Decide whether to yield
394
- should_yield = False
395
- if not has_yielded_audio and pending_audio_size >= min_chunk_size:
396
- # First yield: wait for minimum chunk size
397
- should_yield = True
398
- has_yielded_audio = True
399
- elif has_yielded_audio and (pending_audio_size >= min_chunk_size or time_since_last_yield >= min_yield_interval):
400
- # Subsequent yields: either enough audio or enough time has passed
401
- should_yield = True
402
-
403
- if should_yield and pending_chunks:
404
- # Concatenate and yield only the new audio chunks
405
- new_audio = np.concatenate(pending_chunks)
406
- new_duration = len(new_audio) / sample_rate
407
- total_duration = sum(len(chunk) for chunk in all_audio_chunks) / sample_rate
408
-
409
- log_update = log + f"🎡 Streaming: {total_duration:.1f}s generated (chunk {chunk_count})\n"
410
-
411
- # Yield streaming audio chunk and keep complete_audio as None during streaming
412
- yield (sample_rate, new_audio), None, log_update, gr.update(visible=True)
413
-
414
- # Clear pending chunks after yielding
415
- pending_chunks = []
416
- last_yield_time = current_time
417
-
418
- # Yield any remaining chunks
419
- if pending_chunks:
420
- final_new_audio = np.concatenate(pending_chunks)
421
- total_duration = sum(len(chunk) for chunk in all_audio_chunks) / sample_rate
422
- log_update = log + f"🎡 Streaming final chunk: {total_duration:.1f}s total\n"
423
- yield (sample_rate, final_new_audio), None, log_update, gr.update(visible=True)
424
- has_yielded_audio = True # Mark that we yielded audio
425
-
426
- # Wait for generation to complete (with timeout to prevent hanging)
427
- generation_thread.join(timeout=5.0) # Increased timeout to 5 seconds
428
-
429
- # If thread is still alive after timeout, force end
430
- if generation_thread.is_alive():
431
- print("Warning: Generation thread did not complete within timeout")
432
- audio_streamer.end()
433
- generation_thread.join(timeout=5.0)
434
-
435
- # Clean up
436
- self.current_streamer = None
437
- self.is_generating = False
438
-
439
- generation_time = time.time() - start_time
440
-
441
- # Check if stopped by user
442
- if self.stop_generation:
443
- yield None, None, "πŸ›‘ Generation stopped by user", gr.update(visible=False)
444
- return
445
-
446
- # Debug logging
447
- # print(f"Debug: has_received_chunks={has_received_chunks}, chunk_count={chunk_count}, all_audio_chunks length={len(all_audio_chunks)}")
448
-
449
- # Check if we received any chunks but didn't yield audio
450
- if has_received_chunks and not has_yielded_audio and all_audio_chunks:
451
- # We have chunks but didn't meet the yield criteria, yield them now
452
- complete_audio = np.concatenate(all_audio_chunks)
453
- final_duration = len(complete_audio) / sample_rate
454
-
455
- final_log = log + f"⏱️ Generation completed in {generation_time:.2f} seconds\n"
456
- final_log += f"🎡 Final audio duration: {final_duration:.2f} seconds\n"
457
- final_log += f"πŸ“Š Total chunks: {chunk_count}\n"
458
- final_log += "✨ Generation successful! Complete audio is ready.\n"
459
- final_log += "πŸ’‘ Not satisfied? You can regenerate or adjust the CFG scale for different results."
460
-
461
- # Yield the complete audio
462
- yield None, (sample_rate, complete_audio), final_log, gr.update(visible=False)
463
- return
464
-
465
- if not has_received_chunks:
466
- error_log = log + f"\n❌ Error: No audio chunks were received from the model. Generation time: {generation_time:.2f}s"
467
- yield None, None, error_log, gr.update(visible=False)
468
- return
469
-
470
- if not has_yielded_audio:
471
- error_log = log + f"\n❌ Error: Audio was generated but not streamed. Chunk count: {chunk_count}"
472
- yield None, None, error_log, gr.update(visible=False)
473
- return
474
-
475
- # Prepare the complete audio
476
- if all_audio_chunks:
477
- complete_audio = np.concatenate(all_audio_chunks)
478
- final_duration = len(complete_audio) / sample_rate
479
-
480
- final_log = log + f"⏱️ Generation completed in {generation_time:.2f} seconds\n"
481
- final_log += f"🎡 Final audio duration: {final_duration:.2f} seconds\n"
482
- final_log += f"πŸ“Š Total chunks: {chunk_count}\n"
483
- final_log += "✨ Generation successful! Complete audio is ready in the 'Complete Audio' tab.\n"
484
- final_log += "πŸ’‘ Not satisfied? You can regenerate or adjust the CFG scale for different results."
485
-
486
- # Final yield: Clear streaming audio and provide complete audio
487
- yield None, (sample_rate, complete_audio), final_log, gr.update(visible=False)
488
- else:
489
- final_log = log + "❌ No audio was generated."
490
- yield None, None, final_log, gr.update(visible=False)
491
-
492
- except gr.Error as e:
493
- # Handle Gradio-specific errors (like input validation)
494
- self.is_generating = False
495
- self.current_streamer = None
496
- error_msg = f"❌ Input Error: {str(e)}"
497
- print(error_msg)
498
- yield None, None, error_msg, gr.update(visible=False)
499
-
500
- except Exception as e:
501
- self.is_generating = False
502
- self.current_streamer = None
503
- error_msg = f"❌ An unexpected error occurred: {str(e)}"
504
- print(error_msg)
505
- import traceback
506
- traceback.print_exc()
507
- yield None, None, error_msg, gr.update(visible=False)
508
-
509
- def _generate_with_streamer(self, inputs, cfg_scale, audio_streamer, voice_cloning_enabled: bool):
510
- """Helper method to run generation with streamer in a separate thread."""
511
- try:
512
- # Check for stop signal before starting generation
513
- if self.stop_generation:
514
- audio_streamer.end()
515
- return
516
-
517
- # Define a stop check function that can be called from generate
518
- def check_stop_generation():
519
- return self.stop_generation
520
-
521
- outputs = self.model.generate(
522
- **inputs,
523
- max_new_tokens=None,
524
- cfg_scale=cfg_scale,
525
- tokenizer=self.processor.tokenizer,
526
- generation_config={
527
- 'do_sample': False,
528
- },
529
- audio_streamer=audio_streamer,
530
- stop_check_fn=check_stop_generation, # Pass the stop check function
531
- verbose=False, # Disable verbose in streaming mode
532
- refresh_negative=True,
533
- is_prefill=voice_cloning_enabled,
534
- )
535
-
536
- except Exception as e:
537
- print(f"Error in generation thread: {e}")
538
- traceback.print_exc()
539
- # Make sure to end the stream on error
540
- audio_streamer.end()
541
-
542
- def stop_audio_generation(self):
543
- """Stop the current audio generation process."""
544
- self.stop_generation = True
545
- if self.current_streamer is not None:
546
- try:
547
- self.current_streamer.end()
548
- except Exception as e:
549
- print(f"Error stopping streamer: {e}")
550
- print("πŸ›‘ Audio generation stop requested")
551
-
552
- def load_example_scripts(self):
553
- """Load example scripts from the text_examples directory."""
554
- examples_dir = os.path.join(os.path.dirname(__file__), "text_examples")
555
- self.example_scripts = []
556
-
557
- # Check if text_examples directory exists
558
- if not os.path.exists(examples_dir):
559
- print(f"Warning: text_examples directory not found at {examples_dir}")
560
- return
561
-
562
- # Get all .txt files in the text_examples directory
563
- txt_files = sorted([f for f in os.listdir(examples_dir)
564
- if f.lower().endswith('.txt') and os.path.isfile(os.path.join(examples_dir, f))])
565
-
566
- for txt_file in txt_files:
567
- file_path = os.path.join(examples_dir, txt_file)
568
-
569
- import re
570
- # Check if filename contains a time pattern like "45min", "90min", etc.
571
- time_pattern = re.search(r'(\d+)min', txt_file.lower())
572
- if time_pattern:
573
- minutes = int(time_pattern.group(1))
574
- if minutes > 15:
575
- print(f"Skipping {txt_file}: duration {minutes} minutes exceeds 15-minute limit")
576
- continue
577
-
578
- try:
579
- with open(file_path, 'r', encoding='utf-8') as f:
580
- script_content = f.read().strip()
581
-
582
- # Remove empty lines and lines with only whitespace
583
- script_content = '\n'.join(line for line in script_content.split('\n') if line.strip())
584
-
585
- if not script_content:
586
- continue
587
-
588
- # Parse the script to determine number of speakers
589
- num_speakers = self._get_num_speakers_from_script(script_content)
590
-
591
- # Add to examples list as [num_speakers, script_content]
592
- self.example_scripts.append([num_speakers, script_content])
593
- print(f"Loaded example: {txt_file} with {num_speakers} speakers")
594
-
595
- except Exception as e:
596
- print(f"Error loading example script {txt_file}: {e}")
597
-
598
- if self.example_scripts:
599
- print(f"Successfully loaded {len(self.example_scripts)} example scripts")
600
- else:
601
- print("No example scripts were loaded")
602
-
603
- def _get_num_speakers_from_script(self, script: str) -> int:
604
- """Determine the number of unique speakers in a script."""
605
- import re
606
- speakers = set()
607
-
608
- lines = script.strip().split('\n')
609
- for line in lines:
610
- # Use regex to find speaker patterns
611
- match = re.match(r'^Speaker\s+(\d+)\s*:', line.strip(), re.IGNORECASE)
612
- if match:
613
- speaker_id = int(match.group(1))
614
- speakers.add(speaker_id)
615
-
616
- # If no speakers found, default to 1
617
- if not speakers:
618
- return 1
619
-
620
- # Return the maximum speaker ID + 1 (assuming 0-based indexing)
621
- # or the count of unique speakers if they're 1-based
622
- max_speaker = max(speakers)
623
- min_speaker = min(speakers)
624
-
625
- if min_speaker == 0:
626
- return max_speaker + 1
627
- else:
628
- # Assume 1-based indexing, return the count
629
- return len(speakers)
630
-
631
 
632
  def create_demo_interface(demo_instance: VibeVoiceDemo):
633
  """Create the Gradio interface with streaming support."""
@@ -1182,23 +569,6 @@ Potential for Deepfakes and Disinformation: High-quality synthetic speech can be
1182
  return interface
1183
 
1184
 
1185
- def convert_to_16_bit_wav(data):
1186
- # Check if data is a tensor and move to cpu
1187
- if torch.is_tensor(data):
1188
- data = data.detach().cpu().numpy()
1189
-
1190
- # Ensure data is numpy array
1191
- data = np.array(data)
1192
-
1193
- # Normalize to range [-1, 1] if it's not already
1194
- if np.max(np.abs(data)) > 1.0:
1195
- data = data / np.max(np.abs(data))
1196
-
1197
- # Scale to 16-bit integer range
1198
- data = (data * 32767).astype(np.int16)
1199
- return data
1200
-
1201
-
1202
  def parse_args():
1203
  parser = argparse.ArgumentParser(description="VibeVoice Gradio Demo")
1204
  parser.add_argument(
 
2
  VibeVoice Gradio Demo - High-Quality Dialogue Generation Interface with Streaming Support
3
  """
4
 
5
+ import argparse
 
 
 
 
 
 
 
 
6
  import torch
7
+ import gradio as gr
8
 
 
 
 
 
 
9
  from transformers.utils import logging
10
  from transformers import set_seed
11
+ from model import VibeVoiceDemo
12
 
13
  logging.set_verbosity_info()
14
  logger = logging.get_logger(__name__)
15
 
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  def create_demo_interface(demo_instance: VibeVoiceDemo):
20
  """Create the Gradio interface with streaming support."""
 
569
  return interface
570
 
571
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
572
  def parse_args():
573
  parser = argparse.ArgumentParser(description="VibeVoice Gradio Demo")
574
  parser.add_argument(
model.py ADDED
@@ -0,0 +1,634 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading, librosa, torch
2
+ import gradio as gr
3
+ import numpy as np
4
+ import soundfile as sf
5
+
6
+ from typing import Iterator, Optional
7
+ import os, time, traceback
8
+
9
+ from vibevoice.modular.configuration_vibevoice import VibeVoiceConfig
10
+ from vibevoice.modular.modeling_vibevoice_inference import VibeVoiceForConditionalGenerationInference
11
+ from vibevoice.modular.lora_loading import load_lora_assets
12
+ from vibevoice.processor.vibevoice_processor import VibeVoiceProcessor
13
+ from vibevoice.modular.streamer import AudioStreamer
14
+
15
+
16
+
17
+ def convert_to_16_bit_wav(data):
18
+ # Check if data is a tensor and move to cpu
19
+ if torch.is_tensor(data):
20
+ data = data.detach().cpu().numpy()
21
+
22
+ # Ensure data is numpy array
23
+ data = np.array(data)
24
+
25
+ # Normalize to range [-1, 1] if it's not already
26
+ if np.max(np.abs(data)) > 1.0:
27
+ data = data / np.max(np.abs(data))
28
+
29
+ # Scale to 16-bit integer range
30
+ data = (data * 32767).astype(np.int16)
31
+ return data
32
+
33
+ class VibeVoiceDemo:
34
+ def __init__(self, model_path: str, device: str = "cuda", inference_steps: int = 5, adapter_path: Optional[str] = None):
35
+ """Initialize the VibeVoice demo with model loading."""
36
+ self.model_path = model_path
37
+ self.device = device
38
+ self.inference_steps = inference_steps
39
+ self.adapter_path = adapter_path
40
+ self.loaded_adapter_root: Optional[str] = None
41
+ self.is_generating = False # Track generation state
42
+ self.stop_generation = False # Flag to stop generation
43
+ self.current_streamer = None # Track current audio streamer
44
+ self.load_model()
45
+ self.setup_voice_presets()
46
+ self.load_example_scripts() # Load example scripts
47
+
48
+ def load_model(self):
49
+ """Load the VibeVoice model and processor."""
50
+ print(f"Loading processor & model from {self.model_path}")
51
+ self.loaded_adapter_root = None
52
+ # Normalize potential 'mpx'
53
+ if self.device.lower() == "mpx":
54
+ print("Note: device 'mpx' detected, treating it as 'mps'.")
55
+ self.device = "mps"
56
+ if self.device == "mps" and not torch.backends.mps.is_available():
57
+ print("Warning: MPS not available. Falling back to CPU.")
58
+ self.device = "cpu"
59
+ print(f"Using device: {self.device}")
60
+ # Load processor
61
+ self.processor = VibeVoiceProcessor.from_pretrained(self.model_path)
62
+ # Decide dtype & attention
63
+ if self.device == "mps":
64
+ load_dtype = torch.float32
65
+ attn_impl_primary = "sdpa"
66
+ elif self.device == "cuda":
67
+ load_dtype = torch.bfloat16
68
+ attn_impl_primary = "flash_attention_2"
69
+ else:
70
+ load_dtype = torch.float32
71
+ attn_impl_primary = "sdpa"
72
+ print(f"Using device: {self.device}, torch_dtype: {load_dtype}, attn_implementation: {attn_impl_primary}")
73
+ # Load model
74
+ try:
75
+ if self.device == "mps":
76
+ self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
77
+ self.model_path,
78
+ torch_dtype=load_dtype,
79
+ attn_implementation=attn_impl_primary,
80
+ device_map=None,
81
+ )
82
+ self.model.to("mps")
83
+ elif self.device == "cuda":
84
+ self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
85
+ self.model_path,
86
+ torch_dtype=load_dtype,
87
+ device_map="cuda",
88
+ attn_implementation=attn_impl_primary,
89
+ )
90
+ else:
91
+ self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
92
+ self.model_path,
93
+ torch_dtype=load_dtype,
94
+ device_map="cpu",
95
+ attn_implementation=attn_impl_primary,
96
+ )
97
+ except Exception as e:
98
+ if attn_impl_primary == 'flash_attention_2':
99
+ print(f"[ERROR] : {type(e).__name__}: {e}")
100
+ print(traceback.format_exc())
101
+ fallback_attn = "sdpa"
102
+ print(f"Falling back to attention implementation: {fallback_attn}")
103
+ self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
104
+ self.model_path,
105
+ torch_dtype=load_dtype,
106
+ device_map=(self.device if self.device in ("cuda", "cpu") else None),
107
+ attn_implementation=fallback_attn,
108
+ )
109
+ if self.device == "mps":
110
+ self.model.to("mps")
111
+ else:
112
+ raise e
113
+ if self.adapter_path:
114
+ print(f"Loading fine-tuned assets from {self.adapter_path}")
115
+ report = load_lora_assets(self.model, self.adapter_path)
116
+ loaded_components = [
117
+ name for name, loaded in (
118
+ ("language LoRA", report.language_model),
119
+ ("diffusion head LoRA", report.diffusion_head_lora),
120
+ ("diffusion head weights", report.diffusion_head_full),
121
+ ("acoustic connector", report.acoustic_connector),
122
+ ("semantic connector", report.semantic_connector),
123
+ )
124
+ if loaded
125
+ ]
126
+ if loaded_components:
127
+ print(f"Loaded components: {', '.join(loaded_components)}")
128
+ else:
129
+ print("Warning: no adapter components were loaded; check the checkpoint path.")
130
+ if report.adapter_root is not None:
131
+ self.loaded_adapter_root = str(report.adapter_root)
132
+ print(f"Adapter assets resolved to: {self.loaded_adapter_root}")
133
+ else:
134
+ self.loaded_adapter_root = self.adapter_path
135
+
136
+ self.model.eval()
137
+
138
+ # Use SDE solver by default
139
+ self.model.model.noise_scheduler = self.model.model.noise_scheduler.from_config(
140
+ self.model.model.noise_scheduler.config,
141
+ algorithm_type='sde-dpmsolver++',
142
+ beta_schedule='squaredcos_cap_v2'
143
+ )
144
+ self.model.set_ddpm_inference_steps(num_steps=self.inference_steps)
145
+
146
+ if hasattr(self.model.model, 'language_model'):
147
+ print(f"Language model attention: {self.model.model.language_model.config._attn_implementation}")
148
+
149
+ def setup_voice_presets(self):
150
+ """Setup voice presets by scanning the voices directory."""
151
+ voices_dir = os.path.join(os.path.dirname(__file__), "voices")
152
+
153
+ # Check if voices directory exists
154
+ if not os.path.exists(voices_dir):
155
+ print(f"Warning: Voices directory not found at {voices_dir}")
156
+ self.voice_presets = {}
157
+ self.available_voices = {}
158
+ return
159
+
160
+ # Scan for all WAV files in the voices directory
161
+ self.voice_presets = {}
162
+
163
+ # Get all .wav files in the voices directory
164
+ wav_files = [f for f in os.listdir(voices_dir)
165
+ if f.lower().endswith(('.wav', '.mp3', '.flac', '.ogg', '.m4a', '.aac')) and os.path.isfile(os.path.join(voices_dir, f))]
166
+
167
+ # Create dictionary with filename (without extension) as key
168
+ for wav_file in wav_files:
169
+ # Remove .wav extension to get the name
170
+ name = os.path.splitext(wav_file)[0]
171
+ # Create full path
172
+ full_path = os.path.join(voices_dir, wav_file)
173
+ self.voice_presets[name] = full_path
174
+
175
+ # Sort the voice presets alphabetically by name for better UI
176
+ self.voice_presets = dict(sorted(self.voice_presets.items()))
177
+
178
+ # Filter out voices that don't exist (this is now redundant but kept for safety)
179
+ self.available_voices = {
180
+ name: path for name, path in self.voice_presets.items()
181
+ if os.path.exists(path)
182
+ }
183
+
184
+ if not self.available_voices:
185
+ raise gr.Error("No voice presets found. Please add .wav files to the demo/voices directory.")
186
+
187
+ print(f"Found {len(self.available_voices)} voice files in {voices_dir}")
188
+ print(f"Available voices: {', '.join(self.available_voices.keys())}")
189
+
190
+ def read_audio(self, audio_path: str, target_sr: int = 24000) -> np.ndarray:
191
+ """Read and preprocess audio file."""
192
+ try:
193
+ wav, sr = sf.read(audio_path)
194
+ if len(wav.shape) > 1:
195
+ wav = np.mean(wav, axis=1)
196
+ if sr != target_sr:
197
+ wav = librosa.resample(wav, orig_sr=sr, target_sr=target_sr)
198
+ return wav
199
+ except Exception as e:
200
+ print(f"Error reading audio {audio_path}: {e}")
201
+ return np.array([])
202
+
203
+ def generate_podcast_streaming(self,
204
+ num_speakers: int,
205
+ script: str,
206
+ speaker_1: str = None,
207
+ speaker_2: str = None,
208
+ speaker_3: str = None,
209
+ speaker_4: str = None,
210
+ cfg_scale: float = 1.3,
211
+ disable_voice_cloning: bool = False) -> Iterator[tuple]:
212
+ try:
213
+
214
+ # Reset stop flag and set generating state
215
+ self.stop_generation = False
216
+ self.is_generating = True
217
+
218
+ # Validate inputs
219
+ if not script.strip():
220
+ self.is_generating = False
221
+ raise gr.Error("Error: Please provide a script.")
222
+
223
+ # Defend against common mistake
224
+ script = script.replace("’", "'")
225
+
226
+ if num_speakers < 1 or num_speakers > 4:
227
+ self.is_generating = False
228
+ raise gr.Error("Error: Number of speakers must be between 1 and 4.")
229
+
230
+ # Collect selected speakers
231
+ selected_speakers = [speaker_1, speaker_2, speaker_3, speaker_4][:num_speakers]
232
+
233
+ # Validate speaker selections
234
+ for i, speaker in enumerate(selected_speakers):
235
+ if not speaker or speaker not in self.available_voices:
236
+ self.is_generating = False
237
+ raise gr.Error(f"Error: Please select a valid speaker for Speaker {i+1}.")
238
+
239
+ voice_cloning_enabled = not disable_voice_cloning
240
+
241
+ # Build initial log
242
+ log = f"πŸŽ™οΈ Generating podcast with {num_speakers} speakers\n"
243
+ log += f"πŸ“Š Parameters: CFG Scale={cfg_scale}, Inference Steps={self.inference_steps}\n"
244
+ log += f"🎭 Speakers: {', '.join(selected_speakers)}\n"
245
+ log += f"πŸ”Š Voice cloning: {'Enabled' if voice_cloning_enabled else 'Disabled'}\n"
246
+ if self.loaded_adapter_root:
247
+ log += f"🧩 LoRA: {self.loaded_adapter_root}\n"
248
+
249
+ # Check for stop signal
250
+ if self.stop_generation:
251
+ self.is_generating = False
252
+ yield None, "πŸ›‘ Generation stopped by user", gr.update(visible=False)
253
+ return
254
+
255
+ # Load voice samples when voice cloning is enabled
256
+ voice_samples = None
257
+ if voice_cloning_enabled:
258
+ voice_samples = []
259
+ for speaker_name in selected_speakers:
260
+ audio_path = self.available_voices[speaker_name]
261
+ audio_data = self.read_audio(audio_path)
262
+ if len(audio_data) == 0:
263
+ self.is_generating = False
264
+ raise gr.Error(f"Error: Failed to load audio for {speaker_name}")
265
+ voice_samples.append(audio_data)
266
+
267
+ # log += f"βœ… Loaded {len(voice_samples)} voice samples\n"
268
+
269
+ # Check for stop signal
270
+ if self.stop_generation:
271
+ self.is_generating = False
272
+ yield None, "πŸ›‘ Generation stopped by user", gr.update(visible=False)
273
+ return
274
+
275
+ # Parse script to assign speaker ID's
276
+ lines = script.strip().split('\n')
277
+ formatted_script_lines = []
278
+
279
+ for line in lines:
280
+ line = line.strip()
281
+ if not line:
282
+ continue
283
+
284
+ # Check if line already has speaker format
285
+ if line.startswith('Speaker ') and ':' in line:
286
+ formatted_script_lines.append(line)
287
+ else:
288
+ # Auto-assign to speakers in rotation
289
+ speaker_id = len(formatted_script_lines) % num_speakers
290
+ formatted_script_lines.append(f"Speaker {speaker_id}: {line}")
291
+
292
+ formatted_script = '\n'.join(formatted_script_lines)
293
+ log += f"πŸ“ Formatted script with {len(formatted_script_lines)} turns\n\n"
294
+ log += "πŸ”„ Processing with VibeVoice (streaming mode)...\n"
295
+
296
+ # Check for stop signal before processing
297
+ if self.stop_generation:
298
+ self.is_generating = False
299
+ yield None, "πŸ›‘ Generation stopped by user", gr.update(visible=False)
300
+ return
301
+
302
+ start_time = time.time()
303
+
304
+ processor_kwargs = {
305
+ "text": [formatted_script],
306
+ "padding": True,
307
+ "return_tensors": "pt",
308
+ "return_attention_mask": True,
309
+ }
310
+ processor_kwargs["voice_samples"] = [voice_samples] if voice_samples is not None else None
311
+
312
+ inputs = self.processor(**processor_kwargs)
313
+ # Move tensors to device
314
+ target_device = self.device if self.device in ("cuda", "mps") else "cpu"
315
+ for k, v in inputs.items():
316
+ if torch.is_tensor(v):
317
+ inputs[k] = v.to(target_device)
318
+
319
+ # Create audio streamer
320
+ audio_streamer = AudioStreamer(
321
+ batch_size=1,
322
+ stop_signal=None,
323
+ timeout=None
324
+ )
325
+
326
+ # Store current streamer for potential stopping
327
+ self.current_streamer = audio_streamer
328
+
329
+ # Start generation in a separate thread
330
+ generation_thread = threading.Thread(
331
+ target=self._generate_with_streamer,
332
+ args=(inputs, cfg_scale, audio_streamer, voice_cloning_enabled)
333
+ )
334
+ generation_thread.start()
335
+
336
+ # Wait for generation to actually start producing audio
337
+ time.sleep(1) # Reduced from 3 to 1 second
338
+
339
+ # Check for stop signal after thread start
340
+ if self.stop_generation:
341
+ audio_streamer.end()
342
+ generation_thread.join(timeout=5.0) # Wait up to 5 seconds for thread to finish
343
+ self.is_generating = False
344
+ yield None, "πŸ›‘ Generation stopped by user", gr.update(visible=False)
345
+ return
346
+
347
+ # Collect audio chunks as they arrive
348
+ sample_rate = 24000
349
+ all_audio_chunks = [] # For final statistics
350
+ pending_chunks = [] # Buffer for accumulating small chunks
351
+ chunk_count = 0
352
+ last_yield_time = time.time()
353
+ min_yield_interval = 15 # Yield every 15 seconds
354
+ min_chunk_size = sample_rate * 30 # At least 2 seconds of audio
355
+
356
+ # Get the stream for the first (and only) sample
357
+ audio_stream = audio_streamer.get_stream(0)
358
+
359
+ has_yielded_audio = False
360
+ has_received_chunks = False # Track if we received any chunks at all
361
+
362
+ for audio_chunk in audio_stream:
363
+ # Check for stop signal in the streaming loop
364
+ if self.stop_generation:
365
+ audio_streamer.end()
366
+ break
367
+
368
+ chunk_count += 1
369
+ has_received_chunks = True # Mark that we received at least one chunk
370
+
371
+ # Convert tensor to numpy
372
+ if torch.is_tensor(audio_chunk):
373
+ # Convert bfloat16 to float32 first, then to numpy
374
+ if audio_chunk.dtype == torch.bfloat16:
375
+ audio_chunk = audio_chunk.float()
376
+ audio_np = audio_chunk.cpu().numpy().astype(np.float32)
377
+ else:
378
+ audio_np = np.array(audio_chunk, dtype=np.float32)
379
+
380
+ # Ensure audio is 1D and properly normalized
381
+ if len(audio_np.shape) > 1:
382
+ audio_np = audio_np.squeeze()
383
+
384
+ # Convert to 16-bit for Gradio
385
+ audio_16bit = convert_to_16_bit_wav(audio_np)
386
+
387
+ # Store for final statistics
388
+ all_audio_chunks.append(audio_16bit)
389
+
390
+ # Add to pending chunks buffer
391
+ pending_chunks.append(audio_16bit)
392
+
393
+ # Calculate pending audio size
394
+ pending_audio_size = sum(len(chunk) for chunk in pending_chunks)
395
+ current_time = time.time()
396
+ time_since_last_yield = current_time - last_yield_time
397
+
398
+ # Decide whether to yield
399
+ should_yield = False
400
+ if not has_yielded_audio and pending_audio_size >= min_chunk_size:
401
+ # First yield: wait for minimum chunk size
402
+ should_yield = True
403
+ has_yielded_audio = True
404
+ elif has_yielded_audio and (pending_audio_size >= min_chunk_size or time_since_last_yield >= min_yield_interval):
405
+ # Subsequent yields: either enough audio or enough time has passed
406
+ should_yield = True
407
+
408
+ if should_yield and pending_chunks:
409
+ # Concatenate and yield only the new audio chunks
410
+ new_audio = np.concatenate(pending_chunks)
411
+ new_duration = len(new_audio) / sample_rate
412
+ total_duration = sum(len(chunk) for chunk in all_audio_chunks) / sample_rate
413
+
414
+ log_update = log + f"🎡 Streaming: {total_duration:.1f}s generated (chunk {chunk_count})\n"
415
+
416
+ # Yield streaming audio chunk and keep complete_audio as None during streaming
417
+ yield (sample_rate, new_audio), None, log_update, gr.update(visible=True)
418
+
419
+ # Clear pending chunks after yielding
420
+ pending_chunks = []
421
+ last_yield_time = current_time
422
+
423
+ # Yield any remaining chunks
424
+ if pending_chunks:
425
+ final_new_audio = np.concatenate(pending_chunks)
426
+ total_duration = sum(len(chunk) for chunk in all_audio_chunks) / sample_rate
427
+ log_update = log + f"🎡 Streaming final chunk: {total_duration:.1f}s total\n"
428
+ yield (sample_rate, final_new_audio), None, log_update, gr.update(visible=True)
429
+ has_yielded_audio = True # Mark that we yielded audio
430
+
431
+ # Wait for generation to complete (with timeout to prevent hanging)
432
+ generation_thread.join(timeout=5.0) # Increased timeout to 5 seconds
433
+
434
+ # If thread is still alive after timeout, force end
435
+ if generation_thread.is_alive():
436
+ print("Warning: Generation thread did not complete within timeout")
437
+ audio_streamer.end()
438
+ generation_thread.join(timeout=5.0)
439
+
440
+ # Clean up
441
+ self.current_streamer = None
442
+ self.is_generating = False
443
+
444
+ generation_time = time.time() - start_time
445
+
446
+ # Check if stopped by user
447
+ if self.stop_generation:
448
+ yield None, None, "πŸ›‘ Generation stopped by user", gr.update(visible=False)
449
+ return
450
+
451
+ # Debug logging
452
+ # print(f"Debug: has_received_chunks={has_received_chunks}, chunk_count={chunk_count}, all_audio_chunks length={len(all_audio_chunks)}")
453
+
454
+ # Check if we received any chunks but didn't yield audio
455
+ if has_received_chunks and not has_yielded_audio and all_audio_chunks:
456
+ # We have chunks but didn't meet the yield criteria, yield them now
457
+ complete_audio = np.concatenate(all_audio_chunks)
458
+ final_duration = len(complete_audio) / sample_rate
459
+
460
+ final_log = log + f"⏱️ Generation completed in {generation_time:.2f} seconds\n"
461
+ final_log += f"🎡 Final audio duration: {final_duration:.2f} seconds\n"
462
+ final_log += f"πŸ“Š Total chunks: {chunk_count}\n"
463
+ final_log += "✨ Generation successful! Complete audio is ready.\n"
464
+ final_log += "πŸ’‘ Not satisfied? You can regenerate or adjust the CFG scale for different results."
465
+
466
+ # Yield the complete audio
467
+ yield None, (sample_rate, complete_audio), final_log, gr.update(visible=False)
468
+ return
469
+
470
+ if not has_received_chunks:
471
+ error_log = log + f"\n❌ Error: No audio chunks were received from the model. Generation time: {generation_time:.2f}s"
472
+ yield None, None, error_log, gr.update(visible=False)
473
+ return
474
+
475
+ if not has_yielded_audio:
476
+ error_log = log + f"\n❌ Error: Audio was generated but not streamed. Chunk count: {chunk_count}"
477
+ yield None, None, error_log, gr.update(visible=False)
478
+ return
479
+
480
+ # Prepare the complete audio
481
+ if all_audio_chunks:
482
+ complete_audio = np.concatenate(all_audio_chunks)
483
+ final_duration = len(complete_audio) / sample_rate
484
+
485
+ final_log = log + f"⏱️ Generation completed in {generation_time:.2f} seconds\n"
486
+ final_log += f"🎡 Final audio duration: {final_duration:.2f} seconds\n"
487
+ final_log += f"πŸ“Š Total chunks: {chunk_count}\n"
488
+ final_log += "✨ Generation successful! Complete audio is ready in the 'Complete Audio' tab.\n"
489
+ final_log += "πŸ’‘ Not satisfied? You can regenerate or adjust the CFG scale for different results."
490
+
491
+ # Final yield: Clear streaming audio and provide complete audio
492
+ yield None, (sample_rate, complete_audio), final_log, gr.update(visible=False)
493
+ else:
494
+ final_log = log + "❌ No audio was generated."
495
+ yield None, None, final_log, gr.update(visible=False)
496
+
497
+ except gr.Error as e:
498
+ # Handle Gradio-specific errors (like input validation)
499
+ self.is_generating = False
500
+ self.current_streamer = None
501
+ error_msg = f"❌ Input Error: {str(e)}"
502
+ print(error_msg)
503
+ yield None, None, error_msg, gr.update(visible=False)
504
+
505
+ except Exception as e:
506
+ self.is_generating = False
507
+ self.current_streamer = None
508
+ error_msg = f"❌ An unexpected error occurred: {str(e)}"
509
+ print(error_msg)
510
+ import traceback
511
+ traceback.print_exc()
512
+ yield None, None, error_msg, gr.update(visible=False)
513
+
514
+ def _generate_with_streamer(self, inputs, cfg_scale, audio_streamer, voice_cloning_enabled: bool):
515
+ """Helper method to run generation with streamer in a separate thread."""
516
+ try:
517
+ # Check for stop signal before starting generation
518
+ if self.stop_generation:
519
+ audio_streamer.end()
520
+ return
521
+
522
+ # Define a stop check function that can be called from generate
523
+ def check_stop_generation():
524
+ return self.stop_generation
525
+
526
+ outputs = self.model.generate(
527
+ **inputs,
528
+ max_new_tokens=None,
529
+ cfg_scale=cfg_scale,
530
+ tokenizer=self.processor.tokenizer,
531
+ generation_config={
532
+ 'do_sample': False,
533
+ },
534
+ audio_streamer=audio_streamer,
535
+ stop_check_fn=check_stop_generation, # Pass the stop check function
536
+ verbose=False, # Disable verbose in streaming mode
537
+ refresh_negative=True,
538
+ is_prefill=voice_cloning_enabled,
539
+ )
540
+
541
+ except Exception as e:
542
+ print(f"Error in generation thread: {e}")
543
+ traceback.print_exc()
544
+ # Make sure to end the stream on error
545
+ audio_streamer.end()
546
+
547
+ def stop_audio_generation(self):
548
+ """Stop the current audio generation process."""
549
+ self.stop_generation = True
550
+ if self.current_streamer is not None:
551
+ try:
552
+ self.current_streamer.end()
553
+ except Exception as e:
554
+ print(f"Error stopping streamer: {e}")
555
+ print("πŸ›‘ Audio generation stop requested")
556
+
557
+ def load_example_scripts(self):
558
+ """Load example scripts from the text_examples directory."""
559
+ examples_dir = os.path.join(os.path.dirname(__file__), "text_examples")
560
+ self.example_scripts = []
561
+
562
+ # Check if text_examples directory exists
563
+ if not os.path.exists(examples_dir):
564
+ print(f"Warning: text_examples directory not found at {examples_dir}")
565
+ return
566
+
567
+ # Get all .txt files in the text_examples directory
568
+ txt_files = sorted([f for f in os.listdir(examples_dir)
569
+ if f.lower().endswith('.txt') and os.path.isfile(os.path.join(examples_dir, f))])
570
+
571
+ for txt_file in txt_files:
572
+ file_path = os.path.join(examples_dir, txt_file)
573
+
574
+ import re
575
+ # Check if filename contains a time pattern like "45min", "90min", etc.
576
+ time_pattern = re.search(r'(\d+)min', txt_file.lower())
577
+ if time_pattern:
578
+ minutes = int(time_pattern.group(1))
579
+ if minutes > 15:
580
+ print(f"Skipping {txt_file}: duration {minutes} minutes exceeds 15-minute limit")
581
+ continue
582
+
583
+ try:
584
+ with open(file_path, 'r', encoding='utf-8') as f:
585
+ script_content = f.read().strip()
586
+
587
+ # Remove empty lines and lines with only whitespace
588
+ script_content = '\n'.join(line for line in script_content.split('\n') if line.strip())
589
+
590
+ if not script_content:
591
+ continue
592
+
593
+ # Parse the script to determine number of speakers
594
+ num_speakers = self._get_num_speakers_from_script(script_content)
595
+
596
+ # Add to examples list as [num_speakers, script_content]
597
+ self.example_scripts.append([num_speakers, script_content])
598
+ print(f"Loaded example: {txt_file} with {num_speakers} speakers")
599
+
600
+ except Exception as e:
601
+ print(f"Error loading example script {txt_file}: {e}")
602
+
603
+ if self.example_scripts:
604
+ print(f"Successfully loaded {len(self.example_scripts)} example scripts")
605
+ else:
606
+ print("No example scripts were loaded")
607
+
608
+ def _get_num_speakers_from_script(self, script: str) -> int:
609
+ """Determine the number of unique speakers in a script."""
610
+ import re
611
+ speakers = set()
612
+
613
+ lines = script.strip().split('\n')
614
+ for line in lines:
615
+ # Use regex to find speaker patterns
616
+ match = re.match(r'^Speaker\s+(\d+)\s*:', line.strip(), re.IGNORECASE)
617
+ if match:
618
+ speaker_id = int(match.group(1))
619
+ speakers.add(speaker_id)
620
+
621
+ # If no speakers found, default to 1
622
+ if not speakers:
623
+ return 1
624
+
625
+ # Return the maximum speaker ID + 1 (assuming 0-based indexing)
626
+ # or the count of unique speakers if they're 1-based
627
+ max_speaker = max(speakers)
628
+ min_speaker = min(speakers)
629
+
630
+ if min_speaker == 0:
631
+ return max_speaker + 1
632
+ else:
633
+ # Assume 1-based indexing, return the count
634
+ return len(speakers)