From 9428c6b565f517bc2eac4e6030d706ed033ffb27 Mon Sep 17 00:00:00 2001 From: ayush chaurasia Date: Mon, 15 Apr 2024 16:59:16 +0530 Subject: [PATCH] update --- .../benchmarks/llama-index-datasets.py | 129 ++++++++++----- .../embeddings/sentence_transformers.py | 153 +++++++++++++++++- 2 files changed, 243 insertions(+), 39 deletions(-) diff --git a/python/python/lancedb/embeddings/fine_tuner/benchmarks/llama-index-datasets.py b/python/python/lancedb/embeddings/fine_tuner/benchmarks/llama-index-datasets.py index 01a474e4..281bad33 100644 --- a/python/python/lancedb/embeddings/fine_tuner/benchmarks/llama-index-datasets.py +++ b/python/python/lancedb/embeddings/fine_tuner/benchmarks/llama-index-datasets.py @@ -15,25 +15,33 @@ import time import wandb - import pandas as pd + def get_paths_from_dataset(dataset: str, split=True): """ Returns paths of: - downloaded dataset, lance train dataset, lance test dataset, finetuned model """ if split: - return f"./data/{dataset}", f"./data/{dataset}_lance_train", f"./data/{dataset}_lance_test", f"./data/tuned_{dataset}" + return ( + f"./data/{dataset}", + f"./data/{dataset}_lance_train", + f"./data/{dataset}_lance_test", + f"./data/tuned_{dataset}", + ) return f"./data/{dataset}", f"./data/{dataset}_lance", f"./data/tuned_{dataset}" + def get_llama_dataset(dataset: str): """ returns: - nodes, documents, rag_dataset """ if not os.path.exists(f"./data/{dataset}"): - os.system(f"llamaindex-cli download-llamadataset {dataset} --download-dir ./data/{dataset}") + os.system( + f"llamaindex-cli download-llamadataset {dataset} --download-dir ./data/{dataset}" + ) rag_dataset = LabelledRagDataset.from_json(f"./data/{dataset}/rag_dataset.json") docs = SimpleDirectoryReader(input_dir=f"./data/{dataset}/source_files").load_data() @@ -42,6 +50,7 @@ def get_llama_dataset(dataset: str): return nodes, docs, rag_dataset + def lance_dataset_from_llama_nodes(nodes: list, name: str, split=True): llm = Openai() chunks = [TextChunk.from_llama_index_node(node) for node in nodes] @@ -53,12 +62,14 @@ def lance_dataset_from_llama_nodes(nodes: list, name: str, split=True): ds = QADataset.from_llm(chunks, llm) ds.save(f"./data/{name}_lance") return ds - - if os.path.exists(f"./data/{name}_lance_train") and os.path.exists(f"./data/{name}_lance_test"): + + if os.path.exists(f"./data/{name}_lance_train") and os.path.exists( + f"./data/{name}_lance_test" + ): train_ds = QADataset.load(f"./data/{name}_lance_train") test_ds = QADataset.load(f"./data/{name}_lance_test") return train_ds, test_ds - # split chunks random + # split chunks random train_size = int(len(chunks) * 0.65) train_chunks = chunks[:train_size] test_chunks = chunks[train_size:] @@ -69,29 +80,29 @@ def lance_dataset_from_llama_nodes(nodes: list, name: str, split=True): return train_ds, test_ds - - - - -def finetune(trainset: str, model: str, epochs: int, path: str, valset: str = None, top_k=5): +def finetune( + trainset: str, model: str, epochs: int, path: str, valset: str = None, top_k=5 +): print(f"Finetuning {model} for {epochs} epochs") print(f"trainset query instances: {len(trainset.queries)}") print(f"valset query instances: {len(valset.queries)}") valset = valset if valset is not None else trainset model = get_registry().get("sentence-transformers").create(name=model) - base_result = model.evaluate(valset, path=f"./data/eva/", top_k=top_k) + base_result = model.evaluate(valset, path=f"./data/eval/", top_k=top_k) base_hit_rate = pd.DataFrame(base_result)["is_hit"].mean() model.finetune(trainset=trainset, valset=valset, path=path, epochs=epochs) tuned = get_registry().get("sentence-transformers").create(name=path) - tuned_result = tuned.evaluate(valset, path=f"./data/eva/{str(time.time())}", top_k=top_k) + tuned_result = tuned.evaluate( + valset, path=f"./data/eval/{str(time.time())}", top_k=top_k + ) tuned_hit_rate = pd.DataFrame(tuned_result)["is_hit"].mean() return base_hit_rate, tuned_hit_rate -def eval_rag(dataset: str, model: str): +def do_eval_rag(dataset: str, model: str): # Requires - pip install llama-index-vector-stores-lancedb # Requires - pip install llama-index-embeddings-huggingface nodes, docs, rag_dataset = get_llama_dataset(dataset) @@ -101,8 +112,8 @@ def eval_rag(dataset: str, model: str): storage_context = StorageContext.from_defaults(vector_store=vector_store) service_context = ServiceContext.from_defaults(embed_model=embed_model) index = VectorStoreIndex( - nodes, - service_context=service_context, + nodes, + service_context=service_context, show_progress=True, storage_context=storage_context, ) @@ -112,9 +123,7 @@ def eval_rag(dataset: str, model: str): query_engine = index.as_query_engine() # evaluate using the RagEvaluatorPack - RagEvaluatorPack = download_llama_pack( - "RagEvaluatorPack", "./rag_evaluator_pack" - ) + RagEvaluatorPack = download_llama_pack("RagEvaluatorPack", "./rag_evaluator_pack") rag_evaluator_pack = RagEvaluatorPack( rag_dataset=rag_dataset, query_engine=query_engine ) @@ -126,36 +135,69 @@ def eval_rag(dataset: str, model: str): return metrics -def main(dataset, model, epochs, top_k=5, eval_rag=False, project: str = "lancedb_finetune"): + +def main( + dataset, + model, + epochs, + top_k=5, + eval_rag=False, + split=True, + project: str = "lancedb_finetune", +): nodes, _, _ = get_llama_dataset(dataset) - trainset, valset = lance_dataset_from_llama_nodes(nodes, dataset) - data_path, lance_train_path, lance_test_path, tuned_path = get_paths_from_dataset(dataset, split=True) - - base_hit_rate, tuned_hit_rate = finetune(trainset, model, epochs, tuned_path, valset, top_k=top_k) + trainset = None + valset = None + if split: + trainset, valset = lance_dataset_from_llama_nodes(nodes, dataset, split) + data_path, lance_train_path, lance_test_path, tuned_path = ( + get_paths_from_dataset(dataset, split=split) + ) + else: + trainset = lance_dataset_from_llama_nodes(nodes, dataset, split) + valset = trainset + data_path, lance_path, tuned_path = get_paths_from_dataset(dataset, split=split) + + base_hit_rate, tuned_hit_rate = finetune( + trainset, model, epochs, tuned_path, valset, top_k=top_k + ) # Base model model metrics - metrics = eval_rag(dataset, model) + metrics = do_eval_rag(dataset, model) if eval_rag else {} # Tuned model metrics - metrics_tuned = eval_rag(dataset, tuned_path) - - wandb.init(project="lancedb_finetune", name=f"{dataset}_{model}_{epochs}") - wandb.log({ - "hit_rate": tuned_hit_rate, - }) + metrics_tuned = do_eval_rag(dataset, tuned_path) if eval_rag else {} + + wandb.init(project="lancedb_finetune", name=f"{dataset}_{model}_{epochs}") + wandb.log( + { + "hit_rate": tuned_hit_rate, + } + ) wandb.log(metrics_tuned) wandb.finish() - wandb.init(project="lancedb_finetune", name=f"{dataset}_{model}_base") - wandb.log({ - "hit_rate": base_hit_rate, - }) + wandb.init(project="lancedb_finetune", name=f"{dataset}_{model}_base") + wandb.log( + { + "hit_rate": base_hit_rate, + } + ) wandb.log(metrics) wandb.finish() def banchmark_all(): - datasets = ["Uber10KDataset2021", "MiniTruthfulQADataset", "MiniSquadV2Dataset", "MiniEsgBenchDataset", "MiniCovidQaDataset", "Llama2PaperDataset", "HistoryOfAlexnetDataset", "PatronusAIFinanceBenchDataset"] + datasets = [ + "Uber10KDataset2021", + "MiniTruthfulQADataset", + "MiniSquadV2Dataset", + "MiniEsgBenchDataset", + "MiniCovidQaDataset", + "Llama2PaperDataset", + "HistoryOfAlexnetDataset", + "PatronusAIFinanceBenchDataset", + ] models = ["BAAI/bge-small-en-v1.5"] top_ks = [5] for top_k in top_ks: @@ -163,19 +205,30 @@ def banchmark_all(): for dataset in datasets: main(dataset, model, 5, top_k=top_k) + if __name__ == "__main__": import argparse + parser = argparse.ArgumentParser() parser.add_argument("--dataset", type=str, default="BraintrustCodaHelpDeskDataset") parser.add_argument("--model", type=str, default="BAAI/bge-small-en-v1.5") parser.add_argument("--epochs", type=int, default=4) parser.add_argument("--project", type=str, default="lancedb_finetune") parser.add_argument("--top_k", type=int, default=5) - parser.add_argument("--eval-rag", action="store_true") + parser.add_argument("--split", type=int, default=1) + parser.add_argument("--eval-rag", action="store_true", default=False) parser.add_argument("--benchmark-all", action="store_true") args = parser.parse_args() if args.benchmark_all: banchmark_all() else: - main(args.dataset, args.model, args.epochs, args.top_k, args.eval_rag, args.project) + main( + args.dataset, + args.model, + args.epochs, + args.top_k, + args.eval_rag, + args.split, + args.project, + ) diff --git a/python/python/lancedb/embeddings/sentence_transformers.py b/python/python/lancedb/embeddings/sentence_transformers.py index 97fe1318..de40507c 100644 --- a/python/python/lancedb/embeddings/sentence_transformers.py +++ b/python/python/lancedb/embeddings/sentence_transformers.py @@ -10,12 +10,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Union +from typing import Any, List, Optional, Union import numpy as np +import logging +from lancedb.embeddings.fine_tuner import QADataset from ..util import attempt_import_or_raise from .base import TextEmbeddingFunction +from .fine_tuner.basetuner import BaseEmbeddingTuner from .registry import register from .utils import weak_lru @@ -80,3 +83,151 @@ class SentenceTransformerEmbeddings(TextEmbeddingFunction): "sentence_transformers", "sentence-transformers" ) return sentence_transformers.SentenceTransformer(self.name, device=self.device) + + def finetune(self, trainset: QADataset, *args, **kwargs): + """ + Finetune the Sentence Transformers model + + Parameters + ---------- + dataset: QADataset + The dataset to use for finetuning + """ + tuner = SentenceTransformersTuner( + model=self.embedding_model, + trainset=trainset, + **kwargs, + ) + tuner.finetune() + + +class SentenceTransformersTuner(BaseEmbeddingTuner): + """Sentence Transformers Embedding Finetuning Engine.""" + + def __init__( + self, + model: Any, + trainset: QADataset, + valset: Optional[QADataset] = None, + path: Optional[str] = "~/.lancedb/embeddings/models", + batch_size: int = 8, + epochs: int = 1, + show_progress: bool = True, + eval_steps: int = 50, + max_input_per_doc: int = -1, + loss: Optional[Any] = None, + evaluator: Optional[Any] = None, + run_name: Optional[str] = None, + log_wandb: bool = False, + ) -> None: + """ + Parameters + ---------- + model: str + The model to use for finetuning. + trainset: QADataset + The training dataset. + valset: Optional[QADataset] + The validation dataset. + path: Optional[str] + The path to save the model. + batch_size: int, default=8 + The batch size. + epochs: int, default=1 + The number of epochs. + show_progress: bool, default=True + Whether to show progress. + eval_steps: int, default=50 + The number of steps to evaluate. + max_input_per_doc: int, default=-1 + The number of input per document. + if -1, use all documents. + """ + from sentence_transformers import InputExample, losses + from sentence_transformers.evaluation import InformationRetrievalEvaluator + from torch.utils.data import DataLoader + + self.model = model + self.trainset = trainset + self.valset = valset + self.path = path + self.batch_size = batch_size + self.epochs = epochs + self.show_progress = show_progress + self.eval_steps = eval_steps + self.max_input_per_doc = max_input_per_doc + self.evaluator = None + self.epochs = epochs + self.show_progress = show_progress + self.eval_steps = eval_steps + self.run_name = run_name + self.log_wandb = log_wandb + + if self.max_input_per_doc < -1: + raise ValueError("max_input_per_doc must be -1 or greater than 0.") + + examples: Any = [] + for query_id, query in self.trainset.queries.items(): + if max_input_per_doc == -1: + for node_id in self.trainset.relevant_docs[query_id]: + text = self.trainset.corpus[node_id] + example = InputExample(texts=[query, text]) + examples.append(example) + else: + node_id = self.trainset.relevant_docs[query_id][ + min(max_input_per_doc, len(self.trainset.relevant_docs[query_id])) + ] + text = self.trainset.corpus[node_id] + example = InputExample(texts=[query, text]) + examples.append(example) + + self.examples = examples + + self.loader: DataLoader = DataLoader(examples, batch_size=batch_size) + + if self.valset is not None: + eval_engine = evaluator or InformationRetrievalEvaluator + self.evaluator = eval_engine( + valset.queries, valset.corpus, valset.relevant_docs + ) + self.evaluator = evaluator + + # define loss + self.loss = loss or losses.MultipleNegativesRankingLoss(self.model) + self.warmup_steps = int(len(self.loader) * epochs * 0.1) + + def finetune(self) -> None: + """Finetune the Sentence Transformers model.""" + self.model.fit( + train_objectives=[(self.loader, self.loss)], + epochs=self.epochs, + warmup_steps=self.warmup_steps, + output_path=self.path, + show_progress_bar=self.show_progress, + evaluator=self.evaluator, + evaluation_steps=self.eval_steps, + callback=self._wandb_callback if self.log_wandb else None, + ) + + self.helper() + + def helper(self) -> None: + """A helper method.""" + logging.info("Finetuning complete.") + logging.info(f"Model saved to {self.path}.") + logging.info("You can now use the model as follows:") + logging.info( + f"model = get_registry().get('sentence-transformers').create(name='./{self.path}')" # noqa + ) + + def _wandb_callback(self, score, epoch, steps): + try: + import wandb + except ImportError: + raise ImportError( + "wandb is not installed. Please install it using `pip install wandb`" + ) + run = wandb.run or wandb.init( + project="sbert_lancedb_finetune", name=self.run_name + ) + run.log({"epoch": epoch, "steps": steps, "score": score})