# Improve retrieval performance by Fine-tuning embedding model

Another way to improve retriever performance is to fine-tune the embedding model itself. Fine-tuning the embedding model can help in learning better representations for the documents and queries in the dataset. This can be particularly useful when the dataset is very different from the pre-trained data used to train the embedding model.

In [24]:
%pip install llama-index-llms-openai llama-index-embeddings-openai llama-index-finetuning llama-index-readers-file scikit-learn llama-index-embeddings-huggingface llama-index-vector-stores-lancedb pyarrow==12.0.1 -qq

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
cudf-cu12 24.4.1 requires pyarrow<15.0.0a0,>=14.0.1, but you have pyarrow 12.0.1 which is incompatible.
datasets 2.20.0 requires pyarrow>=15.0.0, but you have pyarrow 12.0.1 which is incompatible.[0m[31m
[0m

In [22]:
# For eval utils
!git clone https://github.com/lancedb/ragged.git
!cd ragged && pip install .


Cloning into 'ragged'...
remote: Enumerating objects: 160, done.[K
remote: Counting objects: 100% (160/160), done.[K
remote: Compressing objects: 100% (103/103), done.[K
remote: Total 160 (delta 70), reused 125 (delta 41), pack-reused 0[K
Receiving objects: 100% (160/160), 38.15 KiB | 9.54 MiB/s, done.
Resolving deltas: 100% (70/70), done.
Processing /content/ragged
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting datasets (from ragged==0.1.dev0)
  Downloading datasets-2.20.0-py3-none-any.whl (547 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m547.8/547.8 kB[0m [31m13.8 MB/s[0m eta [36m0:00:00[0m
Collecting streamlit (from ragged==0.1.dev0)
  Downloading streamlit-1.36.0-py2.py3-none-any.whl (8.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.6/8.6 MB[0m [31m54.1 MB/s[0m eta [36m0:00:00[0m
Collecti

## The dataset
The dataset we'll use is a synthetic QA dataset generated from LLama2 review paper. The paper was divided into chunks, with each chunk being a unique context. An LLM was prompted to ask questions relevant to the context for testing a retreiver.
The exact code and other utility functions for this can be found in [this](https://github.com/lancedb/ragged) repo


In [8]:
!wget https://raw.githubusercontent.com/AyushExel/assets/main/data_qa.csv

--2024-07-09 20:37:46--  https://raw.githubusercontent.com/AyushExel/assets/main/data_qa.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.108.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 680439 (664K) [text/plain]
Saving to: ‘data_qa.csv’


2024-07-09 20:37:47 (100 MB/s) - ‘data_qa.csv’ saved [680439/680439]



In [9]:
import pandas as pd

data = pd.read_csv("data_qa.csv")

## Pre-processing
Now we need to parse the context(corpus) of the dataset as llama-index text nodes.  

In [10]:
from pathlib import Path
from llama_index.core.node_parser import SentenceSplitter
from llama_index.readers.file import PagedCSVReader

def load_corpus(file, verbose=False):
    if verbose:
        print(f"Loading files {file}...")

    loader = PagedCSVReader(encoding="utf-8")
    docs = loader.load_data(file=Path(file))

    if verbose:
        print(f"Loaded {len(docs)} docs")

    parser = SentenceSplitter()
    nodes = parser.get_nodes_from_documents(docs, show_progress=verbose)

    if verbose:
        print(f"Parsed {len(nodes)} nodes")

    return nodes

In [11]:
import pandas as pd

df = pd.read_csv("data_qa.csv", index_col=0)

In [12]:
import os

os.environ["OPENAI_API_KEY"] = "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"

Split into train and validation sets. We'll use the original df for val as that has different queries generated via a different prompt.


In [13]:
from sklearn.model_selection import train_test_split

# Randomly shuffle df.
#df = df.sample(frac=1, random_state=42)

train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)

train_df.to_csv("train_data_qa.csv", index=False)
val_df.to_csv("val_data_qa.csv", index=False)

In [14]:
train_nodes = load_corpus("train_data_qa.csv", verbose=True)
val_nodes = load_corpus("val_data_qa.csv", verbose=True)

Loading files train_data_qa.csv...
Loaded 176 docs


Parsing nodes:   0%|          | 0/176 [00:00<?, ?it/s]

Parsed 221 nodes
Loading files val_data_qa.csv...
Loaded 44 docs


Parsing nodes:   0%|          | 0/44 [00:00<?, ?it/s]

Parsed 59 nodes


### Generate the query from context from training


In [15]:
from llama_index.finetuning import generate_qa_embedding_pairs
from llama_index.core.evaluation import EmbeddingQAFinetuneDataset

In [16]:
from llama_index.llms.openai import OpenAI


train_dataset = generate_qa_embedding_pairs(
    llm=OpenAI(model="gpt-3.5-turbo"), nodes=train_nodes, verbose=False
)
val_dataset = generate_qa_embedding_pairs(
    llm=OpenAI(model="gpt-3.5-turbo"), nodes=val_nodes, verbose=False
)

train_dataset.save_json("train_dataset.json")
val_dataset.save_json("val_dataset.json")

100%|██████████| 221/221 [05:29<00:00,  1.49s/it]
221it [00:00, ?it/s]


In [17]:
# Load again
train_dataset = EmbeddingQAFinetuneDataset.from_json("train_dataset.json")

val_dataset = EmbeddingQAFinetuneDataset.from_json("val_dataset.json")

## Fine-tune the embedding model

In [18]:
import torch
from llama_index.finetuning import SentenceTransformersFinetuneEngine

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

finetune_engine = SentenceTransformersFinetuneEngine(
    train_dataset,
    model_id="BAAI/bge-small-en-v1.5",
    model_output_path="tuned_model",
    val_dataset=val_dataset,
    device=device
)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/94.8k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/52.0 [00:00<?, ?B/s]



config.json:   0%|          | 0.00/743 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/133M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/366 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

In [19]:
finetune_engine.finetune( )


Epoch:   0%|          | 0/2 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

Iteration:   0%|          | 0/45 [00:00<?, ?it/s]

In [20]:
embed_model = finetune_engine.get_finetuned_model()


## Evaluate on Hit-rate


In [25]:
from ragged.dataset import CSVDataset, SquadDataset
from ragged.rag import llamaIndexRAG
from ragged.metrics.retriever.hit_rate import HitRate
from ragged.search_utils import QueryType


def evaluate_vector(
    dataset,
    embed_model_name_or_path,
    top_k=5,
):
  dataset = CSVDataset(dataset)

  hit_rate = HitRate(dataset, embed_model_kwarg={"name": embed_model_name_or_path})

  print(hit_rate.evaluate(top_k, query_type=QueryType.VECTOR))


def evaluate_all(
    dataset,
    embed_model_name_or_path,
    reranker,
    top_k=5,
):
  dataset = CSVDataset(dataset)
  hit_rate = HitRate(dataset, embed_model_kwarg={"name": embed_model_name_or_path}, reranker=reranker)

  print(hit_rate.evaluate(top_k, query_type=QueryType.ALL))


In [28]:
from lancedb.rerankers import CohereReranker, LinearCombinationReranker


#linear_combination_reranker = LinearCombinationReranker()
cohere_reranker = CohereReranker(api_key="Jp48Rt3QuO4VSLWiFKhbgnx68QaDueC9XEqvWMQZ")

#evaluate_all("data_qa.csv", "BAAI/bge-small-en-v1.5", linear_combination_reranker)
hit_rate_bge_cohere = evaluate_all("data_qa.csv", "BAAI/bge-small-en-v1.5", cohere_reranker)


INFO:lancedb:Adding 110 documents to LanceDB, in 1 batches of size 110


Adding 110 documents to LanceDB, in 1 batches of size 110


Adding batch to LanceDB: 100%|██████████| 110/110 [00:00<00:00, 165663.71it/s]
INFO:lancedb:Adding batch 0 to LanceDB


Adding batch 0 to LanceDB


INFO:lancedb:created table with length 110


created table with length 110


INFO:lancedb:Evaluating query type: vector


Evaluating query type: vector


100%|██████████| 220/220 [00:10<00:00, 20.61it/s]
INFO:lancedb:Hit rate for vector: 0.6409090909090909


Hit rate for vector: 0.6409090909090909


INFO:lancedb:Evaluating query type: fts


Evaluating query type: fts


100%|██████████| 220/220 [00:00<00:00, 361.50it/s]
INFO:lancedb:Hit rate for fts: 0.5954545454545455


Hit rate for fts: 0.5954545454545455


INFO:lancedb:Evaluating query type: rerank_vector


Evaluating query type: rerank_vector


100%|██████████| 220/220 [01:32<00:00,  2.38it/s]
INFO:lancedb:Hit rate for rerank_vector: 0.6772727272727272


Hit rate for rerank_vector: 0.6772727272727272


INFO:lancedb:Evaluating query type: rerank_fts


Evaluating query type: rerank_fts


100%|██████████| 220/220 [01:23<00:00,  2.63it/s]
INFO:lancedb:Hit rate for rerank_fts: 0.6727272727272727


Hit rate for rerank_fts: 0.6727272727272727


INFO:lancedb:Evaluating query type: hybrid


Evaluating query type: hybrid


100%|██████████| 220/220 [01:28<00:00,  2.47it/s]
INFO:lancedb:Hit rate for hybrid: 0.759090909090909


Hit rate for hybrid: 0.759090909090909
vector=0.6409090909090909 fts=0.5954545454545455 rerank_vector=0.6772727272727272 rerank_fts=0.6727272727272727 hybrid=0.759090909090909


In [29]:
#evaluate_all("data_qa.csv", "tuned_model/", linear_combination_reranker)
evaluate_all("data_qa.csv", "tuned_model/", cohere_reranker)



INFO:lancedb:Adding 110 documents to LanceDB, in 1 batches of size 110


Adding 110 documents to LanceDB, in 1 batches of size 110


Adding batch to LanceDB: 100%|██████████| 110/110 [00:00<00:00, 91234.61it/s]
INFO:lancedb:Adding batch 0 to LanceDB


Adding batch 0 to LanceDB


INFO:lancedb:created table with length 110


created table with length 110


INFO:lancedb:Evaluating query type: vector


Evaluating query type: vector


100%|██████████| 220/220 [00:09<00:00, 22.17it/s]
INFO:lancedb:Hit rate for vector: 0.6727272727272727


Hit rate for vector: 0.6727272727272727


INFO:lancedb:Evaluating query type: fts


Evaluating query type: fts


100%|██████████| 220/220 [00:00<00:00, 285.43it/s]
INFO:lancedb:Hit rate for fts: 0.5954545454545455


Hit rate for fts: 0.5954545454545455


INFO:lancedb:Evaluating query type: rerank_vector


Evaluating query type: rerank_vector


100%|██████████| 220/220 [01:29<00:00,  2.45it/s]
INFO:lancedb:Hit rate for rerank_vector: 0.7545454545454545


Hit rate for rerank_vector: 0.7545454545454545


INFO:lancedb:Evaluating query type: rerank_fts


Evaluating query type: rerank_fts


100%|██████████| 220/220 [01:22<00:00,  2.66it/s]
INFO:lancedb:Hit rate for rerank_fts: 0.6727272727272727


Hit rate for rerank_fts: 0.6727272727272727


INFO:lancedb:Evaluating query type: hybrid


Evaluating query type: hybrid


100%|██████████| 220/220 [01:28<00:00,  2.48it/s]
INFO:lancedb:Hit rate for hybrid: 0.7681818181818182


Hit rate for hybrid: 0.7681818181818182
vector=0.6727272727272727 fts=0.5954545454545455 rerank_vector=0.7545454545454545 rerank_fts=0.6727272727272727 hybrid=0.7681818181818182
