Low-Rank Adaptation (LoRA): A Parameter-Efficient Fine-Tuning for LLMs
Large language models (LLMs), such as GPT-3 and LaMDA, are at the forefront of natural language processing (NLP). Trained on terabytes of text data, they generate human-like text and power multiple applications, including chatbots, virtual assistants, and search engines.
This initial training process, where the LLM learns general language representations from a massive dataset, is computationally intensive and requires significant resources. Once pre-trained, it can be adapted for specific tasks through fine-tuning.
Fine-tuning involves training the model on a smaller, task-specific dataset to improve its performance on a particular task. However, due to the billions of parameters these models contain, standard fine-tuning can still be computationally expensive and resource-intensive.
This challenge has led to the development of parameter-efficient fine-tuning (PEFT) techniques. PEFT aims to adapt LLMs by modifying only a small number of parameters while reducing computational costs and preserving the model's original capabilities.
Low-Rank Adaptation (LoRA) stands out for its efficiency among PEFT techniques. It freezes the pre-trained weights and introduces trainable rank decomposition matrices into each layer of the transformer architecture.
Alan Ritter, an associate professor at Georgia Tech, says LoRA has democratized LLM training by giving more people the ability to fine-tune larger models.
An illustration of regular finetuning (left) and LoRA finetuning (right) | Source
This article will explain why the PEFT technique, specifically LoRA, is essential, how it works, and how to implement it using PyTorch.
Why Low-Rank Adaptation?
Retraining methods, where all pre-trained model parameters are updated, come with significant drawbacks:
- High Computational Cost: Updating billions of parameters, like the 175 billion in GPT-3, requires immense computational resources, making fine-tuning time-consuming and expensive.
- Storage Requirements: Storing the fine-tuned models requires substantial storage capacity, which can be a limiting factor.
- Risk of Overfitting: Fine-tuning all parameters, especially when using limited task-specific data, can make the model more prone to overfitting.
In contrast, fine-tuning or transfer learning techniques generally update only a smaller subset of the model’s parameters.
Low-Rank Adaptation refines this approach further by introduces a lightweight solution and reduces the number of parameters that need to be updated during fine-tuning.
This leads to several key benefits:
- Faster Training and Inference: With fewer parameters to update, LoRA enables faster training and inference speeds compared to full retraining.
- Lower Storage Requirements: The reduced parameter count translates to lower storage needs for the fine-tuned model.
- Improved Efficiency: Reducing the trainable parameters indirectly helps mitigate the risk of overfitting by limiting the model's capacity to memorize the training data.
Understanding LoRA
Before we get into LoRA, it's important to grasp the ideas of matrix rank and low-rank matrices. Knowing these concepts will help us understand how LoRA works more efficiently for fine-tuning.
What Is the Rank of a Matrix?
A matrix is an array of numbers arranged in rows and columns. The rank of a matrix refers to the maximum number of linearly independent rows or columns within it.
- Linear independence means that none of the rows (or columns) can be expressed as a linear combination of the others.
Consider the matrix 𝐴:
In this matrix,
- The second row is 2 times the first row, and the third is 3 times the first.
- Similarly, the second and third columns are multiples of the first column.
Since all rows and columns are linearly dependent, the matrix rank is 1, as only one row or column carries independent information.
Now, let's look at the concept of lower-rank matrices.
Low-Rank Matrices
A matrix is considered low-rank if its rank is less than the smaller of its number of rows and columns. For a matrix 𝐴 of size 𝑚 × 𝑛, if a rank less than min (𝑚, 𝑛) indicates a low rank.
In the above example, 𝐴 is a 3×3 matrix with a rank of 1. Since its rank is less than 3, it is considered a low-rank matrix.
How LoRA Works
LoRA introduces efficiency in fine-tuning by using the concept of low-rank decomposition. In standard fine-tuning, the weight update for a pre-trained model is represented as:
Here:
- WPretrained: The original weight matrix.
- Δ𝑊: The weight update matrix learned during fine-tuning.
Instead of directly learning and applying the full-weight update matrix Δ𝑊, LoRA approximates it using two smaller matrices (low-rank decomposition), 𝐴 and 𝐵:
- 𝐴: This matrix has dimensions r d, where 'r' is the rank (a hyperparameter) and 'd' is the hidden dimension of the layer. It's initialized with random values from a Gaussian distribution.
- 𝐵: This matrix has dimensions d r and is initialized with zeros.
Therefore, the updated weights become:
The rank 𝑟 controls the expressiveness of the adaptation:
- Lower 𝑟: Fewer trainable parameters, more memory-efficient.
- Higher 𝑟: Greater flexibility in updating the weights.
A and B Matrices Decomposition with Different Rank (r) | Source
During fine-tuning, only the matrices A and B are trained. This reduced the number of trainable parameters compared to standard fine-tuning, where all the weights in the LLM would be updated.
LoRA Implementation with PyTorch
Let’s walk through the implementation of LoRA in PyTorch. We will train a network to classify MNIST digits and then fine-tune the network on a poorly performing digit. This implementation provides a hands-on demonstration, building up the components progressively.
The code used in this guide has been re-implemented from this GitHub repo.
Step 1: Prepare the Environment
The first step is to set up your environment and import the necessary libraries. For this implementation, you'll need a torch and torchvision.
Since this code is designed to run on Colab, ensure the runtime is set to GPU for faster training. If you're using a local GPU setup, ensure your environment is configured accordingly.
Import the necessary libraries:
Copy
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm
# Define the device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Step 2: Load the MNIST Dataset
Load the MNIST dataset, which contains images of handwritten digits, using torchvision.datasets:
Copy
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
# Load the MNIST dataset
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)
# Load the MNIST test set
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)
Step 3: Define the Neural Network
Create a neural network model with a few layers. In this case, we're using a simple model called RichNet with three linear layers and ReLU activations:
Copy
class RichBoyNet(nn.Module):
def __init__(self, hidden_size_1=1000, hidden_size_2=2000):
super(RichBoyNet,self).__init__()
self.linear1 = nn.Linear(28*28, hidden_size_1)
self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
self.linear3 = nn.Linear(hidden_size_2, 10)
self.relu = nn.ReLU()
def forward(self, img):
x = img.view(-1, 28*28)
x = self.relu(self.linear1(x))
x = self.relu(self.linear2(x))
x = self.linear3(x)
return x
net = RichBoyNet().to(device)
Step 4: Pre-train the Model
Train the model on the MNIST dataset for one epoch to simulate general pre-training. This step helps the model learn general features from the data:
Copy
def train(train_loader, net, epochs=5, total_iterations_limit=None):
cross_el = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
total_iterations = 0
for epoch in range(epochs):
net.train()
loss_sum = 0
num_iterations = 0
data_iterator = tqdm(train_loader, desc=f'Epoch {epoch+1}')
if total_iterations_limit is not None:
data_iterator.total = total_iterations_limit
for data in data_iterator:
num_iterations += 1
total_iterations += 1
x, y = data
x = x.to(device)
y = y.to(device)
optimizer.zero_grad()
output = net(x.view(-1, 28*28))
loss = cross_el(output, y)
loss_sum += loss.item()
avg_loss = loss_sum / num_iterations
data_iterator.set_postfix(loss=avg_loss)
loss.backward()
optimizer.step()
if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
return
train(train_loader, net, epochs=1)
Output
Copy
Epoch 1: 100%|██████████| 6000/6000 [00:58<00:00, 102.78it/s, loss=0.238]
Step 5: Store Original Weights
Store a copy of the pre-trained model's weights. This is important for comparing the weights before and after fine-tuning with LoRA:
Copy
original_weights = {}
for name, param in net.named_parameters():
original_weights[name] = param.clone().detach()
Step 6: Test the Pre-trained Model
Evaluate the performance of the pre-trained model on the test dataset:
Copy
def test():
correct = 0
total = 0
wrong_counts = [0 for i in range(10)]
with torch.no_grad():
for data in tqdm(test_loader, desc='Testing'):
x, y = data
x = x.to(device)
y = y.to(device)
output = net(x.view(-1, 784))
for idx, i in enumerate(output):
if torch.argmax(i) == y[idx]:
correct +=1
else:
wrong_counts[y[idx]] +=1
total +=1
print(f'Accuracy: {round(correct/total, 3)}')
for i in range(len(wrong_counts)):
print(f'wrong counts for the digit {i}: {wrong_counts[i]}')
test()
Output
The network performs poorly on the digit 9. Let's fine-tune it on the digit 9. But before introducing the LoRA matrices, let's visualize how many parameters are in the original network.
Copy
total_parameters_original = 0
for index, layer in enumerate([net.linear1, net.linear2, net.linear3]):
total_parameters_original += layer.weight.nelement() + layer.bias.nelement()
print(f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape}')
print(f'Total number of parameters: {total_parameters_original:,}')
Output
Step 7: Define LoRA Parametrization and Add it to the Model
Create a class that implements the LoRA technique. This class will add low-rank matrices (lora_A and lora_B) to the linear layers of the model:
Copy
class LoRAParametrization(nn.Module):
def __init__(self, features_in, features_out, rank=1, alpha=1, device='cpu'):
super().__init__()
self.lora_A = nn.Parameter(torch.zeros((rank,features_out)).to(device))
self.lora_B = nn.Parameter(torch.zeros((features_in, rank)).to(device))
nn.init.normal_(self.lora_A, mean=0, std=1)
self.scale = alpha / rank
self.enabled = True
def forward(self, original_weights):
if self.enabled:
# Return W + (B*A)*scale
return original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape) * self.scale
else:
return original_weights
Apply the LoRA parametrization to the linear layers of the model using parametrize.register_parametrization:
Copy
import torch.nn.utils.parametrize as parametrize
def linear_layer_parameterization(layer, device, rank=1, lora_alpha=1):
features_in, features_out = layer.weight.shape
return LoRAParametrization(
features_in, features_out, rank=rank, alpha=lora_alpha, device=device
)
parametrize.register_parametrization(
net.linear1, "weight", linear_layer_parameterization(net.linear1, device)
)
parametrize.register_parametrization(
net.linear2, "weight", linear_layer_parameterization(net.linear2, device)
)
parametrize.register_parametrization(
net.linear3, "weight", linear_layer_parameterization(net.linear3, device)
)
def enable_disable_lora(enabled=True):
for layer in [net.linear1, net.linear2, net.linear3]:
layer.parametrizations["weight"][0].enabled = enabled
Step 8: Freeze Non-LoRA Parameters
Freeze the original weights of the model so that only the LoRA parameters are updated during fine-tuning:
Copy
# Freeze the non-Lora parameters
for name, param in net.named_parameters():
if 'lora' not in name:
print(f'Freezing non-LoRA parameter {name}')
param.requires_grad = False
Output
Freezing non-LoRA parameter
Step 9: Fine-tune with LoRA and Testing
Fine-tune the model with LoRA on a specific task. In this case, the model is fine-tuned on a subset of the MNIST dataset containing only the digit 9:
Copy
# Load the MNIST dataset again, by keeping only the digit 9
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
exclude_indices = mnist_trainset.targets == 9
mnist_trainset.data = mnist_trainset.data[exclude_indices]
mnist_trainset.targets = mnist_trainset.targets[exclude_indices]
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)
train(train_loader, net, epochs=1, total_iterations_limit=100)
Output
Copy
Epoch 1: 99%|█████████▉| 99/100 [00:01<00:00, 68.50it/s, loss=0.15]
Test the model's performance with LoRA enabled (the digit 9 should be classified better) and disabled (the accuracy and error counts must be the same as the original network) to observe the impact of LoRA on the specific task:
Copy
# Test with LoRA enabled
enable_disable_lora(enabled=True)
test()
Output
The digit 9 is classified more accurately when LoRA is enabled.
Copy
# Test with LoRA disabled
enable_disable_lora(enabled=False)
test()
Output
Accuracy and error counts are the same as the original network when LoRA is disabled
Limitations of LoRA
While LoRA is an effective technique for fine-tuning, it's important to know its limitations:
- Not Ideal for All Tasks: LoRA might not be the optimal choice for all fine-tuning tasks, especially those requiring substantial changes to the model's behavior.
- Hyperparameter Tuning: LoRA requires careful tuning of hyperparameters, particularly the rank r, to achieve optimal performance.
Conclusion
Today, we explored the concept of Low-Rank Adaptation (LoRA). After building a foundational understanding, we demonstrated the implementation of LoRA in PyTorch by fine-tuning a network on the MNIST dataset.
This exercise shows that low-rank adaptation (LoRA) is an important advancement in fine-tuning models with fewer parameters, making the process more efficient.
By utilizing lower-rank matrix decomposition, LoRA allows for effective task-specific fine-tuning of large language models while avoiding the significant computational and storage costs that come with standard methods.