llama3是meta开源的大模型,在开源大模型中占着重要地位,在这之前可能是Mistral,目前也有gemma2,Qwen2以及微软的Phi3等.
llama表现很不错,LLM Leaderboard (toloka.ai),很多模型都是在它基础上微调得到的.
这里将llama介绍分为位置编码,transformer层,ffn层以及其中的norm的改进.
llama3相比于llama2,上下文窗口增大,tokenizer从sentencepiece变为tiktoken,token数也增多了.
| Feature | LLaMa 2 | LLaMa 3 | 
|---|---|---|
| Training Data Size | 2 trillion tokens | 15 trillion tokens (7x larger) | 
| Context Window | 4K tokens | 8k tokens | 
| Focus Area | General language understanding | Nuance, context, complex tasks | 
| False Refusal Rate | Higher | Lower | 
| Response Diversity | Lower | Higher | 
| Code Generation | Limited capability | Enhanced capability | 
旋转位置编码
旋转位置嵌入(RoPE)是一种用于基于transformer模型的技术,可将位置信息纳入标记表示中。与依赖正弦和余弦函数的传统位置编码不同,RoPE 利用旋转矩阵来编码绝对和相对位置信息。这种方法的提出是为了提高位置嵌入在transformer中的有效性。Meta的LLaMA、清华的ChatGLM都采用了RoPE
在llama3中代码如下1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # 一个位置上的特征
    t = torch.arange(end, device=freqs.device, dtype=torch.float32)
    freqs = torch.outer(t, freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64 # 得到e^freqs^ shape [T,dim//2]
    return freqs_cis
self.freqs_cis = precompute_freqs_cis(
    params.dim // params.n_heads,
    params.max_seq_len * 2,
    params.rope_theta,
)
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)
def apply_rotary_emb(
        xq: torch.Tensor,
        xk: torch.Tensor,
        freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) #shape [bs,seq_len,dim//2,2] -> [bs,seq_len,dim]
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))#shape [bs,seq_len,dim//2,2] -> [bs,seq_len,dim]
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_) # [1,seq_len,dim]
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)#[bs,seq_len,dim//2,2] type float32 -> 
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
计算freqs_cis,其是一个复数,旋转编码通常应用在q和k上.
下面是另一种实现1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47import torch
import torch.nn as nn
class RotaryPositionalEmbeddings(nn.Module):
    def __init__(self, d: int, base: int = 10_000):
        super().__init__()
        self.base = base
        self.d = d
        self.cos_cached = None
        self.sin_cached = None
    def _build_cache(self, x: torch.Tensor):
        if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
            return
        seq_len = x.shape[0]
        # THETA = 10,000^(-2*i/d) or 1/10,000^(2i/d)
        theta = 1. / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device)
        # Position index [0,1,...]
        seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)
        idx_theta = torch.einsum('n,d->nd', seq_idx,
                                 theta)  # Calculates m*(THETA) = [ [0, 0...], [THETA_1, THETA_2...THETA_d/2], ... [seq-1*(THETA_1), seq-1*(THETA_2)...] ]
        idx_theta2 = torch.cat([idx_theta, idx_theta],
                               dim=1)  # [THETA_1, THETA_2...THETA_d/2] -> [THETA_1, THETA_2...THETA_d]
        self.cos_cached = idx_theta2.cos()[:, None, None, :]  # Cache [cosTHETA_1, cosTHETA_2...cosTHETA_d]
        self.sin_cached = idx_theta2.sin()[:, None, None, :]  # cache [sinTHETA_1, sinTHETA_2...sinTHETA_d]
    def _neg_half(self, x: torch.Tensor):
        d_2 = self.d // 2
        return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
    def forward(self, x: torch.Tensor):
        self._build_cache(x)
        neg_half_x = self._neg_half(x)
        x_rope = (x * self.cos_cached[:x.shape[0]]) + (neg_half_x * self.sin_cached[:x.shape[0]])
        return x_rope
if __name__ == '__main__':
    x = torch.tensor([[1, 2, 3, 4], [4, 5, 6, 7], [7, 8, 9, 10]], dtype=torch.float)
    x = x[:, None, None, :]
    p = RotaryPositionalEmbeddings(4)(x)
    print(p)
RSMNorm
另一种规范化的方式,方法是在2019年的论文中提出的1910.07467 (arxiv.org)1
2
3
4
5
6
7
8
9
10
11
12
13class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight
在llama3中,有三个地方使用,在attention,ffn以及在所有transformer layer之后,1
2h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
out = h + self.feed_forward(self.ffn_norm(h))
GroupQuery Attention&&KVcache
GroupQuery:query的head数是kv的head数的若干倍.
KV cache:在生成新的token时,K和V往往改变不大,也就不需要怎么计算,所以只需要存下计算的值即可.这是节约显存的操作.1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        model_parallel_size = fs_init.get_model_parallel_world_size()
        self.n_local_heads = args.n_heads // model_parallel_size
        self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
        self.n_rep = self.n_local_heads // self.n_local_kv_heads
        self.head_dim = args.dim // args.n_heads
        self.wq = ColumnParallelLinear(
            args.dim,
            args.n_heads * self.head_dim,
            bias=False,
            gather_output=False,
            init_method=lambda x: x,
        )
        self.wk = ColumnParallelLinear(
            args.dim,
            self.n_kv_heads * self.head_dim,
            bias=False,
            gather_output=False,
            init_method=lambda x: x,
        )
        self.wv = ColumnParallelLinear(
            args.dim,
            self.n_kv_heads * self.head_dim,
            bias=False,
            gather_output=False,
            init_method=lambda x: x,
        )
        self.wo = RowParallelLinear(
            args.n_heads * self.head_dim,
            args.dim,
            bias=False,
            input_is_parallel=True,
            init_method=lambda x: x,
        )
        self.cache_k = torch.zeros(
            (
                args.max_batch_size,
                args.max_seq_len,
                self.n_local_kv_heads,
                self.head_dim,
            )
        ).cuda()
        self.cache_v = torch.zeros(
            (
                args.max_batch_size,
                args.max_seq_len,
                self.n_local_kv_heads,
                self.head_dim,
            )
        ).cuda()
    def forward(
            self,
            x: torch.Tensor,
            start_pos: int,
            freqs_cis: torch.Tensor,
            mask: Optional[torch.Tensor],
    ):
        bsz, seqlen, _ = x.shape
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
        self.cache_k = self.cache_k.to(xq)
        self.cache_v = self.cache_v.to(xq)
        self.cache_k[:bsz, start_pos: start_pos + seqlen] = xk
        self.cache_v[:bsz, start_pos: start_pos + seqlen] = xv
        keys = self.cache_k[:bsz, : start_pos + seqlen]
        values = self.cache_v[:bsz, : start_pos + seqlen]
        # repeat k/v heads if n_kv_heads < n_heads
        keys = repeat_kv(
            keys, self.n_rep
        )  # (bs, cache_len + seqlen, n_local_heads, head_dim)
        values = repeat_kv(
            values, self.n_rep
        )  # (bs, cache_len + seqlen, n_local_heads, head_dim)
        xq = xq.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)
        keys = keys.transpose(1, 2)  # (bs, n_local_heads, cache_len + seqlen, head_dim)
        values = values.transpose(
            1, 2
        )  # (bs, n_local_heads, cache_len + seqlen, head_dim)
        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = scores + mask  # (bs, n_local_heads, seqlen, cache_len + seqlen)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
        output = torch.matmul(scores, values)  # (bs, n_local_heads, seqlen, head_dim)
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
        return self.wo(output)
Transformer
1  | class Transformer(nn.Module):  | 
silu激活函数
1  | class FeedForward(nn.Module):  | 

