transformer and attention(二):various attention modules

介绍现在的各种各样(空间上,通道上)的attention模块以及相关代码.

Squeeze-and-Excitation Networks 2018

image-20240110094833740

  1. SENet通过学习channel之间的相关性,筛选出了针对通道的注意力,稍微增加了一点计算量,但是效果提升较明显
  2. Squeeze-and-Excitation(SE) block是一个子结构,可以有效地嵌到其他分类或检测模型中。
  3. SENet的核心思想在于通过网络根据loss去学习feature map的特征权重来使模型达到更好的结果
  4. SE模块本质上是一种attention机制

image-20240110095007870

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import numpy as np
import torch
from torch import nn
from torch.nn import init


# implement SEAttention

class SEAttention(nn.Module):
def __init__(self, channel=512, reduction=16):
super(SEAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction),
nn.ReLU(),
nn.Linear(channel // reduction, channel),
nn.Sigmoid()
)

def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.001)
if m.bias is not None:
init.constant_(m.bias, 0)

def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)

Bottlenet attention Module (BAM) 2018

image-20240123200825093

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
import torch.nn as nn
from torch.nn import init

class Flatten(nn.Module):
def forward(self, input):
return input.view(input.size(0), -1)



class ChannelAttention(nn.Module):
def __init__(self,channel,reduction:int=16,num_layer:int=3):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
gate_channels = [channel]
gate_channels += [channel // reduction] * num_layer
gate_channels += [channel]

self.ca = nn.Sequential()
self.ca.add_module('flatten',Flatten())
for i in range(num_layer):
self.ca.add_module('fc{}'.format(i),nn.Linear(gate_channels[i],gate_channels[i+1]))
self.ca.add_module('bn%d' % i, nn.BatchNorm1d(gate_channels[i+1]))
self.ca.add_module('relu{}'.format(i),nn.ReLU())

self.ca.add_module('last_fc',nn.Linear(gate_channels[-2],gate_channels[-1]))

def forward(self,x):
res = self.avg_pool(x)
res = self.ca(res)
return res.unsqueeze(-1).unsqueeze(-1).expand_as(x)

class SpatialAttention(nn.Module):
def __init__(self,channel,reduction=16,num_layers=3,dia_val=2):
super().__init__()
self.sa = nn.Sequential()
self.sa.add_module('conv_reduce1',nn.Conv2d(in_channels=channel,out_channels=channel//reduction,kernel_size=1))
self.sa.add_module('bn_reduce1',nn.BatchNorm2d(channel//reduction))
self.sa.add_module('relu_reduce1',nn.ReLU())
for i in range(num_layers):
self.sa.add_module('conv_%d' % i,nn.Conv2d(in_channels=channel//reduction,out_channels=channel//reduction,kernel_size=3,padding=1,dilation=dia_val))
self.sa.add_module('bn_%d' % i,nn.BatchNorm2d(channel//reduction))
self.sa.add_module('relu_%d' % i,nn.ReLU())
self.sa.add_module('conv_last',nn.Conv2d(in_channels=channel//reduction,out_channels=channel,kernel_size=1))

def forward(self,x):
res = self.sa(x)

return res.expand_as(x)



class BAMBlock(nn.Module):
def __init__(self,channel:int=512,reduction:int=16,dia_val:int=2):
super().__init__()
self.ca = ChannelAttention(channel=channel,reduction=reduction)
self.sa = SpatialAttention(channel=channel,reduction=reduction,dia_val=dia_val)
self.sigmoid = nn.Sigmoid()
self.init_weights()

def forward(self,x):
b, c, _, _ = x.size()
sa_out = self.sa(x)
ca_out = self.ca(x)
weight = self.sigmoid(sa_out + ca_out)
out = (1 + weight) * x
return out


def init_weights(self):
# initial weights for the model
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m,nn.BatchNorm2d):
init.constant_(m.weight,1)
init.constant_(m.bias,0)
elif isinstance(m,nn.Linear):
init.normal_(m.weight,std=0.001)
if m.bias is not None:
init.constant_(m.bias,0)

image-20240123200842336

DANet: Dual Attention Network 2018

image-20240123210612879

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class PositionAttentionModule(nn.Module):

def __init__(self,d_model=512,kernel_size=3,H=7,W=7):
super().__init__()
self.cnn=nn.Conv2d(d_model,d_model,kernel_size=kernel_size,padding=(kernel_size-1)//2)
self.pa=ScaledDotProductAttention(d_model,d_k=d_model,d_v=d_model,h=1)

def forward(self,x):
bs,c,h,w=x.shape
y=self.cnn(x)
y=y.view(bs,c,-1).permute(0,2,1) #bs,h*w,c
y=self.pa(y,y,y) #bs,h*w,c
return y


class ChannelAttentionModule(nn.Module):

def __init__(self,d_model=512,kernel_size=3,H=7,W=7):
super().__init__()
self.cnn=nn.Conv2d(d_model,d_model,kernel_size=kernel_size,padding=(kernel_size-1)//2)
self.pa=SimplifiedScaledDotProductAttention(H*W,h=1)

def forward(self,x):
bs,c,h,w=x.shape
y=self.cnn(x)
y=y.view(bs,c,-1) #bs,c,h*w
y=self.pa(y,y,y) #bs,c,h*w
return y


class DAModule(nn.Module):

def __init__(self,d_model=512,kernel_size=3,H=7,W=7):
super().__init__()
self.position_attention_module=PositionAttentionModule(d_model=512,kernel_size=3,H=7,W=7)
self.channel_attention_module=ChannelAttentionModule(d_model=512,kernel_size=3,H=7,W=7)

def forward(self,input):
bs,c,h,w=input.shape
p_out=self.position_attention_module(input)
c_out=self.channel_attention_module(input)
p_out=p_out.permute(0,2,1).view(bs,c,h,w)
c_out=c_out.view(bs,c,h,w)
return p_out+c_out


CBAM: Convolutional Block Attention Module 2018

image-20240110104503985

通道注意力

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class ChannelAttention(nn.Module):
def __init__(self, in_planes, ratio=16):
super(ChannelAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)

self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False)
self.relu1 = nn.ReLU()
self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)

self.sigmoid = nn.Sigmoid()

def forward(self, x):
avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
out = avg_out + max_out
return self.sigmoid(out)

image-20240110104513785

空间注意力

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class SpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
super(SpatialAttention, self).__init__()

assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
padding = 3 if kernel_size == 7 else 1

self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
self.sigmoid = nn.Sigmoid()

def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
x = torch.cat([avg_out, max_out], dim=1)
x = self.conv1(x)
return self.sigmoid(x)

image-20240110104523806

image-20240110142553774

Non-Local 2018

image-20240111161108109

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
import torch.nn as nn
class NonLocalNet(nn.Module):
def __init__(self, input_dim=64, output_dim=64):
super(NonLocalNet, self).__init__()
intermediate_dim = input_dim // 2
self.to_q = nn.Conv2d(input_dim, intermediate_dim, 1)
self.to_k = nn.Conv2d(input_dim, intermediate_dim, 1)
self.to_v = nn.Conv2d(input_dim, intermediate_dim, 1)

self.conv = nn.Conv2d(intermediate_dim, output_dim, 1)

def forward(self, x):
q = self.to_q(x).squeeze()
k = self.to_k(x).squeeze()
v = self.to_v(x).squeeze()

u = torch.bmm(q, k.transpose(1, 2))
u = torch.softmax(u, dim=1)
out = torch.bmm(u, v)
out = out.unsqueeze(2)
out = self.conv(out)
return out + x

image-20240119110302126

SKNet 2019

image-20240123192700386

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
class SKConv(nn.Module):
"""
https://arxiv.org/pdf/1903.06586.pdf
"""
def __init__(self, feature_dim, WH, M, G, r, stride=1, L=32):

""" Constructor
Args:
features: input channel dimensionality.
WH: input spatial dimensionality, used for GAP kernel size.
M: the number of branchs.
G: num of convolution groups.
r: the radio for compute d, the length of z.
stride: stride, default 1.
L: the minimum dim of the vector z in paper, default 32.
"""
super().__init__()
d = max(int(feature_dim / r), L)
self.M = M
self.feature_dim = feature_dim
self.convs = nn.ModuleList()
for i in range(M):
self.convs.append(nn.Sequential(
nn.Conv2d(feature_dim, feature_dim, kernel_size=3 + i * 2, stride=stride, padding=1 + i, groups=G),
nn.BatchNorm2d(feature_dim),
nn.ReLU(inplace=False)
))
self.gap = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(feature_dim, d)
self.fcs = nn.ModuleList()
for i in range(M):
self.fcs.append(
nn.Linear(d, feature_dim)
)
self.softmax = nn.Softmax(dim=1)

def forward(self, x):
for i, conv in enumerate(self.convs):
feat = conv(x).unsqueeze_(dim=1)
if i == 0:
feas = feat
else:
feas = torch.cat((feas, feat), dim=1)

fea_U = torch.sum(feas, dim=1)
fea_s = self.gap(fea_U).squeeze_()
fea_z = self.fc(fea_s)
for i, fc in enumerate(self.fcs):
vector = fc(fea_z).unsqueeze_(dim=1)
if i == 0:
attention_vectors = vector
else:
attention_vectors = torch.cat((attention_vectors, vector), dim=1)
attention_vectors = self.softmax(attention_vectors)
attention_vectors = attention_vectors.unsqueeze(-1).unsqueeze(-1)
fea_v = (feas*attention_vectors).sum(dim=1)
return fea_v

CC-Net和Axial Attention

看论文时提到了CC-Net使用了交叉注意了.

参考Axial Attention 和 Criss-Cross Attention及其代码实现 | 码农家园 (codenong.com)这篇blog,写的不错.

Axial Attention

轴向注意力,Axial Attention 的感受野是目标像素的同一行(或者同一列) 的W(或H)个像素

比如row attention

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#实现轴向注意力中的 row Attention
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Softmax

# device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device = torch.device('cuda:0' if torch.cuda.device_count() > 1 else 'cpu')

class RowAttention(nn.Module):

def __init__(self, in_dim, q_k_dim, device):
'''
Parameters
----------
in_dim : int
channel of input img tensor
q_k_dim: int
channel of Q, K vector
device : torch.device
'''
super(RowAttention, self).__init__()
self.in_dim = in_dim
self.q_k_dim = q_k_dim
self.device = device

self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels = self.q_k_dim, kernel_size=1)
self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels = self.q_k_dim, kernel_size=1)
self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels = self.in_dim, kernel_size=1)
self.softmax = Softmax(dim=2)
self.gamma = nn.Parameter(torch.zeros(1)).to(self.device)

def forward(self, x):
'''
Parameters
----------
x : Tensor
4-D , (batch, in_dims, height, width) -- (b,c1,h,w)
'''

## c1 = in_dims; c2 = q_k_dim
b, _, h, w = x.size()

Q = self.query_conv(x) #size = (b,c2, h,w)
K = self.key_conv(x) #size = (b, c2, h, w)
V = self.value_conv(x) #size = (b, c1,h,w)

Q = Q.permute(0,2,1,3).contiguous().view(b*h, -1,w).permute(0,2,1) #size = (b*h,w,c2)
K = K.permute(0,2,1,3).contiguous().view(b*h, -1,w) #size = (b*h,c2,w)
V = V.permute(0,2,1,3).contiguous().view(b*h, -1,w) #size = (b*h, c1,w)

#size = (b*h,w,w) [:,i,j] 表示Q的所有h的第 Wi行位置上所有通道值与 K的所有h的第 Wj列位置上的所有通道值的乘积,
# 即(1,c2) * (c2,1) = (1,1)
row_attn = torch.bmm(Q,K)
########
#此时的 row_atten的[:,i,0:w] 表示Q的所有h的第 Wi行位置上所有通道值与 K的所有行的 所有列(0:w)的逐个位置上的所有通道值的乘积
#此操作即为 Q的某个(i,j)与 K的(i,0:w)逐个位置的值的乘积,得到行attn
########

#对row_attn进行softmax
row_attn = self.softmax(row_attn) #对列进行softmax,即[k,i,0:w] ,某一行的所有列加起来等于1,

#size = (b*h,c1,w) 这里先需要对row_atten进行 行列置换,使得某一列的所有行加起来等于1
#[:,i,j]即为V的所有行的某个通道上,所有列的值 与 row_attn的行的乘积,即求权重和
out = torch.bmm(V,row_attn.permute(0,2,1))
#size = (b,c1,h,2)
out = out.view(b,h,-1,w).permute(0,2,1,3)
out = self.gamma*out + x
return out
#实现轴向注意力中的 cols Attention
x = torch.randn(4, 8, 16, 20).to(device)
row_attn = RowAttention(in_dim = 8, q_k_dim = 4,device = device).to(device)
print(row_attn(x).size())

列注意力同理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#实现轴向注意力中的 column Attention
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Softmax

# device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device = torch.device('cuda:0' if torch.cuda.device_count() > 1 else 'cpu')

class ColAttention(nn.Module):

def __init__(self, in_dim, q_k_dim, device):
'''
Parameters
----------
in_dim : int
channel of input img tensor
q_k_dim: int
channel of Q, K vector
device : torch.device
'''
super(ColAttention, self).__init__()
self.in_dim = in_dim
self.q_k_dim = q_k_dim
self.device = device

self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels = self.q_k_dim, kernel_size=1)
self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels = self.q_k_dim, kernel_size=1)
self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels = self.in_dim, kernel_size=1)
self.softmax = Softmax(dim=2)
self.gamma = nn.Parameter(torch.zeros(1)).to(self.device)

def forward(self, x):
'''
Parameters
----------
x : Tensor
4-D , (batch, in_dims, height, width) -- (b,c1,h,w)
'''

## c1 = in_dims; c2 = q_k_dim
b, _, h, w = x.size()

Q = self.query_conv(x) #size = (b,c2, h,w)
K = self.key_conv(x) #size = (b, c2, h, w)
V = self.value_conv(x) #size = (b, c1,h,w)

Q = Q.permute(0,3,1,2).contiguous().view(b*w, -1,h).permute(0,2,1) #size = (b*w,h,c2)
K = K.permute(0,3,1,2).contiguous().view(b*w, -1,h) #size = (b*w,c2,h)
V = V.permute(0,3,1,2).contiguous().view(b*w, -1,h) #size = (b*w,c1,h)

#size = (b*w,h,h) [:,i,j] 表示Q的所有W的第 Hi行位置上所有通道值与 K的所有W的第 Hj列位置上的所有通道值的乘积,
# 即(1,c2) * (c2,1) = (1,1)
col_attn = torch.bmm(Q,K)
########
#此时的 col_atten的[:,i,0:w] 表示Q的所有W的第 Hi行位置上所有通道值与 K的所有W的 所有列(0:h)的逐个位置上的所有通道值的乘积
#此操作即为 Q的某个(i,j)与 K的(i,0:h)逐个位置的值的乘积,得到列attn
########

#对row_attn进行softmax
col_attn = self.softmax(col_attn) #对列进行softmax,即[k,i,0:w] ,某一行的所有列加起来等于1,

#size = (b*w,c1,h) 这里先需要对col_atten进行 行列置换,使得某一列的所有行加起来等于1
#[:,i,j]即为V的所有行的某个通道上,所有列的值 与 col_attn的行的乘积,即求权重和
out = torch.bmm(V,col_attn.permute(0,2,1))

#size = (b,c1,h,w)
out = out.view(b,w,-1,h).permute(0,2,3,1)

out = self.gamma*out + x

return out

#实现轴向注意力中的 cols Attention
x = torch.randn(4, 8, 16, 20).to(device)
col_attn = ColAttention(8, 4, device = device)

print(col_attn(x).size())

Criss-Cross Attention Module 2019

img

CC-Attention 的感受野是与目标像素的同一行和同一列的(H + W - 1)个像素,目标元素的同一行和同一列.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
class CrissCrossAttention(nn.Module):
""" Criss-Cross Attention Module

reference: https://github.com/speedinghzl/CCNet

"""
def __init__(self, in_dim):
super(CrissCrossAttention,self).__init__()


self.query_conv = nn.Sequential(
nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1),
nn.BatchNorm2d(in_dim,eps=1e-5, momentum=0.01, affine=True),
nn.ReLU()
)
self.key_conv = nn.Sequential(
nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1),
nn.BatchNorm2d(in_dim,eps=1e-5, momentum=0.01, affine=True),
nn.ReLU()
)
self.value_conv = nn.Sequential(
nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1),
nn.BatchNorm2d(in_dim,eps=1e-5, momentum=0.01, affine=True),
nn.ReLU()
)


self.softmax = Softmax(dim=3)
self.INF = INF
self.gamma = nn.Parameter(torch.zeros(1))


def forward(self, query, key, value):
m_batchsize, _, height, width = query.size()


proj_query = self.query_conv(query)
proj_query_H = proj_query.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height).permute(0, 2, 1)
proj_query_W = proj_query.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width).permute(0, 2, 1)


proj_key = self.key_conv(key)
proj_key_H = proj_key.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)
proj_key_W = proj_key.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width)


proj_value = self.value_conv(value)
proj_value_H = proj_value.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)
proj_value_W = proj_value.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width)
energy_H = (torch.bmm(proj_query_H, proj_key_H)+self.INF(m_batchsize, height, width)).view(m_batchsize,width,height,height).permute(0,2,1,3)
energy_W = torch.bmm(proj_query_W, proj_key_W).view(m_batchsize,height,width,width)
concate = self.softmax(torch.cat([energy_H, energy_W], 3))

att_H = concate[:,:,:,0:height].permute(0,2,1,3).contiguous().view(m_batchsize*width,height,height)
att_W = concate[:,:,:,height:height+width].contiguous().view(m_batchsize*height,width,width)
out_H = torch.bmm(proj_value_H, att_H.permute(0, 2, 1)).view(m_batchsize,width,-1,height).permute(0,2,3,1)
out_W = torch.bmm(proj_value_W, att_W.permute(0, 2, 1)).view(m_batchsize,height,-1,width).permute(0,2,1,3)
return self.gamma*(out_H + out_W) + value

img

Coordinate Attention 2021

image-20240119195945841

在通道注意力的基础上兼顾其位置关系,将通道注意力与空间注意力联合起来

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
class h_sigmoid(nn.Module):
def __init__(self, inplace=True):
super(h_sigmoid, self).__init__()
self.relu = nn.ReLU6(inplace=inplace)

def forward(self, x):
return self.relu(x + 3) / 6


class h_swish(nn.Module):
def __init__(self, inplace=True):
super(h_swish, self).__init__()
self.sigmoid = h_sigmoid(inplace=inplace)

def forward(self, x):
return x * self.sigmoid(

class CA(nn.Module):
def __init__(self, inp, reduction):
super(CA, self).__init__()
# h:height(行) w:width(列)
self.pool_h = nn.AdaptiveAvgPool2d((None, 1)) # (b,c,h,w)-->(b,c,h,1)
self.pool_w = nn.AdaptiveAvgPool2d((1, None)) # (b,c,h,w)-->(b,c,1,w)

# mip = max(8, inp // reduction) 论文作者所用
mip = inp // reduction

self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
self.bn1 = nn.BatchNorm2d(mip)
self.act = h_swish()

self.conv_h = nn.Conv2d(mip, inp, kernel_size=1, stride=1, padding=0)
self.conv_w = nn.Conv2d(mip, inp, kernel_size=1, stride=1, padding=0)

def forward(self, x):
identity = x

n, c, h, w = x.size()
x_h = self.pool_h(x) # (b,c,h,1)
x_w = self.pool_w(x).permute(0, 1, 3, 2) # (b,c,w,1)

y = torch.cat([x_h, x_w], dim=2)
y = self.conv1(y)
y = self.bn1(y)
y = self.act(y)

x_h, x_w = torch.split(y, [h, w], dim=2)
x_w = x_w.permute(0, 1, 3, 2)

a_h = self.conv_h(x_h).sigmoid()
a_w = self.conv_w(x_w).sigmoid()

out = identity * a_w * a_h

return out

Attentional Feature Fusion 2021

WACV 2021 Open Access Repository (thecvf.com)

YimianDai/open-aff: code and trained models for “Attentional Feature Fusion” (github.com)

这些注意力模块通常用在一些block(或叫unit)块中,然后一般把这些块放到多尺度的网络下

参考资料

  1. Axial Attention 和 Criss-Cross Attention及其代码实现_cross attention代码-CSDN博客
  2. sknet阅读笔记及pytorch实现代码_pytorch sknet-CSDN博客
  3. 【注意力机制集锦】Channel Attention通道注意力网络结构、源码解读系列一_通道注意力机制结构图-CSDN博客
  4. 【注意力机制集锦2】BAM&SGE&DAN原文、结构、源码详解_bam注意力机制-CSDN博客

Thanks to lyp2333/External-Attention-pytorch (github.com) and xmu-xiaoma666/External-Attention-pytorch: 🍀 Pytorch implementation of various Attention Mechanisms, MLP, Re-parameter, Convolution, which is helpful to further understand papers.⭐⭐⭐ (github.com)

-------------本文结束感谢您的阅读-------------
感谢阅读.

欢迎关注我的其它发布渠道