Spaces:
Running
Running
File size: 18,533 Bytes
3c92819 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 |
import torch
import numpy as np
import asyncio
from typing import List, Optional, Tuple
from snac import SNAC
from .constants import (
CODE_END_TOKEN_ID,
CODE_TOKEN_OFFSET,
SNAC_MODEL_NAME,
SNAC_SAMPLE_RATE,
SNAC_TOKENS_PER_FRAME,
)
class SNACDecoder:
"""
SNAC Decoder for maya1.
Unpacks 7-token SNAC frames and decodes to audio waveforms.
Unpacking logic is the EXACT INVERSE of training preprocessing.
Supports async batching for concurrent requests.
CRITICAL: Any mismatch in unpacking will produce garbage audio.
"""
def __init__(
self,
device: str = "cuda",
compile_decoder: bool = False,
enable_batching: bool = False,
max_batch_size: int = 64,
batch_timeout_ms: int = 15,
):
"""
Initialize SNAC decoder.
Args:
device: Device for SNAC model (cuda/cpu)
compile_decoder: Use torch.compile for speedup
enable_batching: Enable async batching
max_batch_size: Max sequences to batch together
batch_timeout_ms: Max wait time before processing batch
"""
self.device = device
self.enable_batching = enable_batching
self.max_batch_size = max_batch_size
self.batch_timeout_ms = batch_timeout_ms
print(f"Loading SNAC 24kHz model to {device}...")
self.snac_model = SNAC.from_pretrained(SNAC_MODEL_NAME).eval().to(device)
if compile_decoder:
print(f"Compiling SNAC decoder with torch.compile...")
self._compile_model()
# Batching infrastructure
if enable_batching:
self.request_queue = asyncio.Queue()
self.batch_processor_task = None
self._running = False
print(f"Batching enabled (max_batch={max_batch_size}, timeout={batch_timeout_ms}ms)")
print(f"SNAC decoder initialized")
def _compile_model(self):
"""Compile SNAC decoder with torch.compile"""
# Warm up with various sizes
for frames in [4, 16, 32]:
dummy_codes = [
torch.randint(0, 4096, (1, frames), device=self.device),
torch.randint(0, 4096, (1, frames * 2), device=self.device),
torch.randint(0, 4096, (1, frames * 4), device=self.device),
]
with torch.inference_mode():
z_q = self.snac_model.quantizer.from_codes(dummy_codes)
_ = self.snac_model.decoder(z_q)
# Apply compilation
self.snac_model.decoder = torch.compile(
self.snac_model.decoder,
mode="max-autotune"
)
self.snac_model.quantizer = torch.compile(
self.snac_model.quantizer,
mode="reduce-overhead"
)
print(f"SNAC decoder compiled")
def unpack_snac_from_7(self, vocab_ids: List[int]) -> List[List[int]]:
"""
Unpack 7-token SNAC frames to 3 hierarchical levels.
This is the EXACT INVERSE of the training preprocessing function
`pack_snac_to_7_and_offset()`.
Frame structure:
[slot0, slot1, slot2, slot3, slot4, slot5, slot6]
Unpacking:
- slot0: L1[i]
- slot1: L2[2*i] (even index)
- slot2: L3[4*i + 0]
- slot3: L3[4*i + 1]
- slot4: L2[2*i + 1] (odd index)
- slot5: L3[4*i + 2]
- slot6: L3[4*i + 3]
Args:
vocab_ids: List of SNAC token IDs (128266-156937)
Must be divisible by 7
Returns:
[L1, L2, L3] where:
L1: n elements (coarse level)
L2: 2n elements (medium level)
L3: 4n elements (fine level)
"""
# Strip EOS token if present
if vocab_ids and vocab_ids[-1] == CODE_END_TOKEN_ID:
vocab_ids = vocab_ids[:-1]
# Ensure complete frames (divisible by 7)
frames = len(vocab_ids) // SNAC_TOKENS_PER_FRAME
vocab_ids = vocab_ids[:frames * SNAC_TOKENS_PER_FRAME]
if frames == 0:
return [[], [], []]
l1, l2, l3 = [], [], []
for i in range(frames):
# Extract 7 slots for this frame
slots = vocab_ids[i*7:(i+1)*7]
# Subtract offset (128266) and mod 4096 to get original codes
# Each level uses 4096 codes (0-4095)
l1.append((slots[0] - CODE_TOKEN_OFFSET) % 4096)
l2.extend([
(slots[1] - CODE_TOKEN_OFFSET) % 4096, # Even index
(slots[4] - CODE_TOKEN_OFFSET) % 4096, # Odd index
])
l3.extend([
(slots[2] - CODE_TOKEN_OFFSET) % 4096,
(slots[3] - CODE_TOKEN_OFFSET) % 4096,
(slots[5] - CODE_TOKEN_OFFSET) % 4096,
(slots[6] - CODE_TOKEN_OFFSET) % 4096,
])
return [l1, l2, l3]
@torch.inference_mode()
def decode(
self,
snac_tokens: List[int],
trim_warmup: bool = True,
trim_amount: Optional[int] = None,
use_sliding_window: bool = False
) -> Optional[np.ndarray]:
"""
Decode SNAC tokens to audio waveform.
Args:
snac_tokens: List of SNAC token IDs (7*n tokens)
trim_warmup: Whether to trim SNAC warmup samples (default: True)
trim_amount: Number of samples to trim (default: 2048 for first chunk, 0 for others)
Can be set to a smaller value (e.g., 512) for intermediate chunks
use_sliding_window: If True, only return middle 2048 samples (for sliding window streaming)
Returns:
Audio waveform as numpy array (float32, 24kHz mono)
Shape: (samples,)
Returns None if not enough tokens
"""
if len(snac_tokens) < SNAC_TOKENS_PER_FRAME:
print(f"Not enough SNAC tokens: {len(snac_tokens)} < {SNAC_TOKENS_PER_FRAME}")
return None
# Unpack to 3 levels
levels = self.unpack_snac_from_7(snac_tokens)
if not levels[0]: # No frames after unpacking
return None
# Convert to tensors
codes = [
torch.tensor(level, dtype=torch.long, device=self.device).unsqueeze(0)
for level in levels
]
# Decode through SNAC
z_q = self.snac_model.quantizer.from_codes(codes)
audio = self.snac_model.decoder(z_q)
# Extract audio (remove padding if any)
# SNAC decoder outputs: [batch, 1, samples]
audio = audio[0, 0].cpu().numpy()
# Sliding window mode: only keep middle 2048 samples
# This eliminates popping/cracking when using overlapping 28-token windows
if use_sliding_window:
if len(audio) >= 4096:
audio = audio[2048:4096] # Keep middle portion only
else:
# For shorter audio, keep everything (final chunk)
pass
else:
# Standard mode: trim warm-up samples
# Default: 2048 samples for first chunk, 0 for subsequent chunks
# Can be customized via trim_amount parameter
if trim_warmup:
if trim_amount is None:
trim_amount = 2048 # Default full trim
if len(audio) > trim_amount:
audio = audio[trim_amount:]
return audio
def decode_to_bytes(
self,
snac_tokens: List[int],
trim_warmup: bool = True,
use_sliding_window: bool = False
) -> Optional[bytes]:
"""
Decode SNAC tokens to audio bytes (int16 PCM).
Args:
snac_tokens: List of SNAC token IDs
trim_warmup: Whether to trim SNAC warmup samples (default: True)
use_sliding_window: If True, only return middle 2048 samples (for sliding window streaming)
Returns:
Audio as bytes (int16 PCM, 24kHz mono)
Returns None if decode fails
"""
audio = self.decode(snac_tokens, trim_warmup=trim_warmup, use_sliding_window=use_sliding_window)
if audio is None:
return None
# Convert float32 to int16 PCM
audio_int16 = (audio * 32767).astype(np.int16)
return audio_int16.tobytes()
def validate_tokens(self, snac_tokens: List[int]) -> bool:
"""
Validate SNAC tokens before decoding.
Args:
snac_tokens: List of SNAC token IDs
Returns:
True if valid, False otherwise
"""
# Check minimum length
if len(snac_tokens) < SNAC_TOKENS_PER_FRAME:
print(f"Too few tokens: {len(snac_tokens)}")
return False
# Check divisibility by 7
if len(snac_tokens) % SNAC_TOKENS_PER_FRAME != 0:
print(f" Warning: Token count {len(snac_tokens)} not divisible by 7")
print(f" Will truncate to {(len(snac_tokens) // 7) * 7}")
# Check token range
for i, token_id in enumerate(snac_tokens):
if token_id < CODE_TOKEN_OFFSET or token_id > 156937:
print(f" Invalid token at position {i}: {token_id}")
print(f" Expected range: [{CODE_TOKEN_OFFSET}, 156937]")
return False
return True
# ========== Async Batching Methods ==========
@property
def is_running(self) -> bool:
"""Check if batch processor is running."""
return self._running if self.enable_batching else False
async def start_batch_processor(self):
"""Start the background batch processor task."""
if not self.enable_batching:
return
if self._running:
print("Batch processor already running")
return
self._running = True
self.batch_processor_task = asyncio.create_task(self._batch_processor_loop())
print("Batch processor started")
async def stop_batch_processor(self):
"""Stop the background batch processor task."""
if not self.enable_batching:
return
if not self._running:
return
self._running = False
if self.batch_processor_task:
self.batch_processor_task.cancel()
try:
await self.batch_processor_task
except asyncio.CancelledError:
pass
print("Batch processor stopped")
async def decode_single_async(
self,
snac_tokens: List[int],
trim_warmup: bool = True,
use_sliding_window: bool = False
) -> Optional[bytes]:
"""
Async decode for batching support.
Queues the request and waits for batched processing.
Args:
snac_tokens: List of SNAC token IDs
trim_warmup: Whether to trim SNAC warmup samples (default: True)
use_sliding_window: If True, only return middle 2048 samples (for sliding window streaming)
Returns:
Audio bytes or None if decode fails
"""
if not self.enable_batching:
# Fallback to synchronous decode
return self.decode_to_bytes(snac_tokens, trim_warmup=trim_warmup, use_sliding_window=use_sliding_window)
# Create future for result
result_future = asyncio.Future()
# Add to queue (include trim_warmup and sliding_window flags)
await self.request_queue.put((snac_tokens, trim_warmup, use_sliding_window, result_future))
# Wait for result
return await result_future
async def _batch_processor_loop(self):
"""Background task that processes batched decode requests."""
while self._running:
try:
# Collect batch
batch = await self._collect_batch()
if not batch:
continue
# Process batch
await self._process_batch(batch)
except asyncio.CancelledError:
break
except Exception as e:
print(f"Batch processor error: {e}")
import traceback
traceback.print_exc()
async def _collect_batch(self) -> List[Tuple[List[int], bool, bool, asyncio.Future]]:
"""
Collect requests into a batch.
Waits for timeout or until batch is full.
Returns:
List of (tokens, trim_warmup, use_sliding_window, future) tuples
"""
batch = []
timeout_sec = self.batch_timeout_ms / 1000.0
try:
# Wait for first request (blocking)
first_item = await asyncio.wait_for(
self.request_queue.get(),
timeout=timeout_sec
)
batch.append(first_item)
# Collect more requests (non-blocking)
while len(batch) < self.max_batch_size:
try:
item = await asyncio.wait_for(
self.request_queue.get(),
timeout=timeout_sec
)
batch.append(item)
except asyncio.TimeoutError:
break # Timeout reached, process what we have
except asyncio.TimeoutError:
# No requests in timeout period
pass
return batch
@torch.inference_mode()
async def _process_batch(self, batch: List[Tuple[List[int], bool, bool, asyncio.Future]]):
"""
Process a batch of decode requests.
Args:
batch: List of (tokens, trim_warmup, use_sliding_window, future) tuples
"""
if not batch:
return
# Extract components
token_sequences = [item[0] for item in batch]
trim_warmup_flags = [item[1] for item in batch]
sliding_window_flags = [item[2] for item in batch]
futures = [item[3] for item in batch]
lengths = [len(tokens) for tokens in token_sequences]
can_batch_efficiently = len(set(lengths)) == 1
if can_batch_efficiently and len(batch) > 1:
# Efficient batching: all same length
try:
audio_bytes_list = await self._decode_batch_same_length(
token_sequences, trim_warmup_flags, sliding_window_flags
)
# Set results
for future, audio_bytes in zip(futures, audio_bytes_list):
if not future.done():
future.set_result(audio_bytes)
except Exception as e:
# Set exceptions
for future in futures:
if not future.done():
future.set_exception(e)
else:
# Sequential decode (different lengths or single item)
for tokens, trim_warmup, use_sliding_window, future in batch:
try:
audio_bytes = self.decode_to_bytes(
tokens, trim_warmup=trim_warmup, use_sliding_window=use_sliding_window
)
if not future.done():
future.set_result(audio_bytes)
except Exception as e:
if not future.done():
future.set_exception(e)
async def _decode_batch_same_length(
self,
token_sequences: List[List[int]],
trim_warmup_flags: List[bool],
sliding_window_flags: List[bool]
) -> List[Optional[bytes]]:
"""
Decode multiple sequences with same length in parallel.
Args:
token_sequences: List of token sequences (all same length)
trim_warmup_flags: List of trim_warmup flags for each sequence
sliding_window_flags: List of use_sliding_window flags for each sequence
Returns:
List of audio bytes
"""
if not token_sequences:
return []
# Unpack all sequences
unpacked_list = [self.unpack_snac_from_7(tokens) for tokens in token_sequences]
# Check all have valid frames
valid_indices = [i for i, levels in enumerate(unpacked_list) if levels[0]]
if not valid_indices:
return [None] * len(token_sequences)
# Stack into batched tensors
batch_size = len(valid_indices)
frames = len(unpacked_list[valid_indices[0]][0])
# Build batched codes [batch, frames], [batch, 2*frames], [batch, 4*frames]
codes = [
torch.stack([
torch.tensor(unpacked_list[i][level_idx], dtype=torch.long, device=self.device)
for i in valid_indices
], dim=0)
for level_idx in range(3)
]
# Batched decode
z_q = self.snac_model.quantizer.from_codes(codes)
audio_batch = self.snac_model.decoder(z_q) # [batch, 1, samples]
# Extract and convert to bytes
audio_bytes_list = [None] * len(token_sequences)
for batch_idx, orig_idx in enumerate(valid_indices):
audio = audio_batch[batch_idx, 0].detach().cpu().numpy()
# Apply sliding window or trim warmup based on flags
if sliding_window_flags[orig_idx]:
# Sliding window mode: keep middle 2048 samples only
if len(audio) >= 4096:
audio = audio[2048:4096]
else:
# Standard mode: trim warm-up if requested
if trim_warmup_flags[orig_idx] and len(audio) > 2048:
audio = audio[2048:]
# Convert to int16
audio_int16 = (audio * 32767).astype(np.int16)
audio_bytes_list[orig_idx] = audio_int16.tobytes()
return audio_bytes_list |