speechtotextv2 / app.py
areksmyk's picture
Update app.py
4a434ae verified
import gradio as gr
import torch
import nemo.collections.asr as nemo_asr
from pydub import AudioSegment
import os
import logging
from typing import Optional
import threading
# Konfiguracja logowania
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class TimeoutException(Exception):
"""Wyj膮tek dla timeoutu transkrypcji."""
pass
class TranscriptionService:
"""Klasa do zarz膮dzania modelami ASR na r贸偶nych urz膮dzeniach."""
def __init__(self):
# Usuni臋cie wst臋pnego 艂adowania. Modele b臋d膮 艂adowane dynamicznie
self.models = {
'mps': None,
'cuda': None,
'cpu': None
}
self.model_name = "nvidia/parakeet-tdt-0.6b-v3"
self.timeout_seconds = 300 # 5 minut timeout
self.chunk_length_minutes = 5 # Dziel pliki d艂u偶sze ni偶 5 minut
def _get_optimal_device(self, audio_length_minutes: float) -> str:
"""
Wybiera optymalne urz膮dzenie na podstawie d艂ugo艣ci audio i dost臋pno艣ci sprz臋tu.
"""
if torch.cuda.is_available():
logger.info("U偶ywam CUDA (GPU) - najlepsza wydajno艣膰")
return "cuda"
if torch.backends.mps.is_available() and audio_length_minutes <= 8:
logger.info(f"Plik kr贸tki ({audio_length_minutes:.2f} min) - u偶ywam MPS")
return "mps"
if torch.backends.mps.is_available() and audio_length_minutes > 8:
logger.info(f"Plik d艂ugi ({audio_length_minutes:.2f} min) - u偶ywam CPU zamiast MPS")
else:
logger.info("Brak GPU/MPS - u偶ywam CPU")
return "cpu"
def _load_model(self, device: str) -> nemo_asr.models.ASRModel:
"""
艁aduje model na okre艣lonym urz膮dzeniu (z cache'owaniem).
"""
if self.models[device] is None:
logger.info(f"艁adowanie modelu na {device.upper()}...")
try:
model = nemo_asr.models.ASRModel.from_pretrained(
model_name=self.model_name
)
self.models[device] = model.to(device)
logger.info("Model za艂adowany pomy艣lnie")
except Exception as e:
logger.error(f"B艂膮d 艂adowania modelu na {device}: {e}")
raise
return self.models[device]
def _split_audio(self, audio_file_path: str, chunk_length_ms: int) -> list:
"""
Dzieli d艂ugi plik audio na mniejsze fragmenty.
"""
audio = AudioSegment.from_file(audio_file_path)
chunks = []
for i, chunk in enumerate(audio[::chunk_length_ms]):
chunk_path = f"/tmp/temp_chunk_{i}.wav"
chunk.export(chunk_path, format="wav")
chunks.append(chunk_path)
return chunks
def _transcribe_with_timeout(self, audio_file_path: str, device: str) -> str:
"""
Wykonuje transkrypcj臋 z timeoutem.
"""
# 艁adowanie modelu przeniesione tutaj
model = self._load_model(device)
result = {"text": None, "error": None}
def transcribe_worker():
try:
transcriptions = model.transcribe([audio_file_path])
if transcriptions and len(transcriptions) > 0:
result["text"] = transcriptions[0].text
else:
result["error"] = "Model nie zwr贸ci艂 偶adnej transkrypcji."
except Exception as e:
result["error"] = f"B艂膮d transkrypcji: {str(e)}"
thread = threading.Thread(target=transcribe_worker)
thread.start()
thread.join(timeout=self.timeout_seconds)
if thread.is_alive():
raise TimeoutException(f"Transkrypcja przekroczy艂a limit {self.timeout_seconds} sekund")
if result["error"]:
raise Exception(result["error"])
return result["text"]
def transcribe(self, audio_file_path: str, progress=None) -> str:
"""
G艂贸wna funkcja transkrypcji.
"""
if not audio_file_path or not os.path.exists(audio_file_path):
return "B艂膮d: Nie wybrano pliku audio lub plik nie istnieje."
temp_files = []
try:
logger.info(f"Analizuj臋 plik: {os.path.basename(audio_file_path)}")
audio = AudioSegment.from_file(audio_file_path)
length_minutes = len(audio) / (1000 * 60)
logger.info(f"D艂ugo艣膰 pliku: {length_minutes:.2f} minut")
device = self._get_optimal_device(length_minutes)
if length_minutes > self.chunk_length_minutes:
if progress:
progress(0.1, desc="Dziel臋 plik na fragmenty...")
logger.info(f"Dziel臋 plik na fragmenty po {self.chunk_length_minutes} minut")
chunk_length_ms = self.chunk_length_minutes * 60 * 1000
chunks = self._split_audio(audio_file_path, chunk_length_ms)
temp_files.extend(chunks)
logger.info(f"Transkrypcja {len(chunks)} fragment贸w...")
all_transcriptions = []
for i, chunk_path in enumerate(chunks):
if progress:
progress_value = 0.1 + (0.8 * (i + 1) / len(chunks))
progress(progress_value, desc=f"Transkrypcja fragmentu {i+1}/{len(chunks)}...")
logger.info(f"Transkrypcja fragmentu {i+1}/{len(chunks)}...")
chunk_text = self._transcribe_with_timeout(chunk_path, device)
all_transcriptions.append(chunk_text)
logger.info(f"Fragment {i+1} przetworzony")
result_text = " ".join(all_transcriptions)
else:
if progress:
progress(0.5, desc="Rozpoczynam transkrypcj臋...")
logger.info("Rozpoczynam transkrypcj臋...")
result_text = self._transcribe_with_timeout(audio_file_path, device)
logger.info("Transkrypcja zako艅czona pomy艣lnie")
return result_text
except FileNotFoundError:
error_msg = f"B艂膮d: Plik {audio_file_path} nie zosta艂 znaleziony."
logger.error(error_msg)
return error_msg
except TimeoutException as e:
error_msg = f"Timeout: {str(e)}"
logger.error(error_msg)
return error_msg
except Exception as e:
error_msg = f"Wyst膮pi艂 b艂膮d podczas transkrypcji: {str(e)}"
logger.error(error_msg)
return error_msg
finally:
for temp_file in temp_files:
try:
os.remove(temp_file)
except:
pass
# Globalna instancja serwisu
transcription_service = TranscriptionService()
def transcribe_audio_wrapper(audio_file_path: str, progress=gr.Progress()) -> str:
"""Wrapper dla Gradio - izoluje logik臋 od interfejsu."""
return transcription_service.transcribe(audio_file_path, progress)
def create_interface() -> gr.Interface:
"""Tworzy i konfiguruje interfejs Gradio."""
return gr.Interface(
fn=transcribe_audio_wrapper,
inputs=gr.Audio(
type="filepath",
label="Wybierz plik audio",
format="wav" # Opcjonalnie: wymu艣 konkretny format
),
outputs=gr.Textbox(
lines=10,
label="Wynik transkrypcji",
placeholder="Tutaj pojawi si臋 transkrypcja..."
),
title="馃帳 Transkrypcja mowy na tekst",
description="""
Wybierz plik audio, a model NVIDIA Parakeet wykona transkrypcj臋.
**Obs艂ugiwane formaty:** WAV, MP3, FLAC, M4A i inne
**Optymalizacja urz膮dzenia:** Automatyczny wyb贸r GPU/CPU
""",
examples=None,
cache_examples=False,
flagging_options=None,
allow_flagging="never"
)
if __name__ == "__main__":
logger.info("=== Informacje o systemie ===")
logger.info(f"CUDA dost臋pne: {torch.cuda.is_available()}")
logger.info(f"MPS dost臋pne: {torch.backends.mps.is_available()}")
if torch.cuda.is_available():
logger.info(f"GPU: {torch.cuda.get_device_name(0)}")
interface = create_interface()
interface.launch(
server_name="0.0.0.0", # Zmieniono z 127.0.0.1
server_port=7860,
share=False,
debug=False,
show_error=True
)