LoRa 大语言模型的低阶适应

LoRa 大语言模型的低阶适应

本文是LoRA: Low-Rank Adaptation of Large Language Models论文的学习记录

LoRA用来降低大语言模型下游任务训练的算力及内存资源需求量,降低预训练大模型产品化落地的成本。

Lora原理示意图

对预训练模型的下游任务微调,一般采用全参数微调或者固定预训练参数并添加MLP的适配层,但这两种方法都有局限性

全参数微调:对像GPT3这样1750亿参数的模型进行全参数微调,需要的算力资源过大,影响模型下游任务推广。

添加适配层:增加的适配层影响模型部署后的推理速度。

1 低秩矩阵

深度学习模型可以简化为一系列的矩阵乘法,其中 W_0\in R^{d\times k} 是表示预训练模型的权重矩阵。

预训练语言模型中的权值矩阵可以当做一个低秩(Low Rank)矩阵(此句可能不准确)。这里秩是矩阵的线性无关列/行的最大数量,低秩表示秩小于矩阵的行数/列数。如果一个矩阵的秩等于矩阵的行数或者列数,这个矩阵称为满秩(Full Rank)矩阵。他们的特点如下

低秩矩阵:低秩矩阵可以使用两个较小矩阵相乘表示,可用于数据压缩。从低秩矩阵的这个特点看加入到一个参数极大的模型中,就像为模型增加一个可学习的正则化模块。

满秩矩阵:对于线性方程组满秩矩阵可用于线性代数求解。

假设 \Delta W预训练语言模型微调需要学习的参数矩阵,那假设他也应该是一个低秩矩阵,进一步将他分解为B和A两个小矩阵的乘积,原矩阵维度为(d,k),矩阵B维度为(d,r)矩阵A维度为(r,k)其中r的大小远小于d或者k。

W_0+\Delta W = W_0 + BA ,where B\in R^{d\times r},A\in R^{r\times k}

在模型微调时固定预训练的参数,只更新B和A这两个小矩阵。此时如果输入x完成一次前向计算

h=W_0x+\Delta Wx = W_0x + BAx

到这里我们可以看出,Lora对现有大模型微调方法的参数多训练成本高和添加适配层增加模型部署延迟问题是一个很好的解决方案。首先模型微调针对比预训练模型参数矩阵小的多B、A矩阵,其次在模型推理时可以先在模型初始化阶段将\Delta W 矩阵加到预训练模型参数矩阵,这样推理时并不会消耗额外时间。如果资源足够或者对推理时间不敏感

2 一些结论

初始化矩阵B和矩阵A时,B初始化为0,A进行高斯分布的随机初始化。

选择低秩矩阵时,选择多个待调矩阵并提供较小r的策略比选择单一矩阵并提供较大r的策略好。

首先在r的选择上,取r=1已经表现足够优秀。

3 代码实现

在代码库github.com/microsoft/Lo的README.md描述段快速上手流程。简略描述下

1、构造用来微调的lora层

import loralib as lora
# Add a pair of low-rank adaptation matrices with rank r=16
layer = lora.Linear(in_features, out_features, r=16)

lora.Linear继承自Pytorch的nn.Linear。这里是建立一个Linear线性层。lora.Linear在构造时根据输入的r,创建lora_A和lora_B,然后计算一个scaling用来缩放参数。默认

if r > 0:
    self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
    self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
    self.scaling = self.lora_alpha / self.r
    # Freezing the pre-trained weight matrix
    self.weight.requires_grad = False
self.reset_parameters()

接着通过reset_parameters()函数使用kaiming_uniform初始化lora_A,使用0初始化loar_B

nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)

lora.Linear在进行前向推理时,先计算元参数结果,再叠加Lora矩阵计算结果

def forward(self, x: torch.Tensor):
    def T(w):
        return w.T if self.fan_in_fan_out else w
    if self.r > 0 and not self.merged:
        result = F.linear(x, T(self.weight), bias=self.bias)
        if self.r > 0:
            result += (self.lora_dropout(x) @ self.lora_A.T @ self.lora_B.T) * self.scaling
        return result
    else:
        return F.linear(x, T(self.weight), bias=self.bias)

2、仅将模型中的lora结构设置为可训练

model = BigModel()
# This sets requires_grad to False for all parameters without the string "lora_" in their names
lora.mark_only_lora_as_trainable(model)
# Training loop
for batch in dataloader:

mark_only_lora_as_trainable的实现主要通过将model.named_parameters()中的带'loar_'字段参数的requires_grad梯度更新标志位设置为True,非'lora_'参数的requires_grad梯度更新标志位设置为False实现。

3、保存checkpoint时,仅保存LoRA参数

torch.save(lora.lora_state_dict(model), checkpoint_path)

lora_state_dict的实现主要通过过滤model.state_dict()参数中的'loar_'字段实现。

4、加载模型参数时,先加载原模型参数再加载LoRA参数,并将加载模型参数命令的strict参数设置为False

# Load the pretrained checkpoint first
model.load_state_dict(torch.load('ckpt_pretrained.pt'), strict=False)
# Then load the LoRA checkpoint
model.load_state_dict(torch.load('ckpt_lora.pt'), strict=False)

参考

github.com/microsoft/Lo

HELLO七仔:Transformer模型

发布于 2023-04-06 17:18・IP 属地四川