Compare commits

..

1 Commits

Author SHA1 Message Date
ayush chaurasia
45f2d9976b update 2024-04-11 16:24:29 +05:30
2 changed files with 195 additions and 33 deletions

View File

@@ -0,0 +1,188 @@
import os
from llama_index.core import SimpleDirectoryReader
from llama_index.core.llama_dataset import LabelledRagDataset
from llama_index.core.node_parser import SentenceSplitter
from lancedb.embeddings.fine_tuner.dataset import QADataset, TextChunk
from lancedb.embeddings.fine_tuner.llm import Openai
from lancedb.embeddings import get_registry
from llama_index.core import VectorStoreIndex
from llama_index.vector_stores.lancedb import LanceDBVectorStore
from llama_index.core import ServiceContext, VectorStoreIndex, StorageContext
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core.llama_pack import download_llama_pack
import time
import wandb
import pandas as pd
def get_paths_from_dataset(dataset: str, split=True):
"""
Returns paths of:
- downloaded dataset, lance train dataset, lance test dataset, finetuned model
"""
if split:
return f"./data/{dataset}", f"./data/{dataset}_lance_train", f"./data/{dataset}_lance_test", f"./data/tuned_{dataset}"
return f"./data/{dataset}", f"./data/{dataset}_lance", f"./data/tuned_{dataset}"
def get_llama_dataset(dataset: str):
"""
returns:
- nodes, documents, rag_dataset
"""
if not os.path.exists(f"./data/{dataset}"):
os.system(f"llamaindex-cli download-llamadataset {dataset} --download-dir ./data/{dataset}")
rag_dataset = LabelledRagDataset.from_json(f"./data/{dataset}/rag_dataset.json")
docs = SimpleDirectoryReader(input_dir=f"./data/{dataset}/source_files").load_data()
parser = SentenceSplitter()
nodes = parser.get_nodes_from_documents(docs)
return nodes, docs, rag_dataset
def lance_dataset_from_llama_nodes(nodes: list, name: str, split=True):
llm = Openai()
chunks = [TextChunk.from_llama_index_node(node) for node in nodes]
# train test split 75-35
if not split:
if os.path.exists(f"./data/{name}_lance"):
ds = QADataset.load(f"./data/{name}_lance")
return ds
ds = QADataset.from_llm(chunks, llm)
ds.save(f"./data/{name}_lance")
return ds
if os.path.exists(f"./data/{name}_lance_train") and os.path.exists(f"./data/{name}_lance_test"):
train_ds = QADataset.load(f"./data/{name}_lance_train")
test_ds = QADataset.load(f"./data/{name}_lance_test")
return train_ds, test_ds
# split chunks random
train_size = int(len(chunks) * 0.65)
train_chunks = chunks[:train_size]
test_chunks = chunks[train_size:]
train_ds = QADataset.from_llm(train_chunks, llm)
test_ds = QADataset.from_llm(test_chunks, llm)
train_ds.save(f"./data/{name}_lance_train")
test_ds.save(f"./data/{name}_lance_test")
return train_ds, test_ds
def finetune(trainset: str, model: str, epochs: int, path: str, valset: str = None):
print(f"Finetuning {model} for {epochs} epochs")
print(f"trainset query instances: {len(trainset.queries)}")
print(f"valset query instances: {len(valset.queries)}")
valset = valset if valset is not None else trainset
model = get_registry().get("sentence-transformers").create(name=model)
base_result = model.evaluate(valset, path=f"./data/eva/")
base_hit_rate = pd.DataFrame(base_result)["is_hit"].mean()
model.finetune(trainset=trainset, valset=valset, path=path, epochs=epochs)
tuned = get_registry().get("sentence-transformers").create(name=path)
tuned_result = tuned.evaluate(valset, path=f"./data/eva/{str(time.time())}")
tuned_hit_rate = pd.DataFrame(tuned_result)["is_hit"].mean()
return base_hit_rate, tuned_hit_rate
def eval_rag(dataset: str, model: str):
# Requires - pip install llama-index-vector-stores-lancedb
# Requires - pip install llama-index-embeddings-huggingface
nodes, docs, rag_dataset = get_llama_dataset(dataset)
embed_model = HuggingFaceEmbedding(model)
vector_store = LanceDBVectorStore(uri="/tmp/lancedb")
storage_context = StorageContext.from_defaults(vector_store=vector_store)
service_context = ServiceContext.from_defaults(embed_model=embed_model)
index = VectorStoreIndex(
nodes,
service_context=service_context,
show_progress=True,
storage_context=storage_context,
)
# build basic RAG system
index = VectorStoreIndex.from_documents(documents=docs)
query_engine = index.as_query_engine()
# evaluate using the RagEvaluatorPack
RagEvaluatorPack = download_llama_pack(
"RagEvaluatorPack", "./rag_evaluator_pack"
)
rag_evaluator_pack = RagEvaluatorPack(
rag_dataset=rag_dataset, query_engine=query_engine
)
metrics = rag_evaluator_pack.run(
batch_size=20, # batches the number of openai api calls to make
sleep_time_in_seconds=1, # seconds to sleep before making an api call
)
return metrics
def main(dataset, model, epochs, project: str = "lancedb_finetune"):
nodes, _, _ = get_llama_dataset(dataset)
trainset, valset = lance_dataset_from_llama_nodes(nodes, dataset)
data_path, lance_train_path, lance_test_path, tuned_path = get_paths_from_dataset(dataset, split=True)
base_hit_rate, tuned_hit_rate = finetune(trainset, model, epochs, tuned_path, valset)
""""
wandb.init(project="lancedb_finetune", name=f"{dataset}_{model}_base")
wandb.log({
"hit_rate": base_hit_rate,
})
wandb.finish()
wandb.init(project="lancedb_finetune", name=f"{dataset}_{model}_{epochs}")
wandb.log({
"hit_rate": tuned_hit_rate,
})
wandb.finish()
"""
# Base model model metrics
metrics = eval_rag(dataset, model)
# Tuned model metrics
metrics_tuned = eval_rag(dataset, tuned_path)
wandb.init(project="lancedb_rageval", name=f"{dataset}_{model}_base")
import pdb; pdb.set_trace()
wandb.log(
metrics
)
wandb.finish()
wandb.init(project="lancedb_rageval", name=f"{dataset}_{model}_{epochs}")
wandb.log(metrics_tuned)
def banchmark_all():
datasets = [#"Uber10KDataset2021", "MiniTruthfulQADataset", "MiniSquadV2Dataset", "MiniEsgBenchDataset", "MiniCovidQaDataset", "Llama2PaperDataset", "HistoryOfAlexnetDataset",
"PatronusAIFinanceBenchDataset"]
models = ["BAAI/bge-small-en-v1.5"]
for model in models:
for dataset in datasets:
main(dataset, model, 5)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default="BraintrustCodaHelpDeskDataset")
parser.add_argument("--model", type=str, default="BAAI/bge-small-en-v1.5")
parser.add_argument("--epochs", type=int, default=4)
parser.add_argument("--project", type=str, default="lancedb_finetune")
parser.add_argument("--benchmark-all", action="store_true")
args = parser.parse_args()
if args.benchmark_all:
banchmark_all()
else:
main(args.dataset, args.model, args.epochs, args.project)

View File

@@ -1,13 +1,13 @@
import re
import uuid
from pathlib import Path
from typing import Any, Dict, List, Tuple, Optional
from typing import Any, Dict, List, Tuple
import lance
import pyarrow as pa
from pydantic import BaseModel
from tqdm import tqdm
from lancedb.utils.general import LOGGER
from .llm import BaseLLM
DEFAULT_PROMPT_TMPL = """\
@@ -37,7 +37,7 @@ class QADataset(BaseModel):
relevant_docs (Dict[str, List[str]]): Dict query id -> list of doc ids.
"""
path: Optional[str] = None
queries: Dict[str, str] # id -> query
corpus: Dict[str, str] # id -> text
relevant_docs: Dict[str, List[str]] # query id -> list of retrieved doc ids
@@ -53,7 +53,6 @@ class QADataset(BaseModel):
def save(self, path: str, mode: str = "overwrite") -> None:
"""Save to lance dataset"""
self.path = path
save_dir = Path(path)
save_dir.mkdir(parents=True, exist_ok=True)
@@ -87,28 +86,20 @@ class QADataset(BaseModel):
)
@classmethod
def load(cls, path: str, version: Optional[int] = None) -> "QADataset":
def load(cls, path: str) -> "QADataset":
"""Load from .lance data"""
load_dir = Path(path)
queries = lance.dataset(load_dir / "queries.lance", version=version).to_table().to_pydict()
corpus = lance.dataset(load_dir / "corpus.lance", version=version).to_table().to_pydict()
queries = lance.dataset(load_dir / "queries.lance").to_table().to_pydict()
corpus = lance.dataset(load_dir / "corpus.lance").to_table().to_pydict()
relevant_docs = (
lance.dataset(load_dir / "relevant_docs.lance", version=version).to_table().to_pydict()
lance.dataset(load_dir / "relevant_docs.lance").to_table().to_pydict()
)
return cls(
path=str(path),
queries=dict(zip(queries["id"], queries["query"])),
corpus=dict(zip(corpus["id"], corpus["text"])),
relevant_docs=dict(zip(relevant_docs["query_id"], relevant_docs["doc_id"])),
)
@classmethod
def switch_version(cls, version: int) -> "QADataset":
"""Switch version of a dataset."""
if not cls.path:
raise ValueError("Path not set. You need to call save() first.")
return cls.load(cls.path, version=version)
# generate queries as a convenience function
@classmethod
def from_llm(
@@ -151,23 +142,6 @@ class QADataset(BaseModel):
"""Create a QADataset from a list of TextChunks and a list of questions."""
node_dict = {node.id: node.text for node in docs}
return cls(queries=queries, corpus=node_dict, relevant_docs=relevant_docs)
def versions(self) -> List[int]:
"""Get the versions of the dataset."""
# TODO: tidy this up
data_paths = self._get_data_file_paths()
return lance.dataset(data_paths[0]).versions()
def _get_data_file_paths(self) -> str:
"""Get the absolute path of the dataset."""
queries = self.path / "queries.lance"
corpus = self.path / "corpus.lance"
relevant_docs = self.path / "relevant_docs.lance"
return queries, corpus, relevant_docs
class TextChunk(BaseModel):