playmak3r commited on
Commit
1e4e16f
·
1 Parent(s): b1c3b85

refactor: use class attributes for voices and examples directories in VibeVoiceDemo

Browse files
Files changed (1) hide show
  1. model.py +23 -24
model.py CHANGED
@@ -31,6 +31,9 @@ def convert_to_16_bit_wav(data):
31
  return data
32
 
33
  class VibeVoiceDemo:
 
 
 
34
  def __init__(self, model_path: str, device: str = "cuda", inference_steps: int = 5, adapter_path: Optional[str] = None):
35
  """Initialize the VibeVoice demo with model loading."""
36
  self.model_path = model_path
@@ -148,43 +151,40 @@ class VibeVoiceDemo:
148
 
149
  def setup_voice_presets(self):
150
  """Setup voice presets by scanning the voices directory."""
151
- voices_dir = os.path.join(os.path.dirname(__file__), "voices")
152
-
153
  # Check if voices directory exists
154
- if not os.path.exists(voices_dir):
155
- print(f"Warning: Voices directory not found at {voices_dir}")
156
  self.voice_presets = {}
157
  self.available_voices = {}
158
  return
159
-
160
  # Scan for all WAV files in the voices directory
161
  self.voice_presets = {}
162
-
163
  # Get all .wav files in the voices directory
164
- wav_files = [f for f in os.listdir(voices_dir)
165
- if f.lower().endswith(('.wav', '.mp3', '.flac', '.ogg', '.m4a', '.aac')) and os.path.isfile(os.path.join(voices_dir, f))]
166
-
167
  # Create dictionary with filename (without extension) as key
168
  for wav_file in wav_files:
169
  # Remove .wav extension to get the name
170
  name = os.path.splitext(wav_file)[0]
171
- # Create full path
172
- full_path = os.path.join(voices_dir, wav_file)
173
  self.voice_presets[name] = full_path
174
-
175
  # Sort the voice presets alphabetically by name for better UI
176
  self.voice_presets = dict(sorted(self.voice_presets.items()))
177
-
178
  # Filter out voices that don't exist (this is now redundant but kept for safety)
179
  self.available_voices = {
180
  name: path for name, path in self.voice_presets.items()
181
  if os.path.exists(path)
182
  }
183
-
184
  if not self.available_voices:
185
  raise gr.Error("No voice presets found. Please add .wav files to the demo/voices directory.")
186
 
187
- print(f"Found {len(self.available_voices)} voice files in {voices_dir}")
188
  print(f"Available voices: {', '.join(self.available_voices.keys())}")
189
 
190
  def read_audio(self, audio_path: str, target_sr: int = 24000) -> np.ndarray:
@@ -537,13 +537,13 @@ class VibeVoiceDemo:
537
  refresh_negative=True,
538
  is_prefill=voice_cloning_enabled,
539
  )
540
-
541
  except Exception as e:
542
  print(f"Error in generation thread: {e}")
543
  traceback.print_exc()
544
  # Make sure to end the stream on error
545
  audio_streamer.end()
546
-
547
  def stop_audio_generation(self):
548
  """Stop the current audio generation process."""
549
  self.stop_generation = True
@@ -553,23 +553,22 @@ class VibeVoiceDemo:
553
  except Exception as e:
554
  print(f"Error stopping streamer: {e}")
555
  print("🛑 Audio generation stop requested")
556
-
557
  def load_example_scripts(self):
558
  """Load example scripts from the text_examples directory."""
559
- examples_dir = os.path.join(os.path.dirname(__file__), "text_examples")
560
  self.example_scripts = []
561
 
562
  # Check if text_examples directory exists
563
- if not os.path.exists(examples_dir):
564
- print(f"Warning: text_examples directory not found at {examples_dir}")
565
  return
566
 
567
  # Get all .txt files in the text_examples directory
568
- txt_files = sorted([f for f in os.listdir(examples_dir)
569
- if f.lower().endswith('.txt') and os.path.isfile(os.path.join(examples_dir, f))])
570
 
571
  for txt_file in txt_files:
572
- file_path = os.path.join(examples_dir, txt_file)
573
 
574
  import re
575
  # Check if filename contains a time pattern like "45min", "90min", etc.
 
31
  return data
32
 
33
  class VibeVoiceDemo:
34
+ voices_dir = os.path.join(os.path.dirname(__file__), "voices")
35
+ examples_dir = os.path.join(os.path.dirname(__file__), "text_examples")
36
+
37
  def __init__(self, model_path: str, device: str = "cuda", inference_steps: int = 5, adapter_path: Optional[str] = None):
38
  """Initialize the VibeVoice demo with model loading."""
39
  self.model_path = model_path
 
151
 
152
  def setup_voice_presets(self):
153
  """Setup voice presets by scanning the voices directory."""
 
 
154
  # Check if voices directory exists
155
+ if not os.path.exists(self.voices_dir):
156
+ print(f"Warning: Voices directory not found at {self.voices_dir}")
157
  self.voice_presets = {}
158
  self.available_voices = {}
159
  return
160
+
161
  # Scan for all WAV files in the voices directory
162
  self.voice_presets = {}
163
+
164
  # Get all .wav files in the voices directory
165
+ wav_files = [f for f in os.listdir(self.voices_dir)
166
+ if f.lower().endswith(('.wav', '.mp3', '.flac', '.ogg', '.m4a', '.aac')) and os.path.isfile(os.path.join(self.voices_dir, f))]
167
+
168
  # Create dictionary with filename (without extension) as key
169
  for wav_file in wav_files:
170
  # Remove .wav extension to get the name
171
  name = os.path.splitext(wav_file)[0]
172
+ full_path = os.path.join(self.voices_dir, wav_file)
 
173
  self.voice_presets[name] = full_path
174
+
175
  # Sort the voice presets alphabetically by name for better UI
176
  self.voice_presets = dict(sorted(self.voice_presets.items()))
177
+
178
  # Filter out voices that don't exist (this is now redundant but kept for safety)
179
  self.available_voices = {
180
  name: path for name, path in self.voice_presets.items()
181
  if os.path.exists(path)
182
  }
183
+
184
  if not self.available_voices:
185
  raise gr.Error("No voice presets found. Please add .wav files to the demo/voices directory.")
186
 
187
+ print(f"Found {len(self.available_voices)} voice files in {self.voices_dir}")
188
  print(f"Available voices: {', '.join(self.available_voices.keys())}")
189
 
190
  def read_audio(self, audio_path: str, target_sr: int = 24000) -> np.ndarray:
 
537
  refresh_negative=True,
538
  is_prefill=voice_cloning_enabled,
539
  )
540
+
541
  except Exception as e:
542
  print(f"Error in generation thread: {e}")
543
  traceback.print_exc()
544
  # Make sure to end the stream on error
545
  audio_streamer.end()
546
+
547
  def stop_audio_generation(self):
548
  """Stop the current audio generation process."""
549
  self.stop_generation = True
 
553
  except Exception as e:
554
  print(f"Error stopping streamer: {e}")
555
  print("🛑 Audio generation stop requested")
556
+
557
  def load_example_scripts(self):
558
  """Load example scripts from the text_examples directory."""
 
559
  self.example_scripts = []
560
 
561
  # Check if text_examples directory exists
562
+ if not os.path.exists(self.examples_dir):
563
+ print(f"Warning: text_examples directory not found at {self.examples_dir}")
564
  return
565
 
566
  # Get all .txt files in the text_examples directory
567
+ txt_files = sorted([f for f in os.listdir(self.examples_dir)
568
+ if f.lower().endswith('.txt') and os.path.isfile(os.path.join(self.examples_dir, f))])
569
 
570
  for txt_file in txt_files:
571
+ file_path = os.path.join(self.examples_dir, txt_file)
572
 
573
  import re
574
  # Check if filename contains a time pattern like "45min", "90min", etc.