mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-04 10:52:56 +00:00
update usage
This commit is contained in:
@@ -1,3 +1,81 @@
|
||||
### Fine-tuning workflow
|
||||
The fine-tuning workflow is as follows:
|
||||
1. Create a `QADataset` object.
|
||||
2. Initialize any embedding function using LanceDB embedding API
|
||||
3. Call `finetune` method on the embedding object with the `QADataset` object as an argument.
|
||||
4. Evaluate the fine-tuned model using the `evaluate` method in the embedding API.
|
||||
|
||||
# End-to-End Examples
|
||||
The following is an example of how to fine-tune an embedding model using the LanceDB embedding API.
|
||||
|
||||
## Example 1: Fine-tuning from a synthetic dataset
|
||||
|
||||
```python
|
||||
import os
|
||||
import pandas as pd
|
||||
|
||||
from lancedb.embeddings.fine_tuner.llm import Openai
|
||||
from lancedb.embeddings.fine_tuner.dataset import QADataset, TextChunk
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from llama_index.core.schema import MetadataMode
|
||||
from lancedb.embeddings import get_registry
|
||||
|
||||
dataset = "Uber10KDataset2021"
|
||||
lance_dataset_dir = dataset + "_lance"
|
||||
valset_dir = dataset + "_lance_val"
|
||||
finetuned_model_path = "./model_finetuned"
|
||||
|
||||
# 1. Create a QADataset object. See all datasets on llama-index here: https://github.com/run-llama/llama_index/tree/main/llama-datasets
|
||||
|
||||
if not os.path.exists(f"./data/{dataset}"):
|
||||
os.system(
|
||||
f"llamaindex-cli download-llamadataset {dataset} --download-dir ./data/{dataset}"
|
||||
)
|
||||
docs = SimpleDirectoryReader(input_dir=f"./data/{dataset}/source_files").load_data()
|
||||
|
||||
parser = SentenceSplitter()
|
||||
nodes = parser.get_nodes_from_documents(docs)
|
||||
# convert Llama-index TextNode to TextChunk
|
||||
chunks = [TextChunk.from_llama_index_node(node) for node in nodes]
|
||||
llm = Openai()
|
||||
|
||||
if os.path.exists(lance_dataset_dir):
|
||||
trainset = QADataset.load(lance_dataset_dir)
|
||||
else:
|
||||
trainset = QADataset.from_llm(chunks, llm, num_questions_per_chunk=2)
|
||||
trainset.save(lance_dataset_dir)
|
||||
|
||||
# Ideally, we should have a standard dataset for validation, but here we're just generating a synthetic dataset.
|
||||
if os.path.exists(valset_dir):
|
||||
valset = QADataset.load(valset_dir)
|
||||
else:
|
||||
valset = QADataset.from_llm(chunks, llm, num_questions_per_chunk=4)
|
||||
valset.save(valset_dir)
|
||||
|
||||
# 2. Initialize the embedding model
|
||||
model = get_registry().get("sentence-transformers").create(name="sentence-transformers/multi-qa-MiniLM-L6-cos-v1")
|
||||
|
||||
# 3. Fine-tune the model
|
||||
model.finetune(trainset=trainset, path=finetuned_model_path, epochs=4)
|
||||
|
||||
# 4. Evaluate the fine-tuned model
|
||||
base = get_registry().get("sentence-transformers").create(name="sentence-transformers/multi-qa-MiniLM-L6-cos-v1")
|
||||
base_results = base.evaluate(valset, top_k=5)
|
||||
|
||||
tuned = get_registry().get("sentence-transformers").create(name=finetuned_model_path)
|
||||
tuned_results = tuned.evaluate(valset, top_k=5)
|
||||
|
||||
openai = get_registry().get("openai").create(name="text-embedding-3-small")
|
||||
openai_results = openai.evaluate(valset, top_k=5)
|
||||
|
||||
|
||||
print("openai-embedding-v3 hit-rate - ", pd.DataFrame(openai_results)["is_hit"].mean())
|
||||
print("fine-tuned hit-rate - ", pd.DataFrame(tuned_results)["is_hit"].mean())
|
||||
print("Base model hite-rate - ", pd.DataFrame(base_results)["is_hit"].mean())
|
||||
```
|
||||
|
||||
Fine-tuning workflow for embeddings consists for the following parts:
|
||||
|
||||
### QADataset
|
||||
@@ -66,69 +144,4 @@ class BaseEmbeddingTuner(ABC):
|
||||
```
|
||||
|
||||
### Embedding API finetuning implementation
|
||||
Each embedding API needs to implement `finetune` method in order to support fine-tuning. A vanilla evaluation technique has been implemented in the `BaseEmbedding` class that calculates hit_rate @ `top_k`.
|
||||
|
||||
### Fine-tuning workflow
|
||||
The fine-tuning workflow is as follows:
|
||||
1. Create a `QADataset` object.
|
||||
2. Initialize any embedding function using LanceDB embedding API
|
||||
3. Call `finetune` method on the embedding object with the `QADataset` object as an argument.
|
||||
4. Evaluate the fine-tuned model using the `evaluate` method in the embedding API.
|
||||
|
||||
# End-to-End Examples
|
||||
The following is an example of how to fine-tune an embedding model using the LanceDB embedding API.
|
||||
|
||||
## Example 1: Fine-tuning from a synthetic dataset
|
||||
|
||||
```python
|
||||
import pandas as pd
|
||||
|
||||
from lancedb.embeddings.fine_tuner.llm import Openai
|
||||
from lancedb.embeddings.fine_tuner.dataset import QADataset, TextChunk
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from llama_index.core.schema import MetadataMode
|
||||
from lancedb.embeddings import get_registry
|
||||
|
||||
# 1. Create a QADataset object
|
||||
url = "uber10k.pdf"
|
||||
reader = SimpleDirectoryReader(input_files=url)
|
||||
docs = reader.load_data()
|
||||
|
||||
parser = SentenceSplitter()
|
||||
nodes = parser.get_nodes_from_documents(docs)
|
||||
|
||||
if os.path.exists(name):
|
||||
ds = QADataset.load(name)
|
||||
else:
|
||||
llm = Openai()
|
||||
|
||||
# convert Llama-index TextNode to TextChunk
|
||||
chunks = [TextChunk.from_llama_index_node(node) for node in nodes]
|
||||
|
||||
ds = QADataset.from_llm(chunks, llm)
|
||||
ds.save(name)
|
||||
|
||||
# 2. Initialize the embedding model
|
||||
model = get_registry().get("sentence-transformers").create()
|
||||
|
||||
# 3. Fine-tune the model
|
||||
model.finetune(trainset=ds, path="model_finetuned", epochs=4)
|
||||
|
||||
# 4. Evaluate the fine-tuned model
|
||||
base = get_registry().get("sentence-transformers").create()
|
||||
tuned = get_registry().get("sentence-transformers").create(name="./model_finetuned_1")
|
||||
openai = get_registry().get("openai").create(name="text-embedding-3-large")
|
||||
|
||||
|
||||
rs1 = base.evaluate(trainset, path="val_res")
|
||||
rs2 = tuned.evaluate(trainset, path="val_res")
|
||||
rs3 = openai.evaluate(trainset)
|
||||
|
||||
print("openai-embedding-v3 hit-rate - ", pd.DataFrame(rs3)["is_hit"].mean())
|
||||
print("fine-tuned hit-rate - ", pd.DataFrame(rs2)["is_hit"].mean())
|
||||
print("Base model hite-rate - ", pd.DataFrame(rs1)["is_hit"].mean())
|
||||
```
|
||||
|
||||
|
||||
Each embedding API needs to implement `finetune` method in order to support fine-tuning. A vanilla evaluation technique has been implemented in the `BaseEmbedding` class that calculates hit_rate @ `top_k`.
|
||||
Reference in New Issue
Block a user