Bertoin's picture
new model and prompts
2cabcff verified
import gradio as gr
import numpy as np
import random
import spaces
import torch
import types
from diffusers.pipelines.prx import PRXPipeline
# monkey patch to add 1024 aspect ratios
import diffusers.pipelines.prx.pipeline_prx as prx_mod
import math
def get_timestep_embedding(
timesteps: torch.Tensor,
embedding_dim: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1,
scale: float = 0,
max_period: int = 10000,
) -> torch.Tensor:
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
Args
timesteps (torch.Tensor):
a 1-D Tensor of N indices, one per batch element. These may be fractional.
embedding_dim (int):
the dimension of the output.
flip_sin_to_cos (bool):
Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
downscale_freq_shift (float):
Controls the delta between frequencies between dimensions
scale (float):
Scaling factor applied to the embeddings.
max_period (int):
Controls the maximum frequency of the embeddings
Returns
torch.Tensor: an [N x dim] Tensor of positional embeddings.
"""
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
# scale embeddings
emb = scale * emb
# concat sine and cosine embeddings
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
# flip sine and cosine embeddings
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
# zero pad
if embedding_dim % 2 == 1:
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
def _compute_timestep_embedding(self, timestep: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
return self.time_in(
get_timestep_embedding(
timesteps=timestep,
embedding_dim=256,
max_period=self.time_max_period,
scale=self.time_factor,
flip_sin_to_cos=True, # Match original cos, sin order
downscale_freq_shift=0.0,
).to(dtype)
)
CUSTOM_ASPECT_RATIO_512_BIN = {
"0.49": [704, 1440],
"0.52": [736, 1408],
"0.53": [736, 1376],
"0.57": [768, 1344],
"0.59": [768, 1312],
"0.62": [800, 1280],
"0.67": [832, 1248],
"0.68": [832, 1216],
"0.78": [896, 1152],
"0.83": [928, 1120],
"0.94": [992, 1056],
"1.0": [1024, 1024],
"1.06": [1056, 992],
"1.13": [1088, 960],
"1.21": [1120, 928],
"1.29": [1152, 896],
"1.37": [1184, 864],
"1.46": [1216, 832],
"1.5": [1248, 832],
"1.71": [1312, 768],
"1.75": [1344, 768],
"1.87": [1376, 736],
"1.91": [1408, 736],
"2.05": [1440, 704],
}
prx_mod.ASPECT_RATIO_512_BIN = CUSTOM_ASPECT_RATIO_512_BIN
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = PRXPipeline.from_pretrained(
"Photoroom/prx-1024-t2i-beta",
torch_dtype=dtype
).to(device)
# Properly bind the method to the instance using types.MethodType
pipe.transformer._compute_timestep_embedding = types.MethodType(_compute_timestep_embedding, pipe.transformer)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
@spaces.GPU()
def infer(
prompt,
negative_prompt="", # <-- NEW
seed=42,
randomize_seed=False,
width=1024,
height=1024,
num_inference_steps=28,
guidance_scale=4.0,
progress=gr.Progress(track_tqdm=True)
):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device=device).manual_seed(seed)
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt, # <-- NEW
width=width,
height=height,
num_inference_steps=num_inference_steps,
generator=generator,
guidance_scale=guidance_scale,
).images[0]
return image, seed # <-- IMPORTANT: return for Gradio
examples = [
# ["A massive black monolith standing alone in a mirror-like salt flat after rainfall, horizon dissolving into pastel pink and cyan, reflections perfect and infinite, minimalist 2.39:1 frame, cinematic atmosphere of silence, RED Komodo 6K capture, 35 mm lens, ND filter, high dynamic range, ultra-clean tones and soft ambient light.", ""],
["A turtle covered in vibrant ceramic mosaic tiles, tiny geometric patterns, resting on weathered stone steps in a Mediterranean town square, warm daylight, artisanal feel", ""],
["Hundreds of paper lanterns drifting along a quiet river at dusk, soft orange light piercing cold blue mist, reflections trembling across rippled water, camera at water level with shallow DOF, cinematic color contrast of warm and cool tones", ""],
["A woman standing ankle-deep in the ocean at dawn, gentle waves touching her feet, mist and pastel horizon, cinematic wide composition, calm and contemplative mood, filmic color grading reminiscent of Terrence Malick's imagery.", ""],
# ["In the courtyard of a coastal house, white sheets flap slowly in the wind, a woman pauses between hanging clothes, eyes closed, light flickering through the fabric. A flock of seagulls turns sharply overhead, casting moving shadows on the walls. The sound of waves faintly audible, palette of whites, greys, and sun-bleached blues, evokes transience and memory.", ""],
# ["A close-up portrait in a photography studio, multiple soft light sources creating gradients of shadow on her face, minimal background, cinematic 4 K realism, artistic focus on light and emotion rather than glamour.", ""],
["A cat sculpted from fine white porcelain with delicate blue floral motifs, standing gracefully in a minimalist contemporary art gallery, polished marble floor reflections, soft museum lighting, ultra-detailed ceramic gloss", ""],
["A whimsical fantasy dog made entirely from layered paper cutouts, textured handmade paper, watercolor patterns on its body, pastel tones, enchanted meadow, soft glow, playful mood, highly detailed illustration.", ""],
["A front-facing portrait of a lion on the golden savanna at sunset.", ""],
["An owl sculpted from layered book pages, faint text visible on feathers, perched on a wooden reading desk in a grand library, golden lamplight, quiet scholarly ambience", ""],
["Une peinture numérique d’un vieux tram rouillé reposant sur une plage de sable balayée par le vent, ses couleurs délavées brillant doucement sous la lumière dorée du soir", ""],
["A digital painting depicts a herd of African elephants traversing a dry, grassy savanna.", ""],
["A fox constructed from tightly coiled metal wire strands, intricate loops, semi-transparent silhouette, perched on an old brick rooftop in a calm city evening, soft warm window lights, poetic mood", ""]
]
css = """
#col-container {
margin: 0 auto;
max-width: 640px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("# PRX 1.2B Image Generator")
gr.Markdown("Generate high-quality images using the beta-preview of PRX.")
gr.Markdown("Works best with very detailed prompts in natural language.")
prompt = gr.Text(
label="Prompt",
show_label=True,
max_lines=2,
placeholder="Enter your prompt",
)
negative_prompt = gr.Text( # <-- NEW UI CONTROL
label="Negative prompt",
max_lines=2,
placeholder="Things to avoid (e.g., blurry, low-res, extra limbs...)",
value=""
)
with gr.Row():
run_button = gr.Button("Run", scale=0)
result = gr.Image(label="Result", show_label=False)
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=700,
maximum=1440,
step=1,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=700,
maximum=1440,
step=1,
value=1024,
)
with gr.Row():
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=28,
)
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=1.0,
maximum=7.0,
step=0.1,
value=4.0,
)
gr.Examples(
examples=examples,
fn=infer,
inputs=[prompt, negative_prompt], # <-- NEW
outputs=[result, seed],
cache_examples="lazy"
)
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[
prompt,
negative_prompt, # <-- NEW
seed,
randomize_seed,
width,
height,
num_inference_steps,
guidance_scale,
],
outputs=[result, seed]
)
demo.launch()