linoyts's picture
linoyts HF Staff
Update app.py
f88bac4 verified
raw
history blame
5.71 kB
import gradio as gr
import numpy as np
import random
import torch
import spaces
from PIL import Image
from diffusers import FlowMatchEulerDiscreteScheduler
from optimization import optimize_pipeline_
from diffusers import QwenImageEditPlusPipeline, QwenImageTransformer2DModel
# from qwenimage.pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
# from qwenimage.transformer_qwenimage import QwenImageTransformer2DModel
# from qwenimage.qwen_fa3_processor import QwenDoubleStreamAttnProcessorFA3
import math
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from PIL import Image
import os
import gradio as gr
from gradio_client import Client, handle_file
import tempfile
# --- Model Loading ---
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
# Scheduler configuration for Lightning
scheduler_config = {
"base_image_seq_len": 256,
"base_shift": math.log(3),
"invert_sigmas": False,
"max_image_seq_len": 8192,
"max_shift": math.log(3),
"num_train_timesteps": 1000,
"shift": 1.0,
"shift_terminal": None,
"stochastic_sampling": False,
"time_shift_type": "exponential",
"use_beta_sigmas": False,
"use_dynamic_shifting": True,
"use_exponential_sigmas": False,
"use_karras_sigmas": False,
}
# Initialize scheduler with Lightning config
scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config)
pipe = QwenImageEditPlusPipeline.from_pretrained("Qwen/Qwen-Image-Edit-2509",
scheduler=scheduler,
torch_dtype=dtype).to(device)
pipe.load_lora_weights(
"lightx2v/Qwen-Image-Lightning",
weight_name="Qwen-Image-Lightning-4steps-V2.0.safetensors", adapter_name="fast"
)
pipe.load_lora_weights(
"dx8152/Qwen-Image-Edit-2509-Fusion",
weight_name="溶图.safetensors", adapter_name="fusion"
)
pipe.set_adapters(["fast", "fusion"], adapter_weights=[1.,1.])
pipe.fuse_lora(adapter_names=["fast"])
pipe.fuse_lora(adapter_names=["fusion"])
pipe.unload_lora_weights()
# pipe.transformer.__class__ = QwenImageTransformer2DModel
# pipe.transformer.set_attn_processor(QwenDoubleStreamAttnProcessorFA3())
# optimize_pipeline_(pipe, image=[Image.new("RGB", (1024, 1024)), Image.new("RGB", (1024, 1024))], prompt="prompt")
MAX_SEED = np.iinfo(np.int32).max
@spaces.GPU
def infer(
image_subject,
image_background=None,
prompt="",
seed=42,
randomize_seed=True,
true_guidance_scale=1,
num_inference_steps=4,
height=None,
width=None,
progress=gr.Progress(track_tqdm=True)
):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device=device).manual_seed(seed)
result = pipe(
image=image_subject,
prompt=prompt,
# height=height,
# width=width,
num_inference_steps=num_inference_steps,
generator=generator,
true_cfg_scale=true_guidance_scale,
num_images_per_prompt=1,
).images[0]
return [image_subject,result], seed
# --- UI ---
css = '''#col-container { max-width: 800px; margin: 0 auto; }
.dark .progress-text{color: white !important}
#examples{max-width: 800px; margin: 0 auto; }'''
with gr.Blocks(theme=gr.themes.Citrus(), css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("## Qwen Image Edit — Fusion")
gr.Markdown("""
Qwen Image Edit 2509 ✨
Using [dx8152's Qwen-Image-Edit-2509 Fusion LoRA](https://huggingface.co/dx8152/Qwen-Image-Edit-2509-Fusion) and [lightx2v Qwen-Image-Lightning LoRA]() for 4-step inference 💨
"""
)
with gr.Row():
with gr.Column():
with gr.Row():
image_subject = gr.Image(label="input image", type="pil")
image_background = gr.Image(label="background Image", type="pil", visible=False)
prompt = gr.Textbox(label="prompt")
run_button = gr.Button("Fuse", variant="primary")
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
true_guidance_scale = gr.Slider(label="True Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=1.0)
num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=40, step=1, value=4)
height = gr.Slider(label="Height", minimum=256, maximum=2048, step=8, value=1024)
width = gr.Slider(label="Width", minimum=256, maximum=2048, step=8, value=1024)
with gr.Column():
result = gr.ImageSlider(label="Output Image", interactive=False)
prompt_preview = gr.Textbox(label="Processed Prompt", interactive=False, visible=False)
gr.Examples(
examples=[
["fusion_car.png"],["fusion_shoes.png"],
],
inputs=[image_subject],
outputs=[result,seed],
fn=infer,
cache_examples="lazy",
elem_id="examples"
)
inputs = [
image_subject,image_background, prompt,
seed, randomize_seed, true_guidance_scale, num_inference_steps, height, width
]
outputs = [result, seed]
run_event = run_button.click(
fn=infer,
inputs=inputs,
outputs=outputs
)
demo.launch(share=True)