File size: 4,377 Bytes
0def483 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
"""
Gradio interface for HATSAT application.
"""
import gradio as gr
from PIL import Image
from config import REQUIRED_IMAGE_SIZE, UPSCALE_FACTOR
from utils.image_utils import validate_image_size, upscale_image
from interface.css_styles import generate_css, get_sample_images
def upscale_and_display(image, model, device):
"""Process image upload and return upscaled result."""
if image is None:
return None, "Please upload an image or select a sample image."
# Validate image size
is_valid, message = validate_image_size(image)
if not is_valid:
return None, f"❌ Error: {message}"
try:
# Get the super-resolution output
upscaled = upscale_image(image, model, device)
return upscaled, "✅ Image successfully enhanced!"
except Exception as e:
return None, f"❌ Error processing image: {str(e)}"
def select_sample_image(image_path):
"""Load and return a sample image."""
if image_path:
return Image.open(image_path)
return None
def create_interface(model, device):
"""Create and configure the Gradio interface."""
css = generate_css()
with gr.Blocks(css=css, title="HATSAT - Super-Resolution for Satellite Images") as iface:
gr.Markdown("# HATSAT - Super-Resolution for Satellite Images")
gr.Markdown(f"Upload a satellite image or select a sample to enhance its resolution by {UPSCALE_FACTOR}x.")
gr.Markdown(f"⚠️ **Important**: Images must be exactly **{REQUIRED_IMAGE_SIZE[0]}x{REQUIRED_IMAGE_SIZE[1]} pixels** for the model to work properly.")
# Acknowledgments section
with gr.Accordion("Acknowledgments", open=False):
gr.Markdown("""
### Base Model: HAT (Hybrid Attention Transformer)
This model is a fine tuned version of **HAT**:
- **GitHub Repository**: [https://github.com/XPixelGroup/HAT](https://github.com/XPixelGroup/HAT)
- **Paper**: [Activating More Pixels in Image Super-Resolution Transformer](https://arxiv.org/abs/2205.04437)
- **Authors**: Xiangyu Chen, Xintao Wang, Jiantao Zhou, Yu Qiao, Chao Dong
### Training Dataset: SEN2NAIPv2
The model was fine-tuned using the **SEN2NAIPv2** dataset:
- **HuggingFace Dataset**: [https://huggingface.co/datasets/tacofoundation/SEN2NAIPv2](https://huggingface.co/datasets/tacofoundation/SEN2NAIPv2)
- **Description**: High-resolution satellite imagery dataset for super-resolution tasks
""")
# Sample images
sample_images = get_sample_images()
sample_buttons = []
if sample_images:
gr.Markdown("**Sample Images (click to select):**")
with gr.Row():
for i, img_path in enumerate(sample_images):
btn = gr.Button(
"",
elem_id=f"sample_btn_{i}",
elem_classes="sample-image-btn"
)
sample_buttons.append((btn, img_path))
with gr.Row():
input_image = gr.Image(
type="pil",
label=f"Input Image (must be {REQUIRED_IMAGE_SIZE[0]}x{REQUIRED_IMAGE_SIZE[1]} pixels)",
elem_classes="image-container",
sources=["upload"],
height=500,
width=500
)
output_image = gr.Image(
type="pil",
label=f"Enhanced Output ({UPSCALE_FACTOR}x)",
elem_classes="image-container",
interactive=False,
height=500,
width=500,
show_download_button=True
)
submit_btn = gr.Button("Enhance Image", variant="primary")
# Status message
status_message = gr.Textbox(
label="Status",
interactive=False,
show_label=True
)
# Event handlers
if sample_images:
for btn, img_path in sample_buttons:
btn.click(fn=lambda path=img_path: select_sample_image(path), outputs=input_image)
submit_btn.click(
fn=lambda img: upscale_and_display(img, model, device),
inputs=input_image,
outputs=[output_image, status_message]
)
return iface |