jgitsolutions commited on
Commit
3ee9494
·
verified ·
1 Parent(s): 745c2c7

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -4
app.py CHANGED
@@ -407,8 +407,9 @@ class SpanStrategy(UpscalerStrategy):
407
  elif (psutil.cpu_count(logical=False) or 0) < 4 and Config.DEVICE == 'cpu':
408
  logger.info("ℹ Skipping torch.compile on low-core CPU.")
409
  else:
410
- self.model = torch.compile(self.model)
411
- logger.info(" torch.compile enabled (default mode)")
 
412
  self.compiled = True
413
  except Exception:
414
  self.compiled = True
@@ -524,10 +525,11 @@ class HatsStrategy(UpscalerStrategy):
524
  h, w = img_np.shape[:2]
525
  tile_config = select_tile_config(h, w)
526
 
527
- dtype = torch.float16 if Config.DEVICE == 'cuda' else torch.bfloat16
528
 
529
  try:
530
- with torch.autocast(device_type=Config.DEVICE, dtype=dtype):
 
531
  if tile_config['tile'] > 0:
532
  output_tensor = manual_tile_upscale(
533
  self.model,
 
407
  elif (psutil.cpu_count(logical=False) or 0) < 4 and Config.DEVICE == 'cpu':
408
  logger.info("ℹ Skipping torch.compile on low-core CPU.")
409
  else:
410
+ # SPAN architecture uses .data.clone() in forward pass which breaks torch.compile/inductor
411
+ logger.info(" Skipping torch.compile for SPAN (incompatible architecture).")
412
+ # self.model = torch.compile(self.model)
413
  self.compiled = True
414
  except Exception:
415
  self.compiled = True
 
525
  h, w = img_np.shape[:2]
526
  tile_config = select_tile_config(h, w)
527
 
528
+ dtype = torch.float16 if Config.DEVICE == 'cuda' else torch.float32
529
 
530
  try:
531
+ context = torch.autocast(device_type=Config.DEVICE, dtype=dtype) if Config.DEVICE != 'cpu' else torch.no_grad()
532
+ with context:
533
  if tile_config['tile'] > 0:
534
  output_tensor = manual_tile_upscale(
535
  self.model,