""" Gradio interface for HATSAT application. """ import gradio as gr from PIL import Image from config import REQUIRED_IMAGE_SIZE, UPSCALE_FACTOR from utils.image_utils import validate_image_size, upscale_image from interface.css_styles import generate_css, get_sample_images def upscale_and_display(image, model, device): """Process image upload and return upscaled result.""" if image is None: return None, "Please upload an image or select a sample image." # Validate image size is_valid, message = validate_image_size(image) if not is_valid: return None, f"❌ Error: {message}" try: # Get the super-resolution output upscaled = upscale_image(image, model, device) return upscaled, "✅ Image successfully enhanced!" except Exception as e: return None, f"❌ Error processing image: {str(e)}" def select_sample_image(image_path): """Load and return a sample image.""" if image_path: return Image.open(image_path) return None def create_interface(model, device): """Create and configure the Gradio interface.""" css = generate_css() with gr.Blocks(css=css, title="HATSAT - Super-Resolution for Satellite Images") as iface: gr.Markdown("# HATSAT - Super-Resolution for Satellite Images") gr.Markdown(f"Upload a satellite image or select a sample to enhance its resolution by {UPSCALE_FACTOR}x.") gr.Markdown(f"⚠️ **Important**: Images must be exactly **{REQUIRED_IMAGE_SIZE[0]}x{REQUIRED_IMAGE_SIZE[1]} pixels** for the model to work properly.") # Acknowledgments section with gr.Accordion("Acknowledgments", open=False): gr.Markdown(""" ### Base Model: HAT (Hybrid Attention Transformer) This model is a fine tuned version of **HAT**: - **GitHub Repository**: [https://github.com/XPixelGroup/HAT](https://github.com/XPixelGroup/HAT) - **Paper**: [Activating More Pixels in Image Super-Resolution Transformer](https://arxiv.org/abs/2205.04437) - **Authors**: Xiangyu Chen, Xintao Wang, Jiantao Zhou, Yu Qiao, Chao Dong ### Training Dataset: SEN2NAIPv2 The model was fine-tuned using the **SEN2NAIPv2** dataset: - **HuggingFace Dataset**: [https://huggingface.co/datasets/tacofoundation/SEN2NAIPv2](https://huggingface.co/datasets/tacofoundation/SEN2NAIPv2) - **Description**: High-resolution satellite imagery dataset for super-resolution tasks """) # Sample images sample_images = get_sample_images() sample_buttons = [] if sample_images: gr.Markdown("**Sample Images (click to select):**") with gr.Row(): for i, img_path in enumerate(sample_images): btn = gr.Button( "", elem_id=f"sample_btn_{i}", elem_classes="sample-image-btn" ) sample_buttons.append((btn, img_path)) with gr.Row(): input_image = gr.Image( type="pil", label=f"Input Image (must be {REQUIRED_IMAGE_SIZE[0]}x{REQUIRED_IMAGE_SIZE[1]} pixels)", elem_classes="image-container", sources=["upload"], height=500, width=500 ) output_image = gr.Image( type="pil", label=f"Enhanced Output ({UPSCALE_FACTOR}x)", elem_classes="image-container", interactive=False, height=500, width=500, show_download_button=True ) submit_btn = gr.Button("Enhance Image", variant="primary") # Status message status_message = gr.Textbox( label="Status", interactive=False, show_label=True ) # Event handlers if sample_images: for btn, img_path in sample_buttons: btn.click(fn=lambda path=img_path: select_sample_image(path), outputs=input_image) submit_btn.click( fn=lambda img: upscale_and_display(img, model, device), inputs=input_image, outputs=[output_image, status_message] ) return iface