llama源码学习 | model.py


最开始我们还是先看引入的模块和包
import math
from dataclasses import dataclass
from typing import Any, Optional, Tuple
import fairscale.nn.model_parallel.initialize as fs_init
import torch
import torch.nn.functional as F
from fairscale.nn.model_parallel.layers import (
ColumnParallelLinear,
ParallelEmbedding,
RowParallelLinear,
from torch import nn
这部分代码导入了所需的Python模块。
math
、
dataclasses
、
typing
、
torch
和
torch.nn.functional
是Python的标准库,用于基本的数学运算、数据类定义、类型注解和PyTorch的函数接口。
fairscale.nn.model_parallel.initialize
和
fairscale.nn.model_parallel.layers
是FairScale库的模块,用于实现模型并行化。
from dataclasses import dataclass
: 这句代码是从Python的dataclasses
模块中导入dataclass
装饰器。dataclass
是Python 3.7及以上版本中的一个功能,它可以自动添加特殊方法(如__init__()
和__repr__()
)到类中,使得定义类变得更加简洁。
from typing import Any, Optional, Tuple
: 这句代码是从Python的typing
模块中导入Any
,Optional
和Tuple
类型。typing
模块是Python 3.5及以上版本中的一个功能,它提供了对Python代码的静态类型支持。Any
表示任何类型,Optional
表示可选类型(即某个类型或者None
),Tuple
表示元组类型。
import fairscale.nn.model_parallel.initialize as fs_init
: 这句代码是导入fairscale
库中的model_parallel.initialize
模块,并将其重命名为fs_init
。fairscale
是一个优化和扩展PyTorch的库,它提供了一些高级功能,如模型并行化、优化器状态分片等。在这里,model_parallel.initialize
模块可能包含了一些初始化模型并行化的函数或者类。
fairscale.nn.model_parallel.layers
: 这是fairscale
库中的一个模块,它包含了一些并行化的神经网络层,如ColumnParallelLinear
,ParallelEmbedding
,RowParallelLinear
等。这些层可以在多个设备上并行计算,从而加速模型的训练。
torch.nn.functional
: 这是PyTorch库中的一个模块,它包含了一些函数,这些函数对应了神经网络中的操作,如卷积、池化、非线性激活等。这些函数的特点是,它们不包含状态(即没有权重),因此它们既可以作为独立的函数使用,也可以作为其他层的一部分使用。
FairScale是一个Python库,它提供了一系列工具和功能,用于优化和扩展PyTorch的能力。FairScale的主要目标是使研究人员和工程师能够更容易地使用先进的分布式训练和优化技术,从而提高模型训练的效率和规模。
FairScale的主要功能包括:
1. 模型并行化:FairScale提供了一种高效的模型并行化方法,可以将大型模型分布在多个GPU上进行训练。这使得训练更大、更复杂的模型成为可能,而不需要增加单个GPU的内存需求。
2. 优化器状态分片:在训练大型模型时,优化器(如Adam或SGD)的状态可能会占用大量的GPU内存。FairScale提供了一种优化器状态分片(Optimizer State Sharding,OSS)的技术,可以将优化器的状态分布在多个GPU上,从而减少单个GPU的内存需求。
3. 梯度累积:在训练大型模型时,可能需要在多个小批量(mini-batch)上累积梯度,然后再进行一次参数更新,以减少GPU内存的需求。FairScale提供了一种梯度累积的方法,可以在不增加代码复杂性的情况下实现这一功能。
4. 混合精度训练:FairScale支持混合精度训练,这是一种使用较低精度(如半精度浮点数)进行计算,但使用较高精度(如单精度浮点数)进行参数更新的方法。混合精度训练可以减少GPU内存的需求,并可能加速模型的训练。
在实现模型并行化时,FairScale提供了`fairscale.nn.model_parallel`模块,其中包含了一些并行化的神经网络层,如`ColumnParallelLinear`,`ParallelEmbedding`,`RowParallelLinear`等。这些层可以在多个设备上并行计算,从而加速模型的训练。例如,`ColumnParallelLinear`是一个线性层,它将输入的特征分布在多个设备上,并在每个设备上进行部分计算,然后再将结果聚合起来。这样,即使模型的大小超过单个设备的内存容量,也可以进行训练。
@dataclass
class ModelArgs:
dim: int = 4096
n_layers: int = 32
n_heads: int = 32
n_kv_heads: Optional[int] = None
vocab_size: int = -1 # defined later by tokenizer
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
ffn_dim_multiplier: Optional[float] = None
norm_eps: float = 1e-5
max_batch_size: int = 32
max_seq_len: int = 2048
这部分代码定义了一个数据类
ModelArgs
,用于存储模型的参数。这些参数包括模型的维度、层数、头数、词汇表大小、批量大小、序列长度等。
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
这部分代码定义了一个名为
RMSNorm
的类,它是一个标准化层。在神经网络中,标准化层可以帮助改善模型的训练速度和性能。
RMSNorm
类是一种标准化层,它实现了Root Mean Square Layer Normalization(RMSNorm)。标准化是深度学习中常用的一种技术,可以帮助加速模型的训练,并提高模型的性能。
在__init__方法中,首先定义了一个小的常数eps,用于防止除以零的错误。 然后定义了一个可学习的权重weight,它的形状与输入的维度相同。 在_norm方法中,首先计算输入x的平方,然后计算其在最后一个维度上的均值, 然后加上eps,然后对结果取平方根的倒数,最后用这个值乘以x,得到标准化后的结果。 在forward方法中,首先调用_norm方法对输入x进行标准化, 然后将结果乘以权重weight,得到最终的输出。
在RMSNorm中,`weight`是一个可学习的参数,它的形状与输入的维度相同。这个权重参数在标准化后的数据上进行缩放,可以增加模型的表达能力。标准化操作可以将输入数据的均值变为0,方差变为1,这样可以防止梯度消失或爆炸,有助于模型的训练。然而,这种操作可能会改变数据的原始分布,导致一些有用的信息丢失。为了解决这个问题,我们在标准化后的数据上乘以一个可学习的权重参数,这样可以恢复数据的原始分布,同时保留标准化带来的优点。这种方法在一些标准化技术中是常见的,例如Batch Normalization和Layer Normalization也有类似的操作。在这些方法中,除了标准化操作,还会有一个可学习的缩放参数和一个可学习的偏移参数,它们可以在标准化后的数据上进行缩放和偏移,以增加模型的表达能力。
标准化层(Normalization Layer)是深度学习模型中的一种常见层,它的主要作用是对输入数据进行标准化处理,使得数据的均值为0,方差为1。这样可以加快模型的收敛速度,提高模型的训练稳定性,并有助于防止梯度消失或梯度爆炸的问题。
标准化层的实现通常包括以下步骤:
1. 计算输入数据的均值和方差。
2. 使用这个均值和方差将输入数据进行标准化,即减去均值,然后除以方差的平方根。
3. 对标准化后的数据进行缩放和平移,这两个操作的参数是可学习的。
在PyTorch中,有多种类型的标准化层,如批标准化(Batch Normalization)、层标准化(Layer Normalization)、实例标准化(Instance Normalization)等,它们的区别主要在于计算均值和方差时所使用的数据范围不同。
例如,批标准化(Batch Normalization)是最常见的一种标准化方法,它在每个特征维度上,对一个批次(batch)中的所有数据计算均值和方差。在PyTorch中,可以使用`torch.nn.BatchNorm1d`、`torch.nn.BatchNorm2d`、`torch.nn.BatchNorm3d`等类来实现批标准化。另一种常见的标准化方法是层标准化(Layer Normalization),它在单个数据样本上,对所有特征维度计算均值和方差。在PyTorch中,可以使用`torch.nn.LayerNorm`类来实现层标准化。使用的是RMS标准化(RMS Normalization),这是一种新的标准化方法,它的实现方式与上述方法类似,但在计算标准化参数时使用了均方根(Root Mean Square)方法。
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
这部分代码定义了一个函数
precompute_freqs_cis
,用于预计算频率。这个函数在模型中用于实现Rotary Positional Embedding。
precompute_freqs_cis
函数用于预计算旋转位置编码(Rotary Positional Encoding)中的频率项。这是一种新的位置编码方法,用于在自注意力机制中引入序列中的位置信息。
在自然语言处理任务中,位置信息是非常重要的,因为词语在句子中的位置会影响其语义。例如,在"我爱你"和"你爱我"这两个句子中,虽然词语相同,但是由于位置不同,其意义完全不同。因此,我们需要一种方法来将位置信息引入模型中。在这个函数中,首先计算了一个频率向量
freqs
,然后使用
torch.outer
函数计算了一个频率矩阵,然后使用
torch.polar
函数将频率矩阵转换为复数形式,得到了预计算的频率项
freqs_cis
。
旋转位置编码是一种新的位置编码方法,它的主要思想是将位置信息编码为复数,然后通过复数乘法将位置信息融入到模型中。这种方法可以有效地引入位置信息,而且计算效率高,不会增加额外的参数。
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
这部分代码定义了一个函数
reshape_for_broadcast
,用于调整张量的形状以便进行广播操作。这个函数的目的是调整
freqs_cis
的形状,使其可以与
x
进行广播操作。它首先获取
x
的维度数
ndim
,然后创建一个新的形状
shape
,这个形状的每个维度与
x
的对应维度相同,除了第二维和最后一维,这两个维度的大小与
freqs_cis
的形状相同。然后,它使用
view
方法将
freqs_cis
的形状调整为新的形状。
广播(Broadcasting)是一种强大的机制,它允许NumPy和其他类似的库在执行算术运算时处理具有不同形状的数组。这在深度学习和科学计算中非常常见,因为我们经常需要对不同形状的张量进行操作。广播的基本原则是:在两个数组的形状不完全相同的情况下,如果从尾部开始比较,维度相同或者其中一个是1,那么就可以进行广播操作。广播操作会自动扩展维度为1的维度以匹配另一个数组的维度。例如,假设我们有一个形状为(3,)的数组a和一个形状为(3,3)的数组b,我们想要将a和b的每一行相加。由于a和b的最后一个维度是匹配的,因此我们可以使用广播将a扩展为(3,3),然后进行相加操作。在PyTorch中,广播操作是自动进行的,你不需要显式地调用任何函数。例如,如果你有两个形状不同但兼容的张量,你可以直接将它们相加,PyTorch会自动进行广播。
广播操作可以使你的代码更简洁,更易读,同时也可以提高计算效率,因为它避免了显式地复制数据。
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
这部分代码定义了一个函数
apply_rotary_emb
,用于应用Rotary Positional Embedding。
apply_rotary_emb
函数是用来应用旋转位置编码(Rotary Positional Encoding)的。这是一种新的位置编码方法,用于在自注意力机制中引入序列中的位置信息。
在自然语言处理任务中,位置信息是非常重要的,因为词语在句子中的位置会影响其语义。例如,在"我爱你"和"你爱我"这两个句子中,虽然词语相同,但是由于位置不同,其意义完全不同。因此,我们需要一种方法来将位置信息引入模型中。
旋转位置编码是一种新的位置编码方法,它的主要思想是将位置信息编码为复数,然后通过复数乘法将位置信息融入到模型中。这种方法可以有效地引入位置信息,而且计算效率高,不会增加额外的参数。在这个函数中,首先将输入的
xq
和
xk
转换为复数形式,然后调用
reshape_for_broadcast
函数将
freqs_cis
调整为可以进行广播操作的形状,然后通过复数乘法将位置信息融入到
xq
和
xk
中,最后将结果转换回实数形式
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
bs, slen, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, :, None, :]
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
这部分代码定义了一个函数
repeat_kv
,用于重复张量。
repeat_kv
函数是用来重复张量的。在某些情况下,我们可能需要将一个张量在某个维度上重复多次,以匹配另一个张量的形状。这在深度学习中是很常见的,因为我们经常需要对不同形状的张量进行操作。
在这个函数中,
n_rep
参数表示要重复的次数。如果
n_rep
为1,那么直接返回输入的张量
x
。否则,使用
torch.repeat_interleave
函数将
x
在第2个维度(索引从0开始)上重复
n_rep
次。
在这个函数中,首先获取输入张量
x
的形状,然后判断
n_rep
是否为1。如果
n_rep
为1,那么直接返回
x
。否则,将
x
在第4个维度上增加一个维度,然后使用
expand
函数将
x
在第4个维度上扩展
n_rep
次,然后使用
reshape
函数将
x
的形状调整为
(bs, slen, n_kv_heads * n_rep, head_dim)
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
model_parallel_size = fs_init.get_model_parallel_world_size()
self.n_local_heads = args.n_heads // model_parallel_size
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.wq = ColumnParallelLinear(
args.dim,
args.n_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
self.wk = ColumnParallelLinear(
args.dim,
self.n_kv_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
self.wv = ColumnParallelLinear(
args.dim,
self.n_kv_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
self.wo = RowParallelLinear(
args.n_heads * self.head_dim,
args.dim,
bias=False,
input_is_parallel=True,
init_method=lambda x: x,
self.cache_k = torch.zeros(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
).cuda()
self.cache_v = torch.zeros(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
).cuda()
这部分代码定义了一个名为
Attention
的类,它是一个注意力层。在Transformer模型中,注意力层用于计算输入的每个元素对输出的每个元素的影响。
Attention
类是实现自注意力机制(Self-Attention)的关键部分。自注意力机制是Transformer模型的核心组成部分,它允许模型在处理序列数据时,对序列中的每个元素都进行注意力分配,从而捕捉序列中的长距离依赖关系。
在
__init__
函数中,首先初始化了一些参数,包括注意力头的数量(
n_heads
),模型并行的大小(
model_parallel_size
),本地注意力头的数量(
n_local_heads
),本地键/值头的数量(
n_local_kv_heads
),重复的次数(
n_rep
),以及每个头的维度(
head_dim
)。
然后,定义了四个线性层,分别用于计算查询(
wq
),键(
wk
),值(
wv
)和输出(
wo
)。这四个线性层都使用了
fairscale
库中的并行线性层,以实现模型并行。
在
forward
函数中,首先获取输入张量
x
的形状,然后通过线性层计算查询,键和值。
然后,将查询,键和值的形状调整为适合进行注意力计算的形状,然后应用旋转位置编码。
接着,定义了两个缓存,用于存储键和值。这是为了实现自注意力机制中的缓存机制,可以提高计算效率。
然后,从缓存中获取键和值,然后重复键和值以匹配查询的形状。
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int,
ffn_dim_multiplier: Optional[float],
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = ColumnParallelLinear(
dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
self.w2 = RowParallelLinear(
hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x
self.w3 = ColumnParallelLinear(
dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
这部分代码定义了一个名为
FeedForward
的类,它是一个前馈网络层。在Transformer模型中,前馈网络层用于对注意力层的输出进行进一步的处理。
FeedForward
类是实现前馈神经网络的部分。在Transformer模型中,每个Transformer块都包含一个自注意力层和一个前馈神经网络层。前馈神经网络层是一个全连接的神经网络,它在每个位置上独立地处理输入。
FeedForward
类是实现前馈神经网络的部分。在Transformer模型中,每个Transformer块都包含一个自注意力层和一个前馈神经网络层。前馈神经网络层是一个全连接的神经网络,它在每个位置上独立地处理输入。
class TransformerBlock(nn.Module):
def __init__(self, layer_id: int, args: ModelArgs):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.head_dim = args.dim // args.n_heads
self.attention = Attention(args)
self.feed_forward = FeedForward(
dim=args.dim,
hidden_dim=4 * args.dim,
multiple_of=args.multiple_of,
ffn_dim_multiplier=args.ffn_dim_multiplier,
self.layer_id = layer_id
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
这部分代码定义了一个名为
TransformerBlock
的类,它是一个Transformer层。在Transformer模型中,每一层都包括一个注意力操作和一个前馈网络操作。
class Transformer(nn.Module):
def __init__(self, params: ModelArgs):
super().__init__()
self.params = params
self.vocab_size = params.vocab_size
self.n_layers = params.n_layers
self.tok_embeddings = ParallelEmbedding(
params.vocab_size, params.dim, init_method=lambda x: x
self.layers = torch.nn.ModuleList()
for layer_id in range(params.n_layers):
self.layers.append(TransformerBlock(layer_id, params))
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = ColumnParallelLinear(
params.dim, params.vocab_size, bias=False, init_method=lambda x: x
self.freqs_cis = precompute_freqs_cis(
self.params.dim // self.params.n_heads, self.params.max_seq_len * 2
这部分代码定义了一个名为
Transformer
的类,它是整个模型的主体。这个类包括词嵌入层、多个Transformer层、一个标准化层和一个输出层。这是整个Transformer模型的主体。它的构造函数接收一个包含模型参数的
ModelArgs
对象。在构造函数中,它初始化了词嵌入层(
self.tok_embeddings
),一系列的Transformer块(
self.layers
),一个归一化层(
self.norm
),和一个输出层(
self.output
)。在前向传播函数(
forward
)中,它首先通过词嵌入层处理输入的词标记,然后通过每一个Transformer块处理结果,接着通过归一化层处理结果,最后通过输出层得到最终的输出。
@torch.inference_mode()
def forward(self, tokens: torch.Tensor, start_pos: int):
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
self.freqs_cis = self.freqs_cis.to(h.device)
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
mask = None
if seqlen > 1:
mask = torch.full(
(1, 1, seqlen, seqlen), float("-inf"), device=tokens.device
mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)