Spaces:
Runtime error
Runtime error
update dependencies
Browse files- .gitignore +2 -1
- annotator/dsine/__init__.py +0 -0
- annotator/dsine/dsine.py +303 -0
- annotator/dsine/rotation.py +85 -0
- annotator/dsine/submodules.py +237 -0
- annotator/dsine/utils.py +104 -0
- annotator/dsine_local.py +63 -0
- app.py +5 -2
.gitignore
CHANGED
|
@@ -159,4 +159,5 @@ cython_debug/
|
|
| 159 |
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 160 |
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 161 |
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 162 |
-
#.idea/
|
|
|
|
|
|
| 159 |
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 160 |
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 161 |
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 162 |
+
#.idea/
|
| 163 |
+
gradio/
|
annotator/dsine/__init__.py
ADDED
|
File without changes
|
annotator/dsine/dsine.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
from .submodules import (
|
| 7 |
+
Encoder,
|
| 8 |
+
ConvGRU,
|
| 9 |
+
UpSampleBN,
|
| 10 |
+
UpSampleGN,
|
| 11 |
+
RayReLU,
|
| 12 |
+
convex_upsampling,
|
| 13 |
+
get_unfold,
|
| 14 |
+
get_prediction_head,
|
| 15 |
+
INPUT_CHANNELS_DICT,
|
| 16 |
+
)
|
| 17 |
+
from .rotation import axis_angle_to_matrix
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class Decoder(nn.Module):
|
| 21 |
+
def __init__(self, output_dims, B=5, NF=2048, BN=False, downsample_ratio=8):
|
| 22 |
+
super(Decoder, self).__init__()
|
| 23 |
+
input_channels = INPUT_CHANNELS_DICT[B]
|
| 24 |
+
output_dim, feature_dim, hidden_dim = output_dims
|
| 25 |
+
features = bottleneck_features = NF
|
| 26 |
+
self.downsample_ratio = downsample_ratio
|
| 27 |
+
|
| 28 |
+
UpSample = UpSampleBN if BN else UpSampleGN
|
| 29 |
+
self.conv2 = nn.Conv2d(
|
| 30 |
+
bottleneck_features + 2, features, kernel_size=1, stride=1, padding=0
|
| 31 |
+
)
|
| 32 |
+
self.up1 = UpSample(
|
| 33 |
+
skip_input=features // 1 + input_channels[1] + 2,
|
| 34 |
+
output_features=features // 2,
|
| 35 |
+
align_corners=False,
|
| 36 |
+
)
|
| 37 |
+
self.up2 = UpSample(
|
| 38 |
+
skip_input=features // 2 + input_channels[2] + 2,
|
| 39 |
+
output_features=features // 4,
|
| 40 |
+
align_corners=False,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
# prediction heads
|
| 44 |
+
i_dim = features // 4
|
| 45 |
+
h_dim = 128
|
| 46 |
+
self.normal_head = get_prediction_head(i_dim + 2, h_dim, output_dim)
|
| 47 |
+
self.feature_head = get_prediction_head(i_dim + 2, h_dim, feature_dim)
|
| 48 |
+
self.hidden_head = get_prediction_head(i_dim + 2, h_dim, hidden_dim)
|
| 49 |
+
|
| 50 |
+
def forward(self, features, uvs):
|
| 51 |
+
_, _, x_block2, x_block3, x_block4 = (
|
| 52 |
+
features[4],
|
| 53 |
+
features[5],
|
| 54 |
+
features[6],
|
| 55 |
+
features[8],
|
| 56 |
+
features[11],
|
| 57 |
+
)
|
| 58 |
+
uv_32, uv_16, uv_8 = uvs
|
| 59 |
+
|
| 60 |
+
x_d0 = self.conv2(torch.cat([x_block4, uv_32], dim=1))
|
| 61 |
+
x_d1 = self.up1(x_d0, torch.cat([x_block3, uv_16], dim=1))
|
| 62 |
+
x_feat = self.up2(x_d1, torch.cat([x_block2, uv_8], dim=1))
|
| 63 |
+
x_feat = torch.cat([x_feat, uv_8], dim=1)
|
| 64 |
+
|
| 65 |
+
normal = self.normal_head(x_feat)
|
| 66 |
+
normal = F.normalize(normal, dim=1)
|
| 67 |
+
f = self.feature_head(x_feat)
|
| 68 |
+
h = self.hidden_head(x_feat)
|
| 69 |
+
return normal, f, h
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class DSINE(nn.Module):
|
| 73 |
+
def __init__(self):
|
| 74 |
+
super(DSINE, self).__init__()
|
| 75 |
+
self.downsample_ratio = 8
|
| 76 |
+
self.ps = 5 # patch size
|
| 77 |
+
self.num_iter = 5 # num iterations
|
| 78 |
+
|
| 79 |
+
# define encoder
|
| 80 |
+
self.encoder = Encoder(
|
| 81 |
+
B=5,
|
| 82 |
+
pretrained=True,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# define decoder
|
| 86 |
+
self.output_dim = output_dim = 3
|
| 87 |
+
self.feature_dim = feature_dim = 64
|
| 88 |
+
self.hidden_dim = hidden_dim = 64
|
| 89 |
+
self.decoder = Decoder(
|
| 90 |
+
[output_dim, feature_dim, hidden_dim], B=5, NF=2048, BN=False
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
# ray direction-based ReLU
|
| 94 |
+
self.ray_relu = RayReLU(eps=1e-2)
|
| 95 |
+
|
| 96 |
+
# pixel_coords (1, 3, H, W)
|
| 97 |
+
# NOTE: this is set to some arbitrarily high number,
|
| 98 |
+
# if your input is 2000+ pixels wide/tall, increase these values
|
| 99 |
+
h = 2000
|
| 100 |
+
w = 2000
|
| 101 |
+
pixel_coords = np.ones((3, h, w)).astype(np.float32)
|
| 102 |
+
x_range = np.concatenate([np.arange(w).reshape(1, w)] * h, axis=0)
|
| 103 |
+
y_range = np.concatenate([np.arange(h).reshape(h, 1)] * w, axis=1)
|
| 104 |
+
pixel_coords[0, :, :] = x_range + 0.5
|
| 105 |
+
pixel_coords[1, :, :] = y_range + 0.5
|
| 106 |
+
self.pixel_coords = torch.from_numpy(pixel_coords).unsqueeze(0)
|
| 107 |
+
|
| 108 |
+
# define ConvGRU cell
|
| 109 |
+
self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=feature_dim + 2, ks=self.ps)
|
| 110 |
+
|
| 111 |
+
# padding used during NRN
|
| 112 |
+
self.pad = (self.ps - 1) // 2
|
| 113 |
+
|
| 114 |
+
# prediction heads
|
| 115 |
+
self.prob_head = get_prediction_head(
|
| 116 |
+
self.hidden_dim + 2, 64, self.ps * self.ps
|
| 117 |
+
) # weights assigned for each nghbr pixel
|
| 118 |
+
self.xy_head = get_prediction_head(
|
| 119 |
+
self.hidden_dim + 2, 64, self.ps * self.ps * 2
|
| 120 |
+
) # rotation axis for each nghbr pixel
|
| 121 |
+
self.angle_head = get_prediction_head(
|
| 122 |
+
self.hidden_dim + 2, 64, self.ps * self.ps
|
| 123 |
+
) # rotation angle for each nghbr pixel
|
| 124 |
+
|
| 125 |
+
# prediction heads - weights used for upsampling the coarse resolution output
|
| 126 |
+
self.up_prob_head = get_prediction_head(
|
| 127 |
+
self.hidden_dim + 2, 64, 9 * self.downsample_ratio * self.downsample_ratio
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
def get_ray(self, intrins, H, W, orig_H, orig_W, return_uv=False):
|
| 131 |
+
B, _, _ = intrins.shape
|
| 132 |
+
fu = intrins[:, 0, 0][:, None, None] * (W / orig_W)
|
| 133 |
+
cu = intrins[:, 0, 2][:, None, None] * (W / orig_W)
|
| 134 |
+
fv = intrins[:, 1, 1][:, None, None] * (H / orig_H)
|
| 135 |
+
cv = intrins[:, 1, 2][:, None, None] * (H / orig_H)
|
| 136 |
+
|
| 137 |
+
# (B, 2, H, W)
|
| 138 |
+
ray = self.pixel_coords[:, :, :H, :W].repeat(B, 1, 1, 1)
|
| 139 |
+
ray[:, 0, :, :] = (ray[:, 0, :, :] - cu) / fu
|
| 140 |
+
ray[:, 1, :, :] = (ray[:, 1, :, :] - cv) / fv
|
| 141 |
+
|
| 142 |
+
if return_uv:
|
| 143 |
+
return ray[:, :2, :, :]
|
| 144 |
+
else:
|
| 145 |
+
return F.normalize(ray, dim=1)
|
| 146 |
+
|
| 147 |
+
def upsample(self, h, pred_norm, uv_8):
|
| 148 |
+
up_mask = self.up_prob_head(torch.cat([h, uv_8], dim=1))
|
| 149 |
+
up_pred_norm = convex_upsampling(pred_norm, up_mask, self.downsample_ratio)
|
| 150 |
+
up_pred_norm = F.normalize(up_pred_norm, dim=1)
|
| 151 |
+
return up_pred_norm
|
| 152 |
+
|
| 153 |
+
def refine(self, h, feat_map, pred_norm, intrins, orig_H, orig_W, uv_8, ray_8):
|
| 154 |
+
B, C, H, W = pred_norm.shape
|
| 155 |
+
fu = intrins[:, 0, 0][:, None, None, None] * (W / orig_W) # (B, 1, 1, 1)
|
| 156 |
+
cu = intrins[:, 0, 2][:, None, None, None] * (W / orig_W)
|
| 157 |
+
fv = intrins[:, 1, 1][:, None, None, None] * (H / orig_H)
|
| 158 |
+
cv = intrins[:, 1, 2][:, None, None, None] * (H / orig_H)
|
| 159 |
+
|
| 160 |
+
h_new = self.gru(h, feat_map)
|
| 161 |
+
|
| 162 |
+
# get nghbr prob (B, 1, ps*ps, h, w)
|
| 163 |
+
nghbr_prob = self.prob_head(torch.cat([h_new, uv_8], dim=1)).unsqueeze(1)
|
| 164 |
+
nghbr_prob = torch.sigmoid(nghbr_prob)
|
| 165 |
+
|
| 166 |
+
# get nghbr normals (B, 3, ps*ps, h, w)
|
| 167 |
+
nghbr_normals = get_unfold(pred_norm, ps=self.ps, pad=self.pad)
|
| 168 |
+
|
| 169 |
+
# get nghbr xy (B, 2, ps*ps, h, w)
|
| 170 |
+
nghbr_xys = self.xy_head(torch.cat([h_new, uv_8], dim=1))
|
| 171 |
+
nghbr_xs, nghbr_ys = torch.split(
|
| 172 |
+
nghbr_xys, [self.ps * self.ps, self.ps * self.ps], dim=1
|
| 173 |
+
)
|
| 174 |
+
nghbr_xys = torch.cat([nghbr_xs.unsqueeze(1), nghbr_ys.unsqueeze(1)], dim=1)
|
| 175 |
+
nghbr_xys = F.normalize(nghbr_xys, dim=1)
|
| 176 |
+
|
| 177 |
+
# get nghbr theta (B, 1, ps*ps, h, w)
|
| 178 |
+
nghbr_angle = self.angle_head(torch.cat([h_new, uv_8], dim=1)).unsqueeze(1)
|
| 179 |
+
nghbr_angle = torch.sigmoid(nghbr_angle) * np.pi
|
| 180 |
+
|
| 181 |
+
# get nghbr pixel coord (1, 3, ps*ps, h, w)
|
| 182 |
+
nghbr_pixel_coord = get_unfold(
|
| 183 |
+
self.pixel_coords[:, :, :H, :W], ps=self.ps, pad=self.pad
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
# nghbr axes (B, 3, ps*ps, h, w)
|
| 187 |
+
nghbr_axes = torch.zeros_like(nghbr_normals)
|
| 188 |
+
|
| 189 |
+
du_over_fu = nghbr_xys[:, 0, ...] / fu # (B, ps*ps, h, w)
|
| 190 |
+
dv_over_fv = nghbr_xys[:, 1, ...] / fv # (B, ps*ps, h, w)
|
| 191 |
+
|
| 192 |
+
term_u = (
|
| 193 |
+
nghbr_pixel_coord[:, 0, ...] + nghbr_xys[:, 0, ...] - cu
|
| 194 |
+
) / fu # (B, ps*ps, h, w)
|
| 195 |
+
term_v = (
|
| 196 |
+
nghbr_pixel_coord[:, 1, ...] + nghbr_xys[:, 1, ...] - cv
|
| 197 |
+
) / fv # (B, ps*ps, h, w)
|
| 198 |
+
|
| 199 |
+
nx = nghbr_normals[:, 0, ...] # (B, ps*ps, h, w)
|
| 200 |
+
ny = nghbr_normals[:, 1, ...] # (B, ps*ps, h, w)
|
| 201 |
+
nz = nghbr_normals[:, 2, ...] # (B, ps*ps, h, w)
|
| 202 |
+
|
| 203 |
+
nghbr_delta_z_num = -(du_over_fu * nx + dv_over_fv * ny)
|
| 204 |
+
nghbr_delta_z_denom = term_u * nx + term_v * ny + nz
|
| 205 |
+
nghbr_delta_z_denom[torch.abs(nghbr_delta_z_denom) < 1e-8] = 1e-8 * torch.sign(
|
| 206 |
+
nghbr_delta_z_denom[torch.abs(nghbr_delta_z_denom) < 1e-8]
|
| 207 |
+
)
|
| 208 |
+
nghbr_delta_z = nghbr_delta_z_num / nghbr_delta_z_denom
|
| 209 |
+
|
| 210 |
+
nghbr_axes[:, 0, ...] = du_over_fu + nghbr_delta_z * term_u
|
| 211 |
+
nghbr_axes[:, 1, ...] = dv_over_fv + nghbr_delta_z * term_v
|
| 212 |
+
nghbr_axes[:, 2, ...] = nghbr_delta_z
|
| 213 |
+
nghbr_axes = F.normalize(nghbr_axes, dim=1) # (B, 3, ps*ps, h, w)
|
| 214 |
+
|
| 215 |
+
# make sure axes are all valid
|
| 216 |
+
invalid = (
|
| 217 |
+
torch.sum(
|
| 218 |
+
torch.logical_or(
|
| 219 |
+
torch.isnan(nghbr_axes), torch.isinf(nghbr_axes)
|
| 220 |
+
).float(),
|
| 221 |
+
dim=1,
|
| 222 |
+
)
|
| 223 |
+
> 0.5
|
| 224 |
+
) # (B, ps*ps, h, w)
|
| 225 |
+
nghbr_axes[:, 0, ...][invalid] = 0.0
|
| 226 |
+
nghbr_axes[:, 1, ...][invalid] = 0.0
|
| 227 |
+
nghbr_axes[:, 2, ...][invalid] = 0.0
|
| 228 |
+
|
| 229 |
+
# nghbr_axes_angle (B, 3, ps*ps, h, w)
|
| 230 |
+
nghbr_axes_angle = nghbr_axes * nghbr_angle
|
| 231 |
+
nghbr_axes_angle = nghbr_axes_angle.permute(
|
| 232 |
+
0, 2, 3, 4, 1
|
| 233 |
+
) # (B, ps*ps, h, w, 3)
|
| 234 |
+
nghbr_R = axis_angle_to_matrix(nghbr_axes_angle) # (B, ps*ps, h, w, 3, 3)
|
| 235 |
+
|
| 236 |
+
# (B, 3, ps*ps, h, w)
|
| 237 |
+
nghbr_normals_rot = (
|
| 238 |
+
torch.bmm(
|
| 239 |
+
nghbr_R.reshape(B * self.ps * self.ps * H * W, 3, 3),
|
| 240 |
+
nghbr_normals.permute(0, 2, 3, 4, 1)
|
| 241 |
+
.reshape(B * self.ps * self.ps * H * W, 3)
|
| 242 |
+
.unsqueeze(-1),
|
| 243 |
+
)
|
| 244 |
+
.reshape(B, self.ps * self.ps, H, W, 3, 1)
|
| 245 |
+
.squeeze(-1)
|
| 246 |
+
.permute(0, 4, 1, 2, 3)
|
| 247 |
+
) # (B, 3, ps*ps, h, w)
|
| 248 |
+
nghbr_normals_rot = F.normalize(nghbr_normals_rot, dim=1)
|
| 249 |
+
|
| 250 |
+
# ray ReLU
|
| 251 |
+
nghbr_normals_rot = torch.cat(
|
| 252 |
+
[
|
| 253 |
+
self.ray_relu(nghbr_normals_rot[:, :, i, :, :], ray_8).unsqueeze(2)
|
| 254 |
+
for i in range(nghbr_normals_rot.size(2))
|
| 255 |
+
],
|
| 256 |
+
dim=2,
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
# (B, 1, ps*ps, h, w) * (B, 3, ps*ps, h, w)
|
| 260 |
+
pred_norm = torch.sum(nghbr_prob * nghbr_normals_rot, dim=2) # (B, C, H, W)
|
| 261 |
+
pred_norm = F.normalize(pred_norm, dim=1)
|
| 262 |
+
|
| 263 |
+
up_mask = self.up_prob_head(torch.cat([h_new, uv_8], dim=1))
|
| 264 |
+
up_pred_norm = convex_upsampling(pred_norm, up_mask, self.downsample_ratio)
|
| 265 |
+
up_pred_norm = F.normalize(up_pred_norm, dim=1)
|
| 266 |
+
|
| 267 |
+
return h_new, pred_norm, up_pred_norm
|
| 268 |
+
|
| 269 |
+
def forward(self, img, intrins=None):
|
| 270 |
+
# Step 1. encoder
|
| 271 |
+
features = self.encoder(img)
|
| 272 |
+
|
| 273 |
+
# Step 2. get uv encoding
|
| 274 |
+
B, _, orig_H, orig_W = img.shape
|
| 275 |
+
intrins[:, 0, 2] += 0.5
|
| 276 |
+
intrins[:, 1, 2] += 0.5
|
| 277 |
+
uv_32 = self.get_ray(
|
| 278 |
+
intrins, orig_H // 32, orig_W // 32, orig_H, orig_W, return_uv=True
|
| 279 |
+
)
|
| 280 |
+
uv_16 = self.get_ray(
|
| 281 |
+
intrins, orig_H // 16, orig_W // 16, orig_H, orig_W, return_uv=True
|
| 282 |
+
)
|
| 283 |
+
uv_8 = self.get_ray(
|
| 284 |
+
intrins, orig_H // 8, orig_W // 8, orig_H, orig_W, return_uv=True
|
| 285 |
+
)
|
| 286 |
+
ray_8 = self.get_ray(intrins, orig_H // 8, orig_W // 8, orig_H, orig_W)
|
| 287 |
+
|
| 288 |
+
# Step 3. decoder - initial prediction
|
| 289 |
+
pred_norm, feat_map, h = self.decoder(features, uvs=(uv_32, uv_16, uv_8))
|
| 290 |
+
pred_norm = self.ray_relu(pred_norm, ray_8)
|
| 291 |
+
|
| 292 |
+
# Step 4. add ray direction encoding
|
| 293 |
+
feat_map = torch.cat([feat_map, uv_8], dim=1)
|
| 294 |
+
|
| 295 |
+
# iterative refinement
|
| 296 |
+
up_pred_norm = self.upsample(h, pred_norm, uv_8)
|
| 297 |
+
pred_list = [up_pred_norm]
|
| 298 |
+
for i in range(self.num_iter):
|
| 299 |
+
h, pred_norm, up_pred_norm = self.refine(
|
| 300 |
+
h, feat_map, pred_norm.detach(), intrins, orig_H, orig_W, uv_8, ray_8
|
| 301 |
+
)
|
| 302 |
+
pred_list.append(up_pred_norm)
|
| 303 |
+
return pred_list
|
annotator/dsine/rotation.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# NOTE: from PyTorch3D
|
| 6 |
+
def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
|
| 7 |
+
"""
|
| 8 |
+
Convert rotations given as axis/angle to quaternions.
|
| 9 |
+
|
| 10 |
+
Args:
|
| 11 |
+
axis_angle: Rotations given as a vector in axis angle form,
|
| 12 |
+
as a tensor of shape (..., 3), where the magnitude is
|
| 13 |
+
the angle turned anticlockwise in radians around the
|
| 14 |
+
vector's direction.
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
quaternions with real part first, as tensor of shape (..., 4).
|
| 18 |
+
"""
|
| 19 |
+
angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
|
| 20 |
+
half_angles = angles * 0.5
|
| 21 |
+
eps = 1e-6
|
| 22 |
+
small_angles = angles.abs() < eps
|
| 23 |
+
sin_half_angles_over_angles = torch.empty_like(angles)
|
| 24 |
+
sin_half_angles_over_angles[~small_angles] = (
|
| 25 |
+
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
|
| 26 |
+
)
|
| 27 |
+
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
|
| 28 |
+
# so sin(x/2)/x is about 1/2 - (x*x)/48
|
| 29 |
+
sin_half_angles_over_angles[small_angles] = (
|
| 30 |
+
0.5 - (angles[small_angles] * angles[small_angles]) / 48
|
| 31 |
+
)
|
| 32 |
+
quaternions = torch.cat(
|
| 33 |
+
[torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
|
| 34 |
+
)
|
| 35 |
+
return quaternions
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# NOTE: from PyTorch3D
|
| 39 |
+
def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
|
| 40 |
+
"""
|
| 41 |
+
Convert rotations given as quaternions to rotation matrices.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
quaternions: quaternions with real part first,
|
| 45 |
+
as tensor of shape (..., 4).
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
| 49 |
+
"""
|
| 50 |
+
r, i, j, k = torch.unbind(quaternions, -1)
|
| 51 |
+
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
|
| 52 |
+
two_s = 2.0 / (quaternions * quaternions).sum(-1)
|
| 53 |
+
|
| 54 |
+
o = torch.stack(
|
| 55 |
+
(
|
| 56 |
+
1 - two_s * (j * j + k * k),
|
| 57 |
+
two_s * (i * j - k * r),
|
| 58 |
+
two_s * (i * k + j * r),
|
| 59 |
+
two_s * (i * j + k * r),
|
| 60 |
+
1 - two_s * (i * i + k * k),
|
| 61 |
+
two_s * (j * k - i * r),
|
| 62 |
+
two_s * (i * k - j * r),
|
| 63 |
+
two_s * (j * k + i * r),
|
| 64 |
+
1 - two_s * (i * i + j * j),
|
| 65 |
+
),
|
| 66 |
+
-1,
|
| 67 |
+
)
|
| 68 |
+
return o.reshape(quaternions.shape[:-1] + (3, 3))
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# NOTE: from PyTorch3D
|
| 72 |
+
def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor:
|
| 73 |
+
"""
|
| 74 |
+
Convert rotations given as axis/angle to rotation matrices.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
axis_angle: Rotations given as a vector in axis angle form,
|
| 78 |
+
as a tensor of shape (..., 3), where the magnitude is
|
| 79 |
+
the angle turned anticlockwise in radians around the
|
| 80 |
+
vector's direction.
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
| 84 |
+
"""
|
| 85 |
+
return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
|
annotator/dsine/submodules.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import geffnet
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
INPUT_CHANNELS_DICT = {
|
| 8 |
+
0: [1280, 112, 40, 24, 16],
|
| 9 |
+
1: [1280, 112, 40, 24, 16],
|
| 10 |
+
2: [1408, 120, 48, 24, 16],
|
| 11 |
+
3: [1536, 136, 48, 32, 24],
|
| 12 |
+
4: [1792, 160, 56, 32, 24],
|
| 13 |
+
5: [2048, 176, 64, 40, 24],
|
| 14 |
+
6: [2304, 200, 72, 40, 32],
|
| 15 |
+
7: [2560, 224, 80, 48, 32],
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class Encoder(nn.Module):
|
| 20 |
+
def __init__(self, B=5, pretrained=True):
|
| 21 |
+
"""e.g. B=5 will return EfficientNet-B5"""
|
| 22 |
+
super(Encoder, self).__init__()
|
| 23 |
+
basemodel_name = 'tf_efficientnet_b%s_ap' % B
|
| 24 |
+
basemodel = geffnet.create_model(basemodel_name, pretrained=pretrained)
|
| 25 |
+
# Remove last layer
|
| 26 |
+
basemodel.global_pool = nn.Identity()
|
| 27 |
+
basemodel.classifier = nn.Identity()
|
| 28 |
+
self.original_model = basemodel
|
| 29 |
+
|
| 30 |
+
def forward(self, x):
|
| 31 |
+
features = [x]
|
| 32 |
+
for k, v in self.original_model._modules.items():
|
| 33 |
+
if k == "blocks":
|
| 34 |
+
for ki, vi in v._modules.items():
|
| 35 |
+
features.append(vi(features[-1]))
|
| 36 |
+
else:
|
| 37 |
+
features.append(v(features[-1]))
|
| 38 |
+
return features
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class ConvGRU(nn.Module):
|
| 42 |
+
def __init__(self, hidden_dim, input_dim, ks=3):
|
| 43 |
+
super(ConvGRU, self).__init__()
|
| 44 |
+
p = (ks - 1) // 2
|
| 45 |
+
self.convz = nn.Conv2d(hidden_dim + input_dim, hidden_dim, ks, padding=p)
|
| 46 |
+
self.convr = nn.Conv2d(hidden_dim + input_dim, hidden_dim, ks, padding=p)
|
| 47 |
+
self.convq = nn.Conv2d(hidden_dim + input_dim, hidden_dim, ks, padding=p)
|
| 48 |
+
|
| 49 |
+
def forward(self, h, x):
|
| 50 |
+
hx = torch.cat([h, x], dim=1)
|
| 51 |
+
z = torch.sigmoid(self.convz(hx))
|
| 52 |
+
r = torch.sigmoid(self.convr(hx))
|
| 53 |
+
q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1)))
|
| 54 |
+
h = (1 - z) * h + z * q
|
| 55 |
+
return h
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class RayReLU(nn.Module):
|
| 59 |
+
def __init__(self, eps=1e-2):
|
| 60 |
+
super(RayReLU, self).__init__()
|
| 61 |
+
self.eps = eps
|
| 62 |
+
|
| 63 |
+
def forward(self, pred_norm, ray):
|
| 64 |
+
# angle between the predicted normal and ray direction
|
| 65 |
+
cos = torch.cosine_similarity(pred_norm, ray, dim=1).unsqueeze(
|
| 66 |
+
1
|
| 67 |
+
) # (B, 1, H, W)
|
| 68 |
+
|
| 69 |
+
# component of pred_norm along view
|
| 70 |
+
norm_along_view = ray * cos
|
| 71 |
+
|
| 72 |
+
# cos should be bigger than eps
|
| 73 |
+
norm_along_view_relu = ray * (torch.relu(cos - self.eps) + self.eps)
|
| 74 |
+
|
| 75 |
+
# difference
|
| 76 |
+
diff = norm_along_view_relu - norm_along_view
|
| 77 |
+
|
| 78 |
+
# updated pred_norm
|
| 79 |
+
new_pred_norm = pred_norm + diff
|
| 80 |
+
new_pred_norm = F.normalize(new_pred_norm, dim=1)
|
| 81 |
+
|
| 82 |
+
return new_pred_norm
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class UpSampleBN(nn.Module):
|
| 86 |
+
def __init__(self, skip_input, output_features, align_corners=True):
|
| 87 |
+
super(UpSampleBN, self).__init__()
|
| 88 |
+
self._net = nn.Sequential(
|
| 89 |
+
nn.Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1),
|
| 90 |
+
nn.BatchNorm2d(output_features),
|
| 91 |
+
nn.LeakyReLU(),
|
| 92 |
+
nn.Conv2d(
|
| 93 |
+
output_features, output_features, kernel_size=3, stride=1, padding=1
|
| 94 |
+
),
|
| 95 |
+
nn.BatchNorm2d(output_features),
|
| 96 |
+
nn.LeakyReLU(),
|
| 97 |
+
)
|
| 98 |
+
self.align_corners = align_corners
|
| 99 |
+
|
| 100 |
+
def forward(self, x, concat_with):
|
| 101 |
+
up_x = F.interpolate(
|
| 102 |
+
x,
|
| 103 |
+
size=[concat_with.size(2), concat_with.size(3)],
|
| 104 |
+
mode="bilinear",
|
| 105 |
+
align_corners=self.align_corners,
|
| 106 |
+
)
|
| 107 |
+
f = torch.cat([up_x, concat_with], dim=1)
|
| 108 |
+
return self._net(f)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class Conv2d_WS(nn.Conv2d):
|
| 112 |
+
"""weight standardization"""
|
| 113 |
+
|
| 114 |
+
def __init__(
|
| 115 |
+
self,
|
| 116 |
+
in_channels,
|
| 117 |
+
out_channels,
|
| 118 |
+
kernel_size,
|
| 119 |
+
stride=1,
|
| 120 |
+
padding=0,
|
| 121 |
+
dilation=1,
|
| 122 |
+
groups=1,
|
| 123 |
+
bias=True,
|
| 124 |
+
):
|
| 125 |
+
super(Conv2d_WS, self).__init__(
|
| 126 |
+
in_channels,
|
| 127 |
+
out_channels,
|
| 128 |
+
kernel_size,
|
| 129 |
+
stride,
|
| 130 |
+
padding,
|
| 131 |
+
dilation,
|
| 132 |
+
groups,
|
| 133 |
+
bias,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
def forward(self, x):
|
| 137 |
+
weight = self.weight
|
| 138 |
+
weight_mean = (
|
| 139 |
+
weight.mean(dim=1, keepdim=True)
|
| 140 |
+
.mean(dim=2, keepdim=True)
|
| 141 |
+
.mean(dim=3, keepdim=True)
|
| 142 |
+
)
|
| 143 |
+
weight = weight - weight_mean
|
| 144 |
+
std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5
|
| 145 |
+
weight = weight / std.expand_as(weight)
|
| 146 |
+
return F.conv2d(
|
| 147 |
+
x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class UpSampleGN(nn.Module):
|
| 152 |
+
"""UpSample with GroupNorm"""
|
| 153 |
+
|
| 154 |
+
def __init__(self, skip_input, output_features, align_corners=True):
|
| 155 |
+
super(UpSampleGN, self).__init__()
|
| 156 |
+
self._net = nn.Sequential(
|
| 157 |
+
Conv2d_WS(skip_input, output_features, kernel_size=3, stride=1, padding=1),
|
| 158 |
+
nn.GroupNorm(8, output_features),
|
| 159 |
+
nn.LeakyReLU(),
|
| 160 |
+
Conv2d_WS(
|
| 161 |
+
output_features, output_features, kernel_size=3, stride=1, padding=1
|
| 162 |
+
),
|
| 163 |
+
nn.GroupNorm(8, output_features),
|
| 164 |
+
nn.LeakyReLU(),
|
| 165 |
+
)
|
| 166 |
+
self.align_corners = align_corners
|
| 167 |
+
|
| 168 |
+
def forward(self, x, concat_with):
|
| 169 |
+
up_x = F.interpolate(
|
| 170 |
+
x,
|
| 171 |
+
size=[concat_with.size(2), concat_with.size(3)],
|
| 172 |
+
mode="bilinear",
|
| 173 |
+
align_corners=self.align_corners,
|
| 174 |
+
)
|
| 175 |
+
f = torch.cat([up_x, concat_with], dim=1)
|
| 176 |
+
return self._net(f)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def upsample_via_bilinear(out, up_mask, downsample_ratio):
|
| 180 |
+
"""bilinear upsampling (up_mask is a dummy variable)"""
|
| 181 |
+
return F.interpolate(
|
| 182 |
+
out, scale_factor=downsample_ratio, mode="bilinear", align_corners=True
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def upsample_via_mask(out, up_mask, downsample_ratio):
|
| 187 |
+
"""convex upsampling"""
|
| 188 |
+
# out: low-resolution output (B, o_dim, H, W)
|
| 189 |
+
# up_mask: (B, 9*k*k, H, W)
|
| 190 |
+
k = downsample_ratio
|
| 191 |
+
|
| 192 |
+
N, o_dim, H, W = out.shape
|
| 193 |
+
up_mask = up_mask.view(N, 1, 9, k, k, H, W)
|
| 194 |
+
up_mask = torch.softmax(up_mask, dim=2) # (B, 1, 9, k, k, H, W)
|
| 195 |
+
|
| 196 |
+
up_out = F.unfold(out, [3, 3], padding=1) # (B, 2, H, W) -> (B, 2 X 3*3, H*W)
|
| 197 |
+
up_out = up_out.view(N, o_dim, 9, 1, 1, H, W) # (B, 2, 3*3, 1, 1, H, W)
|
| 198 |
+
up_out = torch.sum(up_mask * up_out, dim=2) # (B, 2, k, k, H, W)
|
| 199 |
+
|
| 200 |
+
up_out = up_out.permute(0, 1, 4, 2, 5, 3) # (B, 2, H, k, W, k)
|
| 201 |
+
return up_out.reshape(N, o_dim, k * H, k * W) # (B, 2, kH, kW)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def convex_upsampling(out, up_mask, k):
|
| 205 |
+
# out: low-resolution output (B, C, H, W)
|
| 206 |
+
# up_mask: (B, 9*k*k, H, W)
|
| 207 |
+
B, C, H, W = out.shape
|
| 208 |
+
up_mask = up_mask.view(B, 1, 9, k, k, H, W)
|
| 209 |
+
up_mask = torch.softmax(up_mask, dim=2) # (B, 1, 9, k, k, H, W)
|
| 210 |
+
|
| 211 |
+
out = F.pad(out, pad=(1, 1, 1, 1), mode="replicate")
|
| 212 |
+
up_out = F.unfold(out, [3, 3], padding=0) # (B, C, H, W) -> (B, C X 3*3, H*W)
|
| 213 |
+
up_out = up_out.view(B, C, 9, 1, 1, H, W) # (B, C, 9, 1, 1, H, W)
|
| 214 |
+
|
| 215 |
+
up_out = torch.sum(up_mask * up_out, dim=2) # (B, C, k, k, H, W)
|
| 216 |
+
up_out = up_out.permute(0, 1, 4, 2, 5, 3) # (B, C, H, k, W, k)
|
| 217 |
+
return up_out.reshape(B, C, k * H, k * W) # (B, C, kH, kW)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def get_unfold(pred_norm, ps, pad):
|
| 221 |
+
B, C, H, W = pred_norm.shape
|
| 222 |
+
pred_norm = F.pad(
|
| 223 |
+
pred_norm, pad=(pad, pad, pad, pad), mode="replicate"
|
| 224 |
+
) # (B, C, h, w)
|
| 225 |
+
pred_norm_unfold = F.unfold(pred_norm, [ps, ps], padding=0) # (B, C X ps*ps, h*w)
|
| 226 |
+
pred_norm_unfold = pred_norm_unfold.view(B, C, ps * ps, H, W) # (B, C, ps*ps, h, w)
|
| 227 |
+
return pred_norm_unfold
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def get_prediction_head(input_dim, hidden_dim, output_dim):
|
| 231 |
+
return nn.Sequential(
|
| 232 |
+
nn.Conv2d(input_dim, hidden_dim, 3, padding=1),
|
| 233 |
+
nn.ReLU(inplace=True),
|
| 234 |
+
nn.Conv2d(hidden_dim, hidden_dim, 1),
|
| 235 |
+
nn.ReLU(inplace=True),
|
| 236 |
+
nn.Conv2d(hidden_dim, output_dim, 1),
|
| 237 |
+
)
|
annotator/dsine/utils.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" utils
|
| 2 |
+
"""
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def load_checkpoint(fpath, model):
|
| 10 |
+
print("loading checkpoint... {}".format(fpath))
|
| 11 |
+
|
| 12 |
+
ckpt = torch.load(fpath, map_location="cpu")["model"]
|
| 13 |
+
|
| 14 |
+
load_dict = {}
|
| 15 |
+
for k, v in ckpt.items():
|
| 16 |
+
if k.startswith("module."):
|
| 17 |
+
k_ = k.replace("module.", "")
|
| 18 |
+
load_dict[k_] = v
|
| 19 |
+
else:
|
| 20 |
+
load_dict[k] = v
|
| 21 |
+
|
| 22 |
+
model.load_state_dict(load_dict)
|
| 23 |
+
print("loading checkpoint... / done")
|
| 24 |
+
return model
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def compute_normal_error(pred_norm, gt_norm):
|
| 28 |
+
pred_error = torch.cosine_similarity(pred_norm, gt_norm, dim=1)
|
| 29 |
+
pred_error = torch.clamp(pred_error, min=-1.0, max=1.0)
|
| 30 |
+
pred_error = torch.acos(pred_error) * 180.0 / np.pi
|
| 31 |
+
pred_error = pred_error.unsqueeze(1) # (B, 1, H, W)
|
| 32 |
+
return pred_error
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def compute_normal_metrics(total_normal_errors):
|
| 36 |
+
total_normal_errors = total_normal_errors.detach().cpu().numpy()
|
| 37 |
+
num_pixels = total_normal_errors.shape[0]
|
| 38 |
+
|
| 39 |
+
metrics = {
|
| 40 |
+
"mean": np.average(total_normal_errors),
|
| 41 |
+
"median": np.median(total_normal_errors),
|
| 42 |
+
"rmse": np.sqrt(np.sum(total_normal_errors * total_normal_errors) / num_pixels),
|
| 43 |
+
"a1": 100.0 * (np.sum(total_normal_errors < 5) / num_pixels),
|
| 44 |
+
"a2": 100.0 * (np.sum(total_normal_errors < 7.5) / num_pixels),
|
| 45 |
+
"a3": 100.0 * (np.sum(total_normal_errors < 11.25) / num_pixels),
|
| 46 |
+
"a4": 100.0 * (np.sum(total_normal_errors < 22.5) / num_pixels),
|
| 47 |
+
"a5": 100.0 * (np.sum(total_normal_errors < 30) / num_pixels),
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
return metrics
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def pad_input(orig_H, orig_W):
|
| 54 |
+
if orig_W % 32 == 0:
|
| 55 |
+
l = 0
|
| 56 |
+
r = 0
|
| 57 |
+
else:
|
| 58 |
+
new_W = 32 * ((orig_W // 32) + 1)
|
| 59 |
+
l = (new_W - orig_W) // 2
|
| 60 |
+
r = (new_W - orig_W) - l
|
| 61 |
+
|
| 62 |
+
if orig_H % 32 == 0:
|
| 63 |
+
t = 0
|
| 64 |
+
b = 0
|
| 65 |
+
else:
|
| 66 |
+
new_H = 32 * ((orig_H // 32) + 1)
|
| 67 |
+
t = (new_H - orig_H) // 2
|
| 68 |
+
b = (new_H - orig_H) - t
|
| 69 |
+
return l, r, t, b
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def get_intrins_from_fov(new_fov, H, W, device):
|
| 73 |
+
# NOTE: top-left pixel should be (0,0)
|
| 74 |
+
if W >= H:
|
| 75 |
+
new_fu = (W / 2.0) / np.tan(np.deg2rad(new_fov / 2.0))
|
| 76 |
+
new_fv = (W / 2.0) / np.tan(np.deg2rad(new_fov / 2.0))
|
| 77 |
+
else:
|
| 78 |
+
new_fu = (H / 2.0) / np.tan(np.deg2rad(new_fov / 2.0))
|
| 79 |
+
new_fv = (H / 2.0) / np.tan(np.deg2rad(new_fov / 2.0))
|
| 80 |
+
|
| 81 |
+
new_cu = (W / 2.0) - 0.5
|
| 82 |
+
new_cv = (H / 2.0) - 0.5
|
| 83 |
+
|
| 84 |
+
new_intrins = torch.tensor(
|
| 85 |
+
[[new_fu, 0, new_cu], [0, new_fv, new_cv], [0, 0, 1]],
|
| 86 |
+
dtype=torch.float32,
|
| 87 |
+
device=device,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
return new_intrins
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def get_intrins_from_txt(intrins_path, device):
|
| 94 |
+
# NOTE: top-left pixel should be (0,0)
|
| 95 |
+
with open(intrins_path, "r") as f:
|
| 96 |
+
intrins_ = f.readlines()[0].split()[0].split(",")
|
| 97 |
+
intrins_ = [float(i) for i in intrins_]
|
| 98 |
+
fx, fy, cx, cy = intrins_
|
| 99 |
+
|
| 100 |
+
intrins = torch.tensor(
|
| 101 |
+
[[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=torch.float32, device=device
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
return intrins
|
annotator/dsine_local.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from torchvision import transforms
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from .dsine.dsine import DSINE
|
| 7 |
+
from .dsine import utils as dsine_utils
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class NormalDetector:
|
| 11 |
+
def __init__(self, model_path):
|
| 12 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 13 |
+
self.model = DSINE()
|
| 14 |
+
self.model = dsine_utils.load_checkpoint(model_path, self.model)
|
| 15 |
+
self.normalize = transforms.Normalize(
|
| 16 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
| 17 |
+
)
|
| 18 |
+
self.fov = 60
|
| 19 |
+
|
| 20 |
+
@torch.no_grad()
|
| 21 |
+
def __call__(self, image):
|
| 22 |
+
self.model.to(self.device)
|
| 23 |
+
self.model.pixel_coords = self.model.pixel_coords.to(self.device)
|
| 24 |
+
|
| 25 |
+
img = np.array(image).astype(np.float32) / 255.0
|
| 26 |
+
img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).to(self.device)
|
| 27 |
+
_, _, orig_H, orig_W = img.shape
|
| 28 |
+
l, r, t, b = dsine_utils.pad_input(orig_H, orig_W)
|
| 29 |
+
img = F.pad(img, (l, r, t, b), mode="constant", value=0.0)
|
| 30 |
+
img = self.normalize(img)
|
| 31 |
+
intrinsics = dsine_utils.get_intrins_from_fov(
|
| 32 |
+
new_fov=self.fov, H=orig_H, W=orig_W, device=self.device
|
| 33 |
+
).unsqueeze(0)
|
| 34 |
+
|
| 35 |
+
intrinsics[:, 0, 2] += l
|
| 36 |
+
intrinsics[:, 1, 2] += t
|
| 37 |
+
|
| 38 |
+
pred_norm = self.model(img, intrins=intrinsics)[-1]
|
| 39 |
+
pred_norm = pred_norm[:, :, t : t + orig_H, l : l + orig_W]
|
| 40 |
+
pred_norm_np = (
|
| 41 |
+
pred_norm.cpu().detach().numpy()[0, :, :, :].transpose(1, 2, 0)
|
| 42 |
+
) # (H, W, 3)
|
| 43 |
+
pred_norm_np = ((pred_norm_np + 1.0) / 2.0 * 255.0).astype(np.uint8)
|
| 44 |
+
normal_img = Image.fromarray(pred_norm_np).resize((orig_W, orig_H))
|
| 45 |
+
|
| 46 |
+
self.model.to("cpu")
|
| 47 |
+
self.model.pixel_coords = self.model.pixel_coords.to("cpu")
|
| 48 |
+
return normal_img
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
if __name__ == "__main__":
|
| 52 |
+
from diffusers.utils import load_image
|
| 53 |
+
|
| 54 |
+
image = load_image(
|
| 55 |
+
"https://qhstaticssl.kujiale.com/image/jpeg/1716177580588/9AAA49344B9CE33512C4EBD0A287495F.jpg"
|
| 56 |
+
)
|
| 57 |
+
image = np.asarray(image)
|
| 58 |
+
normal_detector = NormalDetector(
|
| 59 |
+
model_path="/juicefs/training/models/open_source/dsine/dsine.pt",
|
| 60 |
+
efficientnet_path="/juicefs/training/models/open_source/dsine/tf_efficientnet_b5_ap-9e82fae8.pth",
|
| 61 |
+
)
|
| 62 |
+
normal_image = normal_detector(image)
|
| 63 |
+
normal_image.save("normal_image.jpg")
|
app.py
CHANGED
|
@@ -7,10 +7,11 @@ from diffusers import (
|
|
| 7 |
UniPCMultistepScheduler,
|
| 8 |
)
|
| 9 |
import gradio as gr
|
|
|
|
| 10 |
|
| 11 |
from annotator.util import resize_image, HWC3
|
| 12 |
from annotator.midas import DepthDetector
|
| 13 |
-
from annotator.
|
| 14 |
from annotator.upernet import SegmDetector
|
| 15 |
|
| 16 |
controlnet_checkpoint = "kujiale-ai/controlnet"
|
|
@@ -26,7 +27,9 @@ pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
|
| 26 |
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
| 27 |
|
| 28 |
apply_depth = DepthDetector()
|
| 29 |
-
apply_normal = NormalDetector(
|
|
|
|
|
|
|
| 30 |
apply_segm = SegmDetector()
|
| 31 |
|
| 32 |
|
|
|
|
| 7 |
UniPCMultistepScheduler,
|
| 8 |
)
|
| 9 |
import gradio as gr
|
| 10 |
+
from huggingface_hub import hf_hub_download
|
| 11 |
|
| 12 |
from annotator.util import resize_image, HWC3
|
| 13 |
from annotator.midas import DepthDetector
|
| 14 |
+
from annotator.dsine_local import NormalDetector
|
| 15 |
from annotator.upernet import SegmDetector
|
| 16 |
|
| 17 |
controlnet_checkpoint = "kujiale-ai/controlnet"
|
|
|
|
| 27 |
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
| 28 |
|
| 29 |
apply_depth = DepthDetector()
|
| 30 |
+
apply_normal = NormalDetector(
|
| 31 |
+
hf_hub_download("camenduru/DSINE", filename="dsine.pt")
|
| 32 |
+
)
|
| 33 |
apply_segm = SegmDetector()
|
| 34 |
|
| 35 |
|