Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import gradio as gr | |
| import spaces | |
| import os | |
| import random | |
| import torch | |
| from PIL import Image | |
| import cv2 | |
| from huggingface_hub import login | |
| from diffusers import FluxControlNetPipeline, FluxControlNetModel | |
| from diffusers.models import FluxMultiControlNetModel | |
| """ | |
| FLUX‑1 ControlNet demo | |
| ---------------------- | |
| This script rebuilds the Gradio interface shown in your screenshot with **one** control‑image upload | |
| slot and integrates the FLUX.1‑dev‑ControlNet‑Union‑Pro model. | |
| Key points | |
| ~~~~~~~~~~ | |
| * Single *control image* input (left). | |
| * *Result* and *Pre‑processed Cond* previews side‑by‑side (center & right). | |
| * *Prompt* textbox plus a dedicated **ControlNet** panel for choosing the mode and strength. | |
| * Seed handling with optional randomisation. | |
| * Advanced sliders for *Guidance scale* and *Inference steps*. | |
| * Works on CUDA (bfloat16) or CPU (float32). | |
| * Minimal Canny preview implementation when the *canny* mode is selected (extend as you like for the | |
| other modes). | |
| Before running, set the `HUGGINGFACE_TOKEN` environment variable **or** call | |
| `login("<YOUR_HF_TOKEN>")` explicitly. | |
| """ | |
| # -------------------------------------------------- | |
| # Model & pipeline setup | |
| # -------------------------------------------------- | |
| HF_TOKEN = os.getenv("HF_TOKEN_NEW") | |
| login(HF_TOKEN) | |
| # If you prefer to hard‑code the token, uncomment: | |
| # login("hf_your_token_here") | |
| BASE_MODEL = "black-forest-labs/FLUX.1-dev" | |
| CONTROLNET_MODEL = "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
| controlnet_single = FluxControlNetModel.from_pretrained( | |
| CONTROLNET_MODEL, torch_dtype=dtype | |
| ) | |
| controlnet = FluxMultiControlNetModel([controlnet_single]) | |
| pipe = FluxControlNetPipeline.from_pretrained( | |
| BASE_MODEL, controlnet=controlnet, torch_dtype=dtype | |
| ).to(device) | |
| pipe.set_progress_bar_config(disable=True) | |
| # -------------------------------------------------- | |
| # UI ‑> model value mapping | |
| # -------------------------------------------------- | |
| MODE_MAPPING = { | |
| "canny": 0, | |
| "depth": 1, | |
| "openpose": 2, | |
| "gray": 3, | |
| "blur": 4, | |
| "tile": 5, | |
| "low quality": 6, | |
| } | |
| MAX_SEED = 100 | |
| # -------------------------------------------------- | |
| # Helper: quick‑n‑dirty Canny preview (only for UI display) | |
| # -------------------------------------------------- | |
| def _preview_canny(pil_img: Image.Image) -> Image.Image: | |
| arr = np.array(pil_img.convert("RGB")) | |
| edges = cv2.Canny(arr, 100, 200) | |
| edges_rgb = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB) | |
| return Image.fromarray(edges_rgb) | |
| def _make_preview(control_image: Image.Image, mode: str) -> Image.Image: | |
| if mode == "canny": | |
| return _preview_canny(control_image) | |
| # For other modes you can plug in your own visualiser later | |
| return control_image | |
| # -------------------------------------------------- | |
| # Inference function | |
| # -------------------------------------------------- | |
| def infer( | |
| control_image: Image.Image, | |
| prompt: str, | |
| mode: str, | |
| control_strength: float, | |
| seed: int, | |
| randomize_seed: bool, | |
| guidance_scale: float, | |
| num_inference_steps: int, | |
| ): | |
| if control_image is None: | |
| raise gr.Error("Please upload a control image first.") | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| gen = torch.Generator(device).manual_seed(seed) | |
| w, h = control_image.size | |
| result = pipe( | |
| prompt=prompt, | |
| control_image=[control_image], | |
| control_mode=[MODE_MAPPING[mode]], | |
| width=w, | |
| height=h, | |
| controlnet_conditioning_scale=[control_strength], | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| generator=gen, | |
| ).images[0] | |
| preview = _make_preview(control_image, mode) | |
| return result, seed, preview | |
| # -------------------------------------------------- | |
| # Gradio UI | |
| # -------------------------------------------------- | |
| css = """#wrapper {max-width: 960px; margin: 0 auto;}""" | |
| with gr.Blocks(css=css, elem_id="wrapper") as demo: | |
| gr.Markdown("## FLUX.1‑dev‑ControlNet‑Union‑Pro") | |
| gr.Markdown( | |
| "A unified ControlNet for **FLUX.1‑dev** from the InstantX team and Shakker Labs. " | |
| + "Recommended strengths: *canny 0.65*, *tile 0.45*, *depth 0.55*, *blur 0.45*, " | |
| + "*openpose 0.55*, *gray 0.45*, *low quality 0.40*. Long prompts usually help." | |
| ) | |
| # ------------ Image panel row ------------ | |
| with gr.Row(): | |
| control_image = gr.Image( | |
| label="Upload a processed control image", | |
| type="pil", | |
| height=512, | |
| ) | |
| result_image = gr.Image(label="Result", height=512) | |
| preview_image = gr.Image(label="Pre‑processed Cond", height=512) | |
| # ------------ Prompt ------------ | |
| prompt_txt = gr.Textbox(label="Prompt", value="best quality", lines=1) | |
| # ------------ ControlNet settings ------------ | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### ControlNet") | |
| mode_radio = gr.Radio( | |
| choices=list(MODE_MAPPING.keys()), value="gray", label="Mode" | |
| ) | |
| strength_slider = gr.Slider( | |
| 0.0, 1.0, value=0.5, step=0.01, label="control strength" | |
| ) | |
| with gr.Column(): | |
| seed_slider = gr.Slider(0, MAX_SEED, step=1, value=42, label="Seed") | |
| randomize_chk = gr.Checkbox(label="Randomize seed", value=True) | |
| guidance_slider = gr.Slider( | |
| 0.0, 10.0, step=0.1, value=3.5, label="Guidance scale" | |
| ) | |
| steps_slider = gr.Slider(1, 50, step=1, value=24, label="Inference steps") | |
| submit_btn = gr.Button("Submit") | |
| submit_btn.click( | |
| fn=infer, | |
| inputs=[ | |
| control_image, | |
| prompt_txt, | |
| mode_radio, | |
| strength_slider, | |
| seed_slider, | |
| randomize_chk, | |
| guidance_slider, | |
| steps_slider, | |
| ], | |
| outputs=[result_image, seed_slider, preview_image], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |