Implementing Corrective RAG (CRAG) using LangGraph and Chroma DB
Introduction
In our Advanced RAG Techniques series, today we’re exploring a powerful technique to improve traditional RAG called Corrective RAG (CRAG).
Since LLMs have the limitation of generating responses based on their pre-trained knowledge, RAG has become a popular technique for making them smarter by connecting them to external data.
On a high level, RAG involves breaking your data sources into smaller chunks and storing them in a vector database.
When a user asks a question, the most relevant chunks of information are retrieved from the vector database and passed on to the LLM to generate a response that aligns with your data.
Limitations of Traditional RAG
RAG enables LLMs to go beyond their training data, but it’s not always perfect. The quality of the generated response depends heavily on the quality of the data retrieved.
If the retrieval system doesn’t find the right information, the results can be:
- Inaccurate or Irrelevant
Mismatched or incomplete responses are generated when the retrieved documents don’t properly address the query. - Hallucinations
Missing important context can cause the model to make up information or provide incorrect answers. - Inconsistent
Sometimes, retrieval systems repeatedly fetch irrelevant documents, making the system unpredictable for users.
Corrective RAG (CRAG)
CRAG is an advanced RAG technique that improves the traditional RAG by actively evaluating and refining the retrieved documents to ensure accuracy and relevance.
By introducing an evaluator and corrective mechanisms, CRAG addresses the shortcomings of standard RAG. Here's how the process works:
- Initial Retrieval: A standard retriever pulls documents from a knowledge base.
- Evaluation: A retrieval evaluator grades the documents' relevance to the query, assigning confidence scores and categorizing them into three tiers:
- Correct: If one or more documents score above an upper threshold, they are refined using a "decompose-then-recompose" algorithm to strip irrelevant details and emphasize the most critical points.
- Incorrect: If all documents fall below a lower threshold, CRAG performs a web search to gather fresh and potentially more accurate external knowledge.
- Ambiguous: For mixed results, CRAG combines the refinement of initial documents with new information from web searches.
- Generation: The refined knowledge is used to generate a final response, ensuring high accuracy and relevance.
CRAG vs Traditional RAG
CRAG is an improved version of Naive or traditional RAG. This technique has the ability to fix errors in the information it retrieves.
With a built-in retrieval evaluator, CRAG identifies when information is incorrect or irrelevant and corrects it before it impacts the final output. This helps CRAG to provide more accurate and reliable information, cutting down on errors and misinformation.
While traditional RAG typically checks for relevance alone, CRAG refines documents further, filtering out irrelevant or imprecise details. This ensures the generated text is not just relevant but also precise
Implementation
In this section, we will go through a step-by-step guide on how to implement CRAG using LangGraph. You can access the notebook on our GitHub
You'll learn how to set up your environment, create a basic knowledge vector store, and configure the key components needed for CRAG, like the retrieval evaluator, question rewriter, and web search tool.
1. Installing Libraries and Setting Up Environment
We'll start with installing the necessary libraries and configuring the environment variables:
! pip install --q athina chromadb langgraph
import os
from google.colab import userdata
os.environ["OPENAI_API_KEY"] = userdata.get('OPENAI_API_KEY')
os.environ["TAVILY_API_KEY"] = userdata.get('TAVILY_API_KEY')
os.environ['ATHINA_API_KEY'] = userdata.get('ATHINA_API_KEY')
2. Loading Embedding model, Documents and Preparing Embeddings
We'll be using the OpenAIEmbeddings model for generating the vectors. Next, load your documents (CSV file) and prepare them for retrieval. Split the document using the standard Recursive text splitter
# load embedding model
from langchain_openai import OpenAIEmbeddings
embeddings = OpenAIEmbeddings()
# load data
from langchain.document_loaders import CSVLoader
loader = CSVLoader("./context.csv")
documents = loader.load()
3. Creating a Vector Store with ChromaDB
Now, create a vector store to store document embeddings for efficient similarity search.
# create vectorstore
from langchain.vectorstores import Chroma
vectorstore = Chroma.from_documents(documents, embeddings)
4. Setting Up the Retrievers
Define retrievers from the vector store
# create retriever
retriever = vectorstore.as_retriever()
5. Setting up a Grader/Evaluator for Documents
Here we'll define the grader function that will grade the retrieved documents based on a yes/no criteria
# create grader for doc retriever
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_openai import ChatOpenAI
# defining a data class for the grader
class GradeDocuments(BaseModel):
binary_score: str = Field(
description="Documents are relevant to the question, 'yes' or 'no'"
)
# LLM with function call
llm = ChatOpenAI(temperature=0)
structured_llm_grader = llm.with_structured_output(GradeDocuments)
# Prompt for the grader
system = """You are a grader assessing relevance of a retrieved document to a user question. \n
It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \n
If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \n
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question."""
grade_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("human", "Retrieved document: \n\n {document} \n\n User question: {question}"),
]
)
retrieval_grader = grade_prompt | structured_llm_grader
6. Running the Grader
# run grader
question = "how does interlibrary loan work"
docs = retriever.get_relevant_documents(question)
print(retrieval_grader.invoke({"question": question, "document": docs}))
7. Set up a RAG Chain
In this step, we'll set up a RAG chain that takes user's query and a set of retrieved documents to generate an answer. We'll also define the prompt template. At the end we'll use an output parser to format the generated output so that it is easier to read
# create document chain
from langchain import hub
from langchain_core.output_parsers import StrOutputParser
from langchain.prompts import ChatPromptTemplate
template = """"
You are a helpful assistant that answers questions based on the following context.'
Use the provided context to answer the question.
Context: {context}
Question: {question}
Answer:
"""
prompt = ChatPromptTemplate.from_template(template)
llm = ChatOpenAI(temperature=0)
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
rag_chain = prompt | llm | StrOutputParser()
# response
generation = rag_chain.invoke({"context": docs, "question": question})
generation
Interlibrary loan (ILL) is a service that allows patrons of one library to borrow physical materials or receive electronic documents from another library. After receiving a request from a patron, the borrowing library identifies potential lending libraries with the desired item. The lending library then delivers the item physically or electronically to the borrowing library, who then delivers it to their patron. If necessary, arrangements are made for the return of the item. Fees may accompany interlibrary loan services, and the majority of requests are now managed through semi-automated systems. Libraries negotiate for ILL eligibility, and with the increasing demand for digital materials, they are exploring the legal, technical, and licensing aspects of lending and borrowing ebooks through interlibrary loan.
7. Initializing the Web Search Tool (using Tavily)
If the knowledge base doesn’t have enough information, CRAG turns to web search to fill in the gaps. This broadens the range of possible information sources. In this step, we use the Tavily API to search the web and find additional documents.
# define web search
from langchain_community.tools.tavily_search import TavilySearchResults
web_search_tool = TavilySearchResults(k=3)
# # sample web search
# web_search_tool.invoke('USA election 2024')
8. Setting up the LangGraph Workflow
To build the CRAG workflow with LangGraph, follow these three main steps:
- Define the graph state
- Define function nodes
- Connect all function nodes
Define the graph state
Create a shared state to store data as it moves between nodes during the workflow. This state will hold all the variables, such as the user's question, retrieved documents, and generated answers.
# define a data class for state
from typing import List
from typing_extensions import TypedDict
class GraphState(TypedDict):
question: str
generation: str
web_search: str
documents: List[str]
Define function nodes to Build Graph
In the LangGraph workflow, each function node handles a specific task in the CRAG pipeline, such as retrieving documents, generating answers, evaluating relevance, transforming queries, and searching the web. Here's a breakdown of each function:
The retrieve
function finds documents from the knowledge base that are relevant to the user's question. It uses a retriever object, which is usually a vector store created from pre-processed documents. This function takes the current state, including the user's question, and uses the retriever to get relevant documents. It then adds these documents to the state.
# define graph steps
from langchain.schema import Document
# node function for retrieval
def retrieve(state):
print("---RETRIEVE---")
question = state["question"]
# Retrieval
documents = retriever.invoke(question)
return {"documents": documents, "question": question}
# node function for generation
def generate(state):
print("---GENERATE---")
question = state["question"]
documents = state["documents"]
# RAG generation
generation = rag_chain.invoke({"context": documents, "question": question})
return {"documents": documents, "question": question, "generation": generation}
# node function for check_relevance
def grade_documents(state):
print("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
question = state["question"]
documents = state["documents"]
# Score each doc
filtered_docs = []
web_search = "No"
for d in documents:
score = retrieval_grader.invoke(
{"question": question, "document": d.page_content}
)
grade = score.binary_score
if grade == "yes":
print("---GRADE: DOCUMENT RELEVANT---")
filtered_docs.append(d)
else:
print("---GRADE: DOCUMENT NOT RELEVANT---")
web_search = "Yes"
continue
return {"documents": filtered_docs, "question": question, "web_search": web_search}
# node function for web search
def web_search(state):
print("---WEB SEARCH---")
question = state["question"]
documents = state["documents"]
# Web search
docs = web_search_tool.invoke({"query": question})
web_results = "\n".join([d["content"] for d in docs])
web_results = Document(page_content=web_results)
documents.append(web_results)
return {"documents": documents, "question": question}
# node function for decision
def decide_to_generate(state):
print("---ASSESS GRADED DOCUMENTS---")
state["question"]
web_search = state["web_search"]
state["documents"]
if web_search == "Yes":
print("---DECISION: WEB SEARCH---")
return "web_search"
else:
print("---DECISION: GENERATE---")
return "generate"
9. Connect all Function Nodes
Once all the function nodes have been defined, we can now link all the function nodes together in the LangGraph workflow to build the CRAG pipeline. This means connecting the nodes with edges to manage the flow of information and decisions, making sure the workflow runs correctly based on each step's results.
# Build graph
from langgraph.graph import END, StateGraph, START
# Graph
workflow = StateGraph(GraphState)
# Define the nodes
workflow.add_node("retrieve", retrieve) # retrieve
workflow.add_node("grade_documents", grade_documents) # grade documents
workflow.add_node("generate", generate) # generatae
workflow.add_node("web_search_node", web_search) # web search
# Build graph
workflow.add_edge(START, "retrieve")
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
"grade_documents",
decide_to_generate,
{
"web_search": "web_search_node",
"generate": "generate",
},
)
workflow.add_edge("web_search_node", "generate")
workflow.add_edge("generate", END)
# Compile
app = workflow.compile()
Example 1:
Here is an example where retriever fetches relevant documents:
# example 1 where documents are relevant
from pprint import pprint
inputs = {"question": "how does interlibrary loan work"}
for output in app.stream(inputs):
for key, value in output.items():
pprint(f"Node '{key}':")
pprint("\n---\n")
pprint(value["generation"])
---RETRIEVE---
"Node 'retrieve':"
'\n---\n'
---CHECK DOCUMENT RELEVANCE TO QUESTION---
---GRADE: DOCUMENT RELEVANT---
---GRADE: DOCUMENT RELEVANT---
---GRADE: DOCUMENT RELEVANT---
---GRADE: DOCUMENT RELEVANT---
---ASSESS GRADED DOCUMENTS---
---DECISION: GENERATE---
"Node 'grade_documents':"
'\n---\n'
---GENERATE---
"Node 'generate':"
'\n---\n'
('Interlibrary loan works by allowing patrons of one library to borrow '
'physical materials or receive electronic documents from another library that '
'holds the desired item. The borrowing library identifies potential lending '
'libraries, which then deliver the item either physically or electronically. '
'The borrowing library receives the item, delivers it to their patron, and '
'arranges for its return if necessary. Fees may accompany interlibrary loan '
'services, and the majority of requests are now managed through '
'semi-automated systems.')
Example 2:
Here is an example where the documents fetched by retriever are not relevant:
# example 2 where documents are not relevant
from pprint import pprint
inputs = {"question": "What is Retrieval-Augmented Generation (RAG)?"}
for output in app.stream(inputs):
for key, value in output.items():
pprint(f"Node '{key}':")
pprint("\n---\n")
pprint(value["generation"])
---RETRIEVE---
"Node 'retrieve':"
'\n---\n'
---CHECK DOCUMENT RELEVANCE TO QUESTION---
---GRADE: DOCUMENT NOT RELEVANT---
---GRADE: DOCUMENT NOT RELEVANT---
---GRADE: DOCUMENT NOT RELEVANT---
---GRADE: DOCUMENT NOT RELEVANT---
---ASSESS GRADED DOCUMENTS---
---DECISION: WEB SEARCH---
"Node 'grade_documents':"
'\n---\n'
---WEB SEARCH---
"Node 'web_search_node':"
'\n---\n'
---GENERATE---
"Node 'generate':"
'\n---\n'
('Retrieval-Augmented Generation (RAG) is a technique that enhances generative '
'artificial intelligence models by allowing them to access and reference '
'external knowledge bases, such as specific documents or databases, before '
'generating a response. This process helps improve the relevance and accuracy '
"of the generated text by incorporating information beyond the model's own "
'training data.')
10. Preparing Data for Evaluation
# Create a dataframe to store the question, context, and response
inputs = {"question": "what are points on a mortgage?"}
outputs = []
expected_response = "Points, sometimes also called a 'discount point', are a form of pre-paid interest ."
for output in app.stream(inputs):
for key, value in output.items():
if key == "generate":
question = value["question"]
documents = value["documents"]
generation = value["generation"]
context = "\n".join(doc.page_content for doc in documents)
# Append the result
outputs.append({
"query": question,
"context": context,
"response": generation,
"expected_response": expected_response
})
---RETRIEVE---
---CHECK DOCUMENT RELEVANCE TO QUESTION---
---GRADE: DOCUMENT RELEVANT---
---GRADE: DOCUMENT RELEVANT---
---GRADE: DOCUMENT RELEVANT---
---GRADE: DOCUMENT RELEVANT---
---ASSESS GRADED DOCUMENTS---
---DECISION: GENERATE---
---GENERATE---
# Convert to DataFrame
import pandas as pd
df = pd.DataFrame(outputs)
# Convert to dictionary
df_dict = df.to_dict(orient='records')
# Convert context to list
for record in df_dict:
if not isinstance(record.get('context'), list):
if record.get('context') is None:
record['context'] = []
else:
record['context'] = [record['context']]
Response Evaluation using Athina IDE
Now, we'll log our dataset to Athina IDE to evaluate the performance of our CRAG using Context Precision Eval.
# set api keys for Athina evals
from athina.keys import AthinaApiKey, OpenAiApiKey
OpenAiApiKey.set_key(os.getenv('OPENAI_API_KEY'))
AthinaApiKey.set_key(os.getenv('ATHINA_API_KEY'))
# load dataset
from athina.loaders import Loader
dataset = Loader().load_dict(df_dict)
# evaluate
from athina.evals import RagasContextPrecision
RagasContextPrecision(model="gpt-4o").run_batch(data=dataset).to_df()