transformer and attention(三)

这里介绍一些细节信息.有关位置编码信息和用于图像的transformer.

线性注意力

其中$Q\in\mathbb{R}^{n\times d_k},\boldsymbol{K}\in\mathbb{R}^{m\times d_k},\boldsymbol{V}\in\mathbb{R}^{m\times d_v}$​,一般情况下n>d甚至n>>d.所以如果对QK^T^进行softmax操作,复杂度为O(mn),所以去掉Softmax的Attention的复杂度可以降到最理想的线性级别Linear Attention.

只要保证Attention相似的分布特性,要求sim(q~i~,k~j~)≥0恒成立.比如可以把核函数改为激活函数使得输出大于0.

还可以改成softmax.

image-20240217224419083

其中softmax1、softmax2分别指在第一个(n)、第二个维度(d)进行Softmax运算.

线性Attention的探索:Attention必须有个Softmax吗? - 科学空间|Scientific Spaces提出将指数

e^qK^泰勒展开,$e^{\boldsymbol{q}_i^\top\boldsymbol{k}_j}\approx1+\boldsymbol{q}_i^\top\boldsymbol{k}_j$

image-20240217224836831

此外还有稀疏注意力,这里就不多介绍了.

图像中的transformer与attention

注意力机制以及transformer都是先在NLP领域发展,所以一般attention可能会处理一些1维数据,有CNN与transformer结合的Conformer[2005.08100] Conformer: Convolution-augmented Transformer for Speech Recognition (arxiv.org),conformer中的编码采用相对位置编码.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
import torch
import torch.nn.functional as F
from einops import rearrange
from einops.layers.torch import Rearrange
from torch import nn, einsum


def exists(val):
return val is not None


def default(val, d):
return val if exists(val) else d


class Swish(nn.Module):
def forward(self, x):
return x * x.sigmoid()


class FeedForward(nn.Module):
def __init__(self, dim, mult=4, dropout=0.0):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * mult),
Swish(), # or can be replace by nn.silu()
nn.Dropout(dropout),
nn.Linear(dim * mult, dim),
nn.Dropout(dropout),
)

def forward(self, x):
return self.net(x)


class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0, max_pos_emb=512):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head**-0.5
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim)

self.max_pos_emb = max_pos_emb
self.rel_pos_emb = nn.Embedding(2 * max_pos_emb + 1, dim_head)

self.dropout = nn.Dropout(dropout)

def forward(self, x, context=None, mask=None, context_mask=None):
n, device, h, max_pos_emb, has_context = (
x.shape[-2],
x.device,
self.heads,
self.max_pos_emb,
exists(context),
)
context = default(context, x)

q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim=-1))
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))

dots = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale

# shaw's relative positional embedding
seq = torch.arange(n, device=device)
dist = rearrange(seq, "i -> i ()") - rearrange(seq, "j -> () j")
dist = dist.clamp(-max_pos_emb, max_pos_emb) + max_pos_emb
rel_pos_emb = self.rel_pos_emb(dist).to(q)
pos_attn = einsum("b h n d, n r d -> b h n r", q, rel_pos_emb) * self.scale
dots = dots + pos_attn

if exists(mask) or exists(context_mask):
mask = default(mask, lambda: torch.ones(*x.shape[:2], device=device))
context_mask = (
default(context_mask, mask)
if not has_context
else default(
context_mask, lambda: torch.ones(*context.shape[:2], device=device)
)
)
mask_value = -torch.finfo(dots.dtype).max
mask = rearrange(mask, "b i -> b () i ()") * rearrange(
context_mask, "b j -> b () () j"
)
dots.masked_fill_(~mask, mask_value)

attn = dots.softmax(dim=-1)

out = einsum("b h i j, b h j d -> b h i d", attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
out = self.to_out(out)
return self.dropout(out)


def calc_same_padding(kernel_size):
pad = kernel_size // 2
return pad, pad - (kernel_size + 1) % 2


class DepthWiseConv1d(nn.Module):
def __init__(self, chan_in, chan_out, kernel_size, padding):
super().__init__()
self.padding = padding
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups=chan_in)

def forward(self, x):
x = F.pad(x, self.padding)
return self.conv(x)


class GLU(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim

def forward(self, x):
out, gate = x.chunk(2, dim=self.dim)
return out * gate.sigmoid()


class ConformerConvModule(nn.Module):
def __init__(
self, dim, causal=False, expansion_factor=2, kernel_size=31, dropout=0.0
):
super().__init__()
inner_dim = dim * expansion_factor
padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0)
self.net = nn.Sequential(
nn.LayerNorm(dim),
Rearrange("b n d -> b d n"),
nn.Conv1d(dim, inner_dim * 2, 1),
GLU(dim=1),
DepthWiseConv1d(
inner_dim, inner_dim, kernel_size=kernel_size, padding=padding
),
nn.BatchNorm1d(inner_dim) if not causal else nn.Identity(),
Swish(),
nn.Conv1d(inner_dim, dim, 1),
Rearrange("b d n -> b n d"),
nn.Dropout(dropout),
)

def forward(self, x):
return self.net(x)


class Scale(nn.Module):
def __init__(self, scale, fn):
super().__init__()
self.scale = scale
self.fn = fn

def forward(self, x, **kwargs):
return self.fn(x, **kwargs) * self.scale


class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)

def forward(self, x, **kwargs):
x = self.norm(x)
return self.fn(x, **kwargs)


class ConformerBlock(nn.Module):
def __init__(
self,
*,
dim,
dim_head=64,
heads=8,
ff_mult=4,
conv_expansion_factor=2,
conv_kernel_size=31,
attn_dropout=0.0,
ff_dropout=0.0,
conv_dropout=0.0,
conv_causal=False
):
super().__init__()
self.ff1 = FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
self.attn = Attention(
dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout
)
self.conv = ConformerConvModule(
dim=dim,
causal=conv_causal,
expansion_factor=conv_expansion_factor,
kernel_size=conv_kernel_size,
dropout=conv_dropout,
)
self.ff2 = FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)

self.attn = PreNorm(dim, self.attn)
self.ff1 = Scale(0.5, PreNorm(dim, self.ff1))
self.ff2 = Scale(0.5, PreNorm(dim, self.ff2))

self.post_norm = nn.LayerNorm(dim)

def forward(self, x, mask=None):
x = self.ff1(x) + x
x = self.attn(x, mask=mask) + x
x = self.conv(x) + x
x = self.ff2(x) + x
x = self.post_norm(x)
return x


class Conformer(nn.Module):
def __init__(
self,
dim,
*,
depth,
dim_head=64,
heads=8,
ff_mult=4,
conv_expansion_factor=2,
conv_kernel_size=31,
attn_dropout=0.0,
ff_dropout=0.0,
conv_dropout=0.0,
conv_causal=False
):
super().__init__()
self.dim = dim
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
ConformerBlock(
dim=dim,
dim_head=dim_head,
heads=heads,
ff_mult=ff_mult,
conv_expansion_factor=conv_expansion_factor,
conv_kernel_size=conv_kernel_size,
conv_causal=conv_causal,
)
)

def forward(self, x):
for block in self.layers:
x = block(x)
return x

上一节中其实已经充分使用了feature map也就是二维数据上的注意力机制,现在介绍一下在视觉领域表现出色的transformer及其变体.

Vision Transformer

image-20240217121859843

将transformer拿到CV领域的出名作品,通过patch embedding得到序列,再加上位置编码就能像在nlp一样处理问题.

img

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import torch
from einops import rearrange
from einops.layers.torch import Rearrange
from torch import nn
# helpers

def pair(t):
return t if isinstance(t, tuple) else (t, t)

def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype=torch.float32):
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
omega = torch.arange(dim // 4) / (dim // 4 - 1)
omega = 1.0 / (temperature**omega)

y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
return pe.type(dtype)


# classes
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)

def forward(self, x):
return self.net(x)

class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head**-0.5
self.norm = nn.LayerNorm(dim)

self.attend = nn.Softmax(dim=-1)

self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)

def forward(self, x):
x = self.norm(x)

qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv)

dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

attn = self.attend(dots)

out = torch.matmul(attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
return self.to_out(out)


class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
Attention(dim, heads=heads, dim_head=dim_head),
FeedForward(dim, mlp_dim),
]
)
)

def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return self.norm(x)


class SimpleViT(nn.Module):
def __init__(
self,
*,
image_size,
patch_size,
num_classes,
dim,
depth,
heads,
mlp_dim,
channels=3,
dim_head=64
):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)

assert (
image_height % patch_height == 0 and image_width % patch_width == 0
), "Image dimensions must be divisible by the patch size."

patch_dim = channels * patch_height * patch_width

self.to_patch_embedding = nn.Sequential(
Rearrange(
"b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
p1=patch_height,
p2=patch_width,
),
nn.LayerNorm(patch_dim),
nn.Linear(patch_dim, dim),
nn.LayerNorm(dim),
)

self.pos_embedding = posemb_sincos_2d(
h=image_height // patch_height,
w=image_width // patch_width,
dim=dim,
)

self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)

self.pool = "mean"
self.to_latent = nn.Identity()

self.linear_head = nn.Linear(dim, num_classes)

def forward(self, img):
device = img.device

x = self.to_patch_embedding(img)
x += self.pos_embedding.to(device, dtype=x.dtype)

x = self.transformer(x)
x = x.mean(dim=1)

x = self.to_latent(x)
return self.linear_head(x)

上面做了patch之后的位置编码使用三角函数绝对编码,attention和feednetwork与transformer没有什么差别.

卷积注意力

使用vision transformer中使用的绝对位置注意力,但是也可以使用相对位置注意力或者卷积注意力.

卷积位置嵌入( CPE )方法考虑了输入序列的2D性质。采用补零的方式进行2D卷积采集位置信息。卷积位置嵌入( Convolutional Position嵌入,CPE )可用于合并ViT不同阶段的位置数据。CPE可以具体引入到自注意力模块,前馈网络,或者在两个编码器层之间的。

卷积注意力通常方法是利用2D卷积或者depth-wise的卷积将已经做了patch的图像数据进行处理.

1
2
3
4
5
6
7
8
9
10
11
12
class ConvolutionalPositionEmbedding(nn.Module):
def __init__(self, d_model, kernel_size=3, padding=1):
super().__init__()
self.conv = nn.Conv2d(d_model, d_model, kernel_size, padding=padding)

def forward(self, x):
x = x.transpose(1, 2) # 将通道维度和序列长度维度交换
x = x.unsqueeze(2) # 在通道维度和序列长度维度之间添加一个维度
x = self.conv(x) # 对输入进行卷积操作
x = x.squeeze(2) # 移除添加的维度
x = x.transpose(1, 2) # 将通道维度和序列长度维度交换回来
return x

CVT

img

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
#   #!/usr/bin/env python
# #-*- coding:utf-8 -*-
# Copyleft (C) 2024 proanimer, Inc. All Rights Reserved
# author:proanimer
# createTime:2024/2/18 上午10:38
# lastModifiedTime:2024/2/18 上午10:38
# file:cvt.py
# software: classicNets
#
import torch
import torch.nn as nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from torch import einsum


class SepConv2d(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
):
super(SepConv2d, self).__init__()
self.depthwise = torch.nn.Conv2d(
in_channels,
in_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=in_channels,
)
self.bn = torch.nn.BatchNorm2d(in_channels)
self.pointwise = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1)

def forward(self, x):
x = self.depthwise(x)
x = self.bn(x)
x = self.pointwise(x)
return x


class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn

def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x


class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn

def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)


class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.0):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout),
)

def forward(self, x):
return self.net(x)


class ConvAttention(nn.Module):
def __init__(
self,
dim,
img_size,
heads=8,
dim_head=64,
kernel_size=3,
q_stride=1,
k_stride=1,
v_stride=1,
dropout=0.0,
last_stage=False,
):
super().__init__()
self.last_stage = last_stage
self.img_size = img_size
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)

self.heads = heads
self.scale = dim_head**-0.5
pad = (kernel_size - q_stride) // 2
self.to_q = SepConv2d(dim, inner_dim, kernel_size, q_stride, pad)
self.to_k = SepConv2d(dim, inner_dim, kernel_size, k_stride, pad)
self.to_v = SepConv2d(dim, inner_dim, kernel_size, v_stride, pad)

self.to_out = (
nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
if project_out
else nn.Identity()
)

def forward(self, x):
b, n, _, h = *x.shape, self.heads
if self.last_stage:
cls_token = x[:, 0]
x = x[:, 1:]
cls_token = rearrange(cls_token.unsqueeze(1), "b n (h d) -> b h n d", h=h)
x = rearrange(x, "b (l w) n -> b n l w", l=self.img_size, w=self.img_size)
q = self.to_q(x)
q = rearrange(q, "b (h d) l w -> b h (l w) d", h=h)

v = self.to_v(x)
v = rearrange(v, "b (h d) l w -> b h (l w) d", h=h)

k = self.to_k(x)
k = rearrange(k, "b (h d) l w -> b h (l w) d", h=h)

if self.last_stage:
q = torch.cat((cls_token, q), dim=2)
v = torch.cat((cls_token, v), dim=2)
k = torch.cat((cls_token, k), dim=2)

dots = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale

attn = dots.softmax(dim=-1)

out = einsum("b h i j, b h j d -> b h i d", attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
out = self.to_out(out)
return out


class Transformer(nn.Module):
def __init__(
self,
dim,
img_size,
depth,
heads,
dim_head,
mlp_dim,
dropout=0.0,
last_stage=False,
):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
PreNorm(
dim,
ConvAttention(
dim,
img_size,
heads=heads,
dim_head=dim_head,
dropout=dropout,
last_stage=last_stage,
),
),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout)),
]
)
)

def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x


class cvt(nn.Module):
def __init__(
self,
image_size,
in_channels,
num_classes,
dim=64,
kernels=[7, 3, 3],
strides=[4, 2, 2],
heads=[1, 3, 6],
depth=[1, 2, 10],
pool="cls",
dropout=0.0,
emb_dropout=0.0,
scale_dim=4,
):
super(cvt, self).__init__()
assert pool in {
"cls",
"mean",
}, "pool type must be either cls (cls token) or mean (mean pooling)"
self.pool = pool
self.dim = dim
self.stage1_conv_embed = nn.Sequential(
nn.Conv2d(in_channels, dim, kernels[0], strides[0], 2),
Rearrange("b c h w -> b (h w) c", h=image_size // 4, w=image_size // 4),
nn.LayerNorm(dim),
)
self.stage_1_transformer = nn.Sequential(
Transformer(
dim,
img_size=image_size // 4,
depth=depth[0],
heads=heads[0],
dim_head=dim // heads[0],
mlp_dim=dim * scale_dim,
dropout=dropout,
last_stage=True,
),
Rearrange("b (h w) c -> b c h w", h=image_size // 4, w=image_size // 4),
)
# stage 2
in_channels = dim
scale = heads[1] // heads[0]
dim = scale * dim
self.stage2_conv_embed = nn.Sequential(
nn.Conv2d(in_channels, dim, kernels[1], strides[1], 1),
Rearrange("b c h w -> b (h w) c", h=image_size // 8, w=image_size // 8),
nn.LayerNorm(dim),
)
self.stage_2_transformer = nn.Sequential(
Transformer(
dim,
img_size=image_size // 8,
depth=depth[1],
heads=heads[1],
dim_head=dim // heads[1],
mlp_dim=dim * scale_dim,
dropout=dropout,
last_stage=True,
),
Rearrange("b (h w) c -> b c h w", h=image_size // 8, w=image_size // 8),
)
# stage 3
in_channels = dim
scale = heads[2] // heads[1]
dim = scale * dim
self.stage3_conv_embed = nn.Sequential(
nn.Conv2d(in_channels, dim, kernels[2], strides[2], 1),
Rearrange("b c h w -> b (h w) c", h=image_size // 16, w=image_size // 16),
nn.LayerNorm(dim),
)
self.stage_3_transformer = nn.Sequential(
Transformer(
dim=dim,
img_size=image_size // 16,
depth=depth[2],
heads=heads[2],
dim_head=self.dim,
mlp_dim=dim * scale_dim,
dropout=dropout,
last_stage=True,
),
)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.drop_large = nn.Dropout(emb_dropout)

self.mlp_head = nn.Sequential(nn.LayerNorm(dim), nn.Linear(dim, num_classes))

def forward(self,img):
xs = self.stage1_conv_embed(img)
xs = self.stage1_transformer(xs)

xs = self.stage2_conv_embed(xs)
xs = self.stage2_transformer(xs)

xs = self.stage3_conv_embed(xs)
b, n, _ = xs.shape
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
xs = torch.cat((cls_tokens, xs), dim=1)
xs = self.stage3_transformer(xs)
xs = xs.mean(dim=1) if self.pool == 'mean' else xs[:, 0]
xs = self.mlp_head(xs)
return xs

PVT

image-20240218105527163

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
#   #!/usr/bin/env python
# #-*- coding:utf-8 -*-
# Copyleft (C) 2024 proanimer, Inc. All Rights Reserved
# author:proanimer
# createTime:2024/2/18 下午2:22
# lastModifiedTime:2024/2/18 下午2:22
# file:pvt.py
# software: classicNets
#
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import DropPath, to_2tuple, trunc_normal_


class Mlp(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.0,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)

def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x


class Attention(nn.Module):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
sr_ratio=1,
):
super().__init__()
assert (
dim % num_heads == 0
), f"dim {dim} should be divided by num_heads {num_heads}."

self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5

self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)

self.sr_ratio = sr_ratio
if sr_ratio > 1:
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
self.norm = nn.LayerNorm(dim)

def forward(self, x, H, W):
B, N, C = x.shape
q = (
self.q(x)
.reshape(B, N, self.num_heads, C // self.num_heads)
.permute(0, 2, 1, 3)
)

if self.sr_ratio > 1:
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
x_ = self.norm(x_)
kv = (
self.kv(x_)
.reshape(B, -1, 2, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
else:
kv = (
self.kv(x)
.reshape(B, -1, 2, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
k, v = kv[0], kv[1]

attn = (q @ k.transpose(-2, -1)) * self.scale # q (B,H,N,C) K(B,H,C,N)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)

x = (
(attn @ v).transpose(1, 2).reshape(B, N, C)
) # (B,H,N,N) @ (B,H,N,C) -> (B,H,N,C)
x = self.proj(x)
x = self.proj_drop(x)

return x


class Block(nn.Module):
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
sr_ratio=1,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
sr_ratio=sr_ratio,
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)

def forward(self, x, H, W):
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
x = x + self.drop_path(self.mlp(self.norm2(x)))

return x


class PatchEmbed(nn.Module):
"""Image to Patch Embedding"""

def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)

self.img_size = img_size
self.patch_size = patch_size
assert (
img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0
), f"img_size {img_size} should be divided by patch_size {patch_size}."
self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
self.num_patches = self.H * self.W
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
)
self.norm = nn.LayerNorm(embed_dim)

def forward(self, x):
B, C, H, W = x.shape

x = (
self.proj(x).flatten(2).transpose(1, 2)
) # B,C,H,W->B,embed_dim,seq*seq->B,seq*seq,embed_dim
x = self.norm(x)
H, W = H // self.patch_size[0], W // self.patch_size[1]

return x, (H, W)


class PyramidVisionTransformer(nn.Module):
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
embed_dims=[64, 128, 256, 512],
num_heads=[1, 2, 4, 8],
mlp_ratios=[4, 4, 4, 4],
qkv_bias=False,
qk_scale=None,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.0,
norm_layer=nn.LayerNorm,
depths=[3, 4, 6, 3],
sr_ratios=[8, 4, 2, 1],
F4=False,
num_stages=4,
):
super().__init__()
self.depths = depths
self.F4 = F4
self.num_stages = num_stages

dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
] # stochastic depth decay rule
cur = 0

for i in range(num_stages):
patch_embed = PatchEmbed(
img_size=img_size if i == 0 else img_size // (2 ** (i + 1)),
patch_size=patch_size if i == 0 else 2,
in_chans=in_chans if i == 0 else embed_dims[i - 1],
embed_dim=embed_dims[i],
) # [B,seq=num_patches,dim=patch_size**2*embed_dim]
num_patches = (
patch_embed.num_patches
if i != num_stages - 1
else patch_embed.num_patches + 1
)
pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dims[i]))
pos_drop = nn.Dropout(p=drop_rate)

block = nn.ModuleList(
[
Block(
dim=embed_dims[i],
num_heads=num_heads[i],
mlp_ratio=mlp_ratios[i],
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[cur + j],
norm_layer=norm_layer,
sr_ratio=sr_ratios[i],
)
for j in range(depths[i])
]
)
cur += depths[i]

setattr(self, f"patch_embed{i + 1}", patch_embed)
setattr(self, f"pos_embed{i + 1}", pos_embed)
setattr(self, f"pos_drop{i + 1}", pos_drop)
setattr(self, f"block{i + 1}", block)

trunc_normal_(pos_embed, std=0.02)

# init weights
self.apply(self._init_weights)
# self.init_weights(pretrained)

def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)

def _get_pos_embed(self, pos_embed, patch_embed, H, W):
if H * W == self.patch_embed1.num_patches:
return pos_embed
else:
return (
F.interpolate(
pos_embed.reshape(1, patch_embed.H, patch_embed.W, -1).permute(
0, 3, 1, 2
),
size=(H, W),
mode="bilinear",
)
.reshape(1, -1, H * W)
.permute(0, 2, 1)
)

def forward_features(self, x):
outs = []
B = x.shape[0]
for i in range(self.num_stages):
patch_embed = getattr(self, f"patch_embed{i + 1}")
pos_embed = getattr(self, f"pos_embed{i + 1}")
pos_drop = getattr(self, f"pos_drop{i + 1}")
block = getattr(self, f"block{i + 1}")
x, (H, W) = patch_embed(x)
if i == self.num_stages - 1:
pos_embed = self._get_pos_embed(pos_embed[:, 1:], patch_embed, H, W)
else:
pos_embed = self._get_pos_embed(pos_embed, patch_embed, H, W)

x = pos_drop(x + pos_embed)
for blk in block:
x = blk(x, H, W)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
outs.append(x)

return outs

def forward(self, x):
x = self.forward_features(x)

if self.F4:
x = x[3:4]
return x

PVT v2

2106.13797.pdf (arxiv.org)对之前的pvt进行了改进,包括空间大小降低放的方法,patch embdedding改为了有重叠区域的patch embedding.FeedNetwork中加了depth-wise卷积.

image-20240221212026465

image-20240221212055900

CPVT中的PEG

image-20240219150034479

conditional position encoding

image-20240218103528794

出自论文2102.10882.pdf (arxiv.org)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
import torch.nn as nn
class PEG(nn.Module):
def __init__(self, dim=256, k=3):
self.proj = nn.Conv2d(dim, dim, k, 1, k//2, groups=dim)
# Only for demo use, more complicated functions are effective too.
def forward(self, x, H, W):
B, N, C = x.shape
cls_token, feat_token = x[:, 0], x[:, 1:] # cls token不参与PEG
cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W)
x = self.proj(cnn_feat) + cnn_feat # 产生PE加上自身
x = x.flatten(2).transpose(1, 2)
x = torch.cat((cls_token.unsqueeze(1), x), dim=1)
return x

class VisionTransformer:
def __init__(layers=12, dim=192, nhead=3, img_size=224, patch_size=16):
self.pos_block = PEG(dim)
self.blocks = nn.ModuleList([TransformerEncoderLayer(dim
, nhead, dim*4) for _ in range(layers)])
self.patch_embed = PatchEmbed(img_size, patch_size, dim
*4)
def forward_features(self, x):
B, C, H, W = x.shape
x, patch_size = self.patch_embed(x)
_H, _W = H // patch_size, W // patch_size
x = torch.cat((self.cls_tokens, x), dim=1)
for i, blk in enumerate(self.blocks):
x = blk(x)
if i == 0: # 第一个encoder之后施加PEG
x = self.pos_block(x, _H, _W)
return x[:, 0]

LocalVit

image-20240218105718876

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, patch_height, patch_width, scale = 4, depth_kernel = 3, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
Residual(PreNorm(dim, ConvFF(dim, scale, depth_kernel, patch_height, patch_width)))
]))
def forward(self, x):

for attn, convff in self.layers:
x = attn(x)
cls_tokens = x[:, 0]
x = convff(x[:, 1:])
x = torch.cat((cls_tokens.unsqueeze(1), x), dim=1)
return xclass ConvFF(nn.Module):

def __init__(self, dim = 192, scale = 4, depth_kernel = 3, patch_height = 14, patch_width = 14, dropout=0.):
super().__init__()

scale_dim = dim*scale
self.up_proj = nn.Sequential(
Rearrange('b (h w) c -> b c h w', h=patch_height, w=patch_width),
nn.Conv2d(dim, scale_dim, kernel_size=1),
nn.Hardswish()
)

self.depth_conv = nn.Sequential(
nn.Conv2d(scale_dim, scale_dim, kernel_size=depth_kernel, padding=1, groups=scale_dim, bias=True),
nn.Conv2d(scale_dim, scale_dim, kernel_size=1, bias=True),
nn.Hardswish()
)

self.down_proj = nn.Sequential(
nn.Conv2d(scale_dim, dim, kernel_size=1),
nn.Dropout(dropout),
Rearrange('b c h w ->b (h w) c')
)

在feed-forward中使用2d的卷积.

transformer中的绝对和相对位置编码

位置编码可以分为使用nn.Embedding或者nn.Parameter的可学习参数,也可以直接使用固定的值,比如三角函数编码.此外可以分为相对位置和绝对位置编码

绝对位置编码

transformer中使用了位置编码信息,被认为是绝对位置编码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
class PositionalEncoding(nn.Module):
"Implement the PE function."

def __init__(self, d_model, dropout, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)

# Compute the positional encodings once in log space.
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) *
-(math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)

def forward(self, x):
x = x + Variable(self.pe[:, :x.size(1)],
requires_grad=False)
return self.dropout(x)

我们可能希望使用相对位置编码而不是绝对位置编码,原因有很多。首先,使用绝对位置信息必然意味着模型可以处理的token数量有限制。假设一个语言模型最多只能编码1024个位置。这必然意味着任何长于1024个token的序列都不能被模型处理;相对位置编码可以推广到看不见长度的序列,因为理论上它编码的唯一信息是两个标记之间的相对成对距离。

相对位置编码的历史

相对位置嵌入( Relative Position Embedding,RPE )技术主要用于将与相对位置相关的信息纳入到注意力模块中。该技术基于这样的思想:块之间的空间关系比它们的绝对位置承载更多的权重。为了计算RPE值,使用了基于可学习参数的查找表。查找过程由图像patch间的相对距离决定。虽然RPE技术可以扩展到不同长度的序列,但它可能会增加训练和测试时间。

attention is all you need中的attention中,自我注意力可以表述为如下,并使用三角函数索引进行位置编码.

1D数据

Shaw

相对位置编码在swin-transformer以及Self-Attention with Relative Position Representations中都有体现.较早的论文1803.02155.pdf (arxiv.org)

其中的w^k^和w^v^是需要训练的参数.

以下是1803.02155.pdf (arxiv.org)中的相对位置注意力

image-20240216225108501

img

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# shaw's relative positional embedding
seq = torch.arange(n, device=device)
dist = rearrange(seq, "i -> i ()") - rearrange(seq, "j -> () j")
dist = dist.clamp(-max_pos_emb, max_pos_emb) + max_pos_emb
rel_pos_emb = self.rel_pos_emb(dist).to(q)
pos_attn = einsum("b h n d, n r d -> b h n r", q, rel_pos_emb) * self.scale
dots = dots + pos_attn

if exists(mask) or exists(context_mask):
mask = default(mask, lambda: torch.ones(*x.shape[:2], device=device))
context_mask = (
default(context_mask, mask)
if not has_context
else default(
context_mask, lambda: torch.ones(*context.shape[:2], device=device)
)
)
mask_value = -torch.finfo(dots.dtype).max
mask = rearrange(mask, "b i -> b () i ()") * rearrange(
context_mask, "b j -> b () () j"
)
dots.masked_fill_(~mask, mask_value)

attn = dots.softmax(dim=-1)

out = einsum("b h i j, b h j d -> b h i d", attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
out = self.to_out(out)

transformer-xl

众所周知,q=xW~Q~,k=xW~K~,加入相对位置编码后,展开一般注意力公式有

img

img

Transformer-XL的做法很简单,直接将 $pj$ 替换为相对位置向量 $R{i-j}$, 至于两个 $p_i$ , 则干脆替换为两个可训练的问量 $u,v$

之后的改进也是基于此,并且不再改动计算V了.

在transformer-xl(或者也是XLNET中使用的编码)中

1
2
3
4
5
6
7
8
9
10
class PositionalEmbedding(nn.Module):
def __init__(self, demb):
super(PositionalEmbedding, self).__init__()
self.demb = demb
inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))

def forward(self, pos_seq):
sinusoid_inp = torch.outer(pos_seq, self.inv_freq) # 向量之间相乘
pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)
return pos_emb[:,None,:]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head)           # qlen x bsz x n_head x d_head
w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head

r_head_k = r_head_k.view(rlen, self.n_head, self.d_head) # qlen x n_head x d_head

#### compute attention score
rw_head_q = w_head_q + r_w_bias #加上biase # qlen x bsz x n_head x d_head
AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head

rr_head_q = w_head_q + r_r_bias #加上biase
BD = torch.einsum('ibnd,jnd->ijbn', (rr_head_q, r_head_k)) # qlen x klen x bsz x n_head
BD = self._rel_shift(BD)

# [qlen x klen x bsz x n_head]
attn_score = AC + BD
attn_score.mul_(self.scale)

其中u,v是两个可学习参数,W^R^是一个矩阵将s~i-j~投影到一个与位置相关的key向量.

Music transformer

后来Huang对shaw的相对位置编码进行改进

image-20240216225143335

Huang

此外还有2009.13658.pdf (arxiv.org)提出的

T5

img

DeBERTa

img

总结下来就是在计算attention权重时或者在计算最后的注意力时加上一个与相对位置信息相关的值.这个值的计算通常类似如下

1
2
3
4
5
# shaw's relative positional embedding
seq = torch.arange(n, device=device)
dist = rearrange(seq, "i -> i ()") - rearrange(seq, "j -> () j")
dist = dist.clamp(-max_pos_emb, max_pos_emb) + max_pos_emb
rel_pos_emb = self.rel_pos_emb(dist).to(q)

以上大多用于1D数据比如音频和文字.

2D数据

Stand-Alone Self-Attention in Vision Models

SASA

公式如下

对相对距离进行维度分解,每个元素ab∈N~k(i,j)~得到两个距离:行偏移量a-i和列偏移量b-j .

行偏移和列偏移分别与一个嵌入r~a-i~和r~b-j~相关联,每个嵌入维度为1/2d~out~,行偏移嵌入和列偏移嵌入被串联起来形成r~a-i,b-j~。

或者表示如下

其中p是可训练参数,长度是1/2d~z~

image-20240217180330619

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
import torch.nn as nn
import torch.nn.functional as F

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")


class SASA_Layer(nn.Module):
def __init__(self, in_channels, kernel_size=7, num_heads=8, image_size=224, inference=False):
super(SASA_Layer, self).__init__()
self.kernel_size = min(kernel_size, image_size) # receptive field shouldn't be larger than input H/W
self.num_heads = num_heads
self.dk = self.dv = in_channels
self.dkh = self.dk // self.num_heads
self.dvh = self.dv // self.num_heads

assert self.dk % self.num_heads == 0, "dk should be divided by num_heads. (example: dk: 32, num_heads: 8)"
assert self.dk % self.num_heads == 0, "dv should be divided by num_heads. (example: dv: 32, num_heads: 8)"

self.k_conv = nn.Conv2d(self.dk, self.dk, kernel_size=1).to(device)
self.q_conv = nn.Conv2d(self.dk, self.dk, kernel_size=1).to(device)
self.v_conv = nn.Conv2d(self.dv, self.dv, kernel_size=1).to(device)

# Positional encodings
self.rel_encoding_h = nn.Parameter(torch.randn(self.dk // 2, self.kernel_size, 1), requires_grad=True)
self.rel_encoding_w = nn.Parameter(torch.randn(self.dk // 2, 1, self.kernel_size), requires_grad=True)

# later access attention weights
self.inference = inference
if self.inference:
self.register_parameter('weights', None)

def forward(self, x):
batch_size, _, height, width = x.size()

# Compute k, q, v
padded_x = F.pad(x, [(self.kernel_size-1)//2, (self.kernel_size-1)-((self.kernel_size-1)//2), (self.kernel_size-1)//2, (self.kernel_size-1)-((self.kernel_size-1)//2)])
k = self.k_conv(padded_x)
q = self.q_conv(x)
v = self.v_conv(padded_x)

# Unfold patches into [BS, num_heads*depth, horizontal_patches, vertical_patches, kernel_size, kernel_size]
k = k.unfold(2, self.kernel_size, 1).unfold(3, self.kernel_size, 1)
v = v.unfold(2, self.kernel_size, 1).unfold(3, self.kernel_size, 1)

# Reshape into [BS, num_heads, horizontal_patches, vertical_patches, depth_per_head, kernel_size*kernel_size]
k = k.reshape(batch_size, self.num_heads, height, width, self.dkh, -1)
v = v.reshape(batch_size, self.num_heads, height, width, self.dvh, -1)

# Reshape into [BS, num_heads, height, width, depth_per_head, 1]
q = q.reshape(batch_size, self.num_heads, height, width, self.dkh, 1)

qk = torch.matmul(q.transpose(4, 5), k)
qk = qk.reshape(batch_size, self.num_heads, height, width, self.kernel_size, self.kernel_size)

# Add positional encoding
qr_h = torch.einsum('bhxydz,cij->bhxyij', q, self.rel_encoding_h)
qr_w = torch.einsum('bhxydz,cij->bhxyij', q, self.rel_encoding_w)
qk += qr_h
qk += qr_w

qk = qk.reshape(batch_size, self.num_heads, height, width, 1, self.kernel_size*self.kernel_size)
weights = F.softmax(qk, dim=-1)

if self.inference:
self.weights = nn.Parameter(weights)

attn_out = torch.matmul(weights, v.transpose(4, 5))
attn_out = attn_out.reshape(batch_size, -1, height, width)
return attn_out

上面的代码可能有些问题,应该是将i,j的距离差嵌入到一个embedding中更合适

Rethinking and Improving Relative Position Encoding for Vision Transformer

这是篇好文章,关于注意力中相对位置用于2d图像数据的方法.也是在上面SASA的一种改进.

image-20240217181329312

以往的相对位置编码方法都依赖于输入嵌入。这就带来了一个问题,即编码能否独立于输入?

论文引入相对位置编码的偏向模式和语境模式来研究该问题。前者独立于输入嵌入,而后者考虑了与查询、键或值的交互。也就上图的两种模式.

计算attention weight加上一个偏置,在bias模式下,这个偏置是一个可学习的参数,表示相对位置的权重.

在context模式下,有多种可行的方式.其中r是一个可训练的向量,也表示相对位置,但它会与Q或K交互.

此外context模式也可以应用于value嵌入

为了计算二维图像平面上的相对位置并定义相对权重r~ij~,提出了两种无向映射方法Euclidean和Quantization,以及两种有向映射方法Cross和Product。

在上述欧几里得方法中,距离较近的两个具有不同相对距离的邻居可能被映射到同一个索引中,例如二维相对位置( 1、0 )和( 1 , 1)都被映射到索引1中。假设近邻应该是分离的。因此对欧氏距离进行量化,即将不同的实数映射成不同的整数。

运算quant ( · )将一组实数{ 0,1,1.41,2,2.24,.. }映射为一组整数{ 0,1,2,3,4,.. } .这种方法也是无向的.

像素的位置方向对图像也很重要,因此提出了有向映射方法。这种方法被称为Cross方法,它分别在水平和垂直方向上计算编码,然后进行汇总。方法如下

如果某个方向上的距离是相同的,那么Cross方法将不同的相对位置编码到同一个嵌入中,此外带来了额外的计算开销。为了提高效率并包含更多的方向性信息,设计了Product方法,公式如下:

image-20240217223648427

其他

Swin transformer

[2103.14030] Swin Transformer: Hierarchical Vision Transformer using Shifted Windows (arxiv.org)

[2111.09883] Swin Transformer V2: Scaling Up Capacity and Resolution (arxiv.org)

image-20240218140849412

image-20240218141119075

将Transformer从语言转换到视觉的挑战来自于两个领域之间的差异,例如视觉实体的尺度变化较大,图像中的像素相对于文本中的文字分辨率较高。

为了解决这些差异,提出了一个分层Transformer,其表示由Shifted窗口计算。移位窗口方案通过将自注意力计算限制在不重叠的局部窗口,同时允许跨窗口连接,从而带来更高的效率。这种分层架构具有在各种尺度下建模的灵活性,并且具有与图像大小相关的线性计算复杂度。

image-20240221231106703

Swin-transformerV2

[2111.09883] Swin Transformer V2: Scaling Up Capacity and Resolution (arxiv.org)

image-20240221231423057

Twins

[2104.13840] Twins: Revisiting the Design of Spatial Attention in Vision Transformers (arxiv.org)

img

image-20240218141741213

在这项工作中,重新审视了空间注意力的设计,并证明了一个精心设计但简单的空间注意力机制与最先进的方案相比具有良好的性能。因此,我们提出了两种视觉转换器结构,即Twins - PCPVT和TwinsSVT。我们提出的架构高效且易于实现,只涉及在现代深度学习框架中高度优化的矩阵乘法。更重要的是,所提出的架构在包括图像级cla在内的广泛的视觉任务上取得了优异的性能

此外随着时间发展,目前已经有了空间注意力,通道注意力等等可以用于2D数据的注意力模型.但是基本思想是类似的.

参考资料

  1. Relative position embedding - 知乎 (zhihu.com)
  2. [1803.02155] Self-Attention with Relative Position Representations (arxiv.org)
  3. Relative Positional Embedding | Chao Yang (placebokkk.github.io)
  4. Improve Transformer Models with Better Relative Position Embeddings (aclanthology.org)
  5. 让研究人员绞尽脑汁的Transformer位置编码 - 知乎 (zhihu.com)
  6. 《A survey of the Vision Transformers and its CNN-Transformer based Variants》第一期 - 知乎 (zhihu.com)
-------------本文结束感谢您的阅读-------------
感谢阅读.

欢迎关注我的其它发布渠道