Spaces:
Running
Running
| import os, io, json, requests, spaces, argparse, traceback, tempfile, zipfile, re, ast, time | |
| import gradio as gr | |
| import numpy as np | |
| import huggingface_hub | |
| import onnxruntime as ort | |
| import pandas as pd | |
| from datetime import datetime, timezone | |
| from collections import defaultdict | |
| from PIL import Image, ImageOps | |
| from apscheduler.schedulers.background import BackgroundScheduler | |
| from modules.classifyTags import categorize_tags_output, generate_tags_json, process_tags_for_misc | |
| from modules.pixai import create_pixai_interface | |
| from modules.booru import create_booru_interface | |
| from modules.multi_comfy import create_multi_comfy | |
| from modules.media_handler import handle_single_media_upload, handle_multiple_media_uploads | |
| """ For GPU install all the requirements.txt and the following: | |
| pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu126 or any other Torch version | |
| pip install onnxruntime-gpu | |
| """ | |
| """ It's recommended to create a venv if you want to use it offline: | |
| python -m venv venv | |
| venv\Scripts\activate | |
| pip install ... | |
| python app.py | |
| """ | |
| TITLE = 'Multi-Tagger v1.4' | |
| DESCRIPTION = '\nMulti-Tagger is a versatile application for advanced image analysis and captioning. Supports <b>CUDA</b> and <b>CPU</b>.\n' | |
| SWINV2_MODEL_DSV3_REPO = 'SmilingWolf/wd-swinv2-tagger-v3' | |
| CONV_MODEL_DSV3_REPO = 'SmilingWolf/wd-convnext-tagger-v3' | |
| VIT_MODEL_DSV3_REPO = 'SmilingWolf/wd-vit-tagger-v3' | |
| VIT_LARGE_MODEL_DSV3_REPO = 'SmilingWolf/wd-vit-large-tagger-v3' | |
| EVA02_LARGE_MODEL_DSV3_REPO = 'SmilingWolf/wd-eva02-large-tagger-v3' | |
| MOAT_MODEL_DSV2_REPO = 'SmilingWolf/wd-v1-4-moat-tagger-v2' | |
| SWIN_MODEL_DSV2_REPO = 'SmilingWolf/wd-v1-4-swinv2-tagger-v2' | |
| CONV_MODEL_DSV2_REPO = 'SmilingWolf/wd-v1-4-convnext-tagger-v2' | |
| CONV2_MODEL_DSV2_REPO = 'SmilingWolf/wd-v1-4-convnextv2-tagger-v2' | |
| VIT_MODEL_DSV2_REPO = 'SmilingWolf/wd-v1-4-vit-tagger-v2' | |
| EVA02_LARGE_MODEL_IS_DSV1_REPO = 'deepghs/idolsankaku-eva02-large-tagger-v1' | |
| SWINV2_MODEL_IS_DSV1_REPO = 'deepghs/idolsankaku-swinv2-tagger-v1' | |
| # Global variables for model components (for memory management) | |
| CURRENT_MODEL = None | |
| CURRENT_MODEL_NAME = None | |
| CURRENT_TAGS_DF = None | |
| CURRENT_TAG_NAMES = None | |
| CURRENT_RATING_INDEXES = None | |
| CURRENT_GENERAL_INDEXES = None | |
| CURRENT_CHARACTER_INDEXES = None | |
| CURRENT_MODEL_TARGET_SIZE = None | |
| # Custom CSS for gallery styling | |
| css = """ | |
| #custom-gallery {--row-height: 180px;display: grid;grid-auto-rows: min-content;gap: 10px;} | |
| #custom-gallery .thumbnail-item {height: var(--row-height);width: 100%;position: relative;overflow: hidden;border-radius: 8px;box-shadow: 0 2px 5px rgba(0, 0, 0, 0.1);transition: transform 0.2s ease, box-shadow 0.2s ease;} | |
| #custom-gallery .thumbnail-item:hover {transform: translateY(-3px);box-shadow: 0 4px 12px rgba(0, 0, 0, 0.15);} | |
| #custom-gallery .thumbnail-item img {width: auto;height: 100%;max-width: 100%;max-height: var(--row-height);object-fit: contain;margin: 0 auto;display: block;} | |
| #custom-gallery .thumbnail-item img.portrait {max-width: 100%;} | |
| #custom-gallery .thumbnail-item img.landscape {max-height: 100%;} | |
| .gallery-container {max-height: 500px;overflow-y: auto;padding-right: 0px;--size-80: 500px;} | |
| .thumbnails {display: flex;position: absolute;bottom: 0;width: 120px;overflow-x: scroll;padding-top: 320px;padding-bottom: 280px;padding-left: 4px;flex-wrap: wrap;} | |
| #custom-gallery .thumbnail-item img {width: auto;height: 100%;max-width: 100%;max-height: var(--row-height);object-fit: initial;width: fit-content;margin: 0px auto;display: block;} | |
| """ | |
| MODEL_FILENAME = 'model.onnx' | |
| LABEL_FILENAME = 'selected_tags.csv' | |
| class Timer: | |
| """Utility class for measuring execution time of different operations""" | |
| def __init__(self): | |
| self.start_time = time.perf_counter() | |
| self.checkpoints = [('Start', self.start_time)] | |
| def checkpoint(self, label='Checkpoint'): | |
| """Add a checkpoint with a label""" | |
| now = time.perf_counter() | |
| self.checkpoints.append((label, now)) | |
| def report(self, is_clear_checkpoints=True): | |
| """Report time elapsed since last checkpoint""" | |
| max_label_length = max(len(label) for (label, _) in self.checkpoints) if self.checkpoints else 0 | |
| prev_time = self.checkpoints[0][1] if self.checkpoints else self.start_time | |
| for (label, curr_time) in self.checkpoints[1:]: | |
| elapsed = curr_time - prev_time | |
| print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds") | |
| prev_time = curr_time | |
| if is_clear_checkpoints: | |
| self.checkpoints.clear() | |
| self.checkpoint() | |
| def report_all(self): | |
| """Report all checkpoint times including total execution time""" | |
| print('\n> Execution Time Report:') | |
| max_label_length = max(len(label) for (label, _) in self.checkpoints) if len(self.checkpoints) > 0 else 0 | |
| prev_time = self.start_time | |
| for (label, curr_time) in self.checkpoints[1:]: | |
| elapsed = curr_time - prev_time | |
| print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds") | |
| prev_time = curr_time | |
| total_time = self.checkpoints[-1][1] - self.start_time if self.checkpoints else 0 | |
| print(f"{'Total Execution Time'.ljust(max_label_length)}: {total_time:.3f} seconds\n") | |
| self.checkpoints.clear() | |
| def restart(self): | |
| """Restart the timer""" | |
| self.start_time = time.perf_counter() | |
| self.checkpoints = [('Start', self.start_time)] | |
| def parse_args() -> argparse.Namespace: | |
| """Parse command line arguments""" | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--score-slider-step', type=float, default=0.05) | |
| parser.add_argument('--score-general-threshold', type=float, default=0.35) | |
| parser.add_argument('--score-character-threshold', type=float, default=0.85) | |
| parser.add_argument('--share', action='store_true') | |
| return parser.parse_args() | |
| def load_labels(dataframe) -> tuple: | |
| """Load tag names and their category indexes from the dataframe""" | |
| name_series = dataframe['name'] | |
| tag_names = name_series.tolist() | |
| # Find indexes for different tag categories | |
| rating_indexes = list(np.where(dataframe['category'] == 9)[0]) | |
| general_indexes = list(np.where(dataframe['category'] == 0)[0]) | |
| character_indexes = list(np.where(dataframe['category'] == 4)[0]) | |
| return tag_names, rating_indexes, general_indexes, character_indexes | |
| def mcut_threshold(probs): | |
| """Calculate threshold using Maximum Change in second derivative (MCut) method""" | |
| sorted_probs = probs[probs.argsort()[::-1]] | |
| difs = sorted_probs[:-1] - sorted_probs[1:] | |
| t = difs.argmax() | |
| thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2 | |
| return thresh | |
| def _download_model_files(model_repo): | |
| """Download model files from HuggingFace Hub""" | |
| csv_path = huggingface_hub.hf_hub_download(model_repo, LABEL_FILENAME) | |
| model_path = huggingface_hub.hf_hub_download(model_repo, MODEL_FILENAME) | |
| return csv_path, model_path | |
| def create_optimized_ort_session(model_path): | |
| """Create an optimized ONNX Runtime session with GPU support""" | |
| # Configure session options for better performance | |
| sess_options = ort.SessionOptions() | |
| sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL | |
| sess_options.intra_op_num_threads = 0 # Use all available cores | |
| sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL | |
| sess_options.enable_mem_pattern = True | |
| sess_options.enable_cpu_mem_arena = True | |
| # Check available providers | |
| available_providers = ort.get_available_providers() | |
| print(f"Available ONNX Runtime providers: {available_providers}") | |
| # Configure execution providers (prefer CUDA if available) | |
| providers = [] | |
| # Use CUDA if available | |
| if 'CUDAExecutionProvider' in available_providers: | |
| providers.append('CUDAExecutionProvider') | |
| print("Using CUDA provider for ONNX inference") | |
| else: | |
| print("CUDA provider not available, falling back to CPU") | |
| # Always include CPU as fallback | |
| providers.append('CPUExecutionProvider') | |
| try: | |
| session = ort.InferenceSession(model_path, sess_options, providers=providers) | |
| print(f"Model loaded with providers: {session.get_providers()}") | |
| return session | |
| except Exception as e: | |
| print(f"Failed to create ONNX session: {e}") | |
| raise | |
| def _load_model_components_optimized(model_repo): | |
| """Load and optimize model components""" | |
| global CURRENT_MODEL, CURRENT_MODEL_NAME, CURRENT_TAGS_DF, CURRENT_TAG_NAMES | |
| global CURRENT_RATING_INDEXES, CURRENT_GENERAL_INDEXES, CURRENT_CHARACTER_INDEXES, CURRENT_MODEL_TARGET_SIZE | |
| # Only reload if model changed | |
| if model_repo == CURRENT_MODEL_NAME and CURRENT_MODEL is not None: | |
| return | |
| # Download files | |
| csv_path, model_path = _download_model_files(model_repo) | |
| # Load optimized ONNX model | |
| CURRENT_MODEL = create_optimized_ort_session(model_path) | |
| # Load tags | |
| tags_df = pd.read_csv(csv_path) | |
| tag_names, rating_indexes, general_indexes, character_indexes = load_labels(tags_df) | |
| # Store in global variables | |
| CURRENT_TAGS_DF = tags_df | |
| CURRENT_TAG_NAMES = tag_names | |
| CURRENT_RATING_INDEXES = rating_indexes | |
| CURRENT_GENERAL_INDEXES = general_indexes | |
| CURRENT_CHARACTER_INDEXES = character_indexes | |
| # Get model input size | |
| _, height, width, _ = CURRENT_MODEL.get_inputs()[0].shape | |
| CURRENT_MODEL_TARGET_SIZE = height | |
| CURRENT_MODEL_NAME = model_repo | |
| def _raw_predict(image_array, model_session): | |
| """Run raw prediction using the model session""" | |
| input_name = model_session.get_inputs()[0].name | |
| label_name = model_session.get_outputs()[0].name | |
| preds = model_session.run([label_name], {input_name: image_array})[0] | |
| return preds[0].astype(float) | |
| def unload_model(): | |
| """Explicitly unload the current model from memory""" | |
| global CURRENT_MODEL, CURRENT_MODEL_NAME, CURRENT_TAGS_DF, CURRENT_TAG_NAMES | |
| global CURRENT_RATING_INDEXES, CURRENT_GENERAL_INDEXES, CURRENT_CHARACTER_INDEXES, CURRENT_MODEL_TARGET_SIZE | |
| # Delete the model session | |
| if CURRENT_MODEL is not None: | |
| del CURRENT_MODEL | |
| CURRENT_MODEL = None | |
| # Clear other large objects | |
| CURRENT_TAGS_DF = None | |
| CURRENT_TAG_NAMES = None | |
| CURRENT_RATING_INDEXES = None | |
| CURRENT_GENERAL_INDEXES = None | |
| CURRENT_CHARACTER_INDEXES = None | |
| CURRENT_MODEL_TARGET_SIZE = None | |
| CURRENT_MODEL_NAME = None | |
| # Force garbage collection | |
| import gc | |
| gc.collect() | |
| # Clear CUDA cache if using GPU | |
| try: | |
| import torch | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| except ImportError: | |
| pass | |
| def cleanup_after_processing(): | |
| """Cleanup resources after processing""" | |
| unload_model() | |
| class Predictor: | |
| """Main predictor class for handling image tagging""" | |
| def __init__(self): | |
| self.model_components = None | |
| self.last_loaded_repo = None | |
| def load_model(self, model_repo): | |
| """Load model if not already loaded""" | |
| if model_repo == self.last_loaded_repo and self.model_components is not None: | |
| return | |
| _load_model_components_optimized(model_repo) | |
| self.last_loaded_repo = model_repo | |
| def prepare_image(self, path): | |
| """Prepare image for model input""" | |
| image = Image.open(path) | |
| image = image.convert('RGBA') | |
| target_size = CURRENT_MODEL_TARGET_SIZE | |
| # Create white background and composite | |
| canvas = Image.new('RGBA', image.size, (255, 255, 255)) | |
| canvas.alpha_composite(image) | |
| image = canvas.convert('RGB') | |
| # Pad to square | |
| image_shape = image.size | |
| max_dim = max(image_shape) | |
| pad_left = (max_dim - image_shape[0]) // 2 | |
| pad_top = (max_dim - image_shape[1]) // 2 | |
| padded_image = Image.new('RGB', (max_dim, max_dim), (255, 255, 255)) | |
| padded_image.paste(image, (pad_left, pad_top)) | |
| # Resize if needed | |
| if max_dim != target_size: | |
| padded_image = padded_image.resize((target_size, target_size), Image.BICUBIC) | |
| # Convert to array and preprocess | |
| image_array = np.asarray(padded_image, dtype=np.float32) | |
| image_array = image_array[:, :, ::-1] # BGR to RGB | |
| return np.expand_dims(image_array, axis=0) | |
| def create_file(self, content: str, directory: str, fileName: str) -> str: | |
| """Create a file with the given content""" | |
| file_path = os.path.join(directory, fileName) | |
| if fileName.endswith('.json'): | |
| with open(file_path, 'w', encoding='utf-8') as file: | |
| file.write(content) | |
| else: | |
| with open(file_path, 'w+', encoding='utf-8') as file: | |
| file.write(content) | |
| return file_path | |
| def predict(self, gallery, model_repo, model_repo_2, general_thresh, general_mcut_enabled, | |
| character_thresh, character_mcut_enabled, characters_merge_enabled, | |
| additional_tags_prepend, additional_tags_append, tag_results, progress=gr.Progress()): | |
| """Main prediction function for processing images""" | |
| tag_results.clear() | |
| gallery_len = len(gallery) | |
| print(f"Predict load model: {model_repo}, gallery length: {gallery_len}") | |
| timer = Timer() | |
| progressRatio = 1 | |
| progressTotal = gallery_len + 1 | |
| current_progress = 0 | |
| txt_infos = [] | |
| output_dir = tempfile.mkdtemp() | |
| if not os.path.exists(output_dir): | |
| os.makedirs(output_dir) | |
| # Load initial model | |
| self.load_model(model_repo) | |
| current_progress += progressRatio / progressTotal | |
| progress(current_progress, desc='Initialize wd model finished') | |
| timer.checkpoint("Initialize wd model") | |
| timer.report() | |
| name_counters = defaultdict(int) | |
| for (idx, value) in enumerate(gallery): | |
| try: | |
| # Handle duplicate filenames | |
| image_path = value[0] | |
| image_name = os.path.splitext(os.path.basename(image_path))[0] | |
| name_counters[image_name] += 1 | |
| if name_counters[image_name] > 1: | |
| image_name = f"{image_name}_{name_counters[image_name]:02d}" | |
| # Prepare image | |
| image = self.prepare_image(image_path) | |
| print(f"Gallery {idx:02d}: Starting run first model ({model_repo})...") | |
| # Load and run first model | |
| self.load_model(model_repo) | |
| preds = _raw_predict(image, CURRENT_MODEL) | |
| labels = list(zip(CURRENT_TAG_NAMES, preds)) | |
| # Process ratings | |
| ratings_names = [labels[i] for i in CURRENT_RATING_INDEXES] | |
| rating = dict(ratings_names) | |
| # Process general tags | |
| general_names = [labels[i] for i in CURRENT_GENERAL_INDEXES] | |
| if general_mcut_enabled: | |
| general_probs = np.array([x[1] for x in general_names]) | |
| general_thresh_temp = mcut_threshold(general_probs) | |
| else: | |
| general_thresh_temp = general_thresh | |
| general_res = [x for x in general_names if x[1] > general_thresh_temp] | |
| general_res = dict(general_res) | |
| # Process character tags | |
| character_names = [labels[i] for i in CURRENT_CHARACTER_INDEXES] | |
| if character_mcut_enabled: | |
| character_probs = np.array([x[1] for x in character_names]) | |
| character_thresh_temp = mcut_threshold(character_probs) | |
| character_thresh_temp = max(0.15, character_thresh_temp) | |
| else: | |
| character_thresh_temp = character_thresh | |
| character_res = [x for x in character_names if x[1] > character_thresh_temp] | |
| character_res = dict(character_res) | |
| character_list_1 = list(character_res.keys()) | |
| # Sort general tags by confidence | |
| sorted_general_list_1 = sorted(general_res.items(), key=lambda x: x[1], reverse=True) | |
| sorted_general_list_1 = [x[0] for x in sorted_general_list_1] | |
| # Handle second model if provided | |
| if model_repo_2 and model_repo_2 != model_repo: | |
| print(f"Gallery {idx:02d}: Starting run second model ({model_repo_2})...") | |
| self.load_model(model_repo_2) | |
| preds_2 = _raw_predict(image, CURRENT_MODEL) | |
| labels_2 = list(zip(CURRENT_TAG_NAMES, preds_2)) | |
| # Process general tags from second model | |
| general_names_2 = [labels_2[i] for i in CURRENT_GENERAL_INDEXES] | |
| if general_mcut_enabled: | |
| general_probs_2 = np.array([x[1] for x in general_names_2]) | |
| general_thresh_temp_2 = mcut_threshold(general_probs_2) | |
| else: | |
| general_thresh_temp_2 = general_thresh | |
| general_res_2 = [x for x in general_names_2 if x[1] > general_thresh_temp_2] | |
| general_res_2 = dict(general_res_2) | |
| # Process character tags from second model | |
| character_names_2 = [labels_2[i] for i in CURRENT_CHARACTER_INDEXES] | |
| if character_mcut_enabled: | |
| character_probs_2 = np.array([x[1] for x in character_names_2]) | |
| character_thresh_temp_2 = mcut_threshold(character_probs_2) | |
| character_thresh_temp_2 = max(0.15, character_thresh_temp_2) | |
| else: | |
| character_thresh_temp_2 = character_thresh | |
| character_res_2 = [x for x in character_names_2 if x[1] > character_thresh_temp_2] | |
| character_res_2 = dict(character_res_2) | |
| character_list_2 = list(character_res_2.keys()) | |
| # Sort general tags from second model | |
| sorted_general_list_2 = sorted(general_res_2.items(), key=lambda x: x[1], reverse=True) | |
| sorted_general_list_2 = [x[0] for x in sorted_general_list_2] | |
| # Combine results from both models | |
| combined_character_list = list(set(character_list_1 + character_list_2)) | |
| combined_general_list = list(set(sorted_general_list_1 + sorted_general_list_2)) | |
| else: | |
| combined_character_list = character_list_1 | |
| combined_general_list = sorted_general_list_1 | |
| # Remove characters from general tags if merging is disabled | |
| if not characters_merge_enabled: | |
| combined_character_list = [item for item in combined_character_list | |
| if item not in combined_general_list] | |
| # Handle additional tags | |
| prepend_list = [tag.strip() for tag in additional_tags_prepend.split(',') if tag.strip()] | |
| append_list = [tag.strip() for tag in additional_tags_append.split(',') if tag.strip()] | |
| # Avoid duplicates in prepend/append lists | |
| if prepend_list and append_list: | |
| append_list = [item for item in append_list if item not in prepend_list] | |
| # Remove prepended tags from main list | |
| if prepend_list: | |
| combined_general_list = [item for item in combined_general_list if item not in prepend_list] | |
| # Remove appended tags from main list | |
| if append_list: | |
| combined_general_list = [item for item in combined_general_list if item not in append_list] | |
| # Combine all tags | |
| combined_general_list = prepend_list + combined_general_list + append_list | |
| # Format output string | |
| sorted_general_strings = ', '.join( | |
| (combined_character_list if characters_merge_enabled else []) + | |
| combined_general_list | |
| ).replace('(', '\\(').replace(')', '\\)').replace('_', ' ') | |
| # Generate categorized output | |
| categorized_strings = categorize_tags_output(sorted_general_strings, character_res).replace('(', '\\(').replace(')', '\\)') | |
| categorized_json = generate_tags_json(sorted_general_strings, character_res) | |
| # Create output files | |
| txt_content = f"Output (string): {sorted_general_strings}\n\nCategorized Output: {categorized_strings}" | |
| txt_file = self.create_file(txt_content, output_dir, f"{image_name}_output.txt") | |
| txt_infos.append({'path': txt_file, 'name': f"{image_name}_output.txt"}) | |
| # Save image copy | |
| image_path = value[0] | |
| image = Image.open(image_path) | |
| image.save(os.path.join(output_dir, f"{image_name}.png"), format='PNG') | |
| txt_infos.append({'path': os.path.join(output_dir, f"{image_name}.png"), 'name': f"{image_name}.png"}) | |
| # Create tags text file | |
| txt_file = self.create_file(sorted_general_strings, output_dir, image_name + '.txt') | |
| # Create categorized tags file | |
| categorized_file = self.create_file(categorized_strings, output_dir, f"{image_name}_categorized.txt") | |
| txt_infos.append({'path': categorized_file, 'name': f"{image_name}_categorized.txt"}) | |
| txt_infos.append({'path': txt_file, 'name': image_name + '.txt'}) | |
| # Create JSON file | |
| json_content = json.dumps(categorized_json, indent=2, ensure_ascii=False) | |
| json_file = self.create_file(json_content, output_dir, f"{image_name}_categorized.json") | |
| txt_infos.append({'path': json_file, 'name': f"{image_name}_categorized.json"}) | |
| # Store results | |
| tag_results[image_path] = { | |
| 'strings': sorted_general_strings, | |
| 'categorized_strings': categorized_strings, | |
| 'categorized_json': categorized_json, | |
| 'rating': rating, | |
| 'character_res': character_res, | |
| 'general_res': general_res | |
| } | |
| # Update progress | |
| current_progress += progressRatio / progressTotal | |
| progress(current_progress, desc=f"image{idx:02d}, predict finished") | |
| timer.checkpoint(f"image{idx:02d}, predict finished") | |
| timer.report() | |
| except Exception as e: | |
| print(traceback.format_exc()) | |
| print('Error predict: ' + str(e)) | |
| # Create download zip | |
| download = [] | |
| if txt_infos is not None and len(txt_infos) > 0: | |
| downloadZipPath = os.path.join( | |
| output_dir, | |
| 'Multi-Tagger-' + datetime.now().strftime('%Y%m%d-%H%M%S') + '.zip' | |
| ) | |
| with zipfile.ZipFile(downloadZipPath, 'w', zipfile.ZIP_DEFLATED) as taggers_zip: | |
| for info in txt_infos: | |
| taggers_zip.write(info['path'], arcname=info['name']) | |
| # If using GPU, model will auto unload after zip file creation | |
| cleanup_after_processing() # Comment here to turn off this behavior | |
| download.append(downloadZipPath) | |
| progress(1, desc=f"Predict completed") | |
| timer.report_all() | |
| print('Predict is complete.') | |
| # Return first image results as default | |
| first_image_results = '', {}, {}, {}, '', {} | |
| if gallery and len(gallery) > 0: | |
| first_image_path = gallery[0][0] | |
| if first_image_path in tag_results: | |
| first_result = tag_results[first_image_path] | |
| character_tags_formatted = ", ".join([name.replace("(", "\\(").replace(")", "\\)").replace("_", " ") | |
| for name in first_result['character_res'].keys()]) | |
| first_image_results = ( | |
| first_result['strings'], | |
| first_result['rating'], | |
| character_tags_formatted, | |
| first_result['general_res'], | |
| first_result.get('categorized_strings', ''), | |
| first_result.get('categorized_json', {}) | |
| ) | |
| return ( | |
| download, | |
| first_image_results[0], | |
| first_image_results[1], | |
| first_image_results[2], | |
| first_image_results[3], | |
| first_image_results[4], | |
| first_image_results[5], | |
| tag_results | |
| ) | |
| def get_selection_from_gallery(gallery: list, tag_results: dict, selected_state: gr.SelectData): | |
| # Return first image results if no selection | |
| if not selected_state and gallery and len(gallery) > 0: | |
| first_image_path = gallery[0][0] | |
| if first_image_path in tag_results: | |
| first_result = tag_results[first_image_path] | |
| character_tags_formatted = ", ".join([name.replace("(", "\\(").replace(")", "\\)").replace("_", " ") | |
| for name in first_result['character_res'].keys()]) | |
| return ( | |
| first_result['strings'], | |
| first_result['rating'], | |
| character_tags_formatted, | |
| first_result['general_res'], | |
| first_result.get('categorized_strings', ''), | |
| first_result.get('categorized_json', {}) | |
| ) | |
| if not selected_state: | |
| return '', {}, '', {}, '', {} | |
| # Get selected image path | |
| selected_value = selected_state.value | |
| image_path = None | |
| if isinstance(selected_value, dict) and 'image' in selected_value: | |
| image_path = selected_value['image']['path'] | |
| elif isinstance(selected_value, (list, tuple)) and len(selected_value) > 0: | |
| image_path = selected_value[0] | |
| else: | |
| image_path = str(selected_value) | |
| # Return stored results | |
| if image_path in tag_results: | |
| result = tag_results[image_path] | |
| character_tags_formatted = ", ".join([name.replace("(", "\\(").replace(")", "\\)").replace("_", " ") | |
| for name in result['character_res'].keys()]) | |
| return ( | |
| result['strings'], | |
| result['rating'], | |
| character_tags_formatted, | |
| result['general_res'], | |
| result.get('categorized_strings', ''), | |
| result.get('categorized_json', {}) | |
| ) | |
| return '', {}, '', {}, '', {} | |
| def append_gallery(gallery: list, image: str): | |
| """Add a single media file (image or video) to the gallery""" | |
| return handle_single_media_upload(image, gallery) | |
| def extend_gallery(gallery: list, images): | |
| """Add multiple media files (images or videos) to the gallery""" | |
| return handle_multiple_media_uploads(images, gallery) | |
| # Parse arguments and initialize predictor | |
| args = parse_args() | |
| predictor = Predictor() | |
| dropdown_list = [ | |
| EVA02_LARGE_MODEL_DSV3_REPO, VIT_LARGE_MODEL_DSV3_REPO, SWINV2_MODEL_DSV3_REPO, | |
| CONV_MODEL_DSV3_REPO, VIT_MODEL_DSV3_REPO, MOAT_MODEL_DSV2_REPO, | |
| SWIN_MODEL_DSV2_REPO, CONV_MODEL_DSV2_REPO, CONV2_MODEL_DSV2_REPO, | |
| VIT_MODEL_DSV2_REPO, EVA02_LARGE_MODEL_IS_DSV1_REPO, SWINV2_MODEL_IS_DSV1_REPO | |
| ] | |
| def _restart_space(): | |
| """Restart the HuggingFace Space periodically for stability""" | |
| HF_TOKEN = os.getenv('HF_TOKEN') | |
| if not HF_TOKEN: | |
| raise ValueError('HF_TOKEN environment variable is not set.') | |
| huggingface_hub.HfApi().restart_space( | |
| repo_id='Werli/Multi-Tagger', | |
| token=HF_TOKEN, | |
| factory_reboot=False | |
| ) | |
| # Setup scheduler for periodic restarts | |
| scheduler = BackgroundScheduler() | |
| restart_space_job = scheduler.add_job(_restart_space, 'interval', seconds=172800) | |
| scheduler.start() | |
| next_run_time_utc = restart_space_job.next_run_time.astimezone(timezone.utc) | |
| NEXT_RESTART = f"Next Restart: {next_run_time_utc.strftime('%Y-%m-%d %H:%M:%S')} (UTC) - The space will restart every 2 days to ensure stability and performance. It uses a background scheduler to handle the restart process." | |
| with gr.Blocks(title=TITLE, css=css, theme="Werli/Purple-Crimson-Gradio-Theme", fill_width=True) as demo: | |
| gr.Markdown(value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>") | |
| gr.Markdown(value=f"<p style='text-align: center;'>{DESCRIPTION}</p>") | |
| with gr.Tab(label='Waifu Diffusion'): | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Column(variant='panel'): | |
| image_input = gr.Image( | |
| label='Upload an Image (or paste from clipboard)', | |
| type='filepath', | |
| sources=['upload', 'clipboard'], | |
| height=150 | |
| ) | |
| with gr.Row(): | |
| upload_button = gr.UploadButton( | |
| 'Upload multiple images or videos', | |
| file_types=['image', 'video'], | |
| file_count='multiple', | |
| size='md' | |
| ) | |
| gallery = gr.Gallery( | |
| columns=2, | |
| show_share_button=False, | |
| interactive=True, | |
| height='auto', | |
| label='Grid of images', | |
| preview=False, | |
| elem_id='custom-gallery' | |
| ) | |
| submit = gr.Button(value='Analyze Images', variant='primary', size='lg') | |
| clear = gr.ClearButton(components=[gallery], value='Clear Gallery', variant='secondary', size='sm') | |
| with gr.Column(variant='panel'): | |
| model_repo = gr.Dropdown( | |
| dropdown_list, | |
| value=EVA02_LARGE_MODEL_DSV3_REPO, | |
| label='1st Model' | |
| ) | |
| PLUS = '+?' | |
| gr.Markdown(value=f"<p style='text-align: center;'>{PLUS}</p>") | |
| model_repo_2 = gr.Dropdown( | |
| [None] + dropdown_list, | |
| value=None, | |
| label='2nd Model (Optional)', | |
| info='Select another model for diversified results.' | |
| ) | |
| with gr.Row(): | |
| general_thresh = gr.Slider( | |
| 0, 1, | |
| step=args.score_slider_step, | |
| value=args.score_general_threshold, | |
| label='General Tags Threshold', | |
| scale=3 | |
| ) | |
| general_mcut_enabled = gr.Checkbox( | |
| value=False, | |
| label='Use MCut threshold', | |
| scale=1 | |
| ) | |
| with gr.Row(): | |
| character_thresh = gr.Slider( | |
| 0, 1, | |
| step=args.score_slider_step, | |
| value=args.score_character_threshold, | |
| label='Character Tags Threshold', | |
| scale=3 | |
| ) | |
| character_mcut_enabled = gr.Checkbox( | |
| value=False, | |
| label='Use MCut threshold', | |
| scale=1 | |
| ) | |
| with gr.Row(): | |
| characters_merge_enabled = gr.Checkbox( | |
| value=False, | |
| label='Merge characters into the string output', | |
| scale=1 | |
| ) | |
| with gr.Row(): | |
| additional_tags_prepend = gr.Text( | |
| label='Prepend Additional tags (comma split)' | |
| ) | |
| additional_tags_append = gr.Text( | |
| label='Append Additional tags (comma split)' | |
| ) | |
| with gr.Row(): | |
| clear = gr.ClearButton( | |
| components=[ | |
| gallery, model_repo, general_thresh, general_mcut_enabled, | |
| character_thresh, character_mcut_enabled, characters_merge_enabled, | |
| additional_tags_prepend, additional_tags_append | |
| ], | |
| value='Clear Everything', | |
| variant='secondary', | |
| size='lg' | |
| ) | |
| with gr.Column(variant='panel'): | |
| download_file = gr.File(label='Download') | |
| character_res = gr.Textbox( | |
| label="Character tags", | |
| show_copy_button=True, | |
| lines=3 | |
| ) | |
| sorted_general_strings = gr.Textbox( | |
| label='Output', | |
| show_label=True, | |
| show_copy_button=True, | |
| lines=5 | |
| ) | |
| categorized_strings = gr.Textbox( | |
| label='Categorized', | |
| show_label=True, | |
| show_copy_button=True, | |
| lines=5 | |
| ) | |
| tags_json = gr.JSON( | |
| label='Categorized Tags (JSON)', | |
| visible=True | |
| ) | |
| rating = gr.Label(label='Rating') | |
| general_res = gr.Textbox( | |
| label="General tags", | |
| show_copy_button=True, | |
| lines=3, | |
| visible=False # Temp | |
| ) | |
| # State to store results | |
| tag_results = gr.State({}) | |
| # Event handlers | |
| image_input.change( | |
| append_gallery, | |
| inputs=[gallery, image_input], | |
| outputs=[gallery, image_input] | |
| ) | |
| upload_button.upload( | |
| extend_gallery, | |
| inputs=[gallery, upload_button], | |
| outputs=gallery | |
| ) | |
| gallery.select( | |
| get_selection_from_gallery, | |
| inputs=[gallery, tag_results], | |
| outputs=[sorted_general_strings, rating, character_res, general_res, categorized_strings, tags_json] | |
| ) | |
| submit.click( | |
| predictor.predict, | |
| inputs=[ | |
| gallery, model_repo, model_repo_2, general_thresh, general_mcut_enabled, | |
| character_thresh, character_mcut_enabled, characters_merge_enabled, | |
| additional_tags_prepend, additional_tags_append, tag_results | |
| ], | |
| outputs=[download_file, sorted_general_strings, rating, character_res, general_res, categorized_strings, tags_json, tag_results] | |
| ) | |
| gr.Markdown('[Based on SmilingWolf/wd-tagger](https://huggingface.co/spaces/SmilingWolf/wd-tagger) <p style="text-align:right"><a href="https://huggingface.co/spaces/John6666/danbooru-tags-transformer-v2-with-wd-tagger-b">Prompt Enhancer</a></p>') | |
| with gr.Tab("PixAI"): | |
| pixai_interface = create_pixai_interface() | |
| with gr.Tab("Booru Image Fetcher"): | |
| booru_interface = create_booru_interface() | |
| with gr.Tab("ComfyUI Extractor"): | |
| comfy_interface = create_multi_comfy() | |
| with gr.Tab(label="Misc"): | |
| with gr.Row(): | |
| with gr.Column(variant="panel"): | |
| tag_string = gr.Textbox( | |
| label="Input Tags", | |
| placeholder="1girl, cat, horns, blue hair, ...\nor\n? 1girl 1234567? cat 1234567? horns 1234567? blue hair 1234567? ...", | |
| lines=4 | |
| ) | |
| submit_button = gr.Button(value="START", variant="primary", size="lg") | |
| with gr.Column(variant="panel"): | |
| cleaned_tags_output = gr.Textbox( | |
| label="Cleaned Tags", | |
| show_label=True, | |
| show_copy_button=True, | |
| lines=4, | |
| info="Tags with ? and numbers removed, formatted with commas. Useful for clearing tags from Booru sites." | |
| ) | |
| classify_tags_for_display = gr.Textbox( | |
| label="Categorized (string)", | |
| show_label=True, | |
| show_copy_button=True, | |
| lines=8, | |
| info="Tags organized by categories" | |
| ) | |
| generate_categorized_json = gr.JSON( | |
| label="Categorized JSON (tags)" | |
| ) | |
| # Fix the event handler to properly call the function | |
| submit_button.click( | |
| process_tags_for_misc, | |
| inputs=[tag_string], | |
| outputs=[cleaned_tags_output, classify_tags_for_display, generate_categorized_json] | |
| ) | |
| gr.Markdown(NEXT_RESTART) | |
| demo.queue(max_size=5).launch(show_error=True, show_api=False) | |