Spaces:
Running
Running
| import os, json, zipfile, tempfile, time, traceback | |
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| import onnxruntime as ort | |
| from collections import defaultdict | |
| from typing import Union, Dict, Any, Tuple, List | |
| from PIL import Image | |
| from huggingface_hub import hf_hub_download | |
| from huggingface_hub.errors import EntryNotFoundError | |
| from datetime import datetime | |
| # Global variables for model components (for memory management) | |
| CURRENT_MODEL = None | |
| CURRENT_MODEL_NAME = None | |
| CURRENT_TAGS_DF = None | |
| CURRENT_D_IPS = None | |
| CURRENT_PREPROCESS_FUNC = None | |
| CURRENT_THRESHOLDS = None | |
| CURRENT_CATEGORY_NAMES = None | |
| css = """ | |
| #custom-gallery {--row-height: 180px;display: grid;grid-auto-rows: min-content;gap: 10px;} | |
| #custom-gallery .thumbnail-item {height: var(--row-height);width: 100%;position: relative;overflow: hidden;border-radius: 8px;box-shadow: 0 2px 5px rgba(0, 0, 0, 0.1);transition: transform 0.2s ease, box-shadow 0.2s ease;} | |
| #custom-gallery .thumbnail-item:hover {transform: translateY(-3px);box-shadow: 0 4px 12px rgba(0, 0, 0, 0.15);} | |
| #custom-gallery .thumbnail-item img {width: auto;height: 100%;max-width: 100%;max-height: var(--row-height);object-fit: contain;margin: 0 auto;display: block;} | |
| #custom-gallery .thumbnail-item img.portrait {max-width: 100%;} | |
| #custom-gallery .thumbnail-item img.landscape {max-height: 100%;} | |
| .gallery-container {max-height: 500px;overflow-y: auto;padding-right: 0px;--size-80: 500px;} | |
| .thumbnails {display: flex;position: absolute;bottom: 0;width: 120px;overflow-x: scroll;padding-top: 320px;padding-bottom: 280px;padding-left: 4px;flex-wrap: wrap;} | |
| #custom-gallery .thumbnail-item img {width: auto;height: 100%;max-width: 100%;max-height: var(--row-height);object-fit: initial;width: fit-content;margin: 0px auto;display: block;} | |
| """ | |
| def preprocess_on_gpu(img, device='cuda'): | |
| """Preprocess image on GPU using PyTorch""" | |
| import torch | |
| import torchvision.transforms as transforms | |
| # Convert PIL to tensor and move to GPU | |
| transform = transforms.Compose([transforms.Resize((448, 448)), transforms.ToTensor(), transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])]) | |
| # Move to GPU if available | |
| tensor_img = transform(img).unsqueeze(0) | |
| if torch.cuda.is_available(): | |
| tensor_img = tensor_img.to(device) | |
| return tensor_img.cpu().numpy() | |
| class Timer: # Report the execution time & process | |
| def __init__(self): | |
| self.start_time = time.perf_counter() | |
| self.checkpoints = [('Start', self.start_time)] | |
| def checkpoint(self, label='Checkpoint'): | |
| now = time.perf_counter() | |
| self.checkpoints.append((label, now)) | |
| def report(self, is_clear_checkpoints=True): | |
| max_label_length = max(len(label) for (label, _) in self.checkpoints) if self.checkpoints else 0 | |
| prev_time = self.checkpoints[0][1] if self.checkpoints else self.start_time | |
| for (label, curr_time) in self.checkpoints[1:]: | |
| elapsed = curr_time - prev_time | |
| print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds") | |
| prev_time = curr_time | |
| if is_clear_checkpoints: | |
| self.checkpoints.clear() | |
| self.checkpoint() | |
| def report_all(self): | |
| print('\n> Execution Time Report:') | |
| max_label_length = max(len(label) for (label, _) in self.checkpoints) if len(self.checkpoints) > 0 else 0 | |
| prev_time = self.start_time | |
| for (label, curr_time) in self.checkpoints[1:]: | |
| elapsed = curr_time - prev_time | |
| print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds") | |
| prev_time = curr_time | |
| total_time = self.checkpoints[-1][1] - self.start_time if self.checkpoints else 0 | |
| print(f"{'Total Execution Time'.ljust(max_label_length)}: {total_time:.3f} seconds\n") # Performance tests | |
| self.checkpoints.clear() | |
| def restart(self): | |
| self.start_time = time.perf_counter() | |
| self.checkpoints = [('Start', self.start_time)] | |
| def _get_repo_id(model_name: str) -> str: | |
| """Get the repository ID for the specified model name.""" | |
| if '/' in model_name: | |
| return model_name | |
| else: | |
| return f'deepghs/pixai-tagger-{model_name}-onnx' | |
| def _download_model_files(model_name: str): | |
| """Download all required model files.""" | |
| repo_id = _get_repo_id(model_name) | |
| # Download the necessary files using hf_hub_download instead of local cache... | |
| model_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename='model.onnx', | |
| library_name="pixai-tagger" | |
| ) | |
| tags_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename='selected_tags.csv', | |
| library_name="pixai-tagger" | |
| ) | |
| preprocess_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename='preprocess.json', | |
| library_name="pixai-tagger" | |
| ) | |
| try: | |
| thresholds_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename='thresholds.csv', | |
| library_name="pixai-tagger" | |
| ) | |
| except EntryNotFoundError: | |
| thresholds_path = None | |
| return model_path, tags_path, preprocess_path, thresholds_path | |
| def create_optimized_ort_session(model_path): | |
| """Create an optimized ONNX Runtime session with GPU support""" | |
| # Test: Session options for better performance | |
| sess_options = ort.SessionOptions() | |
| sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL | |
| sess_options.intra_op_num_threads = 0 # Use all available cores | |
| sess_options.execution_mode = ort.ExecutionMode.ORT_PARALLEL | |
| sess_options.enable_mem_pattern = True | |
| sess_options.enable_cpu_mem_arena = True | |
| # Check available providers | |
| available_providers = ort.get_available_providers() | |
| print(f"Available ONNX Runtime providers: {available_providers}") | |
| # Use appropriate execution providers (in order of preference) | |
| providers = [] | |
| # Use CUDA if available | |
| if 'CUDAExecutionProvider' in available_providers: | |
| cuda_provider = ('CUDAExecutionProvider', { | |
| 'device_id': 0, | |
| 'arena_extend_strategy': 'kNextPowerOfTwo', | |
| 'gpu_mem_limit': 4 * 1024 * 1024 * 1024, # 4GB VRAM | |
| 'cudnn_conv_algo_search': 'EXHAUSTIVE', | |
| 'do_copy_in_default_stream': True, | |
| }) | |
| providers.append(cuda_provider) | |
| print("Using CUDA provider for ONNX inference") | |
| else: | |
| print("CUDA provider not available, falling back to CPU") | |
| # Always include CPU as fallback (FOR HF) | |
| providers.append('CPUExecutionProvider') | |
| try: | |
| session = ort.InferenceSession(model_path, sess_options, providers=providers) | |
| print(f"Model loaded with providers: {session.get_providers()}") | |
| return session | |
| except Exception as e: | |
| print(f"Failed to create ONNX session: {e}") | |
| raise | |
| def _load_model_components_optimized(model_name: str): | |
| global CURRENT_MODEL, CURRENT_MODEL_NAME, CURRENT_TAGS_DF, CURRENT_D_IPS | |
| global CURRENT_PREPROCESS_FUNC, CURRENT_THRESHOLDS, CURRENT_CATEGORY_NAMES | |
| # Only reload if model changed | |
| if CURRENT_MODEL_NAME != model_name: | |
| # Download files | |
| model_path, tags_path, preprocess_path, thresholds_path = _download_model_files(model_name) | |
| # Load optimized ONNX model | |
| CURRENT_MODEL = create_optimized_ort_session(model_path) | |
| # Load tags | |
| CURRENT_TAGS_DF = pd.read_csv(tags_path) | |
| CURRENT_D_IPS = {} | |
| if 'ips' in CURRENT_TAGS_DF.columns: | |
| CURRENT_TAGS_DF['ips'] = CURRENT_TAGS_DF['ips'].fillna('{}').map(json.loads) | |
| for name, ips in zip(CURRENT_TAGS_DF['name'], CURRENT_TAGS_DF['ips']): | |
| if ips: | |
| CURRENT_D_IPS[name] = ips | |
| # Load preprocessing | |
| with open(preprocess_path, 'r') as f: | |
| data_ = json.load(f) | |
| # Simple preprocessing function | |
| def transform(img): | |
| # Ensure image is in RGB mode | |
| if img.mode != 'RGB': | |
| img = img.convert('RGB') | |
| # Resize to 448x448 <- Very important. | |
| img = img.resize((448, 448), Image.Resampling.LANCZOS) | |
| # Convert to numpy array and normalize | |
| img_array = np.array(img).astype(np.float32) | |
| # Normalize pixel values to [0, 1] | |
| img_array = img_array / 255.0 | |
| # Normalize with ImageNet mean and std | |
| mean = np.array([0.48145466, 0.4578275, 0.40821073]).astype(np.float32) | |
| std = np.array([0.26862954, 0.26130258, 0.27577711]).astype(np.float32) | |
| img_array = (img_array - mean) / std | |
| # Transpose to (C, H, W) | |
| img_array = np.transpose(img_array, (2, 0, 1)) | |
| return img_array | |
| CURRENT_PREPROCESS_FUNC = transform | |
| # Load thresholds | |
| CURRENT_THRESHOLDS = {} | |
| CURRENT_CATEGORY_NAMES = {} | |
| if thresholds_path and os.path.exists(thresholds_path): | |
| df_category_thresholds = pd.read_csv(thresholds_path, keep_default_na=False) | |
| for item in df_category_thresholds.to_dict('records'): | |
| if item['category'] not in CURRENT_THRESHOLDS: | |
| CURRENT_THRESHOLDS[item['category']] = item['threshold'] | |
| CURRENT_CATEGORY_NAMES[item['category']] = item['name'] | |
| else: | |
| # Default thresholds if file doesn't exist | |
| CURRENT_THRESHOLDS = {0: 0.3, 4: 0.85, 9: 0.85} | |
| CURRENT_CATEGORY_NAMES = {0: 'general', 4: 'character', 9: 'rating'} | |
| CURRENT_MODEL_NAME = model_name | |
| return (CURRENT_MODEL, CURRENT_TAGS_DF, CURRENT_D_IPS, CURRENT_PREPROCESS_FUNC, | |
| CURRENT_THRESHOLDS, CURRENT_CATEGORY_NAMES) | |
| def _raw_predict(image: Image.Image, model_name: str): | |
| """Make a raw prediction with the PixAI tagger model.""" | |
| try: | |
| # Ensure we have a PIL Image | |
| if not isinstance(image, Image.Image): | |
| raise ValueError("Input must be a PIL Image") # <- | |
| # Load model components | |
| model, _, _, preprocess_func, _, _ = _load_model_components_optimized(model_name) | |
| # Preprocess image | |
| input_tensor = preprocess_func(image) | |
| # Add batch dimension | |
| if len(input_tensor.shape) == 3: | |
| input_tensor = np.expand_dims(input_tensor, axis=0) | |
| # Run inference | |
| output_names = [output.name for output in model.get_outputs()] | |
| output_values = model.run(output_names, {'input': input_tensor.astype(np.float32)}) | |
| return {name: value[0] for name, value in zip(output_names, output_values)} | |
| except Exception as e: | |
| raise RuntimeError(f"Error processing image: {str(e)}") | |
| def get_pixai_tags( | |
| image: Union[str, Image.Image], | |
| model_name: str = 'deepghs/pixai-tagger-v0.9-onnx', | |
| thresholds: Union[float, Dict[Any, float]] = None, | |
| fmt='all' | |
| ): | |
| try: | |
| # Load image if it's a path | |
| if isinstance(image, str): | |
| pil_image = Image.open(image) | |
| elif isinstance(image, Image.Image): | |
| pil_image = image | |
| else: | |
| raise ValueError("Image must be a file path or PIL Image") | |
| # Load model components | |
| _, df_tags, d_ips, _, default_thresholds, category_names = _load_model_components_optimized(model_name) | |
| values = _raw_predict(pil_image, model_name) | |
| prediction = values.get('prediction', np.array([])) | |
| if prediction.size == 0: | |
| raise RuntimeError("Model did not return valid predictions") | |
| tags = {} | |
| # Process tags by category | |
| for category in sorted(set(df_tags['category'].tolist())): | |
| mask = df_tags['category'] == category | |
| tag_names = df_tags.loc[mask, 'name'] | |
| category_pred = prediction[mask] | |
| # Determine threshold for this category | |
| if isinstance(thresholds, float): | |
| category_threshold = thresholds | |
| elif isinstance(thresholds, dict) and \ | |
| (category in thresholds or category_names.get(category, '') in thresholds): | |
| if category in thresholds: | |
| category_threshold = thresholds[category] | |
| elif category_names.get(category, '') in thresholds: | |
| category_threshold = thresholds[category_names[category]] | |
| else: | |
| category_threshold = 0.85 | |
| else: | |
| category_threshold = default_thresholds.get(category, 0.85) | |
| # Apply threshold | |
| pred_mask = category_pred >= category_threshold | |
| filtered_tag_names = tag_names[pred_mask].tolist() | |
| filtered_predictions = category_pred[pred_mask].tolist() | |
| # Sort by confidence | |
| cate_tags = dict(sorted( | |
| zip(filtered_tag_names, filtered_predictions), | |
| key=lambda x: (-x[1], x[0]) | |
| )) | |
| category_name = category_names.get(category, f"category_{category}") | |
| values[category_name] = cate_tags | |
| tags.update(cate_tags) | |
| values['tag'] = tags | |
| # Handle IPs if available | |
| if 'ips' in df_tags.columns: | |
| ips_mapping, ips_counts = {}, defaultdict(int) | |
| for tag, _ in tags.items(): | |
| if tag in d_ips: | |
| ips_mapping[tag] = d_ips[tag] | |
| for ip_name in d_ips[tag]: | |
| ips_counts[ip_name] += 1 | |
| values['ips_mapping'] = ips_mapping | |
| values['ips_count'] = dict(ips_counts) | |
| values['ips'] = [x for x, _ in sorted(ips_counts.items(), key=lambda x: (-x[1], x[0]))] | |
| # Return based on format | |
| if fmt == 'all': | |
| # Return all available categories | |
| available_categories = [category_names.get(cat, f"category_{cat}") | |
| for cat in sorted(set(df_tags['category'].tolist()))] | |
| return tuple(values.get(cat, {}) for cat in available_categories) | |
| elif fmt in values: | |
| return values[fmt] | |
| else: | |
| return values | |
| except Exception as e: | |
| raise RuntimeError(f"Error processing image: {str(e)}") | |
| def format_ips_output(ips_result, ips_mapping): | |
| """Format IP detection output as a single string with proper escaping.""" | |
| if not ips_result and not ips_mapping: | |
| return "" | |
| # Format detected IPs | |
| ips_list = [] | |
| if ips_result: | |
| ips_list = [ip.replace("(", "\\(").replace(")", "\\)").replace("_", " ") | |
| for ip in ips_result] | |
| # Format character-to-IP mapping | |
| mapping_list = [] | |
| if ips_mapping: | |
| for char, ips in ips_mapping.items(): | |
| formatted_char = char.replace("(", "\\(").replace(")", "\\)").replace("_", " ") | |
| formatted_ips = [ip.replace("(", "\\(").replace(")", "\\)").replace("_", " ") | |
| for ip in ips] | |
| mapping_list.append(f"{formatted_char}: {', '.join(formatted_ips)}") | |
| # Combine all into a single string | |
| result_parts = [] | |
| if ips_list: | |
| result_parts.append(", ".join(ips_list)) | |
| if mapping_list: | |
| result_parts.extend(mapping_list) | |
| return ", ".join(result_parts) | |
| def process_single_image( | |
| image_path, | |
| model_name="deepghs/pixai-tagger-v0.9-onnx", ### | |
| general_threshold=0.3, | |
| character_threshold=0.85, | |
| progress=None, | |
| idx=0, | |
| total_images=1 | |
| ): | |
| """Process a single image and return all formatted outputs.""" | |
| try: | |
| if image_path is None: | |
| return "", "", "", "", {}, {} | |
| if progress: | |
| progress((idx)/total_images, desc=f"Processing image {idx+1}/{total_images}") | |
| # Load image from path | |
| pil_image = Image.open(image_path) | |
| # Set thresholds | |
| thresholds = { | |
| 'general': general_threshold, | |
| 'character': character_threshold | |
| } | |
| # Get all tag categories | |
| all_categories = get_pixai_tags( | |
| pil_image, model_name, thresholds, fmt='all' | |
| ) | |
| # Ensure we have at least 3 categories (general, character, rating) | |
| while len(all_categories) < 3: | |
| all_categories += ({},) | |
| general_tags = all_categories[0] if len(all_categories) > 0 else {} | |
| character_tags = all_categories[1] if len(all_categories) > 1 else {} | |
| rating_tags = all_categories[2] if len(all_categories) > 2 else {} | |
| # Get IP detection data | |
| ips_result = get_pixai_tags(pil_image, model_name, thresholds, fmt='ips') or [] | |
| ips_mapping = get_pixai_tags(pil_image, model_name, thresholds, fmt='ips_mapping') or {} | |
| # Format character tags (names only) | |
| character_names = [name.replace("(", "\\(").replace(")", "\\)").replace("_", " ") # Replacement shouldn't be necessary here, but I'll do anyway | |
| for name in character_tags.keys()] | |
| character_output = ", ".join(character_names) | |
| # Format general tags (names only) | |
| general_names = [name.replace("(", "\\(").replace(")", "\\)").replace("_", " ") | |
| for name in general_tags.keys()] | |
| general_output = ", ".join(general_names) | |
| # Format IP detection output | |
| ips_output = format_ips_output(ips_result, ips_mapping) | |
| # Format combined tags (Character tags first, then General tags, then IP tags) | |
| combined_parts = [] | |
| if character_names: | |
| combined_parts.append(", ".join(character_names)) | |
| if general_names: | |
| combined_parts.append(", ".join(general_names)) | |
| if ips_output: | |
| combined_parts.append(ips_output) | |
| combined_output = ", ".join(combined_parts) | |
| # Get detailed JSON data | |
| json_data = { | |
| "character_tags": character_tags, | |
| "general_tags": general_tags, | |
| "rating_tags": rating_tags, | |
| "ips_result": ips_result, | |
| "ips_mapping": ips_mapping | |
| } | |
| # Format rating as label-compatible dict | |
| rating_output = {k.replace("(", "\\(").replace(")", "\\)").replace("_", " "): v | |
| for k, v in rating_tags.items()} | |
| return ( | |
| character_output, # Character tags | |
| general_output, # General tags | |
| ips_output, # IP Detection | |
| combined_output, # Combined tags | |
| json_data, # Detailed JSON | |
| rating_output # Rating <- Not working atm | |
| ) | |
| except Exception as e: | |
| error_msg = f"Error: {str(e)}" | |
| # Return error message for all 6 outputs | |
| return error_msg, error_msg, error_msg, error_msg, {}, {} # 6 | |
| """GPU""" | |
| def unload_model(): | |
| """Explicitly unload the current model from memory""" | |
| global CURRENT_MODEL, CURRENT_MODEL_NAME, CURRENT_TAGS_DF, CURRENT_D_IPS | |
| global CURRENT_PREPROCESS_FUNC, CURRENT_THRESHOLDS, CURRENT_CATEGORY_NAMES | |
| # Delete the model session | |
| if CURRENT_MODEL is not None: | |
| del CURRENT_MODEL | |
| CURRENT_MODEL = None | |
| # Clear other large objects | |
| CURRENT_TAGS_DF = None | |
| CURRENT_D_IPS = None | |
| CURRENT_PREPROCESS_FUNC = None | |
| CURRENT_THRESHOLDS = None | |
| CURRENT_CATEGORY_NAMES = None | |
| CURRENT_MODEL_NAME = None | |
| # Force garbage collection | |
| import gc | |
| gc.collect() | |
| # Clear CUDA cache if using GPU | |
| try: | |
| import torch | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| except ImportError: | |
| pass | |
| # print("Model unloaded and memory cleared") | |
| def cleanup_after_processing(): | |
| unload_model() | |
| def process_gallery_images( | |
| gallery, | |
| model_name, | |
| general_threshold, | |
| character_threshold, | |
| progress=gr.Progress() | |
| ): | |
| """Process all images in the gallery and return results with download file.""" | |
| if not gallery: | |
| return [], "", "", "", {}, {}, {}, None | |
| tag_results = {} | |
| txt_infos = [] | |
| output_dir = tempfile.mkdtemp() | |
| if not os.path.exists(output_dir): | |
| os.makedirs(output_dir) | |
| total_images = len(gallery) | |
| timer = Timer() | |
| try: | |
| for idx, image_data in enumerate(gallery): | |
| try: | |
| image_path = image_data[0] if isinstance(image_data, (list, tuple)) else image_data | |
| # Process image | |
| results = process_single_image( | |
| image_path, model_name, general_threshold, character_threshold, | |
| progress, idx, total_images | |
| ) | |
| # Store results | |
| tag_results[image_path] = { | |
| 'character_tags': results[0], | |
| 'general_tags': results[1], | |
| 'ips_detection': results[2], | |
| 'combined_tags': results[3], | |
| 'json_data': results[4], | |
| 'rating': results[5] | |
| } | |
| # Create output files with descriptive names | |
| image_name = os.path.splitext(os.path.basename(image_path))[0] | |
| # Save all output files with descriptive prefixes | |
| files_to_create = [ | |
| (f"character_tags-{image_name}.txt", results[0]), | |
| (f"general_tags-{image_name}.txt", results[1]), | |
| (f"ips_detection-{image_name}.txt", results[2]), | |
| (f"combined_tags-{image_name}.txt", results[3]), | |
| (f"detailed_json-{image_name}.json", json.dumps(results[4], indent=4, ensure_ascii=False)) | |
| ] | |
| for file_name, content in files_to_create: | |
| file_path = os.path.join(output_dir, file_name) | |
| with open(file_path, 'w', encoding='utf-8') as f: | |
| f.write(content if isinstance(content, str) else content) | |
| txt_infos.append({'path': file_path, 'name': file_name}) | |
| # Copy original image | |
| original_image = Image.open(image_path) | |
| image_copy_path = os.path.join(output_dir, f"{image_name}{os.path.splitext(image_path)[1]}") | |
| original_image.save(image_copy_path) | |
| txt_infos.append({'path': image_copy_path, 'name': f"{image_name}{os.path.splitext(image_path)[1]}"}) | |
| timer.checkpoint(f"image{idx:02d}, processed") | |
| except Exception as e: | |
| print(f"Error processing image {image_path}: {str(e)}") | |
| print(traceback.format_exc()) | |
| continue | |
| # Create zip file | |
| download_zip_path = os.path.join(output_dir, f"Multi-Tagger-{datetime.now().strftime('%Y%m%d-%H%M%S')}.zip") | |
| with zipfile.ZipFile(download_zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf: | |
| for info in txt_infos: | |
| zipf.write(info['path'], arcname=info['name']) | |
| # If using GPU, model will auto unload after zip file creation | |
| cleanup_after_processing() # Comment here to turn off this behavior | |
| progress(1.0, desc="Processing complete") | |
| timer.report_all() | |
| print('Processing is complete.') | |
| # Return first image results as default if available even if we are tagging 1000+ images. | |
| first_image_results = ("", "", "", {}, {}, "") # 6 | |
| if gallery and len(gallery) > 0: | |
| first_image_path = gallery[0][0] if isinstance(gallery[0], (list, tuple)) else gallery[0] | |
| if first_image_path in tag_results: | |
| result = tag_results[first_image_path] | |
| first_image_results = ( | |
| result['character_tags'], | |
| result['general_tags'], | |
| result['combined_tags'], | |
| result['json_data'], | |
| result['rating'], | |
| result['ips_detection'] | |
| ) | |
| return tag_results, first_image_results[0], first_image_results[1], first_image_results[2], first_image_results[3], first_image_results[4], first_image_results[5], download_zip_path | |
| except Exception as e: | |
| print(f"Error in process_gallery_images: {str(e)}") | |
| print(traceback.format_exc()) | |
| progress(1.0, desc="Processing failed") | |
| return {}, "", "", "", {}, {}, "", None | |
| def get_selection_from_gallery(gallery, tag_results, selected_state: gr.SelectData): | |
| """Handle gallery image selection and update UI with stored results.""" | |
| if not selected_state or not tag_results: | |
| return "", "", "", {}, {}, "" | |
| # Get selected image path | |
| selected_value = selected_state.value | |
| if isinstance(selected_value, dict) and 'image' in selected_value: | |
| image_path = selected_value['image']['path'] | |
| elif isinstance(selected_value, (list, tuple)) and len(selected_value) > 0: | |
| image_path = selected_value[0] | |
| else: | |
| image_path = str(selected_value) | |
| # Retrieve stored results | |
| if image_path in tag_results: | |
| result = tag_results[image_path] | |
| return ( | |
| result['character_tags'], | |
| result['general_tags'], | |
| result['combined_tags'], | |
| result['json_data'], | |
| result['rating'], | |
| result['ips_detection'] | |
| ) | |
| # Return empty if not found | |
| return "", "", "", {}, {}, "" | |
| def append_gallery(gallery, image): | |
| """Add a single image to the gallery.""" | |
| if gallery is None: | |
| gallery = [] | |
| if not image: | |
| return gallery, None | |
| gallery.append(image) | |
| return gallery, None | |
| def extend_gallery(gallery, images): | |
| """Add multiple images to the gallery.""" | |
| if gallery is None: | |
| gallery = [] | |
| if not images: | |
| return gallery | |
| gallery.extend(images) | |
| return gallery | |
| def create_pixai_interface(): | |
| """Create the PixAI Gradio interface""" | |
| with gr.Blocks(css=css, fill_width=True) as demo: | |
| # gr.Markdown("Upload anime-style images to extract tags using PixAI") | |
| # State to store results | |
| tag_results = gr.State({}) | |
| selected_image = gr.Textbox(label='Selected Image', visible=False) | |
| with gr.Row(): | |
| with gr.Column(): | |
| # Image upload section | |
| with gr.Column(variant='panel'): | |
| image_input = gr.Image( | |
| label='Upload an Image or clicking paste from clipboard button', | |
| type='filepath', | |
| sources=['upload', 'clipboard'], | |
| height=150 | |
| ) | |
| with gr.Row(): | |
| upload_button = gr.UploadButton( | |
| 'Upload multiple images', | |
| file_types=['image'], | |
| file_count='multiple', | |
| size='sm' | |
| ) | |
| gallery = gr.Gallery( | |
| columns=2, | |
| show_share_button=False, | |
| interactive=True, | |
| height='auto', | |
| label='Grid of images', | |
| preview=False, | |
| elem_id='custom-gallery' | |
| ) | |
| run_button = gr.Button("Analyze Images", variant="primary", size='lg') | |
| model_dropdown = gr.Dropdown( | |
| choices=["deepghs/pixai-tagger-v0.9-onnx"], | |
| value="deepghs/pixai-tagger-v0.9-onnx", | |
| label="Model" | |
| ) | |
| # Threshold controls | |
| with gr.Row(): | |
| general_threshold = gr.Slider( | |
| minimum=0.0, maximum=1.0, value=0.30, step=0.05, | |
| label="General Tags Threshold", scale=3 | |
| ) | |
| character_threshold = gr.Slider( | |
| minimum=0.0, maximum=1.0, value=0.85, step=0.05, | |
| label="Character Tags Threshold", scale=3 | |
| ) | |
| with gr.Row(): | |
| clear = gr.ClearButton( | |
| components=[gallery, model_dropdown, general_threshold, character_threshold], | |
| variant='secondary', | |
| size='lg' | |
| ) | |
| clear.add([tag_results]) | |
| detailed_json_output = gr.JSON(label="Detailed JSON") | |
| with gr.Column(variant='panel'): | |
| download_file = gr.File(label="Download") | |
| # Output blocks | |
| character_tags_output = gr.Textbox( | |
| label="Character tags", | |
| show_copy_button=True, | |
| lines=3 | |
| ) | |
| general_tags_output = gr.Textbox( | |
| label="General tags", | |
| show_copy_button=True, | |
| lines=3 | |
| ) | |
| ips_detection_output = gr.Textbox( | |
| label="IPs Detection", | |
| show_copy_button=True, | |
| lines=5 | |
| ) | |
| combined_tags_output = gr.Textbox( | |
| label="Combined tags", | |
| show_copy_button=True, | |
| lines=6 | |
| ) | |
| rating_output = gr.Label(label="Rating") | |
| # Clear button targets | |
| clear.add([ | |
| download_file, | |
| character_tags_output, | |
| general_tags_output, | |
| ips_detection_output, | |
| combined_tags_output, | |
| rating_output, | |
| detailed_json_output | |
| ]) | |
| # Event handlers | |
| image_input.change( | |
| append_gallery, | |
| inputs=[gallery, image_input], | |
| outputs=[gallery, image_input] | |
| ) | |
| upload_button.upload( | |
| extend_gallery, | |
| inputs=[gallery, upload_button], | |
| outputs=gallery | |
| ) | |
| gallery.select( | |
| get_selection_from_gallery, | |
| inputs=[gallery, tag_results], | |
| outputs=[ | |
| character_tags_output, | |
| general_tags_output, | |
| combined_tags_output, | |
| detailed_json_output, | |
| rating_output, | |
| ips_detection_output | |
| ] | |
| ) | |
| run_button.click( | |
| process_gallery_images, | |
| inputs=[gallery, model_dropdown, general_threshold, character_threshold], | |
| outputs=[ | |
| tag_results, | |
| character_tags_output, | |
| general_tags_output, | |
| combined_tags_output, | |
| detailed_json_output, | |
| rating_output, | |
| ips_detection_output, | |
| download_file | |
| ] | |
| ) | |
| gr.Markdown('[Based on Source code for imgutils.tagging.pixai](https://dghs-imgutils.deepghs.org/main/_modules/imgutils/tagging/pixai.html) & [pixai-labs/pixai-tagger-demo](https://huggingface.co/spaces/pixai-labs/pixai-tagger-demo)') | |
| return demo | |
| # Export public API | |
| __all__ = [ | |
| 'get_pixai_tags', | |
| 'process_single_image', | |
| 'process_gallery_images', | |
| 'create_pixai_interface', | |
| 'unload_model', | |
| 'cleanup_after_processing' | |
| ] | |