BorisEm Claude commited on
Commit
841d16c
·
1 Parent(s): 154404e

Add HAT satellite super-resolution model with Gradio interface

Browse files

- Added fine-tuned HAT model weights (170MB) via Git LFS
- Implemented complete Gradio web interface for 4x satellite image upscaling
- Added model architecture implementation with all required components
- Included comprehensive documentation and usage instructions

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>

Files changed (4) hide show
  1. README.md +74 -8
  2. app.py +677 -0
  3. net_g_20000.pth +3 -0
  4. requirements.txt +7 -0
README.md CHANGED
@@ -1,14 +1,80 @@
1
  ---
2
- title: HATSAT
3
- emoji: 🏆
4
- colorFrom: purple
5
- colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 5.46.1
8
  app_file: app.py
9
  pinned: false
10
- license: apache-2.0
11
- short_description: A fined tuned version of HAT on satellite images.
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: HAT Super-Resolution for Satellite Images
3
+ emoji: 🛰️
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: gradio
7
+ sdk_version: 3.50.2
8
  app_file: app.py
9
  pinned: false
 
 
10
  ---
11
 
12
+ # HAT Super-Resolution for Satellite Images
13
+
14
+ This Hugging Face Space demonstrates a fine-tuned **Hybrid Attention Transformer (HAT)** model for satellite image super-resolution. The model performs 4x upscaling of satellite imagery, enhancing the resolution while preserving important geographical and structural details.
15
+
16
+ ## Model Details
17
+
18
+ - **Architecture**: HAT (Hybrid Attention Transformer)
19
+ - **Upscaling Factor**: 4x
20
+ - **Input Channels**: 3 (RGB)
21
+ - **Training**: Fine-tuned on satellite imagery dataset
22
+ - **Base Model**: Pre-trained HAT model from ImageNet
23
+
24
+ ## Model Configuration
25
+
26
+ - **Window Size**: 16
27
+ - **Embed Dimension**: 180
28
+ - **Depths**: [6, 6, 6, 6, 6, 6]
29
+ - **Number of Heads**: [6, 6, 6, 6, 6, 6]
30
+ - **Compress Ratio**: 3
31
+ - **Squeeze Factor**: 30
32
+ - **Overlap Ratio**: 0.5
33
+
34
+ ## Usage
35
+
36
+ 1. Upload a satellite image (RGB format)
37
+ 2. The model will automatically upscale it by 4x
38
+ 3. Download the enhanced high-resolution result
39
+
40
+ ## Training Details
41
+
42
+ The model was fine-tuned using:
43
+ - **Loss Function**: L1Loss
44
+ - **Optimizer**: Adam (lr=2e-5)
45
+ - **Training Iterations**: 20,000
46
+ - **Scheduler**: MultiStepLR with milestones at [10000, 50000, 100000, 130000, 140000]
47
+
48
+ ## Applications
49
+
50
+ This model is particularly useful for:
51
+ - Enhancing low-resolution satellite imagery
52
+ - Geographic analysis and mapping
53
+ - Environmental monitoring
54
+ - Urban planning and development
55
+ - Agricultural monitoring
56
+
57
+ ## Technical Implementation
58
+
59
+ The model implements several key architectural components:
60
+ - **Hybrid Attention Blocks (HAB)**: Combining window-based and overlapping attention
61
+ - **Overlapping Cross-Attention Blocks (OCAB)**: For enhanced feature extraction
62
+ - **Residual Hybrid Attention Groups (RHAG)**: Stacked attention layers with residual connections
63
+ - **Channel Attention Blocks (CAB)**: For feature refinement
64
+
65
+ ## Performance
66
+
67
+ The model has been trained for 20,000 iterations with careful monitoring of PSNR and SSIM metrics on satellite imagery validation data.
68
+
69
+ ## Citation
70
+
71
+ If you use this model in your research, please cite the original HAT paper:
72
+
73
+ ```bibtex
74
+ @article{chen2023hat,
75
+ title={Activating More Pixels in Image Super-Resolution Transformer},
76
+ author={Chen, Xiangyu and Wang, Xintao and Zhou, Jiantao and Qiao, Yu and Dong, Chao},
77
+ journal={arXiv preprint arXiv:2205.04437},
78
+ year={2022}
79
+ }
80
+ ```
app.py ADDED
@@ -0,0 +1,677 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ from PIL import Image
6
+ import cv2
7
+ import math
8
+ from einops import rearrange
9
+
10
+
11
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
12
+ if drop_prob == 0. or not training:
13
+ return x
14
+ keep_prob = 1 - drop_prob
15
+ shape = (x.shape[0], ) + (1, ) * (x.ndim - 1)
16
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
17
+ random_tensor.floor_()
18
+ output = x.div(keep_prob) * random_tensor
19
+ return output
20
+
21
+
22
+ class DropPath(nn.Module):
23
+ def __init__(self, drop_prob=None):
24
+ super(DropPath, self).__init__()
25
+ self.drop_prob = drop_prob
26
+
27
+ def forward(self, x):
28
+ return drop_path(x, self.drop_prob, self.training)
29
+
30
+
31
+ class ChannelAttention(nn.Module):
32
+ def __init__(self, num_feat, squeeze_factor=16):
33
+ super(ChannelAttention, self).__init__()
34
+ self.attention = nn.Sequential(
35
+ nn.AdaptiveAvgPool2d(1),
36
+ nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0),
37
+ nn.ReLU(inplace=True),
38
+ nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0),
39
+ nn.Sigmoid())
40
+
41
+ def forward(self, x):
42
+ y = self.attention(x)
43
+ return x * y
44
+
45
+
46
+ class CAB(nn.Module):
47
+ def __init__(self, num_feat, compress_ratio=3, squeeze_factor=30):
48
+ super(CAB, self).__init__()
49
+ self.cab = nn.Sequential(
50
+ nn.Conv2d(num_feat, num_feat // compress_ratio, 3, 1, 1),
51
+ nn.GELU(),
52
+ nn.Conv2d(num_feat // compress_ratio, num_feat, 3, 1, 1),
53
+ ChannelAttention(num_feat, squeeze_factor)
54
+ )
55
+
56
+ def forward(self, x):
57
+ return self.cab(x)
58
+
59
+
60
+ class Mlp(nn.Module):
61
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
62
+ super().__init__()
63
+ out_features = out_features or in_features
64
+ hidden_features = hidden_features or in_features
65
+ self.fc1 = nn.Linear(in_features, hidden_features)
66
+ self.act = act_layer()
67
+ self.fc2 = nn.Linear(hidden_features, out_features)
68
+ self.drop = nn.Dropout(drop)
69
+
70
+ def forward(self, x):
71
+ x = self.fc1(x)
72
+ x = self.act(x)
73
+ x = self.drop(x)
74
+ x = self.fc2(x)
75
+ x = self.drop(x)
76
+ return x
77
+
78
+
79
+ def window_partition(x, window_size):
80
+ B, H, W, C = x.shape
81
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
82
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
83
+ return windows
84
+
85
+
86
+ def window_reverse(windows, window_size, H, W):
87
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
88
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
89
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
90
+ return x
91
+
92
+
93
+ class WindowAttention(nn.Module):
94
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
95
+ super().__init__()
96
+ self.dim = dim
97
+ self.window_size = window_size
98
+ self.num_heads = num_heads
99
+ head_dim = dim // num_heads
100
+ self.scale = qk_scale or head_dim ** -0.5
101
+
102
+ self.relative_position_bias_table = nn.Parameter(
103
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))
104
+
105
+ coords_h = torch.arange(self.window_size[0])
106
+ coords_w = torch.arange(self.window_size[1])
107
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
108
+ coords_flatten = torch.flatten(coords, 1)
109
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
110
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous()
111
+ relative_coords[:, :, 0] += self.window_size[0] - 1
112
+ relative_coords[:, :, 1] += self.window_size[1] - 1
113
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
114
+ relative_position_index = relative_coords.sum(-1)
115
+ self.register_buffer("relative_position_index", relative_position_index)
116
+
117
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
118
+ self.attn_drop = nn.Dropout(attn_drop)
119
+ self.proj = nn.Linear(dim, dim)
120
+ self.proj_drop = nn.Dropout(proj_drop)
121
+
122
+ nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)
123
+ self.softmax = nn.Softmax(dim=-1)
124
+
125
+ def forward(self, x, mask=None):
126
+ B_, N, C = x.shape
127
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
128
+ q, k, v = qkv[0], qkv[1], qkv[2]
129
+
130
+ q = q * self.scale
131
+ attn = (q @ k.transpose(-2, -1))
132
+
133
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
134
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)
135
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
136
+ attn = attn + relative_position_bias.unsqueeze(0)
137
+
138
+ if mask is not None:
139
+ nW = mask.shape[0]
140
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
141
+ attn = attn.view(-1, self.num_heads, N, N)
142
+ attn = self.softmax(attn)
143
+ else:
144
+ attn = self.softmax(attn)
145
+
146
+ attn = self.attn_drop(attn)
147
+
148
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
149
+ x = self.proj(x)
150
+ x = self.proj_drop(x)
151
+ return x
152
+
153
+
154
+ class HAB(nn.Module):
155
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
156
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
157
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm, compress_ratio=3, squeeze_factor=30):
158
+ super().__init__()
159
+ self.dim = dim
160
+ self.input_resolution = input_resolution
161
+ self.num_heads = num_heads
162
+ self.window_size = window_size
163
+ self.shift_size = shift_size
164
+ self.mlp_ratio = mlp_ratio
165
+ if min(self.input_resolution) <= self.window_size:
166
+ self.shift_size = 0
167
+ self.window_size = min(self.input_resolution)
168
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
169
+
170
+ self.norm1 = norm_layer(dim)
171
+ self.attn = WindowAttention(
172
+ dim, window_size=(self.window_size, self.window_size), num_heads=num_heads,
173
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
174
+
175
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
176
+ self.norm2 = norm_layer(dim)
177
+ mlp_hidden_dim = int(dim * mlp_ratio)
178
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
179
+
180
+ self.conv_scale = nn.Parameter(torch.ones(1))
181
+ self.conv_block = CAB(dim, compress_ratio, squeeze_factor)
182
+
183
+ if self.shift_size > 0:
184
+ H, W = self.input_resolution
185
+ img_mask = torch.zeros((1, H, W, 1))
186
+ h_slices = (slice(0, -self.window_size),
187
+ slice(-self.window_size, -self.shift_size),
188
+ slice(-self.shift_size, None))
189
+ w_slices = (slice(0, -self.window_size),
190
+ slice(-self.window_size, -self.shift_size),
191
+ slice(-self.shift_size, None))
192
+ cnt = 0
193
+ for h in h_slices:
194
+ for w in w_slices:
195
+ img_mask[:, h, w, :] = cnt
196
+ cnt += 1
197
+
198
+ mask_windows = window_partition(img_mask, self.window_size)
199
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
200
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
201
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
202
+ else:
203
+ attn_mask = None
204
+
205
+ self.register_buffer("attn_mask", attn_mask)
206
+
207
+ def forward(self, x):
208
+ H, W = self.input_resolution
209
+ B, L, C = x.shape
210
+ assert L == H * W, "input feature has wrong size"
211
+
212
+ shortcut = x
213
+ x = self.norm1(x)
214
+ x = x.view(B, H, W, C)
215
+
216
+ if self.shift_size > 0:
217
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
218
+ else:
219
+ shifted_x = x
220
+
221
+ x_windows = window_partition(shifted_x, self.window_size)
222
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
223
+
224
+ attn_windows = self.attn(x_windows, mask=self.attn_mask)
225
+
226
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
227
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W)
228
+
229
+ if self.shift_size > 0:
230
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
231
+ else:
232
+ x = shifted_x
233
+ x = x.view(B, H * W, C)
234
+
235
+ x = shortcut + self.drop_path(x)
236
+
237
+ y = x
238
+ x = self.norm2(x)
239
+ x = self.mlp(x)
240
+ x = y + self.drop_path(x)
241
+
242
+ conv_x = self.conv_block(x.view(B, H, W, C).permute(0, 3, 1, 2))
243
+ conv_x = conv_x.permute(0, 2, 3, 1).view(B, H * W, C)
244
+
245
+ x = x + self.conv_scale * conv_x
246
+
247
+ return x
248
+
249
+
250
+ class OCAB(nn.Module):
251
+ def __init__(self, dim, input_resolution, window_size, overlap_ratio, num_heads,
252
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
253
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, compress_ratio=3,
254
+ squeeze_factor=30):
255
+ super().__init__()
256
+ self.dim = dim
257
+ self.input_resolution = input_resolution
258
+ self.window_size = window_size
259
+ self.num_heads = num_heads
260
+ self.shift_size = round(overlap_ratio * window_size)
261
+ self.mlp_ratio = mlp_ratio
262
+
263
+ if min(self.input_resolution) <= self.window_size:
264
+ self.shift_size = 0
265
+ self.window_size = min(self.input_resolution)
266
+
267
+ assert 0 <= self.shift_size, "shift_size >= 0 is required"
268
+
269
+ self.norm1 = norm_layer(dim)
270
+ self.attn = WindowAttention(
271
+ dim, window_size=(self.window_size, self.window_size), num_heads=num_heads,
272
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
273
+
274
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
275
+ self.norm2 = norm_layer(dim)
276
+ mlp_hidden_dim = int(dim * mlp_ratio)
277
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
278
+
279
+ self.conv_scale = nn.Parameter(torch.ones(1))
280
+ self.conv_block = CAB(dim, compress_ratio, squeeze_factor)
281
+
282
+ def forward(self, x):
283
+ H, W = self.input_resolution
284
+ B, L, C = x.shape
285
+ assert L == H * W, "input feature has wrong size"
286
+
287
+ shortcut = x
288
+ x = self.norm1(x)
289
+ x = x.view(B, H, W, C)
290
+
291
+ pad_l = pad_t = 0
292
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
293
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
294
+ x = torch.nn.functional.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
295
+ _, Hp, Wp, _ = x.shape
296
+
297
+ if self.shift_size > 0:
298
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
299
+ else:
300
+ shifted_x = x
301
+
302
+ x_windows = window_partition(shifted_x, self.window_size)
303
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
304
+
305
+ attn_windows = self.attn(x_windows, mask=None)
306
+
307
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
308
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp)
309
+
310
+ if self.shift_size > 0:
311
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
312
+ else:
313
+ x = shifted_x
314
+
315
+ if pad_r > 0 or pad_b > 0:
316
+ x = x[:, :H, :W, :].contiguous()
317
+
318
+ x = x.view(B, H * W, C)
319
+ x = shortcut + self.drop_path(x)
320
+
321
+ y = x
322
+ x = self.norm2(x)
323
+ x = self.mlp(x)
324
+ x = y + self.drop_path(x)
325
+
326
+ conv_x = self.conv_block(x.view(B, H, W, C).permute(0, 3, 1, 2))
327
+ conv_x = conv_x.permute(0, 2, 3, 1).view(B, H * W, C)
328
+
329
+ x = x + self.conv_scale * conv_x
330
+
331
+ return x
332
+
333
+
334
+ class PatchEmbed(nn.Module):
335
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
336
+ super().__init__()
337
+ img_size = (img_size, img_size)
338
+ patch_size = (patch_size, patch_size)
339
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
340
+ self.img_size = img_size
341
+ self.patch_size = patch_size
342
+ self.patches_resolution = patches_resolution
343
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
344
+
345
+ self.in_chans = in_chans
346
+ self.embed_dim = embed_dim
347
+
348
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
349
+ if norm_layer is not None:
350
+ self.norm = norm_layer(embed_dim)
351
+ else:
352
+ self.norm = None
353
+
354
+ def forward(self, x):
355
+ B, C, H, W = x.shape
356
+ assert H == self.img_size[0] and W == self.img_size[1], \
357
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
358
+ x = self.proj(x).flatten(2).transpose(1, 2)
359
+ if self.norm is not None:
360
+ x = self.norm(x)
361
+ return x
362
+
363
+
364
+ class PatchUnEmbed(nn.Module):
365
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
366
+ super().__init__()
367
+ img_size = (img_size, img_size)
368
+ patch_size = (patch_size, patch_size)
369
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
370
+ self.img_size = img_size
371
+ self.patch_size = patch_size
372
+ self.patches_resolution = patches_resolution
373
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
374
+
375
+ self.in_chans = in_chans
376
+ self.embed_dim = embed_dim
377
+
378
+ def forward(self, x, x_size):
379
+ H, W = x_size
380
+ B, HW, C = x.shape
381
+ x = x.transpose(1, 2).view(B, self.embed_dim, H, W)
382
+ return x
383
+
384
+
385
+ class RHAG(nn.Module):
386
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size, compress_ratio,
387
+ squeeze_factor, conv_scale, overlap_ratio, mlp_ratio=4., qkv_bias=True, qk_scale=None,
388
+ drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, downsample=None,
389
+ use_checkpoint=False):
390
+ super().__init__()
391
+ self.dim = dim
392
+ self.input_resolution = input_resolution
393
+ self.depth = depth
394
+ self.use_checkpoint = use_checkpoint
395
+
396
+ self.blocks_1 = nn.ModuleList([
397
+ HAB(dim=dim, input_resolution=input_resolution,
398
+ num_heads=num_heads, window_size=window_size,
399
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
400
+ mlp_ratio=mlp_ratio,
401
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
402
+ drop=drop, attn_drop=attn_drop,
403
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
404
+ norm_layer=norm_layer, compress_ratio=compress_ratio,
405
+ squeeze_factor=squeeze_factor)
406
+ for i in range(depth // 2)])
407
+
408
+ self.blocks_2 = nn.ModuleList([
409
+ OCAB(dim=dim, input_resolution=input_resolution,
410
+ window_size=window_size, overlap_ratio=overlap_ratio,
411
+ num_heads=num_heads, mlp_ratio=mlp_ratio,
412
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
413
+ drop=drop, attn_drop=attn_drop,
414
+ drop_path=drop_path[i + depth//2] if isinstance(drop_path, list) else drop_path,
415
+ norm_layer=norm_layer, compress_ratio=compress_ratio,
416
+ squeeze_factor=squeeze_factor)
417
+ for i in range(depth // 2)])
418
+
419
+ self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
420
+ self.conv_scale = conv_scale
421
+
422
+ if downsample is not None:
423
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
424
+ else:
425
+ self.downsample = None
426
+
427
+ def forward(self, x, x_size):
428
+ H, W = x_size
429
+ res = x
430
+ for blk in self.blocks_1:
431
+ if self.use_checkpoint:
432
+ x = torch.utils.checkpoint.checkpoint(blk, x)
433
+ else:
434
+ x = blk(x)
435
+ for blk in self.blocks_2:
436
+ if self.use_checkpoint:
437
+ x = torch.utils.checkpoint.checkpoint(blk, x)
438
+ else:
439
+ x = blk(x)
440
+
441
+ conv_x = self.conv(x.transpose(1, 2).view(-1, self.dim, H, W)).view(-1, self.dim, H * W).transpose(1, 2)
442
+ x = res + x + conv_x * self.conv_scale
443
+
444
+ if self.downsample is not None:
445
+ x = self.downsample(x)
446
+ return x
447
+
448
+
449
+ class Upsample(nn.Sequential):
450
+ def __init__(self, scale, num_feat):
451
+ m = []
452
+ if (scale & (scale - 1)) == 0:
453
+ for _ in range(int(math.log(scale, 2))):
454
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
455
+ m.append(nn.PixelShuffle(2))
456
+ elif scale == 3:
457
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
458
+ m.append(nn.PixelShuffle(3))
459
+ else:
460
+ raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
461
+ super(Upsample, self).__init__(*m)
462
+
463
+
464
+ class HAT(nn.Module):
465
+ def __init__(self, img_size=64, patch_size=1, in_chans=3, embed_dim=180, depths=[6, 6, 6, 6, 6, 6],
466
+ num_heads=[6, 6, 6, 6, 6, 6], window_size=16, compress_ratio=3, squeeze_factor=30,
467
+ conv_scale=0.01, overlap_ratio=0.5, mlp_ratio=4., qkv_bias=True, qk_scale=None,
468
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm,
469
+ ape=False, patch_norm=True, use_checkpoint=False, upscale=2, img_range=1.,
470
+ upsampler='', resi_connection='1conv', **kwargs):
471
+ super(HAT, self).__init__()
472
+
473
+ self.window_size = window_size
474
+ self.shift_size = window_size // 2
475
+ self.overlap_ratio = overlap_ratio
476
+ num_in_ch = in_chans
477
+ num_out_ch = in_chans
478
+ num_feat = 64
479
+ self.img_range = img_range
480
+ if in_chans == 3:
481
+ rgb_mean = (0.4488, 0.4371, 0.4040)
482
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
483
+ else:
484
+ self.mean = torch.zeros(1, 1, 1, 1)
485
+ self.upscale = upscale
486
+ self.upsampler = upsampler
487
+
488
+ self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
489
+
490
+ self.num_layers = len(depths)
491
+ self.embed_dim = embed_dim
492
+ self.ape = ape
493
+ self.patch_norm = patch_norm
494
+ self.num_features = embed_dim
495
+ self.mlp_ratio = mlp_ratio
496
+
497
+ self.patch_embed = PatchEmbed(
498
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
499
+ norm_layer=norm_layer if self.patch_norm else None)
500
+ num_patches = self.patch_embed.num_patches
501
+ patches_resolution = self.patch_embed.patches_resolution
502
+ self.patches_resolution = patches_resolution
503
+
504
+ self.patch_unembed = PatchUnEmbed(
505
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
506
+ norm_layer=norm_layer if self.patch_norm else None)
507
+
508
+ if self.ape:
509
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
510
+ nn.init.trunc_normal_(self.absolute_pos_embed, std=.02)
511
+
512
+ self.pos_drop = nn.Dropout(p=drop_rate)
513
+
514
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
515
+
516
+ self.layers = nn.ModuleList()
517
+ for i_layer in range(self.num_layers):
518
+ layer = RHAG(dim=embed_dim,
519
+ input_resolution=(patches_resolution[0],
520
+ patches_resolution[1]),
521
+ depth=depths[i_layer],
522
+ num_heads=num_heads[i_layer],
523
+ window_size=window_size,
524
+ compress_ratio=compress_ratio,
525
+ squeeze_factor=squeeze_factor,
526
+ conv_scale=conv_scale,
527
+ overlap_ratio=overlap_ratio,
528
+ mlp_ratio=self.mlp_ratio,
529
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
530
+ drop=drop_rate, attn_drop=attn_drop_rate,
531
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
532
+ norm_layer=norm_layer,
533
+ downsample=None,
534
+ use_checkpoint=use_checkpoint)
535
+ self.layers.append(layer)
536
+ self.norm = norm_layer(self.num_features)
537
+
538
+ if resi_connection == '1conv':
539
+ self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
540
+ elif resi_connection == '3conv':
541
+ self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
542
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
543
+ nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
544
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
545
+ nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
546
+
547
+ if upsampler == 'pixelshuffle':
548
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
549
+ nn.LeakyReLU(inplace=True))
550
+ self.upsample = Upsample(upscale, num_feat)
551
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
552
+
553
+ self.apply(self._init_weights)
554
+
555
+ def _init_weights(self, m):
556
+ if isinstance(m, nn.Linear):
557
+ nn.init.trunc_normal_(m.weight, std=.02)
558
+ if isinstance(m, nn.Linear) and m.bias is not None:
559
+ nn.init.constant_(m.bias, 0)
560
+ elif isinstance(m, nn.LayerNorm):
561
+ nn.init.constant_(m.bias, 0)
562
+ nn.init.constant_(m.weight, 1.0)
563
+
564
+ @torch.jit.ignore
565
+ def no_weight_decay(self):
566
+ return {'absolute_pos_embed'}
567
+
568
+ @torch.jit.ignore
569
+ def no_weight_decay_keywords(self):
570
+ return {'relative_position_bias_table'}
571
+
572
+ def forward_features(self, x):
573
+ x_size = (x.shape[2], x.shape[3])
574
+ x = self.patch_embed(x)
575
+ if self.ape:
576
+ x = x + self.absolute_pos_embed
577
+ x = self.pos_drop(x)
578
+
579
+ for layer in self.layers:
580
+ x = layer(x, x_size)
581
+
582
+ x = self.norm(x)
583
+ x = self.patch_unembed(x, x_size)
584
+
585
+ return x
586
+
587
+ def forward(self, x):
588
+ self.mean = self.mean.type_as(x)
589
+ x = (x - self.mean) * self.img_range
590
+
591
+ x_first = self.conv_first(x)
592
+ res = self.conv_after_body(self.forward_features(x_first)) + x_first
593
+ if self.upsampler == 'pixelshuffle':
594
+ x = self.conv_before_upsample(res)
595
+ x = self.conv_last(self.upsample(x))
596
+
597
+ x = x / self.img_range + self.mean
598
+
599
+ return x
600
+
601
+
602
+ # Load the model
603
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
604
+
605
+ model = HAT(
606
+ upscale=4,
607
+ in_chans=3,
608
+ img_size=128,
609
+ window_size=16,
610
+ compress_ratio=3,
611
+ squeeze_factor=30,
612
+ conv_scale=0.01,
613
+ overlap_ratio=0.5,
614
+ img_range=1.,
615
+ depths=[6, 6, 6, 6, 6, 6],
616
+ embed_dim=180,
617
+ num_heads=[6, 6, 6, 6, 6, 6],
618
+ mlp_ratio=2,
619
+ upsampler='pixelshuffle',
620
+ resi_connection='1conv'
621
+ )
622
+
623
+ # Load the fine-tuned weights
624
+ checkpoint = torch.load('net_g_20000.pth', map_location=device)
625
+ if 'params_ema' in checkpoint:
626
+ model.load_state_dict(checkpoint['params_ema'])
627
+ elif 'params' in checkpoint:
628
+ model.load_state_dict(checkpoint['params'])
629
+ else:
630
+ model.load_state_dict(checkpoint)
631
+
632
+ model.to(device)
633
+ model.eval()
634
+
635
+
636
+ def upscale_image(image):
637
+ # Convert PIL image to tensor
638
+ img_np = np.array(image).astype(np.float32) / 255.0
639
+ img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(device)
640
+
641
+ # Ensure the image dimensions are multiples of window_size
642
+ h, w = img_tensor.shape[2], img_tensor.shape[3]
643
+
644
+ # Pad if necessary
645
+ pad_h = (16 - h % 16) % 16
646
+ pad_w = (16 - w % 16) % 16
647
+
648
+ if pad_h > 0 or pad_w > 0:
649
+ img_tensor = torch.nn.functional.pad(img_tensor, (0, pad_w, 0, pad_h), mode='reflect')
650
+
651
+ with torch.no_grad():
652
+ output = model(img_tensor)
653
+
654
+ # Remove padding if it was added
655
+ if pad_h > 0 or pad_w > 0:
656
+ output = output[:, :, :h*4, :w*4]
657
+
658
+ # Convert back to PIL image
659
+ output_np = output.squeeze(0).permute(1, 2, 0).cpu().numpy()
660
+ output_np = np.clip(output_np * 255.0, 0, 255).astype(np.uint8)
661
+
662
+ return Image.fromarray(output_np)
663
+
664
+
665
+ # Gradio interface
666
+ iface = gr.Interface(
667
+ fn=upscale_image,
668
+ inputs=gr.Image(type="pil", label="Input Satellite Image"),
669
+ outputs=gr.Image(type="pil", label="Super-Resolution Output (4x)"),
670
+ title="HAT Super-Resolution for Satellite Images",
671
+ description="Upload a satellite image to enhance its resolution by 4x using a fine-tuned HAT model. This model has been specifically trained on satellite imagery to provide high-quality super-resolution results.",
672
+ examples=None,
673
+ cache_examples=False
674
+ )
675
+
676
+ if __name__ == "__main__":
677
+ iface.launch()
net_g_20000.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:94d74fbc11a23bec569dd31b9f10b1b1033fd839cadabbef155d2beab3a3ffeb
3
+ size 170284129
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch>=1.13.0
2
+ torchvision>=0.14.0
3
+ gradio>=3.0.0
4
+ numpy>=1.21.0
5
+ Pillow>=8.0.0
6
+ opencv-python>=4.5.0
7
+ einops>=0.4.0