Spaces:
Build error
Build error
| # -*- coding: utf-8 -*- | |
| # @Author : Lintao Peng | |
| # @File : SGFMT.py | |
| # coding=utf-8 | |
| # Design based on the Vit | |
| import torch.nn as nn | |
| from net.IntmdSequential import IntermediateSequential | |
| #实现了自注意力机制,相当于unet的bottleneck层 | |
| class SelfAttention(nn.Module): | |
| def __init__( | |
| self, dim, heads=8, qkv_bias=False, qk_scale=None, dropout_rate=0.0 | |
| ): | |
| super().__init__() | |
| self.num_heads = heads | |
| head_dim = dim // heads | |
| self.scale = qk_scale or head_dim ** -0.5 | |
| self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
| self.attn_drop = nn.Dropout(dropout_rate) | |
| self.proj = nn.Linear(dim, dim) | |
| self.proj_drop = nn.Dropout(dropout_rate) | |
| def forward(self, x): | |
| B, N, C = x.shape | |
| qkv = ( | |
| self.qkv(x) | |
| .reshape(B, N, 3, self.num_heads, C // self.num_heads) | |
| .permute(2, 0, 3, 1, 4) | |
| ) | |
| q, k, v = ( | |
| qkv[0], | |
| qkv[1], | |
| qkv[2], | |
| ) # make torchscript happy (cannot use tensor as tuple) | |
| attn = (q @ k.transpose(-2, -1)) * self.scale | |
| attn = attn.softmax(dim=-1) | |
| attn = self.attn_drop(attn) | |
| x = (attn @ v).transpose(1, 2).reshape(B, N, C) | |
| x = self.proj(x) | |
| x = self.proj_drop(x) | |
| return x | |
| class Residual(nn.Module): | |
| def __init__(self, fn): | |
| super().__init__() | |
| self.fn = fn | |
| def forward(self, x): | |
| return self.fn(x) + x | |
| class PreNorm(nn.Module): | |
| def __init__(self, dim, fn): | |
| super().__init__() | |
| self.norm = nn.LayerNorm(dim) | |
| self.fn = fn | |
| def forward(self, x): | |
| return self.fn(self.norm(x)) | |
| class PreNormDrop(nn.Module): | |
| def __init__(self, dim, dropout_rate, fn): | |
| super().__init__() | |
| self.norm = nn.LayerNorm(dim) | |
| self.dropout = nn.Dropout(p=dropout_rate) | |
| self.fn = fn | |
| def forward(self, x): | |
| return self.dropout(self.fn(self.norm(x))) | |
| class FeedForward(nn.Module): | |
| def __init__(self, dim, hidden_dim, dropout_rate): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(dim, hidden_dim), | |
| nn.GELU(), | |
| nn.Dropout(p=dropout_rate), | |
| nn.Linear(hidden_dim, dim), | |
| nn.Dropout(p=dropout_rate), | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |
| class TransformerModel(nn.Module): | |
| def __init__( | |
| self, | |
| dim, #512 | |
| depth, #4 | |
| heads, #8 | |
| mlp_dim, #4096 | |
| dropout_rate=0.1, | |
| attn_dropout_rate=0.1, | |
| ): | |
| super().__init__() | |
| layers = [] | |
| for _ in range(depth): | |
| layers.extend( | |
| [ | |
| Residual( | |
| PreNormDrop( | |
| dim, | |
| dropout_rate, | |
| SelfAttention(dim, heads=heads, dropout_rate=attn_dropout_rate), | |
| ) | |
| ), | |
| Residual( | |
| PreNorm(dim, FeedForward(dim, mlp_dim, dropout_rate)) | |
| ), | |
| ] | |
| ) | |
| # dim = dim / 2 | |
| self.net = IntermediateSequential(*layers) | |
| def forward(self, x): | |
| return self.net(x) | |