Spaces:
Runtime error
Runtime error
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class FPA(nn.Module): | |
| def __init__(self, channels=2048): | |
| """ | |
| Feature Pyramid Attention | |
| :type channels: int | |
| """ | |
| super(FPA, self).__init__() | |
| channels_mid = int(channels / 4) | |
| self.channels_cond = channels | |
| # Master branch | |
| self.conv_master = nn.Conv2d(self.channels_cond, channels, kernel_size=1, bias=False) | |
| self.bn_master = nn.BatchNorm2d(channels) | |
| # Global pooling branch | |
| self.conv_gpb = nn.Conv2d(self.channels_cond, channels, kernel_size=1, bias=False) | |
| #self.bn_gpb = nn.BatchNorm2d(channels) | |
| # C333 because of the shape of last feature maps is (16, 16). | |
| self.conv7x7_1 = nn.Conv2d(self.channels_cond, channels_mid, kernel_size=(7, 7), stride=2, padding=3, bias=False) | |
| self.bn1_1 = nn.BatchNorm2d(channels_mid) | |
| self.conv5x5_1 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(5, 5), stride=2, padding=2, bias=False) | |
| self.bn2_1 = nn.BatchNorm2d(channels_mid) | |
| self.conv3x3_1 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(3, 3), stride=2, padding=1, bias=False) | |
| self.bn3_1 = nn.BatchNorm2d(channels_mid) | |
| self.conv7x7_2 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(7, 7), stride=1, padding=3, bias=False) | |
| self.bn1_2 = nn.BatchNorm2d(channels_mid) | |
| self.conv5x5_2 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(5, 5), stride=1, padding=2, bias=False) | |
| self.bn2_2 = nn.BatchNorm2d(channels_mid) | |
| self.conv3x3_2 = nn.Conv2d(channels_mid, channels_mid, kernel_size=(3, 3), stride=1, padding=1, bias=False) | |
| self.bn3_2 = nn.BatchNorm2d(channels_mid) | |
| self.bn_upsample_1 = nn.BatchNorm2d(channels) | |
| self.conv1x1_up1 = nn.Conv2d(channels_mid, channels, kernel_size=(1, 1), stride=1, padding=0, bias=False) | |
| self.relu = nn.ReLU(inplace=True) | |
| def forward(self, x): | |
| """ | |
| :param x: Shape: [b, 2048, h, w] | |
| :return: out: Feature maps. Shape: [b, 2048, h, w] | |
| """ | |
| # Master branch | |
| x_master = self.conv_master(x) | |
| x_master = self.bn_master(x_master) | |
| # Global pooling branch | |
| x_gpb = nn.AvgPool2d(x.shape[2:])(x).view(x.shape[0], self.channels_cond, 1, 1) | |
| x_gpb = self.conv_gpb(x_gpb) | |
| #x_gpb = self.bn_gpb(x_gpb) | |
| # Branch 1 | |
| x1_1 = self.conv7x7_1(x) | |
| x1_1 = self.bn1_1(x1_1) | |
| x1_1 = self.relu(x1_1) | |
| x1_2 = self.conv7x7_2(x1_1) | |
| x1_2 = self.bn1_2(x1_2) | |
| # Branch 2 | |
| x2_1 = self.conv5x5_1(x1_1) | |
| x2_1 = self.bn2_1(x2_1) | |
| x2_1 = self.relu(x2_1) | |
| x2_2 = self.conv5x5_2(x2_1) | |
| x2_2 = self.bn2_2(x2_2) | |
| # Branch 3 | |
| x3_1 = self.conv3x3_1(x2_1) | |
| x3_1 = self.bn3_1(x3_1) | |
| x3_1 = self.relu(x3_1) | |
| x3_2 = self.conv3x3_2(x3_1) | |
| x3_2 = self.bn3_2(x3_2) | |
| # Merge branch 1 and 2 | |
| x3_upsample = F.upsample(x3_2, size=x2_2.shape[-2:], | |
| mode='bilinear', align_corners=False) | |
| x2_merge = self.relu(x2_2 + x3_upsample) | |
| x2_upsample = F.upsample(x2_merge, size=x1_2.shape[-2:], | |
| mode='bilinear', align_corners=False) | |
| x1_merge = self.relu(x1_2 + x2_upsample) | |
| x1_merge_upsample = F.upsample(x1_merge, size=x_master.shape[-2:], | |
| mode='bilinear', align_corners=False) | |
| x1_merge_upsample_ch = self.relu(self.bn_upsample_1(self.conv1x1_up1(x1_merge_upsample))) | |
| x_master = x_master * x1_merge_upsample_ch | |
| # | |
| out = self.relu(x_master + x_gpb) | |
| return out | |
| class GAU(nn.Module): | |
| def __init__(self, channels_high, channels_low, upsample=True): | |
| super(GAU, self).__init__() | |
| # Global Attention Upsample | |
| self.upsample = upsample | |
| self.conv3x3 = nn.Conv2d(channels_low, channels_low, kernel_size=3, padding=1, bias=False) | |
| self.bn_low = nn.BatchNorm2d(channels_low) | |
| self.conv1x1 = nn.Conv2d(channels_high, channels_low, kernel_size=1, padding=0, bias=False) | |
| #self.bn_high = nn.BatchNorm2d(channels_low) | |
| if upsample: | |
| self.conv_upsample = nn.ConvTranspose2d(channels_high, channels_low, kernel_size=4, stride=2, padding=1, bias=False) | |
| self.bn_upsample = nn.BatchNorm2d(channels_low) | |
| else: | |
| self.conv_reduction = nn.Conv2d(channels_high, channels_low, kernel_size=1, padding=0, bias=False) | |
| self.bn_reduction = nn.BatchNorm2d(channels_low) | |
| self.relu = nn.ReLU(inplace=True) | |
| def forward(self, fms_high, fms_low, fm_mask=None): | |
| """ | |
| Use the high level features with abundant catagory information to weight the low level features with pixel | |
| localization information. In the meantime, we further use mask feature maps with catagory-specific information | |
| to localize the mask position. | |
| :param fms_high: Features of high level. Tensor. | |
| :param fms_low: Features of low level. Tensor. | |
| :param fm_mask: | |
| :return: fms_att_upsample | |
| """ | |
| b, c, h, w = fms_high.shape | |
| fms_high_gp = nn.AvgPool2d(fms_high.shape[2:])(fms_high).view(len(fms_high), c, 1, 1) | |
| fms_high_gp = self.conv1x1(fms_high_gp) | |
| # fms_high_gp = self.bn_high(fms_high_gp)# arlog, when the spatial size HxW = 1x1, the BN cannot be used. | |
| fms_high_gp = self.relu(fms_high_gp) | |
| # fms_low_mask = torch.cat([fms_low, fm_mask], dim=1) | |
| fms_low_mask = self.conv3x3(fms_low) | |
| fms_low_mask = self.bn_low(fms_low_mask) | |
| fms_att = fms_low_mask * fms_high_gp | |
| if self.upsample: | |
| out = self.relu( | |
| self.bn_upsample(self.conv_upsample(fms_high)) + fms_att) | |
| else: | |
| out = self.relu( | |
| self.bn_reduction(self.conv_reduction(fms_high)) + fms_att) | |
| return out | |
| class PAN(nn.Module): | |
| def __init__(self): | |
| """ | |
| :param blocks: Blocks of the network with reverse sequential. | |
| """ | |
| super(PAN, self).__init__() | |
| channels_blocks = [2048, 1024, 512, 256] | |
| self.fpa = FPA(channels=channels_blocks[0]) | |
| self.gau_block1 = GAU(channels_blocks[0], channels_blocks[1]) | |
| self.gau_block2 = GAU(channels_blocks[1], channels_blocks[2]) | |
| self.gau_block3 = GAU(channels_blocks[2], channels_blocks[3]) | |
| self.gau = [self.gau_block1, self.gau_block2, self.gau_block3] | |
| def forward(self, fms): | |
| """ | |
| :param fms: Feature maps of forward propagation in the network with reverse sequential. shape:[b, c, h, w] | |
| :return: fm_high. [b, 256, h, w] | |
| """ | |
| feats = [] | |
| for i, fm_low in enumerate(fms[::-1]): | |
| if i == 0: | |
| fm_high = self.fpa(fm_low) | |
| else: | |
| fm_high = self.gau[int(i-1)](fm_high, fm_low) | |
| feats.append(fm_high) | |
| feats.reverse() | |
| return tuple(feats) | |