Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderTiny | |
| # For Image-to-Image, you would also import: | |
| # from diffusers import StableDiffusionImg2ImgPipeline | |
| from PIL import Image | |
| import os # For better logging/debugging | |
| from typing import Literal # For type hinting the gender choices | |
| # --- Configuration --- | |
| # 1. Force CPU usage for compatibility on machines without a GPU | |
| device = "cpu" | |
| # 2. Choose a smaller/distilled Stable Diffusion model for CPU speed | |
| # 'nota-ai/bk-sdm-small' offers a good balance of size, speed, and reasonable quality for CPU. | |
| # If higher quality is essential and you can tolerate much longer generation times on CPU, | |
| # you might consider 'runwayml/stable-diffusion-v1-5', but be prepared for significant slowdowns | |
| # and potentially higher memory consumption that might require `enable_sequential_cpu_offload()`. | |
| model_id = "nota-ai/bk-sdm-small" | |
| # 3. Tiny VAE for drastically faster encoding/decoding on CPU. This is a crucial optimization. | |
| tiny_vae_id = "sayakpaul/taesd-diffusers" | |
| # --- Model Loading --- | |
| # Load the pipeline globally when the application starts to avoid reloading on each request. | |
| print(f"[{os.getpid()}] Loading model: {model_id} on {device}...") | |
| try: | |
| # Use StableDiffusionPipeline for Text-to-Image generation (generate a new person in a style) | |
| # If you want to transform an uploaded image (Image-to-Image), uncomment the line below | |
| # and replace `StableDiffusionPipeline` with `StableDiffusionImg2ImgPipeline`. | |
| pipe_class = StableDiffusionPipeline | |
| # pipe_class = StableDiffusionImg2ImgPipeline # Uncomment this for Image-to-Image functionality | |
| pipe = pipe_class.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float32, # CPU usually performs best with float32 | |
| low_cpu_mem_usage=True, # Helps reduce peak memory usage on CPU | |
| safety_checker=None # Disable safety checker to save CPU cycles and memory for faster generation | |
| ) | |
| print(f"[{os.getpid()}] Main pipeline loaded.") | |
| # Load and assign the Tiny VAE for significant speed optimization in the VAE step | |
| print(f"[{os.getpid()}] Loading Tiny VAE from {tiny_vae_id}...") | |
| try: | |
| pipe.vae = AutoencoderTiny.from_pretrained(tiny_vae_id, torch_dtype=torch.float32) | |
| print(f"[{os.getpid()}] Tiny VAE loaded successfully.") | |
| except Exception as vae_e: | |
| print(f"[{os.getpid()}] Warning: Could not load Tiny VAE '{tiny_vae_id}': {vae_e}. Using default VAE (this will be slower).") | |
| # Ensure the default VAE is explicitly moved to CPU if Tiny VAE fails to load | |
| pipe.vae.to(device) | |
| # Move entire pipeline components to CPU explicitly | |
| pipe.to(device) | |
| # Set up the scheduler. DDIMScheduler is a good general-purpose choice. | |
| pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) | |
| # Optional: Enable CPU offload if you encounter Out-Of-Memory errors on CPU, | |
| # especially with larger models. Be aware that this will make generation significantly slower. | |
| # pipe.enable_sequential_cpu_offload() | |
| print(f"[{os.getpid()}] Model fully loaded and configured on {device}.") | |
| except Exception as e: | |
| print(f"[{os.getpid()}] FATAL ERROR: Failed to load models: {e}") | |
| # Raise an exception to prevent the application from starting if model loading fails | |
| raise RuntimeError(f"Failed to load Stable Diffusion model: {e}") | |
| # --- Preset Styles --- | |
| styles = { | |
| "Pixar": "pixar style portrait of", | |
| "Anime": "anime style portrait of", | |
| "Cyberpunk": "cyberpunk futuristic avatar of", | |
| "Disney": "disney movie character of", | |
| "Sketch": "pencil sketch portrait of", | |
| "Astronaut": "realistic astronaut with helmet, portrait of" | |
| } | |
| # --- Generation Function --- | |
| def generate_avatar(image_input: Image.Image, style: str, gender: Literal["male", "female", "unspecified"]): | |
| """ | |
| Generates an avatar based on a chosen style and gender. | |
| - If using StableDiffusionPipeline (Text-to-Image): The uploaded `image_input` | |
| is used only to trigger the generation and is NOT directly used to | |
| influence the avatar's appearance. A new person is generated based on the text. | |
| - If using StableDiffusionImg2ImgPipeline (Image-to-Image - commented out by default): | |
| The `image_input` WOULD be used as the base image for transformation. | |
| """ | |
| if image_input is None: | |
| gr.Warning("Please upload an image to enable avatar generation. (Even if it's not directly used for content, it acts as a trigger).") | |
| return None | |
| # Base prompt from the selected style | |
| base_prompt = styles[style] | |
| # Construct the subject part of the prompt based on gender selection | |
| gender_subject = "" | |
| if gender == "male": | |
| gender_subject = "a man" | |
| elif gender == "female": | |
| gender_subject = "a woman" | |
| else: # unspecified | |
| gender_subject = "a person" # Model will default based on its biases if no gender specified | |
| # Enhance the prompt for better quality and detail in text-to-image generation | |
| prompt = f"{base_prompt} {gender_subject}, high quality, detailed, professional photography, studio lighting, volumetric lighting, 4k, cinematic, sharp focus" | |
| # Stronger negative prompt to avoid common issues like low quality, distortions, and undesired artifacts | |
| negative_prompt = "low resolution, blurry, distorted, bad quality, ugly, cartoon, sketch, duplicate, out of frame, bad anatomy, deformed, extra limbs, malformed hands, missing fingers, watermark, text, signature, low contrast, oversaturated" | |
| # Inference parameters (tuned for a balance of speed and quality on CPU) | |
| num_inference_steps = 25 # Generally, 20-30 steps is a good range for quality vs speed on CPU | |
| guidance_scale = 7.5 # Higher values make output closer to prompt, but can be less diverse | |
| print(f"[{os.getpid()}] Generating for style: '{style}', gender: '{gender}', with prompt: '{prompt}' (Steps: {num_inference_steps}, Guidance: {guidance_scale})") | |
| try: | |
| # Use torch.no_grad() or torch.inference_mode() to disable gradient calculations | |
| # during inference, which saves memory and speeds up computation. | |
| with torch.no_grad(): # For PyTorch >= 1.9, torch.inference_mode() is also an option | |
| if isinstance(pipe, StableDiffusionPipeline): | |
| # Text-to-Image generation: Image_input is ignored for content | |
| generated_image = pipe( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| height=512, # Stable Diffusion 1.x models are usually trained at 512x512 | |
| width=512 | |
| ).images[0] | |
| # elif isinstance(pipe, StableDiffusionImg2ImgPipeline): | |
| # # Image-to-Image generation: Uncomment this block if you switch to Img2ImgPipeline | |
| # # The 'strength' parameter controls how much noise is added to the input image. | |
| # # 0.0 means no change, 1.0 means complete re-imagining (like text-to-image). | |
| # # A value around 0.7-0.8 is typical for style transfer. | |
| # strength = 0.75 | |
| # generated_image = pipe( | |
| # prompt=prompt, | |
| # image=image_input, # Pass the uploaded image here for img2img | |
| # negative_prompt=negative_prompt, | |
| # num_inference_steps=num_inference_steps, | |
| # guidance_scale=guidance_scale, | |
| # strength=strength | |
| # ).images[0] | |
| else: | |
| raise ValueError("Unsupported pipeline type. Please check model loading.") | |
| print(f"[{os.getpid()}] Image generation complete.") | |
| return generated_image | |
| except Exception as e: | |
| print(f"[{os.getpid()}] Error during image generation: {e}") | |
| # Display an error message to the user in the Gradio interface | |
| gr.Error(f"An error occurred during image generation: {e}") | |
| return None # Return None to clear the output image | |
| # --- Gradio Interface Definition --- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## 🎨 Stable Diffusion Avatar Generator with Preset Styles (CPU Optimized)") | |
| gr.Markdown( | |
| "This demo uses a smaller, distilled Stable Diffusion model and is optimized for CPU inference. " | |
| "Generation will still take time on CPU compared to GPU (e.g., 20-60 seconds per image depending on CPU and parameters).<br>" | |
| "**Note:** The uploaded image is currently used only to trigger generation and is **not directly influencing the avatar's appearance**. " | |
| "It's here for your reference or potential future Image-to-Image features. You will get a new person in the chosen style." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| # Image input component. type="pil" ensures a PIL Image object is passed to the function. | |
| image_input = gr.Image( | |
| label="Upload your photo", | |
| type="pil", | |
| sources=["upload", "webcam"], # Allow file upload or webcam capture | |
| # Optional: Add a placeholder image path if you want a default visual | |
| # value="assets/placeholder.jpg" | |
| ) | |
| style_selector = gr.Radio( | |
| choices=list(styles.keys()), | |
| label="Choose a style", | |
| value="Anime", # Default selected style | |
| info="Select the artistic style for your avatar." | |
| ) | |
| gender_selector = gr.Radio( | |
| choices=["male", "female", "unspecified"], | |
| label="Choose a Gender", | |
| value="male", # Default to male to address your specific issue | |
| info="Explicitly set the gender of the generated person. 'Unspecified' may lead to biased results from the model." | |
| ) | |
| generate_btn = gr.Button("Generate Avatar", variant="primary") | |
| with gr.Column(): | |
| output_image = gr.Image(label="Generated Avatar") | |
| # Connect the button click to the generation function, passing all inputs | |
| generate_btn.click( | |
| fn=generate_avatar, | |
| inputs=[image_input, style_selector, gender_selector], # Now includes gender_selector | |
| outputs=output_image | |
| ) | |
| # Optional: Add examples for quick testing | |
| gr.Examples( | |
| examples=[ | |
| # Example format: [image_path_or_None, style_name, gender] | |
| # Use None for image_input as it's not directly influencing the output in text-to-image mode | |
| [None, "Pixar", "male"], | |
| [None, "Anime", "female"], | |
| [None, "Cyberpunk", "unspecified"], # To show what 'unspecified' might produce | |
| [None, "Disney", "male"], | |
| [None, "Sketch", "female"], | |
| [None, "Astronaut", "male"] | |
| ], | |
| inputs=[image_input, style_selector, gender_selector], | |
| # fn=generate_avatar, # Uncomment if you want examples to run the generation live | |
| # outputs=output_image, | |
| cache_examples=False, # Set to True if examples are pre-computed images, False for live generation | |
| label="Quick Examples (Generates new images each time)" | |
| ) | |
| # Launch the Gradio application | |
| # share=True will generate a public link (useful for sharing demos temporarily) | |
| # auth=("username", "password") for basic authentication | |
| demo.launch(inbrowser=True, show_error=True) |