Implementing Corrective RAG in the Easiest Way

Implementing Corrective RAG in the Easiest Way

4 min read

Even though text-generation models are good at generating content, they sometimes need to improve in returning facts. This happens because of the way they are trained. Retrieval Augmented Generation(RAG) techniques have been introduced to address this issue by fetching context from a knowledge base. 

Corrective RAG is an additional step to ensure the model sticks to the information it gets. It corrects factual inaccuracies in real-time by ranking options based on how likely they fit the model and match the retrieved info. This helps make accurate corrections before finishing the text.

Corrective Retrieval Augmented Generation paper

Overview of CRAG at Inference

Top Level Structure

CRAG has three main parts:

  1. Generative Model: It generates an initial sequence.
  2. Retrieval Model: It retrieves context based on the initial sequence from the knowledge base.
  3. Retrieval Evaluator: It manages the back-and-forth between the generator and retriever, ranks options, and decides the final sequence to be given as output.

In CRAG, the Retrieval Evaluator links the Retriever and Generator. It keeps track of the text created, asks the generator for more, gets knowledge with updated info, scores options for both text fit and accuracy, and chooses the best one to add to the output at each step.

Pseudocode of Algorithm given in CRAG Paper

Implementation

Here is the Colab link

Implementation will include giving ratings/scores to retrieved documents based on how well they answer a question:

For Correct documents -

  1. If at least one document is relevant, it moves on to creating text
  2. Before creating text, it cleans up the knowledge
  3. This breaks down the document into “knowledge strips”
  4. It rates each strip and gets rid of ones that don’t matter

For Ambiguous or Incorrect documents -

  1. If all documents need to be more relevant or are unsure, the method looks for more information.
  2. It uses a web search to add more details to what it found
  3. The diagram in the paper also shows that they might change the question to get better results.

We will implement CRAG using Langgraph and LanceDB. LanceDB for lightning-fast retrieval from the knowledge base. Here is the flow of how it will work.

Flow Diagram

The next step to implement this flow will be: 

  1. Retrieve Relevant Documents
  2. If a Relevant Document is not found, Go for Supplement Retrieval with Web search(using Tavily API).
  3. Query Re-writing to optimize the query for Web search.

Here is some pseudocode; for a full implementation, check the Colab.

Building Retriever

We will use Jay Alammer’s articles on Transforms as a Knowledge.

from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import LanceDB
from langchain_openai import ChatOpenAI, OpenAIEmbeddings

# Using Jay Alammer's articles on Transformers, Bert, and using transformers for retrieval
urls = [
    "https://jalammar.github.io/illustrated-transformer/",
    "https://jalammar.github.io/illustrated-bert/",
    "https://jalammar.github.io/illustrated-retrieval-transformer/",
]

docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]

# document chunking
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
    chunk_size=250, chunk_overlap=0
)
doc_splits = text_splitter.split_documents(docs_list)

Define Knowledge base using LanceDB

import lancedb

def lanceDBConnection(embed):
    db = lancedb.connect("/tmp/lancedb")
    table = db.create_table(
        "crag_demo",
        data=[{"vector": embed.embed_query("Hello World"), "text": "Hello World"}],
        mode="overwrite",)

    return table

Embeddings Extraction

We will use the OpenAI embeddings function and insert them in the LanceDB knowledge base for fetching context to extract the embeddings of documents. 

# OpenAI embeddings
embedder = OpenAIEmbeddings()

# LanceDB as vector store
table = lanceDBConnection(embedder)
vectorstore = LanceDB.from_documents(
    documents=doc_splits,
    embedding=embedder,
    connection=table,
)

# ready with our retriever
retriever = vectorstore.as_retriever()

Define Langgraph

We will define a graph for building Langgraph by adding nodes and edges as in the above flow diagram.

from typing import Dict, TypedDict

from langchain_core.messages import BaseMessage


class GraphState(TypedDict):
    """
    Represents the state of our graph.

    Attributes:
        keys: A dictionary where each key is a string.
    """

    keys: Dict[str, any]

This Graph will include five nodes: Document Retriever, Generator, Document Grader, Query Transformer, and Web Search, and 1 Edge will be Decide to Generate.

The following graph shows the flow shown in the diagram.

import pprint

from langgraph.graph import END, StateGraph

workflow = StateGraph(GraphState)

# Define the nodes
workflow.add_node("retrieve", retrieve)  # retrieve docs
workflow.add_node("grade_documents", grade_documents)  # grade retrieved docs
workflow.add_node("generate", generate)  # generate answers
workflow.add_node("transform_query", transform_query)  # transform_query for web search
workflow.add_node("web_search", web_search)  # web search

# Build graph
workflow.set_entry_point("retrieve")
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
    "grade_documents",
    decide_to_generate,
    {
        "transform_query": "transform_query",
        "generate": "generate",
    },
)
workflow.add_edge("transform_query", "web_search")
workflow.add_edge("web_search", "generate")
workflow.add_edge("generate", END)

# Compile
app = workflow.compile()
Now we are ready with our graph and ready to query
# Run
query_prompt = "How Transformers work?"
inputs = {"keys": {"question": query_prompt}}
for output in app.stream(inputs):
    for key, value in output.items():
        # Node
        # print full state at each node
        pprint.pprint(value["keys"], indent=2, width=80, depth=None)
    pprint.pprint("------------------------")

# Final generation
print("*"*5, " Generated Answer ", "*"*5)
pprint.pprint(value["keys"]["generation"])

Here is the generated answer output, illustrating each node’s functioning and their decisions, ultimately resulting in the final generated output.

Output

Checkout the Colab for Implementation

Challenges and Future Works

No Doubt CRAG helps generate more accurate factual information out of the knowledge base, but still, Some Challenges remain for the widespread adoption of CRAG: 

  1. Retrieval quality depends on Comprehensive Knowledge coverage.
  2. Increased computation cost and latency compared to basic models.
  3. The framework is sensitive to Retriever limitations.
  4. Balancing between fluency and factuality is challenging.