diff --git a/python/python/lancedb/embeddings/fine_tuner/README.md b/python/python/lancedb/embeddings/fine_tuner/README.md index 7cee0cf9..afad9a70 100644 --- a/python/python/lancedb/embeddings/fine_tuner/README.md +++ b/python/python/lancedb/embeddings/fine_tuner/README.md @@ -1,3 +1,81 @@ +### Fine-tuning workflow +The fine-tuning workflow is as follows: +1. Create a `QADataset` object. +2. Initialize any embedding function using LanceDB embedding API +3. Call `finetune` method on the embedding object with the `QADataset` object as an argument. +4. Evaluate the fine-tuned model using the `evaluate` method in the embedding API. + +# End-to-End Examples +The following is an example of how to fine-tune an embedding model using the LanceDB embedding API. + +## Example 1: Fine-tuning from a synthetic dataset + +```python +import os +import pandas as pd + +from lancedb.embeddings.fine_tuner.llm import Openai +from lancedb.embeddings.fine_tuner.dataset import QADataset, TextChunk +from lancedb.pydantic import LanceModel, Vector +from llama_index.core import SimpleDirectoryReader +from llama_index.core.node_parser import SentenceSplitter +from llama_index.core.schema import MetadataMode +from lancedb.embeddings import get_registry + +dataset = "Uber10KDataset2021" +lance_dataset_dir = dataset + "_lance" +valset_dir = dataset + "_lance_val" +finetuned_model_path = "./model_finetuned" + +# 1. Create a QADataset object. See all datasets on llama-index here: https://github.com/run-llama/llama_index/tree/main/llama-datasets + +if not os.path.exists(f"./data/{dataset}"): + os.system( + f"llamaindex-cli download-llamadataset {dataset} --download-dir ./data/{dataset}" + ) +docs = SimpleDirectoryReader(input_dir=f"./data/{dataset}/source_files").load_data() + +parser = SentenceSplitter() +nodes = parser.get_nodes_from_documents(docs) +# convert Llama-index TextNode to TextChunk +chunks = [TextChunk.from_llama_index_node(node) for node in nodes] +llm = Openai() + +if os.path.exists(lance_dataset_dir): + trainset = QADataset.load(lance_dataset_dir) +else: + trainset = QADataset.from_llm(chunks, llm, num_questions_per_chunk=2) + trainset.save(lance_dataset_dir) + +# Ideally, we should have a standard dataset for validation, but here we're just generating a synthetic dataset. +if os.path.exists(valset_dir): + valset = QADataset.load(valset_dir) +else: + valset = QADataset.from_llm(chunks, llm, num_questions_per_chunk=4) + valset.save(valset_dir) + +# 2. Initialize the embedding model +model = get_registry().get("sentence-transformers").create(name="sentence-transformers/multi-qa-MiniLM-L6-cos-v1") + +# 3. Fine-tune the model +model.finetune(trainset=trainset, path=finetuned_model_path, epochs=4) + +# 4. Evaluate the fine-tuned model +base = get_registry().get("sentence-transformers").create(name="sentence-transformers/multi-qa-MiniLM-L6-cos-v1") +base_results = base.evaluate(valset, top_k=5) + +tuned = get_registry().get("sentence-transformers").create(name=finetuned_model_path) +tuned_results = tuned.evaluate(valset, top_k=5) + +openai = get_registry().get("openai").create(name="text-embedding-3-small") +openai_results = openai.evaluate(valset, top_k=5) + + +print("openai-embedding-v3 hit-rate - ", pd.DataFrame(openai_results)["is_hit"].mean()) +print("fine-tuned hit-rate - ", pd.DataFrame(tuned_results)["is_hit"].mean()) +print("Base model hite-rate - ", pd.DataFrame(base_results)["is_hit"].mean()) +``` + Fine-tuning workflow for embeddings consists for the following parts: ### QADataset @@ -66,69 +144,4 @@ class BaseEmbeddingTuner(ABC): ``` ### Embedding API finetuning implementation -Each embedding API needs to implement `finetune` method in order to support fine-tuning. A vanilla evaluation technique has been implemented in the `BaseEmbedding` class that calculates hit_rate @ `top_k`. - -### Fine-tuning workflow -The fine-tuning workflow is as follows: -1. Create a `QADataset` object. -2. Initialize any embedding function using LanceDB embedding API -3. Call `finetune` method on the embedding object with the `QADataset` object as an argument. -4. Evaluate the fine-tuned model using the `evaluate` method in the embedding API. - -# End-to-End Examples -The following is an example of how to fine-tune an embedding model using the LanceDB embedding API. - -## Example 1: Fine-tuning from a synthetic dataset - -```python -import pandas as pd - -from lancedb.embeddings.fine_tuner.llm import Openai -from lancedb.embeddings.fine_tuner.dataset import QADataset, TextChunk -from lancedb.pydantic import LanceModel, Vector -from llama_index.core import SimpleDirectoryReader -from llama_index.core.node_parser import SentenceSplitter -from llama_index.core.schema import MetadataMode -from lancedb.embeddings import get_registry - -# 1. Create a QADataset object -url = "uber10k.pdf" -reader = SimpleDirectoryReader(input_files=url) -docs = reader.load_data() - -parser = SentenceSplitter() -nodes = parser.get_nodes_from_documents(docs) - -if os.path.exists(name): - ds = QADataset.load(name) -else: - llm = Openai() - - # convert Llama-index TextNode to TextChunk - chunks = [TextChunk.from_llama_index_node(node) for node in nodes] - - ds = QADataset.from_llm(chunks, llm) - ds.save(name) - -# 2. Initialize the embedding model -model = get_registry().get("sentence-transformers").create() - -# 3. Fine-tune the model -model.finetune(trainset=ds, path="model_finetuned", epochs=4) - -# 4. Evaluate the fine-tuned model -base = get_registry().get("sentence-transformers").create() -tuned = get_registry().get("sentence-transformers").create(name="./model_finetuned_1") -openai = get_registry().get("openai").create(name="text-embedding-3-large") - - -rs1 = base.evaluate(trainset, path="val_res") -rs2 = tuned.evaluate(trainset, path="val_res") -rs3 = openai.evaluate(trainset) - -print("openai-embedding-v3 hit-rate - ", pd.DataFrame(rs3)["is_hit"].mean()) -print("fine-tuned hit-rate - ", pd.DataFrame(rs2)["is_hit"].mean()) -print("Base model hite-rate - ", pd.DataFrame(rs1)["is_hit"].mean()) -``` - - +Each embedding API needs to implement `finetune` method in order to support fine-tuning. A vanilla evaluation technique has been implemented in the `BaseEmbedding` class that calculates hit_rate @ `top_k`. \ No newline at end of file