multimodalart HF Staff commited on
Commit
989c44e
·
verified ·
1 Parent(s): 06529b5

Add prompt upsampling

Browse files
Files changed (1) hide show
  1. app.py +150 -25
app.py CHANGED
@@ -13,6 +13,8 @@ from optimization import optimize_pipeline_
13
  import requests
14
  from PIL import Image
15
  import json
 
 
16
 
17
  dtype = torch.bfloat16
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -20,6 +22,34 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
20
  MAX_SEED = np.iinfo(np.int32).max
21
  MAX_IMAGE_SIZE = 1024
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  def remote_text_encoder(prompts):
24
  from gradio_client import Client
25
 
@@ -29,8 +59,8 @@ def remote_text_encoder(prompts):
29
  api_name="/encode_text"
30
  )
31
 
 
32
  prompt_embeds = torch.load(result[0])
33
-
34
  return prompt_embeds
35
 
36
  # Load model
@@ -48,56 +78,136 @@ pipe = Flux2Pipeline.from_pretrained(
48
  transformer=dit,
49
  torch_dtype=torch.bfloat16
50
  )
51
- pipe.to("cuda")
52
 
53
  pipe.transformer.set_attention_backend("_flash_3_hub")
54
 
 
55
  optimize_pipeline_(
56
  pipe,
57
  image=[Image.new("RGB", (1024, 1024))],
58
- prompt_embeds = remote_text_encoder("prompt").to("cuda"),
59
  guidance_scale=2.5,
60
  width=1024,
61
  height=1024,
62
  num_inference_steps=1
63
  )
64
 
 
 
 
 
 
65
 
66
- def get_duration(prompt, input_images=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=50, guidance_scale=2.5, progress=gr.Progress(track_tqdm=True)):
67
- num_images = 0 if input_images is None else len(input_images)
68
- step_duration = 1 + 0.7 * num_images
69
- return max(65, num_inference_steps * step_duration + 10)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  @spaces.GPU(duration=get_duration)
73
- def infer(prompt, input_images=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=50, guidance_scale=2.5, progress=gr.Progress(track_tqdm=True)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  if randomize_seed:
76
  seed = random.randint(0, MAX_SEED)
77
 
78
- # Get prompt embeddings from remote text encoder
79
- progress(0.1, desc="Encoding prompt...")
80
- prompt_embeds = remote_text_encoder(prompt).to("cuda")
81
-
82
  # Prepare image list (convert None or empty gallery to None)
83
  image_list = None
84
  if input_images is not None and len(input_images) > 0:
85
  image_list = []
86
  for item in input_images:
87
  image_list.append(item[0])
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
- # Generate image
90
- progress(0.3, desc="Generating image...")
91
- generator = torch.Generator(device=device).manual_seed(seed)
92
- image = pipe(
93
- prompt_embeds=prompt_embeds,
94
- image=image_list,
95
- width=width,
96
- height=height,
97
- num_inference_steps=num_inference_steps,
98
- guidance_scale=guidance_scale,
99
- generator=generator,
100
- ).images[0]
 
101
 
102
  return image, seed
103
 
@@ -118,6 +228,9 @@ css="""
118
  margin: 0 auto;
119
  max-width: 620px;
120
  }
 
 
 
121
  """
122
 
123
  with gr.Blocks() as demo:
@@ -152,6 +265,12 @@ FLUX.2 [dev] is a 32B model rectified flow capable of generating, editing and co
152
 
153
  with gr.Accordion("Advanced Settings", open=False):
154
 
 
 
 
 
 
 
155
  seed = gr.Slider(
156
  label="Seed",
157
  minimum=0,
@@ -180,6 +299,12 @@ FLUX.2 [dev] is a 32B model rectified flow capable of generating, editing and co
180
  value=1024,
181
  )
182
 
 
 
 
 
 
 
183
  with gr.Row():
184
 
185
  num_inference_steps = gr.Slider(
@@ -219,7 +344,7 @@ FLUX.2 [dev] is a 32B model rectified flow capable of generating, editing and co
219
  gr.on(
220
  triggers=[run_button.click, prompt.submit],
221
  fn=infer,
222
- inputs=[prompt, input_images, seed, randomize_seed, width, height, num_inference_steps, guidance_scale],
223
  outputs=[result, seed]
224
  )
225
 
 
13
  import requests
14
  from PIL import Image
15
  import json
16
+ import base64
17
+ from huggingface_hub import InferenceClient
18
 
19
  dtype = torch.bfloat16
20
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
22
  MAX_SEED = np.iinfo(np.int32).max
23
  MAX_IMAGE_SIZE = 1024
24
 
25
+ # Setup VLM Client
26
+ hf_client = InferenceClient(
27
+ api_key=os.environ.get("HF_TOKEN"),
28
+ )
29
+ VLM_MODEL = "baidu/ERNIE-4.5-VL-424B-A47B-Base-PT"
30
+
31
+ SYSTEM_PROMPT_TEXT_ONLY = """You are an expert prompt engineer for FLUX.2 by Black Forest Labs. Rewrite user prompts to be more descriptive while strictly preserving their core subject and intent.
32
+
33
+ Guidelines:
34
+ 1. Structure: Keep structured inputs structured (enhance within fields). Convert natural language to detailed paragraphs.
35
+ 2. Details: Add concrete visual specifics - form, scale, textures, materials, lighting (quality, direction, color), shadows, spatial relationships, and environmental context.
36
+ 3. Text in Images: Put ALL text in quotation marks, matching the prompt's language. Always provide explicit quoted text for objects that would contain text in reality (signs, labels, screens, etc.) - without it, the model generates gibberish.
37
+
38
+ Output only the revised prompt and nothing else."""
39
+
40
+ SYSTEM_PROMPT_WITH_IMAGES = """You are FLUX.2 by Black Forest Labs, an image-editing expert. You convert editing requests into one concise instruction (50-80 words, ~30 for brief requests).
41
+
42
+ Rules:
43
+ - Single instruction only, no commentary
44
+ - Use clear, analytical language (avoid "whimsical," "cascading," etc.)
45
+ - Specify what changes AND what stays the same (face, lighting, composition)
46
+ - Reference actual image elements
47
+ - Turn negatives into positives ("don't change X" → "keep X")
48
+ - Make abstractions concrete ("futuristic" → "glowing cyan neon, metallic panels")
49
+ - Keep content PG-13
50
+
51
+ Output only the final instruction in plain text and nothing else."""
52
+
53
  def remote_text_encoder(prompts):
54
  from gradio_client import Client
55
 
 
59
  api_name="/encode_text"
60
  )
61
 
62
+ # Load returns a tensor, usually on CPU by default
63
  prompt_embeds = torch.load(result[0])
 
64
  return prompt_embeds
65
 
66
  # Load model
 
78
  transformer=dit,
79
  torch_dtype=torch.bfloat16
80
  )
81
+ pipe.to(device)
82
 
83
  pipe.transformer.set_attention_backend("_flash_3_hub")
84
 
85
+ # Optimization runs once at startup
86
  optimize_pipeline_(
87
  pipe,
88
  image=[Image.new("RGB", (1024, 1024))],
89
+ prompt_embeds = remote_text_encoder("prompt").to(device),
90
  guidance_scale=2.5,
91
  width=1024,
92
  height=1024,
93
  num_inference_steps=1
94
  )
95
 
96
+ def image_to_data_uri(img):
97
+ buffered = io.BytesIO()
98
+ img.save(buffered, format="PNG")
99
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
100
+ return f"data:image/png;base64,{img_str}"
101
 
102
+ def upsample_prompt_logic(prompt, image_list):
103
+ try:
104
+ if image_list and len(image_list) > 0:
105
+ # Image + Text Editing Mode
106
+ system_content = SYSTEM_PROMPT_WITH_IMAGES
107
+
108
+ # Construct user message with text and images
109
+ user_content = [{"type": "text", "text": prompt}]
110
+
111
+ for img in image_list:
112
+ data_uri = image_to_data_uri(img)
113
+ user_content.append({
114
+ "type": "image_url",
115
+ "image_url": {"url": data_uri}
116
+ })
117
+
118
+ messages = [
119
+ {"role": "system", "content": system_content},
120
+ {"role": "user", "content": user_content}
121
+ ]
122
+ else:
123
+ # Text Only Mode
124
+ system_content = SYSTEM_PROMPT_TEXT_ONLY
125
+ messages = [
126
+ {"role": "system", "content": system_content},
127
+ {"role": "user", "content": prompt}
128
+ ]
129
 
130
+ completion = hf_client.chat.completions.create(
131
+ model=VLM_MODEL,
132
+ messages=messages,
133
+ max_tokens=1024
134
+ )
135
+
136
+ return completion.choices[0].message.content
137
+ except Exception as e:
138
+ print(f"Upsampling failed: {e}")
139
+ return prompt
140
+
141
+ # Updated duration function to match generate_image arguments (including progress)
142
+ def get_duration(prompt_embeds, image_list, width, height, num_inference_steps, guidance_scale, seed, force_dimensions, progress=gr.Progress(track_tqdm=True)):
143
+ num_images = 0 if image_list is None else len(image_list)
144
+ step_duration = 1 + 0.8 * num_images
145
+ return max(65, num_inference_steps * step_duration + 10)
146
 
147
  @spaces.GPU(duration=get_duration)
148
+ def generate_image(prompt_embeds, image_list, width, height, num_inference_steps, guidance_scale, seed, force_dimensions, progress=gr.Progress(track_tqdm=True)):
149
+ # Move embeddings to GPU only when inside the GPU decorated function
150
+ prompt_embeds = prompt_embeds.to(device)
151
+
152
+ generator = torch.Generator(device=device).manual_seed(seed)
153
+
154
+ pipe_kwargs = {
155
+ "prompt_embeds": prompt_embeds,
156
+ "image": image_list,
157
+ "num_inference_steps": num_inference_steps,
158
+ "guidance_scale": guidance_scale,
159
+ "generator": generator,
160
+ }
161
+
162
+ if image_list is None or force_dimensions:
163
+ pipe_kwargs["width"] = width
164
+ pipe_kwargs["height"] = height
165
+
166
+ # Progress bar for the actual generation steps
167
+ if progress:
168
+ progress(0, desc="Starting generation...")
169
+
170
+ image = pipe(**pipe_kwargs).images[0]
171
+ return image
172
+
173
+ def infer(prompt, input_images=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=50, guidance_scale=2.5, force_dimensions=False, prompt_upsampling=False, progress=gr.Progress(track_tqdm=True)):
174
 
175
  if randomize_seed:
176
  seed = random.randint(0, MAX_SEED)
177
 
 
 
 
 
178
  # Prepare image list (convert None or empty gallery to None)
179
  image_list = None
180
  if input_images is not None and len(input_images) > 0:
181
  image_list = []
182
  for item in input_images:
183
  image_list.append(item[0])
184
+
185
+ # 1. Upsampling (Network bound - No GPU needed)
186
+ final_prompt = prompt
187
+ if prompt_upsampling:
188
+ progress(0.05, desc="Upsampling prompt...")
189
+ final_prompt = upsample_prompt_logic(prompt, image_list)
190
+ print(f"Original Prompt: {prompt}")
191
+ print(f"Upsampled Prompt: {final_prompt}")
192
+
193
+ # 2. Text Encoding (Network bound - No GPU needed)
194
+ progress(0.1, desc="Encoding prompt...")
195
+ # This returns CPU tensors
196
+ prompt_embeds = remote_text_encoder(final_prompt)
197
 
198
+ # 3. Image Generation (GPU bound)
199
+ progress(0.3, desc="Waiting for GPU...")
200
+ image = generate_image(
201
+ prompt_embeds,
202
+ image_list,
203
+ width,
204
+ height,
205
+ num_inference_steps,
206
+ guidance_scale,
207
+ seed,
208
+ force_dimensions,
209
+ progress
210
+ )
211
 
212
  return image, seed
213
 
 
228
  margin: 0 auto;
229
  max-width: 620px;
230
  }
231
+ .gallery-container img{
232
+ object-fit: contain;
233
+ }
234
  """
235
 
236
  with gr.Blocks() as demo:
 
265
 
266
  with gr.Accordion("Advanced Settings", open=False):
267
 
268
+ prompt_upsampling = gr.Checkbox(
269
+ label="Prompt Upsampling",
270
+ value=True,
271
+ info="Automatically enhance the prompt using a VLM"
272
+ )
273
+
274
  seed = gr.Slider(
275
  label="Seed",
276
  minimum=0,
 
299
  value=1024,
300
  )
301
 
302
+ force_dimensions = gr.Checkbox(
303
+ label="Force width/height when image input",
304
+ value=False,
305
+ info="When unchecked, width/height settings are ignored if input images are provided"
306
+ )
307
+
308
  with gr.Row():
309
 
310
  num_inference_steps = gr.Slider(
 
344
  gr.on(
345
  triggers=[run_button.click, prompt.submit],
346
  fn=infer,
347
+ inputs=[prompt, input_images, seed, randomize_seed, width, height, num_inference_steps, guidance_scale, force_dimensions, prompt_upsampling],
348
  outputs=[result, seed]
349
  )
350