From bb2e624ff0e5b0eae9ced8db7ebc49a5082439e1 Mon Sep 17 00:00:00 2001 From: Ayush Chaurasia Date: Thu, 11 Jul 2024 17:34:29 +0530 Subject: [PATCH] docs: add fine tuning section in retriever guide and minor fixes (#1438) --- docs/mkdocs.yml | 4 + .../embeddings/default_embedding_functions.md | 2 +- .../guides/tuning_retrievers/1_query_types.md | 3 + .../guides/tuning_retrievers/2_reranking.md | 4 +- .../tuning_retrievers/3_embed_tuning.md | 82 + docs/src/notebooks/embedding_tuner.ipynb | 1437 +++++++++++++++++ docs/src/notebooks/lancedb_reranking.ipynb | 2 +- 7 files changed, 1531 insertions(+), 3 deletions(-) create mode 100644 docs/src/guides/tuning_retrievers/3_embed_tuning.md create mode 100644 docs/src/notebooks/embedding_tuner.ipynb diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 0413704b..38c577d7 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -105,6 +105,7 @@ nav: - Jina Reranker: reranking/jina.md - OpenAI Reranker: reranking/openai.md - Building Custom Rerankers: reranking/custom_reranker.md + - Example: notebooks/lancedb_reranking.ipynb - Filtering: sql.md - Versioning & Reproducibility: notebooks/reproducibility.ipynb - Configuring Storage: guides/storage.md @@ -112,6 +113,7 @@ nav: - Tuning retrieval performance: - Choosing right query type: guides/tuning_retrievers/1_query_types.md - Reranking: guides/tuning_retrievers/2_reranking.md + - Embedding fine-tuning: guides/tuning_retrievers/3_embed_tuning.md - 🧬 Managing embeddings: - Overview: embeddings/index.md - Embedding functions: embeddings/embedding_functions.md @@ -188,6 +190,7 @@ nav: - Jina Reranker: reranking/jina.md - OpenAI Reranker: reranking/openai.md - Building Custom Rerankers: reranking/custom_reranker.md + - Example: notebooks/lancedb_reranking.ipynb - Filtering: sql.md - Versioning & Reproducibility: notebooks/reproducibility.ipynb - Configuring Storage: guides/storage.md @@ -195,6 +198,7 @@ nav: - Tuning retrieval performance: - Choosing right query type: guides/tuning_retrievers/1_query_types.md - Reranking: guides/tuning_retrievers/2_reranking.md + - Embedding fine-tuning: guides/tuning_retrievers/3_embed_tuning.md - Managing Embeddings: - Overview: embeddings/index.md - Embedding functions: embeddings/embedding_functions.md diff --git a/docs/src/embeddings/default_embedding_functions.md b/docs/src/embeddings/default_embedding_functions.md index ae026acf..14910485 100644 --- a/docs/src/embeddings/default_embedding_functions.md +++ b/docs/src/embeddings/default_embedding_functions.md @@ -563,7 +563,7 @@ uris = [ # get each uri as bytes image_bytes = [requests.get(uri).content for uri in uris] table.add( - [{"label": labels, "image_uri": uris, "image_bytes": image_bytes}] + pd.DataFrame({"label": labels, "image_uri": uris, "image_bytes": image_bytes}) ) ``` Now we can search using text from both the default vector column and the custom vector column diff --git a/docs/src/guides/tuning_retrievers/1_query_types.md b/docs/src/guides/tuning_retrievers/1_query_types.md index 9e8d1354..73a78eac 100644 --- a/docs/src/guides/tuning_retrievers/1_query_types.md +++ b/docs/src/guides/tuning_retrievers/1_query_types.md @@ -1,4 +1,7 @@ ## Improving retriever performance + +Try it yourself - Open In Colab
+ VectorDBs are used as retreivers in recommender or chatbot-based systems for retrieving relevant data based on user queries. For example, retriever is a critical component of Retrieval Augmented Generation (RAG) acrhitectures. In this section, we will discuss how to improve the performance of retrievers. There are serveral ways to improve the performance of retrievers. Some of the common techniques are: diff --git a/docs/src/guides/tuning_retrievers/2_reranking.md b/docs/src/guides/tuning_retrievers/2_reranking.md index 8e8fce56..059169e7 100644 --- a/docs/src/guides/tuning_retrievers/2_reranking.md +++ b/docs/src/guides/tuning_retrievers/2_reranking.md @@ -1,4 +1,6 @@ -Continuing from the previous example, we can now rerank the results using more complex rerankers. +Continuing from the previous section, we can now rerank the results using more complex rerankers. + +Try it yourself - Open In Colab
## Reranking search results You can rerank any search results using a reranker. The syntax for reranking is as follows: diff --git a/docs/src/guides/tuning_retrievers/3_embed_tuning.md b/docs/src/guides/tuning_retrievers/3_embed_tuning.md new file mode 100644 index 00000000..bb42c3b3 --- /dev/null +++ b/docs/src/guides/tuning_retrievers/3_embed_tuning.md @@ -0,0 +1,82 @@ +## Finetuning the Embedding Model +Try it yourself - Open In Colab
+ +Another way to improve retriever performance is to fine-tune the embedding model itself. Fine-tuning the embedding model can help in learning better representations for the documents and queries in the dataset. This can be particularly useful when the dataset is very different from the pre-trained data used to train the embedding model. + +We'll use the same dataset as in the previous sections. Start off by splitting the dataset into training and validation sets: +```python +from sklearn.model_selection import train_test_split + +train_df, validation_df = train_test_split("data_qa.csv", test_size=0.2, random_state=42) + +train_df.to_csv("data_train.csv", index=False) +validation_df.to_csv("data_val.csv", index=False) +``` + +You can use any tuning API to fine-tune embedding models. In this example, we'll utilise Llama-index as it also comes with utilities for synthetic data generation and training the model. + + +Then parse the dataset as llama-index text nodes and generate synthetic QA pairs from each node. +```python +from llama_index.core.node_parser import SentenceSplitter +from llama_index.readers.file import PagedCSVReader +from llama_index.finetuning import generate_qa_embedding_pairs +from llama_index.core.evaluation import EmbeddingQAFinetuneDataset + +def load_corpus(file): + loader = PagedCSVReader(encoding="utf-8") + docs = loader.load_data(file=Path(file)) + + parser = SentenceSplitter() + nodes = parser.get_nodes_from_documents(docs) + + return nodes + +from llama_index.llms.openai import OpenAI + + +train_dataset = generate_qa_embedding_pairs( + llm=OpenAI(model="gpt-3.5-turbo"), nodes=train_nodes, verbose=False +) +val_dataset = generate_qa_embedding_pairs( + llm=OpenAI(model="gpt-3.5-turbo"), nodes=val_nodes, verbose=False +) +``` + +Now we'll use `SentenceTransformersFinetuneEngine` engine to fine-tune the model. You can also use `sentence-transformers` or `transformers` library to fine-tune the model. + +```python +from llama_index.finetuning import SentenceTransformersFinetuneEngine + +finetune_engine = SentenceTransformersFinetuneEngine( + train_dataset, + model_id="BAAI/bge-small-en-v1.5", + model_output_path="tuned_model", + val_dataset=val_dataset, +) +finetune_engine.finetune() +embed_model = finetune_engine.get_finetuned_model() +``` +This saves the fine tuned embedding model in `tuned_model` folder. This al + +# Evaluation results +In order to eval the retriever, you can either use this model to ingest the data into LanceDB directly or llama-index's LanceDB integration to create a `VectorStoreIndex` and use it as a retriever. +On performing the same hit-rate evaluation as before, we see a significant improvement in the hit-rate across all query types. + +### Baseline +| Query Type | Hit-rate@5 | +| --- | --- | +| Vector Search | 0.640 | +| Full-text Search | 0.595 | +| Reranked Vector Search | 0.677 | +| Reranked Full-text Search | 0.672 | +| Hybrid Search (w/ CohereReranker) | 0.759| + +### Fine-tuned model ( 2 iterations ) +| Query Type | Hit-rate@5 | +| --- | --- | +| Vector Search | 0.672 | +| Full-text Search | 0.595 | +| Reranked Vector Search | 0.754 | +| Reranked Full-text Search | 0.672| +| Hybrid Search (w/ CohereReranker) | 0.768 | diff --git a/docs/src/notebooks/embedding_tuner.ipynb b/docs/src/notebooks/embedding_tuner.ipynb new file mode 100644 index 00000000..416f5fd5 --- /dev/null +++ b/docs/src/notebooks/embedding_tuner.ipynb @@ -0,0 +1,1437 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "gpuType": "T4" + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# Improve retrieval performance by Fine-tuning embedding model\n", + "\n", + "Another way to improve retriever performance is to fine-tune the embedding model itself. Fine-tuning the embedding model can help in learning better representations for the documents and queries in the dataset. This can be particularly useful when the dataset is very different from the pre-trained data used to train the embedding model." + ], + "metadata": { + "id": "rYMbEXANHZ0B" + } + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "collapsed": true, + "id": "6T7bwebVquFE", + "outputId": "55bea6d1-631f-409e-9b7b-cb441d26102a" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "cudf-cu12 24.4.1 requires pyarrow<15.0.0a0,>=14.0.1, but you have pyarrow 12.0.1 which is incompatible.\n", + "datasets 2.20.0 requires pyarrow>=15.0.0, but you have pyarrow 12.0.1 which is incompatible.\u001b[0m\u001b[31m\n", + "\u001b[0m" + ] + } + ], + "source": [ + "%pip install llama-index-llms-openai llama-index-embeddings-openai llama-index-finetuning llama-index-readers-file scikit-learn llama-index-embeddings-huggingface llama-index-vector-stores-lancedb pyarrow==12.0.1 -qq" + ] + }, + { + "cell_type": "code", + "source": [ + "# For eval utils\n", + "!git clone https://github.com/lancedb/ragged.git\n", + "!cd ragged && pip install .\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "6RRNyCDJDEcQ", + "outputId": "bbcb0689-e82f-4593-f53c-77c3443a929d" + }, + "execution_count": 22, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Cloning into 'ragged'...\n", + "remote: Enumerating objects: 160, done.\u001b[K\n", + "remote: Counting objects: 100% (160/160), done.\u001b[K\n", + "remote: Compressing objects: 100% (103/103), done.\u001b[K\n", + "remote: Total 160 (delta 70), reused 125 (delta 41), pack-reused 0\u001b[K\n", + "Receiving objects: 100% (160/160), 38.15 KiB | 9.54 MiB/s, done.\n", + "Resolving deltas: 100% (70/70), done.\n", + "Processing /content/ragged\n", + " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", + " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "Collecting datasets (from ragged==0.1.dev0)\n", + " Downloading datasets-2.20.0-py3-none-any.whl (547 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m547.8/547.8 kB\u001b[0m \u001b[31m13.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: lancedb in /usr/local/lib/python3.10/dist-packages (from ragged==0.1.dev0) (0.9.0)\n", + "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from ragged==0.1.dev0) (2.0.3)\n", + "Collecting streamlit (from ragged==0.1.dev0)\n", + " Downloading streamlit-1.36.0-py2.py3-none-any.whl (8.6 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.6/8.6 MB\u001b[0m \u001b[31m54.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: tantivy in /usr/local/lib/python3.10/dist-packages (from ragged==0.1.dev0) (0.22.0)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from datasets->ragged==0.1.dev0) (3.15.4)\n", + "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets->ragged==0.1.dev0) (1.25.2)\n", + "Collecting pyarrow>=15.0.0 (from datasets->ragged==0.1.dev0)\n", + " Downloading pyarrow-16.1.0-cp310-cp310-manylinux_2_28_x86_64.whl (40.8 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.8/40.8 MB\u001b[0m \u001b[31m14.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: pyarrow-hotfix in /usr/local/lib/python3.10/dist-packages (from datasets->ragged==0.1.dev0) (0.6)\n", + "Collecting dill<0.3.9,>=0.3.0 (from datasets->ragged==0.1.dev0)\n", + " Downloading dill-0.3.8-py3-none-any.whl (116 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m20.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting requests>=2.32.2 (from datasets->ragged==0.1.dev0)\n", + " Downloading requests-2.32.3-py3-none-any.whl (64 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m64.9/64.9 kB\u001b[0m \u001b[31m11.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: tqdm>=4.66.3 in /usr/local/lib/python3.10/dist-packages (from datasets->ragged==0.1.dev0) (4.66.4)\n", + "Collecting xxhash (from datasets->ragged==0.1.dev0)\n", + " Downloading xxhash-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m29.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting multiprocess (from datasets->ragged==0.1.dev0)\n", + " Downloading multiprocess-0.70.16-py310-none-any.whl (134 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m24.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: fsspec[http]<=2024.5.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from datasets->ragged==0.1.dev0) (2023.6.0)\n", + "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets->ragged==0.1.dev0) (3.9.5)\n", + "Requirement already satisfied: huggingface-hub>=0.21.2 in /usr/local/lib/python3.10/dist-packages (from datasets->ragged==0.1.dev0) (0.23.4)\n", + "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets->ragged==0.1.dev0) (24.1)\n", + "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets->ragged==0.1.dev0) (6.0.1)\n", + "Requirement already satisfied: deprecation in /usr/local/lib/python3.10/dist-packages (from lancedb->ragged==0.1.dev0) (2.1.0)\n", + "Requirement already satisfied: pylance==0.13.0 in /usr/local/lib/python3.10/dist-packages (from lancedb->ragged==0.1.dev0) (0.13.0)\n", + "Requirement already satisfied: ratelimiter~=1.0 in /usr/local/lib/python3.10/dist-packages (from lancedb->ragged==0.1.dev0) (1.2.0.post0)\n", + "Requirement already satisfied: retry>=0.9.2 in /usr/local/lib/python3.10/dist-packages (from lancedb->ragged==0.1.dev0) (0.9.2)\n", + "Requirement already satisfied: pydantic>=1.10 in /usr/local/lib/python3.10/dist-packages (from lancedb->ragged==0.1.dev0) (2.8.0)\n", + "Requirement already satisfied: attrs>=21.3.0 in /usr/local/lib/python3.10/dist-packages (from lancedb->ragged==0.1.dev0) (23.2.0)\n", + "Requirement already satisfied: cachetools in /usr/local/lib/python3.10/dist-packages (from lancedb->ragged==0.1.dev0) (5.3.3)\n", + "Requirement already satisfied: overrides>=0.7 in /usr/local/lib/python3.10/dist-packages (from lancedb->ragged==0.1.dev0) (7.7.0)\n", + "Collecting pyarrow>=15.0.0 (from datasets->ragged==0.1.dev0)\n", + " Downloading pyarrow-15.0.0-cp310-cp310-manylinux_2_28_x86_64.whl (38.3 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m38.3/38.3 MB\u001b[0m \u001b[31m12.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->ragged==0.1.dev0) (2.8.2)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->ragged==0.1.dev0) (2023.4)\n", + "Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas->ragged==0.1.dev0) (2024.1)\n", + "Requirement already satisfied: altair<6,>=4.0 in /usr/local/lib/python3.10/dist-packages (from streamlit->ragged==0.1.dev0) (4.2.2)\n", + "Requirement already satisfied: blinker<2,>=1.0.0 in /usr/lib/python3/dist-packages (from streamlit->ragged==0.1.dev0) (1.4)\n", + "Requirement already satisfied: click<9,>=7.0 in /usr/local/lib/python3.10/dist-packages (from streamlit->ragged==0.1.dev0) (8.1.7)\n", + "Requirement already satisfied: pillow<11,>=7.1.0 in /usr/local/lib/python3.10/dist-packages (from streamlit->ragged==0.1.dev0) (9.4.0)\n", + "Requirement already satisfied: protobuf<6,>=3.20 in /usr/local/lib/python3.10/dist-packages (from streamlit->ragged==0.1.dev0) (3.20.3)\n", + "Requirement already satisfied: rich<14,>=10.14.0 in /usr/local/lib/python3.10/dist-packages (from streamlit->ragged==0.1.dev0) (13.7.1)\n", + "Requirement already satisfied: tenacity<9,>=8.1.0 in /usr/local/lib/python3.10/dist-packages (from streamlit->ragged==0.1.dev0) (8.3.0)\n", + "Requirement already satisfied: toml<2,>=0.10.1 in /usr/local/lib/python3.10/dist-packages (from streamlit->ragged==0.1.dev0) (0.10.2)\n", + "Requirement already satisfied: typing-extensions<5,>=4.3.0 in /usr/local/lib/python3.10/dist-packages (from streamlit->ragged==0.1.dev0) (4.12.2)\n", + "Collecting gitpython!=3.1.19,<4,>=3.0.7 (from streamlit->ragged==0.1.dev0)\n", + " Downloading GitPython-3.1.43-py3-none-any.whl (207 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.3/207.3 kB\u001b[0m \u001b[31m22.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting pydeck<1,>=0.8.0b4 (from streamlit->ragged==0.1.dev0)\n", + " Downloading pydeck-0.9.1-py2.py3-none-any.whl (6.9 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m6.9/6.9 MB\u001b[0m \u001b[31m63.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: tornado<7,>=6.0.3 in /usr/local/lib/python3.10/dist-packages (from streamlit->ragged==0.1.dev0) (6.3.3)\n", + "Collecting watchdog<5,>=2.1.5 (from streamlit->ragged==0.1.dev0)\n", + " Downloading watchdog-4.0.1-py3-none-manylinux2014_x86_64.whl (83 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m83.0/83.0 kB\u001b[0m \u001b[31m11.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: entrypoints in /usr/local/lib/python3.10/dist-packages (from altair<6,>=4.0->streamlit->ragged==0.1.dev0) (0.4)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from altair<6,>=4.0->streamlit->ragged==0.1.dev0) (3.1.4)\n", + "Requirement already satisfied: jsonschema>=3.0 in /usr/local/lib/python3.10/dist-packages (from altair<6,>=4.0->streamlit->ragged==0.1.dev0) (4.19.2)\n", + "Requirement already satisfied: toolz in /usr/local/lib/python3.10/dist-packages (from altair<6,>=4.0->streamlit->ragged==0.1.dev0) (0.12.1)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->ragged==0.1.dev0) (1.3.1)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->ragged==0.1.dev0) (1.4.1)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->ragged==0.1.dev0) (6.0.5)\n", + "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->ragged==0.1.dev0) (1.9.4)\n", + "Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->ragged==0.1.dev0) (4.0.3)\n", + "Collecting gitdb<5,>=4.0.1 (from gitpython!=3.1.19,<4,>=3.0.7->streamlit->ragged==0.1.dev0)\n", + " Downloading gitdb-4.0.11-py3-none-any.whl (62 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m62.7/62.7 kB\u001b[0m \u001b[31m10.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: annotated-types>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from pydantic>=1.10->lancedb->ragged==0.1.dev0) (0.7.0)\n", + "Requirement already satisfied: pydantic-core==2.20.0 in /usr/local/lib/python3.10/dist-packages (from pydantic>=1.10->lancedb->ragged==0.1.dev0) (2.20.0)\n", + "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->ragged==0.1.dev0) (1.16.0)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets->ragged==0.1.dev0) (3.3.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets->ragged==0.1.dev0) (3.7)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets->ragged==0.1.dev0) (2.0.7)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets->ragged==0.1.dev0) (2024.6.2)\n", + "Requirement already satisfied: decorator>=3.4.2 in /usr/local/lib/python3.10/dist-packages (from retry>=0.9.2->lancedb->ragged==0.1.dev0) (4.4.2)\n", + "Requirement already satisfied: py<2.0.0,>=1.4.26 in /usr/local/lib/python3.10/dist-packages (from retry>=0.9.2->lancedb->ragged==0.1.dev0) (1.11.0)\n", + "Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich<14,>=10.14.0->streamlit->ragged==0.1.dev0) (3.0.0)\n", + "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich<14,>=10.14.0->streamlit->ragged==0.1.dev0) (2.16.1)\n", + "Collecting smmap<6,>=3.0.1 (from gitdb<5,>=4.0.1->gitpython!=3.1.19,<4,>=3.0.7->streamlit->ragged==0.1.dev0)\n", + " Downloading smmap-5.0.1-py3-none-any.whl (24 kB)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->altair<6,>=4.0->streamlit->ragged==0.1.dev0) (2.1.5)\n", + "Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=3.0->altair<6,>=4.0->streamlit->ragged==0.1.dev0) (2023.12.1)\n", + "Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=3.0->altair<6,>=4.0->streamlit->ragged==0.1.dev0) (0.35.1)\n", + "Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=3.0->altair<6,>=4.0->streamlit->ragged==0.1.dev0) (0.18.1)\n", + "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich<14,>=10.14.0->streamlit->ragged==0.1.dev0) (0.1.2)\n", + "Building wheels for collected packages: ragged\n", + " Building wheel for ragged (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + " Created wheel for ragged: filename=ragged-0.1.dev0-py3-none-any.whl size=24662 sha256=d086741b289188a92153223fdb65db69f9297a523c7874746fd1669f7d3f9c07\n", + " Stored in directory: /tmp/pip-ephem-wheel-cache-q327t6y_/wheels/aa/3f/b0/d70e6f86074491db9b0bc7431c11f0138f2ed2359151509cf7\n", + "Successfully built ragged\n", + "Installing collected packages: xxhash, watchdog, smmap, requests, pyarrow, dill, pydeck, multiprocess, gitdb, gitpython, datasets, streamlit, ragged\n", + " Attempting uninstall: requests\n", + " Found existing installation: requests 2.31.0\n", + " Uninstalling requests-2.31.0:\n", + " Successfully uninstalled requests-2.31.0\n", + " Attempting uninstall: pyarrow\n", + " Found existing installation: pyarrow 12.0.1\n", + " Uninstalling pyarrow-12.0.1:\n", + " Successfully uninstalled pyarrow-12.0.1\n", + "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "cudf-cu12 24.4.1 requires pyarrow<15.0.0a0,>=14.0.1, but you have pyarrow 15.0.0 which is incompatible.\n", + "google-colab 1.0.0 requires requests==2.31.0, but you have requests 2.32.3 which is incompatible.\u001b[0m\u001b[31m\n", + "\u001b[0mSuccessfully installed datasets-2.20.0 dill-0.3.8 gitdb-4.0.11 gitpython-3.1.43 multiprocess-0.70.16 pyarrow-15.0.0 pydeck-0.9.1 ragged-0.1.dev0 requests-2.32.3 smmap-5.0.1 streamlit-1.36.0 watchdog-4.0.1 xxhash-3.4.1\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "## The dataset\n", + "The dataset we'll use is a synthetic QA dataset generated from LLama2 review paper. The paper was divided into chunks, with each chunk being a unique context. An LLM was prompted to ask questions relevant to the context for testing a retreiver.\n", + "The exact code and other utility functions for this can be found in [this](https://github.com/lancedb/ragged) repo\n" + ], + "metadata": { + "id": "B_2S_b0c3pdp" + } + }, + { + "cell_type": "code", + "source": [ + "!wget https://raw.githubusercontent.com/AyushExel/assets/main/data_qa.csv" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "4QFDh3jD3d1X", + "outputId": "642f53c8-a084-4c34-db6a-bfee35abbd28" + }, + "execution_count": 8, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "--2024-07-09 20:37:46-- https://raw.githubusercontent.com/AyushExel/assets/main/data_qa.csv\n", + "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.108.133, 185.199.109.133, ...\n", + "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 680439 (664K) [text/plain]\n", + "Saving to: β€˜data_qa.csv’\n", + "\n", + "data_qa.csv 100%[===================>] 664.49K --.-KB/s in 0.006s \n", + "\n", + "2024-07-09 20:37:47 (100 MB/s) - β€˜data_qa.csv’ saved [680439/680439]\n", + "\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "import pandas as pd\n", + "\n", + "data = pd.read_csv(\"data_qa.csv\")" + ], + "metadata": { + "id": "AIF2zczc3kwW" + }, + "execution_count": 9, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Pre-processing\n", + "Now we need to parse the context(corpus) of the dataset as llama-index text nodes. " + ], + "metadata": { + "id": "_xV40VSy3twE" + } + }, + { + "cell_type": "code", + "source": [ + "from pathlib import Path\n", + "from llama_index.core.node_parser import SentenceSplitter\n", + "from llama_index.readers.file import PagedCSVReader\n", + "\n", + "def load_corpus(file, verbose=False):\n", + " if verbose:\n", + " print(f\"Loading files {file}...\")\n", + "\n", + " loader = PagedCSVReader(encoding=\"utf-8\")\n", + " docs = loader.load_data(file=Path(file))\n", + "\n", + " if verbose:\n", + " print(f\"Loaded {len(docs)} docs\")\n", + "\n", + " parser = SentenceSplitter()\n", + " nodes = parser.get_nodes_from_documents(docs, show_progress=verbose)\n", + "\n", + " if verbose:\n", + " print(f\"Parsed {len(nodes)} nodes\")\n", + "\n", + " return nodes" + ], + "metadata": { + "id": "mzDZYUX4qxBC" + }, + "execution_count": 10, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "import pandas as pd\n", + "\n", + "df = pd.read_csv(\"data_qa.csv\", index_col=0)" + ], + "metadata": { + "id": "eoLOdNO-4HbV" + }, + "execution_count": 11, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "import os\n", + "\n", + "os.environ[\"OPENAI_API_KEY\"] = \"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\"" + ], + "metadata": { + "id": "EqsFZ5KYqzvg" + }, + "execution_count": 12, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Split into train and validation sets. We'll use the original df for val as that has different queries generated via a different prompt.\n" + ], + "metadata": { + "id": "zrwa35x96FLZ" + } + }, + { + "cell_type": "code", + "source": [ + "from sklearn.model_selection import train_test_split\n", + "\n", + "# Randomly shuffle df.\n", + "#df = df.sample(frac=1, random_state=42)\n", + "\n", + "train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)\n", + "\n", + "train_df.to_csv(\"train_data_qa.csv\", index=False)\n", + "val_df.to_csv(\"val_data_qa.csv\", index=False)" + ], + "metadata": { + "id": "diHhY9Ipq9Uw" + }, + "execution_count": 13, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "train_nodes = load_corpus(\"train_data_qa.csv\", verbose=True)\n", + "val_nodes = load_corpus(\"val_data_qa.csv\", verbose=True)" + ], + "metadata": { + "id": "C7PKGtXPq_Fc", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 188, + "referenced_widgets": [ + "3c85bfeaccc84a47844c770fa1fb2511", + "7461a200b0634607ac479708e3cba537", + "156a3c94ba094fbf86e70681c69ca31a", + "4261378f06ed48cc8cef251cd2c096ab", + "59aeeeae529a440fab4c231501fce4f6", + "1308faaa9fa944b2b17e96e8cd9a9445", + "83c0d87febcc4dfaa95b4d3e2005a416", + "f503b1be4e2e42c8bf4460eea2f1bb07", + "8228bf5a569844d584003446649731a6", + "2255fb2d83734ef88843ffe47116da84", + "9f30b1969bf24b86b8deaa41ea7231f6", + "f55c2f3c448741819e618b44bc0b1976", + "b0bad294bb6443388b77f854c4f77569", + "812696b2a65c4ca281da45f286ab95cf", + "9a026ebe3c8b416e9c5c4d7dd05bba66", + "af4dfd45973d466cb5c78d002c723cd6", + "211ee4b118154b0a94cbc686fdf90c55", + "3929b1c14657468792c74c6610598af5", + "222f778312d745aebc6f1d33c651dca8", + "d65d92433f304e389e0ce8aa7baf7155", + "a81c89c5c5a64ad0938d8d1e9789838c", + "95a2e33e8d244d24813b386e60301a2a" + ] + }, + "outputId": "bcb428bd-5d02-444c-e456-22260402faa8" + }, + "execution_count": 14, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Loading files train_data_qa.csv...\n", + "Loaded 176 docs\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Parsing nodes: 0%| | 0/176 [00:00