Multi-Tagger / app.py
Werli's picture
Fix
8912092 verified
import os, io, json, requests, spaces, argparse, traceback, tempfile, zipfile, re, ast, time
import gradio as gr
import numpy as np
import huggingface_hub
import onnxruntime as ort
import pandas as pd
from datetime import datetime, timezone
from collections import defaultdict
from PIL import Image, ImageOps
from apscheduler.schedulers.background import BackgroundScheduler
from modules.classifyTags import categorize_tags_output, generate_tags_json, process_tags_for_misc
from modules.pixai import create_pixai_interface
from modules.booru import create_booru_interface
from modules.multi_comfy import create_multi_comfy
from modules.media_handler import handle_single_media_upload, handle_multiple_media_uploads
""" For GPU install all the requirements.txt and the following:
pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu126 or any other Torch version
pip install onnxruntime-gpu
"""
""" It's recommended to create a venv if you want to use it offline:
python -m venv venv
venv\Scripts\activate
pip install ...
python app.py
"""
TITLE = 'Multi-Tagger v1.4'
DESCRIPTION = '\nMulti-Tagger is a versatile application for advanced image analysis and captioning. Supports <b>CUDA</b> and <b>CPU</b>.\n'
SWINV2_MODEL_DSV3_REPO = 'SmilingWolf/wd-swinv2-tagger-v3'
CONV_MODEL_DSV3_REPO = 'SmilingWolf/wd-convnext-tagger-v3'
VIT_MODEL_DSV3_REPO = 'SmilingWolf/wd-vit-tagger-v3'
VIT_LARGE_MODEL_DSV3_REPO = 'SmilingWolf/wd-vit-large-tagger-v3'
EVA02_LARGE_MODEL_DSV3_REPO = 'SmilingWolf/wd-eva02-large-tagger-v3'
MOAT_MODEL_DSV2_REPO = 'SmilingWolf/wd-v1-4-moat-tagger-v2'
SWIN_MODEL_DSV2_REPO = 'SmilingWolf/wd-v1-4-swinv2-tagger-v2'
CONV_MODEL_DSV2_REPO = 'SmilingWolf/wd-v1-4-convnext-tagger-v2'
CONV2_MODEL_DSV2_REPO = 'SmilingWolf/wd-v1-4-convnextv2-tagger-v2'
VIT_MODEL_DSV2_REPO = 'SmilingWolf/wd-v1-4-vit-tagger-v2'
EVA02_LARGE_MODEL_IS_DSV1_REPO = 'deepghs/idolsankaku-eva02-large-tagger-v1'
SWINV2_MODEL_IS_DSV1_REPO = 'deepghs/idolsankaku-swinv2-tagger-v1'
# Global variables for model components (for memory management)
CURRENT_MODEL = None
CURRENT_MODEL_NAME = None
CURRENT_TAGS_DF = None
CURRENT_TAG_NAMES = None
CURRENT_RATING_INDEXES = None
CURRENT_GENERAL_INDEXES = None
CURRENT_CHARACTER_INDEXES = None
CURRENT_MODEL_TARGET_SIZE = None
# Custom CSS for gallery styling
css = """
#custom-gallery {--row-height: 180px;display: grid;grid-auto-rows: min-content;gap: 10px;}
#custom-gallery .thumbnail-item {height: var(--row-height);width: 100%;position: relative;overflow: hidden;border-radius: 8px;box-shadow: 0 2px 5px rgba(0, 0, 0, 0.1);transition: transform 0.2s ease, box-shadow 0.2s ease;}
#custom-gallery .thumbnail-item:hover {transform: translateY(-3px);box-shadow: 0 4px 12px rgba(0, 0, 0, 0.15);}
#custom-gallery .thumbnail-item img {width: auto;height: 100%;max-width: 100%;max-height: var(--row-height);object-fit: contain;margin: 0 auto;display: block;}
#custom-gallery .thumbnail-item img.portrait {max-width: 100%;}
#custom-gallery .thumbnail-item img.landscape {max-height: 100%;}
.gallery-container {max-height: 500px;overflow-y: auto;padding-right: 0px;--size-80: 500px;}
.thumbnails {display: flex;position: absolute;bottom: 0;width: 120px;overflow-x: scroll;padding-top: 320px;padding-bottom: 280px;padding-left: 4px;flex-wrap: wrap;}
#custom-gallery .thumbnail-item img {width: auto;height: 100%;max-width: 100%;max-height: var(--row-height);object-fit: initial;width: fit-content;margin: 0px auto;display: block;}
"""
MODEL_FILENAME = 'model.onnx'
LABEL_FILENAME = 'selected_tags.csv'
class Timer:
"""Utility class for measuring execution time of different operations"""
def __init__(self):
self.start_time = time.perf_counter()
self.checkpoints = [('Start', self.start_time)]
def checkpoint(self, label='Checkpoint'):
"""Add a checkpoint with a label"""
now = time.perf_counter()
self.checkpoints.append((label, now))
def report(self, is_clear_checkpoints=True):
"""Report time elapsed since last checkpoint"""
max_label_length = max(len(label) for (label, _) in self.checkpoints) if self.checkpoints else 0
prev_time = self.checkpoints[0][1] if self.checkpoints else self.start_time
for (label, curr_time) in self.checkpoints[1:]:
elapsed = curr_time - prev_time
print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds")
prev_time = curr_time
if is_clear_checkpoints:
self.checkpoints.clear()
self.checkpoint()
def report_all(self):
"""Report all checkpoint times including total execution time"""
print('\n> Execution Time Report:')
max_label_length = max(len(label) for (label, _) in self.checkpoints) if len(self.checkpoints) > 0 else 0
prev_time = self.start_time
for (label, curr_time) in self.checkpoints[1:]:
elapsed = curr_time - prev_time
print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds")
prev_time = curr_time
total_time = self.checkpoints[-1][1] - self.start_time if self.checkpoints else 0
print(f"{'Total Execution Time'.ljust(max_label_length)}: {total_time:.3f} seconds\n")
self.checkpoints.clear()
def restart(self):
"""Restart the timer"""
self.start_time = time.perf_counter()
self.checkpoints = [('Start', self.start_time)]
def parse_args() -> argparse.Namespace:
"""Parse command line arguments"""
parser = argparse.ArgumentParser()
parser.add_argument('--score-slider-step', type=float, default=0.05)
parser.add_argument('--score-general-threshold', type=float, default=0.35)
parser.add_argument('--score-character-threshold', type=float, default=0.85)
parser.add_argument('--share', action='store_true')
return parser.parse_args()
def load_labels(dataframe) -> tuple:
"""Load tag names and their category indexes from the dataframe"""
name_series = dataframe['name']
tag_names = name_series.tolist()
# Find indexes for different tag categories
rating_indexes = list(np.where(dataframe['category'] == 9)[0])
general_indexes = list(np.where(dataframe['category'] == 0)[0])
character_indexes = list(np.where(dataframe['category'] == 4)[0])
return tag_names, rating_indexes, general_indexes, character_indexes
def mcut_threshold(probs):
"""Calculate threshold using Maximum Change in second derivative (MCut) method"""
sorted_probs = probs[probs.argsort()[::-1]]
difs = sorted_probs[:-1] - sorted_probs[1:]
t = difs.argmax()
thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2
return thresh
def _download_model_files(model_repo):
"""Download model files from HuggingFace Hub"""
csv_path = huggingface_hub.hf_hub_download(model_repo, LABEL_FILENAME)
model_path = huggingface_hub.hf_hub_download(model_repo, MODEL_FILENAME)
return csv_path, model_path
def create_optimized_ort_session(model_path):
"""Create an optimized ONNX Runtime session with GPU support"""
# Configure session options for better performance
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.intra_op_num_threads = 0 # Use all available cores
sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL
sess_options.enable_mem_pattern = True
sess_options.enable_cpu_mem_arena = True
# Check available providers
available_providers = ort.get_available_providers()
print(f"Available ONNX Runtime providers: {available_providers}")
# Configure execution providers (prefer CUDA if available)
providers = []
# Use CUDA if available
if 'CUDAExecutionProvider' in available_providers:
providers.append('CUDAExecutionProvider')
print("Using CUDA provider for ONNX inference")
else:
print("CUDA provider not available, falling back to CPU")
# Always include CPU as fallback
providers.append('CPUExecutionProvider')
try:
session = ort.InferenceSession(model_path, sess_options, providers=providers)
print(f"Model loaded with providers: {session.get_providers()}")
return session
except Exception as e:
print(f"Failed to create ONNX session: {e}")
raise
def _load_model_components_optimized(model_repo):
"""Load and optimize model components"""
global CURRENT_MODEL, CURRENT_MODEL_NAME, CURRENT_TAGS_DF, CURRENT_TAG_NAMES
global CURRENT_RATING_INDEXES, CURRENT_GENERAL_INDEXES, CURRENT_CHARACTER_INDEXES, CURRENT_MODEL_TARGET_SIZE
# Only reload if model changed
if model_repo == CURRENT_MODEL_NAME and CURRENT_MODEL is not None:
return
# Download files
csv_path, model_path = _download_model_files(model_repo)
# Load optimized ONNX model
CURRENT_MODEL = create_optimized_ort_session(model_path)
# Load tags
tags_df = pd.read_csv(csv_path)
tag_names, rating_indexes, general_indexes, character_indexes = load_labels(tags_df)
# Store in global variables
CURRENT_TAGS_DF = tags_df
CURRENT_TAG_NAMES = tag_names
CURRENT_RATING_INDEXES = rating_indexes
CURRENT_GENERAL_INDEXES = general_indexes
CURRENT_CHARACTER_INDEXES = character_indexes
# Get model input size
_, height, width, _ = CURRENT_MODEL.get_inputs()[0].shape
CURRENT_MODEL_TARGET_SIZE = height
CURRENT_MODEL_NAME = model_repo
def _raw_predict(image_array, model_session):
"""Run raw prediction using the model session"""
input_name = model_session.get_inputs()[0].name
label_name = model_session.get_outputs()[0].name
preds = model_session.run([label_name], {input_name: image_array})[0]
return preds[0].astype(float)
def unload_model():
"""Explicitly unload the current model from memory"""
global CURRENT_MODEL, CURRENT_MODEL_NAME, CURRENT_TAGS_DF, CURRENT_TAG_NAMES
global CURRENT_RATING_INDEXES, CURRENT_GENERAL_INDEXES, CURRENT_CHARACTER_INDEXES, CURRENT_MODEL_TARGET_SIZE
# Delete the model session
if CURRENT_MODEL is not None:
del CURRENT_MODEL
CURRENT_MODEL = None
# Clear other large objects
CURRENT_TAGS_DF = None
CURRENT_TAG_NAMES = None
CURRENT_RATING_INDEXES = None
CURRENT_GENERAL_INDEXES = None
CURRENT_CHARACTER_INDEXES = None
CURRENT_MODEL_TARGET_SIZE = None
CURRENT_MODEL_NAME = None
# Force garbage collection
import gc
gc.collect()
# Clear CUDA cache if using GPU
try:
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
except ImportError:
pass
def cleanup_after_processing():
"""Cleanup resources after processing"""
unload_model()
class Predictor:
"""Main predictor class for handling image tagging"""
def __init__(self):
self.model_components = None
self.last_loaded_repo = None
def load_model(self, model_repo):
"""Load model if not already loaded"""
if model_repo == self.last_loaded_repo and self.model_components is not None:
return
_load_model_components_optimized(model_repo)
self.last_loaded_repo = model_repo
def prepare_image(self, path):
"""Prepare image for model input"""
image = Image.open(path)
image = image.convert('RGBA')
target_size = CURRENT_MODEL_TARGET_SIZE
# Create white background and composite
canvas = Image.new('RGBA', image.size, (255, 255, 255))
canvas.alpha_composite(image)
image = canvas.convert('RGB')
# Pad to square
image_shape = image.size
max_dim = max(image_shape)
pad_left = (max_dim - image_shape[0]) // 2
pad_top = (max_dim - image_shape[1]) // 2
padded_image = Image.new('RGB', (max_dim, max_dim), (255, 255, 255))
padded_image.paste(image, (pad_left, pad_top))
# Resize if needed
if max_dim != target_size:
padded_image = padded_image.resize((target_size, target_size), Image.BICUBIC)
# Convert to array and preprocess
image_array = np.asarray(padded_image, dtype=np.float32)
image_array = image_array[:, :, ::-1] # BGR to RGB
return np.expand_dims(image_array, axis=0)
def create_file(self, content: str, directory: str, fileName: str) -> str:
"""Create a file with the given content"""
file_path = os.path.join(directory, fileName)
if fileName.endswith('.json'):
with open(file_path, 'w', encoding='utf-8') as file:
file.write(content)
else:
with open(file_path, 'w+', encoding='utf-8') as file:
file.write(content)
return file_path
def predict(self, gallery, model_repo, model_repo_2, general_thresh, general_mcut_enabled,
character_thresh, character_mcut_enabled, characters_merge_enabled,
additional_tags_prepend, additional_tags_append, tag_results, progress=gr.Progress()):
"""Main prediction function for processing images"""
tag_results.clear()
gallery_len = len(gallery)
print(f"Predict load model: {model_repo}, gallery length: {gallery_len}")
timer = Timer()
progressRatio = 1
progressTotal = gallery_len + 1
current_progress = 0
txt_infos = []
output_dir = tempfile.mkdtemp()
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Load initial model
self.load_model(model_repo)
current_progress += progressRatio / progressTotal
progress(current_progress, desc='Initialize wd model finished')
timer.checkpoint("Initialize wd model")
timer.report()
name_counters = defaultdict(int)
for (idx, value) in enumerate(gallery):
try:
# Handle duplicate filenames
image_path = value[0]
image_name = os.path.splitext(os.path.basename(image_path))[0]
name_counters[image_name] += 1
if name_counters[image_name] > 1:
image_name = f"{image_name}_{name_counters[image_name]:02d}"
# Prepare image
image = self.prepare_image(image_path)
print(f"Gallery {idx:02d}: Starting run first model ({model_repo})...")
# Load and run first model
self.load_model(model_repo)
preds = _raw_predict(image, CURRENT_MODEL)
labels = list(zip(CURRENT_TAG_NAMES, preds))
# Process ratings
ratings_names = [labels[i] for i in CURRENT_RATING_INDEXES]
rating = dict(ratings_names)
# Process general tags
general_names = [labels[i] for i in CURRENT_GENERAL_INDEXES]
if general_mcut_enabled:
general_probs = np.array([x[1] for x in general_names])
general_thresh_temp = mcut_threshold(general_probs)
else:
general_thresh_temp = general_thresh
general_res = [x for x in general_names if x[1] > general_thresh_temp]
general_res = dict(general_res)
# Process character tags
character_names = [labels[i] for i in CURRENT_CHARACTER_INDEXES]
if character_mcut_enabled:
character_probs = np.array([x[1] for x in character_names])
character_thresh_temp = mcut_threshold(character_probs)
character_thresh_temp = max(0.15, character_thresh_temp)
else:
character_thresh_temp = character_thresh
character_res = [x for x in character_names if x[1] > character_thresh_temp]
character_res = dict(character_res)
character_list_1 = list(character_res.keys())
# Sort general tags by confidence
sorted_general_list_1 = sorted(general_res.items(), key=lambda x: x[1], reverse=True)
sorted_general_list_1 = [x[0] for x in sorted_general_list_1]
# Handle second model if provided
if model_repo_2 and model_repo_2 != model_repo:
print(f"Gallery {idx:02d}: Starting run second model ({model_repo_2})...")
self.load_model(model_repo_2)
preds_2 = _raw_predict(image, CURRENT_MODEL)
labels_2 = list(zip(CURRENT_TAG_NAMES, preds_2))
# Process general tags from second model
general_names_2 = [labels_2[i] for i in CURRENT_GENERAL_INDEXES]
if general_mcut_enabled:
general_probs_2 = np.array([x[1] for x in general_names_2])
general_thresh_temp_2 = mcut_threshold(general_probs_2)
else:
general_thresh_temp_2 = general_thresh
general_res_2 = [x for x in general_names_2 if x[1] > general_thresh_temp_2]
general_res_2 = dict(general_res_2)
# Process character tags from second model
character_names_2 = [labels_2[i] for i in CURRENT_CHARACTER_INDEXES]
if character_mcut_enabled:
character_probs_2 = np.array([x[1] for x in character_names_2])
character_thresh_temp_2 = mcut_threshold(character_probs_2)
character_thresh_temp_2 = max(0.15, character_thresh_temp_2)
else:
character_thresh_temp_2 = character_thresh
character_res_2 = [x for x in character_names_2 if x[1] > character_thresh_temp_2]
character_res_2 = dict(character_res_2)
character_list_2 = list(character_res_2.keys())
# Sort general tags from second model
sorted_general_list_2 = sorted(general_res_2.items(), key=lambda x: x[1], reverse=True)
sorted_general_list_2 = [x[0] for x in sorted_general_list_2]
# Combine results from both models
combined_character_list = list(set(character_list_1 + character_list_2))
combined_general_list = list(set(sorted_general_list_1 + sorted_general_list_2))
else:
combined_character_list = character_list_1
combined_general_list = sorted_general_list_1
# Remove characters from general tags if merging is disabled
if not characters_merge_enabled:
combined_character_list = [item for item in combined_character_list
if item not in combined_general_list]
# Handle additional tags
prepend_list = [tag.strip() for tag in additional_tags_prepend.split(',') if tag.strip()]
append_list = [tag.strip() for tag in additional_tags_append.split(',') if tag.strip()]
# Avoid duplicates in prepend/append lists
if prepend_list and append_list:
append_list = [item for item in append_list if item not in prepend_list]
# Remove prepended tags from main list
if prepend_list:
combined_general_list = [item for item in combined_general_list if item not in prepend_list]
# Remove appended tags from main list
if append_list:
combined_general_list = [item for item in combined_general_list if item not in append_list]
# Combine all tags
combined_general_list = prepend_list + combined_general_list + append_list
# Format output string
sorted_general_strings = ', '.join(
(combined_character_list if characters_merge_enabled else []) +
combined_general_list
).replace('(', '\\(').replace(')', '\\)').replace('_', ' ')
# Generate categorized output
categorized_strings = categorize_tags_output(sorted_general_strings, character_res).replace('(', '\\(').replace(')', '\\)')
categorized_json = generate_tags_json(sorted_general_strings, character_res)
# Create output files
txt_content = f"Output (string): {sorted_general_strings}\n\nCategorized Output: {categorized_strings}"
txt_file = self.create_file(txt_content, output_dir, f"{image_name}_output.txt")
txt_infos.append({'path': txt_file, 'name': f"{image_name}_output.txt"})
# Save image copy
image_path = value[0]
image = Image.open(image_path)
image.save(os.path.join(output_dir, f"{image_name}.png"), format='PNG')
txt_infos.append({'path': os.path.join(output_dir, f"{image_name}.png"), 'name': f"{image_name}.png"})
# Create tags text file
txt_file = self.create_file(sorted_general_strings, output_dir, image_name + '.txt')
# Create categorized tags file
categorized_file = self.create_file(categorized_strings, output_dir, f"{image_name}_categorized.txt")
txt_infos.append({'path': categorized_file, 'name': f"{image_name}_categorized.txt"})
txt_infos.append({'path': txt_file, 'name': image_name + '.txt'})
# Create JSON file
json_content = json.dumps(categorized_json, indent=2, ensure_ascii=False)
json_file = self.create_file(json_content, output_dir, f"{image_name}_categorized.json")
txt_infos.append({'path': json_file, 'name': f"{image_name}_categorized.json"})
# Store results
tag_results[image_path] = {
'strings': sorted_general_strings,
'categorized_strings': categorized_strings,
'categorized_json': categorized_json,
'rating': rating,
'character_res': character_res,
'general_res': general_res
}
# Update progress
current_progress += progressRatio / progressTotal
progress(current_progress, desc=f"image{idx:02d}, predict finished")
timer.checkpoint(f"image{idx:02d}, predict finished")
timer.report()
except Exception as e:
print(traceback.format_exc())
print('Error predict: ' + str(e))
# Create download zip
download = []
if txt_infos is not None and len(txt_infos) > 0:
downloadZipPath = os.path.join(
output_dir,
'Multi-Tagger-' + datetime.now().strftime('%Y%m%d-%H%M%S') + '.zip'
)
with zipfile.ZipFile(downloadZipPath, 'w', zipfile.ZIP_DEFLATED) as taggers_zip:
for info in txt_infos:
taggers_zip.write(info['path'], arcname=info['name'])
# If using GPU, model will auto unload after zip file creation
cleanup_after_processing() # Comment here to turn off this behavior
download.append(downloadZipPath)
progress(1, desc=f"Predict completed")
timer.report_all()
print('Predict is complete.')
# Return first image results as default
first_image_results = '', {}, {}, {}, '', {}
if gallery and len(gallery) > 0:
first_image_path = gallery[0][0]
if first_image_path in tag_results:
first_result = tag_results[first_image_path]
character_tags_formatted = ", ".join([name.replace("(", "\\(").replace(")", "\\)").replace("_", " ")
for name in first_result['character_res'].keys()])
first_image_results = (
first_result['strings'],
first_result['rating'],
character_tags_formatted,
first_result['general_res'],
first_result.get('categorized_strings', ''),
first_result.get('categorized_json', {})
)
return (
download,
first_image_results[0],
first_image_results[1],
first_image_results[2],
first_image_results[3],
first_image_results[4],
first_image_results[5],
tag_results
)
def get_selection_from_gallery(gallery: list, tag_results: dict, selected_state: gr.SelectData):
# Return first image results if no selection
if not selected_state and gallery and len(gallery) > 0:
first_image_path = gallery[0][0]
if first_image_path in tag_results:
first_result = tag_results[first_image_path]
character_tags_formatted = ", ".join([name.replace("(", "\\(").replace(")", "\\)").replace("_", " ")
for name in first_result['character_res'].keys()])
return (
first_result['strings'],
first_result['rating'],
character_tags_formatted,
first_result['general_res'],
first_result.get('categorized_strings', ''),
first_result.get('categorized_json', {})
)
if not selected_state:
return '', {}, '', {}, '', {}
# Get selected image path
selected_value = selected_state.value
image_path = None
if isinstance(selected_value, dict) and 'image' in selected_value:
image_path = selected_value['image']['path']
elif isinstance(selected_value, (list, tuple)) and len(selected_value) > 0:
image_path = selected_value[0]
else:
image_path = str(selected_value)
# Return stored results
if image_path in tag_results:
result = tag_results[image_path]
character_tags_formatted = ", ".join([name.replace("(", "\\(").replace(")", "\\)").replace("_", " ")
for name in result['character_res'].keys()])
return (
result['strings'],
result['rating'],
character_tags_formatted,
result['general_res'],
result.get('categorized_strings', ''),
result.get('categorized_json', {})
)
return '', {}, '', {}, '', {}
def append_gallery(gallery: list, image: str):
"""Add a single media file (image or video) to the gallery"""
return handle_single_media_upload(image, gallery)
def extend_gallery(gallery: list, images):
"""Add multiple media files (images or videos) to the gallery"""
return handle_multiple_media_uploads(images, gallery)
# Parse arguments and initialize predictor
args = parse_args()
predictor = Predictor()
dropdown_list = [
EVA02_LARGE_MODEL_DSV3_REPO, VIT_LARGE_MODEL_DSV3_REPO, SWINV2_MODEL_DSV3_REPO,
CONV_MODEL_DSV3_REPO, VIT_MODEL_DSV3_REPO, MOAT_MODEL_DSV2_REPO,
SWIN_MODEL_DSV2_REPO, CONV_MODEL_DSV2_REPO, CONV2_MODEL_DSV2_REPO,
VIT_MODEL_DSV2_REPO, EVA02_LARGE_MODEL_IS_DSV1_REPO, SWINV2_MODEL_IS_DSV1_REPO
]
def _restart_space():
"""Restart the HuggingFace Space periodically for stability"""
HF_TOKEN = os.getenv('HF_TOKEN')
if not HF_TOKEN:
raise ValueError('HF_TOKEN environment variable is not set.')
huggingface_hub.HfApi().restart_space(
repo_id='Werli/Multi-Tagger',
token=HF_TOKEN,
factory_reboot=False
)
# Setup scheduler for periodic restarts
scheduler = BackgroundScheduler()
restart_space_job = scheduler.add_job(_restart_space, 'interval', seconds=172800)
scheduler.start()
next_run_time_utc = restart_space_job.next_run_time.astimezone(timezone.utc)
NEXT_RESTART = f"Next Restart: {next_run_time_utc.strftime('%Y-%m-%d %H:%M:%S')} (UTC) - The space will restart every 2 days to ensure stability and performance. It uses a background scheduler to handle the restart process."
with gr.Blocks(title=TITLE, css=css, theme="Werli/Purple-Crimson-Gradio-Theme", fill_width=True) as demo:
gr.Markdown(value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>")
gr.Markdown(value=f"<p style='text-align: center;'>{DESCRIPTION}</p>")
with gr.Tab(label='Waifu Diffusion'):
with gr.Row():
with gr.Column():
with gr.Column(variant='panel'):
image_input = gr.Image(
label='Upload an Image (or paste from clipboard)',
type='filepath',
sources=['upload', 'clipboard'],
height=150
)
with gr.Row():
upload_button = gr.UploadButton(
'Upload multiple images or videos',
file_types=['image', 'video'],
file_count='multiple',
size='md'
)
gallery = gr.Gallery(
columns=2,
show_share_button=False,
interactive=True,
height='auto',
label='Grid of images',
preview=False,
elem_id='custom-gallery'
)
submit = gr.Button(value='Analyze Images', variant='primary', size='lg')
clear = gr.ClearButton(components=[gallery], value='Clear Gallery', variant='secondary', size='sm')
with gr.Column(variant='panel'):
model_repo = gr.Dropdown(
dropdown_list,
value=EVA02_LARGE_MODEL_DSV3_REPO,
label='1st Model'
)
PLUS = '+?'
gr.Markdown(value=f"<p style='text-align: center;'>{PLUS}</p>")
model_repo_2 = gr.Dropdown(
[None] + dropdown_list,
value=None,
label='2nd Model (Optional)',
info='Select another model for diversified results.'
)
with gr.Row():
general_thresh = gr.Slider(
0, 1,
step=args.score_slider_step,
value=args.score_general_threshold,
label='General Tags Threshold',
scale=3
)
general_mcut_enabled = gr.Checkbox(
value=False,
label='Use MCut threshold',
scale=1
)
with gr.Row():
character_thresh = gr.Slider(
0, 1,
step=args.score_slider_step,
value=args.score_character_threshold,
label='Character Tags Threshold',
scale=3
)
character_mcut_enabled = gr.Checkbox(
value=False,
label='Use MCut threshold',
scale=1
)
with gr.Row():
characters_merge_enabled = gr.Checkbox(
value=False,
label='Merge characters into the string output',
scale=1
)
with gr.Row():
additional_tags_prepend = gr.Text(
label='Prepend Additional tags (comma split)'
)
additional_tags_append = gr.Text(
label='Append Additional tags (comma split)'
)
with gr.Row():
clear = gr.ClearButton(
components=[
gallery, model_repo, general_thresh, general_mcut_enabled,
character_thresh, character_mcut_enabled, characters_merge_enabled,
additional_tags_prepend, additional_tags_append
],
value='Clear Everything',
variant='secondary',
size='lg'
)
with gr.Column(variant='panel'):
download_file = gr.File(label='Download')
character_res = gr.Textbox(
label="Character tags",
show_copy_button=True,
lines=3
)
sorted_general_strings = gr.Textbox(
label='Output',
show_label=True,
show_copy_button=True,
lines=5
)
categorized_strings = gr.Textbox(
label='Categorized',
show_label=True,
show_copy_button=True,
lines=5
)
tags_json = gr.JSON(
label='Categorized Tags (JSON)',
visible=True
)
rating = gr.Label(label='Rating')
general_res = gr.Textbox(
label="General tags",
show_copy_button=True,
lines=3,
visible=False # Temp
)
# State to store results
tag_results = gr.State({})
# Event handlers
image_input.change(
append_gallery,
inputs=[gallery, image_input],
outputs=[gallery, image_input]
)
upload_button.upload(
extend_gallery,
inputs=[gallery, upload_button],
outputs=gallery
)
gallery.select(
get_selection_from_gallery,
inputs=[gallery, tag_results],
outputs=[sorted_general_strings, rating, character_res, general_res, categorized_strings, tags_json]
)
submit.click(
predictor.predict,
inputs=[
gallery, model_repo, model_repo_2, general_thresh, general_mcut_enabled,
character_thresh, character_mcut_enabled, characters_merge_enabled,
additional_tags_prepend, additional_tags_append, tag_results
],
outputs=[download_file, sorted_general_strings, rating, character_res, general_res, categorized_strings, tags_json, tag_results]
)
gr.Markdown('[Based on SmilingWolf/wd-tagger](https://huggingface.co/spaces/SmilingWolf/wd-tagger) <p style="text-align:right"><a href="https://huggingface.co/spaces/John6666/danbooru-tags-transformer-v2-with-wd-tagger-b">Prompt Enhancer</a></p>')
with gr.Tab("PixAI"):
pixai_interface = create_pixai_interface()
with gr.Tab("Booru Image Fetcher"):
booru_interface = create_booru_interface()
with gr.Tab("ComfyUI Extractor"):
comfy_interface = create_multi_comfy()
with gr.Tab(label="Misc"):
with gr.Row():
with gr.Column(variant="panel"):
tag_string = gr.Textbox(
label="Input Tags",
placeholder="1girl, cat, horns, blue hair, ...\nor\n? 1girl 1234567? cat 1234567? horns 1234567? blue hair 1234567? ...",
lines=4
)
submit_button = gr.Button(value="START", variant="primary", size="lg")
with gr.Column(variant="panel"):
cleaned_tags_output = gr.Textbox(
label="Cleaned Tags",
show_label=True,
show_copy_button=True,
lines=4,
info="Tags with ? and numbers removed, formatted with commas. Useful for clearing tags from Booru sites."
)
classify_tags_for_display = gr.Textbox(
label="Categorized (string)",
show_label=True,
show_copy_button=True,
lines=8,
info="Tags organized by categories"
)
generate_categorized_json = gr.JSON(
label="Categorized JSON (tags)"
)
# Fix the event handler to properly call the function
submit_button.click(
process_tags_for_misc,
inputs=[tag_string],
outputs=[cleaned_tags_output, classify_tags_for_display, generate_categorized_json]
)
gr.Markdown(NEXT_RESTART)
demo.queue(max_size=5).launch(show_error=True, show_api=False)