mvsepless_cpu_beta / multi_inference.py
noblebarkrr's picture
test1
d0cd3b0
import os
import time
import shutil
import sys
import gc
import argparse
import json
import subprocess
from datetime import datetime
from tabulate import tabulate
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(SCRIPT_DIR)
os.chdir(SCRIPT_DIR)
from model_list import models_data
from utils.preedit_config import conf_editor
from utils.download_models import download_model
MODELS_CACHE_DIR = os.path.join(SCRIPT_DIR, "separator", "models_cache")
MODEL_TYPES = ["mel_band_roformer", "bs_roformer", "mdx23c", "scnet", "htdemucs", "bandit", "bandit_v2", "vr", "mdx"]
OUTPUT_FORMATS = ["mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "aiff"]
class MVSEPLESS:
def __init__(self):
self.models_cache_dir = os.path.join(SCRIPT_DIR, "separator", "models_cache")
self.model_types = MODEL_TYPES
self.output_formats = OUTPUT_FORMATS
def get_mt(self):
return list(models_data.keys())
def get_mn(self, model_type):
return list(models_data[model_type].keys())
def get_stems(self, model_type, model_name):
stems = models_data[model_type][model_name]["stems"]
return stems
def get_tgt_inst(self, model_type, model_name):
target_instrument = models_data[model_type][model_name]["target_instrument"]
return target_instrument
def display_models_info(self, filter: str = None):
print("\nAvailable Models Information:")
print("=" * 50)
for model_type in models_data:
print(f"\nModel Type: {model_type.upper()}")
print("-" * 50)
table_data = []
headers = ["Model Name", "Stems", "Target Instrument", "Primary Stem"]
for model_name in models_data[model_type]:
model_info = models_data[model_type][model_name]
if filter and filter not in model_info.get('stems', []):
continue
stems = "\n".join(model_info.get('stems', [])) if 'stems' in model_info else "N/A"
target = model_info.get('target_instrument', "N/A")
primary = model_info.get('primary_stem', "N/A")
table_data.append([model_name, stems, target, primary])
print(tabulate(table_data, headers=headers, tablefmt="grid"))
print()
def separator(
self,
input_file: str = None,
output_dir: str = None,
model_type: str = "mel_band_roformer",
model_name: str = "Mel-Band-Roformer_Vocals_kimberley_jensen",
ext_inst: bool = False,
vr_aggr: int = 5,
output_format: str = "wav",
output_bitrate: str = "320k",
template: str = "NAME_(STEM)_MODEL",
call_method: str = "cli",
selected_stems: list = None
):
if selected_stems is None:
selected_stems = []
if not input_file:
print("Please, input path to input file")
return [("None", "/none/none.mp3")]
if not os.path.exists(input_file):
print("Input file not exist")
return [("None", "/none/none.mp3")]
if "STEM" not in template:
template = template + "_STEM"
print(f"Starting inference: {model_type}/{model_name}, bitrate={output_bitrate}, method={call_method}, stems={selected_stems}")
os.makedirs(output_dir, exist_ok=True)
if model_type in ["mel_band_roformer", "bs_roformer", "mdx23c", "scnet", "htdemucs", "bandit", "bandit_v2"]:
try:
info = models_data[model_type][model_name]
except KeyError:
print("Model not exist")
return [("None", "/none/none.mp3")]
conf, ckpt = download_model(self.models_cache_dir, model_name, model_type,
info["checkpoint_url"], info["config_url"])
if model_type != "htdemucs":
conf_editor(conf)
if call_method == "cli":
cmd = ["python", "-m", "separator.msst_separator", f'--input "{input_file}"',
f'--store_dir "{output_dir}"', f'--model_type "{model_type}"',
f'--model_name "{model_name}"', f'--config_path "{conf}"',
f'--start_check_point "{ckpt}"', f'--output_format "{output_format}"',
f'--output_bitrate "{output_bitrate}"', f'--template "{template}"',
"--save_results_info"]
if ext_inst:
cmd.append("--extract_instrumental")
if selected_stems:
instruments = " ".join(f'"{s}"' for s in selected_stems)
cmd.append(f'--selected_instruments {instruments}')
subprocess.run(" ".join(cmd), shell=True, check=True)
results_path = os.path.join(output_dir, "results.json")
if os.path.exists(results_path):
with open(results_path, encoding="utf-8") as f:
return json.load(f)
return [("None", "/none/none.mp3")]
elif call_method == "direct":
from separator.msst_separator import mvsep_offline
try:
return mvsep_offline(
input_path=input_file, store_dir=output_dir, model_type=model_type,
config_path=conf, start_check_point=ckpt, extract_instrumental=ext_inst,
output_format=output_format, output_bitrate=output_bitrate,
model_name=model_name, template=template, selected_instruments=selected_stems
)
except Exception as e:
print(e)
return [("None", "/none/none.mp3")]
elif model_type in ["vr", "mdx"]:
try:
info = models_data[model_type][model_name]
except KeyError:
print("Model not exist")
return [("None", "/none/none.mp3")]
if model_type == "vr" and info.get("custom_vr", False):
conf, ckpt = download_model(self.models_cache_dir, model_name, model_type,
info["checkpoint_url"], info["config_url"])
primary_stem = info["primary_stem"]
if call_method == "cli":
cmd = ["python", "-m", "separator.uvr_sep", "custom_vr",
f'--input_file "{input_file}"', f'--ckpt_path "{ckpt}"',
f'--config_path "{conf}"', f'--bitrate "{output_bitrate}"',
f'--model_name "{model_name}"', f'--template "{template}"',
f'--output_format "{output_format}"', f'--primary_stem "{primary_stem}"',
f'--aggression {vr_aggr}', f'--output_dir "{output_dir}"']
if selected_stems:
instruments = " ".join(f'"{s}"' for s in selected_stems)
cmd.append(f'--selected_instruments {instruments}')
subprocess.run(" ".join(cmd), shell=True, check=True)
results_path = os.path.join(output_dir, "results.json")
if os.path.exists(results_path):
with open(results_path, encoding="utf-8") as f:
return json.load(f)
return [("None", "/none/none.mp3")]
elif call_method == "direct":
from separator.uvr_sep import custom_vr_separate
try:
return custom_vr_separate(
input_file=input_file, ckpt_path=ckpt, config_path=conf,
bitrate=output_bitrate, model_name=model_name, template=template,
output_format=output_format, primary_stem=primary_stem,
aggression=vr_aggr, output_dir=output_dir,
selected_instruments=selected_stems
)
except Exception as e:
print(e)
return [("None", "/none/none.mp3")]
else:
if call_method == "cli":
cmd = ["python", "-m", "separator.uvr_sep", "uvr",
f'--input_file "{input_file}"', f'--output_dir "{output_dir}"',
f'--template "{template}"', f'--bitrate "{output_bitrate}"',
f'--model_dir "{self.models_cache_dir}"', f'--model_type "{model_type}"',
f'--model_name "{model_name}"', f'--output_format "{output_format}"',
f'--aggression {vr_aggr}']
if selected_stems:
instruments = " ".join(f'"{s}"' for s in selected_stems)
cmd.append(f'--selected_instruments {instruments}')
subprocess.run(" ".join(cmd), shell=True, check=True)
results_path = os.path.join(output_dir, "results.json")
if os.path.exists(results_path):
with open(results_path, encoding="utf-8") as f:
return json.load(f)
return [("None", "/none/none.mp3")]
elif call_method == "direct":
from separator.uvr_sep import non_custom_uvr_inference
try:
return non_custom_uvr_inference(
input_file=input_file, output_dir=output_dir, template=template,
bitrate=output_bitrate, model_dir=self.models_cache_dir,
model_type=model_type, model_name=model_name,
output_format=output_format, aggression=vr_aggr,
selected_instruments=selected_stems
)
except Exception as e:
print(e)
return [("None", "/none/none.mp3")]
print("Unsupported model type")
return [("None", "/none/none.mp3")]
def parse_args():
parser = argparse.ArgumentParser(description="Multi-inference for separation audio in Google Colab")
subparsers = parser.add_subparsers(dest='command', required=True, help='Sub-command help')
list_models = subparsers.add_parser('list', help='List of exist models')
list_models.add_argument("-l_filter", "--list_filter", type=str, default=None, help="Show models in list only with specified stem")
separate = subparsers.add_parser('separate', help='Separate I/O params')
separate.add_argument("-i", "--input", type=str, required=True, help="Input file or directory")
separate.add_argument("-o", "--output", type=str, required=True, help="Output directory")
separate.add_argument("-mt", "--model_type", type=str, required=True, choices=MODEL_TYPES, help="Model type")
separate.add_argument("-mn", "--model_name", type=str, required=True, help="Model name")
separate.add_argument("-inst", "--instrumental", action='store_true', help="Extract instrumental")
separate.add_argument("-stems", "--stems", nargs="+", help="Select output stems")
separate.add_argument("-bitrate", "--bitrate", type=str, default="320k", help="Output bitrate")
separate.add_argument("-of", "--format", type=str, default="mp3", help="Output format")
separate.add_argument("-vr_aggr", "--vr_arch_aggressive", type=int, default=5, help="Aggression for VR ARCH models")
separate.add_argument('--template', type=str, default='NAME_STEM', help='Template naming of output files')
separate.add_argument("-l_out", "--list_output", action='store_true', help="Show list output files")
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
mvsepless = MVSEPLESS()
if args.command == 'list':
mvsepless.display_models_info(args.list_filter)
elif args.command == 'separate':
if os.path.isfile(args.input):
results = mvsepless.separator(
input_file=args.input,
output_dir=args.output,
model_type=args.model_type,
model_name=args.model_name,
ext_inst=args.instrumental,
vr_aggr=args.vr_arch_aggressive,
output_format=args.format,
output_bitrate=args.bitrate,
template=args.template,
call_method="cli",
selected_stems=args.stems
)
if args.list_output:
print("Results\n")
for stem, path in results:
print(f"Stem - {stem}\nPath - {path}\n")
elif os.path.isdir(args.input):
batch_results = []
for file in os.listdir(args.input):
abs_path_file = os.path.join(args.input, file)
if os.path.isfile(abs_path_file):
base_name = os.path.splitext(os.path.basename(abs_path_file))[0]
output_subdir = os.path.join(args.output, base_name)
results = mvsepless.separator(
input_file=abs_path_file,
output_dir=output_subdir,
model_type=args.model_type,
model_name=args.model_name,
ext_inst=args.instrumental,
vr_aggr=args.vr_arch_aggressive,
output_format=args.format,
output_bitrate=args.bitrate,
template=args.template,
call_method="cli",
selected_stems=args.stems
)
batch_results.append((base_name, results))
if args.list_output:
print("Results\n")
for name, stems in batch_results:
print(f"Name - {name}")
for stem, path in stems:
print(f" Stem - {stem}\n Path - {path}\n")