Spaces:
Sleeping
Sleeping
Upload 2 files
Browse files- app.py +166 -28
- requirements.txt +4 -1
app.py
CHANGED
|
@@ -1,51 +1,189 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
import torch
|
| 3 |
from diffusers import StableDiffusionUpscalePipeline
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
-
#
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
)
|
|
|
|
| 11 |
|
| 12 |
-
|
| 13 |
-
|
| 14 |
|
| 15 |
-
#
|
| 16 |
-
|
| 17 |
|
| 18 |
-
def
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
generator = torch.manual_seed(42)
|
| 24 |
output = pipe(
|
| 25 |
prompt=prompt,
|
| 26 |
image=input_img,
|
| 27 |
-
num_inference_steps=20,
|
| 28 |
guidance_scale=7.0,
|
| 29 |
generator=generator
|
| 30 |
).images[0]
|
| 31 |
|
| 32 |
return output
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
desc = """
|
| 35 |
-
###
|
| 36 |
-
|
| 37 |
-
*
|
|
|
|
|
|
|
| 38 |
"""
|
| 39 |
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
gr.
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
iface.launch()
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import torch
|
| 3 |
from diffusers import StableDiffusionUpscalePipeline
|
| 4 |
+
from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution
|
| 5 |
+
import gc
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import numpy as np
|
| 8 |
+
import logging
|
| 9 |
+
import io
|
| 10 |
+
import os
|
| 11 |
+
import requests
|
| 12 |
+
from spandrel import ModelLoader
|
| 13 |
|
| 14 |
+
# Setup logging
|
| 15 |
+
log_capture_string = io.StringIO()
|
| 16 |
+
ch = logging.StreamHandler(log_capture_string)
|
| 17 |
+
ch.setLevel(logging.INFO)
|
| 18 |
+
logger = logging.getLogger()
|
| 19 |
+
logger.setLevel(logging.INFO)
|
| 20 |
+
logger.addHandler(ch)
|
| 21 |
|
| 22 |
+
def get_logs():
|
| 23 |
+
return log_capture_string.getvalue()
|
| 24 |
|
| 25 |
+
# Global models cache
|
| 26 |
+
models = {}
|
| 27 |
|
| 28 |
+
def download_file(url, filename):
|
| 29 |
+
if not os.path.exists(filename):
|
| 30 |
+
logger.info(f"Downloading {filename}...")
|
| 31 |
+
response = requests.get(url, stream=True)
|
| 32 |
+
with open(filename, 'wb') as f:
|
| 33 |
+
for chunk in response.iter_content(chunk_size=8192):
|
| 34 |
+
f.write(chunk)
|
| 35 |
+
logger.info(f"Downloaded {filename}.")
|
| 36 |
+
return filename
|
| 37 |
+
|
| 38 |
+
def load_realesrgan_x2():
|
| 39 |
+
if "realesrgan_x2" not in models:
|
| 40 |
+
logger.info("Loading RealESRGAN x2plus model...")
|
| 41 |
+
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth"
|
| 42 |
+
model_path = download_file(url, "RealESRGAN_x2plus.pth")
|
| 43 |
+
|
| 44 |
+
model = ModelLoader().load_from_file(model_path)
|
| 45 |
+
model.eval()
|
| 46 |
+
|
| 47 |
+
# Move to CPU (or CUDA if available, but we focus on CPU here)
|
| 48 |
+
device = torch.device("cpu")
|
| 49 |
+
model.to(device)
|
| 50 |
+
|
| 51 |
+
models["realesrgan_x2"] = model
|
| 52 |
+
logger.info("RealESRGAN x2plus loaded.")
|
| 53 |
+
return models["realesrgan_x2"]
|
| 54 |
+
|
| 55 |
+
def load_swin2sr_x2():
|
| 56 |
+
if "swin2sr_x2" not in models:
|
| 57 |
+
logger.info("Loading Swin2SR x2 model...")
|
| 58 |
+
model_id = "caidas/swin2SR-classical-sr-x2-64"
|
| 59 |
+
processor = AutoImageProcessor.from_pretrained(model_id)
|
| 60 |
+
model = Swin2SRForImageSuperResolution.from_pretrained(model_id)
|
| 61 |
+
models["swin2sr_x2"] = (processor, model)
|
| 62 |
+
logger.info("Swin2SR x2 loaded.")
|
| 63 |
+
return models["swin2sr_x2"]
|
| 64 |
+
|
| 65 |
+
def load_sd_x4():
|
| 66 |
+
if "sd_x4" not in models:
|
| 67 |
+
logger.info("Loading Stable Diffusion x4 model (this might take a while)...")
|
| 68 |
+
model_id = "stabilityai/stable-diffusion-x4-upscaler"
|
| 69 |
+
pipe = StableDiffusionUpscalePipeline.from_pretrained(
|
| 70 |
+
model_id,
|
| 71 |
+
torch_dtype=torch.float32,
|
| 72 |
+
low_cpu_mem_usage=True
|
| 73 |
+
)
|
| 74 |
+
pipe.enable_attention_slicing("max")
|
| 75 |
+
pipe.enable_vae_tiling()
|
| 76 |
+
models["sd_x4"] = pipe
|
| 77 |
+
logger.info("Stable Diffusion x4 loaded.")
|
| 78 |
+
return models["sd_x4"]
|
| 79 |
+
|
| 80 |
+
def upscale_realesrgan(input_img):
|
| 81 |
+
model = load_realesrgan_x2()
|
| 82 |
+
|
| 83 |
+
# Convert PIL to Tensor
|
| 84 |
+
img_np = np.array(input_img).astype(np.float32) / 255.0
|
| 85 |
+
img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0)
|
| 86 |
+
|
| 87 |
+
with torch.no_grad():
|
| 88 |
+
output_tensor = model(img_tensor)
|
| 89 |
+
|
| 90 |
+
# Convert Tensor back to PIL
|
| 91 |
+
output_np = output_tensor.squeeze(0).permute(1, 2, 0).clamp(0, 1).numpy()
|
| 92 |
+
output_np = (output_np * 255.0).round().astype(np.uint8)
|
| 93 |
+
|
| 94 |
+
return Image.fromarray(output_np)
|
| 95 |
+
|
| 96 |
+
def upscale_swin2sr(input_img, scale=2):
|
| 97 |
+
processor, model = load_swin2sr_x2()
|
| 98 |
+
|
| 99 |
+
inputs = processor(images=input_img, return_tensors="pt")
|
| 100 |
|
| 101 |
+
with torch.no_grad():
|
| 102 |
+
outputs = model(**inputs)
|
| 103 |
+
|
| 104 |
+
output = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
| 105 |
+
output = np.moveaxis(output, source=0, destination=-1)
|
| 106 |
+
output = (output * 255.0).round().astype(np.uint8)
|
| 107 |
+
|
| 108 |
+
return Image.fromarray(output)
|
| 109 |
+
|
| 110 |
+
def upscale_diffusion_cpu(input_img, prompt):
|
| 111 |
+
pipe = load_sd_x4()
|
| 112 |
+
|
| 113 |
+
# Resize input if too large to prevent OOM
|
| 114 |
+
max_size = 512
|
| 115 |
+
if max(input_img.size) > max_size:
|
| 116 |
+
ratio = max_size / max(input_img.size)
|
| 117 |
+
new_size = (int(input_img.size[0] * ratio), int(input_img.size[1] * ratio))
|
| 118 |
+
input_img = input_img.resize(new_size, Image.Resampling.LANCZOS)
|
| 119 |
+
logger.warning(f"Resized input to {new_size} to prevent OOM")
|
| 120 |
+
|
| 121 |
generator = torch.manual_seed(42)
|
| 122 |
output = pipe(
|
| 123 |
prompt=prompt,
|
| 124 |
image=input_img,
|
| 125 |
+
num_inference_steps=20,
|
| 126 |
guidance_scale=7.0,
|
| 127 |
generator=generator
|
| 128 |
).images[0]
|
| 129 |
|
| 130 |
return output
|
| 131 |
|
| 132 |
+
def process_image(input_img, model_name, prompt):
|
| 133 |
+
if input_img is None:
|
| 134 |
+
return None, get_logs()
|
| 135 |
+
|
| 136 |
+
logger.info(f"Processing image with {model_name}...")
|
| 137 |
+
try:
|
| 138 |
+
if model_name == "RealESRGAN x2":
|
| 139 |
+
output = upscale_realesrgan(input_img)
|
| 140 |
+
elif model_name == "Swin2SR x2":
|
| 141 |
+
output = upscale_swin2sr(input_img, scale=2)
|
| 142 |
+
elif model_name == "Stable Diffusion x4":
|
| 143 |
+
output = upscale_diffusion_cpu(input_img, prompt)
|
| 144 |
+
else:
|
| 145 |
+
output = input_img # Fallback
|
| 146 |
+
|
| 147 |
+
gc.collect()
|
| 148 |
+
logger.info("Processing complete.")
|
| 149 |
+
return output, get_logs()
|
| 150 |
+
except Exception as e:
|
| 151 |
+
logger.error(f"Error: {str(e)}")
|
| 152 |
+
return None, get_logs()
|
| 153 |
+
|
| 154 |
desc = """
|
| 155 |
+
### Multi-Model Upscaler
|
| 156 |
+
Select a model to upscale your image.
|
| 157 |
+
* **RealESRGAN x2**: Very fast, sharp results. Best for general photos.
|
| 158 |
+
* **Swin2SR x2**: Accurate, good for compressed images. Slower than RealESRGAN.
|
| 159 |
+
* **Stable Diffusion x4**: Slow, creative, high memory usage. Adds details but may hallucinate.
|
| 160 |
"""
|
| 161 |
|
| 162 |
+
with gr.Blocks(title="Universal Upscaler") as iface:
|
| 163 |
+
gr.Markdown(desc)
|
| 164 |
+
|
| 165 |
+
with gr.Row():
|
| 166 |
+
with gr.Column():
|
| 167 |
+
input_image = gr.Image(type="pil", label="Input Image")
|
| 168 |
+
model_selector = gr.Dropdown(
|
| 169 |
+
choices=["RealESRGAN x2", "Swin2SR x2", "Stable Diffusion x4"],
|
| 170 |
+
value="RealESRGAN x2",
|
| 171 |
+
label="Select Model"
|
| 172 |
+
)
|
| 173 |
+
prompt_input = gr.Textbox(
|
| 174 |
+
label="Prompt (for Stable Diffusion only)",
|
| 175 |
+
value="highly detailed, 4k, sharp"
|
| 176 |
+
)
|
| 177 |
+
submit_btn = gr.Button("Upscale")
|
| 178 |
+
|
| 179 |
+
with gr.Column():
|
| 180 |
+
output_image = gr.Image(type="pil", label="Upscaled Image")
|
| 181 |
+
logs_output = gr.TextArea(label="Logs", interactive=False)
|
| 182 |
+
|
| 183 |
+
submit_btn.click(
|
| 184 |
+
fn=process_image,
|
| 185 |
+
inputs=[input_image, model_selector, prompt_input],
|
| 186 |
+
outputs=[output_image, logs_output]
|
| 187 |
+
)
|
| 188 |
|
| 189 |
iface.launch()
|
requirements.txt
CHANGED
|
@@ -4,4 +4,7 @@ transformers
|
|
| 4 |
accelerate
|
| 5 |
scipy
|
| 6 |
pillow
|
| 7 |
-
gradio
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
accelerate
|
| 5 |
scipy
|
| 6 |
pillow
|
| 7 |
+
gradio
|
| 8 |
+
opencv-python
|
| 9 |
+
spandrel
|
| 10 |
+
requests
|