jgitsolutions commited on
Commit
b2759ab
·
verified ·
1 Parent(s): 16c0cde

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +166 -28
  2. requirements.txt +4 -1
app.py CHANGED
@@ -1,51 +1,189 @@
1
  import gradio as gr
2
  import torch
3
  from diffusers import StableDiffusionUpscalePipeline
 
 
 
 
 
 
 
 
 
4
 
5
- # Load pipeline efficiently for CPU
6
- model_id = "stabilityai/stable-diffusion-x4-upscaler"
7
- pipe = StableDiffusionUpscalePipeline.from_pretrained(
8
- model_id,
9
- torch_dtype=torch.float32
10
- )
 
11
 
12
- # 1. SLICING: Cuts attention computation into chunks to save RAM
13
- pipe.enable_attention_slicing("max")
14
 
15
- # 2. OFFLOADING: Moves unused model parts to RAM (critical for low VRAM/CPU)
16
- # pipe.enable_sequential_cpu_offload() # Only works with GPU to save VRAM. On CPU-only machines, this is not needed/supported.
17
 
18
- def upscale_diffusion_cpu(input_img, prompt="high quality, detailed"):
19
- # Resize for the specific pipeline requirements if needed,
20
- # but x4 upscaler handles low-res inputs naturally.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- # CPU Inference is slow, so we limit steps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  generator = torch.manual_seed(42)
24
  output = pipe(
25
  prompt=prompt,
26
  image=input_img,
27
- num_inference_steps=20, # Lower steps for CPU speed (usually 50+)
28
  guidance_scale=7.0,
29
  generator=generator
30
  ).images[0]
31
 
32
  return output
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  desc = """
35
- ### Memory Efficient Diffusion Upscaling (CPU)
36
- This demo uses **Attention Slicing** and **Sequential Offloading** to run a heavy Latent Diffusion model on CPU.
37
- *Note: Diffusion on CPU is significantly slower than CNNs (EDSR) but generates hallucinations for missing details.*
 
 
38
  """
39
 
40
- iface = gr.Interface(
41
- fn=upscale_diffusion_cpu,
42
- inputs=[
43
- gr.Image(type="pil", label="Low Res Input"),
44
- gr.Textbox(label="Prompt (Optional)", value="highly detailed, 4k, sharp")
45
- ],
46
- outputs=gr.Image(type="pil", label="Diffusion Upscaled"),
47
- title="Memory Efficient Diffusion Upscaler",
48
- description=desc
49
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  iface.launch()
 
1
  import gradio as gr
2
  import torch
3
  from diffusers import StableDiffusionUpscalePipeline
4
+ from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution
5
+ import gc
6
+ from PIL import Image
7
+ import numpy as np
8
+ import logging
9
+ import io
10
+ import os
11
+ import requests
12
+ from spandrel import ModelLoader
13
 
14
+ # Setup logging
15
+ log_capture_string = io.StringIO()
16
+ ch = logging.StreamHandler(log_capture_string)
17
+ ch.setLevel(logging.INFO)
18
+ logger = logging.getLogger()
19
+ logger.setLevel(logging.INFO)
20
+ logger.addHandler(ch)
21
 
22
+ def get_logs():
23
+ return log_capture_string.getvalue()
24
 
25
+ # Global models cache
26
+ models = {}
27
 
28
+ def download_file(url, filename):
29
+ if not os.path.exists(filename):
30
+ logger.info(f"Downloading {filename}...")
31
+ response = requests.get(url, stream=True)
32
+ with open(filename, 'wb') as f:
33
+ for chunk in response.iter_content(chunk_size=8192):
34
+ f.write(chunk)
35
+ logger.info(f"Downloaded {filename}.")
36
+ return filename
37
+
38
+ def load_realesrgan_x2():
39
+ if "realesrgan_x2" not in models:
40
+ logger.info("Loading RealESRGAN x2plus model...")
41
+ url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth"
42
+ model_path = download_file(url, "RealESRGAN_x2plus.pth")
43
+
44
+ model = ModelLoader().load_from_file(model_path)
45
+ model.eval()
46
+
47
+ # Move to CPU (or CUDA if available, but we focus on CPU here)
48
+ device = torch.device("cpu")
49
+ model.to(device)
50
+
51
+ models["realesrgan_x2"] = model
52
+ logger.info("RealESRGAN x2plus loaded.")
53
+ return models["realesrgan_x2"]
54
+
55
+ def load_swin2sr_x2():
56
+ if "swin2sr_x2" not in models:
57
+ logger.info("Loading Swin2SR x2 model...")
58
+ model_id = "caidas/swin2SR-classical-sr-x2-64"
59
+ processor = AutoImageProcessor.from_pretrained(model_id)
60
+ model = Swin2SRForImageSuperResolution.from_pretrained(model_id)
61
+ models["swin2sr_x2"] = (processor, model)
62
+ logger.info("Swin2SR x2 loaded.")
63
+ return models["swin2sr_x2"]
64
+
65
+ def load_sd_x4():
66
+ if "sd_x4" not in models:
67
+ logger.info("Loading Stable Diffusion x4 model (this might take a while)...")
68
+ model_id = "stabilityai/stable-diffusion-x4-upscaler"
69
+ pipe = StableDiffusionUpscalePipeline.from_pretrained(
70
+ model_id,
71
+ torch_dtype=torch.float32,
72
+ low_cpu_mem_usage=True
73
+ )
74
+ pipe.enable_attention_slicing("max")
75
+ pipe.enable_vae_tiling()
76
+ models["sd_x4"] = pipe
77
+ logger.info("Stable Diffusion x4 loaded.")
78
+ return models["sd_x4"]
79
+
80
+ def upscale_realesrgan(input_img):
81
+ model = load_realesrgan_x2()
82
+
83
+ # Convert PIL to Tensor
84
+ img_np = np.array(input_img).astype(np.float32) / 255.0
85
+ img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0)
86
+
87
+ with torch.no_grad():
88
+ output_tensor = model(img_tensor)
89
+
90
+ # Convert Tensor back to PIL
91
+ output_np = output_tensor.squeeze(0).permute(1, 2, 0).clamp(0, 1).numpy()
92
+ output_np = (output_np * 255.0).round().astype(np.uint8)
93
+
94
+ return Image.fromarray(output_np)
95
+
96
+ def upscale_swin2sr(input_img, scale=2):
97
+ processor, model = load_swin2sr_x2()
98
+
99
+ inputs = processor(images=input_img, return_tensors="pt")
100
 
101
+ with torch.no_grad():
102
+ outputs = model(**inputs)
103
+
104
+ output = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
105
+ output = np.moveaxis(output, source=0, destination=-1)
106
+ output = (output * 255.0).round().astype(np.uint8)
107
+
108
+ return Image.fromarray(output)
109
+
110
+ def upscale_diffusion_cpu(input_img, prompt):
111
+ pipe = load_sd_x4()
112
+
113
+ # Resize input if too large to prevent OOM
114
+ max_size = 512
115
+ if max(input_img.size) > max_size:
116
+ ratio = max_size / max(input_img.size)
117
+ new_size = (int(input_img.size[0] * ratio), int(input_img.size[1] * ratio))
118
+ input_img = input_img.resize(new_size, Image.Resampling.LANCZOS)
119
+ logger.warning(f"Resized input to {new_size} to prevent OOM")
120
+
121
  generator = torch.manual_seed(42)
122
  output = pipe(
123
  prompt=prompt,
124
  image=input_img,
125
+ num_inference_steps=20,
126
  guidance_scale=7.0,
127
  generator=generator
128
  ).images[0]
129
 
130
  return output
131
 
132
+ def process_image(input_img, model_name, prompt):
133
+ if input_img is None:
134
+ return None, get_logs()
135
+
136
+ logger.info(f"Processing image with {model_name}...")
137
+ try:
138
+ if model_name == "RealESRGAN x2":
139
+ output = upscale_realesrgan(input_img)
140
+ elif model_name == "Swin2SR x2":
141
+ output = upscale_swin2sr(input_img, scale=2)
142
+ elif model_name == "Stable Diffusion x4":
143
+ output = upscale_diffusion_cpu(input_img, prompt)
144
+ else:
145
+ output = input_img # Fallback
146
+
147
+ gc.collect()
148
+ logger.info("Processing complete.")
149
+ return output, get_logs()
150
+ except Exception as e:
151
+ logger.error(f"Error: {str(e)}")
152
+ return None, get_logs()
153
+
154
  desc = """
155
+ ### Multi-Model Upscaler
156
+ Select a model to upscale your image.
157
+ * **RealESRGAN x2**: Very fast, sharp results. Best for general photos.
158
+ * **Swin2SR x2**: Accurate, good for compressed images. Slower than RealESRGAN.
159
+ * **Stable Diffusion x4**: Slow, creative, high memory usage. Adds details but may hallucinate.
160
  """
161
 
162
+ with gr.Blocks(title="Universal Upscaler") as iface:
163
+ gr.Markdown(desc)
164
+
165
+ with gr.Row():
166
+ with gr.Column():
167
+ input_image = gr.Image(type="pil", label="Input Image")
168
+ model_selector = gr.Dropdown(
169
+ choices=["RealESRGAN x2", "Swin2SR x2", "Stable Diffusion x4"],
170
+ value="RealESRGAN x2",
171
+ label="Select Model"
172
+ )
173
+ prompt_input = gr.Textbox(
174
+ label="Prompt (for Stable Diffusion only)",
175
+ value="highly detailed, 4k, sharp"
176
+ )
177
+ submit_btn = gr.Button("Upscale")
178
+
179
+ with gr.Column():
180
+ output_image = gr.Image(type="pil", label="Upscaled Image")
181
+ logs_output = gr.TextArea(label="Logs", interactive=False)
182
+
183
+ submit_btn.click(
184
+ fn=process_image,
185
+ inputs=[input_image, model_selector, prompt_input],
186
+ outputs=[output_image, logs_output]
187
+ )
188
 
189
  iface.launch()
requirements.txt CHANGED
@@ -4,4 +4,7 @@ transformers
4
  accelerate
5
  scipy
6
  pillow
7
- gradio
 
 
 
 
4
  accelerate
5
  scipy
6
  pillow
7
+ gradio
8
+ opencv-python
9
+ spandrel
10
+ requests