Spaces:
Runtime error
Runtime error
| import torch | |
| from mmcv.cnn import ConvModule, constant_init | |
| from torch import nn as nn | |
| from torch.nn import functional as F | |
| class SelfAttentionBlock(nn.Module): | |
| """General self-attention block/non-local block. | |
| Please refer to https://arxiv.org/abs/1706.03762 for details about key, | |
| query and value. | |
| Args: | |
| key_in_channels (int): Input channels of key feature. | |
| query_in_channels (int): Input channels of query feature. | |
| channels (int): Output channels of key/query transform. | |
| out_channels (int): Output channels. | |
| share_key_query (bool): Whether share projection weight between key | |
| and query projection. | |
| query_downsample (nn.Module): Query downsample module. | |
| key_downsample (nn.Module): Key downsample module. | |
| key_query_num_convs (int): Number of convs for key/query projection. | |
| value_num_convs (int): Number of convs for value projection. | |
| matmul_norm (bool): Whether normalize attention map with sqrt of | |
| channels | |
| with_out (bool): Whether use out projection. | |
| conv_cfg (dict|None): Config of conv layers. | |
| norm_cfg (dict|None): Config of norm layers. | |
| act_cfg (dict|None): Config of activation layers. | |
| """ | |
| def __init__(self, key_in_channels, query_in_channels, channels, | |
| out_channels, share_key_query, query_downsample, | |
| key_downsample, key_query_num_convs, value_out_num_convs, | |
| key_query_norm, value_out_norm, matmul_norm, with_out, | |
| conv_cfg, norm_cfg, act_cfg): | |
| super(SelfAttentionBlock, self).__init__() | |
| if share_key_query: | |
| assert key_in_channels == query_in_channels | |
| self.key_in_channels = key_in_channels | |
| self.query_in_channels = query_in_channels | |
| self.out_channels = out_channels | |
| self.channels = channels | |
| self.share_key_query = share_key_query | |
| self.conv_cfg = conv_cfg | |
| self.norm_cfg = norm_cfg | |
| self.act_cfg = act_cfg | |
| self.key_project = self.build_project( | |
| key_in_channels, | |
| channels, | |
| num_convs=key_query_num_convs, | |
| use_conv_module=key_query_norm, | |
| conv_cfg=conv_cfg, | |
| norm_cfg=norm_cfg, | |
| act_cfg=act_cfg) | |
| if share_key_query: | |
| self.query_project = self.key_project | |
| else: | |
| self.query_project = self.build_project( | |
| query_in_channels, | |
| channels, | |
| num_convs=key_query_num_convs, | |
| use_conv_module=key_query_norm, | |
| conv_cfg=conv_cfg, | |
| norm_cfg=norm_cfg, | |
| act_cfg=act_cfg) | |
| self.value_project = self.build_project( | |
| key_in_channels, | |
| channels if with_out else out_channels, | |
| num_convs=value_out_num_convs, | |
| use_conv_module=value_out_norm, | |
| conv_cfg=conv_cfg, | |
| norm_cfg=norm_cfg, | |
| act_cfg=act_cfg) | |
| if with_out: | |
| self.out_project = self.build_project( | |
| channels, | |
| out_channels, | |
| num_convs=value_out_num_convs, | |
| use_conv_module=value_out_norm, | |
| conv_cfg=conv_cfg, | |
| norm_cfg=norm_cfg, | |
| act_cfg=act_cfg) | |
| else: | |
| self.out_project = None | |
| self.query_downsample = query_downsample | |
| self.key_downsample = key_downsample | |
| self.matmul_norm = matmul_norm | |
| self.init_weights() | |
| def init_weights(self): | |
| """Initialize weight of later layer.""" | |
| if self.out_project is not None: | |
| if not isinstance(self.out_project, ConvModule): | |
| constant_init(self.out_project, 0) | |
| def build_project(self, in_channels, channels, num_convs, use_conv_module, | |
| conv_cfg, norm_cfg, act_cfg): | |
| """Build projection layer for key/query/value/out.""" | |
| if use_conv_module: | |
| convs = [ | |
| ConvModule( | |
| in_channels, | |
| channels, | |
| 1, | |
| conv_cfg=conv_cfg, | |
| norm_cfg=norm_cfg, | |
| act_cfg=act_cfg) | |
| ] | |
| for _ in range(num_convs - 1): | |
| convs.append( | |
| ConvModule( | |
| channels, | |
| channels, | |
| 1, | |
| conv_cfg=conv_cfg, | |
| norm_cfg=norm_cfg, | |
| act_cfg=act_cfg)) | |
| else: | |
| convs = [nn.Conv2d(in_channels, channels, 1)] | |
| for _ in range(num_convs - 1): | |
| convs.append(nn.Conv2d(channels, channels, 1)) | |
| if len(convs) > 1: | |
| convs = nn.Sequential(*convs) | |
| else: | |
| convs = convs[0] | |
| return convs | |
| def forward(self, query_feats, key_feats): | |
| """Forward function.""" | |
| batch_size = query_feats.size(0) | |
| query = self.query_project(query_feats) | |
| if self.query_downsample is not None: | |
| query = self.query_downsample(query) | |
| query = query.reshape(*query.shape[:2], -1) | |
| query = query.permute(0, 2, 1).contiguous() | |
| key = self.key_project(key_feats) | |
| value = self.value_project(key_feats) | |
| if self.key_downsample is not None: | |
| key = self.key_downsample(key) | |
| value = self.key_downsample(value) | |
| key = key.reshape(*key.shape[:2], -1) | |
| value = value.reshape(*value.shape[:2], -1) | |
| value = value.permute(0, 2, 1).contiguous() | |
| sim_map = torch.matmul(query, key) | |
| if self.matmul_norm: | |
| sim_map = (self.channels**-.5) * sim_map | |
| sim_map = F.softmax(sim_map, dim=-1) | |
| context = torch.matmul(sim_map, value) | |
| context = context.permute(0, 2, 1).contiguous() | |
| context = context.reshape(batch_size, -1, *query_feats.shape[2:]) | |
| if self.out_project is not None: | |
| context = self.out_project(context) | |
| return context | |