BorisEm commited on
Commit
6cc2b3b
·
1 Parent(s): 4993aa4

Simple refactor for readibility

Browse files
Files changed (1) hide show
  1. app.py +30 -21
app.py CHANGED
@@ -10,6 +10,12 @@ import glob
10
  import base64
11
  from io import BytesIO
12
 
 
 
 
 
 
 
13
 
14
  def to_2tuple(x):
15
  """Convert input to tuple of length 2."""
@@ -685,13 +691,10 @@ model = HAT(
685
  )
686
 
687
  # Load the fine-tuned weights
688
- checkpoint = torch.load('net_g_150000.pth', map_location=device)
689
- if 'params_ema' in checkpoint:
690
- model.load_state_dict(checkpoint['params_ema'])
691
- elif 'params' in checkpoint:
692
- model.load_state_dict(checkpoint['params'])
693
- else:
694
- model.load_state_dict(checkpoint)
695
 
696
  model.to(device)
697
  model.eval()
@@ -706,8 +709,8 @@ def upscale_image(image):
706
  h, w = img_tensor.shape[2], img_tensor.shape[3]
707
 
708
  # Pad if necessary
709
- pad_h = (16 - h % 16) % 16
710
- pad_w = (16 - w % 16) % 16
711
 
712
  if pad_h > 0 or pad_w > 0:
713
  img_tensor = torch.nn.functional.pad(img_tensor, (0, pad_w, 0, pad_h), mode='reflect')
@@ -717,7 +720,7 @@ def upscale_image(image):
717
 
718
  # Remove padding if it was added
719
  if pad_h > 0 or pad_w > 0:
720
- output = output[:, :, :h*4, :w*4]
721
 
722
  # Convert back to PIL image
723
  output_np = output.squeeze(0).permute(1, 2, 0).cpu().numpy()
@@ -736,13 +739,14 @@ def get_sample_images():
736
 
737
  # Gradio interface using Blocks for better layout control
738
  def validate_image_size(image):
739
- """Validate that the image is exactly 130x130 pixels"""
740
  if image is None:
741
  return False, "No image provided"
742
 
743
  width, height = image.size
744
- if width != 130 or height != 130:
745
- return False, f"Image must be exactly 130x130 pixels. Your image is {width}x{height} pixels."
 
746
 
747
  return True, "Valid image size"
748
 
@@ -811,11 +815,16 @@ def generate_css():
811
  }
812
  """
813
 
814
- # Add background images for each sample
815
  sample_images = get_sample_images()
816
- for i, img_path in enumerate(sample_images):
817
- base64_img = image_to_base64(img_path)
818
- base_css += f"#sample_btn_{i} {{ background-image: url('{base64_img}'); }}\n"
 
 
 
 
 
819
 
820
  return base_css
821
 
@@ -823,8 +832,8 @@ css = generate_css()
823
 
824
  with gr.Blocks(css=css, title="HATSAT - Super-Resolution for Satellite Images") as iface:
825
  gr.Markdown("# HATSAT - Super-Resolution for Satellite Images")
826
- gr.Markdown("Upload a satellite image or select a sample to enhance its resolution by 4x.")
827
- gr.Markdown("⚠️ **Important**: Images must be exactly **130x130 pixels** for the model to work properly.")
828
 
829
  # Acknowledgments section
830
  with gr.Accordion("Acknowledgments", open=False):
@@ -858,7 +867,7 @@ with gr.Blocks(css=css, title="HATSAT - Super-Resolution for Satellite Images")
858
  with gr.Row():
859
  input_image = gr.Image(
860
  type="pil",
861
- label="Input Image (must be 130x130 pixels)",
862
  elem_classes="image-container",
863
  sources=["upload"],
864
  height=500,
@@ -867,7 +876,7 @@ with gr.Blocks(css=css, title="HATSAT - Super-Resolution for Satellite Images")
867
 
868
  output_image = gr.Image(
869
  type="pil",
870
- label="Enhanced Output (4x)",
871
  elem_classes="image-container",
872
  interactive=False,
873
  height=500,
 
10
  import base64
11
  from io import BytesIO
12
 
13
+ # Constants
14
+ MODEL_CHECKPOINT = 'net_g_150000.pth'
15
+ REQUIRED_IMAGE_SIZE = (130, 130)
16
+ WINDOW_SIZE = 16
17
+ UPSCALE_FACTOR = 4
18
+
19
 
20
  def to_2tuple(x):
21
  """Convert input to tuple of length 2."""
 
691
  )
692
 
693
  # Load the fine-tuned weights
694
+ checkpoint = torch.load(MODEL_CHECKPOINT, map_location=device)
695
+ # Try different checkpoint formats
696
+ state_dict = checkpoint.get('params_ema') or checkpoint.get('params') or checkpoint
697
+ model.load_state_dict(state_dict)
 
 
 
698
 
699
  model.to(device)
700
  model.eval()
 
709
  h, w = img_tensor.shape[2], img_tensor.shape[3]
710
 
711
  # Pad if necessary
712
+ pad_h = (WINDOW_SIZE - h % WINDOW_SIZE) % WINDOW_SIZE
713
+ pad_w = (WINDOW_SIZE - w % WINDOW_SIZE) % WINDOW_SIZE
714
 
715
  if pad_h > 0 or pad_w > 0:
716
  img_tensor = torch.nn.functional.pad(img_tensor, (0, pad_w, 0, pad_h), mode='reflect')
 
720
 
721
  # Remove padding if it was added
722
  if pad_h > 0 or pad_w > 0:
723
+ output = output[:, :, :h*UPSCALE_FACTOR, :w*UPSCALE_FACTOR]
724
 
725
  # Convert back to PIL image
726
  output_np = output.squeeze(0).permute(1, 2, 0).cpu().numpy()
 
739
 
740
  # Gradio interface using Blocks for better layout control
741
  def validate_image_size(image):
742
+ """Validate that the image is exactly the required size"""
743
  if image is None:
744
  return False, "No image provided"
745
 
746
  width, height = image.size
747
+ req_width, req_height = REQUIRED_IMAGE_SIZE
748
+ if width != req_width or height != req_height:
749
+ return False, f"Image must be exactly {req_width}x{req_height} pixels. Your image is {width}x{height} pixels."
750
 
751
  return True, "Valid image size"
752
 
 
815
  }
816
  """
817
 
818
+ # Add background images for each sample (only if samples exist)
819
  sample_images = get_sample_images()
820
+ if sample_images:
821
+ for i, img_path in enumerate(sample_images):
822
+ try:
823
+ base64_img = image_to_base64(img_path)
824
+ base_css += f"#sample_btn_{i} {{ background-image: url('{base64_img}'); }}\n"
825
+ except Exception:
826
+ # Skip invalid images
827
+ continue
828
 
829
  return base_css
830
 
 
832
 
833
  with gr.Blocks(css=css, title="HATSAT - Super-Resolution for Satellite Images") as iface:
834
  gr.Markdown("# HATSAT - Super-Resolution for Satellite Images")
835
+ gr.Markdown(f"Upload a satellite image or select a sample to enhance its resolution by {UPSCALE_FACTOR}x.")
836
+ gr.Markdown(f"⚠️ **Important**: Images must be exactly **{REQUIRED_IMAGE_SIZE[0]}x{REQUIRED_IMAGE_SIZE[1]} pixels** for the model to work properly.")
837
 
838
  # Acknowledgments section
839
  with gr.Accordion("Acknowledgments", open=False):
 
867
  with gr.Row():
868
  input_image = gr.Image(
869
  type="pil",
870
+ label=f"Input Image (must be {REQUIRED_IMAGE_SIZE[0]}x{REQUIRED_IMAGE_SIZE[1]} pixels)",
871
  elem_classes="image-container",
872
  sources=["upload"],
873
  height=500,
 
876
 
877
  output_image = gr.Image(
878
  type="pil",
879
+ label=f"Enhanced Output ({UPSCALE_FACTOR}x)",
880
  elem_classes="image-container",
881
  interactive=False,
882
  height=500,