import gradio as gr import spaces import torch from diffusers import QwenImagePipeline from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3 from optimization import compile_transformer from hub_utils import _push_compiled_graph_to_hub from huggingface_hub import whoami # --- Model Loading --- dtype = torch.bfloat16 device = "cuda" if torch.cuda.is_available() else "cpu" # Load the model pipeline pipe = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=dtype).to(device) pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3()) @spaces.GPU(duration=120) def push_to_hub(repo_id, filename, oauth_token: gr.OAuthToken): if not filename.endswith(".pt2"): raise NotImplementedError("The filename must end with a `.pt2` extension.") # this will throw if token is invalid try: _ = whoami(oauth_token.token) # --- Ahead-of-time compilation --- compiled_transformer = compile_transformer(pipe, prompt="prompt") token = oauth_token.token out = _push_compiled_graph_to_hub( compiled_transformer.archive_file, repo_id=repo_id, token=token, path_in_repo=filename ) if not isinstance(out, str) and hasattr(out, "commit_url"): commit_url = out.commit_url return f"[{commit_url}]({commit_url})" else: return out except Exception as e: raise gr.Error(f"""Oops, you forgot to login. Please use the loggin button on the top left to migrate your repo {e}""") css=""" #col-container { margin: 0 auto; max-width: 520px; } """ with gr.Blocks(css=css) as demo: with gr.Column(elem_id="col-container"): gr.Markdown("## Compile QwenImage graph ahead of time & push to the Hub") gr.Markdown("Enter a **repo_id** and **filename**. This repo automatically compiles the [QwenImage](https://hf.co/Qwen/Qwen-Image) model on start.") repo_id = gr.Textbox(label="repo_id", placeholder="e.g. sayakpaul/qwen-aot") filename = gr.Textbox(label="filename", placeholder="e.g. compiled.pt2") run = gr.Button("Push graph to Hub", variant="primary") markdown_out = gr.Markdown() run.click(push_to_hub, inputs=[repo_id, filename], outputs=[markdown_out]) def swap_visibilty(profile: gr.OAuthProfile | None): return gr.update(elem_classes=["main_ui_logged_in"]) if profile else gr.update(elem_classes=["main_ui_logged_out"]) css_login = ''' .main_ui_logged_out{opacity: 0.3; pointer-events: none; margin: 0 auto; max-width: 520px} ''' with gr.Blocks(css=css_login) as demo_login: gr.LoginButton() with gr.Column(elem_classes="main_ui_logged_out") as main_ui: demo.render() demo_login.load(fn=swap_visibilty, outputs=main_ui) demo_login.queue() demo_login.launch()