BorisEm commited on
Commit
4993aa4
·
1 Parent(s): 16cb92d

Change title

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +2 -2
  3. app_old.py +0 -700
README.md CHANGED
@@ -9,7 +9,7 @@ app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- # HAT Super-Resolution for Satellite Images
13
 
14
  This Hugging Face Space demonstrates a fine-tuned **Hybrid Attention Transformer (HAT)** model for satellite image super-resolution. The model performs 4x upscaling of satellite imagery, enhancing the resolution while preserving important geographical and structural details.
15
 
 
9
  pinned: false
10
  ---
11
 
12
+ # HATSAT - Super-Resolution for Satellite Images
13
 
14
  This Hugging Face Space demonstrates a fine-tuned **Hybrid Attention Transformer (HAT)** model for satellite image super-resolution. The model performs 4x upscaling of satellite imagery, enhancing the resolution while preserving important geographical and structural details.
15
 
app.py CHANGED
@@ -821,8 +821,8 @@ def generate_css():
821
 
822
  css = generate_css()
823
 
824
- with gr.Blocks(css=css, title="HAT Super-Resolution for Satellite Images") as iface:
825
- gr.Markdown("# HAT Super-Resolution for Satellite Images")
826
  gr.Markdown("Upload a satellite image or select a sample to enhance its resolution by 4x.")
827
  gr.Markdown("⚠️ **Important**: Images must be exactly **130x130 pixels** for the model to work properly.")
828
 
 
821
 
822
  css = generate_css()
823
 
824
+ with gr.Blocks(css=css, title="HATSAT - Super-Resolution for Satellite Images") as iface:
825
+ gr.Markdown("# HATSAT - Super-Resolution for Satellite Images")
826
  gr.Markdown("Upload a satellite image or select a sample to enhance its resolution by 4x.")
827
  gr.Markdown("⚠️ **Important**: Images must be exactly **130x130 pixels** for the model to work properly.")
828
 
app_old.py DELETED
@@ -1,700 +0,0 @@
1
- import gradio as gr
2
- import torch
3
- import torch.nn as nn
4
- import numpy as np
5
- from PIL import Image
6
- import cv2
7
- import math
8
- from einops import rearrange
9
-
10
-
11
- def to_2tuple(x):
12
- """Convert input to tuple of length 2."""
13
- if isinstance(x, (tuple, list)):
14
- return tuple(x)
15
- return (x, x)
16
-
17
-
18
- def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
19
- """Truncated normal initialization."""
20
- def norm_cdf(x):
21
- return (1. + math.erf(x / math.sqrt(2.))) / 2.
22
-
23
- with torch.no_grad():
24
- l = norm_cdf((a - mean) / std)
25
- u = norm_cdf((b - mean) / std)
26
- tensor.uniform_(2 * l - 1, 2 * u - 1)
27
- tensor.erfinv_()
28
- tensor.mul_(std * math.sqrt(2.))
29
- tensor.add_(mean)
30
- tensor.clamp_(min=a, max=b)
31
- return tensor
32
-
33
-
34
- def drop_path(x, drop_prob: float = 0., training: bool = False):
35
- if drop_prob == 0. or not training:
36
- return x
37
- keep_prob = 1 - drop_prob
38
- shape = (x.shape[0], ) + (1, ) * (x.ndim - 1)
39
- random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
40
- random_tensor.floor_()
41
- output = x.div(keep_prob) * random_tensor
42
- return output
43
-
44
-
45
- class DropPath(nn.Module):
46
- def __init__(self, drop_prob=None):
47
- super(DropPath, self).__init__()
48
- self.drop_prob = drop_prob
49
-
50
- def forward(self, x):
51
- return drop_path(x, self.drop_prob, self.training)
52
-
53
-
54
- class ChannelAttention(nn.Module):
55
- def __init__(self, num_feat, squeeze_factor=16):
56
- super(ChannelAttention, self).__init__()
57
- self.attention = nn.Sequential(
58
- nn.AdaptiveAvgPool2d(1),
59
- nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0),
60
- nn.ReLU(inplace=True),
61
- nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0),
62
- nn.Sigmoid())
63
-
64
- def forward(self, x):
65
- y = self.attention(x)
66
- return x * y
67
-
68
-
69
- class CAB(nn.Module):
70
- def __init__(self, num_feat, compress_ratio=3, squeeze_factor=30):
71
- super(CAB, self).__init__()
72
- self.cab = nn.Sequential(
73
- nn.Conv2d(num_feat, num_feat // compress_ratio, 3, 1, 1),
74
- nn.GELU(),
75
- nn.Conv2d(num_feat // compress_ratio, num_feat, 3, 1, 1),
76
- ChannelAttention(num_feat, squeeze_factor)
77
- )
78
-
79
- def forward(self, x):
80
- return self.cab(x)
81
-
82
-
83
- class Mlp(nn.Module):
84
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
85
- super().__init__()
86
- out_features = out_features or in_features
87
- hidden_features = hidden_features or in_features
88
- self.fc1 = nn.Linear(in_features, hidden_features)
89
- self.act = act_layer()
90
- self.fc2 = nn.Linear(hidden_features, out_features)
91
- self.drop = nn.Dropout(drop)
92
-
93
- def forward(self, x):
94
- x = self.fc1(x)
95
- x = self.act(x)
96
- x = self.drop(x)
97
- x = self.fc2(x)
98
- x = self.drop(x)
99
- return x
100
-
101
-
102
- def window_partition(x, window_size):
103
- B, H, W, C = x.shape
104
- x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
105
- windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
106
- return windows
107
-
108
-
109
- def window_reverse(windows, window_size, H, W):
110
- B = int(windows.shape[0] / (H * W / window_size / window_size))
111
- x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
112
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
113
- return x
114
-
115
-
116
- class WindowAttention(nn.Module):
117
- def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
118
- super().__init__()
119
- self.dim = dim
120
- self.window_size = window_size
121
- self.num_heads = num_heads
122
- head_dim = dim // num_heads
123
- self.scale = qk_scale or head_dim ** -0.5
124
-
125
- self.relative_position_bias_table = nn.Parameter(
126
- torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))
127
-
128
- coords_h = torch.arange(self.window_size[0])
129
- coords_w = torch.arange(self.window_size[1])
130
- coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
131
- coords_flatten = torch.flatten(coords, 1)
132
- relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
133
- relative_coords = relative_coords.permute(1, 2, 0).contiguous()
134
- relative_coords[:, :, 0] += self.window_size[0] - 1
135
- relative_coords[:, :, 1] += self.window_size[1] - 1
136
- relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
137
- relative_position_index = relative_coords.sum(-1)
138
- self.register_buffer("relative_position_index", relative_position_index)
139
-
140
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
141
- self.attn_drop = nn.Dropout(attn_drop)
142
- self.proj = nn.Linear(dim, dim)
143
- self.proj_drop = nn.Dropout(proj_drop)
144
-
145
- nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)
146
- self.softmax = nn.Softmax(dim=-1)
147
-
148
- def forward(self, x, mask=None):
149
- B_, N, C = x.shape
150
- qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
151
- q, k, v = qkv[0], qkv[1], qkv[2]
152
-
153
- q = q * self.scale
154
- attn = (q @ k.transpose(-2, -1))
155
-
156
- relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
157
- self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)
158
- relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
159
- attn = attn + relative_position_bias.unsqueeze(0)
160
-
161
- if mask is not None:
162
- nW = mask.shape[0]
163
- attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
164
- attn = attn.view(-1, self.num_heads, N, N)
165
- attn = self.softmax(attn)
166
- else:
167
- attn = self.softmax(attn)
168
-
169
- attn = self.attn_drop(attn)
170
-
171
- x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
172
- x = self.proj(x)
173
- x = self.proj_drop(x)
174
- return x
175
-
176
-
177
- class HAB(nn.Module):
178
- def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
179
- mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
180
- act_layer=nn.GELU, norm_layer=nn.LayerNorm, compress_ratio=3, squeeze_factor=30):
181
- super().__init__()
182
- self.dim = dim
183
- self.input_resolution = input_resolution
184
- self.num_heads = num_heads
185
- self.window_size = window_size
186
- self.shift_size = shift_size
187
- self.mlp_ratio = mlp_ratio
188
- if min(self.input_resolution) <= self.window_size:
189
- self.shift_size = 0
190
- self.window_size = min(self.input_resolution)
191
- assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
192
-
193
- self.norm1 = norm_layer(dim)
194
- self.attn = WindowAttention(
195
- dim, window_size=(self.window_size, self.window_size), num_heads=num_heads,
196
- qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
197
-
198
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
199
- self.norm2 = norm_layer(dim)
200
- mlp_hidden_dim = int(dim * mlp_ratio)
201
- self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
202
-
203
- self.conv_scale = nn.Parameter(torch.ones(1))
204
- self.conv_block = CAB(dim, compress_ratio, squeeze_factor)
205
-
206
- if self.shift_size > 0:
207
- H, W = self.input_resolution
208
- img_mask = torch.zeros((1, H, W, 1))
209
- h_slices = (slice(0, -self.window_size),
210
- slice(-self.window_size, -self.shift_size),
211
- slice(-self.shift_size, None))
212
- w_slices = (slice(0, -self.window_size),
213
- slice(-self.window_size, -self.shift_size),
214
- slice(-self.shift_size, None))
215
- cnt = 0
216
- for h in h_slices:
217
- for w in w_slices:
218
- img_mask[:, h, w, :] = cnt
219
- cnt += 1
220
-
221
- mask_windows = window_partition(img_mask, self.window_size)
222
- mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
223
- attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
224
- attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
225
- else:
226
- attn_mask = None
227
-
228
- self.register_buffer("attn_mask", attn_mask)
229
-
230
- def forward(self, x):
231
- H, W = self.input_resolution
232
- B, L, C = x.shape
233
- assert L == H * W, "input feature has wrong size"
234
-
235
- shortcut = x
236
- x = self.norm1(x)
237
- x = x.view(B, H, W, C)
238
-
239
- if self.shift_size > 0:
240
- shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
241
- else:
242
- shifted_x = x
243
-
244
- x_windows = window_partition(shifted_x, self.window_size)
245
- x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
246
-
247
- attn_windows = self.attn(x_windows, mask=self.attn_mask)
248
-
249
- attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
250
- shifted_x = window_reverse(attn_windows, self.window_size, H, W)
251
-
252
- if self.shift_size > 0:
253
- x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
254
- else:
255
- x = shifted_x
256
- x = x.view(B, H * W, C)
257
-
258
- x = shortcut + self.drop_path(x)
259
-
260
- y = x
261
- x = self.norm2(x)
262
- x = self.mlp(x)
263
- x = y + self.drop_path(x)
264
-
265
- conv_x = self.conv_block(x.view(B, H, W, C).permute(0, 3, 1, 2))
266
- conv_x = conv_x.permute(0, 2, 3, 1).view(B, H * W, C)
267
-
268
- x = x + self.conv_scale * conv_x
269
-
270
- return x
271
-
272
-
273
- class OCAB(nn.Module):
274
- def __init__(self, dim, input_resolution, window_size, overlap_ratio, num_heads,
275
- mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
276
- drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, compress_ratio=3,
277
- squeeze_factor=30):
278
- super().__init__()
279
- self.dim = dim
280
- self.input_resolution = input_resolution
281
- self.window_size = window_size
282
- self.num_heads = num_heads
283
- self.shift_size = round(overlap_ratio * window_size)
284
- self.mlp_ratio = mlp_ratio
285
-
286
- if min(self.input_resolution) <= self.window_size:
287
- self.shift_size = 0
288
- self.window_size = min(self.input_resolution)
289
-
290
- assert 0 <= self.shift_size, "shift_size >= 0 is required"
291
-
292
- self.norm1 = norm_layer(dim)
293
- self.attn = WindowAttention(
294
- dim, window_size=(self.window_size, self.window_size), num_heads=num_heads,
295
- qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
296
-
297
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
298
- self.norm2 = norm_layer(dim)
299
- mlp_hidden_dim = int(dim * mlp_ratio)
300
- self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
301
-
302
- self.conv_scale = nn.Parameter(torch.ones(1))
303
- self.conv_block = CAB(dim, compress_ratio, squeeze_factor)
304
-
305
- def forward(self, x):
306
- H, W = self.input_resolution
307
- B, L, C = x.shape
308
- assert L == H * W, "input feature has wrong size"
309
-
310
- shortcut = x
311
- x = self.norm1(x)
312
- x = x.view(B, H, W, C)
313
-
314
- pad_l = pad_t = 0
315
- pad_r = (self.window_size - W % self.window_size) % self.window_size
316
- pad_b = (self.window_size - H % self.window_size) % self.window_size
317
- x = torch.nn.functional.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
318
- _, Hp, Wp, _ = x.shape
319
-
320
- if self.shift_size > 0:
321
- shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
322
- else:
323
- shifted_x = x
324
-
325
- x_windows = window_partition(shifted_x, self.window_size)
326
- x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
327
-
328
- attn_windows = self.attn(x_windows, mask=None)
329
-
330
- attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
331
- shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp)
332
-
333
- if self.shift_size > 0:
334
- x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
335
- else:
336
- x = shifted_x
337
-
338
- if pad_r > 0 or pad_b > 0:
339
- x = x[:, :H, :W, :].contiguous()
340
-
341
- x = x.view(B, H * W, C)
342
- x = shortcut + self.drop_path(x)
343
-
344
- y = x
345
- x = self.norm2(x)
346
- x = self.mlp(x)
347
- x = y + self.drop_path(x)
348
-
349
- conv_x = self.conv_block(x.view(B, H, W, C).permute(0, 3, 1, 2))
350
- conv_x = conv_x.permute(0, 2, 3, 1).view(B, H * W, C)
351
-
352
- x = x + self.conv_scale * conv_x
353
-
354
- return x
355
-
356
-
357
- class PatchEmbed(nn.Module):
358
- def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
359
- super().__init__()
360
- img_size = (img_size, img_size)
361
- patch_size = (patch_size, patch_size)
362
- patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
363
- self.img_size = img_size
364
- self.patch_size = patch_size
365
- self.patches_resolution = patches_resolution
366
- self.num_patches = patches_resolution[0] * patches_resolution[1]
367
-
368
- self.in_chans = in_chans
369
- self.embed_dim = embed_dim
370
-
371
- self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
372
- if norm_layer is not None:
373
- self.norm = norm_layer(embed_dim)
374
- else:
375
- self.norm = None
376
-
377
- def forward(self, x):
378
- B, C, H, W = x.shape
379
- assert H == self.img_size[0] and W == self.img_size[1], \
380
- f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
381
- x = self.proj(x).flatten(2).transpose(1, 2)
382
- if self.norm is not None:
383
- x = self.norm(x)
384
- return x
385
-
386
-
387
- class PatchUnEmbed(nn.Module):
388
- def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
389
- super().__init__()
390
- img_size = (img_size, img_size)
391
- patch_size = (patch_size, patch_size)
392
- patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
393
- self.img_size = img_size
394
- self.patch_size = patch_size
395
- self.patches_resolution = patches_resolution
396
- self.num_patches = patches_resolution[0] * patches_resolution[1]
397
-
398
- self.in_chans = in_chans
399
- self.embed_dim = embed_dim
400
-
401
- def forward(self, x, x_size):
402
- H, W = x_size
403
- B, HW, C = x.shape
404
- x = x.transpose(1, 2).view(B, self.embed_dim, H, W)
405
- return x
406
-
407
-
408
- class RHAG(nn.Module):
409
- def __init__(self, dim, input_resolution, depth, num_heads, window_size, compress_ratio,
410
- squeeze_factor, conv_scale, overlap_ratio, mlp_ratio=4., qkv_bias=True, qk_scale=None,
411
- drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, downsample=None,
412
- use_checkpoint=False):
413
- super().__init__()
414
- self.dim = dim
415
- self.input_resolution = input_resolution
416
- self.depth = depth
417
- self.use_checkpoint = use_checkpoint
418
-
419
- self.blocks_1 = nn.ModuleList([
420
- HAB(dim=dim, input_resolution=input_resolution,
421
- num_heads=num_heads, window_size=window_size,
422
- shift_size=0 if (i % 2 == 0) else window_size // 2,
423
- mlp_ratio=mlp_ratio,
424
- qkv_bias=qkv_bias, qk_scale=qk_scale,
425
- drop=drop, attn_drop=attn_drop,
426
- drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
427
- norm_layer=norm_layer, compress_ratio=compress_ratio,
428
- squeeze_factor=squeeze_factor)
429
- for i in range(depth // 2)])
430
-
431
- self.blocks_2 = nn.ModuleList([
432
- OCAB(dim=dim, input_resolution=input_resolution,
433
- window_size=window_size, overlap_ratio=overlap_ratio,
434
- num_heads=num_heads, mlp_ratio=mlp_ratio,
435
- qkv_bias=qkv_bias, qk_scale=qk_scale,
436
- drop=drop, attn_drop=attn_drop,
437
- drop_path=drop_path[i + depth//2] if isinstance(drop_path, list) else drop_path,
438
- norm_layer=norm_layer, compress_ratio=compress_ratio,
439
- squeeze_factor=squeeze_factor)
440
- for i in range(depth // 2)])
441
-
442
- self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
443
- self.conv_scale = conv_scale
444
-
445
- if downsample is not None:
446
- self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
447
- else:
448
- self.downsample = None
449
-
450
- def forward(self, x, x_size):
451
- H, W = x_size
452
- res = x
453
- for blk in self.blocks_1:
454
- if self.use_checkpoint:
455
- x = torch.utils.checkpoint.checkpoint(blk, x)
456
- else:
457
- x = blk(x)
458
- for blk in self.blocks_2:
459
- if self.use_checkpoint:
460
- x = torch.utils.checkpoint.checkpoint(blk, x)
461
- else:
462
- x = blk(x)
463
-
464
- conv_x = self.conv(x.transpose(1, 2).view(-1, self.dim, H, W)).view(-1, self.dim, H * W).transpose(1, 2)
465
- x = res + x + conv_x * self.conv_scale
466
-
467
- if self.downsample is not None:
468
- x = self.downsample(x)
469
- return x
470
-
471
-
472
- class Upsample(nn.Sequential):
473
- def __init__(self, scale, num_feat):
474
- m = []
475
- if (scale & (scale - 1)) == 0:
476
- for _ in range(int(math.log(scale, 2))):
477
- m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
478
- m.append(nn.PixelShuffle(2))
479
- elif scale == 3:
480
- m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
481
- m.append(nn.PixelShuffle(3))
482
- else:
483
- raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
484
- super(Upsample, self).__init__(*m)
485
-
486
-
487
- class HAT(nn.Module):
488
- def __init__(self, img_size=64, patch_size=1, in_chans=3, embed_dim=180, depths=[6, 6, 6, 6, 6, 6],
489
- num_heads=[6, 6, 6, 6, 6, 6], window_size=16, compress_ratio=3, squeeze_factor=30,
490
- conv_scale=0.01, overlap_ratio=0.5, mlp_ratio=4., qkv_bias=True, qk_scale=None,
491
- drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm,
492
- ape=False, patch_norm=True, use_checkpoint=False, upscale=2, img_range=1.,
493
- upsampler='', resi_connection='1conv', **kwargs):
494
- super(HAT, self).__init__()
495
-
496
- self.window_size = window_size
497
- self.shift_size = window_size // 2
498
- self.overlap_ratio = overlap_ratio
499
- num_in_ch = in_chans
500
- num_out_ch = in_chans
501
- num_feat = 64
502
- self.img_range = img_range
503
- if in_chans == 3:
504
- rgb_mean = (0.4488, 0.4371, 0.4040)
505
- self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
506
- else:
507
- self.mean = torch.zeros(1, 1, 1, 1)
508
- self.upscale = upscale
509
- self.upsampler = upsampler
510
-
511
- self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
512
-
513
- self.num_layers = len(depths)
514
- self.embed_dim = embed_dim
515
- self.ape = ape
516
- self.patch_norm = patch_norm
517
- self.num_features = embed_dim
518
- self.mlp_ratio = mlp_ratio
519
-
520
- self.patch_embed = PatchEmbed(
521
- img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
522
- norm_layer=norm_layer if self.patch_norm else None)
523
- num_patches = self.patch_embed.num_patches
524
- patches_resolution = self.patch_embed.patches_resolution
525
- self.patches_resolution = patches_resolution
526
-
527
- self.patch_unembed = PatchUnEmbed(
528
- img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
529
- norm_layer=norm_layer if self.patch_norm else None)
530
-
531
- if self.ape:
532
- self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
533
- nn.init.trunc_normal_(self.absolute_pos_embed, std=.02)
534
-
535
- self.pos_drop = nn.Dropout(p=drop_rate)
536
-
537
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
538
-
539
- self.layers = nn.ModuleList()
540
- for i_layer in range(self.num_layers):
541
- layer = RHAG(dim=embed_dim,
542
- input_resolution=(patches_resolution[0],
543
- patches_resolution[1]),
544
- depth=depths[i_layer],
545
- num_heads=num_heads[i_layer],
546
- window_size=window_size,
547
- compress_ratio=compress_ratio,
548
- squeeze_factor=squeeze_factor,
549
- conv_scale=conv_scale,
550
- overlap_ratio=overlap_ratio,
551
- mlp_ratio=self.mlp_ratio,
552
- qkv_bias=qkv_bias, qk_scale=qk_scale,
553
- drop=drop_rate, attn_drop=attn_drop_rate,
554
- drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
555
- norm_layer=norm_layer,
556
- downsample=None,
557
- use_checkpoint=use_checkpoint)
558
- self.layers.append(layer)
559
- self.norm = norm_layer(self.num_features)
560
-
561
- if resi_connection == '1conv':
562
- self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
563
- elif resi_connection == '3conv':
564
- self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
565
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
566
- nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
567
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
568
- nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
569
-
570
- if upsampler == 'pixelshuffle':
571
- self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
572
- nn.LeakyReLU(inplace=True))
573
- self.upsample = Upsample(upscale, num_feat)
574
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
575
-
576
- self.apply(self._init_weights)
577
-
578
- def _init_weights(self, m):
579
- if isinstance(m, nn.Linear):
580
- nn.init.trunc_normal_(m.weight, std=.02)
581
- if isinstance(m, nn.Linear) and m.bias is not None:
582
- nn.init.constant_(m.bias, 0)
583
- elif isinstance(m, nn.LayerNorm):
584
- nn.init.constant_(m.bias, 0)
585
- nn.init.constant_(m.weight, 1.0)
586
-
587
- @torch.jit.ignore
588
- def no_weight_decay(self):
589
- return {'absolute_pos_embed'}
590
-
591
- @torch.jit.ignore
592
- def no_weight_decay_keywords(self):
593
- return {'relative_position_bias_table'}
594
-
595
- def forward_features(self, x):
596
- x_size = (x.shape[2], x.shape[3])
597
- x = self.patch_embed(x)
598
- if self.ape:
599
- x = x + self.absolute_pos_embed
600
- x = self.pos_drop(x)
601
-
602
- for layer in self.layers:
603
- x = layer(x, x_size)
604
-
605
- x = self.norm(x)
606
- x = self.patch_unembed(x, x_size)
607
-
608
- return x
609
-
610
- def forward(self, x):
611
- self.mean = self.mean.type_as(x)
612
- x = (x - self.mean) * self.img_range
613
-
614
- x_first = self.conv_first(x)
615
- res = self.conv_after_body(self.forward_features(x_first)) + x_first
616
- if self.upsampler == 'pixelshuffle':
617
- x = self.conv_before_upsample(res)
618
- x = self.conv_last(self.upsample(x))
619
-
620
- x = x / self.img_range + self.mean
621
-
622
- return x
623
-
624
-
625
- # Load the model
626
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
627
-
628
- model = HAT(
629
- upscale=4,
630
- in_chans=3,
631
- img_size=128,
632
- window_size=16,
633
- compress_ratio=3,
634
- squeeze_factor=30,
635
- conv_scale=0.01,
636
- overlap_ratio=0.5,
637
- img_range=1.,
638
- depths=[6, 6, 6, 6, 6, 6],
639
- embed_dim=180,
640
- num_heads=[6, 6, 6, 6, 6, 6],
641
- mlp_ratio=2,
642
- upsampler='pixelshuffle',
643
- resi_connection='1conv'
644
- )
645
-
646
- # Load the fine-tuned weights
647
- checkpoint = torch.load('net_g_20000.pth', map_location=device)
648
- if 'params_ema' in checkpoint:
649
- model.load_state_dict(checkpoint['params_ema'])
650
- elif 'params' in checkpoint:
651
- model.load_state_dict(checkpoint['params'])
652
- else:
653
- model.load_state_dict(checkpoint)
654
-
655
- model.to(device)
656
- model.eval()
657
-
658
-
659
- def upscale_image(image):
660
- # Convert PIL image to tensor
661
- img_np = np.array(image).astype(np.float32) / 255.0
662
- img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(device)
663
-
664
- # Ensure the image dimensions are multiples of window_size
665
- h, w = img_tensor.shape[2], img_tensor.shape[3]
666
-
667
- # Pad if necessary
668
- pad_h = (16 - h % 16) % 16
669
- pad_w = (16 - w % 16) % 16
670
-
671
- if pad_h > 0 or pad_w > 0:
672
- img_tensor = torch.nn.functional.pad(img_tensor, (0, pad_w, 0, pad_h), mode='reflect')
673
-
674
- with torch.no_grad():
675
- output = model(img_tensor)
676
-
677
- # Remove padding if it was added
678
- if pad_h > 0 or pad_w > 0:
679
- output = output[:, :, :h*4, :w*4]
680
-
681
- # Convert back to PIL image
682
- output_np = output.squeeze(0).permute(1, 2, 0).cpu().numpy()
683
- output_np = np.clip(output_np * 255.0, 0, 255).astype(np.uint8)
684
-
685
- return Image.fromarray(output_np)
686
-
687
-
688
- # Gradio interface
689
- iface = gr.Interface(
690
- fn=upscale_image,
691
- inputs=gr.Image(type="pil", label="Input Satellite Image"),
692
- outputs=gr.Image(type="pil", label="Super-Resolution Output (4x)"),
693
- title="HAT Super-Resolution for Satellite Images",
694
- description="Upload a satellite image to enhance its resolution by 4x using a fine-tuned HAT model. This model has been specifically trained on satellite imagery to provide high-quality super-resolution results.",
695
- examples=None,
696
- cache_examples=False
697
- )
698
-
699
- if __name__ == "__main__":
700
- iface.launch()