multimodalart HF Staff commited on
Commit
9524258
·
verified ·
1 Parent(s): 89d78ab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +126 -54
app.py CHANGED
@@ -27,10 +27,26 @@ def get_default_negative_prompt(existing_json: dict) -> str:
27
  return negative_prompt
28
 
29
  @spaces.GPU(duration=300)
30
- def infer(
31
  prompt,
32
- prompt_refine,
33
  prompt_inspire_image,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  prompt_in_json,
35
  negative_prompt="",
36
  seed=42,
@@ -39,28 +55,66 @@ def infer(
39
  height=1024,
40
  guidance_scale=5,
41
  num_inference_steps=50,
42
- mode="generate",
43
  ):
44
  if randomize_seed:
45
  seed = random.randint(0, MAX_SEED)
46
 
47
  with torch.inference_mode():
48
-
49
- if mode == "refine":
50
- json_prompt_str = (
 
 
 
 
 
 
 
51
  json.dumps(prompt_in_json)
52
  if isinstance(prompt_in_json, (dict, list))
53
  else str(prompt_in_json)
54
  )
55
- output = vlm_pipe(json_prompt=json_prompt_str, prompt=prompt_refine)
56
-
57
- elif mode == "inspire":
58
- if prompt_inspire_image is None:
59
- raise ValueError("Please upload an image to inspire the model.")
60
- output = vlm_pipe(image=prompt_inspire_image, prompt="")
61
 
 
 
 
62
  else:
63
- output = vlm_pipe(prompt=prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  json_prompt = output.values["json_prompt"]
65
 
66
  if negative_prompt:
@@ -77,8 +131,9 @@ def infer(
77
  height=height,
78
  guidance_scale=guidance_scale,
79
  ).images[0]
 
80
  print(neg_json_prompt)
81
- return image, seed, json_prompt, json.dumps(neg_json_prompt), gr.update(open=True)
82
 
83
 
84
  css = """
@@ -104,28 +159,31 @@ with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="violet")) as demo:
104
 
105
  with gr.Row(elem_id="col-container"):
106
  with gr.Column(scale=1):
107
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  with gr.Row():
109
- with gr.Tab("generate") as tab_generate:
110
- prompt_generate = gr.Textbox(
111
- label="Prompt",
112
- placeholder="a man holding a goose screaming"
113
- )
114
-
115
- with gr.Tab("refine") as tab_refine:
116
- prompt_refine = gr.Textbox(
117
- label="Prompt",
118
- info="describe the change you want to make",
119
- placeholder="make the cat white"
120
- )
121
-
122
- with gr.Tab("inspire") as tab_inspire:
123
- prompt_inspire_image = gr.Image(
124
- label="Inspiration Image",
125
- type="pil",
126
- )
127
-
128
- submit_btn = gr.Button("Generate", variant="primary")
129
 
130
  with gr.Accordion("Advanced Settings", open=False):
131
  with gr.Row():
@@ -144,27 +202,42 @@ with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="violet")) as demo:
144
 
145
  with gr.Column(scale=1):
146
  result = gr.Image(label="output")
147
- with gr.Accordion("Structured Prompt", open=False) as structured_accordion:
148
- prompt_in_json = gr.JSON(label="json structured prompt")
149
-
150
- # Track active tab
151
- current_mode = gr.State("generate")
152
- # When "generate" is selected — just set mode
153
- tab_generate.select(lambda: ("generate", gr.update(value=True)), outputs=[current_mode, randomize_seed])
154
-
155
- # When "refine" is selected — set mode and turn off randomize_seed
156
- tab_refine.select(lambda: ("refine", gr.update(value=False)), outputs=[current_mode, randomize_seed])
157
-
158
- # When "inspire" is selected — normal
159
- tab_inspire.select(lambda: ("inspire", gr.update(value=True)), outputs=[current_mode, randomize_seed])
160
 
 
 
 
 
 
 
 
 
 
 
161
 
162
- submit_btn.click(
163
- fn=infer,
 
164
  inputs=[
165
  prompt_generate,
166
- prompt_refine,
167
  prompt_inspire_image,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  prompt_in_json,
169
  negative_prompt,
170
  seed,
@@ -173,9 +246,8 @@ with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="violet")) as demo:
173
  height,
174
  guidance_scale,
175
  num_inference_steps,
176
- current_mode,
177
  ],
178
- outputs=[result, seed, prompt_in_json, negative_prompt_json, structured_accordion],
179
  )
180
 
181
- demo.queue().launch()
 
27
  return negative_prompt
28
 
29
  @spaces.GPU(duration=300)
30
+ def generate_json_prompt(
31
  prompt,
 
32
  prompt_inspire_image,
33
+ use_inspire,
34
+ ):
35
+ with torch.inference_mode():
36
+ if use_inspire and prompt_inspire_image is not None:
37
+ output = vlm_pipe(image=prompt_inspire_image, prompt="")
38
+ else:
39
+ output = vlm_pipe(prompt=prompt)
40
+
41
+ json_prompt = output.values["json_prompt"]
42
+
43
+ return json_prompt
44
+
45
+ @spaces.GPU(duration=300)
46
+ def generate_image(
47
+ prompt,
48
+ prompt_inspire_image,
49
+ use_inspire,
50
  prompt_in_json,
51
  negative_prompt="",
52
  seed=42,
 
55
  height=1024,
56
  guidance_scale=5,
57
  num_inference_steps=50,
 
58
  ):
59
  if randomize_seed:
60
  seed = random.randint(0, MAX_SEED)
61
 
62
  with torch.inference_mode():
63
+ # If JSON prompt is empty or None, generate it first
64
+ if not prompt_in_json or prompt_in_json == "":
65
+ if use_inspire and prompt_inspire_image is not None:
66
+ output = vlm_pipe(image=prompt_inspire_image, prompt="")
67
+ else:
68
+ output = vlm_pipe(prompt=prompt)
69
+ json_prompt = output.values["json_prompt"]
70
+ else:
71
+ # Use the provided JSON prompt
72
+ json_prompt = (
73
  json.dumps(prompt_in_json)
74
  if isinstance(prompt_in_json, (dict, list))
75
  else str(prompt_in_json)
76
  )
 
 
 
 
 
 
77
 
78
+ if negative_prompt:
79
+ neg_output = vlm_pipe(prompt=negative_prompt)
80
+ neg_json_prompt = neg_output.values["json_prompt"]
81
  else:
82
+ neg_json_prompt = get_default_negative_prompt(json.loads(json_prompt))
83
+
84
+ image = pipe(
85
+ prompt=json_prompt,
86
+ num_inference_steps=num_inference_steps,
87
+ negative_prompt=neg_json_prompt,
88
+ width=width,
89
+ height=height,
90
+ guidance_scale=guidance_scale,
91
+ ).images[0]
92
+
93
+ print(neg_json_prompt)
94
+ return image, seed, json_prompt, json.dumps(neg_json_prompt), gr.update(visible=True)
95
+
96
+ @spaces.GPU(duration=300)
97
+ def refine_prompt(
98
+ refine_instruction,
99
+ prompt_in_json,
100
+ negative_prompt="",
101
+ seed=42,
102
+ randomize_seed=False,
103
+ width=1024,
104
+ height=1024,
105
+ guidance_scale=5,
106
+ num_inference_steps=50,
107
+ ):
108
+ if randomize_seed:
109
+ seed = random.randint(0, MAX_SEED)
110
+
111
+ with torch.inference_mode():
112
+ json_prompt_str = (
113
+ json.dumps(prompt_in_json)
114
+ if isinstance(prompt_in_json, (dict, list))
115
+ else str(prompt_in_json)
116
+ )
117
+ output = vlm_pipe(json_prompt=json_prompt_str, prompt=refine_instruction)
118
  json_prompt = output.values["json_prompt"]
119
 
120
  if negative_prompt:
 
131
  height=height,
132
  guidance_scale=guidance_scale,
133
  ).images[0]
134
+
135
  print(neg_json_prompt)
136
+ return image, seed, json_prompt, json.dumps(neg_json_prompt)
137
 
138
 
139
  css = """
 
159
 
160
  with gr.Row(elem_id="col-container"):
161
  with gr.Column(scale=1):
162
+
163
+ with gr.Accordion("Inspire from Image", open=False) as inspire_accordion:
164
+ prompt_inspire_image = gr.Image(
165
+ label="Inspiration Image",
166
+ type="pil",
167
+ )
168
+ use_inspire = gr.Checkbox(label="Use inspiration image", value=False)
169
+
170
+ prompt_generate = gr.Textbox(
171
+ label="Prompt",
172
+ placeholder="a man holding a goose screaming",
173
+ lines=3
174
+ )
175
+
176
+ prompt_in_json = gr.JSON(
177
+ label="Structured JSON Prompt (editable)",
178
+ value=None
179
+ )
180
+
181
  with gr.Row():
182
+ generate_json_btn = gr.Button("Generate JSON Prompt", variant="secondary")
183
+ generate_image_btn = gr.Button("Generate Image", variant="primary")
184
+
185
+ # Refine button - initially hidden
186
+ refine_btn = gr.Button("Refine existing image", variant="primary", visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
  with gr.Accordion("Advanced Settings", open=False):
189
  with gr.Row():
 
202
 
203
  with gr.Column(scale=1):
204
  result = gr.Image(label="output")
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
+ # Generate JSON prompt only
207
+ generate_json_btn.click(
208
+ fn=generate_json_prompt,
209
+ inputs=[
210
+ prompt_generate,
211
+ prompt_inspire_image,
212
+ use_inspire,
213
+ ],
214
+ outputs=[prompt_in_json],
215
+ )
216
 
217
+ # Generate image (generates JSON first if needed)
218
+ generate_image_btn.click(
219
+ fn=generate_image,
220
  inputs=[
221
  prompt_generate,
 
222
  prompt_inspire_image,
223
+ use_inspire,
224
+ prompt_in_json,
225
+ negative_prompt,
226
+ seed,
227
+ randomize_seed,
228
+ width,
229
+ height,
230
+ guidance_scale,
231
+ num_inference_steps,
232
+ ],
233
+ outputs=[result, seed, prompt_in_json, negative_prompt_json, refine_btn],
234
+ )
235
+
236
+ # Refine image (reuses the main prompt box)
237
+ refine_btn.click(
238
+ fn=refine_prompt,
239
+ inputs=[
240
+ prompt_generate, # Reuse the main prompt box
241
  prompt_in_json,
242
  negative_prompt,
243
  seed,
 
246
  height,
247
  guidance_scale,
248
  num_inference_steps,
 
249
  ],
250
+ outputs=[result, seed, prompt_in_json, negative_prompt_json],
251
  )
252
 
253
+ demo.queue().launch()