Simple refactor for readibility
Browse files
app.py
CHANGED
|
@@ -10,6 +10,12 @@ import glob
|
|
| 10 |
import base64
|
| 11 |
from io import BytesIO
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
def to_2tuple(x):
|
| 15 |
"""Convert input to tuple of length 2."""
|
|
@@ -685,13 +691,10 @@ model = HAT(
|
|
| 685 |
)
|
| 686 |
|
| 687 |
# Load the fine-tuned weights
|
| 688 |
-
checkpoint = torch.load(
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
model.load_state_dict(checkpoint['params'])
|
| 693 |
-
else:
|
| 694 |
-
model.load_state_dict(checkpoint)
|
| 695 |
|
| 696 |
model.to(device)
|
| 697 |
model.eval()
|
|
@@ -706,8 +709,8 @@ def upscale_image(image):
|
|
| 706 |
h, w = img_tensor.shape[2], img_tensor.shape[3]
|
| 707 |
|
| 708 |
# Pad if necessary
|
| 709 |
-
pad_h = (
|
| 710 |
-
pad_w = (
|
| 711 |
|
| 712 |
if pad_h > 0 or pad_w > 0:
|
| 713 |
img_tensor = torch.nn.functional.pad(img_tensor, (0, pad_w, 0, pad_h), mode='reflect')
|
|
@@ -717,7 +720,7 @@ def upscale_image(image):
|
|
| 717 |
|
| 718 |
# Remove padding if it was added
|
| 719 |
if pad_h > 0 or pad_w > 0:
|
| 720 |
-
output = output[:, :, :h*
|
| 721 |
|
| 722 |
# Convert back to PIL image
|
| 723 |
output_np = output.squeeze(0).permute(1, 2, 0).cpu().numpy()
|
|
@@ -736,13 +739,14 @@ def get_sample_images():
|
|
| 736 |
|
| 737 |
# Gradio interface using Blocks for better layout control
|
| 738 |
def validate_image_size(image):
|
| 739 |
-
"""Validate that the image is exactly
|
| 740 |
if image is None:
|
| 741 |
return False, "No image provided"
|
| 742 |
|
| 743 |
width, height = image.size
|
| 744 |
-
|
| 745 |
-
|
|
|
|
| 746 |
|
| 747 |
return True, "Valid image size"
|
| 748 |
|
|
@@ -811,11 +815,16 @@ def generate_css():
|
|
| 811 |
}
|
| 812 |
"""
|
| 813 |
|
| 814 |
-
# Add background images for each sample
|
| 815 |
sample_images = get_sample_images()
|
| 816 |
-
|
| 817 |
-
|
| 818 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 819 |
|
| 820 |
return base_css
|
| 821 |
|
|
@@ -823,8 +832,8 @@ css = generate_css()
|
|
| 823 |
|
| 824 |
with gr.Blocks(css=css, title="HATSAT - Super-Resolution for Satellite Images") as iface:
|
| 825 |
gr.Markdown("# HATSAT - Super-Resolution for Satellite Images")
|
| 826 |
-
gr.Markdown("Upload a satellite image or select a sample to enhance its resolution by
|
| 827 |
-
gr.Markdown("⚠️ **Important**: Images must be exactly **
|
| 828 |
|
| 829 |
# Acknowledgments section
|
| 830 |
with gr.Accordion("Acknowledgments", open=False):
|
|
@@ -858,7 +867,7 @@ with gr.Blocks(css=css, title="HATSAT - Super-Resolution for Satellite Images")
|
|
| 858 |
with gr.Row():
|
| 859 |
input_image = gr.Image(
|
| 860 |
type="pil",
|
| 861 |
-
label="Input Image (must be
|
| 862 |
elem_classes="image-container",
|
| 863 |
sources=["upload"],
|
| 864 |
height=500,
|
|
@@ -867,7 +876,7 @@ with gr.Blocks(css=css, title="HATSAT - Super-Resolution for Satellite Images")
|
|
| 867 |
|
| 868 |
output_image = gr.Image(
|
| 869 |
type="pil",
|
| 870 |
-
label="Enhanced Output (
|
| 871 |
elem_classes="image-container",
|
| 872 |
interactive=False,
|
| 873 |
height=500,
|
|
|
|
| 10 |
import base64
|
| 11 |
from io import BytesIO
|
| 12 |
|
| 13 |
+
# Constants
|
| 14 |
+
MODEL_CHECKPOINT = 'net_g_150000.pth'
|
| 15 |
+
REQUIRED_IMAGE_SIZE = (130, 130)
|
| 16 |
+
WINDOW_SIZE = 16
|
| 17 |
+
UPSCALE_FACTOR = 4
|
| 18 |
+
|
| 19 |
|
| 20 |
def to_2tuple(x):
|
| 21 |
"""Convert input to tuple of length 2."""
|
|
|
|
| 691 |
)
|
| 692 |
|
| 693 |
# Load the fine-tuned weights
|
| 694 |
+
checkpoint = torch.load(MODEL_CHECKPOINT, map_location=device)
|
| 695 |
+
# Try different checkpoint formats
|
| 696 |
+
state_dict = checkpoint.get('params_ema') or checkpoint.get('params') or checkpoint
|
| 697 |
+
model.load_state_dict(state_dict)
|
|
|
|
|
|
|
|
|
|
| 698 |
|
| 699 |
model.to(device)
|
| 700 |
model.eval()
|
|
|
|
| 709 |
h, w = img_tensor.shape[2], img_tensor.shape[3]
|
| 710 |
|
| 711 |
# Pad if necessary
|
| 712 |
+
pad_h = (WINDOW_SIZE - h % WINDOW_SIZE) % WINDOW_SIZE
|
| 713 |
+
pad_w = (WINDOW_SIZE - w % WINDOW_SIZE) % WINDOW_SIZE
|
| 714 |
|
| 715 |
if pad_h > 0 or pad_w > 0:
|
| 716 |
img_tensor = torch.nn.functional.pad(img_tensor, (0, pad_w, 0, pad_h), mode='reflect')
|
|
|
|
| 720 |
|
| 721 |
# Remove padding if it was added
|
| 722 |
if pad_h > 0 or pad_w > 0:
|
| 723 |
+
output = output[:, :, :h*UPSCALE_FACTOR, :w*UPSCALE_FACTOR]
|
| 724 |
|
| 725 |
# Convert back to PIL image
|
| 726 |
output_np = output.squeeze(0).permute(1, 2, 0).cpu().numpy()
|
|
|
|
| 739 |
|
| 740 |
# Gradio interface using Blocks for better layout control
|
| 741 |
def validate_image_size(image):
|
| 742 |
+
"""Validate that the image is exactly the required size"""
|
| 743 |
if image is None:
|
| 744 |
return False, "No image provided"
|
| 745 |
|
| 746 |
width, height = image.size
|
| 747 |
+
req_width, req_height = REQUIRED_IMAGE_SIZE
|
| 748 |
+
if width != req_width or height != req_height:
|
| 749 |
+
return False, f"Image must be exactly {req_width}x{req_height} pixels. Your image is {width}x{height} pixels."
|
| 750 |
|
| 751 |
return True, "Valid image size"
|
| 752 |
|
|
|
|
| 815 |
}
|
| 816 |
"""
|
| 817 |
|
| 818 |
+
# Add background images for each sample (only if samples exist)
|
| 819 |
sample_images = get_sample_images()
|
| 820 |
+
if sample_images:
|
| 821 |
+
for i, img_path in enumerate(sample_images):
|
| 822 |
+
try:
|
| 823 |
+
base64_img = image_to_base64(img_path)
|
| 824 |
+
base_css += f"#sample_btn_{i} {{ background-image: url('{base64_img}'); }}\n"
|
| 825 |
+
except Exception:
|
| 826 |
+
# Skip invalid images
|
| 827 |
+
continue
|
| 828 |
|
| 829 |
return base_css
|
| 830 |
|
|
|
|
| 832 |
|
| 833 |
with gr.Blocks(css=css, title="HATSAT - Super-Resolution for Satellite Images") as iface:
|
| 834 |
gr.Markdown("# HATSAT - Super-Resolution for Satellite Images")
|
| 835 |
+
gr.Markdown(f"Upload a satellite image or select a sample to enhance its resolution by {UPSCALE_FACTOR}x.")
|
| 836 |
+
gr.Markdown(f"⚠️ **Important**: Images must be exactly **{REQUIRED_IMAGE_SIZE[0]}x{REQUIRED_IMAGE_SIZE[1]} pixels** for the model to work properly.")
|
| 837 |
|
| 838 |
# Acknowledgments section
|
| 839 |
with gr.Accordion("Acknowledgments", open=False):
|
|
|
|
| 867 |
with gr.Row():
|
| 868 |
input_image = gr.Image(
|
| 869 |
type="pil",
|
| 870 |
+
label=f"Input Image (must be {REQUIRED_IMAGE_SIZE[0]}x{REQUIRED_IMAGE_SIZE[1]} pixels)",
|
| 871 |
elem_classes="image-container",
|
| 872 |
sources=["upload"],
|
| 873 |
height=500,
|
|
|
|
| 876 |
|
| 877 |
output_image = gr.Image(
|
| 878 |
type="pil",
|
| 879 |
+
label=f"Enhanced Output ({UPSCALE_FACTOR}x)",
|
| 880 |
elem_classes="image-container",
|
| 881 |
interactive=False,
|
| 882 |
height=500,
|