jgitsolutions commited on
Commit
e58a107
·
verified ·
1 Parent(s): b2759ab

Upload 3 files

Browse files
Files changed (3) hide show
  1. RealESRGAN_x2plus.pth +3 -0
  2. app.py +610 -141
  3. requirements.txt +8 -1
RealESRGAN_x2plus.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:49fafd45f8fd7aa8d31ab2a22d14d91b536c34494a5cfe31eb5d89c2fa266abb
3
+ size 67061725
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  import torch
3
- from diffusers import StableDiffusionUpscalePipeline
4
  from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution
5
  import gc
6
  from PIL import Image
@@ -10,180 +10,649 @@ import io
10
  import os
11
  import requests
12
  from spandrel import ModelLoader
 
 
 
 
 
13
 
14
- # Setup logging
15
- log_capture_string = io.StringIO()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  ch = logging.StreamHandler(log_capture_string)
17
  ch.setLevel(logging.INFO)
18
- logger = logging.getLogger()
 
 
 
19
  logger.setLevel(logging.INFO)
20
  logger.addHandler(ch)
21
 
22
- def get_logs():
 
23
  return log_capture_string.getvalue()
24
 
25
- # Global models cache
26
- models = {}
27
-
28
- def download_file(url, filename):
29
- if not os.path.exists(filename):
30
- logger.info(f"Downloading {filename}...")
31
- response = requests.get(url, stream=True)
32
- with open(filename, 'wb') as f:
33
- for chunk in response.iter_content(chunk_size=8192):
34
- f.write(chunk)
35
- logger.info(f"Downloaded {filename}.")
36
- return filename
37
-
38
- def load_realesrgan_x2():
39
- if "realesrgan_x2" not in models:
40
- logger.info("Loading RealESRGAN x2plus model...")
41
- url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth"
42
- model_path = download_file(url, "RealESRGAN_x2plus.pth")
43
-
44
- model = ModelLoader().load_from_file(model_path)
45
- model.eval()
46
-
47
- # Move to CPU (or CUDA if available, but we focus on CPU here)
48
- device = torch.device("cpu")
49
- model.to(device)
50
-
51
- models["realesrgan_x2"] = model
52
- logger.info("RealESRGAN x2plus loaded.")
53
- return models["realesrgan_x2"]
54
-
55
- def load_swin2sr_x2():
56
- if "swin2sr_x2" not in models:
57
- logger.info("Loading Swin2SR x2 model...")
58
- model_id = "caidas/swin2SR-classical-sr-x2-64"
59
- processor = AutoImageProcessor.from_pretrained(model_id)
60
- model = Swin2SRForImageSuperResolution.from_pretrained(model_id)
61
- models["swin2sr_x2"] = (processor, model)
62
- logger.info("Swin2SR x2 loaded.")
63
- return models["swin2sr_x2"]
64
-
65
- def load_sd_x4():
66
- if "sd_x4" not in models:
67
- logger.info("Loading Stable Diffusion x4 model (this might take a while)...")
68
- model_id = "stabilityai/stable-diffusion-x4-upscaler"
69
- pipe = StableDiffusionUpscalePipeline.from_pretrained(
70
- model_id,
71
- torch_dtype=torch.float32,
72
- low_cpu_mem_usage=True
73
- )
74
- pipe.enable_attention_slicing("max")
75
- pipe.enable_vae_tiling()
76
- models["sd_x4"] = pipe
77
- logger.info("Stable Diffusion x4 loaded.")
78
- return models["sd_x4"]
79
-
80
- def upscale_realesrgan(input_img):
81
- model = load_realesrgan_x2()
82
-
83
- # Convert PIL to Tensor
84
- img_np = np.array(input_img).astype(np.float32) / 255.0
85
- img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0)
86
-
87
- with torch.no_grad():
88
- output_tensor = model(img_tensor)
89
 
90
- # Convert Tensor back to PIL
91
- output_np = output_tensor.squeeze(0).permute(1, 2, 0).clamp(0, 1).numpy()
92
- output_np = (output_np * 255.0).round().astype(np.uint8)
 
 
 
 
 
 
 
 
 
 
93
 
94
- return Image.fromarray(output_np)
 
 
 
 
 
 
95
 
96
- def upscale_swin2sr(input_img, scale=2):
97
- processor, model = load_swin2sr_x2()
 
 
 
 
 
98
 
99
- inputs = processor(images=input_img, return_tensors="pt")
 
 
100
 
101
- with torch.no_grad():
102
- outputs = model(**inputs)
103
 
104
- output = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
105
- output = np.moveaxis(output, source=0, destination=-1)
106
- output = (output * 255.0).round().astype(np.uint8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
- return Image.fromarray(output)
109
 
110
- def upscale_diffusion_cpu(input_img, prompt):
111
- pipe = load_sd_x4()
112
-
113
- # Resize input if too large to prevent OOM
114
- max_size = 512
115
- if max(input_img.size) > max_size:
116
- ratio = max_size / max(input_img.size)
117
- new_size = (int(input_img.size[0] * ratio), int(input_img.size[1] * ratio))
118
- input_img = input_img.resize(new_size, Image.Resampling.LANCZOS)
119
- logger.warning(f"Resized input to {new_size} to prevent OOM")
120
-
121
- generator = torch.manual_seed(42)
122
- output = pipe(
123
- prompt=prompt,
124
- image=input_img,
125
- num_inference_steps=20,
126
- guidance_scale=7.0,
127
- generator=generator
128
- ).images[0]
129
 
130
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
- def process_image(input_img, model_name, prompt):
 
 
 
 
 
 
 
 
 
 
133
  if input_img is None:
134
- return None, get_logs()
135
 
136
- logger.info(f"Processing image with {model_name}...")
137
  try:
138
- if model_name == "RealESRGAN x2":
139
- output = upscale_realesrgan(input_img)
140
- elif model_name == "Swin2SR x2":
141
- output = upscale_swin2sr(input_img, scale=2)
142
- elif model_name == "Stable Diffusion x4":
143
- output = upscale_diffusion_cpu(input_img, prompt)
144
- else:
145
- output = input_img # Fallback
146
-
147
  gc.collect()
148
- logger.info("Processing complete.")
149
- return output, get_logs()
150
  except Exception as e:
151
- logger.error(f"Error: {str(e)}")
152
- return None, get_logs()
 
 
 
 
 
 
153
 
 
154
  desc = """
155
- ### Multi-Model Upscaler
156
- Select a model to upscale your image.
157
- * **RealESRGAN x2**: Very fast, sharp results. Best for general photos.
158
- * **Swin2SR x2**: Accurate, good for compressed images. Slower than RealESRGAN.
159
- * **Stable Diffusion x4**: Slow, creative, high memory usage. Adds details but may hallucinate.
 
 
160
  """
161
 
162
- with gr.Blocks(title="Universal Upscaler") as iface:
163
  gr.Markdown(desc)
164
 
165
  with gr.Row():
166
- with gr.Column():
167
  input_image = gr.Image(type="pil", label="Input Image")
168
- model_selector = gr.Dropdown(
169
- choices=["RealESRGAN x2", "Swin2SR x2", "Stable Diffusion x4"],
170
- value="RealESRGAN x2",
171
- label="Select Model"
172
- )
173
- prompt_input = gr.Textbox(
174
- label="Prompt (for Stable Diffusion only)",
175
- value="highly detailed, 4k, sharp"
176
- )
177
- submit_btn = gr.Button("Upscale")
178
-
179
- with gr.Column():
180
- output_image = gr.Image(type="pil", label="Upscaled Image")
181
- logs_output = gr.TextArea(label="Logs", interactive=False)
182
-
 
 
 
 
 
 
 
 
 
 
183
  submit_btn.click(
184
  fn=process_image,
185
  inputs=[input_image, model_selector, prompt_input],
186
- outputs=[output_image, logs_output]
187
  )
 
 
 
 
 
 
 
 
 
188
 
189
  iface.launch()
 
1
  import gradio as gr
2
  import torch
3
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline
4
  from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution
5
  import gc
6
  from PIL import Image
 
10
  import os
11
  import requests
12
  from spandrel import ModelLoader
13
+ from abc import ABC, abstractmethod
14
+ from typing import Optional, Tuple, Dict
15
+ import psutil
16
+ import time
17
+ import traceback
18
 
19
+ # --- Configuration ---
20
+ class Config:
21
+ """Configuration settings for the application."""
22
+ MODEL_DIR = "weights"
23
+ REALESRGAN_URL = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth"
24
+ REALESRGAN_FILENAME = "RealESRGAN_x2plus.pth"
25
+ SWIN2SR_ID = "caidas/swin2SR-classical-sr-x2-64"
26
+ SD_ID = "stabilityai/stable-diffusion-x4-upscaler"
27
+
28
+ # SOTA Models (2025)
29
+ SPAN_URL = "https://huggingface.co/Phips/2xNomosUni_span_multijpg/resolve/main/2xNomosUni_span_multijpg.safetensors"
30
+ SPAN_FILENAME = "2xNomosUni_span_multijpg.safetensors"
31
+ HATS_URL = "https://huggingface.co/Phips/4xNomos8kSCHAT-S/resolve/main/4xNomos8kSCHAT-S.safetensors"
32
+ HATS_FILENAME = "4xNomos8kSCHAT-S.safetensors"
33
+
34
+ MAX_IMAGE_SIZE_SD = 512 # Max dimension for SD input to prevent OOM
35
+ DEVICE = "cpu" # Force CPU for this demo, can be "cuda" if available
36
+
37
+ @staticmethod
38
+ def ensure_model_dir():
39
+ if not os.path.exists(Config.MODEL_DIR):
40
+ os.makedirs(Config.MODEL_DIR)
41
+
42
+ # --- Logging Setup ---
43
+ class LogCapture(io.StringIO):
44
+ """Custom StringIO to capture logs."""
45
+ pass
46
+
47
+ log_capture_string = LogCapture()
48
  ch = logging.StreamHandler(log_capture_string)
49
  ch.setLevel(logging.INFO)
50
+ formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
51
+ ch.setFormatter(formatter)
52
+
53
+ logger = logging.getLogger("UpscalerApp")
54
  logger.setLevel(logging.INFO)
55
  logger.addHandler(ch)
56
 
57
+ def get_logs() -> str:
58
+ """Retrieve captured logs."""
59
  return log_capture_string.getvalue()
60
 
61
+ # --- System Monitoring ---
62
+ def get_system_usage() -> str:
63
+ """Returns current CPU and RAM usage."""
64
+ cpu_percent = psutil.cpu_percent()
65
+ ram_percent = psutil.virtual_memory().percent
66
+ ram_used_gb = psutil.virtual_memory().used / (1024 ** 3)
67
+ return f"CPU: {cpu_percent}% | RAM: {ram_percent}% ({ram_used_gb:.1f} GB used)"
68
+
69
+ # --- Abstract Base Class for Models ---
70
+ class UpscalerStrategy(ABC):
71
+ """Abstract base class for upscaling strategies."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
+ def __init__(self):
74
+ self.model = None
75
+ self.name = "Unknown"
76
+
77
+ @abstractmethod
78
+ def load(self) -> None:
79
+ """Load the model into memory."""
80
+ pass
81
+
82
+ @abstractmethod
83
+ def upscale(self, image: Image.Image, **kwargs) -> Image.Image:
84
+ """Upscale the given image."""
85
+ pass
86
 
87
+ def unload(self) -> None:
88
+ """Unload the model to free memory."""
89
+ if self.model is not None:
90
+ del self.model
91
+ self.model = None
92
+ gc.collect()
93
+ logger.info(f"Unloaded {self.name}")
94
 
95
+ # --- Helper Functions for Optimization ---
96
+ def manual_tile_upscale(model, img_tensor, tile_size=256, tile_pad=10, scale=2):
97
+ """
98
+ Low-level tiling implementation for custom models.
99
+ Prevents OOM by processing image in chunks.
100
+ """
101
+ B, C, H, W = img_tensor.shape
102
 
103
+ # Calculate tile dimensions
104
+ tile_h = (H + tile_size - 1) // tile_size
105
+ tile_w = (W + tile_size - 1) // tile_size
106
 
107
+ output = torch.zeros(B, C, H * scale, W * scale,
108
+ device=img_tensor.device, dtype=img_tensor.dtype)
109
 
110
+ for th in range(tile_h):
111
+ for tw in range(tile_w):
112
+ # Calculate input tile coordinates with padding
113
+ x1 = th * tile_size
114
+ y1 = tw * tile_size
115
+ x2 = min((th + 1) * tile_size, H)
116
+ y2 = min((tw + 1) * tile_size, W)
117
+
118
+ # Add halo for context
119
+ x1_pad = max(0, x1 - tile_pad)
120
+ y1_pad = max(0, y1 - tile_pad)
121
+ x2_pad = min(H, x2 + tile_pad)
122
+ y2_pad = min(W, y2 + tile_pad)
123
+
124
+ # Extract padded tile
125
+ tile = img_tensor[:, :, x1_pad:x2_pad, y1_pad:y2_pad]
126
+
127
+ # Process tile
128
+ with torch.no_grad():
129
+ tile_out = model(tile)
130
+
131
+ # Calculate output crop region (remove halo)
132
+ halo_x1 = (x1 - x1_pad) * scale
133
+ halo_y1 = (y1 - y1_pad) * scale
134
+ out_x2 = halo_x1 + (x2 - x1) * scale
135
+ out_y2 = halo_y1 + (y2 - y1) * scale
136
+
137
+ # Place in output
138
+ output[:, :, x1*scale:x2*scale, y1*scale:y2*scale] = \
139
+ tile_out[:, :, halo_x1:out_x2, halo_y1:out_y2]
140
 
141
+ return output
142
 
143
+ def select_tile_config(height, width):
144
+ """
145
+ Dynamically select tile size based on image resolution.
146
+ """
147
+ megapixels = (height * width) / (1024 ** 2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
+ if megapixels < 2: # < 1080p
150
+ return {'tile': 512, 'tile_pad': 10}
151
+ elif megapixels < 6: # < 4K
152
+ return {'tile': 384, 'tile_pad': 15}
153
+ elif megapixels < 16: # < 8K
154
+ return {'tile': 256, 'tile_pad': 20}
155
+ else: # 8K+
156
+ return {'tile': 128, 'tile_pad': 25}
157
+
158
+ # --- Concrete Implementations ---
159
+
160
+ class RealESRGANStrategy(UpscalerStrategy):
161
+ def __init__(self):
162
+ super().__init__()
163
+ self.name = "RealESRGAN x2"
164
+ self.compiled = False
165
+
166
+ def load(self) -> None:
167
+ if self.model is None:
168
+ logger.info(f"Loading {self.name}...")
169
+ Config.ensure_model_dir()
170
+ model_path = os.path.join(Config.MODEL_DIR, Config.REALESRGAN_FILENAME)
171
+
172
+ if not os.path.exists(model_path):
173
+ logger.info(f"Downloading {Config.REALESRGAN_FILENAME}...")
174
+ try:
175
+ response = requests.get(Config.REALESRGAN_URL, stream=True)
176
+ response.raise_for_status()
177
+ with open(model_path, 'wb') as f:
178
+ for chunk in response.iter_content(chunk_size=8192):
179
+ f.write(chunk)
180
+ logger.info("Download complete.")
181
+ except Exception as e:
182
+ logger.error(f"Failed to download model: {e}")
183
+ raise
184
+
185
+ try:
186
+ self.model = ModelLoader().load_from_file(model_path)
187
+ self.model.eval()
188
+ self.model.to(Config.DEVICE)
189
+
190
+ # Optimization: torch.compile
191
+ if not self.compiled:
192
+ try:
193
+ # 'reduce-overhead' uses CUDA graphs, so only use it on CUDA
194
+ if Config.DEVICE == 'cuda':
195
+ self.model = torch.compile(self.model, mode='reduce-overhead')
196
+ logger.info("✓ torch.compile enabled (reduce-overhead mode)")
197
+ elif os.name == 'nt' and Config.DEVICE == 'cpu':
198
+ # Windows requires MSVC for Inductor (default cpu backend)
199
+ # We skip it to avoid "Compiler: cl is not found" error unless user has it.
200
+ logger.info("ℹ Skipping torch.compile on Windows CPU to avoid MSVC requirement.")
201
+ elif (psutil.cpu_count(logical=False) or 0) < 4 and Config.DEVICE == 'cpu':
202
+ # Skip compilation on weak CPUs (e.g. HF Spaces Free Tier) to avoid long startup times
203
+ logger.info("ℹ Skipping torch.compile on low-core CPU to prevent timeout.")
204
+ else:
205
+ # On Linux/Mac CPU, use default mode or skip if problematic. Default is usually safe.
206
+ self.model = torch.compile(self.model)
207
+ logger.info("✓ torch.compile enabled (default mode)")
208
+
209
+ self.compiled = True
210
+ except Exception as e:
211
+ logger.warning(f"⚠ torch.compile not available or failed: {e}")
212
+ self.compiled = True # Mark as tried
213
+
214
+ logger.info(f"{self.name} loaded successfully.")
215
+ except Exception as e:
216
+ logger.error(f"Failed to load model architecture: {e}")
217
+ raise
218
+
219
+ def upscale(self, image: Image.Image, **kwargs) -> Image.Image:
220
+ if self.model is None:
221
+ self.load()
222
+
223
+ logger.info(f"Starting inference with {self.name}...")
224
+ start_time = time.time()
225
+
226
+ img_np = np.array(image).astype(np.float32) / 255.0
227
+ img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(Config.DEVICE)
228
+
229
+ # Optimization: Dynamic Tiling
230
+ h, w = img_np.shape[:2]
231
+ tile_config = select_tile_config(h, w)
232
+ logger.info(f"Using tile config: {tile_config}")
233
+
234
+ # Optimization: Mixed Precision (AMP)
235
+ # Use bfloat16 for CPU if supported, else float32 (autocast handles this mostly)
236
+ # For CUDA, float16 is standard.
237
+ dtype = torch.float16 if Config.DEVICE == 'cuda' else torch.bfloat16
238
+
239
+ try:
240
+ with torch.autocast(device_type=Config.DEVICE, dtype=dtype):
241
+ if tile_config['tile'] > 0:
242
+ output_tensor = manual_tile_upscale(
243
+ self.model,
244
+ img_tensor,
245
+ tile_size=tile_config['tile'],
246
+ tile_pad=tile_config['tile_pad'],
247
+ scale=2
248
+ )
249
+ else:
250
+ output_tensor = self.model(img_tensor) # type: ignore
251
+ except Exception as e:
252
+ logger.warning(f"AMP/Tiling failed, falling back to standard FP32: {e}")
253
+ # Fallback to standard execution
254
+ output_tensor = self.model(img_tensor) # type: ignore
255
+
256
+ output_np = output_tensor.squeeze(0).permute(1, 2, 0).clamp(0, 1).float().cpu().numpy()
257
+ output_np = (output_np * 255.0).round().astype(np.uint8)
258
+
259
+ elapsed = time.time() - start_time
260
+ logger.info(f"Inference finished in {elapsed:.2f}s")
261
+
262
+ # Benchmark info (from doc)
263
+ output_megapixels = (output_np.shape[0] * output_np.shape[1]) / (1024 ** 2)
264
+ throughput = output_megapixels / elapsed
265
+ logger.info(f"Speed: {throughput:.2f} MP/s")
266
+
267
+ return Image.fromarray(output_np)
268
+
269
+ class Swin2SRStrategy(UpscalerStrategy):
270
+ def __init__(self):
271
+ super().__init__()
272
+ self.name = "Swin2SR x2"
273
+ self.processor = None
274
+
275
+ def load(self) -> None:
276
+ if self.model is None:
277
+ logger.info(f"Loading {self.name}...")
278
+ try:
279
+ self.processor = AutoImageProcessor.from_pretrained(Config.SWIN2SR_ID)
280
+ model = Swin2SRForImageSuperResolution.from_pretrained(Config.SWIN2SR_ID)
281
+ self.model = model.to(Config.DEVICE) # type: ignore
282
+ logger.info(f"{self.name} loaded successfully.")
283
+ except Exception as e:
284
+ logger.error(f"Failed to load Swin2SR: {e}")
285
+ raise
286
+
287
+ def upscale(self, image: Image.Image, **kwargs) -> Image.Image:
288
+ if self.model is None or self.processor is None:
289
+ self.load()
290
+
291
+ logger.info(f"Starting inference with {self.name}...")
292
+ start_time = time.time()
293
+
294
+ if self.processor is None:
295
+ raise ValueError("Processor not loaded")
296
+
297
+ inputs = self.processor(images=image, return_tensors="pt").to(Config.DEVICE)
298
+
299
+ with torch.no_grad():
300
+ outputs = self.model(**inputs)
301
+
302
+ output = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
303
+ output = np.moveaxis(output, source=0, destination=-1)
304
+ output = (output * 255.0).round().astype(np.uint8)
305
+
306
+ logger.info(f"Inference finished in {time.time() - start_time:.2f}s")
307
+ return Image.fromarray(output)
308
+
309
+ class StableDiffusionStrategy(UpscalerStrategy):
310
+ def __init__(self):
311
+ super().__init__()
312
+ self.name = "Stable Diffusion x4"
313
+
314
+ def load(self) -> None:
315
+ if self.model is None:
316
+ logger.info(f"Loading {self.name} (this may take time)...")
317
+ try:
318
+ self.model = StableDiffusionUpscalePipeline.from_pretrained(
319
+ Config.SD_ID,
320
+ torch_dtype=torch.float32,
321
+ low_cpu_mem_usage=True
322
+ )
323
+ # Optimizations for CPU
324
+ self.model.enable_attention_slicing("max")
325
+ self.model.enable_vae_tiling()
326
+ logger.info(f"{self.name} loaded successfully.")
327
+ except Exception as e:
328
+ logger.error(f"Failed to load Stable Diffusion: {e}")
329
+ raise
330
+
331
+ def upscale(self, image: Image.Image, **kwargs) -> Image.Image:
332
+ if self.model is None:
333
+ self.load()
334
+
335
+ prompt = kwargs.get("prompt", "high quality, detailed")
336
+
337
+ # Pre-check size
338
+ if max(image.size) > Config.MAX_IMAGE_SIZE_SD:
339
+ ratio = Config.MAX_IMAGE_SIZE_SD / max(image.size)
340
+ new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio))
341
+ image = image.resize(new_size, Image.Resampling.LANCZOS)
342
+ logger.warning(f"Resized input to {new_size} to prevent OOM on CPU.")
343
+
344
+ logger.info(f"Starting inference with {self.name}...")
345
+ start_time = time.time()
346
+
347
+ generator = torch.manual_seed(42)
348
+ output = self.model(
349
+ prompt=prompt,
350
+ image=image,
351
+ num_inference_steps=20,
352
+ guidance_scale=7.0,
353
+ generator=generator
354
+ ).images[0] # type: ignore
355
+
356
+ logger.info(f"Inference finished in {time.time() - start_time:.2f}s")
357
+ return output
358
+
359
+ class SpanStrategy(UpscalerStrategy):
360
+ def __init__(self):
361
+ super().__init__()
362
+ self.name = "SPAN (NomosUni) x2"
363
+ self.compiled = False
364
+
365
+ def load(self) -> None:
366
+ if self.model is None:
367
+ logger.info(f"Loading {self.name}...")
368
+ Config.ensure_model_dir()
369
+ model_path = os.path.join(Config.MODEL_DIR, Config.SPAN_FILENAME)
370
+
371
+ if not os.path.exists(model_path):
372
+ logger.info(f"Downloading {Config.SPAN_FILENAME}...")
373
+ try:
374
+ response = requests.get(Config.SPAN_URL, stream=True)
375
+ response.raise_for_status()
376
+ with open(model_path, 'wb') as f:
377
+ for chunk in response.iter_content(chunk_size=8192):
378
+ f.write(chunk)
379
+ logger.info("Download complete.")
380
+ except Exception as e:
381
+ logger.error(f"Failed to download model: {e}")
382
+ raise
383
+
384
+ try:
385
+ self.model = ModelLoader().load_from_file(model_path)
386
+ self.model.eval()
387
+ self.model.to(Config.DEVICE)
388
+
389
+ # Optimization: torch.compile
390
+ if not self.compiled:
391
+ try:
392
+ if Config.DEVICE == 'cuda':
393
+ self.model = torch.compile(self.model, mode='reduce-overhead')
394
+ logger.info("✓ torch.compile enabled (reduce-overhead mode)")
395
+ elif os.name == 'nt' and Config.DEVICE == 'cpu':
396
+ logger.info("ℹ Skipping torch.compile on Windows CPU.")
397
+ elif (psutil.cpu_count(logical=False) or 0) < 4 and Config.DEVICE == 'cpu':
398
+ logger.info("ℹ Skipping torch.compile on low-core CPU.")
399
+ else:
400
+ self.model = torch.compile(self.model)
401
+ logger.info("✓ torch.compile enabled (default mode)")
402
+ self.compiled = True
403
+ except Exception as e:
404
+ logger.warning(f"⚠ torch.compile failed: {e}")
405
+ self.compiled = True
406
+
407
+ logger.info(f"{self.name} loaded successfully.")
408
+ except Exception as e:
409
+ logger.error(f"Failed to load model architecture: {e}")
410
+ raise
411
+
412
+ def upscale(self, image: Image.Image, **kwargs) -> Image.Image:
413
+ if self.model is None:
414
+ self.load()
415
+
416
+ logger.info(f"Starting inference with {self.name}...")
417
+ start_time = time.time()
418
+
419
+ img_np = np.array(image).astype(np.float32) / 255.0
420
+ img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(Config.DEVICE)
421
+
422
+ # SPAN is very efficient, but we still use tiling for safety on huge images
423
+ h, w = img_np.shape[:2]
424
+ tile_config = select_tile_config(h, w)
425
+
426
+ dtype = torch.float16 if Config.DEVICE == 'cuda' else torch.bfloat16
427
+
428
+ try:
429
+ with torch.autocast(device_type=Config.DEVICE, dtype=dtype):
430
+ if tile_config['tile'] > 0:
431
+ output_tensor = manual_tile_upscale(
432
+ self.model,
433
+ img_tensor,
434
+ tile_size=tile_config['tile'],
435
+ tile_pad=tile_config['tile_pad'],
436
+ scale=2
437
+ )
438
+ else:
439
+ output_tensor = self.model(img_tensor) # type: ignore
440
+ except Exception as e:
441
+ logger.warning(f"AMP/Tiling failed, falling back: {e}")
442
+ output_tensor = self.model(img_tensor) # type: ignore
443
+
444
+ output_np = output_tensor.squeeze(0).permute(1, 2, 0).clamp(0, 1).float().cpu().numpy()
445
+ output_np = (output_np * 255.0).round().astype(np.uint8)
446
+
447
+ elapsed = time.time() - start_time
448
+ logger.info(f"Inference finished in {elapsed:.2f}s")
449
+ return Image.fromarray(output_np)
450
+
451
+ class HatsStrategy(UpscalerStrategy):
452
+ def __init__(self):
453
+ super().__init__()
454
+ self.name = "HAT-S x4"
455
+ self.compiled = False
456
+
457
+ def load(self) -> None:
458
+ if self.model is None:
459
+ logger.info(f"Loading {self.name}...")
460
+ Config.ensure_model_dir()
461
+ model_path = os.path.join(Config.MODEL_DIR, Config.HATS_FILENAME)
462
+
463
+ if not os.path.exists(model_path):
464
+ logger.info(f"Downloading {Config.HATS_FILENAME}...")
465
+ try:
466
+ response = requests.get(Config.HATS_URL, stream=True)
467
+ response.raise_for_status()
468
+ with open(model_path, 'wb') as f:
469
+ for chunk in response.iter_content(chunk_size=8192):
470
+ f.write(chunk)
471
+ logger.info("Download complete.")
472
+ except Exception as e:
473
+ logger.error(f"Failed to download model: {e}")
474
+ raise
475
+
476
+ try:
477
+ self.model = ModelLoader().load_from_file(model_path)
478
+ self.model.eval()
479
+ self.model.to(Config.DEVICE)
480
+
481
+ if not self.compiled:
482
+ try:
483
+ if Config.DEVICE == 'cuda':
484
+ self.model = torch.compile(self.model, mode='reduce-overhead')
485
+ elif os.name == 'nt' and Config.DEVICE == 'cpu':
486
+ pass
487
+ elif (psutil.cpu_count(logical=False) or 0) < 4 and Config.DEVICE == 'cpu':
488
+ pass
489
+ else:
490
+ self.model = torch.compile(self.model)
491
+ self.compiled = True
492
+ except Exception:
493
+ self.compiled = True
494
+
495
+ logger.info(f"{self.name} loaded successfully.")
496
+ except Exception as e:
497
+ logger.error(f"Failed to load model architecture: {e}")
498
+ raise
499
+
500
+ def upscale(self, image: Image.Image, **kwargs) -> Image.Image:
501
+ if self.model is None:
502
+ self.load()
503
+
504
+ logger.info(f"Starting inference with {self.name}...")
505
+ start_time = time.time()
506
+
507
+ img_np = np.array(image).astype(np.float32) / 255.0
508
+ img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(Config.DEVICE)
509
+
510
+ h, w = img_np.shape[:2]
511
+ tile_config = select_tile_config(h, w)
512
+
513
+ dtype = torch.float16 if Config.DEVICE == 'cuda' else torch.bfloat16
514
+
515
+ try:
516
+ with torch.autocast(device_type=Config.DEVICE, dtype=dtype):
517
+ if tile_config['tile'] > 0:
518
+ output_tensor = manual_tile_upscale(
519
+ self.model,
520
+ img_tensor,
521
+ tile_size=tile_config['tile'],
522
+ tile_pad=tile_config['tile_pad'],
523
+ scale=4 # HAT-S is x4
524
+ )
525
+ else:
526
+ output_tensor = self.model(img_tensor) # type: ignore
527
+ except Exception as e:
528
+ logger.warning(f"AMP/Tiling failed, falling back: {e}")
529
+ output_tensor = self.model(img_tensor) # type: ignore
530
+
531
+ output_np = output_tensor.squeeze(0).permute(1, 2, 0).clamp(0, 1).float().cpu().numpy()
532
+ output_np = (output_np * 255.0).round().astype(np.uint8)
533
+
534
+ elapsed = time.time() - start_time
535
+ logger.info(f"Inference finished in {elapsed:.2f}s")
536
+ return Image.fromarray(output_np)
537
+
538
+ # --- Model Manager (Singleton-ish) ---
539
+ class UpscalerManager:
540
+ """Manages model lifecycle and selection."""
541
+ def __init__(self):
542
+ self.strategies: Dict[str, UpscalerStrategy] = {
543
+ "SPAN (NomosUni) x2": SpanStrategy(),
544
+ "RealESRGAN x2": RealESRGANStrategy(),
545
+ "HAT-S x4": HatsStrategy(),
546
+ "Swin2SR x2": Swin2SRStrategy(),
547
+ "Stable Diffusion x4": StableDiffusionStrategy()
548
+ }
549
+ self.current_model_name: Optional[str] = None
550
+
551
+ def get_strategy(self, name: str) -> UpscalerStrategy:
552
+ if name not in self.strategies:
553
+ raise ValueError(f"Model {name} not found.")
554
+
555
+ # Memory Optimization for Free Tier (16GB RAM limit):
556
+ # Ensure only one model is loaded at a time.
557
+ if self.current_model_name != name:
558
+ if self.current_model_name is not None:
559
+ logger.info(f"Switching models: Unloading {self.current_model_name}...")
560
+ self.strategies[self.current_model_name].unload()
561
+ self.current_model_name = name
562
+
563
+ return self.strategies[name]
564
 
565
+ def unload_all(self):
566
+ """Unload all models to free memory."""
567
+ for strategy in self.strategies.values():
568
+ strategy.unload()
569
+ gc.collect()
570
+ logger.info("All models unloaded.")
571
+
572
+ manager = UpscalerManager()
573
+
574
+ # --- Gradio Interface Logic ---
575
+ def process_image(input_img: Image.Image, model_name: str, prompt: str) -> Tuple[Optional[Image.Image], str, str]:
576
  if input_img is None:
577
+ return None, get_logs(), get_system_usage()
578
 
 
579
  try:
580
+ strategy = manager.get_strategy(model_name)
581
+
582
+ # Optional: Unload others if memory is tight (simple logic here)
583
+ # For now, we just rely on the user or OS, but in prod we might auto-unload.
584
+
585
+ output = strategy.upscale(input_img, prompt=prompt)
586
+
587
+ # Explicit GC after heavy operations
 
588
  gc.collect()
589
+
590
+ return output, get_logs(), get_system_usage()
591
  except Exception as e:
592
+ error_msg = f"Critical Error: {str(e)}\n{traceback.format_exc()}"
593
+ logger.error(error_msg)
594
+ # Return the error message in the logs output so the user sees it
595
+ return None, get_logs() + "\n\n" + error_msg, get_system_usage()
596
+
597
+ def unload_models():
598
+ manager.unload_all()
599
+ return get_logs(), get_system_usage()
600
 
601
+ # --- UI Construction ---
602
  desc = """
603
+ ### 🚀 Enterprise-Grade Universal Upscaler (SOTA 2025)
604
+ Select a specialized model to upscale your image.
605
+ * **SPAN (NomosUni) x2**: **SOTA Speed**. Fastest CPU model. Best for general use.
606
+ * **RealESRGAN x2**: 🛡️ **Robust**. Best for removing JPEG artifacts and noise.
607
+ * **HAT-S x4**: 💎 **SOTA Quality**. Best texture detail (slower).
608
+ * **Swin2SR x2**: 🎯 High fidelity, removes compression artifacts.
609
+ * **Stable Diffusion x4**: 🎨 Generative upscaling. Adds missing details (slow, high RAM).
610
  """
611
 
612
+ with gr.Blocks(title="Universal Upscaler Pro") as iface:
613
  gr.Markdown(desc)
614
 
615
  with gr.Row():
616
+ with gr.Column(scale=1):
617
  input_image = gr.Image(type="pil", label="Input Image")
618
+
619
+ with gr.Group():
620
+ model_selector = gr.Dropdown(
621
+ choices=list(manager.strategies.keys()),
622
+ value="SPAN (NomosUni) x2",
623
+ label="Select Model Architecture"
624
+ )
625
+ prompt_input = gr.Textbox(
626
+ label="Prompt (Stable Diffusion Only)",
627
+ value="highly detailed, 4k, sharp",
628
+ placeholder="Describe the image content..."
629
+ )
630
+
631
+ with gr.Accordion("Advanced Settings", open=False):
632
+ gr.Markdown("Memory Management")
633
+ unload_btn = gr.Button("Unload All Models (Free RAM)", variant="secondary")
634
+
635
+ submit_btn = gr.Button("✨ Upscale Image", variant="primary", size="lg")
636
+ system_info = gr.Label(value=get_system_usage(), label="System Status")
637
+
638
+ with gr.Column(scale=1):
639
+ output_image = gr.Image(type="pil", label="Upscaled Result")
640
+ logs_output = gr.TextArea(label="Execution Logs", interactive=False, lines=10)
641
+
642
+ # Event Wiring
643
  submit_btn.click(
644
  fn=process_image,
645
  inputs=[input_image, model_selector, prompt_input],
646
+ outputs=[output_image, logs_output, system_info]
647
  )
648
+
649
+ unload_btn.click(
650
+ fn=unload_models,
651
+ inputs=[],
652
+ outputs=[logs_output, system_info]
653
+ )
654
+
655
+ # Auto-refresh system info every 2 seconds (optional, can be heavy on UI)
656
+ # iface.load(get_system_usage, None, system_info, every=2)
657
 
658
  iface.launch()
requirements.txt CHANGED
@@ -7,4 +7,11 @@ pillow
7
  gradio
8
  opencv-python
9
  spandrel
10
- requests
 
 
 
 
 
 
 
 
7
  gradio
8
  opencv-python
9
  spandrel
10
+ requests
11
+ psutil
12
+ onnx
13
+ onnxruntime
14
+ basicsr
15
+ realesrgan
16
+ openvino
17
+ optimum