File size: 4,377 Bytes
0def483
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""
Gradio interface for HATSAT application.
"""

import gradio as gr
from PIL import Image

from config import REQUIRED_IMAGE_SIZE, UPSCALE_FACTOR
from utils.image_utils import validate_image_size, upscale_image
from interface.css_styles import generate_css, get_sample_images


def upscale_and_display(image, model, device):
    """Process image upload and return upscaled result."""
    if image is None:
        return None, "Please upload an image or select a sample image."

    # Validate image size
    is_valid, message = validate_image_size(image)
    if not is_valid:
        return None, f"❌ Error: {message}"

    try:
        # Get the super-resolution output
        upscaled = upscale_image(image, model, device)
        return upscaled, "✅ Image successfully enhanced!"
    except Exception as e:
        return None, f"❌ Error processing image: {str(e)}"


def select_sample_image(image_path):
    """Load and return a sample image."""
    if image_path:
        return Image.open(image_path)
    return None


def create_interface(model, device):
    """Create and configure the Gradio interface."""
    css = generate_css()

    with gr.Blocks(css=css, title="HATSAT - Super-Resolution for Satellite Images") as iface:
        gr.Markdown("# HATSAT - Super-Resolution for Satellite Images")
        gr.Markdown(f"Upload a satellite image or select a sample to enhance its resolution by {UPSCALE_FACTOR}x.")
        gr.Markdown(f"⚠️ **Important**: Images must be exactly **{REQUIRED_IMAGE_SIZE[0]}x{REQUIRED_IMAGE_SIZE[1]} pixels** for the model to work properly.")

        # Acknowledgments section
        with gr.Accordion("Acknowledgments", open=False):
            gr.Markdown("""
            ### Base Model: HAT (Hybrid Attention Transformer)
            This model is a fine tuned version of **HAT**:
            - **GitHub Repository**: [https://github.com/XPixelGroup/HAT](https://github.com/XPixelGroup/HAT)
            - **Paper**: [Activating More Pixels in Image Super-Resolution Transformer](https://arxiv.org/abs/2205.04437)
            - **Authors**: Xiangyu Chen, Xintao Wang, Jiantao Zhou, Yu Qiao, Chao Dong

            ### Training Dataset: SEN2NAIPv2
            The model was fine-tuned using the **SEN2NAIPv2** dataset:
            - **HuggingFace Dataset**: [https://huggingface.co/datasets/tacofoundation/SEN2NAIPv2](https://huggingface.co/datasets/tacofoundation/SEN2NAIPv2)
            - **Description**: High-resolution satellite imagery dataset for super-resolution tasks
            """)

        # Sample images
        sample_images = get_sample_images()
        sample_buttons = []
        if sample_images:
            gr.Markdown("**Sample Images (click to select):**")
            with gr.Row():
                for i, img_path in enumerate(sample_images):
                    btn = gr.Button(
                        "",
                        elem_id=f"sample_btn_{i}",
                        elem_classes="sample-image-btn"
                    )
                    sample_buttons.append((btn, img_path))

        with gr.Row():
            input_image = gr.Image(
                type="pil",
                label=f"Input Image (must be {REQUIRED_IMAGE_SIZE[0]}x{REQUIRED_IMAGE_SIZE[1]} pixels)",
                elem_classes="image-container",
                sources=["upload"],
                height=500,
                width=500
            )

            output_image = gr.Image(
                type="pil",
                label=f"Enhanced Output ({UPSCALE_FACTOR}x)",
                elem_classes="image-container",
                interactive=False,
                height=500,
                width=500,
                show_download_button=True
            )

        submit_btn = gr.Button("Enhance Image", variant="primary")

        # Status message
        status_message = gr.Textbox(
            label="Status",
            interactive=False,
            show_label=True
        )

        # Event handlers
        if sample_images:
            for btn, img_path in sample_buttons:
                btn.click(fn=lambda path=img_path: select_sample_image(path), outputs=input_image)

        submit_btn.click(
            fn=lambda img: upscale_and_display(img, model, device),
            inputs=input_image,
            outputs=[output_image, status_message]
        )

    return iface