Low-Rank Adaptation (LoRA): A Parameter-Efficient Fine-Tuning for LLMs

Low-Rank Adaptation (LoRA): A Parameter-Efficient Fine-Tuning for LLMs
Photo by Pawel Czerwinski / Unsplash

Large language models (LLMs), such as GPT-3 and LaMDA, are at the forefront of natural language processing (NLP). Trained on terabytes of text data, they generate human-like text and power multiple applications, including chatbots, virtual assistants, and search engines.

This initial training process, where the LLM learns general language representations from a massive dataset, is computationally  intensive and requires significant resources. Once pre-trained, it can be adapted for specific tasks through fine-tuning.

Fine-tuning involves training the model on a smaller, task-specific dataset to improve its performance on a particular task. However, due to the billions of parameters these models contain, standard fine-tuning can still be computationally expensive and resource-intensive.

This challenge has led to the development of parameter-efficient fine-tuning (PEFT) techniques. PEFT aims to adapt LLMs by modifying only a small number of parameters while reducing computational costs and preserving the model's original capabilities.

Low-Rank Adaptation (LoRA) stands out for its efficiency among PEFT techniques. It freezes the pre-trained weights and introduces trainable rank decomposition matrices into each layer of the transformer architecture.

Alan Ritter, an associate professor at Georgia Tech, says LoRA has democratized LLM training by giving more people the ability to fine-tune larger models.

An illustration of regular finetuning (left) and LoRA finetuning (right) |

An illustration of regular finetuning (left) and LoRA finetuning (right) | Source

This article will explain why the PEFT technique, specifically LoRA, is essential, how it works, and how to implement it using PyTorch.

Why Low-Rank Adaptation?

Retraining  methods, where all pre-trained model parameters are updated, come with significant drawbacks:

  • High Computational Cost: Updating billions of parameters, like the 175 billion in GPT-3, requires immense computational resources, making fine-tuning time-consuming and expensive.
  • Storage Requirements: Storing the fine-tuned models requires substantial storage capacity, which can be a limiting factor.
  • Risk of Overfitting: Fine-tuning all parameters, especially when using limited task-specific data, can make the model more prone to overfitting.

In contrast, fine-tuning or transfer learning techniques generally update only a smaller subset of the model’s parameters.

Low-Rank Adaptation refines this approach further by introduces a lightweight solution and reduces the number of parameters that need to be updated during fine-tuning.

This leads to several key benefits:

  • Faster Training and Inference: With fewer parameters to update, LoRA enables faster training and inference speeds compared to full retraining.
  • Lower Storage Requirements: The reduced parameter count translates to lower storage needs for the fine-tuned model.
  • Improved Efficiency: Reducing the trainable parameters indirectly helps mitigate the risk of overfitting by limiting the model's capacity to memorize the training data.

Understanding LoRA

Before we get into LoRA, it's important to grasp the ideas of matrix rank and low-rank matrices. Knowing these concepts will help us understand how LoRA works more efficiently for fine-tuning.

What Is the Rank of a Matrix?

A matrix is an array of numbers arranged in rows and columns. The rank of a matrix refers to the maximum number of linearly independent rows or columns within it.

  • Linear independence means that none of the rows (or columns) can be expressed as a linear combination of the others.

Consider the matrix 𝐴:

image

In this matrix,

  • The second row is 2 times the first row, and the third is 3 times the first.
  • Similarly, the second and third columns are multiples of the first column.

Since all rows and columns are linearly dependent, the matrix rank is 1, as only one row or column carries independent information.

Now, let's look at the concept of lower-rank matrices.

Low-Rank Matrices

A matrix is considered low-rank if its rank is less than the smaller of its number of rows and columns. For a matrix 𝐴 of size 𝑚 × 𝑛, if a rank less than min (𝑚, 𝑛) indicates a low rank.

In the above example, 𝐴 is a 3×3 matrix with a rank of 1. Since its rank is less than 3, it is considered a low-rank matrix.

How LoRA Works

LoRA introduces efficiency in fine-tuning by using the concept of low-rank decomposition. In standard fine-tuning, the weight update for a pre-trained model is represented as:

image

Here:

  • WPretrained: The original weight matrix.
  • Δ𝑊: The weight update matrix learned during fine-tuning.

Instead of directly learning and applying the full-weight update matrix Δ𝑊, LoRA approximates it using two smaller matrices (low-rank decomposition), 𝐴 and 𝐵:

image
Source

Source

  • 𝐴: This matrix has dimensions r d, where 'r' is the rank (a hyperparameter) and 'd' is the hidden dimension of the layer. It's initialized with random values from a Gaussian distribution.
  • 𝐵: This matrix has dimensions d r and is initialized with zeros.

Therefore, the updated weights become:

image
Source

Source

The rank 𝑟 controls the expressiveness of the adaptation:

  • Lower 𝑟: Fewer trainable parameters, more memory-efficient.
  • Higher 𝑟: Greater flexibility in updating the weights.
 A and B Matrices Decomposition with Different Rank (r) |

 A and B Matrices Decomposition with Different Rank (r) | Source

During fine-tuning, only the matrices A and B are trained. This reduced the number of trainable parameters compared to standard fine-tuning, where all the weights in the LLM would be updated.

LoRA Implementation with PyTorch

Let’s walk through the implementation of LoRA in PyTorch. We will train a network to classify MNIST digits and then fine-tune the network on a poorly performing digit. This implementation provides a hands-on demonstration, building up the components progressively.

The code used in this guide has been re-implemented from this GitHub repo.

Step 1: Prepare the Environment

The first step is to set up your environment and import the necessary libraries. For this implementation, you'll need a torch and torchvision.

Since this code is designed to run on Colab, ensure the runtime is set to GPU for faster training. If you're using a local GPU setup, ensure your environment is configured accordingly.

Import the necessary libraries:

Copy


import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm


# Define the device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Step 2: Load the MNIST Dataset

Load the MNIST dataset, which contains images of handwritten digits, using torchvision.datasets:

Copy


transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])


# Load the MNIST dataset
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)


# Load the MNIST test set
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

Step 3: Define the Neural Network

Create a neural network model with a few layers. In this case, we're using a simple model called RichNet with three linear layers and ReLU activations:

Copy


class RichBoyNet(nn.Module):
    def __init__(self, hidden_size_1=1000, hidden_size_2=2000):
        super(RichBoyNet,self).__init__()
        self.linear1 = nn.Linear(28*28, hidden_size_1)
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2)
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()


    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x


net = RichBoyNet().to(device)

Step 4: Pre-train the Model

Train the model on the MNIST dataset for one epoch to simulate general pre-training. This step helps the model learn general features from the data:

Copy

def train(train_loader, net, epochs=5, total_iterations_limit=None):
    cross_el = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)


    total_iterations = 0


    for epoch in range(epochs):
        net.train()


        loss_sum = 0
        num_iterations = 0


        data_iterator = tqdm(train_loader, desc=f'Epoch {epoch+1}')
        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit
        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1
            x, y = data
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            output = net(x.view(-1, 28*28))
            loss = cross_el(output, y)
            loss_sum += loss.item()
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss=avg_loss)
            loss.backward()
            optimizer.step()


            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return


train(train_loader, net, epochs=1)

Output

Copy


Epoch 1: 100%|██████████| 6000/6000 [00:58<00:00, 102.78it/s, loss=0.238]

Step 5: Store Original Weights

Store a copy of the pre-trained model's weights. This is important for comparing the weights before and after fine-tuning with LoRA:

Copy


original_weights = {}
for name, param in net.named_parameters():
    original_weights[name] = param.clone().detach()

Step 6: Test the Pre-trained Model

Evaluate the performance of the pre-trained model on the test dataset:

Copy


def test():
    correct = 0
    total = 0


    wrong_counts = [0 for i in range(10)]


    with torch.no_grad():
        for data in tqdm(test_loader, desc='Testing'):
            x, y = data
            x = x.to(device)
            y = y.to(device)
            output = net(x.view(-1, 784))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct +=1
                else:
                    wrong_counts[y[idx]] +=1
                total +=1
    print(f'Accuracy: {round(correct/total, 3)}')
    for i in range(len(wrong_counts)):
        print(f'wrong counts for the digit {i}: {wrong_counts[i]}')


test()

Output

image

The network performs poorly on the digit 9. Let's fine-tune it on the digit 9. But before introducing the LoRA matrices, let's visualize how many parameters are in the original network.

Copy


total_parameters_original = 0
for index, layer in enumerate([net.linear1, net.linear2, net.linear3]):
    total_parameters_original += layer.weight.nelement() + layer.bias.nelement()
    print(f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape}')
print(f'Total number of parameters: {total_parameters_original:,}')

Output

image

Step 7: Define LoRA Parametrization and Add it to the Model

Create a class that implements the LoRA technique. This class will add low-rank matrices (lora_A and lora_B) to the linear layers of the model:

Copy


class LoRAParametrization(nn.Module):
    def __init__(self, features_in, features_out, rank=1, alpha=1, device='cpu'):
        super().__init__()
        self.lora_A = nn.Parameter(torch.zeros((rank,features_out)).to(device))
        self.lora_B = nn.Parameter(torch.zeros((features_in, rank)).to(device))
        nn.init.normal_(self.lora_A, mean=0, std=1)
        self.scale = alpha / rank
        self.enabled = True


    def forward(self, original_weights):
        if self.enabled:
            # Return W + (B*A)*scale
            return original_weights + torch.matmul(self.lora_B, self.lora_A).view(original_weights.shape) * self.scale
        else:
            return original_weights

Apply the LoRA parametrization to the linear layers of the model using parametrize.register_parametrization:

Copy


import torch.nn.utils.parametrize as parametrize


def linear_layer_parameterization(layer, device, rank=1, lora_alpha=1):   
    features_in, features_out = layer.weight.shape
    return LoRAParametrization(
        features_in, features_out, rank=rank, alpha=lora_alpha, device=device
    )


parametrize.register_parametrization(
    net.linear1, "weight", linear_layer_parameterization(net.linear1, device)
)
parametrize.register_parametrization(
    net.linear2, "weight", linear_layer_parameterization(net.linear2, device)
)
parametrize.register_parametrization(
    net.linear3, "weight", linear_layer_parameterization(net.linear3, device)
)


def enable_disable_lora(enabled=True):
    for layer in [net.linear1, net.linear2, net.linear3]:
        layer.parametrizations["weight"][0].enabled = enabled

Step 8: Freeze Non-LoRA Parameters

Freeze the original weights of the model so that only the LoRA parameters are updated during fine-tuning:

Copy


# Freeze the non-Lora parameters
for name, param in net.named_parameters():
    if 'lora' not in name:
        print(f'Freezing non-LoRA parameter {name}')
        param.requires_grad = False

Output

Freezing non-LoRA parameter

Freezing non-LoRA parameter

Step 9: Fine-tune with LoRA and Testing

Fine-tune the model with LoRA on a specific task. In this case, the model is fine-tuned on a subset of the MNIST dataset containing only the digit 9:

Copy


# Load the MNIST dataset again, by keeping only the digit 9
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
exclude_indices = mnist_trainset.targets == 9
mnist_trainset.data = mnist_trainset.data[exclude_indices]
mnist_trainset.targets = mnist_trainset.targets[exclude_indices]


# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)


train(train_loader, net, epochs=1, total_iterations_limit=100)

Output

Copy


Epoch 1:  99%|█████████▉| 99/100 [00:01<00:00, 68.50it/s, loss=0.15]

Test the model's performance with LoRA enabled (the digit 9 should be classified better) and disabled (the accuracy and error counts must be the same as the original network) to observe the impact of LoRA on the specific task:

Copy


# Test with LoRA enabled
enable_disable_lora(enabled=True)
test()

Output

image

The digit 9 is classified more accurately when LoRA is enabled.

Copy


# Test with LoRA disabled
enable_disable_lora(enabled=False)
test()

Output

image

Accuracy and error counts are the same as the original network when LoRA is disabled

Limitations of LoRA

While LoRA is an effective technique for fine-tuning, it's important to know its limitations:

  • Not Ideal for All Tasks: LoRA might not be the optimal choice for all fine-tuning tasks, especially those requiring substantial changes to the model's behavior.
  • Hyperparameter Tuning: LoRA requires careful tuning of hyperparameters, particularly the rank r, to achieve optimal performance.

Conclusion

Today, we explored the concept of Low-Rank Adaptation (LoRA). After building a foundational understanding, we demonstrated the implementation of LoRA in PyTorch by fine-tuning a network on the MNIST dataset.

This exercise shows that low-rank adaptation (LoRA) is an important advancement in fine-tuning models with fewer parameters, making the process more efficient.

By utilizing lower-rank matrix decomposition, LoRA allows for effective task-specific fine-tuning of large language models while avoiding the significant computational and storage costs that come with standard methods.

Read more