Spaces:
Runtime error
Runtime error
| """ | |
| FBNet model builder | |
| """ | |
| from __future__ import absolute_import, division, print_function, unicode_literals | |
| import copy | |
| import logging | |
| import math | |
| from collections import OrderedDict | |
| import torch | |
| import torch.nn as nn | |
| from maskrcnn_benchmark.layers import ( | |
| BatchNorm2d, | |
| Conv2d, | |
| FrozenBatchNorm2d, | |
| interpolate, | |
| ) | |
| from maskrcnn_benchmark.layers.misc import _NewEmptyTensorOp | |
| logger = logging.getLogger(__name__) | |
| def _py2_round(x): | |
| return math.floor(x + 0.5) if x >= 0.0 else math.ceil(x - 0.5) | |
| def _get_divisible_by(num, divisible_by, min_val): | |
| ret = int(num) | |
| if divisible_by > 0 and num % divisible_by != 0: | |
| ret = int((_py2_round(num / divisible_by) or min_val) * divisible_by) | |
| return ret | |
| PRIMITIVES = { | |
| "skip": lambda C_in, C_out, expansion, stride, **kwargs: Identity( | |
| C_in, C_out, stride | |
| ), | |
| "ir_k3": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( | |
| C_in, C_out, expansion, stride, **kwargs | |
| ), | |
| "ir_k5": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( | |
| C_in, C_out, expansion, stride, kernel=5, **kwargs | |
| ), | |
| "ir_k7": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( | |
| C_in, C_out, expansion, stride, kernel=7, **kwargs | |
| ), | |
| "ir_k1": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( | |
| C_in, C_out, expansion, stride, kernel=1, **kwargs | |
| ), | |
| "shuffle": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( | |
| C_in, C_out, expansion, stride, shuffle_type="mid", pw_group=4, **kwargs | |
| ), | |
| "basic_block": lambda C_in, C_out, expansion, stride, **kwargs: CascadeConv3x3( | |
| C_in, C_out, stride | |
| ), | |
| "shift_5x5": lambda C_in, C_out, expansion, stride, **kwargs: ShiftBlock5x5( | |
| C_in, C_out, expansion, stride | |
| ), | |
| # layer search 2 | |
| "ir_k3_e1": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( | |
| C_in, C_out, 1, stride, kernel=3, **kwargs | |
| ), | |
| "ir_k3_e3": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( | |
| C_in, C_out, 3, stride, kernel=3, **kwargs | |
| ), | |
| "ir_k3_e6": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( | |
| C_in, C_out, 6, stride, kernel=3, **kwargs | |
| ), | |
| "ir_k3_s4": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( | |
| C_in, C_out, 4, stride, kernel=3, shuffle_type="mid", pw_group=4, **kwargs | |
| ), | |
| "ir_k5_e1": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( | |
| C_in, C_out, 1, stride, kernel=5, **kwargs | |
| ), | |
| "ir_k5_e3": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( | |
| C_in, C_out, 3, stride, kernel=5, **kwargs | |
| ), | |
| "ir_k5_e6": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( | |
| C_in, C_out, 6, stride, kernel=5, **kwargs | |
| ), | |
| "ir_k5_s4": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( | |
| C_in, C_out, 4, stride, kernel=5, shuffle_type="mid", pw_group=4, **kwargs | |
| ), | |
| # layer search se | |
| "ir_k3_e1_se": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( | |
| C_in, C_out, 1, stride, kernel=3, se=True, **kwargs | |
| ), | |
| "ir_k3_e3_se": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( | |
| C_in, C_out, 3, stride, kernel=3, se=True, **kwargs | |
| ), | |
| "ir_k3_e6_se": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( | |
| C_in, C_out, 6, stride, kernel=3, se=True, **kwargs | |
| ), | |
| "ir_k3_s4_se": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( | |
| C_in, | |
| C_out, | |
| 4, | |
| stride, | |
| kernel=3, | |
| shuffle_type="mid", | |
| pw_group=4, | |
| se=True, | |
| **kwargs | |
| ), | |
| "ir_k5_e1_se": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( | |
| C_in, C_out, 1, stride, kernel=5, se=True, **kwargs | |
| ), | |
| "ir_k5_e3_se": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( | |
| C_in, C_out, 3, stride, kernel=5, se=True, **kwargs | |
| ), | |
| "ir_k5_e6_se": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( | |
| C_in, C_out, 6, stride, kernel=5, se=True, **kwargs | |
| ), | |
| "ir_k5_s4_se": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( | |
| C_in, | |
| C_out, | |
| 4, | |
| stride, | |
| kernel=5, | |
| shuffle_type="mid", | |
| pw_group=4, | |
| se=True, | |
| **kwargs | |
| ), | |
| # layer search 3 (in addition to layer search 2) | |
| "ir_k3_s2": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( | |
| C_in, C_out, 1, stride, kernel=3, shuffle_type="mid", pw_group=2, **kwargs | |
| ), | |
| "ir_k5_s2": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( | |
| C_in, C_out, 1, stride, kernel=5, shuffle_type="mid", pw_group=2, **kwargs | |
| ), | |
| "ir_k3_s2_se": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( | |
| C_in, | |
| C_out, | |
| 1, | |
| stride, | |
| kernel=3, | |
| shuffle_type="mid", | |
| pw_group=2, | |
| se=True, | |
| **kwargs | |
| ), | |
| "ir_k5_s2_se": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( | |
| C_in, | |
| C_out, | |
| 1, | |
| stride, | |
| kernel=5, | |
| shuffle_type="mid", | |
| pw_group=2, | |
| se=True, | |
| **kwargs | |
| ), | |
| # layer search 4 (in addition to layer search 3) | |
| "ir_k3_sep": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( | |
| C_in, C_out, expansion, stride, kernel=3, cdw=True, **kwargs | |
| ), | |
| "ir_k33_e1": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( | |
| C_in, C_out, 1, stride, kernel=3, cdw=True, **kwargs | |
| ), | |
| "ir_k33_e3": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( | |
| C_in, C_out, 3, stride, kernel=3, cdw=True, **kwargs | |
| ), | |
| "ir_k33_e6": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( | |
| C_in, C_out, 6, stride, kernel=3, cdw=True, **kwargs | |
| ), | |
| # layer search 5 (in addition to layer search 4) | |
| "ir_k7_e1": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( | |
| C_in, C_out, 1, stride, kernel=7, **kwargs | |
| ), | |
| "ir_k7_e3": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( | |
| C_in, C_out, 3, stride, kernel=7, **kwargs | |
| ), | |
| "ir_k7_e6": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( | |
| C_in, C_out, 6, stride, kernel=7, **kwargs | |
| ), | |
| "ir_k7_sep": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( | |
| C_in, C_out, expansion, stride, kernel=7, cdw=True, **kwargs | |
| ), | |
| "ir_k7_sep_e1": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( | |
| C_in, C_out, 1, stride, kernel=7, cdw=True, **kwargs | |
| ), | |
| "ir_k7_sep_e3": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( | |
| C_in, C_out, 3, stride, kernel=7, cdw=True, **kwargs | |
| ), | |
| "ir_k7_sep_e6": lambda C_in, C_out, expansion, stride, **kwargs: IRFBlock( | |
| C_in, C_out, 6, stride, kernel=7, cdw=True, **kwargs | |
| ), | |
| } | |
| class Identity(nn.Module): | |
| def __init__(self, C_in, C_out, stride): | |
| super(Identity, self).__init__() | |
| self.conv = ( | |
| ConvBNRelu( | |
| C_in, | |
| C_out, | |
| kernel=1, | |
| stride=stride, | |
| pad=0, | |
| no_bias=1, | |
| use_relu="relu", | |
| bn_type="bn", | |
| ) | |
| if C_in != C_out or stride != 1 | |
| else None | |
| ) | |
| def forward(self, x): | |
| if self.conv: | |
| out = self.conv(x) | |
| else: | |
| out = x | |
| return out | |
| class CascadeConv3x3(nn.Sequential): | |
| def __init__(self, C_in, C_out, stride): | |
| assert stride in [1, 2] | |
| ops = [ | |
| Conv2d(C_in, C_in, 3, stride, 1, bias=False), | |
| BatchNorm2d(C_in), | |
| nn.ReLU(inplace=True), | |
| Conv2d(C_in, C_out, 3, 1, 1, bias=False), | |
| BatchNorm2d(C_out), | |
| ] | |
| super(CascadeConv3x3, self).__init__(*ops) | |
| self.res_connect = (stride == 1) and (C_in == C_out) | |
| def forward(self, x): | |
| y = super(CascadeConv3x3, self).forward(x) | |
| if self.res_connect: | |
| y += x | |
| return y | |
| class Shift(nn.Module): | |
| def __init__(self, C, kernel_size, stride, padding): | |
| super(Shift, self).__init__() | |
| self.C = C | |
| kernel = torch.zeros((C, 1, kernel_size, kernel_size), dtype=torch.float32) | |
| ch_idx = 0 | |
| assert stride in [1, 2] | |
| self.stride = stride | |
| self.padding = padding | |
| self.kernel_size = kernel_size | |
| self.dilation = 1 | |
| hks = kernel_size // 2 | |
| ksq = kernel_size ** 2 | |
| for i in range(kernel_size): | |
| for j in range(kernel_size): | |
| if i == hks and j == hks: | |
| num_ch = C // ksq + C % ksq | |
| else: | |
| num_ch = C // ksq | |
| kernel[ch_idx : ch_idx + num_ch, 0, i, j] = 1 | |
| ch_idx += num_ch | |
| self.register_parameter("bias", None) | |
| self.kernel = nn.Parameter(kernel, requires_grad=False) | |
| def forward(self, x): | |
| if x.numel() > 0: | |
| return nn.functional.conv2d( | |
| x, | |
| self.kernel, | |
| self.bias, | |
| (self.stride, self.stride), | |
| (self.padding, self.padding), | |
| self.dilation, | |
| self.C, # groups | |
| ) | |
| output_shape = [ | |
| (i + 2 * p - (di * (k - 1) + 1)) // d + 1 | |
| for i, p, di, k, d in zip( | |
| x.shape[-2:], | |
| (self.padding, self.dilation), | |
| (self.dilation, self.dilation), | |
| (self.kernel_size, self.kernel_size), | |
| (self.stride, self.stride), | |
| ) | |
| ] | |
| output_shape = [x.shape[0], self.C] + output_shape | |
| return _NewEmptyTensorOp.apply(x, output_shape) | |
| class ShiftBlock5x5(nn.Sequential): | |
| def __init__(self, C_in, C_out, expansion, stride): | |
| assert stride in [1, 2] | |
| self.res_connect = (stride == 1) and (C_in == C_out) | |
| C_mid = _get_divisible_by(C_in * expansion, 8, 8) | |
| ops = [ | |
| # pw | |
| Conv2d(C_in, C_mid, 1, 1, 0, bias=False), | |
| BatchNorm2d(C_mid), | |
| nn.ReLU(inplace=True), | |
| # shift | |
| Shift(C_mid, 5, stride, 2), | |
| # pw-linear | |
| Conv2d(C_mid, C_out, 1, 1, 0, bias=False), | |
| BatchNorm2d(C_out), | |
| ] | |
| super(ShiftBlock5x5, self).__init__(*ops) | |
| def forward(self, x): | |
| y = super(ShiftBlock5x5, self).forward(x) | |
| if self.res_connect: | |
| y += x | |
| return y | |
| class ChannelShuffle(nn.Module): | |
| def __init__(self, groups): | |
| super(ChannelShuffle, self).__init__() | |
| self.groups = groups | |
| def forward(self, x): | |
| """Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]""" | |
| N, C, H, W = x.size() | |
| g = self.groups | |
| assert C % g == 0, "Incompatible group size {} for input channel {}".format( | |
| g, C | |
| ) | |
| return ( | |
| x.view(N, g, int(C / g), H, W) | |
| .permute(0, 2, 1, 3, 4) | |
| .contiguous() | |
| .view(N, C, H, W) | |
| ) | |
| class ConvBNRelu(nn.Sequential): | |
| def __init__( | |
| self, | |
| input_depth, | |
| output_depth, | |
| kernel, | |
| stride, | |
| pad, | |
| no_bias, | |
| use_relu, | |
| bn_type, | |
| group=1, | |
| *args, | |
| **kwargs | |
| ): | |
| super(ConvBNRelu, self).__init__() | |
| assert use_relu in ["relu", None] | |
| if isinstance(bn_type, (list, tuple)): | |
| assert len(bn_type) == 2 | |
| assert bn_type[0] == "gn" | |
| gn_group = bn_type[1] | |
| bn_type = bn_type[0] | |
| assert bn_type in ["bn", "af", "gn", None] | |
| assert stride in [1, 2, 4] | |
| op = Conv2d( | |
| input_depth, | |
| output_depth, | |
| kernel_size=kernel, | |
| stride=stride, | |
| padding=pad, | |
| bias=not no_bias, | |
| groups=group, | |
| *args, | |
| **kwargs | |
| ) | |
| nn.init.kaiming_normal_(op.weight, mode="fan_out", nonlinearity="relu") | |
| if op.bias is not None: | |
| nn.init.constant_(op.bias, 0.0) | |
| self.add_module("conv", op) | |
| if bn_type == "bn": | |
| bn_op = BatchNorm2d(output_depth) | |
| elif bn_type == "gn": | |
| bn_op = nn.GroupNorm(num_groups=gn_group, num_channels=output_depth) | |
| elif bn_type == "af": | |
| bn_op = FrozenBatchNorm2d(output_depth) | |
| if bn_type is not None: | |
| self.add_module("bn", bn_op) | |
| if use_relu == "relu": | |
| self.add_module("relu", nn.ReLU(inplace=True)) | |
| class SEModule(nn.Module): | |
| reduction = 4 | |
| def __init__(self, C): | |
| super(SEModule, self).__init__() | |
| mid = max(C // self.reduction, 8) | |
| conv1 = Conv2d(C, mid, 1, 1, 0) | |
| conv2 = Conv2d(mid, C, 1, 1, 0) | |
| self.op = nn.Sequential( | |
| nn.AdaptiveAvgPool2d(1), conv1, nn.ReLU(inplace=True), conv2, nn.Sigmoid() | |
| ) | |
| def forward(self, x): | |
| return x * self.op(x) | |
| class Upsample(nn.Module): | |
| def __init__(self, scale_factor, mode, align_corners=None): | |
| super(Upsample, self).__init__() | |
| self.scale = scale_factor | |
| self.mode = mode | |
| self.align_corners = align_corners | |
| def forward(self, x): | |
| return interpolate( | |
| x, scale_factor=self.scale, mode=self.mode, | |
| align_corners=self.align_corners | |
| ) | |
| def _get_upsample_op(stride): | |
| assert ( | |
| stride in [1, 2, 4] | |
| or stride in [-1, -2, -4] | |
| or (isinstance(stride, tuple) and all(x in [-1, -2, -4] for x in stride)) | |
| ) | |
| scales = stride | |
| ret = None | |
| if isinstance(stride, tuple) or stride < 0: | |
| scales = [-x for x in stride] if isinstance(stride, tuple) else -stride | |
| stride = 1 | |
| ret = Upsample(scale_factor=scales, mode="nearest", align_corners=None) | |
| return ret, stride | |
| class IRFBlock(nn.Module): | |
| def __init__( | |
| self, | |
| input_depth, | |
| output_depth, | |
| expansion, | |
| stride, | |
| bn_type="bn", | |
| kernel=3, | |
| width_divisor=1, | |
| shuffle_type=None, | |
| pw_group=1, | |
| se=False, | |
| cdw=False, | |
| dw_skip_bn=False, | |
| dw_skip_relu=False, | |
| ): | |
| super(IRFBlock, self).__init__() | |
| assert kernel in [1, 3, 5, 7], kernel | |
| self.use_res_connect = stride == 1 and input_depth == output_depth | |
| self.output_depth = output_depth | |
| mid_depth = int(input_depth * expansion) | |
| mid_depth = _get_divisible_by(mid_depth, width_divisor, width_divisor) | |
| # pw | |
| self.pw = ConvBNRelu( | |
| input_depth, | |
| mid_depth, | |
| kernel=1, | |
| stride=1, | |
| pad=0, | |
| no_bias=1, | |
| use_relu="relu", | |
| bn_type=bn_type, | |
| group=pw_group, | |
| ) | |
| # negative stride to do upsampling | |
| self.upscale, stride = _get_upsample_op(stride) | |
| # dw | |
| if kernel == 1: | |
| self.dw = nn.Sequential() | |
| elif cdw: | |
| dw1 = ConvBNRelu( | |
| mid_depth, | |
| mid_depth, | |
| kernel=kernel, | |
| stride=stride, | |
| pad=(kernel // 2), | |
| group=mid_depth, | |
| no_bias=1, | |
| use_relu="relu", | |
| bn_type=bn_type, | |
| ) | |
| dw2 = ConvBNRelu( | |
| mid_depth, | |
| mid_depth, | |
| kernel=kernel, | |
| stride=1, | |
| pad=(kernel // 2), | |
| group=mid_depth, | |
| no_bias=1, | |
| use_relu="relu" if not dw_skip_relu else None, | |
| bn_type=bn_type if not dw_skip_bn else None, | |
| ) | |
| self.dw = nn.Sequential(OrderedDict([("dw1", dw1), ("dw2", dw2)])) | |
| else: | |
| self.dw = ConvBNRelu( | |
| mid_depth, | |
| mid_depth, | |
| kernel=kernel, | |
| stride=stride, | |
| pad=(kernel // 2), | |
| group=mid_depth, | |
| no_bias=1, | |
| use_relu="relu" if not dw_skip_relu else None, | |
| bn_type=bn_type if not dw_skip_bn else None, | |
| ) | |
| # pw-linear | |
| self.pwl = ConvBNRelu( | |
| mid_depth, | |
| output_depth, | |
| kernel=1, | |
| stride=1, | |
| pad=0, | |
| no_bias=1, | |
| use_relu=None, | |
| bn_type=bn_type, | |
| group=pw_group, | |
| ) | |
| self.shuffle_type = shuffle_type | |
| if shuffle_type is not None: | |
| self.shuffle = ChannelShuffle(pw_group) | |
| self.se4 = SEModule(output_depth) if se else nn.Sequential() | |
| self.output_depth = output_depth | |
| def forward(self, x): | |
| y = self.pw(x) | |
| if self.shuffle_type == "mid": | |
| y = self.shuffle(y) | |
| if self.upscale is not None: | |
| y = self.upscale(y) | |
| y = self.dw(y) | |
| y = self.pwl(y) | |
| if self.use_res_connect: | |
| y += x | |
| y = self.se4(y) | |
| return y | |
| def _expand_block_cfg(block_cfg): | |
| assert isinstance(block_cfg, list) | |
| ret = [] | |
| for idx in range(block_cfg[2]): | |
| cur = copy.deepcopy(block_cfg) | |
| cur[2] = 1 | |
| cur[3] = 1 if idx >= 1 else cur[3] | |
| ret.append(cur) | |
| return ret | |
| def expand_stage_cfg(stage_cfg): | |
| """ For a single stage """ | |
| assert isinstance(stage_cfg, list) | |
| ret = [] | |
| for x in stage_cfg: | |
| ret += _expand_block_cfg(x) | |
| return ret | |
| def expand_stages_cfg(stage_cfgs): | |
| """ For a list of stages """ | |
| assert isinstance(stage_cfgs, list) | |
| ret = [] | |
| for x in stage_cfgs: | |
| ret.append(expand_stage_cfg(x)) | |
| return ret | |
| def _block_cfgs_to_list(block_cfgs): | |
| assert isinstance(block_cfgs, list) | |
| ret = [] | |
| for stage_idx, stage in enumerate(block_cfgs): | |
| stage = expand_stage_cfg(stage) | |
| for block_idx, block in enumerate(stage): | |
| cur = {"stage_idx": stage_idx, "block_idx": block_idx, "block": block} | |
| ret.append(cur) | |
| return ret | |
| def _add_to_arch(arch, info, name): | |
| """ arch = [{block_0}, {block_1}, ...] | |
| info = [ | |
| # stage 0 | |
| [ | |
| block0_info, | |
| block1_info, | |
| ... | |
| ], ... | |
| ] | |
| convert to: | |
| arch = [ | |
| { | |
| block_0, | |
| name: block0_info, | |
| }, | |
| { | |
| block_1, | |
| name: block1_info, | |
| }, ... | |
| ] | |
| """ | |
| assert isinstance(arch, list) and all(isinstance(x, dict) for x in arch) | |
| assert isinstance(info, list) and all(isinstance(x, list) for x in info) | |
| idx = 0 | |
| for stage_idx, stage in enumerate(info): | |
| for block_idx, block in enumerate(stage): | |
| assert ( | |
| arch[idx]["stage_idx"] == stage_idx | |
| and arch[idx]["block_idx"] == block_idx | |
| ), "Index ({}, {}) does not match for block {}".format( | |
| stage_idx, block_idx, arch[idx] | |
| ) | |
| assert name not in arch[idx] | |
| arch[idx][name] = block | |
| idx += 1 | |
| def unify_arch_def(arch_def): | |
| """ unify the arch_def to: | |
| { | |
| ..., | |
| "arch": [ | |
| { | |
| "stage_idx": idx, | |
| "block_idx": idx, | |
| ... | |
| }, | |
| {}, ... | |
| ] | |
| } | |
| """ | |
| ret = copy.deepcopy(arch_def) | |
| assert "block_cfg" in arch_def and "stages" in arch_def["block_cfg"] | |
| assert "stages" not in ret | |
| # copy 'first', 'last' etc. inside arch_def['block_cfg'] to ret | |
| ret.update({x: arch_def["block_cfg"][x] for x in arch_def["block_cfg"]}) | |
| ret["stages"] = _block_cfgs_to_list(arch_def["block_cfg"]["stages"]) | |
| del ret["block_cfg"] | |
| assert "block_op_type" in arch_def | |
| _add_to_arch(ret["stages"], arch_def["block_op_type"], "block_op_type") | |
| del ret["block_op_type"] | |
| return ret | |
| def get_num_stages(arch_def): | |
| ret = 0 | |
| for x in arch_def["stages"]: | |
| ret = max(x["stage_idx"], ret) | |
| ret = ret + 1 | |
| return ret | |
| def get_blocks(arch_def, stage_indices=None, block_indices=None): | |
| ret = copy.deepcopy(arch_def) | |
| ret["stages"] = [] | |
| for block in arch_def["stages"]: | |
| keep = True | |
| if stage_indices not in (None, []) and block["stage_idx"] not in stage_indices: | |
| keep = False | |
| if block_indices not in (None, []) and block["block_idx"] not in block_indices: | |
| keep = False | |
| if keep: | |
| ret["stages"].append(block) | |
| return ret | |
| class FBNetBuilder(object): | |
| def __init__( | |
| self, | |
| width_ratio, | |
| bn_type="bn", | |
| width_divisor=1, | |
| dw_skip_bn=False, | |
| dw_skip_relu=False, | |
| ): | |
| self.width_ratio = width_ratio | |
| self.last_depth = -1 | |
| self.bn_type = bn_type | |
| self.width_divisor = width_divisor | |
| self.dw_skip_bn = dw_skip_bn | |
| self.dw_skip_relu = dw_skip_relu | |
| def add_first(self, stage_info, dim_in=3, pad=True): | |
| # stage_info: [c, s, kernel] | |
| assert len(stage_info) >= 2 | |
| channel = stage_info[0] | |
| stride = stage_info[1] | |
| out_depth = self._get_divisible_width(int(channel * self.width_ratio)) | |
| kernel = 3 | |
| if len(stage_info) > 2: | |
| kernel = stage_info[2] | |
| out = ConvBNRelu( | |
| dim_in, | |
| out_depth, | |
| kernel=kernel, | |
| stride=stride, | |
| pad=kernel // 2 if pad else 0, | |
| no_bias=1, | |
| use_relu="relu", | |
| bn_type=self.bn_type, | |
| ) | |
| self.last_depth = out_depth | |
| return out | |
| def add_blocks(self, blocks): | |
| """ blocks: [{}, {}, ...] | |
| """ | |
| assert isinstance(blocks, list) and all( | |
| isinstance(x, dict) for x in blocks | |
| ), blocks | |
| modules = OrderedDict() | |
| for block in blocks: | |
| stage_idx = block["stage_idx"] | |
| block_idx = block["block_idx"] | |
| block_op_type = block["block_op_type"] | |
| tcns = block["block"] | |
| n = tcns[2] | |
| assert n == 1 | |
| nnblock = self.add_ir_block(tcns, [block_op_type]) | |
| nn_name = "xif{}_{}".format(stage_idx, block_idx) | |
| assert nn_name not in modules | |
| modules[nn_name] = nnblock | |
| ret = nn.Sequential(modules) | |
| return ret | |
| def add_last(self, stage_info): | |
| """ skip last layer if channel_scale == 0 | |
| use the same output channel if channel_scale < 0 | |
| """ | |
| assert len(stage_info) == 2 | |
| channels = stage_info[0] | |
| channel_scale = stage_info[1] | |
| if channel_scale == 0.0: | |
| return nn.Sequential() | |
| if channel_scale > 0: | |
| last_channel = ( | |
| int(channels * self.width_ratio) if self.width_ratio > 1.0 else channels | |
| ) | |
| last_channel = int(last_channel * channel_scale) | |
| else: | |
| last_channel = int(self.last_depth * (-channel_scale)) | |
| last_channel = self._get_divisible_width(last_channel) | |
| if last_channel == 0: | |
| return nn.Sequential() | |
| dim_in = self.last_depth | |
| ret = ConvBNRelu( | |
| dim_in, | |
| last_channel, | |
| kernel=1, | |
| stride=1, | |
| pad=0, | |
| no_bias=1, | |
| use_relu="relu", | |
| bn_type=self.bn_type, | |
| ) | |
| self.last_depth = last_channel | |
| return ret | |
| # def add_final_pool(self, model, blob_in, kernel_size): | |
| # ret = model.AveragePool(blob_in, "final_avg", kernel=kernel_size, stride=1) | |
| # return ret | |
| def _add_ir_block( | |
| self, dim_in, dim_out, stride, expand_ratio, block_op_type, **kwargs | |
| ): | |
| ret = PRIMITIVES[block_op_type]( | |
| dim_in, | |
| dim_out, | |
| expansion=expand_ratio, | |
| stride=stride, | |
| bn_type=self.bn_type, | |
| width_divisor=self.width_divisor, | |
| dw_skip_bn=self.dw_skip_bn, | |
| dw_skip_relu=self.dw_skip_relu, | |
| **kwargs | |
| ) | |
| return ret, ret.output_depth | |
| def add_ir_block(self, tcns, block_op_types, **kwargs): | |
| t, c, n, s = tcns | |
| assert n == 1 | |
| out_depth = self._get_divisible_width(int(c * self.width_ratio)) | |
| dim_in = self.last_depth | |
| op, ret_depth = self._add_ir_block( | |
| dim_in, | |
| out_depth, | |
| stride=s, | |
| expand_ratio=t, | |
| block_op_type=block_op_types[0], | |
| **kwargs | |
| ) | |
| self.last_depth = ret_depth | |
| return op | |
| def _get_divisible_width(self, width): | |
| ret = _get_divisible_by(int(width), self.width_divisor, self.width_divisor) | |
| return ret | |