|
|
""" |
|
|
Gradio interface for HATSAT application. |
|
|
""" |
|
|
|
|
|
import gradio as gr |
|
|
from PIL import Image |
|
|
|
|
|
from config import REQUIRED_IMAGE_SIZE, UPSCALE_FACTOR |
|
|
from utils.image_utils import validate_image_size, upscale_image |
|
|
from interface.css_styles import generate_css, get_sample_images |
|
|
|
|
|
|
|
|
def upscale_and_display(image, model, device): |
|
|
"""Process image upload and return upscaled result.""" |
|
|
if image is None: |
|
|
return None, "Please upload an image or select a sample image." |
|
|
|
|
|
|
|
|
is_valid, message = validate_image_size(image) |
|
|
if not is_valid: |
|
|
return None, f"❌ Error: {message}" |
|
|
|
|
|
try: |
|
|
|
|
|
upscaled = upscale_image(image, model, device) |
|
|
return upscaled, "✅ Image successfully enhanced!" |
|
|
except Exception as e: |
|
|
return None, f"❌ Error processing image: {str(e)}" |
|
|
|
|
|
|
|
|
def select_sample_image(image_path): |
|
|
"""Load and return a sample image.""" |
|
|
if image_path: |
|
|
return Image.open(image_path) |
|
|
return None |
|
|
|
|
|
|
|
|
def create_interface(model, device): |
|
|
"""Create and configure the Gradio interface.""" |
|
|
css = generate_css() |
|
|
|
|
|
with gr.Blocks(css=css, title="HATSAT - Super-Resolution for Satellite Images") as iface: |
|
|
gr.Markdown("# HATSAT - Super-Resolution for Satellite Images") |
|
|
gr.Markdown(f"Upload a satellite image or select a sample to enhance its resolution by {UPSCALE_FACTOR}x.") |
|
|
gr.Markdown(f"⚠️ **Important**: Images must be exactly **{REQUIRED_IMAGE_SIZE[0]}x{REQUIRED_IMAGE_SIZE[1]} pixels** for the model to work properly.") |
|
|
|
|
|
|
|
|
with gr.Accordion("Acknowledgments", open=False): |
|
|
gr.Markdown(""" |
|
|
### Base Model: HAT (Hybrid Attention Transformer) |
|
|
This model is a fine tuned version of **HAT**: |
|
|
- **GitHub Repository**: [https://github.com/XPixelGroup/HAT](https://github.com/XPixelGroup/HAT) |
|
|
- **Paper**: [Activating More Pixels in Image Super-Resolution Transformer](https://arxiv.org/abs/2205.04437) |
|
|
- **Authors**: Xiangyu Chen, Xintao Wang, Jiantao Zhou, Yu Qiao, Chao Dong |
|
|
|
|
|
### Training Dataset: SEN2NAIPv2 |
|
|
The model was fine-tuned using the **SEN2NAIPv2** dataset: |
|
|
- **HuggingFace Dataset**: [https://huggingface.co/datasets/tacofoundation/SEN2NAIPv2](https://huggingface.co/datasets/tacofoundation/SEN2NAIPv2) |
|
|
- **Description**: High-resolution satellite imagery dataset for super-resolution tasks |
|
|
""") |
|
|
|
|
|
|
|
|
sample_images = get_sample_images() |
|
|
sample_buttons = [] |
|
|
if sample_images: |
|
|
gr.Markdown("**Sample Images (click to select):**") |
|
|
with gr.Row(): |
|
|
for i, img_path in enumerate(sample_images): |
|
|
btn = gr.Button( |
|
|
"", |
|
|
elem_id=f"sample_btn_{i}", |
|
|
elem_classes="sample-image-btn" |
|
|
) |
|
|
sample_buttons.append((btn, img_path)) |
|
|
|
|
|
with gr.Row(): |
|
|
input_image = gr.Image( |
|
|
type="pil", |
|
|
label=f"Input Image (must be {REQUIRED_IMAGE_SIZE[0]}x{REQUIRED_IMAGE_SIZE[1]} pixels)", |
|
|
elem_classes="image-container", |
|
|
sources=["upload"], |
|
|
height=500, |
|
|
width=500 |
|
|
) |
|
|
|
|
|
output_image = gr.Image( |
|
|
type="pil", |
|
|
label=f"Enhanced Output ({UPSCALE_FACTOR}x)", |
|
|
elem_classes="image-container", |
|
|
interactive=False, |
|
|
height=500, |
|
|
width=500, |
|
|
show_download_button=True |
|
|
) |
|
|
|
|
|
submit_btn = gr.Button("Enhance Image", variant="primary") |
|
|
|
|
|
|
|
|
status_message = gr.Textbox( |
|
|
label="Status", |
|
|
interactive=False, |
|
|
show_label=True |
|
|
) |
|
|
|
|
|
|
|
|
if sample_images: |
|
|
for btn, img_path in sample_buttons: |
|
|
btn.click(fn=lambda path=img_path: select_sample_image(path), outputs=input_image) |
|
|
|
|
|
submit_btn.click( |
|
|
fn=lambda img: upscale_and_display(img, model, device), |
|
|
inputs=input_image, |
|
|
outputs=[output_image, status_message] |
|
|
) |
|
|
|
|
|
return iface |