|
|
import gradio as gr
|
|
|
import torch
|
|
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline
|
|
|
from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution
|
|
|
import gc
|
|
|
from PIL import Image
|
|
|
import numpy as np
|
|
|
import logging
|
|
|
import io
|
|
|
import os
|
|
|
import requests
|
|
|
from spandrel import ModelLoader
|
|
|
from abc import ABC, abstractmethod
|
|
|
from typing import Optional, Tuple, Dict
|
|
|
import psutil
|
|
|
import time
|
|
|
import traceback
|
|
|
|
|
|
|
|
|
class Config:
|
|
|
"""Configuration settings for the application."""
|
|
|
MODEL_DIR = "weights"
|
|
|
REALESRGAN_URL = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth"
|
|
|
REALESRGAN_FILENAME = "RealESRGAN_x2plus.pth"
|
|
|
SWIN2SR_ID = "caidas/swin2SR-classical-sr-x2-64"
|
|
|
SD_ID = "stabilityai/stable-diffusion-x4-upscaler"
|
|
|
|
|
|
|
|
|
SPAN_URL = "https://huggingface.co/Phips/2xNomosUni_span_multijpg/resolve/main/2xNomosUni_span_multijpg.safetensors"
|
|
|
SPAN_FILENAME = "2xNomosUni_span_multijpg.safetensors"
|
|
|
HATS_URL = "https://huggingface.co/Phips/4xNomos8kSCHAT-S/resolve/main/4xNomos8kSCHAT-S.safetensors"
|
|
|
HATS_FILENAME = "4xNomos8kSCHAT-S.safetensors"
|
|
|
|
|
|
MAX_IMAGE_SIZE_SD = 512
|
|
|
DEVICE = "cpu"
|
|
|
|
|
|
@staticmethod
|
|
|
def ensure_model_dir():
|
|
|
if not os.path.exists(Config.MODEL_DIR):
|
|
|
os.makedirs(Config.MODEL_DIR)
|
|
|
|
|
|
|
|
|
class LogCapture(io.StringIO):
|
|
|
"""Custom StringIO to capture logs."""
|
|
|
pass
|
|
|
|
|
|
log_capture_string = LogCapture()
|
|
|
ch = logging.StreamHandler(log_capture_string)
|
|
|
ch.setLevel(logging.INFO)
|
|
|
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
|
|
ch.setFormatter(formatter)
|
|
|
|
|
|
logger = logging.getLogger("UpscalerApp")
|
|
|
logger.setLevel(logging.INFO)
|
|
|
logger.addHandler(ch)
|
|
|
|
|
|
def get_logs() -> str:
|
|
|
"""Retrieve captured logs."""
|
|
|
return log_capture_string.getvalue()
|
|
|
|
|
|
|
|
|
def get_system_usage() -> str:
|
|
|
"""Returns current CPU and RAM usage."""
|
|
|
cpu_percent = psutil.cpu_percent()
|
|
|
ram_percent = psutil.virtual_memory().percent
|
|
|
ram_used_gb = psutil.virtual_memory().used / (1024 ** 3)
|
|
|
return f"CPU: {cpu_percent}% | RAM: {ram_percent}% ({ram_used_gb:.1f} GB used)"
|
|
|
|
|
|
|
|
|
class UpscalerStrategy(ABC):
|
|
|
"""Abstract base class for upscaling strategies."""
|
|
|
|
|
|
def __init__(self):
|
|
|
self.model = None
|
|
|
self.name = "Unknown"
|
|
|
|
|
|
@abstractmethod
|
|
|
def load(self) -> None:
|
|
|
"""Load the model into memory."""
|
|
|
pass
|
|
|
|
|
|
@abstractmethod
|
|
|
def upscale(self, image: Image.Image, **kwargs) -> Image.Image:
|
|
|
"""Upscale the given image."""
|
|
|
pass
|
|
|
|
|
|
def unload(self) -> None:
|
|
|
"""Unload the model to free memory."""
|
|
|
if self.model is not None:
|
|
|
del self.model
|
|
|
self.model = None
|
|
|
gc.collect()
|
|
|
logger.info(f"Unloaded {self.name}")
|
|
|
|
|
|
|
|
|
def manual_tile_upscale(model, img_tensor, tile_size=256, tile_pad=10, scale=2):
|
|
|
"""
|
|
|
Low-level tiling implementation for custom models.
|
|
|
Prevents OOM by processing image in chunks.
|
|
|
"""
|
|
|
B, C, H, W = img_tensor.shape
|
|
|
|
|
|
|
|
|
tile_h = (H + tile_size - 1) // tile_size
|
|
|
tile_w = (W + tile_size - 1) // tile_size
|
|
|
|
|
|
output = torch.zeros(B, C, H * scale, W * scale,
|
|
|
device=img_tensor.device, dtype=img_tensor.dtype)
|
|
|
|
|
|
for th in range(tile_h):
|
|
|
for tw in range(tile_w):
|
|
|
|
|
|
x1 = th * tile_size
|
|
|
y1 = tw * tile_size
|
|
|
x2 = min((th + 1) * tile_size, H)
|
|
|
y2 = min((tw + 1) * tile_size, W)
|
|
|
|
|
|
|
|
|
x1_pad = max(0, x1 - tile_pad)
|
|
|
y1_pad = max(0, y1 - tile_pad)
|
|
|
x2_pad = min(H, x2 + tile_pad)
|
|
|
y2_pad = min(W, y2 + tile_pad)
|
|
|
|
|
|
|
|
|
tile = img_tensor[:, :, x1_pad:x2_pad, y1_pad:y2_pad]
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
tile_out = model(tile)
|
|
|
|
|
|
|
|
|
halo_x1 = (x1 - x1_pad) * scale
|
|
|
halo_y1 = (y1 - y1_pad) * scale
|
|
|
out_x2 = halo_x1 + (x2 - x1) * scale
|
|
|
out_y2 = halo_y1 + (y2 - y1) * scale
|
|
|
|
|
|
|
|
|
output[:, :, x1*scale:x2*scale, y1*scale:y2*scale] = \
|
|
|
tile_out[:, :, halo_x1:out_x2, halo_y1:out_y2]
|
|
|
|
|
|
return output
|
|
|
|
|
|
def select_tile_config(height, width):
|
|
|
"""
|
|
|
Dynamically select tile size based on image resolution.
|
|
|
"""
|
|
|
megapixels = (height * width) / (1024 ** 2)
|
|
|
|
|
|
if megapixels < 2:
|
|
|
return {'tile': 512, 'tile_pad': 10}
|
|
|
elif megapixels < 6:
|
|
|
return {'tile': 384, 'tile_pad': 15}
|
|
|
elif megapixels < 16:
|
|
|
return {'tile': 256, 'tile_pad': 20}
|
|
|
else:
|
|
|
return {'tile': 128, 'tile_pad': 25}
|
|
|
|
|
|
|
|
|
|
|
|
class RealESRGANStrategy(UpscalerStrategy):
|
|
|
def __init__(self):
|
|
|
super().__init__()
|
|
|
self.name = "RealESRGAN x2"
|
|
|
self.compiled = False
|
|
|
|
|
|
def load(self) -> None:
|
|
|
if self.model is None:
|
|
|
logger.info(f"Loading {self.name}...")
|
|
|
Config.ensure_model_dir()
|
|
|
model_path = os.path.join(Config.MODEL_DIR, Config.REALESRGAN_FILENAME)
|
|
|
|
|
|
if not os.path.exists(model_path):
|
|
|
logger.info(f"Downloading {Config.REALESRGAN_FILENAME}...")
|
|
|
try:
|
|
|
response = requests.get(Config.REALESRGAN_URL, stream=True)
|
|
|
response.raise_for_status()
|
|
|
with open(model_path, 'wb') as f:
|
|
|
for chunk in response.iter_content(chunk_size=8192):
|
|
|
f.write(chunk)
|
|
|
logger.info("Download complete.")
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to download model: {e}")
|
|
|
raise
|
|
|
|
|
|
try:
|
|
|
self.model = ModelLoader().load_from_file(model_path)
|
|
|
self.model.eval()
|
|
|
self.model.to(Config.DEVICE)
|
|
|
|
|
|
|
|
|
if not self.compiled:
|
|
|
try:
|
|
|
|
|
|
if Config.DEVICE == 'cuda':
|
|
|
self.model = torch.compile(self.model, mode='reduce-overhead')
|
|
|
logger.info("✓ torch.compile enabled (reduce-overhead mode)")
|
|
|
elif os.name == 'nt' and Config.DEVICE == 'cpu':
|
|
|
|
|
|
|
|
|
logger.info("ℹ Skipping torch.compile on Windows CPU to avoid MSVC requirement.")
|
|
|
elif (psutil.cpu_count(logical=False) or 0) < 4 and Config.DEVICE == 'cpu':
|
|
|
|
|
|
logger.info("ℹ Skipping torch.compile on low-core CPU to prevent timeout.")
|
|
|
else:
|
|
|
|
|
|
self.model = torch.compile(self.model)
|
|
|
logger.info("✓ torch.compile enabled (default mode)")
|
|
|
|
|
|
self.compiled = True
|
|
|
except Exception as e:
|
|
|
logger.warning(f"⚠ torch.compile not available or failed: {e}")
|
|
|
self.compiled = True
|
|
|
|
|
|
logger.info(f"{self.name} loaded successfully.")
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to load model architecture: {e}")
|
|
|
raise
|
|
|
|
|
|
def upscale(self, image: Image.Image, **kwargs) -> Image.Image:
|
|
|
if self.model is None:
|
|
|
self.load()
|
|
|
|
|
|
logger.info(f"Starting inference with {self.name}...")
|
|
|
start_time = time.time()
|
|
|
|
|
|
img_np = np.array(image).astype(np.float32) / 255.0
|
|
|
img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(Config.DEVICE)
|
|
|
|
|
|
|
|
|
h, w = img_np.shape[:2]
|
|
|
tile_config = select_tile_config(h, w)
|
|
|
logger.info(f"Using tile config: {tile_config}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dtype = torch.float16 if Config.DEVICE == 'cuda' else torch.bfloat16
|
|
|
|
|
|
try:
|
|
|
with torch.autocast(device_type=Config.DEVICE, dtype=dtype):
|
|
|
if tile_config['tile'] > 0:
|
|
|
output_tensor = manual_tile_upscale(
|
|
|
self.model,
|
|
|
img_tensor,
|
|
|
tile_size=tile_config['tile'],
|
|
|
tile_pad=tile_config['tile_pad'],
|
|
|
scale=2
|
|
|
)
|
|
|
else:
|
|
|
output_tensor = self.model(img_tensor)
|
|
|
except Exception as e:
|
|
|
logger.warning(f"AMP/Tiling failed, falling back to standard FP32: {e}")
|
|
|
|
|
|
output_tensor = self.model(img_tensor)
|
|
|
|
|
|
output_np = output_tensor.squeeze(0).permute(1, 2, 0).clamp(0, 1).float().cpu().numpy()
|
|
|
output_np = (output_np * 255.0).round().astype(np.uint8)
|
|
|
|
|
|
elapsed = time.time() - start_time
|
|
|
logger.info(f"Inference finished in {elapsed:.2f}s")
|
|
|
|
|
|
|
|
|
output_megapixels = (output_np.shape[0] * output_np.shape[1]) / (1024 ** 2)
|
|
|
throughput = output_megapixels / elapsed
|
|
|
logger.info(f"Speed: {throughput:.2f} MP/s")
|
|
|
|
|
|
return Image.fromarray(output_np)
|
|
|
|
|
|
class Swin2SRStrategy(UpscalerStrategy):
|
|
|
def __init__(self):
|
|
|
super().__init__()
|
|
|
self.name = "Swin2SR x2"
|
|
|
self.processor = None
|
|
|
|
|
|
def load(self) -> None:
|
|
|
if self.model is None:
|
|
|
logger.info(f"Loading {self.name}...")
|
|
|
try:
|
|
|
self.processor = AutoImageProcessor.from_pretrained(Config.SWIN2SR_ID)
|
|
|
model = Swin2SRForImageSuperResolution.from_pretrained(Config.SWIN2SR_ID)
|
|
|
self.model = model.to(Config.DEVICE)
|
|
|
logger.info(f"{self.name} loaded successfully.")
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to load Swin2SR: {e}")
|
|
|
raise
|
|
|
|
|
|
def upscale(self, image: Image.Image, **kwargs) -> Image.Image:
|
|
|
if self.model is None or self.processor is None:
|
|
|
self.load()
|
|
|
|
|
|
logger.info(f"Starting inference with {self.name}...")
|
|
|
start_time = time.time()
|
|
|
|
|
|
if self.processor is None:
|
|
|
raise ValueError("Processor not loaded")
|
|
|
|
|
|
inputs = self.processor(images=image, return_tensors="pt").to(Config.DEVICE)
|
|
|
|
|
|
with torch.no_grad():
|
|
|
outputs = self.model(**inputs)
|
|
|
|
|
|
output = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
|
|
output = np.moveaxis(output, source=0, destination=-1)
|
|
|
output = (output * 255.0).round().astype(np.uint8)
|
|
|
|
|
|
logger.info(f"Inference finished in {time.time() - start_time:.2f}s")
|
|
|
return Image.fromarray(output)
|
|
|
|
|
|
class StableDiffusionStrategy(UpscalerStrategy):
|
|
|
def __init__(self):
|
|
|
super().__init__()
|
|
|
self.name = "Stable Diffusion x4"
|
|
|
|
|
|
def load(self) -> None:
|
|
|
if self.model is None:
|
|
|
logger.info(f"Loading {self.name} (this may take time)...")
|
|
|
try:
|
|
|
self.model = StableDiffusionUpscalePipeline.from_pretrained(
|
|
|
Config.SD_ID,
|
|
|
torch_dtype=torch.float32,
|
|
|
low_cpu_mem_usage=True
|
|
|
)
|
|
|
|
|
|
self.model.enable_attention_slicing("max")
|
|
|
self.model.enable_vae_tiling()
|
|
|
logger.info(f"{self.name} loaded successfully.")
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to load Stable Diffusion: {e}")
|
|
|
raise
|
|
|
|
|
|
def upscale(self, image: Image.Image, **kwargs) -> Image.Image:
|
|
|
if self.model is None:
|
|
|
self.load()
|
|
|
|
|
|
prompt = kwargs.get("prompt", "high quality, detailed")
|
|
|
|
|
|
|
|
|
if max(image.size) > Config.MAX_IMAGE_SIZE_SD:
|
|
|
ratio = Config.MAX_IMAGE_SIZE_SD / max(image.size)
|
|
|
new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio))
|
|
|
image = image.resize(new_size, Image.Resampling.LANCZOS)
|
|
|
logger.warning(f"Resized input to {new_size} to prevent OOM on CPU.")
|
|
|
|
|
|
logger.info(f"Starting inference with {self.name}...")
|
|
|
start_time = time.time()
|
|
|
|
|
|
generator = torch.manual_seed(42)
|
|
|
output = self.model(
|
|
|
prompt=prompt,
|
|
|
image=image,
|
|
|
num_inference_steps=20,
|
|
|
guidance_scale=7.0,
|
|
|
generator=generator
|
|
|
).images[0]
|
|
|
|
|
|
logger.info(f"Inference finished in {time.time() - start_time:.2f}s")
|
|
|
return output
|
|
|
|
|
|
class SpanStrategy(UpscalerStrategy):
|
|
|
def __init__(self):
|
|
|
super().__init__()
|
|
|
self.name = "SPAN (NomosUni) x2"
|
|
|
self.compiled = False
|
|
|
|
|
|
def load(self) -> None:
|
|
|
if self.model is None:
|
|
|
logger.info(f"Loading {self.name}...")
|
|
|
Config.ensure_model_dir()
|
|
|
model_path = os.path.join(Config.MODEL_DIR, Config.SPAN_FILENAME)
|
|
|
|
|
|
if not os.path.exists(model_path):
|
|
|
logger.info(f"Downloading {Config.SPAN_FILENAME}...")
|
|
|
try:
|
|
|
response = requests.get(Config.SPAN_URL, stream=True)
|
|
|
response.raise_for_status()
|
|
|
with open(model_path, 'wb') as f:
|
|
|
for chunk in response.iter_content(chunk_size=8192):
|
|
|
f.write(chunk)
|
|
|
logger.info("Download complete.")
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to download model: {e}")
|
|
|
raise
|
|
|
|
|
|
try:
|
|
|
self.model = ModelLoader().load_from_file(model_path)
|
|
|
self.model.eval()
|
|
|
self.model.to(Config.DEVICE)
|
|
|
|
|
|
|
|
|
if not self.compiled:
|
|
|
try:
|
|
|
if Config.DEVICE == 'cuda':
|
|
|
self.model = torch.compile(self.model, mode='reduce-overhead')
|
|
|
logger.info("✓ torch.compile enabled (reduce-overhead mode)")
|
|
|
elif os.name == 'nt' and Config.DEVICE == 'cpu':
|
|
|
logger.info("ℹ Skipping torch.compile on Windows CPU.")
|
|
|
elif (psutil.cpu_count(logical=False) or 0) < 4 and Config.DEVICE == 'cpu':
|
|
|
logger.info("ℹ Skipping torch.compile on low-core CPU.")
|
|
|
else:
|
|
|
self.model = torch.compile(self.model)
|
|
|
logger.info("✓ torch.compile enabled (default mode)")
|
|
|
self.compiled = True
|
|
|
except Exception as e:
|
|
|
logger.warning(f"⚠ torch.compile failed: {e}")
|
|
|
self.compiled = True
|
|
|
|
|
|
logger.info(f"{self.name} loaded successfully.")
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to load model architecture: {e}")
|
|
|
raise
|
|
|
|
|
|
def upscale(self, image: Image.Image, **kwargs) -> Image.Image:
|
|
|
if self.model is None:
|
|
|
self.load()
|
|
|
|
|
|
logger.info(f"Starting inference with {self.name}...")
|
|
|
start_time = time.time()
|
|
|
|
|
|
img_np = np.array(image).astype(np.float32) / 255.0
|
|
|
img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(Config.DEVICE)
|
|
|
|
|
|
|
|
|
h, w = img_np.shape[:2]
|
|
|
tile_config = select_tile_config(h, w)
|
|
|
|
|
|
dtype = torch.float16 if Config.DEVICE == 'cuda' else torch.bfloat16
|
|
|
|
|
|
try:
|
|
|
with torch.autocast(device_type=Config.DEVICE, dtype=dtype):
|
|
|
if tile_config['tile'] > 0:
|
|
|
output_tensor = manual_tile_upscale(
|
|
|
self.model,
|
|
|
img_tensor,
|
|
|
tile_size=tile_config['tile'],
|
|
|
tile_pad=tile_config['tile_pad'],
|
|
|
scale=2
|
|
|
)
|
|
|
else:
|
|
|
output_tensor = self.model(img_tensor)
|
|
|
except Exception as e:
|
|
|
logger.warning(f"AMP/Tiling failed, falling back: {e}")
|
|
|
output_tensor = self.model(img_tensor)
|
|
|
|
|
|
output_np = output_tensor.squeeze(0).permute(1, 2, 0).clamp(0, 1).float().cpu().numpy()
|
|
|
output_np = (output_np * 255.0).round().astype(np.uint8)
|
|
|
|
|
|
elapsed = time.time() - start_time
|
|
|
logger.info(f"Inference finished in {elapsed:.2f}s")
|
|
|
return Image.fromarray(output_np)
|
|
|
|
|
|
class HatsStrategy(UpscalerStrategy):
|
|
|
def __init__(self):
|
|
|
super().__init__()
|
|
|
self.name = "HAT-S x4"
|
|
|
self.compiled = False
|
|
|
|
|
|
def load(self) -> None:
|
|
|
if self.model is None:
|
|
|
logger.info(f"Loading {self.name}...")
|
|
|
Config.ensure_model_dir()
|
|
|
model_path = os.path.join(Config.MODEL_DIR, Config.HATS_FILENAME)
|
|
|
|
|
|
if not os.path.exists(model_path):
|
|
|
logger.info(f"Downloading {Config.HATS_FILENAME}...")
|
|
|
try:
|
|
|
response = requests.get(Config.HATS_URL, stream=True)
|
|
|
response.raise_for_status()
|
|
|
with open(model_path, 'wb') as f:
|
|
|
for chunk in response.iter_content(chunk_size=8192):
|
|
|
f.write(chunk)
|
|
|
logger.info("Download complete.")
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to download model: {e}")
|
|
|
raise
|
|
|
|
|
|
try:
|
|
|
self.model = ModelLoader().load_from_file(model_path)
|
|
|
self.model.eval()
|
|
|
self.model.to(Config.DEVICE)
|
|
|
|
|
|
if not self.compiled:
|
|
|
try:
|
|
|
if Config.DEVICE == 'cuda':
|
|
|
self.model = torch.compile(self.model, mode='reduce-overhead')
|
|
|
elif os.name == 'nt' and Config.DEVICE == 'cpu':
|
|
|
pass
|
|
|
elif (psutil.cpu_count(logical=False) or 0) < 4 and Config.DEVICE == 'cpu':
|
|
|
pass
|
|
|
else:
|
|
|
self.model = torch.compile(self.model)
|
|
|
self.compiled = True
|
|
|
except Exception:
|
|
|
self.compiled = True
|
|
|
|
|
|
logger.info(f"{self.name} loaded successfully.")
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to load model architecture: {e}")
|
|
|
raise
|
|
|
|
|
|
def upscale(self, image: Image.Image, **kwargs) -> Image.Image:
|
|
|
if self.model is None:
|
|
|
self.load()
|
|
|
|
|
|
logger.info(f"Starting inference with {self.name}...")
|
|
|
start_time = time.time()
|
|
|
|
|
|
img_np = np.array(image).astype(np.float32) / 255.0
|
|
|
img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(Config.DEVICE)
|
|
|
|
|
|
h, w = img_np.shape[:2]
|
|
|
tile_config = select_tile_config(h, w)
|
|
|
|
|
|
dtype = torch.float16 if Config.DEVICE == 'cuda' else torch.bfloat16
|
|
|
|
|
|
try:
|
|
|
with torch.autocast(device_type=Config.DEVICE, dtype=dtype):
|
|
|
if tile_config['tile'] > 0:
|
|
|
output_tensor = manual_tile_upscale(
|
|
|
self.model,
|
|
|
img_tensor,
|
|
|
tile_size=tile_config['tile'],
|
|
|
tile_pad=tile_config['tile_pad'],
|
|
|
scale=4
|
|
|
)
|
|
|
else:
|
|
|
output_tensor = self.model(img_tensor)
|
|
|
except Exception as e:
|
|
|
logger.warning(f"AMP/Tiling failed, falling back: {e}")
|
|
|
output_tensor = self.model(img_tensor)
|
|
|
|
|
|
output_np = output_tensor.squeeze(0).permute(1, 2, 0).clamp(0, 1).float().cpu().numpy()
|
|
|
output_np = (output_np * 255.0).round().astype(np.uint8)
|
|
|
|
|
|
elapsed = time.time() - start_time
|
|
|
logger.info(f"Inference finished in {elapsed:.2f}s")
|
|
|
return Image.fromarray(output_np)
|
|
|
|
|
|
|
|
|
class UpscalerManager:
|
|
|
"""Manages model lifecycle and selection."""
|
|
|
def __init__(self):
|
|
|
self.strategies: Dict[str, UpscalerStrategy] = {
|
|
|
"SPAN (NomosUni) x2": SpanStrategy(),
|
|
|
"RealESRGAN x2": RealESRGANStrategy(),
|
|
|
"HAT-S x4": HatsStrategy(),
|
|
|
"Swin2SR x2": Swin2SRStrategy(),
|
|
|
"Stable Diffusion x4": StableDiffusionStrategy()
|
|
|
}
|
|
|
self.current_model_name: Optional[str] = None
|
|
|
|
|
|
def get_strategy(self, name: str) -> UpscalerStrategy:
|
|
|
if name not in self.strategies:
|
|
|
raise ValueError(f"Model {name} not found.")
|
|
|
|
|
|
|
|
|
|
|
|
if self.current_model_name != name:
|
|
|
if self.current_model_name is not None:
|
|
|
logger.info(f"Switching models: Unloading {self.current_model_name}...")
|
|
|
self.strategies[self.current_model_name].unload()
|
|
|
self.current_model_name = name
|
|
|
|
|
|
return self.strategies[name]
|
|
|
|
|
|
def unload_all(self):
|
|
|
"""Unload all models to free memory."""
|
|
|
for strategy in self.strategies.values():
|
|
|
strategy.unload()
|
|
|
gc.collect()
|
|
|
logger.info("All models unloaded.")
|
|
|
|
|
|
manager = UpscalerManager()
|
|
|
|
|
|
|
|
|
def process_image(input_img: Image.Image, model_name: str, prompt: str) -> Tuple[Optional[Image.Image], str, str]:
|
|
|
if input_img is None:
|
|
|
return None, get_logs(), get_system_usage()
|
|
|
|
|
|
try:
|
|
|
strategy = manager.get_strategy(model_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output = strategy.upscale(input_img, prompt=prompt)
|
|
|
|
|
|
|
|
|
gc.collect()
|
|
|
|
|
|
return output, get_logs(), get_system_usage()
|
|
|
except Exception as e:
|
|
|
error_msg = f"Critical Error: {str(e)}\n{traceback.format_exc()}"
|
|
|
logger.error(error_msg)
|
|
|
|
|
|
return None, get_logs() + "\n\n" + error_msg, get_system_usage()
|
|
|
|
|
|
def unload_models():
|
|
|
manager.unload_all()
|
|
|
return get_logs(), get_system_usage()
|
|
|
|
|
|
|
|
|
desc = """
|
|
|
### 🚀 Enterprise-Grade Universal Upscaler (SOTA 2025)
|
|
|
Select a specialized model to upscale your image.
|
|
|
* **SPAN (NomosUni) x2**: ⚡ **SOTA Speed**. Fastest CPU model. Best for general use.
|
|
|
* **RealESRGAN x2**: 🛡️ **Robust**. Best for removing JPEG artifacts and noise.
|
|
|
* **HAT-S x4**: 💎 **SOTA Quality**. Best texture detail (slower).
|
|
|
* **Swin2SR x2**: 🎯 High fidelity, removes compression artifacts.
|
|
|
* **Stable Diffusion x4**: 🎨 Generative upscaling. Adds missing details (slow, high RAM).
|
|
|
"""
|
|
|
|
|
|
with gr.Blocks(title="Universal Upscaler Pro") as iface:
|
|
|
gr.Markdown(desc)
|
|
|
|
|
|
with gr.Row():
|
|
|
with gr.Column(scale=1):
|
|
|
input_image = gr.Image(type="pil", label="Input Image")
|
|
|
|
|
|
with gr.Group():
|
|
|
model_selector = gr.Dropdown(
|
|
|
choices=list(manager.strategies.keys()),
|
|
|
value="SPAN (NomosUni) x2",
|
|
|
label="Select Model Architecture"
|
|
|
)
|
|
|
prompt_input = gr.Textbox(
|
|
|
label="Prompt (Stable Diffusion Only)",
|
|
|
value="highly detailed, 4k, sharp",
|
|
|
placeholder="Describe the image content..."
|
|
|
)
|
|
|
|
|
|
with gr.Accordion("Advanced Settings", open=False):
|
|
|
gr.Markdown("Memory Management")
|
|
|
unload_btn = gr.Button("Unload All Models (Free RAM)", variant="secondary")
|
|
|
|
|
|
submit_btn = gr.Button("✨ Upscale Image", variant="primary", size="lg")
|
|
|
system_info = gr.Label(value=get_system_usage(), label="System Status")
|
|
|
|
|
|
with gr.Column(scale=1):
|
|
|
output_image = gr.Image(type="pil", label="Upscaled Result")
|
|
|
logs_output = gr.TextArea(label="Execution Logs", interactive=False, lines=10)
|
|
|
|
|
|
|
|
|
submit_btn.click(
|
|
|
fn=process_image,
|
|
|
inputs=[input_image, model_selector, prompt_input],
|
|
|
outputs=[output_image, logs_output, system_info]
|
|
|
)
|
|
|
|
|
|
unload_btn.click(
|
|
|
fn=unload_models,
|
|
|
inputs=[],
|
|
|
outputs=[logs_output, system_info]
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
iface.launch() |