Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import gc | |
| from PIL import Image | |
| import numpy as np | |
| import logging | |
| import io | |
| import os | |
| import requests | |
| from spandrel import ModelLoader | |
| from abc import ABC, abstractmethod | |
| from typing import Optional, Tuple, Dict | |
| import psutil | |
| import time | |
| import traceback | |
| # --- Configuration --- | |
| class Config: | |
| """Configuration settings for the application.""" | |
| MODEL_DIR = "." | |
| REALESRGAN_URL = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth" | |
| REALESRGAN_FILENAME = "RealESRGAN_x2plus.pth" | |
| # SOTA Models (2025) | |
| SPAN_URL = "https://huggingface.co/Phips/2xNomosUni_span_multijpg/resolve/main/2xNomosUni_span_multijpg.safetensors" | |
| SPAN_FILENAME = "2xNomosUni_span_multijpg.safetensors" | |
| HATS_URL = "https://huggingface.co/Phips/4xNomos8kSCHAT-S/resolve/main/4xNomos8kSCHAT-S.safetensors" | |
| HATS_FILENAME = "4xNomos8kSCHAT-S.safetensors" | |
| DEVICE = "cpu" # Force CPU for this demo, can be "cuda" if available | |
| def ensure_model_dir(): | |
| if not os.path.exists(Config.MODEL_DIR): | |
| os.makedirs(Config.MODEL_DIR) | |
| # --- Logging Setup --- | |
| class LogCapture(io.StringIO): | |
| """Custom StringIO to capture logs.""" | |
| pass | |
| log_capture_string = LogCapture() | |
| ch = logging.StreamHandler(log_capture_string) | |
| ch.setLevel(logging.INFO) | |
| formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') | |
| ch.setFormatter(formatter) | |
| logger = logging.getLogger("UpscalerApp") | |
| logger.setLevel(logging.INFO) | |
| logger.addHandler(ch) | |
| def get_logs() -> str: | |
| """Retrieve captured logs.""" | |
| return log_capture_string.getvalue() | |
| # --- System Monitoring --- | |
| def get_system_usage() -> str: | |
| """Returns current CPU and RAM usage.""" | |
| cpu_percent = psutil.cpu_percent() | |
| ram_percent = psutil.virtual_memory().percent | |
| ram_used_gb = psutil.virtual_memory().used / (1024 ** 3) | |
| return f"CPU: {cpu_percent}% | RAM: {ram_percent}% ({ram_used_gb:.1f} GB used)" | |
| # --- Abstract Base Class for Models --- | |
| class UpscalerStrategy(ABC): | |
| """Abstract base class for upscaling strategies.""" | |
| def __init__(self): | |
| self.model = None | |
| self.name = "Unknown" | |
| def load(self) -> None: | |
| """Load the model into memory.""" | |
| pass | |
| def upscale(self, image: Image.Image, **kwargs) -> Image.Image: | |
| """Upscale the given image.""" | |
| pass | |
| def unload(self) -> None: | |
| """Unload the model to free memory.""" | |
| if self.model is not None: | |
| del self.model | |
| self.model = None | |
| gc.collect() | |
| logger.info(f"Unloaded {self.name}") | |
| # --- Helper Functions for Optimization --- | |
| def manual_tile_upscale(model, img_tensor, tile_size=256, tile_pad=10, scale=2): | |
| """ | |
| Low-level tiling implementation for custom models. | |
| Prevents OOM by processing image in chunks. | |
| """ | |
| B, C, H, W = img_tensor.shape | |
| # Calculate tile dimensions | |
| tile_h = (H + tile_size - 1) // tile_size | |
| tile_w = (W + tile_size - 1) // tile_size | |
| output = torch.zeros(B, C, H * scale, W * scale, | |
| device=img_tensor.device, dtype=img_tensor.dtype) | |
| for th in range(tile_h): | |
| for tw in range(tile_w): | |
| # Calculate input tile coordinates with padding | |
| x1 = th * tile_size | |
| y1 = tw * tile_size | |
| x2 = min((th + 1) * tile_size, H) | |
| y2 = min((tw + 1) * tile_size, W) | |
| # Add halo for context | |
| x1_pad = max(0, x1 - tile_pad) | |
| y1_pad = max(0, y1 - tile_pad) | |
| x2_pad = min(H, x2 + tile_pad) | |
| y2_pad = min(W, y2 + tile_pad) | |
| # Extract padded tile | |
| tile = img_tensor[:, :, x1_pad:x2_pad, y1_pad:y2_pad] | |
| # Process tile | |
| with torch.no_grad(): | |
| tile_out = model(tile) | |
| # Calculate output crop region (remove halo) | |
| halo_x1 = (x1 - x1_pad) * scale | |
| halo_y1 = (y1 - y1_pad) * scale | |
| out_x2 = halo_x1 + (x2 - x1) * scale | |
| out_y2 = halo_y1 + (y2 - y1) * scale | |
| # Place in output | |
| output[:, :, x1*scale:x2*scale, y1*scale:y2*scale] = \ | |
| tile_out[:, :, halo_x1:out_x2, halo_y1:out_y2] | |
| return output | |
| def select_tile_config(height, width): | |
| """ | |
| Dynamically select tile size based on image resolution. | |
| """ | |
| megapixels = (height * width) / (1024 ** 2) | |
| if megapixels < 2: # < 1080p | |
| return {'tile': 512, 'tile_pad': 10} | |
| elif megapixels < 6: # < 4K | |
| return {'tile': 384, 'tile_pad': 15} | |
| elif megapixels < 16: # < 8K | |
| return {'tile': 256, 'tile_pad': 20} | |
| else: # 8K+ | |
| return {'tile': 128, 'tile_pad': 25} | |
| # --- Concrete Implementations --- | |
| class RealESRGANStrategy(UpscalerStrategy): | |
| def __init__(self): | |
| super().__init__() | |
| self.name = "RealESRGAN x2" | |
| self.compiled = False | |
| def load(self) -> None: | |
| if self.model is None: | |
| logger.info(f"Loading {self.name}...") | |
| Config.ensure_model_dir() | |
| model_path = os.path.join(Config.MODEL_DIR, Config.REALESRGAN_FILENAME) | |
| if not os.path.exists(model_path): | |
| logger.info(f"Downloading {Config.REALESRGAN_FILENAME}...") | |
| try: | |
| response = requests.get(Config.REALESRGAN_URL, stream=True) | |
| response.raise_for_status() | |
| with open(model_path, 'wb') as f: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| logger.info("Download complete.") | |
| except Exception as e: | |
| logger.error(f"Failed to download model: {e}") | |
| raise | |
| try: | |
| self.model = ModelLoader().load_from_file(model_path) | |
| self.model.eval() | |
| self.model.to(Config.DEVICE) | |
| # Optimization: torch.compile | |
| if not self.compiled: | |
| try: | |
| # 'reduce-overhead' uses CUDA graphs, so only use it on CUDA | |
| if Config.DEVICE == 'cuda': | |
| self.model = torch.compile(self.model, mode='reduce-overhead') | |
| logger.info("[INFO] torch.compile enabled (reduce-overhead mode)") | |
| elif os.name == 'nt' and Config.DEVICE == 'cpu': | |
| # Windows requires MSVC for Inductor (default cpu backend) | |
| # We skip it to avoid "Compiler: cl is not found" error unless user has it. | |
| logger.info("[INFO] Skipping torch.compile on Windows CPU to avoid MSVC requirement.") | |
| elif (psutil.cpu_count(logical=False) or 0) < 4 and Config.DEVICE == 'cpu': | |
| # Skip compilation on weak CPUs (e.g. HF Spaces Free Tier) to avoid long startup times | |
| logger.info("[INFO] Skipping torch.compile on low-core CPU to prevent timeout.") | |
| else: | |
| # On Linux/Mac CPU, use default mode or skip if problematic. Default is usually safe. | |
| self.model = torch.compile(self.model) | |
| logger.info("[SUCCESS] torch.compile enabled (default mode)") | |
| self.compiled = True | |
| except Exception as e: | |
| logger.warning(f"[WARNING] torch.compile not available or failed: {e}") | |
| self.compiled = True # Mark as tried | |
| logger.info(f"{self.name} loaded successfully.") | |
| except Exception as e: | |
| logger.error(f"Failed to load model architecture: {e}") | |
| raise | |
| def upscale(self, image: Image.Image, **kwargs) -> Image.Image: | |
| if self.model is None: | |
| self.load() | |
| logger.info(f"Starting inference with {self.name}...") | |
| start_time = time.time() | |
| img_np = np.array(image).astype(np.float32) / 255.0 | |
| img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(Config.DEVICE) | |
| # Optimization: Dynamic Tiling | |
| h, w = img_np.shape[:2] | |
| tile_config = select_tile_config(h, w) | |
| logger.info(f"Using tile config: {tile_config}") | |
| # Optimization: Mixed Precision (AMP) | |
| # Use bfloat16 for CPU if supported, else float32 (autocast handles this mostly) | |
| # For CUDA, float16 is standard. | |
| dtype = torch.float16 if Config.DEVICE == 'cuda' else torch.bfloat16 | |
| try: | |
| # Explicitly disable autocast on CPU for RealESRGAN to avoid "PythonFallbackKernel" errors | |
| # This seems to be a regression in recent PyTorch versions on CPU with some ops | |
| context = torch.autocast(device_type=Config.DEVICE, dtype=dtype) if Config.DEVICE != 'cpu' else torch.no_grad() | |
| with context: | |
| if tile_config['tile'] > 0: | |
| output_tensor = manual_tile_upscale( | |
| self.model, | |
| img_tensor, | |
| tile_size=tile_config['tile'], | |
| tile_pad=tile_config['tile_pad'], | |
| scale=2 | |
| ) | |
| else: | |
| output_tensor = self.model(img_tensor) # type: ignore | |
| except Exception as e: | |
| logger.warning(f"AMP/Tiling failed, falling back to standard FP32: {e}") | |
| # Fallback to standard execution | |
| output_tensor = self.model(img_tensor) # type: ignore | |
| output_np = output_tensor.squeeze(0).permute(1, 2, 0).clamp(0, 1).float().cpu().numpy() | |
| output_np = (output_np * 255.0).round().astype(np.uint8) | |
| elapsed = time.time() - start_time | |
| logger.info(f"Inference finished in {elapsed:.2f}s") | |
| # Benchmark info (from doc) | |
| output_megapixels = (output_np.shape[0] * output_np.shape[1]) / (1024 ** 2) | |
| throughput = output_megapixels / elapsed | |
| logger.info(f"Speed: {throughput:.2f} MP/s") | |
| return Image.fromarray(output_np) | |
| class SpanStrategy(UpscalerStrategy): | |
| def __init__(self): | |
| super().__init__() | |
| self.name = "SPAN (NomosUni) x2" | |
| self.compiled = False | |
| def load(self) -> None: | |
| if self.model is None: | |
| logger.info(f"Loading {self.name}...") | |
| Config.ensure_model_dir() | |
| model_path = os.path.join(Config.MODEL_DIR, Config.SPAN_FILENAME) | |
| if not os.path.exists(model_path): | |
| logger.info(f"Downloading {Config.SPAN_FILENAME}...") | |
| try: | |
| response = requests.get(Config.SPAN_URL, stream=True) | |
| response.raise_for_status() | |
| with open(model_path, 'wb') as f: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| logger.info("Download complete.") | |
| except Exception as e: | |
| logger.error(f"Failed to download model: {e}") | |
| raise | |
| try: | |
| self.model = ModelLoader().load_from_file(model_path) | |
| self.model.eval() | |
| self.model.to(Config.DEVICE) | |
| # Optimization: torch.compile | |
| if not self.compiled: | |
| try: | |
| if Config.DEVICE == 'cuda': | |
| self.model = torch.compile(self.model, mode='reduce-overhead') | |
| logger.info("[INFO] torch.compile enabled (reduce-overhead mode)") | |
| elif os.name == 'nt' and Config.DEVICE == 'cpu': | |
| logger.info("[INFO] Skipping torch.compile on Windows CPU.") | |
| elif (psutil.cpu_count(logical=False) or 0) < 4 and Config.DEVICE == 'cpu': | |
| logger.info("[INFO] Skipping torch.compile on low-core CPU.") | |
| else: | |
| # SPAN architecture uses .data.clone() in forward pass which breaks torch.compile/inductor | |
| logger.info("[INFO] Skipping torch.compile for SPAN (incompatible architecture).") | |
| # self.model = torch.compile(self.model) | |
| self.compiled = True | |
| except Exception: | |
| self.compiled = True | |
| logger.info(f"{self.name} loaded successfully.") | |
| except Exception as e: | |
| logger.error(f"Failed to load model architecture: {e}") | |
| raise | |
| def upscale(self, image: Image.Image, **kwargs) -> Image.Image: | |
| if self.model is None: | |
| self.load() | |
| logger.info(f"Starting inference with {self.name}...") | |
| start_time = time.time() | |
| img_np = np.array(image).astype(np.float32) / 255.0 | |
| img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(Config.DEVICE) | |
| # SPAN is very efficient, but we still use tiling for safety on huge images | |
| h, w = img_np.shape[:2] | |
| tile_config = select_tile_config(h, w) | |
| # Disable AMP for SPAN on CPU to avoid "UntypedStorage" weakref errors in inductor | |
| # SPAN architecture seems sensitive to autocast + compile on CPU | |
| dtype = torch.float32 if Config.DEVICE == 'cpu' else torch.float16 | |
| try: | |
| # Only use autocast if not CPU or if explicitly desired | |
| context = torch.autocast(device_type=Config.DEVICE, dtype=dtype) if Config.DEVICE != 'cpu' else torch.no_grad() | |
| with context: | |
| if tile_config['tile'] > 0: | |
| output_tensor = manual_tile_upscale( | |
| self.model, | |
| img_tensor, | |
| tile_size=tile_config['tile'], | |
| tile_pad=tile_config['tile_pad'], | |
| scale=2 | |
| ) | |
| else: | |
| output_tensor = self.model(img_tensor) # type: ignore | |
| except Exception as e: | |
| logger.warning(f"AMP/Tiling failed, falling back: {e}") | |
| output_tensor = self.model(img_tensor) # type: ignore | |
| output_np = output_tensor.squeeze(0).permute(1, 2, 0).clamp(0, 1).float().cpu().numpy() | |
| output_np = (output_np * 255.0).round().astype(np.uint8) | |
| elapsed = time.time() - start_time | |
| logger.info(f"Inference finished in {elapsed:.2f}s") | |
| return Image.fromarray(output_np) | |
| class HatsStrategy(UpscalerStrategy): | |
| def __init__(self): | |
| super().__init__() | |
| self.name = "HAT-S x4" | |
| self.compiled = False | |
| def load(self) -> None: | |
| if self.model is None: | |
| logger.info(f"Loading {self.name}...") | |
| Config.ensure_model_dir() | |
| model_path = os.path.join(Config.MODEL_DIR, Config.HATS_FILENAME) | |
| if not os.path.exists(model_path): | |
| logger.info(f"Downloading {Config.HATS_FILENAME}...") | |
| try: | |
| response = requests.get(Config.HATS_URL, stream=True) | |
| response.raise_for_status() | |
| with open(model_path, 'wb') as f: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| logger.info("Download complete.") | |
| except Exception as e: | |
| logger.error(f"Failed to download model: {e}") | |
| raise | |
| try: | |
| self.model = ModelLoader().load_from_file(model_path) | |
| self.model.eval() | |
| self.model.to(Config.DEVICE) | |
| if not self.compiled: | |
| try: | |
| if Config.DEVICE == 'cuda': | |
| self.model = torch.compile(self.model, mode='reduce-overhead') | |
| elif os.name == 'nt' and Config.DEVICE == 'cpu': | |
| pass | |
| elif (psutil.cpu_count(logical=False) or 0) < 4 and Config.DEVICE == 'cpu': | |
| pass | |
| else: | |
| # HAT architecture also triggers "UntypedStorage" weakref errors with inductor on CPU | |
| logger.info("[INFO] Skipping torch.compile for HAT-S (incompatible architecture).") | |
| # self.model = torch.compile(self.model) | |
| self.compiled = True | |
| except Exception: | |
| self.compiled = True | |
| logger.info(f"{self.name} loaded successfully.") | |
| except Exception as e: | |
| logger.error(f"Failed to load model architecture: {e}") | |
| raise | |
| def upscale(self, image: Image.Image, **kwargs) -> Image.Image: | |
| if self.model is None: | |
| self.load() | |
| logger.info(f"Starting inference with {self.name}...") | |
| start_time = time.time() | |
| img_np = np.array(image).astype(np.float32) / 255.0 | |
| img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(Config.DEVICE) | |
| h, w = img_np.shape[:2] | |
| tile_config = select_tile_config(h, w) | |
| dtype = torch.float16 if Config.DEVICE == 'cuda' else torch.float32 | |
| try: | |
| context = torch.autocast(device_type=Config.DEVICE, dtype=dtype) if Config.DEVICE != 'cpu' else torch.no_grad() | |
| with context: | |
| if tile_config['tile'] > 0: | |
| output_tensor = manual_tile_upscale( | |
| self.model, | |
| img_tensor, | |
| tile_size=tile_config['tile'], | |
| tile_pad=tile_config['tile_pad'], | |
| scale=4 # HAT-S is x4 | |
| ) | |
| else: | |
| output_tensor = self.model(img_tensor) # type: ignore | |
| except Exception as e: | |
| logger.warning(f"AMP/Tiling failed, falling back: {e}") | |
| output_tensor = self.model(img_tensor) # type: ignore | |
| output_np = output_tensor.squeeze(0).permute(1, 2, 0).clamp(0, 1).float().cpu().numpy() | |
| output_np = (output_np * 255.0).round().astype(np.uint8) | |
| elapsed = time.time() - start_time | |
| logger.info(f"Inference finished in {elapsed:.2f}s") | |
| return Image.fromarray(output_np) | |
| # --- Model Manager (Singleton-ish) --- | |
| class UpscalerManager: | |
| """Manages model lifecycle and selection.""" | |
| def __init__(self): | |
| self.strategies: Dict[str, UpscalerStrategy] = { | |
| "SPAN (NomosUni) x2": SpanStrategy(), | |
| "RealESRGAN x2": RealESRGANStrategy(), | |
| "HAT-S x4": HatsStrategy() | |
| } | |
| self.current_model_name: Optional[str] = None | |
| def get_strategy(self, name: str) -> UpscalerStrategy: | |
| if name not in self.strategies: | |
| raise ValueError(f"Model {name} not found.") | |
| # Memory Optimization for Free Tier (16GB RAM limit): | |
| # Ensure only one model is loaded at a time. | |
| if self.current_model_name != name: | |
| if self.current_model_name is not None: | |
| logger.info(f"Switching models: Unloading {self.current_model_name}...") | |
| self.strategies[self.current_model_name].unload() | |
| self.current_model_name = name | |
| return self.strategies[name] | |
| def unload_all(self): | |
| """Unload all models to free memory.""" | |
| for strategy in self.strategies.values(): | |
| strategy.unload() | |
| gc.collect() | |
| logger.info("All models unloaded.") | |
| manager = UpscalerManager() | |
| # --- Gradio Interface Logic --- | |
| def process_image(input_img: Image.Image, model_name: str, output_format: str) -> Tuple[Optional[str], str, str]: | |
| if input_img is None: | |
| return None, get_logs(), get_system_usage() | |
| try: | |
| strategy = manager.get_strategy(model_name) | |
| output_img = strategy.upscale(input_img) | |
| # Save to temp file with correct extension | |
| output_path = f"output.{output_format.lower()}" | |
| # Convert to RGB if saving as JPEG (doesn't support alpha) | |
| if output_format.lower() in ['jpeg', 'jpg'] and output_img.mode == 'RGBA': | |
| output_img = output_img.convert('RGB') | |
| output_img.save(output_path, format=output_format) | |
| # Explicit GC after heavy operations | |
| gc.collect() | |
| return output_path, get_logs(), get_system_usage() | |
| except Exception as e: | |
| error_msg = f"Critical Error: {str(e)}\n{traceback.format_exc()}" | |
| logger.error(error_msg) | |
| return None, get_logs() + "\n\n" + error_msg, get_system_usage() | |
| def unload_models(): | |
| manager.unload_all() | |
| return get_logs(), get_system_usage() | |
| # --- UI Construction --- | |
| desc = """ | |
| # Universal Upscaler Pro (CPU Optimized) | |
| This application provides state-of-the-art (SOTA) image upscaling running entirely on CPU, optimized for free-tier cloud environments. | |
| ### Available Models | |
| | Model | Scale | Best For | License | | |
| | :--- | :--- | :--- | :--- | | |
| | **SPAN (NomosUni)** | x2 | **Speed & General Use**. Extremely fast, parameter-free attention network. | Apache 2.0 | | |
| | **RealESRGAN** | x2 | **Robustness**. Excellent at removing JPEG artifacts and noise. | BSD 3-Clause | | |
| | **HAT-S** | x4 | **Texture Detail**. Hybrid Attention Transformer for high-fidelity restoration. | MIT | | |
| ### Attributions & Credits | |
| * **Real-ESRGAN**: [Wang et al., 2021](https://github.com/xinntao/Real-ESRGAN). *Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data*. | |
| * **SPAN**: [Zhang et al., 2023](https://github.com/hongyuanyu/SPAN). *Swift Parameter-free Attention Network for Efficient Super-Resolution*. | |
| * **HAT**: [Chen et al., 2023](https://github.com/XPixelGroup/HAT). *Activating Activation Functions for Image Restoration*. | |
| * **NomosUni**: Custom SPAN training by [Phhofm](https://github.com/Phhofm). | |
| """ | |
| with gr.Blocks(title="Universal Upscaler Pro") as iface: | |
| gr.Markdown(desc) | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=300): | |
| input_image = gr.Image(type="pil", label="Input Image", height=400) | |
| with gr.Row(): | |
| model_selector = gr.Dropdown( | |
| choices=list(manager.strategies.keys()), | |
| value="SPAN (NomosUni) x2", | |
| label="Model Architecture", | |
| scale=2 | |
| ) | |
| output_format = gr.Dropdown( | |
| choices=["PNG", "JPEG", "WEBP"], | |
| value="PNG", | |
| label="Output Format", | |
| scale=1 | |
| ) | |
| submit_btn = gr.Button("Upscale Image", variant="primary", size="lg") | |
| with gr.Accordion("Advanced Settings", open=False): | |
| unload_btn = gr.Button("Unload All Models (Free RAM)", variant="secondary") | |
| system_info = gr.Label(value=get_system_usage(), label="System Status") | |
| with gr.Column(scale=1, min_width=300): | |
| output_image = gr.Image(type="filepath", label="Upscaled Result", height=400) | |
| logs_output = gr.TextArea(label="Execution Logs", interactive=False, lines=8) | |
| # Event Wiring | |
| submit_btn.click( | |
| fn=process_image, | |
| inputs=[input_image, model_selector, output_format], | |
| outputs=[output_image, logs_output, system_info] | |
| ) | |
| unload_btn.click( | |
| fn=unload_models, | |
| inputs=[], | |
| outputs=[logs_output, system_info] | |
| ) | |
| iface.launch() | |