Spaces:
Sleeping
Sleeping
Upload 2 files
Browse files
app.py
CHANGED
|
@@ -237,7 +237,11 @@ class RealESRGANStrategy(UpscalerStrategy):
|
|
| 237 |
dtype = torch.float16 if Config.DEVICE == 'cuda' else torch.bfloat16
|
| 238 |
|
| 239 |
try:
|
| 240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
if tile_config['tile'] > 0:
|
| 242 |
output_tensor = manual_tile_upscale(
|
| 243 |
self.model,
|
|
@@ -282,6 +286,8 @@ class Swin2SRStrategy(UpscalerStrategy):
|
|
| 282 |
logger.info(f"{self.name} loaded successfully.")
|
| 283 |
except Exception as e:
|
| 284 |
logger.error(f"Failed to load Swin2SR: {e}")
|
|
|
|
|
|
|
| 285 |
raise
|
| 286 |
|
| 287 |
def upscale(self, image: Image.Image, **kwargs) -> Image.Image:
|
|
@@ -296,7 +302,11 @@ class Swin2SRStrategy(UpscalerStrategy):
|
|
| 296 |
|
| 297 |
inputs = self.processor(images=image, return_tensors="pt").to(Config.DEVICE)
|
| 298 |
|
| 299 |
-
with
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
outputs = self.model(**inputs)
|
| 301 |
|
| 302 |
output = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
|
@@ -400,8 +410,7 @@ class SpanStrategy(UpscalerStrategy):
|
|
| 400 |
self.model = torch.compile(self.model)
|
| 401 |
logger.info("✓ torch.compile enabled (default mode)")
|
| 402 |
self.compiled = True
|
| 403 |
-
except Exception
|
| 404 |
-
logger.warning(f"⚠ torch.compile failed: {e}")
|
| 405 |
self.compiled = True
|
| 406 |
|
| 407 |
logger.info(f"{self.name} loaded successfully.")
|
|
@@ -423,10 +432,15 @@ class SpanStrategy(UpscalerStrategy):
|
|
| 423 |
h, w = img_np.shape[:2]
|
| 424 |
tile_config = select_tile_config(h, w)
|
| 425 |
|
| 426 |
-
|
|
|
|
|
|
|
| 427 |
|
| 428 |
try:
|
| 429 |
-
|
|
|
|
|
|
|
|
|
|
| 430 |
if tile_config['tile'] > 0:
|
| 431 |
output_tensor = manual_tile_upscale(
|
| 432 |
self.model,
|
|
|
|
| 237 |
dtype = torch.float16 if Config.DEVICE == 'cuda' else torch.bfloat16
|
| 238 |
|
| 239 |
try:
|
| 240 |
+
# Explicitly disable autocast on CPU for RealESRGAN to avoid "PythonFallbackKernel" errors
|
| 241 |
+
# This seems to be a regression in recent PyTorch versions on CPU with some ops
|
| 242 |
+
context = torch.autocast(device_type=Config.DEVICE, dtype=dtype) if Config.DEVICE != 'cpu' else torch.no_grad()
|
| 243 |
+
|
| 244 |
+
with context:
|
| 245 |
if tile_config['tile'] > 0:
|
| 246 |
output_tensor = manual_tile_upscale(
|
| 247 |
self.model,
|
|
|
|
| 286 |
logger.info(f"{self.name} loaded successfully.")
|
| 287 |
except Exception as e:
|
| 288 |
logger.error(f"Failed to load Swin2SR: {e}")
|
| 289 |
+
# Swin2SR loading failure is often due to transformers version mismatch or device issues
|
| 290 |
+
# We re-raise to let the UI handle it, but log the specific error
|
| 291 |
raise
|
| 292 |
|
| 293 |
def upscale(self, image: Image.Image, **kwargs) -> Image.Image:
|
|
|
|
| 302 |
|
| 303 |
inputs = self.processor(images=image, return_tensors="pt").to(Config.DEVICE)
|
| 304 |
|
| 305 |
+
# Swin2SR on CPU can be finicky with autocast/tracing.
|
| 306 |
+
# Explicitly disable autocast for Swin2SR on CPU to avoid "PythonFallbackKernel" errors
|
| 307 |
+
context = torch.no_grad()
|
| 308 |
+
|
| 309 |
+
with context:
|
| 310 |
outputs = self.model(**inputs)
|
| 311 |
|
| 312 |
output = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
|
|
|
| 410 |
self.model = torch.compile(self.model)
|
| 411 |
logger.info("✓ torch.compile enabled (default mode)")
|
| 412 |
self.compiled = True
|
| 413 |
+
except Exception:
|
|
|
|
| 414 |
self.compiled = True
|
| 415 |
|
| 416 |
logger.info(f"{self.name} loaded successfully.")
|
|
|
|
| 432 |
h, w = img_np.shape[:2]
|
| 433 |
tile_config = select_tile_config(h, w)
|
| 434 |
|
| 435 |
+
# Disable AMP for SPAN on CPU to avoid "UntypedStorage" weakref errors in inductor
|
| 436 |
+
# SPAN architecture seems sensitive to autocast + compile on CPU
|
| 437 |
+
dtype = torch.float32 if Config.DEVICE == 'cpu' else torch.float16
|
| 438 |
|
| 439 |
try:
|
| 440 |
+
# Only use autocast if not CPU or if explicitly desired
|
| 441 |
+
context = torch.autocast(device_type=Config.DEVICE, dtype=dtype) if Config.DEVICE != 'cpu' else torch.no_grad()
|
| 442 |
+
|
| 443 |
+
with context:
|
| 444 |
if tile_config['tile'] > 0:
|
| 445 |
output_tensor = manual_tile_upscale(
|
| 446 |
self.model,
|