Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -72,6 +72,8 @@ def create_training_demo(trainer: Trainer,
|
|
| 72 |
concept_images = gr.Files(label='Images for your concept')
|
| 73 |
concept_prompt = gr.Textbox(label='Concept Prompt',
|
| 74 |
max_lines=1)
|
|
|
|
|
|
|
| 75 |
gr.Markdown('''
|
| 76 |
- Upload images of the style you are planning on training on.
|
| 77 |
- For a concept prompt, use a unique, made up word to avoid collisions.
|
|
@@ -80,11 +82,13 @@ def create_training_demo(trainer: Trainer,
|
|
| 80 |
gr.Markdown('Training Parameters')
|
| 81 |
num_training_steps = gr.Number(
|
| 82 |
label='Number of Training Steps', value=1000, precision=0)
|
| 83 |
-
learning_rate = gr.Number(label='Learning Rate', value=0.
|
| 84 |
train_text_encoder = gr.Checkbox(label='Train Text Encoder',
|
| 85 |
value=True)
|
|
|
|
|
|
|
| 86 |
learning_rate_text = gr.Number(
|
| 87 |
-
label='Learning Rate for Text Encoder', value=0.
|
| 88 |
gradient_accumulation = gr.Number(
|
| 89 |
label='Number of Gradient Accumulation',
|
| 90 |
value=1,
|
|
@@ -145,7 +149,7 @@ def find_weight_files() -> list[str]:
|
|
| 145 |
return [path.relative_to(curr_dir).as_posix() for path in paths]
|
| 146 |
|
| 147 |
|
| 148 |
-
def
|
| 149 |
return gr.update(choices=find_weight_files())
|
| 150 |
|
| 151 |
|
|
@@ -159,23 +163,13 @@ def create_inference_demo(pipe: InferencePipeline) -> gr.Blocks:
|
|
| 159 |
label='Base Model',
|
| 160 |
visible=False)
|
| 161 |
reload_button = gr.Button('Reload Weight List')
|
| 162 |
-
|
| 163 |
-
value='
|
| 164 |
-
label='
|
| 165 |
prompt = gr.Textbox(
|
| 166 |
label='Prompt',
|
| 167 |
max_lines=1,
|
| 168 |
-
placeholder='Example: "
|
| 169 |
-
alpha = gr.Slider(label='Alpha',
|
| 170 |
-
minimum=0,
|
| 171 |
-
maximum=2,
|
| 172 |
-
step=0.05,
|
| 173 |
-
value=1)
|
| 174 |
-
alpha_for_text = gr.Slider(label='Alpha for Text Encoder',
|
| 175 |
-
minimum=0,
|
| 176 |
-
maximum=2,
|
| 177 |
-
step=0.05,
|
| 178 |
-
value=1)
|
| 179 |
seed = gr.Slider(label='Seed',
|
| 180 |
minimum=0,
|
| 181 |
maximum=100000,
|
|
@@ -184,52 +178,53 @@ def create_inference_demo(pipe: InferencePipeline) -> gr.Blocks:
|
|
| 184 |
with gr.Accordion('Other Parameters', open=False):
|
| 185 |
num_steps = gr.Slider(label='Number of Steps',
|
| 186 |
minimum=0,
|
| 187 |
-
maximum=
|
| 188 |
step=1,
|
| 189 |
-
value=
|
| 190 |
guidance_scale = gr.Slider(label='CFG Scale',
|
| 191 |
minimum=0,
|
| 192 |
maximum=50,
|
| 193 |
step=0.1,
|
| 194 |
-
value=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
|
| 196 |
run_button = gr.Button('Generate')
|
| 197 |
|
| 198 |
gr.Markdown('''
|
| 199 |
-
- Models with names starting with "
|
| 200 |
- After training, you can press "Reload Weight List" button to load your trained model names.
|
| 201 |
-
- The pretrained models for "disney", "illust" and "pop" are trained with the concept prompt "style of sks".
|
| 202 |
-
- The pretrained model for "kiriko" is trained with the concept prompt "game character bnha". For this model, the text encoder is also trained.
|
| 203 |
''')
|
| 204 |
with gr.Column():
|
| 205 |
result = gr.Image(label='Result')
|
| 206 |
|
| 207 |
-
reload_button.click(fn=
|
| 208 |
inputs=None,
|
| 209 |
-
outputs=
|
| 210 |
prompt.submit(fn=pipe.run,
|
| 211 |
inputs=[
|
| 212 |
base_model,
|
| 213 |
-
|
| 214 |
prompt,
|
| 215 |
-
alpha,
|
| 216 |
-
alpha_for_text,
|
| 217 |
seed,
|
| 218 |
num_steps,
|
| 219 |
guidance_scale,
|
|
|
|
| 220 |
],
|
| 221 |
outputs=result,
|
| 222 |
queue=False)
|
| 223 |
run_button.click(fn=pipe.run,
|
| 224 |
inputs=[
|
| 225 |
base_model,
|
| 226 |
-
|
| 227 |
prompt,
|
| 228 |
-
alpha,
|
| 229 |
-
alpha_for_text,
|
| 230 |
seed,
|
| 231 |
num_steps,
|
| 232 |
guidance_scale,
|
|
|
|
| 233 |
],
|
| 234 |
outputs=result,
|
| 235 |
queue=False)
|
|
|
|
| 72 |
concept_images = gr.Files(label='Images for your concept')
|
| 73 |
concept_prompt = gr.Textbox(label='Concept Prompt',
|
| 74 |
max_lines=1)
|
| 75 |
+
class_prompt = gr.Textbox(label='Regularization set Prompt',
|
| 76 |
+
max_lines=1)
|
| 77 |
gr.Markdown('''
|
| 78 |
- Upload images of the style you are planning on training on.
|
| 79 |
- For a concept prompt, use a unique, made up word to avoid collisions.
|
|
|
|
| 82 |
gr.Markdown('Training Parameters')
|
| 83 |
num_training_steps = gr.Number(
|
| 84 |
label='Number of Training Steps', value=1000, precision=0)
|
| 85 |
+
learning_rate = gr.Number(label='Learning Rate', value=0.00001)
|
| 86 |
train_text_encoder = gr.Checkbox(label='Train Text Encoder',
|
| 87 |
value=True)
|
| 88 |
+
modifier_token = gr.Checkbox(label='modifier token',
|
| 89 |
+
value=True)
|
| 90 |
learning_rate_text = gr.Number(
|
| 91 |
+
label='Learning Rate for Text Encoder', value=0.00001)
|
| 92 |
gradient_accumulation = gr.Number(
|
| 93 |
label='Number of Gradient Accumulation',
|
| 94 |
value=1,
|
|
|
|
| 149 |
return [path.relative_to(curr_dir).as_posix() for path in paths]
|
| 150 |
|
| 151 |
|
| 152 |
+
def reload_custom_diffusion_weight_list() -> dict:
|
| 153 |
return gr.update(choices=find_weight_files())
|
| 154 |
|
| 155 |
|
|
|
|
| 163 |
label='Base Model',
|
| 164 |
visible=False)
|
| 165 |
reload_button = gr.Button('Reload Weight List')
|
| 166 |
+
weight_name = gr.Dropdown(choices=find_weight_files(),
|
| 167 |
+
value='custom-diffusion/cat.ckpt',
|
| 168 |
+
label='Custom Diffusion Weight File')
|
| 169 |
prompt = gr.Textbox(
|
| 170 |
label='Prompt',
|
| 171 |
max_lines=1,
|
| 172 |
+
placeholder='Example: "<new1> cat swimming in a pool"')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
seed = gr.Slider(label='Seed',
|
| 174 |
minimum=0,
|
| 175 |
maximum=100000,
|
|
|
|
| 178 |
with gr.Accordion('Other Parameters', open=False):
|
| 179 |
num_steps = gr.Slider(label='Number of Steps',
|
| 180 |
minimum=0,
|
| 181 |
+
maximum=500,
|
| 182 |
step=1,
|
| 183 |
+
value=200)
|
| 184 |
guidance_scale = gr.Slider(label='CFG Scale',
|
| 185 |
minimum=0,
|
| 186 |
maximum=50,
|
| 187 |
step=0.1,
|
| 188 |
+
value=6
|
| 189 |
+
eta = gr.Slider(label='CFG Scale',
|
| 190 |
+
minimum=0,
|
| 191 |
+
maximum=1.,
|
| 192 |
+
step=0.1,
|
| 193 |
+
value=1.)
|
| 194 |
|
| 195 |
run_button = gr.Button('Generate')
|
| 196 |
|
| 197 |
gr.Markdown('''
|
| 198 |
+
- Models with names starting with "custom-diffusion/" are the pretrained models provided in the [original repo](https://github.com/adobe-research/custom-diffusion), and the ones with names starting with "results/" are your trained models.
|
| 199 |
- After training, you can press "Reload Weight List" button to load your trained model names.
|
|
|
|
|
|
|
| 200 |
''')
|
| 201 |
with gr.Column():
|
| 202 |
result = gr.Image(label='Result')
|
| 203 |
|
| 204 |
+
reload_button.click(fn=reload_custom_diffusion_weight_list,
|
| 205 |
inputs=None,
|
| 206 |
+
outputs=weight_name)
|
| 207 |
prompt.submit(fn=pipe.run,
|
| 208 |
inputs=[
|
| 209 |
base_model,
|
| 210 |
+
weight_name,
|
| 211 |
prompt,
|
|
|
|
|
|
|
| 212 |
seed,
|
| 213 |
num_steps,
|
| 214 |
guidance_scale,
|
| 215 |
+
eta,
|
| 216 |
],
|
| 217 |
outputs=result,
|
| 218 |
queue=False)
|
| 219 |
run_button.click(fn=pipe.run,
|
| 220 |
inputs=[
|
| 221 |
base_model,
|
| 222 |
+
weight_name,
|
| 223 |
prompt,
|
|
|
|
|
|
|
| 224 |
seed,
|
| 225 |
num_steps,
|
| 226 |
guidance_scale,
|
| 227 |
+
eta,
|
| 228 |
],
|
| 229 |
outputs=result,
|
| 230 |
queue=False)
|