Spaces:
Running
Running
| from typing import Dict, List, Optional | |
| import torch | |
| import torchaudio as ta | |
| from torch import nn | |
| import pytorch_lightning as pl | |
| from .bandsplit import BandSplitModule | |
| from .maskestim import OverlappingMaskEstimationModule | |
| from .tfmodel import SeqBandModellingModule | |
| from .utils import MusicalBandsplitSpecification | |
| class BaseEndToEndModule(pl.LightningModule): | |
| def __init__( | |
| self, | |
| ) -> None: | |
| super().__init__() | |
| class BaseBandit(BaseEndToEndModule): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| fs: int, | |
| band_type: str = "musical", | |
| n_bands: int = 64, | |
| require_no_overlap: bool = False, | |
| require_no_gap: bool = True, | |
| normalize_channel_independently: bool = False, | |
| treat_channel_as_feature: bool = True, | |
| n_sqm_modules: int = 12, | |
| emb_dim: int = 128, | |
| rnn_dim: int = 256, | |
| bidirectional: bool = True, | |
| rnn_type: str = "LSTM", | |
| n_fft: int = 2048, | |
| win_length: Optional[int] = 2048, | |
| hop_length: int = 512, | |
| window_fn: str = "hann_window", | |
| wkwargs: Optional[Dict] = None, | |
| power: Optional[int] = None, | |
| center: bool = True, | |
| normalized: bool = True, | |
| pad_mode: str = "constant", | |
| onesided: bool = True, | |
| ): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.instantitate_spectral( | |
| n_fft=n_fft, | |
| win_length=win_length, | |
| hop_length=hop_length, | |
| window_fn=window_fn, | |
| wkwargs=wkwargs, | |
| power=power, | |
| normalized=normalized, | |
| center=center, | |
| pad_mode=pad_mode, | |
| onesided=onesided, | |
| ) | |
| self.instantiate_bandsplit( | |
| in_channels=in_channels, | |
| band_type=band_type, | |
| n_bands=n_bands, | |
| require_no_overlap=require_no_overlap, | |
| require_no_gap=require_no_gap, | |
| normalize_channel_independently=normalize_channel_independently, | |
| treat_channel_as_feature=treat_channel_as_feature, | |
| emb_dim=emb_dim, | |
| n_fft=n_fft, | |
| fs=fs, | |
| ) | |
| self.instantiate_tf_modelling( | |
| n_sqm_modules=n_sqm_modules, | |
| emb_dim=emb_dim, | |
| rnn_dim=rnn_dim, | |
| bidirectional=bidirectional, | |
| rnn_type=rnn_type, | |
| ) | |
| def instantitate_spectral( | |
| self, | |
| n_fft: int = 2048, | |
| win_length: Optional[int] = 2048, | |
| hop_length: int = 512, | |
| window_fn: str = "hann_window", | |
| wkwargs: Optional[Dict] = None, | |
| power: Optional[int] = None, | |
| normalized: bool = True, | |
| center: bool = True, | |
| pad_mode: str = "constant", | |
| onesided: bool = True, | |
| ): | |
| assert power is None | |
| window_fn = torch.__dict__[window_fn] | |
| self.stft = ta.transforms.Spectrogram( | |
| n_fft=n_fft, | |
| win_length=win_length, | |
| hop_length=hop_length, | |
| pad_mode=pad_mode, | |
| pad=0, | |
| window_fn=window_fn, | |
| wkwargs=wkwargs, | |
| power=power, | |
| normalized=normalized, | |
| center=center, | |
| onesided=onesided, | |
| ) | |
| self.istft = ta.transforms.InverseSpectrogram( | |
| n_fft=n_fft, | |
| win_length=win_length, | |
| hop_length=hop_length, | |
| pad_mode=pad_mode, | |
| pad=0, | |
| window_fn=window_fn, | |
| wkwargs=wkwargs, | |
| normalized=normalized, | |
| center=center, | |
| onesided=onesided, | |
| ) | |
| def instantiate_bandsplit( | |
| self, | |
| in_channels: int, | |
| band_type: str = "musical", | |
| n_bands: int = 64, | |
| require_no_overlap: bool = False, | |
| require_no_gap: bool = True, | |
| normalize_channel_independently: bool = False, | |
| treat_channel_as_feature: bool = True, | |
| emb_dim: int = 128, | |
| n_fft: int = 2048, | |
| fs: int = 44100, | |
| ): | |
| assert band_type == "musical" | |
| self.band_specs = MusicalBandsplitSpecification( | |
| nfft=n_fft, fs=fs, n_bands=n_bands | |
| ) | |
| self.band_split = BandSplitModule( | |
| in_channels=in_channels, | |
| band_specs=self.band_specs.get_band_specs(), | |
| require_no_overlap=require_no_overlap, | |
| require_no_gap=require_no_gap, | |
| normalize_channel_independently=normalize_channel_independently, | |
| treat_channel_as_feature=treat_channel_as_feature, | |
| emb_dim=emb_dim, | |
| ) | |
| def instantiate_tf_modelling( | |
| self, | |
| n_sqm_modules: int = 12, | |
| emb_dim: int = 128, | |
| rnn_dim: int = 256, | |
| bidirectional: bool = True, | |
| rnn_type: str = "LSTM", | |
| ): | |
| try: | |
| self.tf_model = torch.compile( | |
| SeqBandModellingModule( | |
| n_modules=n_sqm_modules, | |
| emb_dim=emb_dim, | |
| rnn_dim=rnn_dim, | |
| bidirectional=bidirectional, | |
| rnn_type=rnn_type, | |
| ), | |
| disable=True, | |
| ) | |
| except Exception as e: | |
| self.tf_model = SeqBandModellingModule( | |
| n_modules=n_sqm_modules, | |
| emb_dim=emb_dim, | |
| rnn_dim=rnn_dim, | |
| bidirectional=bidirectional, | |
| rnn_type=rnn_type, | |
| ) | |
| def mask(self, x, m): | |
| return x * m | |
| def forward(self, batch, mode="train"): | |
| # Model takes mono as input we give stereo, so we do process of each channel independently | |
| init_shape = batch.shape | |
| if not isinstance(batch, dict): | |
| mono = batch.view(-1, 1, batch.shape[-1]) | |
| batch = { | |
| "mixture": { | |
| "audio": mono | |
| } | |
| } | |
| with torch.no_grad(): | |
| mixture = batch["mixture"]["audio"] | |
| x = self.stft(mixture) | |
| batch["mixture"]["spectrogram"] = x | |
| if "sources" in batch.keys(): | |
| for stem in batch["sources"].keys(): | |
| s = batch["sources"][stem]["audio"] | |
| s = self.stft(s) | |
| batch["sources"][stem]["spectrogram"] = s | |
| batch = self.separate(batch) | |
| if 1: | |
| b = [] | |
| for s in self.stems: | |
| # We need to obtain stereo again | |
| r = batch['estimates'][s]['audio'].view(-1, init_shape[1], init_shape[2]) | |
| b.append(r) | |
| # And we need to return back tensor and not independent stems | |
| batch = torch.stack(b, dim=1) | |
| return batch | |
| def encode(self, batch): | |
| x = batch["mixture"]["spectrogram"] | |
| length = batch["mixture"]["audio"].shape[-1] | |
| z = self.band_split(x) # (batch, emb_dim, n_band, n_time) | |
| q = self.tf_model(z) # (batch, emb_dim, n_band, n_time) | |
| return x, q, length | |
| def separate(self, batch): | |
| raise NotImplementedError | |
| class Bandit(BaseBandit): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| stems: List[str], | |
| band_type: str = "musical", | |
| n_bands: int = 64, | |
| require_no_overlap: bool = False, | |
| require_no_gap: bool = True, | |
| normalize_channel_independently: bool = False, | |
| treat_channel_as_feature: bool = True, | |
| n_sqm_modules: int = 12, | |
| emb_dim: int = 128, | |
| rnn_dim: int = 256, | |
| bidirectional: bool = True, | |
| rnn_type: str = "LSTM", | |
| mlp_dim: int = 512, | |
| hidden_activation: str = "Tanh", | |
| hidden_activation_kwargs: Dict | None = None, | |
| complex_mask: bool = True, | |
| use_freq_weights: bool = True, | |
| n_fft: int = 2048, | |
| win_length: int | None = 2048, | |
| hop_length: int = 512, | |
| window_fn: str = "hann_window", | |
| wkwargs: Dict | None = None, | |
| power: int | None = None, | |
| center: bool = True, | |
| normalized: bool = True, | |
| pad_mode: str = "constant", | |
| onesided: bool = True, | |
| fs: int = 44100, | |
| stft_precisions="32", | |
| bandsplit_precisions="bf16", | |
| tf_model_precisions="bf16", | |
| mask_estim_precisions="bf16", | |
| ): | |
| super().__init__( | |
| in_channels=in_channels, | |
| band_type=band_type, | |
| n_bands=n_bands, | |
| require_no_overlap=require_no_overlap, | |
| require_no_gap=require_no_gap, | |
| normalize_channel_independently=normalize_channel_independently, | |
| treat_channel_as_feature=treat_channel_as_feature, | |
| n_sqm_modules=n_sqm_modules, | |
| emb_dim=emb_dim, | |
| rnn_dim=rnn_dim, | |
| bidirectional=bidirectional, | |
| rnn_type=rnn_type, | |
| n_fft=n_fft, | |
| win_length=win_length, | |
| hop_length=hop_length, | |
| window_fn=window_fn, | |
| wkwargs=wkwargs, | |
| power=power, | |
| center=center, | |
| normalized=normalized, | |
| pad_mode=pad_mode, | |
| onesided=onesided, | |
| fs=fs, | |
| ) | |
| self.stems = stems | |
| self.instantiate_mask_estim( | |
| in_channels=in_channels, | |
| stems=stems, | |
| emb_dim=emb_dim, | |
| mlp_dim=mlp_dim, | |
| hidden_activation=hidden_activation, | |
| hidden_activation_kwargs=hidden_activation_kwargs, | |
| complex_mask=complex_mask, | |
| n_freq=n_fft // 2 + 1, | |
| use_freq_weights=use_freq_weights, | |
| ) | |
| def instantiate_mask_estim( | |
| self, | |
| in_channels: int, | |
| stems: List[str], | |
| emb_dim: int, | |
| mlp_dim: int, | |
| hidden_activation: str, | |
| hidden_activation_kwargs: Optional[Dict] = None, | |
| complex_mask: bool = True, | |
| n_freq: Optional[int] = None, | |
| use_freq_weights: bool = False, | |
| ): | |
| if hidden_activation_kwargs is None: | |
| hidden_activation_kwargs = {} | |
| assert n_freq is not None | |
| self.mask_estim = nn.ModuleDict( | |
| { | |
| stem: OverlappingMaskEstimationModule( | |
| band_specs=self.band_specs.get_band_specs(), | |
| freq_weights=self.band_specs.get_freq_weights(), | |
| n_freq=n_freq, | |
| emb_dim=emb_dim, | |
| mlp_dim=mlp_dim, | |
| in_channels=in_channels, | |
| hidden_activation=hidden_activation, | |
| hidden_activation_kwargs=hidden_activation_kwargs, | |
| complex_mask=complex_mask, | |
| use_freq_weights=use_freq_weights, | |
| ) | |
| for stem in stems | |
| } | |
| ) | |
| def separate(self, batch): | |
| batch["estimates"] = {} | |
| x, q, length = self.encode(batch) | |
| for stem, mem in self.mask_estim.items(): | |
| m = mem(q) | |
| s = self.mask(x, m.to(x.dtype)) | |
| s = torch.reshape(s, x.shape) | |
| batch["estimates"][stem] = { | |
| "audio": self.istft(s, length), | |
| "spectrogram": s, | |
| } | |
| return batch | |