Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import numpy as np | |
| import random | |
| import spaces | |
| import torch | |
| import types | |
| from diffusers.pipelines.prx import PRXPipeline | |
| # monkey patch to add 1024 aspect ratios | |
| import diffusers.pipelines.prx.pipeline_prx as prx_mod | |
| import math | |
| def get_timestep_embedding( | |
| timesteps: torch.Tensor, | |
| embedding_dim: int, | |
| flip_sin_to_cos: bool = False, | |
| downscale_freq_shift: float = 1, | |
| scale: float = 0, | |
| max_period: int = 10000, | |
| ) -> torch.Tensor: | |
| """ | |
| This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. | |
| Args | |
| timesteps (torch.Tensor): | |
| a 1-D Tensor of N indices, one per batch element. These may be fractional. | |
| embedding_dim (int): | |
| the dimension of the output. | |
| flip_sin_to_cos (bool): | |
| Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False) | |
| downscale_freq_shift (float): | |
| Controls the delta between frequencies between dimensions | |
| scale (float): | |
| Scaling factor applied to the embeddings. | |
| max_period (int): | |
| Controls the maximum frequency of the embeddings | |
| Returns | |
| torch.Tensor: an [N x dim] Tensor of positional embeddings. | |
| """ | |
| assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" | |
| half_dim = embedding_dim // 2 | |
| exponent = -math.log(max_period) * torch.arange( | |
| start=0, end=half_dim, dtype=torch.float32, device=timesteps.device | |
| ) | |
| exponent = exponent / (half_dim - downscale_freq_shift) | |
| emb = torch.exp(exponent) | |
| emb = timesteps[:, None].float() * emb[None, :] | |
| # scale embeddings | |
| emb = scale * emb | |
| # concat sine and cosine embeddings | |
| emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) | |
| # flip sine and cosine embeddings | |
| if flip_sin_to_cos: | |
| emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) | |
| # zero pad | |
| if embedding_dim % 2 == 1: | |
| emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) | |
| return emb | |
| def _compute_timestep_embedding(self, timestep: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: | |
| return self.time_in( | |
| get_timestep_embedding( | |
| timesteps=timestep, | |
| embedding_dim=256, | |
| max_period=self.time_max_period, | |
| scale=self.time_factor, | |
| flip_sin_to_cos=True, # Match original cos, sin order | |
| downscale_freq_shift=0.0, | |
| ).to(dtype) | |
| ) | |
| CUSTOM_ASPECT_RATIO_512_BIN = { | |
| "0.49": [704, 1440], | |
| "0.52": [736, 1408], | |
| "0.53": [736, 1376], | |
| "0.57": [768, 1344], | |
| "0.59": [768, 1312], | |
| "0.62": [800, 1280], | |
| "0.67": [832, 1248], | |
| "0.68": [832, 1216], | |
| "0.78": [896, 1152], | |
| "0.83": [928, 1120], | |
| "0.94": [992, 1056], | |
| "1.0": [1024, 1024], | |
| "1.06": [1056, 992], | |
| "1.13": [1088, 960], | |
| "1.21": [1120, 928], | |
| "1.29": [1152, 896], | |
| "1.37": [1184, 864], | |
| "1.46": [1216, 832], | |
| "1.5": [1248, 832], | |
| "1.71": [1312, 768], | |
| "1.75": [1344, 768], | |
| "1.87": [1376, 736], | |
| "1.91": [1408, 736], | |
| "2.05": [1440, 704], | |
| } | |
| prx_mod.ASPECT_RATIO_512_BIN = CUSTOM_ASPECT_RATIO_512_BIN | |
| dtype = torch.bfloat16 | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| pipe = PRXPipeline.from_pretrained( | |
| "Photoroom/prx-1024-t2i-beta", | |
| torch_dtype=dtype | |
| ).to(device) | |
| # Properly bind the method to the instance using types.MethodType | |
| pipe.transformer._compute_timestep_embedding = types.MethodType(_compute_timestep_embedding, pipe.transformer) | |
| MAX_SEED = np.iinfo(np.int32).max | |
| MAX_IMAGE_SIZE = 1024 | |
| def infer( | |
| prompt, | |
| negative_prompt="", # <-- NEW | |
| seed=42, | |
| randomize_seed=False, | |
| width=1024, | |
| height=1024, | |
| num_inference_steps=28, | |
| guidance_scale=4.0, | |
| progress=gr.Progress(track_tqdm=True) | |
| ): | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| generator = torch.Generator(device=device).manual_seed(seed) | |
| image = pipe( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, # <-- NEW | |
| width=width, | |
| height=height, | |
| num_inference_steps=num_inference_steps, | |
| generator=generator, | |
| guidance_scale=guidance_scale, | |
| ).images[0] | |
| return image, seed # <-- IMPORTANT: return for Gradio | |
| examples = [ | |
| # ["A massive black monolith standing alone in a mirror-like salt flat after rainfall, horizon dissolving into pastel pink and cyan, reflections perfect and infinite, minimalist 2.39:1 frame, cinematic atmosphere of silence, RED Komodo 6K capture, 35 mm lens, ND filter, high dynamic range, ultra-clean tones and soft ambient light.", ""], | |
| ["A turtle covered in vibrant ceramic mosaic tiles, tiny geometric patterns, resting on weathered stone steps in a Mediterranean town square, warm daylight, artisanal feel", ""], | |
| ["Hundreds of paper lanterns drifting along a quiet river at dusk, soft orange light piercing cold blue mist, reflections trembling across rippled water, camera at water level with shallow DOF, cinematic color contrast of warm and cool tones", ""], | |
| ["A woman standing ankle-deep in the ocean at dawn, gentle waves touching her feet, mist and pastel horizon, cinematic wide composition, calm and contemplative mood, filmic color grading reminiscent of Terrence Malick's imagery.", ""], | |
| # ["In the courtyard of a coastal house, white sheets flap slowly in the wind, a woman pauses between hanging clothes, eyes closed, light flickering through the fabric. A flock of seagulls turns sharply overhead, casting moving shadows on the walls. The sound of waves faintly audible, palette of whites, greys, and sun-bleached blues, evokes transience and memory.", ""], | |
| # ["A close-up portrait in a photography studio, multiple soft light sources creating gradients of shadow on her face, minimal background, cinematic 4 K realism, artistic focus on light and emotion rather than glamour.", ""], | |
| ["A cat sculpted from fine white porcelain with delicate blue floral motifs, standing gracefully in a minimalist contemporary art gallery, polished marble floor reflections, soft museum lighting, ultra-detailed ceramic gloss", ""], | |
| ["A whimsical fantasy dog made entirely from layered paper cutouts, textured handmade paper, watercolor patterns on its body, pastel tones, enchanted meadow, soft glow, playful mood, highly detailed illustration.", ""], | |
| ["A front-facing portrait of a lion on the golden savanna at sunset.", ""], | |
| ["An owl sculpted from layered book pages, faint text visible on feathers, perched on a wooden reading desk in a grand library, golden lamplight, quiet scholarly ambience", ""], | |
| ["Une peinture numérique d’un vieux tram rouillé reposant sur une plage de sable balayée par le vent, ses couleurs délavées brillant doucement sous la lumière dorée du soir", ""], | |
| ["A digital painting depicts a herd of African elephants traversing a dry, grassy savanna.", ""], | |
| ["A fox constructed from tightly coiled metal wire strands, intricate loops, semi-transparent silhouette, perched on an old brick rooftop in a calm city evening, soft warm window lights, poetic mood", ""] | |
| ] | |
| css = """ | |
| #col-container { | |
| margin: 0 auto; | |
| max-width: 640px; | |
| } | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| with gr.Column(elem_id="col-container"): | |
| gr.Markdown("# PRX 1.2B Image Generator") | |
| gr.Markdown("Generate high-quality images using the beta-preview of PRX.") | |
| gr.Markdown("Works best with very detailed prompts in natural language.") | |
| prompt = gr.Text( | |
| label="Prompt", | |
| show_label=True, | |
| max_lines=2, | |
| placeholder="Enter your prompt", | |
| ) | |
| negative_prompt = gr.Text( # <-- NEW UI CONTROL | |
| label="Negative prompt", | |
| max_lines=2, | |
| placeholder="Things to avoid (e.g., blurry, low-res, extra limbs...)", | |
| value="" | |
| ) | |
| with gr.Row(): | |
| run_button = gr.Button("Run", scale=0) | |
| result = gr.Image(label="Result", show_label=False) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| seed = gr.Slider( | |
| label="Seed", | |
| minimum=0, | |
| maximum=MAX_SEED, | |
| step=1, | |
| value=0, | |
| ) | |
| randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
| with gr.Row(): | |
| width = gr.Slider( | |
| label="Width", | |
| minimum=700, | |
| maximum=1440, | |
| step=1, | |
| value=1024, | |
| ) | |
| height = gr.Slider( | |
| label="Height", | |
| minimum=700, | |
| maximum=1440, | |
| step=1, | |
| value=1024, | |
| ) | |
| with gr.Row(): | |
| num_inference_steps = gr.Slider( | |
| label="Number of inference steps", | |
| minimum=1, | |
| maximum=50, | |
| step=1, | |
| value=28, | |
| ) | |
| guidance_scale = gr.Slider( | |
| label="Guidance scale", | |
| minimum=1.0, | |
| maximum=7.0, | |
| step=0.1, | |
| value=4.0, | |
| ) | |
| gr.Examples( | |
| examples=examples, | |
| fn=infer, | |
| inputs=[prompt, negative_prompt], # <-- NEW | |
| outputs=[result, seed], | |
| cache_examples="lazy" | |
| ) | |
| gr.on( | |
| triggers=[run_button.click, prompt.submit], | |
| fn=infer, | |
| inputs=[ | |
| prompt, | |
| negative_prompt, # <-- NEW | |
| seed, | |
| randomize_seed, | |
| width, | |
| height, | |
| num_inference_steps, | |
| guidance_scale, | |
| ], | |
| outputs=[result, seed] | |
| ) | |
| demo.launch() | |