Spaces:
Running
Running
| import os | |
| import time | |
| import shutil | |
| import sys | |
| import gc | |
| import argparse | |
| import json | |
| import subprocess | |
| from datetime import datetime | |
| from tabulate import tabulate | |
| SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| sys.path.append(SCRIPT_DIR) | |
| os.chdir(SCRIPT_DIR) | |
| from model_list import models_data | |
| from utils.preedit_config import conf_editor | |
| from utils.download_models import download_model | |
| MODELS_CACHE_DIR = os.path.join(SCRIPT_DIR, "separator", "models_cache") | |
| MODEL_TYPES = ["mel_band_roformer", "bs_roformer", "mdx23c", "scnet", "htdemucs", "bandit", "bandit_v2", "vr", "mdx"] | |
| OUTPUT_FORMATS = ["mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "aiff"] | |
| class MVSEPLESS: | |
| def __init__(self): | |
| self.models_cache_dir = os.path.join(SCRIPT_DIR, "separator", "models_cache") | |
| self.model_types = MODEL_TYPES | |
| self.output_formats = OUTPUT_FORMATS | |
| def get_mt(self): | |
| return list(models_data.keys()) | |
| def get_mn(self, model_type): | |
| return list(models_data[model_type].keys()) | |
| def get_stems(self, model_type, model_name): | |
| stems = models_data[model_type][model_name]["stems"] | |
| return stems | |
| def get_tgt_inst(self, model_type, model_name): | |
| target_instrument = models_data[model_type][model_name]["target_instrument"] | |
| return target_instrument | |
| def display_models_info(self, filter: str = None): | |
| print("\nAvailable Models Information:") | |
| print("=" * 50) | |
| for model_type in models_data: | |
| print(f"\nModel Type: {model_type.upper()}") | |
| print("-" * 50) | |
| table_data = [] | |
| headers = ["Model Name", "Stems", "Target Instrument", "Primary Stem"] | |
| for model_name in models_data[model_type]: | |
| model_info = models_data[model_type][model_name] | |
| if filter and filter not in model_info.get('stems', []): | |
| continue | |
| stems = "\n".join(model_info.get('stems', [])) if 'stems' in model_info else "N/A" | |
| target = model_info.get('target_instrument', "N/A") | |
| primary = model_info.get('primary_stem', "N/A") | |
| table_data.append([model_name, stems, target, primary]) | |
| print(tabulate(table_data, headers=headers, tablefmt="grid")) | |
| print() | |
| def separator( | |
| self, | |
| input_file: str = None, | |
| output_dir: str = None, | |
| model_type: str = "mel_band_roformer", | |
| model_name: str = "Mel-Band-Roformer_Vocals_kimberley_jensen", | |
| ext_inst: bool = False, | |
| vr_aggr: int = 5, | |
| output_format: str = "wav", | |
| output_bitrate: str = "320k", | |
| template: str = "NAME_(STEM)_MODEL", | |
| call_method: str = "cli", | |
| selected_stems: list = None | |
| ): | |
| if selected_stems is None: | |
| selected_stems = [] | |
| if not input_file: | |
| print("Please, input path to input file") | |
| return [("None", "/none/none.mp3")] | |
| if not os.path.exists(input_file): | |
| print("Input file not exist") | |
| return [("None", "/none/none.mp3")] | |
| if "STEM" not in template: | |
| template = template + "_STEM" | |
| print(f"Starting inference: {model_type}/{model_name}, bitrate={output_bitrate}, method={call_method}, stems={selected_stems}") | |
| os.makedirs(output_dir, exist_ok=True) | |
| if model_type in ["mel_band_roformer", "bs_roformer", "mdx23c", "scnet", "htdemucs", "bandit", "bandit_v2"]: | |
| try: | |
| info = models_data[model_type][model_name] | |
| except KeyError: | |
| print("Model not exist") | |
| return [("None", "/none/none.mp3")] | |
| conf, ckpt = download_model(self.models_cache_dir, model_name, model_type, | |
| info["checkpoint_url"], info["config_url"]) | |
| if model_type != "htdemucs": | |
| conf_editor(conf) | |
| if call_method == "cli": | |
| cmd = ["python", "-m", "separator.msst_separator", f'--input "{input_file}"', | |
| f'--store_dir "{output_dir}"', f'--model_type "{model_type}"', | |
| f'--model_name "{model_name}"', f'--config_path "{conf}"', | |
| f'--start_check_point "{ckpt}"', f'--output_format "{output_format}"', | |
| f'--output_bitrate "{output_bitrate}"', f'--template "{template}"', | |
| "--save_results_info"] | |
| if ext_inst: | |
| cmd.append("--extract_instrumental") | |
| if selected_stems: | |
| instruments = " ".join(f'"{s}"' for s in selected_stems) | |
| cmd.append(f'--selected_instruments {instruments}') | |
| subprocess.run(" ".join(cmd), shell=True, check=True) | |
| results_path = os.path.join(output_dir, "results.json") | |
| if os.path.exists(results_path): | |
| with open(results_path, encoding="utf-8") as f: | |
| return json.load(f) | |
| return [("None", "/none/none.mp3")] | |
| elif call_method == "direct": | |
| from separator.msst_separator import mvsep_offline | |
| try: | |
| return mvsep_offline( | |
| input_path=input_file, store_dir=output_dir, model_type=model_type, | |
| config_path=conf, start_check_point=ckpt, extract_instrumental=ext_inst, | |
| output_format=output_format, output_bitrate=output_bitrate, | |
| model_name=model_name, template=template, selected_instruments=selected_stems | |
| ) | |
| except Exception as e: | |
| print(e) | |
| return [("None", "/none/none.mp3")] | |
| elif model_type in ["vr", "mdx"]: | |
| try: | |
| info = models_data[model_type][model_name] | |
| except KeyError: | |
| print("Model not exist") | |
| return [("None", "/none/none.mp3")] | |
| if model_type == "vr" and info.get("custom_vr", False): | |
| conf, ckpt = download_model(self.models_cache_dir, model_name, model_type, | |
| info["checkpoint_url"], info["config_url"]) | |
| primary_stem = info["primary_stem"] | |
| if call_method == "cli": | |
| cmd = ["python", "-m", "separator.uvr_sep", "custom_vr", | |
| f'--input_file "{input_file}"', f'--ckpt_path "{ckpt}"', | |
| f'--config_path "{conf}"', f'--bitrate "{output_bitrate}"', | |
| f'--model_name "{model_name}"', f'--template "{template}"', | |
| f'--output_format "{output_format}"', f'--primary_stem "{primary_stem}"', | |
| f'--aggression {vr_aggr}', f'--output_dir "{output_dir}"'] | |
| if selected_stems: | |
| instruments = " ".join(f'"{s}"' for s in selected_stems) | |
| cmd.append(f'--selected_instruments {instruments}') | |
| subprocess.run(" ".join(cmd), shell=True, check=True) | |
| results_path = os.path.join(output_dir, "results.json") | |
| if os.path.exists(results_path): | |
| with open(results_path, encoding="utf-8") as f: | |
| return json.load(f) | |
| return [("None", "/none/none.mp3")] | |
| elif call_method == "direct": | |
| from separator.uvr_sep import custom_vr_separate | |
| try: | |
| return custom_vr_separate( | |
| input_file=input_file, ckpt_path=ckpt, config_path=conf, | |
| bitrate=output_bitrate, model_name=model_name, template=template, | |
| output_format=output_format, primary_stem=primary_stem, | |
| aggression=vr_aggr, output_dir=output_dir, | |
| selected_instruments=selected_stems | |
| ) | |
| except Exception as e: | |
| print(e) | |
| return [("None", "/none/none.mp3")] | |
| else: | |
| if call_method == "cli": | |
| cmd = ["python", "-m", "separator.uvr_sep", "uvr", | |
| f'--input_file "{input_file}"', f'--output_dir "{output_dir}"', | |
| f'--template "{template}"', f'--bitrate "{output_bitrate}"', | |
| f'--model_dir "{self.models_cache_dir}"', f'--model_type "{model_type}"', | |
| f'--model_name "{model_name}"', f'--output_format "{output_format}"', | |
| f'--aggression {vr_aggr}'] | |
| if selected_stems: | |
| instruments = " ".join(f'"{s}"' for s in selected_stems) | |
| cmd.append(f'--selected_instruments {instruments}') | |
| subprocess.run(" ".join(cmd), shell=True, check=True) | |
| results_path = os.path.join(output_dir, "results.json") | |
| if os.path.exists(results_path): | |
| with open(results_path, encoding="utf-8") as f: | |
| return json.load(f) | |
| return [("None", "/none/none.mp3")] | |
| elif call_method == "direct": | |
| from separator.uvr_sep import non_custom_uvr_inference | |
| try: | |
| return non_custom_uvr_inference( | |
| input_file=input_file, output_dir=output_dir, template=template, | |
| bitrate=output_bitrate, model_dir=self.models_cache_dir, | |
| model_type=model_type, model_name=model_name, | |
| output_format=output_format, aggression=vr_aggr, | |
| selected_instruments=selected_stems | |
| ) | |
| except Exception as e: | |
| print(e) | |
| return [("None", "/none/none.mp3")] | |
| print("Unsupported model type") | |
| return [("None", "/none/none.mp3")] | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Multi-inference for separation audio in Google Colab") | |
| subparsers = parser.add_subparsers(dest='command', required=True, help='Sub-command help') | |
| list_models = subparsers.add_parser('list', help='List of exist models') | |
| list_models.add_argument("-l_filter", "--list_filter", type=str, default=None, help="Show models in list only with specified stem") | |
| separate = subparsers.add_parser('separate', help='Separate I/O params') | |
| separate.add_argument("-i", "--input", type=str, required=True, help="Input file or directory") | |
| separate.add_argument("-o", "--output", type=str, required=True, help="Output directory") | |
| separate.add_argument("-mt", "--model_type", type=str, required=True, choices=MODEL_TYPES, help="Model type") | |
| separate.add_argument("-mn", "--model_name", type=str, required=True, help="Model name") | |
| separate.add_argument("-inst", "--instrumental", action='store_true', help="Extract instrumental") | |
| separate.add_argument("-stems", "--stems", nargs="+", help="Select output stems") | |
| separate.add_argument("-bitrate", "--bitrate", type=str, default="320k", help="Output bitrate") | |
| separate.add_argument("-of", "--format", type=str, default="mp3", help="Output format") | |
| separate.add_argument("-vr_aggr", "--vr_arch_aggressive", type=int, default=5, help="Aggression for VR ARCH models") | |
| separate.add_argument('--template', type=str, default='NAME_STEM', help='Template naming of output files') | |
| separate.add_argument("-l_out", "--list_output", action='store_true', help="Show list output files") | |
| return parser.parse_args() | |
| if __name__ == "__main__": | |
| args = parse_args() | |
| mvsepless = MVSEPLESS() | |
| if args.command == 'list': | |
| mvsepless.display_models_info(args.list_filter) | |
| elif args.command == 'separate': | |
| if os.path.isfile(args.input): | |
| results = mvsepless.separator( | |
| input_file=args.input, | |
| output_dir=args.output, | |
| model_type=args.model_type, | |
| model_name=args.model_name, | |
| ext_inst=args.instrumental, | |
| vr_aggr=args.vr_arch_aggressive, | |
| output_format=args.format, | |
| output_bitrate=args.bitrate, | |
| template=args.template, | |
| call_method="cli", | |
| selected_stems=args.stems | |
| ) | |
| if args.list_output: | |
| print("Results\n") | |
| for stem, path in results: | |
| print(f"Stem - {stem}\nPath - {path}\n") | |
| elif os.path.isdir(args.input): | |
| batch_results = [] | |
| for file in os.listdir(args.input): | |
| abs_path_file = os.path.join(args.input, file) | |
| if os.path.isfile(abs_path_file): | |
| base_name = os.path.splitext(os.path.basename(abs_path_file))[0] | |
| output_subdir = os.path.join(args.output, base_name) | |
| results = mvsepless.separator( | |
| input_file=abs_path_file, | |
| output_dir=output_subdir, | |
| model_type=args.model_type, | |
| model_name=args.model_name, | |
| ext_inst=args.instrumental, | |
| vr_aggr=args.vr_arch_aggressive, | |
| output_format=args.format, | |
| output_bitrate=args.bitrate, | |
| template=args.template, | |
| call_method="cli", | |
| selected_stems=args.stems | |
| ) | |
| batch_results.append((base_name, results)) | |
| if args.list_output: | |
| print("Results\n") | |
| for name, stems in batch_results: | |
| print(f"Name - {name}") | |
| for stem, path in stems: | |
| print(f" Stem - {stem}\n Path - {path}\n") | |