jgitsolutions commited on
Commit
745c2c7
·
verified ·
1 Parent(s): e58a107

Upload 2 files

Browse files
Files changed (1) hide show
  1. app.py +20 -6
app.py CHANGED
@@ -237,7 +237,11 @@ class RealESRGANStrategy(UpscalerStrategy):
237
  dtype = torch.float16 if Config.DEVICE == 'cuda' else torch.bfloat16
238
 
239
  try:
240
- with torch.autocast(device_type=Config.DEVICE, dtype=dtype):
 
 
 
 
241
  if tile_config['tile'] > 0:
242
  output_tensor = manual_tile_upscale(
243
  self.model,
@@ -282,6 +286,8 @@ class Swin2SRStrategy(UpscalerStrategy):
282
  logger.info(f"{self.name} loaded successfully.")
283
  except Exception as e:
284
  logger.error(f"Failed to load Swin2SR: {e}")
 
 
285
  raise
286
 
287
  def upscale(self, image: Image.Image, **kwargs) -> Image.Image:
@@ -296,7 +302,11 @@ class Swin2SRStrategy(UpscalerStrategy):
296
 
297
  inputs = self.processor(images=image, return_tensors="pt").to(Config.DEVICE)
298
 
299
- with torch.no_grad():
 
 
 
 
300
  outputs = self.model(**inputs)
301
 
302
  output = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
@@ -400,8 +410,7 @@ class SpanStrategy(UpscalerStrategy):
400
  self.model = torch.compile(self.model)
401
  logger.info("✓ torch.compile enabled (default mode)")
402
  self.compiled = True
403
- except Exception as e:
404
- logger.warning(f"⚠ torch.compile failed: {e}")
405
  self.compiled = True
406
 
407
  logger.info(f"{self.name} loaded successfully.")
@@ -423,10 +432,15 @@ class SpanStrategy(UpscalerStrategy):
423
  h, w = img_np.shape[:2]
424
  tile_config = select_tile_config(h, w)
425
 
426
- dtype = torch.float16 if Config.DEVICE == 'cuda' else torch.bfloat16
 
 
427
 
428
  try:
429
- with torch.autocast(device_type=Config.DEVICE, dtype=dtype):
 
 
 
430
  if tile_config['tile'] > 0:
431
  output_tensor = manual_tile_upscale(
432
  self.model,
 
237
  dtype = torch.float16 if Config.DEVICE == 'cuda' else torch.bfloat16
238
 
239
  try:
240
+ # Explicitly disable autocast on CPU for RealESRGAN to avoid "PythonFallbackKernel" errors
241
+ # This seems to be a regression in recent PyTorch versions on CPU with some ops
242
+ context = torch.autocast(device_type=Config.DEVICE, dtype=dtype) if Config.DEVICE != 'cpu' else torch.no_grad()
243
+
244
+ with context:
245
  if tile_config['tile'] > 0:
246
  output_tensor = manual_tile_upscale(
247
  self.model,
 
286
  logger.info(f"{self.name} loaded successfully.")
287
  except Exception as e:
288
  logger.error(f"Failed to load Swin2SR: {e}")
289
+ # Swin2SR loading failure is often due to transformers version mismatch or device issues
290
+ # We re-raise to let the UI handle it, but log the specific error
291
  raise
292
 
293
  def upscale(self, image: Image.Image, **kwargs) -> Image.Image:
 
302
 
303
  inputs = self.processor(images=image, return_tensors="pt").to(Config.DEVICE)
304
 
305
+ # Swin2SR on CPU can be finicky with autocast/tracing.
306
+ # Explicitly disable autocast for Swin2SR on CPU to avoid "PythonFallbackKernel" errors
307
+ context = torch.no_grad()
308
+
309
+ with context:
310
  outputs = self.model(**inputs)
311
 
312
  output = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
 
410
  self.model = torch.compile(self.model)
411
  logger.info("✓ torch.compile enabled (default mode)")
412
  self.compiled = True
413
+ except Exception:
 
414
  self.compiled = True
415
 
416
  logger.info(f"{self.name} loaded successfully.")
 
432
  h, w = img_np.shape[:2]
433
  tile_config = select_tile_config(h, w)
434
 
435
+ # Disable AMP for SPAN on CPU to avoid "UntypedStorage" weakref errors in inductor
436
+ # SPAN architecture seems sensitive to autocast + compile on CPU
437
+ dtype = torch.float32 if Config.DEVICE == 'cpu' else torch.float16
438
 
439
  try:
440
+ # Only use autocast if not CPU or if explicitly desired
441
+ context = torch.autocast(device_type=Config.DEVICE, dtype=dtype) if Config.DEVICE != 'cpu' else torch.no_grad()
442
+
443
+ with context:
444
  if tile_config['tile'] > 0:
445
  output_tensor = manual_tile_upscale(
446
  self.model,