HATSAT / app.py
BorisEm's picture
Change title
4993aa4
raw
history blame
33.7 kB
import gradio as gr
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import math
from einops import rearrange
import os
import glob
import base64
from io import BytesIO
def to_2tuple(x):
"""Convert input to tuple of length 2."""
if isinstance(x, (tuple, list)):
return tuple(x)
return (x, x)
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
"""Truncated normal initialization."""
def norm_cdf(x):
return (1. + math.erf(x / math.sqrt(2.))) / 2.
with torch.no_grad():
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
tensor.uniform_(2 * l - 1, 2 * u - 1)
tensor.erfinv_()
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
tensor.clamp_(min=a, max=b)
return tensor
def drop_path(x, drop_prob: float = 0., training: bool = False):
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0], ) + (1, ) * (x.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_()
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
class ChannelAttention(nn.Module):
def __init__(self, num_feat, squeeze_factor=16):
super(ChannelAttention, self).__init__()
self.attention = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0),
nn.ReLU(inplace=True),
nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0),
nn.Sigmoid())
def forward(self, x):
y = self.attention(x)
return x * y
class CAB(nn.Module):
def __init__(self, num_feat, compress_ratio=3, squeeze_factor=30):
super(CAB, self).__init__()
self.cab = nn.Sequential(
nn.Conv2d(num_feat, num_feat // compress_ratio, 3, 1, 1),
nn.GELU(),
nn.Conv2d(num_feat // compress_ratio, num_feat, 3, 1, 1),
ChannelAttention(num_feat, squeeze_factor)
)
def forward(self, x):
return self.cab(x)
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def window_partition(x, window_size):
b, h, w, c = x.shape
x = x.view(b, h // window_size, window_size, w // window_size, window_size, c)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, c)
return windows
def window_reverse(windows, window_size, h, w):
b = int(windows.shape[0] / (h * w / window_size / window_size))
x = windows.view(b, h // window_size, w // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1)
return x
class WindowAttention(nn.Module):
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, rpi, mask=None):
b_, n, c = x.shape
qkv = self.qkv(x).reshape(b_, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[rpi.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nw = mask.shape[0]
attn = attn.view(b_ // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, n, n)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(b_, n, c)
x = self.proj(x)
x = self.proj_drop(x)
return x
class HAB(nn.Module):
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
compress_ratio=3, squeeze_factor=30, conv_scale=0.01, mlp_ratio=4.,
qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
if min(self.input_resolution) <= self.window_size:
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size'
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.conv_scale = conv_scale
self.conv_block = CAB(num_feat=dim, compress_ratio=compress_ratio, squeeze_factor=squeeze_factor)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, x, x_size, rpi_sa, attn_mask):
h, w = x_size
b, _, c = x.shape
shortcut = x
x = self.norm1(x)
x = x.view(b, h, w, c)
# Conv_X
conv_x = self.conv_block(x.permute(0, 3, 1, 2))
conv_x = conv_x.permute(0, 2, 3, 1).contiguous().view(b, h * w, c)
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
attn_mask = attn_mask
else:
shifted_x = x
attn_mask = None
# partition windows
x_windows = window_partition(shifted_x, self.window_size)
x_windows = x_windows.view(-1, self.window_size * self.window_size, c)
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, rpi=rpi_sa, mask=attn_mask)
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, c)
shifted_x = window_reverse(attn_windows, self.window_size, h, w)
# reverse cyclic shift
if self.shift_size > 0:
attn_x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
attn_x = shifted_x
attn_x = attn_x.view(b, h * w, c)
# FFN
x = shortcut + self.drop_path(attn_x) + conv_x * self.conv_scale
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class OCAB(nn.Module):
def __init__(self, dim, input_resolution, window_size, overlap_ratio, num_heads,
qkv_bias=True, qk_scale=None, mlp_ratio=2, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.window_size = window_size
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
self.overlap_win_size = int(window_size * overlap_ratio) + window_size
self.norm1 = norm_layer(dim)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.unfold = nn.Unfold(kernel_size=(self.overlap_win_size, self.overlap_win_size),
stride=window_size, padding=(self.overlap_win_size-window_size)//2)
self.relative_position_bias_table = nn.Parameter(
torch.zeros((window_size + self.overlap_win_size - 1) * (window_size + self.overlap_win_size - 1), num_heads))
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
self.proj = nn.Linear(dim,dim)
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=nn.GELU)
def forward(self, x, x_size, rpi):
h, w = x_size
b, _, c = x.shape
shortcut = x
x = self.norm1(x)
x = x.view(b, h, w, c)
qkv = self.qkv(x).reshape(b, h, w, 3, c).permute(3, 0, 4, 1, 2)
q = qkv[0].permute(0, 2, 3, 1)
kv = torch.cat((qkv[1], qkv[2]), dim=1)
# partition windows
q_windows = window_partition(q, self.window_size)
q_windows = q_windows.view(-1, self.window_size * self.window_size, c)
kv_windows = self.unfold(kv)
kv_windows = rearrange(kv_windows, 'b (nc ch owh oww) nw -> nc (b nw) (owh oww) ch',
nc=2, ch=c, owh=self.overlap_win_size, oww=self.overlap_win_size).contiguous()
k_windows, v_windows = kv_windows[0], kv_windows[1]
b_, nq, _ = q_windows.shape
_, n, _ = k_windows.shape
d = self.dim // self.num_heads
q = q_windows.reshape(b_, nq, self.num_heads, d).permute(0, 2, 1, 3)
k = k_windows.reshape(b_, n, self.num_heads, d).permute(0, 2, 1, 3)
v = v_windows.reshape(b_, n, self.num_heads, d).permute(0, 2, 1, 3)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[rpi.view(-1)].view(
self.window_size * self.window_size, self.overlap_win_size * self.overlap_win_size, -1)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
attn = attn + relative_position_bias.unsqueeze(0)
attn = self.softmax(attn)
attn_windows = (attn @ v).transpose(1, 2).reshape(b_, nq, self.dim)
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, self.dim)
x = window_reverse(attn_windows, self.window_size, h, w)
x = x.view(b, h * w, self.dim)
x = self.proj(x) + shortcut
x = x + self.mlp(self.norm2(x))
return x
class AttenBlocks(nn.Module):
def __init__(self, dim, input_resolution, depth, num_heads, window_size, compress_ratio,
squeeze_factor, conv_scale, overlap_ratio, mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, downsample=None,
use_checkpoint=False):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList([
HAB(dim=dim, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2, compress_ratio=compress_ratio,
squeeze_factor=squeeze_factor, conv_scale=conv_scale, mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer) for i in range(depth)
])
# OCAB
self.overlap_attn = OCAB(dim=dim, input_resolution=input_resolution, window_size=window_size,
overlap_ratio=overlap_ratio, num_heads=num_heads, qkv_bias=qkv_bias,
qk_scale=qk_scale, mlp_ratio=mlp_ratio, norm_layer=norm_layer)
# patch merging layer
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x, x_size, params):
for blk in self.blocks:
x = blk(x, x_size, params['rpi_sa'], params['attn_mask'])
x = self.overlap_attn(x, x_size, params['rpi_oca'])
if self.downsample is not None:
x = self.downsample(x)
return x
class RHAG(nn.Module):
def __init__(self, dim, input_resolution, depth, num_heads, window_size, compress_ratio,
squeeze_factor, conv_scale, overlap_ratio, mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, downsample=None,
use_checkpoint=False, img_size=224, patch_size=4, resi_connection='1conv'):
super(RHAG, self).__init__()
self.dim = dim
self.input_resolution = input_resolution
self.residual_group = AttenBlocks(
dim=dim, input_resolution=input_resolution, depth=depth, num_heads=num_heads,
window_size=window_size, compress_ratio=compress_ratio, squeeze_factor=squeeze_factor,
conv_scale=conv_scale, overlap_ratio=overlap_ratio, mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop,
drop_path=drop_path, norm_layer=norm_layer, downsample=downsample,
use_checkpoint=use_checkpoint)
if resi_connection == '1conv':
self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
elif resi_connection == 'identity':
self.conv = nn.Identity()
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
self.patch_unembed = PatchUnEmbed(
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
def forward(self, x, x_size, params):
return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size, params), x_size))) + x
class PatchEmbed(nn.Module):
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
x = x.flatten(2).transpose(1, 2)
if self.norm is not None:
x = self.norm(x)
return x
class PatchUnEmbed(nn.Module):
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
def forward(self, x, x_size):
x = x.transpose(1, 2).contiguous().view(x.shape[0], self.embed_dim, x_size[0], x_size[1])
return x
class Upsample(nn.Sequential):
def __init__(self, scale, num_feat):
m = []
if (scale & (scale - 1)) == 0:
for _ in range(int(math.log(scale, 2))):
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(2))
elif scale == 3:
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(3))
else:
raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
super(Upsample, self).__init__(*m)
class HAT(nn.Module):
def __init__(self, img_size=64, patch_size=1, in_chans=3, embed_dim=96, depths=(6, 6, 6, 6),
num_heads=(6, 6, 6, 6), window_size=7, compress_ratio=3, squeeze_factor=30,
conv_scale=0.01, overlap_ratio=0.5, mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm,
ape=False, patch_norm=True, use_checkpoint=False, upscale=2, img_range=1.,
upsampler='', resi_connection='1conv', **kwargs):
super(HAT, self).__init__()
self.window_size = window_size
self.shift_size = window_size // 2
self.overlap_ratio = overlap_ratio
num_in_ch = in_chans
num_out_ch = in_chans
num_feat = 64
self.img_range = img_range
if in_chans == 3:
rgb_mean = (0.4488, 0.4371, 0.4040)
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
else:
self.mean = torch.zeros(1, 1, 1, 1)
self.upscale = upscale
self.upsampler = upsampler
# relative position index
relative_position_index_SA = self.calculate_rpi_sa()
relative_position_index_OCA = self.calculate_rpi_oca()
self.register_buffer('relative_position_index_SA', relative_position_index_SA)
self.register_buffer('relative_position_index_OCA', relative_position_index_OCA)
# shallow feature extraction
self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
# deep feature extraction
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.num_features = embed_dim
self.mlp_ratio = mlp_ratio
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
# merge non-overlapping patches into image
self.patch_unembed = PatchUnEmbed(
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
# absolute position embedding
if self.ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
trunc_normal_(self.absolute_pos_embed, std=.02)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
# build Residual Hybrid Attention Groups (RHAG)
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = RHAG(
dim=embed_dim,
input_resolution=(patches_resolution[0], patches_resolution[1]),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
compress_ratio=compress_ratio,
squeeze_factor=squeeze_factor,
conv_scale=conv_scale,
overlap_ratio=overlap_ratio,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
norm_layer=norm_layer,
downsample=None,
use_checkpoint=use_checkpoint,
img_size=img_size,
patch_size=patch_size,
resi_connection=resi_connection)
self.layers.append(layer)
self.norm = norm_layer(self.num_features)
# build the last conv layer in deep feature extraction
if resi_connection == '1conv':
self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
elif resi_connection == 'identity':
self.conv_after_body = nn.Identity()
# high quality image reconstruction
if self.upsampler == 'pixelshuffle':
self.conv_before_upsample = nn.Sequential(
nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
self.upsample = Upsample(upscale, num_feat)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def calculate_rpi_sa(self):
coords_h = torch.arange(self.window_size)
coords_w = torch.arange(self.window_size)
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
coords_flatten = torch.flatten(coords, 1)
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
relative_coords[:, :, 0] += self.window_size - 1
relative_coords[:, :, 1] += self.window_size - 1
relative_coords[:, :, 0] *= 2 * self.window_size - 1
relative_position_index = relative_coords.sum(-1)
return relative_position_index
def calculate_rpi_oca(self):
window_size_ori = self.window_size
window_size_ext = self.window_size + int(self.overlap_ratio * self.window_size)
coords_h = torch.arange(window_size_ori)
coords_w = torch.arange(window_size_ori)
coords_ori = torch.stack(torch.meshgrid([coords_h, coords_w]))
coords_ori_flatten = torch.flatten(coords_ori, 1)
coords_h = torch.arange(window_size_ext)
coords_w = torch.arange(window_size_ext)
coords_ext = torch.stack(torch.meshgrid([coords_h, coords_w]))
coords_ext_flatten = torch.flatten(coords_ext, 1)
relative_coords = coords_ext_flatten[:, None, :] - coords_ori_flatten[:, :, None]
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
relative_coords[:, :, 0] += window_size_ori - window_size_ext + 1
relative_coords[:, :, 1] += window_size_ori - window_size_ext + 1
relative_coords[:, :, 0] *= window_size_ori + window_size_ext - 1
relative_position_index = relative_coords.sum(-1)
return relative_position_index
def calculate_mask(self, x_size):
h, w = x_size
img_mask = torch.zeros((1, h, w, 1))
h_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size)
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
return attn_mask
@torch.jit.ignore
def no_weight_decay(self):
return {'absolute_pos_embed'}
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {'relative_position_bias_table'}
def forward_features(self, x):
x_size = (x.shape[2], x.shape[3])
attn_mask = self.calculate_mask(x_size).to(x.device)
params = {'attn_mask': attn_mask, 'rpi_sa': self.relative_position_index_SA, 'rpi_oca': self.relative_position_index_OCA}
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x, x_size, params)
x = self.norm(x)
x = self.patch_unembed(x, x_size)
return x
def forward(self, x):
self.mean = self.mean.type_as(x)
x = (x - self.mean) * self.img_range
if self.upsampler == 'pixelshuffle':
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x = self.conv_before_upsample(x)
x = self.conv_last(self.upsample(x))
x = x / self.img_range + self.mean
return x
# Load the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = HAT(
upscale=4,
in_chans=3,
img_size=128,
window_size=16,
compress_ratio=3,
squeeze_factor=30,
conv_scale=0.01,
overlap_ratio=0.5,
img_range=1.,
depths=[6, 6, 6, 6, 6, 6],
embed_dim=180,
num_heads=[6, 6, 6, 6, 6, 6],
mlp_ratio=2,
upsampler='pixelshuffle',
resi_connection='1conv'
)
# Load the fine-tuned weights
checkpoint = torch.load('net_g_150000.pth', map_location=device)
if 'params_ema' in checkpoint:
model.load_state_dict(checkpoint['params_ema'])
elif 'params' in checkpoint:
model.load_state_dict(checkpoint['params'])
else:
model.load_state_dict(checkpoint)
model.to(device)
model.eval()
def upscale_image(image):
# Convert PIL image to tensor
img_np = np.array(image).astype(np.float32) / 255.0
img_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(device)
# Ensure the image dimensions are multiples of window_size
h, w = img_tensor.shape[2], img_tensor.shape[3]
# Pad if necessary
pad_h = (16 - h % 16) % 16
pad_w = (16 - w % 16) % 16
if pad_h > 0 or pad_w > 0:
img_tensor = torch.nn.functional.pad(img_tensor, (0, pad_w, 0, pad_h), mode='reflect')
with torch.no_grad():
output = model(img_tensor)
# Remove padding if it was added
if pad_h > 0 or pad_w > 0:
output = output[:, :, :h*4, :w*4]
# Convert back to PIL image
output_np = output.squeeze(0).permute(1, 2, 0).cpu().numpy()
output_np = np.clip(output_np * 255.0, 0, 255).astype(np.uint8)
return Image.fromarray(output_np)
# Get sample images
def get_sample_images():
sample_dir = "sample_images"
if os.path.exists(sample_dir):
image_files = glob.glob(os.path.join(sample_dir, "*.png")) + glob.glob(os.path.join(sample_dir, "*.jpg"))
return sorted(image_files)
return []
# Gradio interface using Blocks for better layout control
def validate_image_size(image):
"""Validate that the image is exactly 130x130 pixels"""
if image is None:
return False, "No image provided"
width, height = image.size
if width != 130 or height != 130:
return False, f"Image must be exactly 130x130 pixels. Your image is {width}x{height} pixels."
return True, "Valid image size"
def upscale_and_display(image):
if image is None:
return None, "Please upload an image or select a sample image."
# Validate image size
is_valid, message = validate_image_size(image)
if not is_valid:
return None, f"❌ Error: {message}"
try:
# Get the super-resolution output
upscaled = upscale_image(image)
return upscaled, "✅ Image successfully enhanced!"
except Exception as e:
return None, f"❌ Error processing image: {str(e)}"
def select_sample_image(image_path):
if image_path:
return Image.open(image_path)
return None
def image_to_base64(image_path):
"""Convert image to base64 data URL for CSS background"""
img = Image.open(image_path)
img.thumbnail((120, 120), Image.Resampling.LANCZOS)
buffer = BytesIO()
img.save(buffer, format='PNG')
img_str = base64.b64encode(buffer.getvalue()).decode()
return f"data:image/png;base64,{img_str}"
# Generate CSS with base64 images
def generate_css():
base_css = """
/* Target only the image display area, not the whole component */
.image-container [data-testid="image"] {
height: 500px !important;
min-height: 500px !important;
}
/* Make images fill their containers */
.image-container img {
width: 500px !important;
height: 500px !important;
object-fit: contain !important;
object-position: center !important;
}
/* Sample image buttons with background images */
.sample-image-btn {
height: 120px !important;
width: 120px !important;
background-size: cover !important;
background-position: center !important;
border: 2px solid #ddd !important;
border-radius: 8px !important;
cursor: pointer !important;
transition: border-color 0.2s !important;
margin: 5px !important;
}
.sample-image-btn:hover {
border-color: #007acc !important;
}
"""
# Add background images for each sample
sample_images = get_sample_images()
for i, img_path in enumerate(sample_images):
base64_img = image_to_base64(img_path)
base_css += f"#sample_btn_{i} {{ background-image: url('{base64_img}'); }}\n"
return base_css
css = generate_css()
with gr.Blocks(css=css, title="HATSAT - Super-Resolution for Satellite Images") as iface:
gr.Markdown("# HATSAT - Super-Resolution for Satellite Images")
gr.Markdown("Upload a satellite image or select a sample to enhance its resolution by 4x.")
gr.Markdown("⚠️ **Important**: Images must be exactly **130x130 pixels** for the model to work properly.")
# Acknowledgments section
with gr.Accordion("Acknowledgments", open=False):
gr.Markdown("""
### Base Model: HAT (Hybrid Attention Transformer)
This model is a fine tuned version of **HAT**:
- **GitHub Repository**: [https://github.com/XPixelGroup/HAT](https://github.com/XPixelGroup/HAT)
- **Paper**: [Activating More Pixels in Image Super-Resolution Transformer](https://arxiv.org/abs/2205.04437)
- **Authors**: Xiangyu Chen, Xintao Wang, Jiantao Zhou, Yu Qiao, Chao Dong
### Training Dataset: SEN2NAIPv2
The model was fine-tuned using the **SEN2NAIPv2** dataset:
- **HuggingFace Dataset**: [https://huggingface.co/datasets/tacofoundation/SEN2NAIPv2](https://huggingface.co/datasets/tacofoundation/SEN2NAIPv2)
- **Description**: High-resolution satellite imagery dataset for super-resolution tasks
""")
# Sample images
sample_images = get_sample_images()
sample_buttons = []
if sample_images:
gr.Markdown("**Sample Images (click to select):**")
with gr.Row():
for i, img_path in enumerate(sample_images):
btn = gr.Button(
"",
elem_id=f"sample_btn_{i}",
elem_classes="sample-image-btn"
)
sample_buttons.append((btn, img_path))
with gr.Row():
input_image = gr.Image(
type="pil",
label="Input Image (must be 130x130 pixels)",
elem_classes="image-container",
sources=["upload"],
height=500,
width=500
)
output_image = gr.Image(
type="pil",
label="Enhanced Output (4x)",
elem_classes="image-container",
interactive=False,
height=500,
width=500,
show_download_button=True
)
submit_btn = gr.Button("Enhance Image", variant="primary")
# Status message
status_message = gr.Textbox(
label="Status",
interactive=False,
show_label=True
)
# Event handlers
if sample_images:
for btn, img_path in sample_buttons:
btn.click(fn=lambda path=img_path: select_sample_image(path), outputs=input_image)
submit_btn.click(fn=upscale_and_display, inputs=input_image, outputs=[output_image, status_message])
if __name__ == "__main__":
iface.launch()