BorisEm commited on
Commit
0def483
·
1 Parent(s): 6cc2b3b

Broke down code base into smaller files for readibility

Browse files
.gitignore ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be added to the global gitignore or merged into this project gitignore. For a PyCharm
158
+ # project, it is possible to include .idea/directory entries, you may need to remove them.
159
+ .idea/
app.py CHANGED
@@ -1,904 +1,21 @@
1
- import gradio as gr
2
- import torch
3
- import torch.nn as nn
4
- import numpy as np
5
- from PIL import Image
6
- import math
7
- from einops import rearrange
8
- import os
9
- import glob
10
- import base64
11
- from io import BytesIO
12
-
13
- # Constants
14
- MODEL_CHECKPOINT = 'net_g_150000.pth'
15
- REQUIRED_IMAGE_SIZE = (130, 130)
16
- WINDOW_SIZE = 16
17
- UPSCALE_FACTOR = 4
18
-
19
-
20
- def to_2tuple(x):
21
- """Convert input to tuple of length 2."""
22
- if isinstance(x, (tuple, list)):
23
- return tuple(x)
24
- return (x, x)
25
-
26
-
27
- def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
28
- """Truncated normal initialization."""
29
- def norm_cdf(x):
30
- return (1. + math.erf(x / math.sqrt(2.))) / 2.
31
-
32
- with torch.no_grad():
33
- l = norm_cdf((a - mean) / std)
34
- u = norm_cdf((b - mean) / std)
35
- tensor.uniform_(2 * l - 1, 2 * u - 1)
36
- tensor.erfinv_()
37
- tensor.mul_(std * math.sqrt(2.))
38
- tensor.add_(mean)
39
- tensor.clamp_(min=a, max=b)
40
- return tensor
41
-
42
-
43
- def drop_path(x, drop_prob: float = 0., training: bool = False):
44
- if drop_prob == 0. or not training:
45
- return x
46
- keep_prob = 1 - drop_prob
47
- shape = (x.shape[0], ) + (1, ) * (x.ndim - 1)
48
- random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
49
- random_tensor.floor_()
50
- output = x.div(keep_prob) * random_tensor
51
- return output
52
-
53
-
54
- class DropPath(nn.Module):
55
- def __init__(self, drop_prob=None):
56
- super(DropPath, self).__init__()
57
- self.drop_prob = drop_prob
58
-
59
- def forward(self, x):
60
- return drop_path(x, self.drop_prob, self.training)
61
-
62
-
63
- class ChannelAttention(nn.Module):
64
- def __init__(self, num_feat, squeeze_factor=16):
65
- super(ChannelAttention, self).__init__()
66
- self.attention = nn.Sequential(
67
- nn.AdaptiveAvgPool2d(1),
68
- nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0),
69
- nn.ReLU(inplace=True),
70
- nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0),
71
- nn.Sigmoid())
72
-
73
- def forward(self, x):
74
- y = self.attention(x)
75
- return x * y
76
-
77
-
78
- class CAB(nn.Module):
79
- def __init__(self, num_feat, compress_ratio=3, squeeze_factor=30):
80
- super(CAB, self).__init__()
81
- self.cab = nn.Sequential(
82
- nn.Conv2d(num_feat, num_feat // compress_ratio, 3, 1, 1),
83
- nn.GELU(),
84
- nn.Conv2d(num_feat // compress_ratio, num_feat, 3, 1, 1),
85
- ChannelAttention(num_feat, squeeze_factor)
86
- )
87
-
88
- def forward(self, x):
89
- return self.cab(x)
90
-
91
-
92
- class Mlp(nn.Module):
93
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
94
- super().__init__()
95
- out_features = out_features or in_features
96
- hidden_features = hidden_features or in_features
97
- self.fc1 = nn.Linear(in_features, hidden_features)
98
- self.act = act_layer()
99
- self.fc2 = nn.Linear(hidden_features, out_features)
100
- self.drop = nn.Dropout(drop)
101
-
102
- def forward(self, x):
103
- x = self.fc1(x)
104
- x = self.act(x)
105
- x = self.drop(x)
106
- x = self.fc2(x)
107
- x = self.drop(x)
108
- return x
109
-
110
-
111
- def window_partition(x, window_size):
112
- b, h, w, c = x.shape
113
- x = x.view(b, h // window_size, window_size, w // window_size, window_size, c)
114
- windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, c)
115
- return windows
116
-
117
-
118
- def window_reverse(windows, window_size, h, w):
119
- b = int(windows.shape[0] / (h * w / window_size / window_size))
120
- x = windows.view(b, h // window_size, w // window_size, window_size, window_size, -1)
121
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1)
122
- return x
123
-
124
-
125
- class WindowAttention(nn.Module):
126
- def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
127
- super().__init__()
128
- self.dim = dim
129
- self.window_size = window_size
130
- self.num_heads = num_heads
131
- head_dim = dim // num_heads
132
- self.scale = qk_scale or head_dim**-0.5
133
-
134
- self.relative_position_bias_table = nn.Parameter(
135
- torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))
136
-
137
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
138
- self.attn_drop = nn.Dropout(attn_drop)
139
- self.proj = nn.Linear(dim, dim)
140
- self.proj_drop = nn.Dropout(proj_drop)
141
-
142
- trunc_normal_(self.relative_position_bias_table, std=.02)
143
- self.softmax = nn.Softmax(dim=-1)
144
-
145
- def forward(self, x, rpi, mask=None):
146
- b_, n, c = x.shape
147
- qkv = self.qkv(x).reshape(b_, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4)
148
- q, k, v = qkv[0], qkv[1], qkv[2]
149
-
150
- q = q * self.scale
151
- attn = (q @ k.transpose(-2, -1))
152
-
153
- relative_position_bias = self.relative_position_bias_table[rpi.view(-1)].view(
154
- self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)
155
- relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
156
- attn = attn + relative_position_bias.unsqueeze(0)
157
-
158
- if mask is not None:
159
- nw = mask.shape[0]
160
- attn = attn.view(b_ // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0)
161
- attn = attn.view(-1, self.num_heads, n, n)
162
- attn = self.softmax(attn)
163
- else:
164
- attn = self.softmax(attn)
165
-
166
- attn = self.attn_drop(attn)
167
-
168
- x = (attn @ v).transpose(1, 2).reshape(b_, n, c)
169
- x = self.proj(x)
170
- x = self.proj_drop(x)
171
- return x
172
-
173
-
174
- class HAB(nn.Module):
175
- def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
176
- compress_ratio=3, squeeze_factor=30, conv_scale=0.01, mlp_ratio=4.,
177
- qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
178
- act_layer=nn.GELU, norm_layer=nn.LayerNorm):
179
- super().__init__()
180
- self.dim = dim
181
- self.input_resolution = input_resolution
182
- self.num_heads = num_heads
183
- self.window_size = window_size
184
- self.shift_size = shift_size
185
- self.mlp_ratio = mlp_ratio
186
- if min(self.input_resolution) <= self.window_size:
187
- self.shift_size = 0
188
- self.window_size = min(self.input_resolution)
189
- assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size'
190
-
191
- self.norm1 = norm_layer(dim)
192
- self.attn = WindowAttention(
193
- dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
194
- qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
195
-
196
- self.conv_scale = conv_scale
197
- self.conv_block = CAB(num_feat=dim, compress_ratio=compress_ratio, squeeze_factor=squeeze_factor)
198
-
199
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
200
- self.norm2 = norm_layer(dim)
201
- mlp_hidden_dim = int(dim * mlp_ratio)
202
- self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
203
-
204
- def forward(self, x, x_size, rpi_sa, attn_mask):
205
- h, w = x_size
206
- b, _, c = x.shape
207
-
208
- shortcut = x
209
- x = self.norm1(x)
210
- x = x.view(b, h, w, c)
211
-
212
- # Conv_X
213
- conv_x = self.conv_block(x.permute(0, 3, 1, 2))
214
- conv_x = conv_x.permute(0, 2, 3, 1).contiguous().view(b, h * w, c)
215
-
216
- # cyclic shift
217
- if self.shift_size > 0:
218
- shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
219
- attn_mask = attn_mask
220
- else:
221
- shifted_x = x
222
- attn_mask = None
223
-
224
- # partition windows
225
- x_windows = window_partition(shifted_x, self.window_size)
226
- x_windows = x_windows.view(-1, self.window_size * self.window_size, c)
227
-
228
- # W-MSA/SW-MSA
229
- attn_windows = self.attn(x_windows, rpi=rpi_sa, mask=attn_mask)
230
-
231
- # merge windows
232
- attn_windows = attn_windows.view(-1, self.window_size, self.window_size, c)
233
- shifted_x = window_reverse(attn_windows, self.window_size, h, w)
234
-
235
- # reverse cyclic shift
236
- if self.shift_size > 0:
237
- attn_x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
238
- else:
239
- attn_x = shifted_x
240
- attn_x = attn_x.view(b, h * w, c)
241
-
242
- # FFN
243
- x = shortcut + self.drop_path(attn_x) + conv_x * self.conv_scale
244
- x = x + self.drop_path(self.mlp(self.norm2(x)))
245
-
246
- return x
247
-
248
-
249
- class OCAB(nn.Module):
250
- def __init__(self, dim, input_resolution, window_size, overlap_ratio, num_heads,
251
- qkv_bias=True, qk_scale=None, mlp_ratio=2, norm_layer=nn.LayerNorm):
252
- super().__init__()
253
- self.dim = dim
254
- self.input_resolution = input_resolution
255
- self.window_size = window_size
256
- self.num_heads = num_heads
257
- head_dim = dim // num_heads
258
- self.scale = qk_scale or head_dim**-0.5
259
- self.overlap_win_size = int(window_size * overlap_ratio) + window_size
260
-
261
- self.norm1 = norm_layer(dim)
262
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
263
- self.unfold = nn.Unfold(kernel_size=(self.overlap_win_size, self.overlap_win_size),
264
- stride=window_size, padding=(self.overlap_win_size-window_size)//2)
265
-
266
- self.relative_position_bias_table = nn.Parameter(
267
- torch.zeros((window_size + self.overlap_win_size - 1) * (window_size + self.overlap_win_size - 1), num_heads))
268
-
269
- trunc_normal_(self.relative_position_bias_table, std=.02)
270
- self.softmax = nn.Softmax(dim=-1)
271
-
272
- self.proj = nn.Linear(dim,dim)
273
-
274
- self.norm2 = norm_layer(dim)
275
- mlp_hidden_dim = int(dim * mlp_ratio)
276
- self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=nn.GELU)
277
-
278
- def forward(self, x, x_size, rpi):
279
- h, w = x_size
280
- b, _, c = x.shape
281
-
282
- shortcut = x
283
- x = self.norm1(x)
284
- x = x.view(b, h, w, c)
285
-
286
- qkv = self.qkv(x).reshape(b, h, w, 3, c).permute(3, 0, 4, 1, 2)
287
- q = qkv[0].permute(0, 2, 3, 1)
288
- kv = torch.cat((qkv[1], qkv[2]), dim=1)
289
-
290
- # partition windows
291
- q_windows = window_partition(q, self.window_size)
292
- q_windows = q_windows.view(-1, self.window_size * self.window_size, c)
293
-
294
- kv_windows = self.unfold(kv)
295
- kv_windows = rearrange(kv_windows, 'b (nc ch owh oww) nw -> nc (b nw) (owh oww) ch',
296
- nc=2, ch=c, owh=self.overlap_win_size, oww=self.overlap_win_size).contiguous()
297
- k_windows, v_windows = kv_windows[0], kv_windows[1]
298
-
299
- b_, nq, _ = q_windows.shape
300
- _, n, _ = k_windows.shape
301
- d = self.dim // self.num_heads
302
- q = q_windows.reshape(b_, nq, self.num_heads, d).permute(0, 2, 1, 3)
303
- k = k_windows.reshape(b_, n, self.num_heads, d).permute(0, 2, 1, 3)
304
- v = v_windows.reshape(b_, n, self.num_heads, d).permute(0, 2, 1, 3)
305
-
306
- q = q * self.scale
307
- attn = (q @ k.transpose(-2, -1))
308
-
309
- relative_position_bias = self.relative_position_bias_table[rpi.view(-1)].view(
310
- self.window_size * self.window_size, self.overlap_win_size * self.overlap_win_size, -1)
311
- relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
312
- attn = attn + relative_position_bias.unsqueeze(0)
313
-
314
- attn = self.softmax(attn)
315
- attn_windows = (attn @ v).transpose(1, 2).reshape(b_, nq, self.dim)
316
-
317
- # merge windows
318
- attn_windows = attn_windows.view(-1, self.window_size, self.window_size, self.dim)
319
- x = window_reverse(attn_windows, self.window_size, h, w)
320
- x = x.view(b, h * w, self.dim)
321
-
322
- x = self.proj(x) + shortcut
323
- x = x + self.mlp(self.norm2(x))
324
- return x
325
-
326
-
327
- class AttenBlocks(nn.Module):
328
- def __init__(self, dim, input_resolution, depth, num_heads, window_size, compress_ratio,
329
- squeeze_factor, conv_scale, overlap_ratio, mlp_ratio=4., qkv_bias=True, qk_scale=None,
330
- drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, downsample=None,
331
- use_checkpoint=False):
332
- super().__init__()
333
- self.dim = dim
334
- self.input_resolution = input_resolution
335
- self.depth = depth
336
- self.use_checkpoint = use_checkpoint
337
-
338
- # build blocks
339
- self.blocks = nn.ModuleList([
340
- HAB(dim=dim, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size,
341
- shift_size=0 if (i % 2 == 0) else window_size // 2, compress_ratio=compress_ratio,
342
- squeeze_factor=squeeze_factor, conv_scale=conv_scale, mlp_ratio=mlp_ratio,
343
- qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop,
344
- drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
345
- norm_layer=norm_layer) for i in range(depth)
346
- ])
347
-
348
- # OCAB
349
- self.overlap_attn = OCAB(dim=dim, input_resolution=input_resolution, window_size=window_size,
350
- overlap_ratio=overlap_ratio, num_heads=num_heads, qkv_bias=qkv_bias,
351
- qk_scale=qk_scale, mlp_ratio=mlp_ratio, norm_layer=norm_layer)
352
-
353
- # patch merging layer
354
- if downsample is not None:
355
- self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
356
- else:
357
- self.downsample = None
358
-
359
- def forward(self, x, x_size, params):
360
- for blk in self.blocks:
361
- x = blk(x, x_size, params['rpi_sa'], params['attn_mask'])
362
-
363
- x = self.overlap_attn(x, x_size, params['rpi_oca'])
364
-
365
- if self.downsample is not None:
366
- x = self.downsample(x)
367
- return x
368
-
369
-
370
- class RHAG(nn.Module):
371
- def __init__(self, dim, input_resolution, depth, num_heads, window_size, compress_ratio,
372
- squeeze_factor, conv_scale, overlap_ratio, mlp_ratio=4., qkv_bias=True, qk_scale=None,
373
- drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, downsample=None,
374
- use_checkpoint=False, img_size=224, patch_size=4, resi_connection='1conv'):
375
- super(RHAG, self).__init__()
376
-
377
- self.dim = dim
378
- self.input_resolution = input_resolution
379
-
380
- self.residual_group = AttenBlocks(
381
- dim=dim, input_resolution=input_resolution, depth=depth, num_heads=num_heads,
382
- window_size=window_size, compress_ratio=compress_ratio, squeeze_factor=squeeze_factor,
383
- conv_scale=conv_scale, overlap_ratio=overlap_ratio, mlp_ratio=mlp_ratio,
384
- qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop,
385
- drop_path=drop_path, norm_layer=norm_layer, downsample=downsample,
386
- use_checkpoint=use_checkpoint)
387
-
388
- if resi_connection == '1conv':
389
- self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
390
- elif resi_connection == 'identity':
391
- self.conv = nn.Identity()
392
-
393
- self.patch_embed = PatchEmbed(
394
- img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
395
-
396
- self.patch_unembed = PatchUnEmbed(
397
- img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
398
-
399
- def forward(self, x, x_size, params):
400
- return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size, params), x_size))) + x
401
-
402
-
403
- class PatchEmbed(nn.Module):
404
- def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
405
- super().__init__()
406
- img_size = to_2tuple(img_size)
407
- patch_size = to_2tuple(patch_size)
408
- patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
409
- self.img_size = img_size
410
- self.patch_size = patch_size
411
- self.patches_resolution = patches_resolution
412
- self.num_patches = patches_resolution[0] * patches_resolution[1]
413
-
414
- self.in_chans = in_chans
415
- self.embed_dim = embed_dim
416
-
417
- if norm_layer is not None:
418
- self.norm = norm_layer(embed_dim)
419
- else:
420
- self.norm = None
421
-
422
- def forward(self, x):
423
- x = x.flatten(2).transpose(1, 2)
424
- if self.norm is not None:
425
- x = self.norm(x)
426
- return x
427
-
428
-
429
- class PatchUnEmbed(nn.Module):
430
- def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
431
- super().__init__()
432
- img_size = to_2tuple(img_size)
433
- patch_size = to_2tuple(patch_size)
434
- patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
435
- self.img_size = img_size
436
- self.patch_size = patch_size
437
- self.patches_resolution = patches_resolution
438
- self.num_patches = patches_resolution[0] * patches_resolution[1]
439
-
440
- self.in_chans = in_chans
441
- self.embed_dim = embed_dim
442
-
443
- def forward(self, x, x_size):
444
- x = x.transpose(1, 2).contiguous().view(x.shape[0], self.embed_dim, x_size[0], x_size[1])
445
- return x
446
-
447
-
448
- class Upsample(nn.Sequential):
449
- def __init__(self, scale, num_feat):
450
- m = []
451
- if (scale & (scale - 1)) == 0:
452
- for _ in range(int(math.log(scale, 2))):
453
- m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
454
- m.append(nn.PixelShuffle(2))
455
- elif scale == 3:
456
- m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
457
- m.append(nn.PixelShuffle(3))
458
- else:
459
- raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
460
- super(Upsample, self).__init__(*m)
461
-
462
-
463
- class HAT(nn.Module):
464
- def __init__(self, img_size=64, patch_size=1, in_chans=3, embed_dim=96, depths=(6, 6, 6, 6),
465
- num_heads=(6, 6, 6, 6), window_size=7, compress_ratio=3, squeeze_factor=30,
466
- conv_scale=0.01, overlap_ratio=0.5, mlp_ratio=4., qkv_bias=True, qk_scale=None,
467
- drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm,
468
- ape=False, patch_norm=True, use_checkpoint=False, upscale=2, img_range=1.,
469
- upsampler='', resi_connection='1conv', **kwargs):
470
- super(HAT, self).__init__()
471
-
472
- self.window_size = window_size
473
- self.shift_size = window_size // 2
474
- self.overlap_ratio = overlap_ratio
475
-
476
- num_in_ch = in_chans
477
- num_out_ch = in_chans
478
- num_feat = 64
479
- self.img_range = img_range
480
- if in_chans == 3:
481
- rgb_mean = (0.4488, 0.4371, 0.4040)
482
- self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
483
- else:
484
- self.mean = torch.zeros(1, 1, 1, 1)
485
- self.upscale = upscale
486
- self.upsampler = upsampler
487
-
488
- # relative position index
489
- relative_position_index_SA = self.calculate_rpi_sa()
490
- relative_position_index_OCA = self.calculate_rpi_oca()
491
- self.register_buffer('relative_position_index_SA', relative_position_index_SA)
492
- self.register_buffer('relative_position_index_OCA', relative_position_index_OCA)
493
-
494
- # shallow feature extraction
495
- self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
496
-
497
- # deep feature extraction
498
- self.num_layers = len(depths)
499
- self.embed_dim = embed_dim
500
- self.ape = ape
501
- self.patch_norm = patch_norm
502
- self.num_features = embed_dim
503
- self.mlp_ratio = mlp_ratio
504
-
505
- # split image into non-overlapping patches
506
- self.patch_embed = PatchEmbed(
507
- img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
508
- norm_layer=norm_layer if self.patch_norm else None)
509
- num_patches = self.patch_embed.num_patches
510
- patches_resolution = self.patch_embed.patches_resolution
511
- self.patches_resolution = patches_resolution
512
-
513
- # merge non-overlapping patches into image
514
- self.patch_unembed = PatchUnEmbed(
515
- img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
516
- norm_layer=norm_layer if self.patch_norm else None)
517
-
518
- # absolute position embedding
519
- if self.ape:
520
- self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
521
- trunc_normal_(self.absolute_pos_embed, std=.02)
522
-
523
- self.pos_drop = nn.Dropout(p=drop_rate)
524
-
525
- # stochastic depth
526
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
527
-
528
- # build Residual Hybrid Attention Groups (RHAG)
529
- self.layers = nn.ModuleList()
530
- for i_layer in range(self.num_layers):
531
- layer = RHAG(
532
- dim=embed_dim,
533
- input_resolution=(patches_resolution[0], patches_resolution[1]),
534
- depth=depths[i_layer],
535
- num_heads=num_heads[i_layer],
536
- window_size=window_size,
537
- compress_ratio=compress_ratio,
538
- squeeze_factor=squeeze_factor,
539
- conv_scale=conv_scale,
540
- overlap_ratio=overlap_ratio,
541
- mlp_ratio=self.mlp_ratio,
542
- qkv_bias=qkv_bias,
543
- qk_scale=qk_scale,
544
- drop=drop_rate,
545
- attn_drop=attn_drop_rate,
546
- drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
547
- norm_layer=norm_layer,
548
- downsample=None,
549
- use_checkpoint=use_checkpoint,
550
- img_size=img_size,
551
- patch_size=patch_size,
552
- resi_connection=resi_connection)
553
- self.layers.append(layer)
554
- self.norm = norm_layer(self.num_features)
555
-
556
- # build the last conv layer in deep feature extraction
557
- if resi_connection == '1conv':
558
- self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
559
- elif resi_connection == 'identity':
560
- self.conv_after_body = nn.Identity()
561
-
562
- # high quality image reconstruction
563
- if self.upsampler == 'pixelshuffle':
564
- self.conv_before_upsample = nn.Sequential(
565
- nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
566
- self.upsample = Upsample(upscale, num_feat)
567
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
568
-
569
- self.apply(self._init_weights)
570
-
571
- def _init_weights(self, m):
572
- if isinstance(m, nn.Linear):
573
- trunc_normal_(m.weight, std=.02)
574
- if isinstance(m, nn.Linear) and m.bias is not None:
575
- nn.init.constant_(m.bias, 0)
576
- elif isinstance(m, nn.LayerNorm):
577
- nn.init.constant_(m.bias, 0)
578
- nn.init.constant_(m.weight, 1.0)
579
-
580
- def calculate_rpi_sa(self):
581
- coords_h = torch.arange(self.window_size)
582
- coords_w = torch.arange(self.window_size)
583
- coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
584
- coords_flatten = torch.flatten(coords, 1)
585
- relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
586
- relative_coords = relative_coords.permute(1, 2, 0).contiguous()
587
- relative_coords[:, :, 0] += self.window_size - 1
588
- relative_coords[:, :, 1] += self.window_size - 1
589
- relative_coords[:, :, 0] *= 2 * self.window_size - 1
590
- relative_position_index = relative_coords.sum(-1)
591
- return relative_position_index
592
-
593
- def calculate_rpi_oca(self):
594
- window_size_ori = self.window_size
595
- window_size_ext = self.window_size + int(self.overlap_ratio * self.window_size)
596
-
597
- coords_h = torch.arange(window_size_ori)
598
- coords_w = torch.arange(window_size_ori)
599
- coords_ori = torch.stack(torch.meshgrid([coords_h, coords_w]))
600
- coords_ori_flatten = torch.flatten(coords_ori, 1)
601
-
602
- coords_h = torch.arange(window_size_ext)
603
- coords_w = torch.arange(window_size_ext)
604
- coords_ext = torch.stack(torch.meshgrid([coords_h, coords_w]))
605
- coords_ext_flatten = torch.flatten(coords_ext, 1)
606
-
607
- relative_coords = coords_ext_flatten[:, None, :] - coords_ori_flatten[:, :, None]
608
- relative_coords = relative_coords.permute(1, 2, 0).contiguous()
609
- relative_coords[:, :, 0] += window_size_ori - window_size_ext + 1
610
- relative_coords[:, :, 1] += window_size_ori - window_size_ext + 1
611
- relative_coords[:, :, 0] *= window_size_ori + window_size_ext - 1
612
- relative_position_index = relative_coords.sum(-1)
613
- return relative_position_index
614
-
615
- def calculate_mask(self, x_size):
616
- h, w = x_size
617
- img_mask = torch.zeros((1, h, w, 1))
618
- h_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None))
619
- w_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None))
620
- cnt = 0
621
- for h in h_slices:
622
- for w in w_slices:
623
- img_mask[:, h, w, :] = cnt
624
- cnt += 1
625
-
626
- mask_windows = window_partition(img_mask, self.window_size)
627
- mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
628
- attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
629
- attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
630
- return attn_mask
631
-
632
- @torch.jit.ignore
633
- def no_weight_decay(self):
634
- return {'absolute_pos_embed'}
635
-
636
- @torch.jit.ignore
637
- def no_weight_decay_keywords(self):
638
- return {'relative_position_bias_table'}
639
-
640
- def forward_features(self, x):
641
- x_size = (x.shape[2], x.shape[3])
642
-
643
- attn_mask = self.calculate_mask(x_size).to(x.device)
644
- params = {'attn_mask': attn_mask, 'rpi_sa': self.relative_position_index_SA, 'rpi_oca': self.relative_position_index_OCA}
645
-
646
- x = self.patch_embed(x)
647
- if self.ape:
648
- x = x + self.absolute_pos_embed
649
- x = self.pos_drop(x)
650
-
651
- for layer in self.layers:
652
- x = layer(x, x_size, params)
653
-
654
- x = self.norm(x)
655
- x = self.patch_unembed(x, x_size)
656
- return x
657
-
658
- def forward(self, x):
659
- self.mean = self.mean.type_as(x)
660
- x = (x - self.mean) * self.img_range
661
-
662
- if self.upsampler == 'pixelshuffle':
663
- x = self.conv_first(x)
664
- x = self.conv_after_body(self.forward_features(x)) + x
665
- x = self.conv_before_upsample(x)
666
- x = self.conv_last(self.upsample(x))
667
-
668
- x = x / self.img_range + self.mean
669
- return x
670
-
671
-
672
- # Load the model
673
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
674
-
675
- model = HAT(
676
- upscale=4,
677
- in_chans=3,
678
- img_size=128,
679
- window_size=16,
680
- compress_ratio=3,
681
- squeeze_factor=30,
682
- conv_scale=0.01,
683
- overlap_ratio=0.5,
684
- img_range=1.,
685
- depths=[6, 6, 6, 6, 6, 6],
686
- embed_dim=180,
687
- num_heads=[6, 6, 6, 6, 6, 6],
688
- mlp_ratio=2,
689
- upsampler='pixelshuffle',
690
- resi_connection='1conv'
691
- )
692
-
693
- # Load the fine-tuned weights
694
- checkpoint = torch.load(MODEL_CHECKPOINT, map_location=device)
695
- # Try different checkpoint formats
696
- state_dict = checkpoint.get('params_ema') or checkpoint.get('params') or checkpoint
697
- model.load_state_dict(state_dict)
698
-
699
- model.to(device)
700
- model.eval()
701
-
702
-
703
- def upscale_image(image):
704
- # Convert PIL image to tensor
705
- img_np = np.array(image).astype(np.float32) / 255.0
706
- img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(device)
707
-
708
- # Ensure the image dimensions are multiples of window_size
709
- h, w = img_tensor.shape[2], img_tensor.shape[3]
710
-
711
- # Pad if necessary
712
- pad_h = (WINDOW_SIZE - h % WINDOW_SIZE) % WINDOW_SIZE
713
- pad_w = (WINDOW_SIZE - w % WINDOW_SIZE) % WINDOW_SIZE
714
-
715
- if pad_h > 0 or pad_w > 0:
716
- img_tensor = torch.nn.functional.pad(img_tensor, (0, pad_w, 0, pad_h), mode='reflect')
717
-
718
- with torch.no_grad():
719
- output = model(img_tensor)
720
-
721
- # Remove padding if it was added
722
- if pad_h > 0 or pad_w > 0:
723
- output = output[:, :, :h*UPSCALE_FACTOR, :w*UPSCALE_FACTOR]
724
-
725
- # Convert back to PIL image
726
- output_np = output.squeeze(0).permute(1, 2, 0).cpu().numpy()
727
- output_np = np.clip(output_np * 255.0, 0, 255).astype(np.uint8)
728
-
729
- return Image.fromarray(output_np)
730
-
731
-
732
- # Get sample images
733
- def get_sample_images():
734
- sample_dir = "sample_images"
735
- if os.path.exists(sample_dir):
736
- image_files = glob.glob(os.path.join(sample_dir, "*.png")) + glob.glob(os.path.join(sample_dir, "*.jpg"))
737
- return sorted(image_files)
738
- return []
739
-
740
- # Gradio interface using Blocks for better layout control
741
- def validate_image_size(image):
742
- """Validate that the image is exactly the required size"""
743
- if image is None:
744
- return False, "No image provided"
745
-
746
- width, height = image.size
747
- req_width, req_height = REQUIRED_IMAGE_SIZE
748
- if width != req_width or height != req_height:
749
- return False, f"Image must be exactly {req_width}x{req_height} pixels. Your image is {width}x{height} pixels."
750
-
751
- return True, "Valid image size"
752
-
753
- def upscale_and_display(image):
754
- if image is None:
755
- return None, "Please upload an image or select a sample image."
756
-
757
- # Validate image size
758
- is_valid, message = validate_image_size(image)
759
- if not is_valid:
760
- return None, f"❌ Error: {message}"
761
-
762
- try:
763
- # Get the super-resolution output
764
- upscaled = upscale_image(image)
765
- return upscaled, "✅ Image successfully enhanced!"
766
- except Exception as e:
767
- return None, f"❌ Error processing image: {str(e)}"
768
-
769
- def select_sample_image(image_path):
770
- if image_path:
771
- return Image.open(image_path)
772
- return None
773
-
774
- def image_to_base64(image_path):
775
- """Convert image to base64 data URL for CSS background"""
776
- img = Image.open(image_path)
777
- img.thumbnail((120, 120), Image.Resampling.LANCZOS)
778
- buffer = BytesIO()
779
- img.save(buffer, format='PNG')
780
- img_str = base64.b64encode(buffer.getvalue()).decode()
781
- return f"data:image/png;base64,{img_str}"
782
-
783
- # Generate CSS with base64 images
784
- def generate_css():
785
- base_css = """
786
- /* Target only the image display area, not the whole component */
787
- .image-container [data-testid="image"] {
788
- height: 500px !important;
789
- min-height: 500px !important;
790
- }
791
-
792
- /* Make images fill their containers */
793
- .image-container img {
794
- width: 500px !important;
795
- height: 500px !important;
796
- object-fit: contain !important;
797
- object-position: center !important;
798
- }
799
-
800
- /* Sample image buttons with background images */
801
- .sample-image-btn {
802
- height: 120px !important;
803
- width: 120px !important;
804
- background-size: cover !important;
805
- background-position: center !important;
806
- border: 2px solid #ddd !important;
807
- border-radius: 8px !important;
808
- cursor: pointer !important;
809
- transition: border-color 0.2s !important;
810
- margin: 5px !important;
811
- }
812
-
813
- .sample-image-btn:hover {
814
- border-color: #007acc !important;
815
- }
816
  """
817
 
818
- # Add background images for each sample (only if samples exist)
819
- sample_images = get_sample_images()
820
- if sample_images:
821
- for i, img_path in enumerate(sample_images):
822
- try:
823
- base64_img = image_to_base64(img_path)
824
- base_css += f"#sample_btn_{i} {{ background-image: url('{base64_img}'); }}\n"
825
- except Exception:
826
- # Skip invalid images
827
- continue
828
-
829
- return base_css
830
-
831
- css = generate_css()
832
-
833
- with gr.Blocks(css=css, title="HATSAT - Super-Resolution for Satellite Images") as iface:
834
- gr.Markdown("# HATSAT - Super-Resolution for Satellite Images")
835
- gr.Markdown(f"Upload a satellite image or select a sample to enhance its resolution by {UPSCALE_FACTOR}x.")
836
- gr.Markdown(f"⚠️ **Important**: Images must be exactly **{REQUIRED_IMAGE_SIZE[0]}x{REQUIRED_IMAGE_SIZE[1]} pixels** for the model to work properly.")
837
-
838
- # Acknowledgments section
839
- with gr.Accordion("Acknowledgments", open=False):
840
- gr.Markdown("""
841
- ### Base Model: HAT (Hybrid Attention Transformer)
842
- This model is a fine tuned version of **HAT**:
843
- - **GitHub Repository**: [https://github.com/XPixelGroup/HAT](https://github.com/XPixelGroup/HAT)
844
- - **Paper**: [Activating More Pixels in Image Super-Resolution Transformer](https://arxiv.org/abs/2205.04437)
845
- - **Authors**: Xiangyu Chen, Xintao Wang, Jiantao Zhou, Yu Qiao, Chao Dong
846
-
847
- ### Training Dataset: SEN2NAIPv2
848
- The model was fine-tuned using the **SEN2NAIPv2** dataset:
849
- - **HuggingFace Dataset**: [https://huggingface.co/datasets/tacofoundation/SEN2NAIPv2](https://huggingface.co/datasets/tacofoundation/SEN2NAIPv2)
850
- - **Description**: High-resolution satellite imagery dataset for super-resolution tasks
851
- """)
852
-
853
- # Sample images
854
- sample_images = get_sample_images()
855
- sample_buttons = []
856
- if sample_images:
857
- gr.Markdown("**Sample Images (click to select):**")
858
- with gr.Row():
859
- for i, img_path in enumerate(sample_images):
860
- btn = gr.Button(
861
- "",
862
- elem_id=f"sample_btn_{i}",
863
- elem_classes="sample-image-btn"
864
- )
865
- sample_buttons.append((btn, img_path))
866
-
867
- with gr.Row():
868
- input_image = gr.Image(
869
- type="pil",
870
- label=f"Input Image (must be {REQUIRED_IMAGE_SIZE[0]}x{REQUIRED_IMAGE_SIZE[1]} pixels)",
871
- elem_classes="image-container",
872
- sources=["upload"],
873
- height=500,
874
- width=500
875
- )
876
-
877
- output_image = gr.Image(
878
- type="pil",
879
- label=f"Enhanced Output ({UPSCALE_FACTOR}x)",
880
- elem_classes="image-container",
881
- interactive=False,
882
- height=500,
883
- width=500,
884
- show_download_button=True
885
- )
886
 
887
- submit_btn = gr.Button("Enhance Image", variant="primary")
888
 
889
- # Status message
890
- status_message = gr.Textbox(
891
- label="Status",
892
- interactive=False,
893
- show_label=True
894
- )
895
 
896
- # Event handlers
897
- if sample_images:
898
- for btn, img_path in sample_buttons:
899
- btn.click(fn=lambda path=img_path: select_sample_image(path), outputs=input_image)
900
 
901
- submit_btn.click(fn=upscale_and_display, inputs=input_image, outputs=[output_image, status_message])
902
 
903
  if __name__ == "__main__":
904
- iface.launch()
 
1
+ """
2
+ HATSAT - Super-Resolution for Satellite Images
3
+ Main application entry point.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  """
5
 
6
+ from utils.model_utils import load_model
7
+ from interface.gradio_app import create_interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
 
9
 
10
+ def main():
11
+ """Initialize and launch the HATSAT application."""
12
+ # Load model and get device
13
+ model, device = load_model()
 
 
14
 
15
+ # Create and launch Gradio interface
16
+ iface = create_interface(model, device)
17
+ iface.launch()
 
18
 
 
19
 
20
  if __name__ == "__main__":
21
+ main()
config.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration constants for HATSAT application.
3
+ """
4
+
5
+ # Model configuration
6
+ MODEL_CHECKPOINT = 'net_g_150000.pth'
7
+ REQUIRED_IMAGE_SIZE = (130, 130)
8
+ WINDOW_SIZE = 16
9
+ UPSCALE_FACTOR = 4
10
+
11
+ # Model architecture parameters
12
+ MODEL_CONFIG = {
13
+ 'upscale': 4,
14
+ 'in_chans': 3,
15
+ 'img_size': 128,
16
+ 'window_size': 16,
17
+ 'compress_ratio': 3,
18
+ 'squeeze_factor': 30,
19
+ 'conv_scale': 0.01,
20
+ 'overlap_ratio': 0.5,
21
+ 'img_range': 1.,
22
+ 'depths': [6, 6, 6, 6, 6, 6],
23
+ 'embed_dim': 180,
24
+ 'num_heads': [6, 6, 6, 6, 6, 6],
25
+ 'mlp_ratio': 2,
26
+ 'upsampler': 'pixelshuffle',
27
+ 'resi_connection': '1conv'
28
+ }
interface/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio interface components for HATSAT application.
3
+ """
4
+
5
+ from .gradio_app import create_interface
6
+ from .css_styles import generate_css, get_sample_images
7
+
8
+ __all__ = ['create_interface', 'generate_css', 'get_sample_images']
interface/css_styles.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CSS styling and sample image utilities for Gradio interface.
3
+ """
4
+
5
+ import os
6
+ import glob
7
+ from utils.image_utils import image_to_base64
8
+
9
+
10
+ def get_sample_images():
11
+ """Get list of sample images."""
12
+ sample_dir = "sample_images"
13
+ if os.path.exists(sample_dir):
14
+ image_files = glob.glob(os.path.join(sample_dir, "*.png")) + glob.glob(os.path.join(sample_dir, "*.jpg"))
15
+ return sorted(image_files)
16
+ return []
17
+
18
+
19
+ def generate_css():
20
+ """Generate CSS with base64 images for sample buttons."""
21
+ base_css = """
22
+ /* Target only the image display area, not the whole component */
23
+ .image-container [data-testid="image"] {
24
+ height: 500px !important;
25
+ min-height: 500px !important;
26
+ }
27
+
28
+ /* Make images fill their containers */
29
+ .image-container img {
30
+ width: 500px !important;
31
+ height: 500px !important;
32
+ object-fit: contain !important;
33
+ object-position: center !important;
34
+ }
35
+
36
+ /* Sample image buttons with background images */
37
+ .sample-image-btn {
38
+ height: 120px !important;
39
+ width: 120px !important;
40
+ background-size: cover !important;
41
+ background-position: center !important;
42
+ border: 2px solid #ddd !important;
43
+ border-radius: 8px !important;
44
+ cursor: pointer !important;
45
+ transition: border-color 0.2s !important;
46
+ margin: 5px !important;
47
+ }
48
+
49
+ .sample-image-btn:hover {
50
+ border-color: #007acc !important;
51
+ }
52
+ """
53
+
54
+ # Add background images for each sample (only if samples exist)
55
+ sample_images = get_sample_images()
56
+ if sample_images:
57
+ for i, img_path in enumerate(sample_images):
58
+ try:
59
+ base64_img = image_to_base64(img_path)
60
+ base_css += f"#sample_btn_{i} {{ background-image: url('{base64_img}'); }}\n"
61
+ except Exception:
62
+ # Skip invalid images
63
+ continue
64
+
65
+ return base_css
interface/gradio_app.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio interface for HATSAT application.
3
+ """
4
+
5
+ import gradio as gr
6
+ from PIL import Image
7
+
8
+ from config import REQUIRED_IMAGE_SIZE, UPSCALE_FACTOR
9
+ from utils.image_utils import validate_image_size, upscale_image
10
+ from interface.css_styles import generate_css, get_sample_images
11
+
12
+
13
+ def upscale_and_display(image, model, device):
14
+ """Process image upload and return upscaled result."""
15
+ if image is None:
16
+ return None, "Please upload an image or select a sample image."
17
+
18
+ # Validate image size
19
+ is_valid, message = validate_image_size(image)
20
+ if not is_valid:
21
+ return None, f"❌ Error: {message}"
22
+
23
+ try:
24
+ # Get the super-resolution output
25
+ upscaled = upscale_image(image, model, device)
26
+ return upscaled, "✅ Image successfully enhanced!"
27
+ except Exception as e:
28
+ return None, f"❌ Error processing image: {str(e)}"
29
+
30
+
31
+ def select_sample_image(image_path):
32
+ """Load and return a sample image."""
33
+ if image_path:
34
+ return Image.open(image_path)
35
+ return None
36
+
37
+
38
+ def create_interface(model, device):
39
+ """Create and configure the Gradio interface."""
40
+ css = generate_css()
41
+
42
+ with gr.Blocks(css=css, title="HATSAT - Super-Resolution for Satellite Images") as iface:
43
+ gr.Markdown("# HATSAT - Super-Resolution for Satellite Images")
44
+ gr.Markdown(f"Upload a satellite image or select a sample to enhance its resolution by {UPSCALE_FACTOR}x.")
45
+ gr.Markdown(f"⚠️ **Important**: Images must be exactly **{REQUIRED_IMAGE_SIZE[0]}x{REQUIRED_IMAGE_SIZE[1]} pixels** for the model to work properly.")
46
+
47
+ # Acknowledgments section
48
+ with gr.Accordion("Acknowledgments", open=False):
49
+ gr.Markdown("""
50
+ ### Base Model: HAT (Hybrid Attention Transformer)
51
+ This model is a fine tuned version of **HAT**:
52
+ - **GitHub Repository**: [https://github.com/XPixelGroup/HAT](https://github.com/XPixelGroup/HAT)
53
+ - **Paper**: [Activating More Pixels in Image Super-Resolution Transformer](https://arxiv.org/abs/2205.04437)
54
+ - **Authors**: Xiangyu Chen, Xintao Wang, Jiantao Zhou, Yu Qiao, Chao Dong
55
+
56
+ ### Training Dataset: SEN2NAIPv2
57
+ The model was fine-tuned using the **SEN2NAIPv2** dataset:
58
+ - **HuggingFace Dataset**: [https://huggingface.co/datasets/tacofoundation/SEN2NAIPv2](https://huggingface.co/datasets/tacofoundation/SEN2NAIPv2)
59
+ - **Description**: High-resolution satellite imagery dataset for super-resolution tasks
60
+ """)
61
+
62
+ # Sample images
63
+ sample_images = get_sample_images()
64
+ sample_buttons = []
65
+ if sample_images:
66
+ gr.Markdown("**Sample Images (click to select):**")
67
+ with gr.Row():
68
+ for i, img_path in enumerate(sample_images):
69
+ btn = gr.Button(
70
+ "",
71
+ elem_id=f"sample_btn_{i}",
72
+ elem_classes="sample-image-btn"
73
+ )
74
+ sample_buttons.append((btn, img_path))
75
+
76
+ with gr.Row():
77
+ input_image = gr.Image(
78
+ type="pil",
79
+ label=f"Input Image (must be {REQUIRED_IMAGE_SIZE[0]}x{REQUIRED_IMAGE_SIZE[1]} pixels)",
80
+ elem_classes="image-container",
81
+ sources=["upload"],
82
+ height=500,
83
+ width=500
84
+ )
85
+
86
+ output_image = gr.Image(
87
+ type="pil",
88
+ label=f"Enhanced Output ({UPSCALE_FACTOR}x)",
89
+ elem_classes="image-container",
90
+ interactive=False,
91
+ height=500,
92
+ width=500,
93
+ show_download_button=True
94
+ )
95
+
96
+ submit_btn = gr.Button("Enhance Image", variant="primary")
97
+
98
+ # Status message
99
+ status_message = gr.Textbox(
100
+ label="Status",
101
+ interactive=False,
102
+ show_label=True
103
+ )
104
+
105
+ # Event handlers
106
+ if sample_images:
107
+ for btn, img_path in sample_buttons:
108
+ btn.click(fn=lambda path=img_path: select_sample_image(path), outputs=input_image)
109
+
110
+ submit_btn.click(
111
+ fn=lambda img: upscale_and_display(img, model, device),
112
+ inputs=input_image,
113
+ outputs=[output_image, status_message]
114
+ )
115
+
116
+ return iface
model/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HAT model architecture components.
3
+ """
4
+
5
+ from .hat_model import HAT
6
+ from .components import (
7
+ DropPath, ChannelAttention, CAB, Mlp,
8
+ WindowAttention, HAB, OCAB, AttenBlocks,
9
+ RHAG, PatchEmbed, PatchUnEmbed, Upsample
10
+ )
11
+
12
+ __all__ = [
13
+ 'HAT', 'DropPath', 'ChannelAttention', 'CAB', 'Mlp',
14
+ 'WindowAttention', 'HAB', 'OCAB', 'AttenBlocks',
15
+ 'RHAG', 'PatchEmbed', 'PatchUnEmbed', 'Upsample'
16
+ ]
model/components.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HAT model components and building blocks.
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import math
8
+ from einops import rearrange
9
+
10
+
11
+ def to_2tuple(x):
12
+ """Convert input to tuple of length 2."""
13
+ if isinstance(x, (tuple, list)):
14
+ return tuple(x)
15
+ return (x, x)
16
+
17
+
18
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
19
+ """Truncated normal initialization."""
20
+ def norm_cdf(x):
21
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
22
+
23
+ with torch.no_grad():
24
+ l = norm_cdf((a - mean) / std)
25
+ u = norm_cdf((b - mean) / std)
26
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
27
+ tensor.erfinv_()
28
+ tensor.mul_(std * math.sqrt(2.))
29
+ tensor.add_(mean)
30
+ tensor.clamp_(min=a, max=b)
31
+ return tensor
32
+
33
+
34
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
35
+ if drop_prob == 0. or not training:
36
+ return x
37
+ keep_prob = 1 - drop_prob
38
+ shape = (x.shape[0], ) + (1, ) * (x.ndim - 1)
39
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
40
+ random_tensor.floor_()
41
+ output = x.div(keep_prob) * random_tensor
42
+ return output
43
+
44
+
45
+ class DropPath(nn.Module):
46
+ def __init__(self, drop_prob=None):
47
+ super(DropPath, self).__init__()
48
+ self.drop_prob = drop_prob
49
+
50
+ def forward(self, x):
51
+ return drop_path(x, self.drop_prob, self.training)
52
+
53
+
54
+ class ChannelAttention(nn.Module):
55
+ def __init__(self, num_feat, squeeze_factor=16):
56
+ super(ChannelAttention, self).__init__()
57
+ self.attention = nn.Sequential(
58
+ nn.AdaptiveAvgPool2d(1),
59
+ nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0),
60
+ nn.ReLU(inplace=True),
61
+ nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0),
62
+ nn.Sigmoid())
63
+
64
+ def forward(self, x):
65
+ y = self.attention(x)
66
+ return x * y
67
+
68
+
69
+ class CAB(nn.Module):
70
+ def __init__(self, num_feat, compress_ratio=3, squeeze_factor=30):
71
+ super(CAB, self).__init__()
72
+ self.cab = nn.Sequential(
73
+ nn.Conv2d(num_feat, num_feat // compress_ratio, 3, 1, 1),
74
+ nn.GELU(),
75
+ nn.Conv2d(num_feat // compress_ratio, num_feat, 3, 1, 1),
76
+ ChannelAttention(num_feat, squeeze_factor)
77
+ )
78
+
79
+ def forward(self, x):
80
+ return self.cab(x)
81
+
82
+
83
+ class Mlp(nn.Module):
84
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
85
+ super().__init__()
86
+ out_features = out_features or in_features
87
+ hidden_features = hidden_features or in_features
88
+ self.fc1 = nn.Linear(in_features, hidden_features)
89
+ self.act = act_layer()
90
+ self.fc2 = nn.Linear(hidden_features, out_features)
91
+ self.drop = nn.Dropout(drop)
92
+
93
+ def forward(self, x):
94
+ x = self.fc1(x)
95
+ x = self.act(x)
96
+ x = self.drop(x)
97
+ x = self.fc2(x)
98
+ x = self.drop(x)
99
+ return x
100
+
101
+
102
+ def window_partition(x, window_size):
103
+ b, h, w, c = x.shape
104
+ x = x.view(b, h // window_size, window_size, w // window_size, window_size, c)
105
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, c)
106
+ return windows
107
+
108
+
109
+ def window_reverse(windows, window_size, h, w):
110
+ b = int(windows.shape[0] / (h * w / window_size / window_size))
111
+ x = windows.view(b, h // window_size, w // window_size, window_size, window_size, -1)
112
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1)
113
+ return x
114
+
115
+
116
+ class WindowAttention(nn.Module):
117
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
118
+ super().__init__()
119
+ self.dim = dim
120
+ self.window_size = window_size
121
+ self.num_heads = num_heads
122
+ head_dim = dim // num_heads
123
+ self.scale = qk_scale or head_dim**-0.5
124
+
125
+ self.relative_position_bias_table = nn.Parameter(
126
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))
127
+
128
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
129
+ self.attn_drop = nn.Dropout(attn_drop)
130
+ self.proj = nn.Linear(dim, dim)
131
+ self.proj_drop = nn.Dropout(proj_drop)
132
+
133
+ trunc_normal_(self.relative_position_bias_table, std=.02)
134
+ self.softmax = nn.Softmax(dim=-1)
135
+
136
+ def forward(self, x, rpi, mask=None):
137
+ b_, n, c = x.shape
138
+ qkv = self.qkv(x).reshape(b_, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4)
139
+ q, k, v = qkv[0], qkv[1], qkv[2]
140
+
141
+ q = q * self.scale
142
+ attn = (q @ k.transpose(-2, -1))
143
+
144
+ relative_position_bias = self.relative_position_bias_table[rpi.view(-1)].view(
145
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)
146
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
147
+ attn = attn + relative_position_bias.unsqueeze(0)
148
+
149
+ if mask is not None:
150
+ nw = mask.shape[0]
151
+ attn = attn.view(b_ // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0)
152
+ attn = attn.view(-1, self.num_heads, n, n)
153
+ attn = self.softmax(attn)
154
+ else:
155
+ attn = self.softmax(attn)
156
+
157
+ attn = self.attn_drop(attn)
158
+
159
+ x = (attn @ v).transpose(1, 2).reshape(b_, n, c)
160
+ x = self.proj(x)
161
+ x = self.proj_drop(x)
162
+ return x
163
+
164
+
165
+ class HAB(nn.Module):
166
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
167
+ compress_ratio=3, squeeze_factor=30, conv_scale=0.01, mlp_ratio=4.,
168
+ qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
169
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
170
+ super().__init__()
171
+ self.dim = dim
172
+ self.input_resolution = input_resolution
173
+ self.num_heads = num_heads
174
+ self.window_size = window_size
175
+ self.shift_size = shift_size
176
+ self.mlp_ratio = mlp_ratio
177
+ if min(self.input_resolution) <= self.window_size:
178
+ self.shift_size = 0
179
+ self.window_size = min(self.input_resolution)
180
+ assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size'
181
+
182
+ self.norm1 = norm_layer(dim)
183
+ self.attn = WindowAttention(
184
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
185
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
186
+
187
+ self.conv_scale = conv_scale
188
+ self.conv_block = CAB(num_feat=dim, compress_ratio=compress_ratio, squeeze_factor=squeeze_factor)
189
+
190
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
191
+ self.norm2 = norm_layer(dim)
192
+ mlp_hidden_dim = int(dim * mlp_ratio)
193
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
194
+
195
+ def forward(self, x, x_size, rpi_sa, attn_mask):
196
+ h, w = x_size
197
+ b, _, c = x.shape
198
+
199
+ shortcut = x
200
+ x = self.norm1(x)
201
+ x = x.view(b, h, w, c)
202
+
203
+ # Conv_X
204
+ conv_x = self.conv_block(x.permute(0, 3, 1, 2))
205
+ conv_x = conv_x.permute(0, 2, 3, 1).contiguous().view(b, h * w, c)
206
+
207
+ # cyclic shift
208
+ if self.shift_size > 0:
209
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
210
+ attn_mask = attn_mask
211
+ else:
212
+ shifted_x = x
213
+ attn_mask = None
214
+
215
+ # partition windows
216
+ x_windows = window_partition(shifted_x, self.window_size)
217
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, c)
218
+
219
+ # W-MSA/SW-MSA
220
+ attn_windows = self.attn(x_windows, rpi=rpi_sa, mask=attn_mask)
221
+
222
+ # merge windows
223
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, c)
224
+ shifted_x = window_reverse(attn_windows, self.window_size, h, w)
225
+
226
+ # reverse cyclic shift
227
+ if self.shift_size > 0:
228
+ attn_x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
229
+ else:
230
+ attn_x = shifted_x
231
+ attn_x = attn_x.view(b, h * w, c)
232
+
233
+ # FFN
234
+ x = shortcut + self.drop_path(attn_x) + conv_x * self.conv_scale
235
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
236
+
237
+ return x
238
+
239
+
240
+ class OCAB(nn.Module):
241
+ def __init__(self, dim, input_resolution, window_size, overlap_ratio, num_heads,
242
+ qkv_bias=True, qk_scale=None, mlp_ratio=2, norm_layer=nn.LayerNorm):
243
+ super().__init__()
244
+ self.dim = dim
245
+ self.input_resolution = input_resolution
246
+ self.window_size = window_size
247
+ self.num_heads = num_heads
248
+ head_dim = dim // num_heads
249
+ self.scale = qk_scale or head_dim**-0.5
250
+ self.overlap_win_size = int(window_size * overlap_ratio) + window_size
251
+
252
+ self.norm1 = norm_layer(dim)
253
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
254
+ self.unfold = nn.Unfold(kernel_size=(self.overlap_win_size, self.overlap_win_size),
255
+ stride=window_size, padding=(self.overlap_win_size-window_size)//2)
256
+
257
+ self.relative_position_bias_table = nn.Parameter(
258
+ torch.zeros((window_size + self.overlap_win_size - 1) * (window_size + self.overlap_win_size - 1), num_heads))
259
+
260
+ trunc_normal_(self.relative_position_bias_table, std=.02)
261
+ self.softmax = nn.Softmax(dim=-1)
262
+
263
+ self.proj = nn.Linear(dim,dim)
264
+
265
+ self.norm2 = norm_layer(dim)
266
+ mlp_hidden_dim = int(dim * mlp_ratio)
267
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=nn.GELU)
268
+
269
+ def forward(self, x, x_size, rpi):
270
+ h, w = x_size
271
+ b, _, c = x.shape
272
+
273
+ shortcut = x
274
+ x = self.norm1(x)
275
+ x = x.view(b, h, w, c)
276
+
277
+ qkv = self.qkv(x).reshape(b, h, w, 3, c).permute(3, 0, 4, 1, 2)
278
+ q = qkv[0].permute(0, 2, 3, 1)
279
+ kv = torch.cat((qkv[1], qkv[2]), dim=1)
280
+
281
+ # partition windows
282
+ q_windows = window_partition(q, self.window_size)
283
+ q_windows = q_windows.view(-1, self.window_size * self.window_size, c)
284
+
285
+ kv_windows = self.unfold(kv)
286
+ kv_windows = rearrange(kv_windows, 'b (nc ch owh oww) nw -> nc (b nw) (owh oww) ch',
287
+ nc=2, ch=c, owh=self.overlap_win_size, oww=self.overlap_win_size).contiguous()
288
+ k_windows, v_windows = kv_windows[0], kv_windows[1]
289
+
290
+ b_, nq, _ = q_windows.shape
291
+ _, n, _ = k_windows.shape
292
+ d = self.dim // self.num_heads
293
+ q = q_windows.reshape(b_, nq, self.num_heads, d).permute(0, 2, 1, 3)
294
+ k = k_windows.reshape(b_, n, self.num_heads, d).permute(0, 2, 1, 3)
295
+ v = v_windows.reshape(b_, n, self.num_heads, d).permute(0, 2, 1, 3)
296
+
297
+ q = q * self.scale
298
+ attn = (q @ k.transpose(-2, -1))
299
+
300
+ relative_position_bias = self.relative_position_bias_table[rpi.view(-1)].view(
301
+ self.window_size * self.window_size, self.overlap_win_size * self.overlap_win_size, -1)
302
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
303
+ attn = attn + relative_position_bias.unsqueeze(0)
304
+
305
+ attn = self.softmax(attn)
306
+ attn_windows = (attn @ v).transpose(1, 2).reshape(b_, nq, self.dim)
307
+
308
+ # merge windows
309
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, self.dim)
310
+ x = window_reverse(attn_windows, self.window_size, h, w)
311
+ x = x.view(b, h * w, self.dim)
312
+
313
+ x = self.proj(x) + shortcut
314
+ x = x + self.mlp(self.norm2(x))
315
+ return x
316
+
317
+
318
+ class AttenBlocks(nn.Module):
319
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size, compress_ratio,
320
+ squeeze_factor, conv_scale, overlap_ratio, mlp_ratio=4., qkv_bias=True, qk_scale=None,
321
+ drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, downsample=None,
322
+ use_checkpoint=False):
323
+ super().__init__()
324
+ self.dim = dim
325
+ self.input_resolution = input_resolution
326
+ self.depth = depth
327
+ self.use_checkpoint = use_checkpoint
328
+
329
+ # build blocks
330
+ self.blocks = nn.ModuleList([
331
+ HAB(dim=dim, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size,
332
+ shift_size=0 if (i % 2 == 0) else window_size // 2, compress_ratio=compress_ratio,
333
+ squeeze_factor=squeeze_factor, conv_scale=conv_scale, mlp_ratio=mlp_ratio,
334
+ qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop,
335
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
336
+ norm_layer=norm_layer) for i in range(depth)
337
+ ])
338
+
339
+ # OCAB
340
+ self.overlap_attn = OCAB(dim=dim, input_resolution=input_resolution, window_size=window_size,
341
+ overlap_ratio=overlap_ratio, num_heads=num_heads, qkv_bias=qkv_bias,
342
+ qk_scale=qk_scale, mlp_ratio=mlp_ratio, norm_layer=norm_layer)
343
+
344
+ # patch merging layer
345
+ if downsample is not None:
346
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
347
+ else:
348
+ self.downsample = None
349
+
350
+ def forward(self, x, x_size, params):
351
+ for blk in self.blocks:
352
+ x = blk(x, x_size, params['rpi_sa'], params['attn_mask'])
353
+
354
+ x = self.overlap_attn(x, x_size, params['rpi_oca'])
355
+
356
+ if self.downsample is not None:
357
+ x = self.downsample(x)
358
+ return x
359
+
360
+
361
+ class RHAG(nn.Module):
362
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size, compress_ratio,
363
+ squeeze_factor, conv_scale, overlap_ratio, mlp_ratio=4., qkv_bias=True, qk_scale=None,
364
+ drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, downsample=None,
365
+ use_checkpoint=False, img_size=224, patch_size=4, resi_connection='1conv'):
366
+ super(RHAG, self).__init__()
367
+
368
+ self.dim = dim
369
+ self.input_resolution = input_resolution
370
+
371
+ self.residual_group = AttenBlocks(
372
+ dim=dim, input_resolution=input_resolution, depth=depth, num_heads=num_heads,
373
+ window_size=window_size, compress_ratio=compress_ratio, squeeze_factor=squeeze_factor,
374
+ conv_scale=conv_scale, overlap_ratio=overlap_ratio, mlp_ratio=mlp_ratio,
375
+ qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop,
376
+ drop_path=drop_path, norm_layer=norm_layer, downsample=downsample,
377
+ use_checkpoint=use_checkpoint)
378
+
379
+ if resi_connection == '1conv':
380
+ self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
381
+ elif resi_connection == 'identity':
382
+ self.conv = nn.Identity()
383
+
384
+ self.patch_embed = PatchEmbed(
385
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
386
+
387
+ self.patch_unembed = PatchUnEmbed(
388
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
389
+
390
+ def forward(self, x, x_size, params):
391
+ return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size, params), x_size))) + x
392
+
393
+
394
+ class PatchEmbed(nn.Module):
395
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
396
+ super().__init__()
397
+ img_size = to_2tuple(img_size)
398
+ patch_size = to_2tuple(patch_size)
399
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
400
+ self.img_size = img_size
401
+ self.patch_size = patch_size
402
+ self.patches_resolution = patches_resolution
403
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
404
+
405
+ self.in_chans = in_chans
406
+ self.embed_dim = embed_dim
407
+
408
+ if norm_layer is not None:
409
+ self.norm = norm_layer(embed_dim)
410
+ else:
411
+ self.norm = None
412
+
413
+ def forward(self, x):
414
+ x = x.flatten(2).transpose(1, 2)
415
+ if self.norm is not None:
416
+ x = self.norm(x)
417
+ return x
418
+
419
+
420
+ class PatchUnEmbed(nn.Module):
421
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
422
+ super().__init__()
423
+ img_size = to_2tuple(img_size)
424
+ patch_size = to_2tuple(patch_size)
425
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
426
+ self.img_size = img_size
427
+ self.patch_size = patch_size
428
+ self.patches_resolution = patches_resolution
429
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
430
+
431
+ self.in_chans = in_chans
432
+ self.embed_dim = embed_dim
433
+
434
+ def forward(self, x, x_size):
435
+ x = x.transpose(1, 2).contiguous().view(x.shape[0], self.embed_dim, x_size[0], x_size[1])
436
+ return x
437
+
438
+
439
+ class Upsample(nn.Sequential):
440
+ def __init__(self, scale, num_feat):
441
+ m = []
442
+ if (scale & (scale - 1)) == 0:
443
+ for _ in range(int(math.log(scale, 2))):
444
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
445
+ m.append(nn.PixelShuffle(2))
446
+ elif scale == 3:
447
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
448
+ m.append(nn.PixelShuffle(3))
449
+ else:
450
+ raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
451
+ super(Upsample, self).__init__(*m)
model/hat_model.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HAT (Hybrid Attention Transformer) main model implementation.
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import math
8
+
9
+ from .components import (
10
+ RHAG, PatchEmbed, PatchUnEmbed, Upsample,
11
+ trunc_normal_, window_partition, to_2tuple
12
+ )
13
+
14
+
15
+ class HAT(nn.Module):
16
+ def __init__(self, img_size=64, patch_size=1, in_chans=3, embed_dim=96, depths=(6, 6, 6, 6),
17
+ num_heads=(6, 6, 6, 6), window_size=7, compress_ratio=3, squeeze_factor=30,
18
+ conv_scale=0.01, overlap_ratio=0.5, mlp_ratio=4., qkv_bias=True, qk_scale=None,
19
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm,
20
+ ape=False, patch_norm=True, use_checkpoint=False, upscale=2, img_range=1.,
21
+ upsampler='', resi_connection='1conv', **kwargs):
22
+ super(HAT, self).__init__()
23
+
24
+ self.window_size = window_size
25
+ self.shift_size = window_size // 2
26
+ self.overlap_ratio = overlap_ratio
27
+
28
+ num_in_ch = in_chans
29
+ num_out_ch = in_chans
30
+ num_feat = 64
31
+ self.img_range = img_range
32
+ if in_chans == 3:
33
+ rgb_mean = (0.4488, 0.4371, 0.4040)
34
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
35
+ else:
36
+ self.mean = torch.zeros(1, 1, 1, 1)
37
+ self.upscale = upscale
38
+ self.upsampler = upsampler
39
+
40
+ # relative position index
41
+ relative_position_index_SA = self.calculate_rpi_sa()
42
+ relative_position_index_OCA = self.calculate_rpi_oca()
43
+ self.register_buffer('relative_position_index_SA', relative_position_index_SA)
44
+ self.register_buffer('relative_position_index_OCA', relative_position_index_OCA)
45
+
46
+ # shallow feature extraction
47
+ self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
48
+
49
+ # deep feature extraction
50
+ self.num_layers = len(depths)
51
+ self.embed_dim = embed_dim
52
+ self.ape = ape
53
+ self.patch_norm = patch_norm
54
+ self.num_features = embed_dim
55
+ self.mlp_ratio = mlp_ratio
56
+
57
+ # split image into non-overlapping patches
58
+ self.patch_embed = PatchEmbed(
59
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
60
+ norm_layer=norm_layer if self.patch_norm else None)
61
+ num_patches = self.patch_embed.num_patches
62
+ patches_resolution = self.patch_embed.patches_resolution
63
+ self.patches_resolution = patches_resolution
64
+
65
+ # merge non-overlapping patches into image
66
+ self.patch_unembed = PatchUnEmbed(
67
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
68
+ norm_layer=norm_layer if self.patch_norm else None)
69
+
70
+ # absolute position embedding
71
+ if self.ape:
72
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
73
+ trunc_normal_(self.absolute_pos_embed, std=.02)
74
+
75
+ self.pos_drop = nn.Dropout(p=drop_rate)
76
+
77
+ # stochastic depth
78
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
79
+
80
+ # build Residual Hybrid Attention Groups (RHAG)
81
+ self.layers = nn.ModuleList()
82
+ for i_layer in range(self.num_layers):
83
+ layer = RHAG(
84
+ dim=embed_dim,
85
+ input_resolution=(patches_resolution[0], patches_resolution[1]),
86
+ depth=depths[i_layer],
87
+ num_heads=num_heads[i_layer],
88
+ window_size=window_size,
89
+ compress_ratio=compress_ratio,
90
+ squeeze_factor=squeeze_factor,
91
+ conv_scale=conv_scale,
92
+ overlap_ratio=overlap_ratio,
93
+ mlp_ratio=self.mlp_ratio,
94
+ qkv_bias=qkv_bias,
95
+ qk_scale=qk_scale,
96
+ drop=drop_rate,
97
+ attn_drop=attn_drop_rate,
98
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
99
+ norm_layer=norm_layer,
100
+ downsample=None,
101
+ use_checkpoint=use_checkpoint,
102
+ img_size=img_size,
103
+ patch_size=patch_size,
104
+ resi_connection=resi_connection)
105
+ self.layers.append(layer)
106
+ self.norm = norm_layer(self.num_features)
107
+
108
+ # build the last conv layer in deep feature extraction
109
+ if resi_connection == '1conv':
110
+ self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
111
+ elif resi_connection == 'identity':
112
+ self.conv_after_body = nn.Identity()
113
+
114
+ # high quality image reconstruction
115
+ if self.upsampler == 'pixelshuffle':
116
+ self.conv_before_upsample = nn.Sequential(
117
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
118
+ self.upsample = Upsample(upscale, num_feat)
119
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
120
+
121
+ self.apply(self._init_weights)
122
+
123
+ def _init_weights(self, m):
124
+ if isinstance(m, nn.Linear):
125
+ trunc_normal_(m.weight, std=.02)
126
+ if isinstance(m, nn.Linear) and m.bias is not None:
127
+ nn.init.constant_(m.bias, 0)
128
+ elif isinstance(m, nn.LayerNorm):
129
+ nn.init.constant_(m.bias, 0)
130
+ nn.init.constant_(m.weight, 1.0)
131
+
132
+ def calculate_rpi_sa(self):
133
+ coords_h = torch.arange(self.window_size)
134
+ coords_w = torch.arange(self.window_size)
135
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
136
+ coords_flatten = torch.flatten(coords, 1)
137
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
138
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous()
139
+ relative_coords[:, :, 0] += self.window_size - 1
140
+ relative_coords[:, :, 1] += self.window_size - 1
141
+ relative_coords[:, :, 0] *= 2 * self.window_size - 1
142
+ relative_position_index = relative_coords.sum(-1)
143
+ return relative_position_index
144
+
145
+ def calculate_rpi_oca(self):
146
+ window_size_ori = self.window_size
147
+ window_size_ext = self.window_size + int(self.overlap_ratio * self.window_size)
148
+
149
+ coords_h = torch.arange(window_size_ori)
150
+ coords_w = torch.arange(window_size_ori)
151
+ coords_ori = torch.stack(torch.meshgrid([coords_h, coords_w]))
152
+ coords_ori_flatten = torch.flatten(coords_ori, 1)
153
+
154
+ coords_h = torch.arange(window_size_ext)
155
+ coords_w = torch.arange(window_size_ext)
156
+ coords_ext = torch.stack(torch.meshgrid([coords_h, coords_w]))
157
+ coords_ext_flatten = torch.flatten(coords_ext, 1)
158
+
159
+ relative_coords = coords_ext_flatten[:, None, :] - coords_ori_flatten[:, :, None]
160
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous()
161
+ relative_coords[:, :, 0] += window_size_ori - window_size_ext + 1
162
+ relative_coords[:, :, 1] += window_size_ori - window_size_ext + 1
163
+ relative_coords[:, :, 0] *= window_size_ori + window_size_ext - 1
164
+ relative_position_index = relative_coords.sum(-1)
165
+ return relative_position_index
166
+
167
+ def calculate_mask(self, x_size):
168
+ h, w = x_size
169
+ img_mask = torch.zeros((1, h, w, 1))
170
+ h_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None))
171
+ w_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None))
172
+ cnt = 0
173
+ for h in h_slices:
174
+ for w in w_slices:
175
+ img_mask[:, h, w, :] = cnt
176
+ cnt += 1
177
+
178
+ mask_windows = window_partition(img_mask, self.window_size)
179
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
180
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
181
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
182
+ return attn_mask
183
+
184
+ @torch.jit.ignore
185
+ def no_weight_decay(self):
186
+ return {'absolute_pos_embed'}
187
+
188
+ @torch.jit.ignore
189
+ def no_weight_decay_keywords(self):
190
+ return {'relative_position_bias_table'}
191
+
192
+ def forward_features(self, x):
193
+ x_size = (x.shape[2], x.shape[3])
194
+
195
+ attn_mask = self.calculate_mask(x_size).to(x.device)
196
+ params = {'attn_mask': attn_mask, 'rpi_sa': self.relative_position_index_SA, 'rpi_oca': self.relative_position_index_OCA}
197
+
198
+ x = self.patch_embed(x)
199
+ if self.ape:
200
+ x = x + self.absolute_pos_embed
201
+ x = self.pos_drop(x)
202
+
203
+ for layer in self.layers:
204
+ x = layer(x, x_size, params)
205
+
206
+ x = self.norm(x)
207
+ x = self.patch_unembed(x, x_size)
208
+ return x
209
+
210
+ def forward(self, x):
211
+ self.mean = self.mean.type_as(x)
212
+ x = (x - self.mean) * self.img_range
213
+
214
+ if self.upsampler == 'pixelshuffle':
215
+ x = self.conv_first(x)
216
+ x = self.conv_after_body(self.forward_features(x)) + x
217
+ x = self.conv_before_upsample(x)
218
+ x = self.conv_last(self.upsample(x))
219
+
220
+ x = x / self.img_range + self.mean
221
+ return x
utils/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions for HATSAT application.
3
+ """
4
+
5
+ from .image_utils import upscale_image, validate_image_size, image_to_base64
6
+ from .model_utils import load_model, get_device
7
+
8
+ __all__ = [
9
+ 'upscale_image', 'validate_image_size', 'image_to_base64',
10
+ 'load_model', 'get_device'
11
+ ]
utils/image_utils.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Image processing utilities.
3
+ """
4
+
5
+ import torch
6
+ import numpy as np
7
+ from PIL import Image
8
+ import base64
9
+ from io import BytesIO
10
+
11
+ from config import REQUIRED_IMAGE_SIZE, WINDOW_SIZE, UPSCALE_FACTOR
12
+
13
+
14
+ def validate_image_size(image):
15
+ """Validate that the image is exactly the required size."""
16
+ if image is None:
17
+ return False, "No image provided"
18
+
19
+ width, height = image.size
20
+ req_width, req_height = REQUIRED_IMAGE_SIZE
21
+ if width != req_width or height != req_height:
22
+ return False, f"Image must be exactly {req_width}x{req_height} pixels. Your image is {width}x{height} pixels."
23
+
24
+ return True, "Valid image size"
25
+
26
+
27
+ def upscale_image(image, model, device):
28
+ """Upscale an image using the HAT model."""
29
+ # Convert PIL image to tensor
30
+ img_np = np.array(image).astype(np.float32) / 255.0
31
+ img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(device)
32
+
33
+ # Ensure the image dimensions are multiples of window_size
34
+ h, w = img_tensor.shape[2], img_tensor.shape[3]
35
+
36
+ # Pad if necessary
37
+ pad_h = (WINDOW_SIZE - h % WINDOW_SIZE) % WINDOW_SIZE
38
+ pad_w = (WINDOW_SIZE - w % WINDOW_SIZE) % WINDOW_SIZE
39
+
40
+ if pad_h > 0 or pad_w > 0:
41
+ img_tensor = torch.nn.functional.pad(img_tensor, (0, pad_w, 0, pad_h), mode='reflect')
42
+
43
+ with torch.no_grad():
44
+ output = model(img_tensor)
45
+
46
+ # Remove padding if it was added
47
+ if pad_h > 0 or pad_w > 0:
48
+ output = output[:, :, :h*UPSCALE_FACTOR, :w*UPSCALE_FACTOR]
49
+
50
+ # Convert back to PIL image
51
+ output_np = output.squeeze(0).permute(1, 2, 0).cpu().numpy()
52
+ output_np = np.clip(output_np * 255.0, 0, 255).astype(np.uint8)
53
+
54
+ return Image.fromarray(output_np)
55
+
56
+
57
+ def image_to_base64(image_path):
58
+ """Convert image to base64 data URL for CSS background."""
59
+ img = Image.open(image_path)
60
+ img.thumbnail((120, 120), Image.Resampling.LANCZOS)
61
+ buffer = BytesIO()
62
+ img.save(buffer, format='PNG')
63
+ img_str = base64.b64encode(buffer.getvalue()).decode()
64
+ return f"data:image/png;base64,{img_str}"
utils/model_utils.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model loading and device utilities.
3
+ """
4
+
5
+ import torch
6
+ from model import HAT
7
+ from config import MODEL_CHECKPOINT, MODEL_CONFIG
8
+
9
+
10
+ def get_device():
11
+ """Get the appropriate device for model inference."""
12
+ return torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
+
14
+
15
+ def load_model():
16
+ """Load and initialize the HAT model with pre-trained weights."""
17
+ device = get_device()
18
+
19
+ # Initialize model
20
+ model = HAT(**MODEL_CONFIG)
21
+
22
+ # Load the fine-tuned weights
23
+ checkpoint = torch.load(MODEL_CHECKPOINT, map_location=device)
24
+ # Try different checkpoint formats
25
+ state_dict = checkpoint.get('params_ema') or checkpoint.get('params') or checkpoint
26
+ model.load_state_dict(state_dict)
27
+
28
+ model.to(device)
29
+ model.eval()
30
+
31
+ return model, device