mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-23 05:19:58 +00:00
Compare commits
18 Commits
docs/mcp
...
fine_tuner
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
db67b27a42 | ||
|
|
4b0820ef15 | ||
|
|
c7bb919561 | ||
|
|
df404b726e | ||
|
|
ffbb104648 | ||
|
|
3ebd561fd9 | ||
|
|
6bc488f674 | ||
|
|
ea34c0b4c4 | ||
|
|
1a827925eb | ||
|
|
fe5888d661 | ||
|
|
6074e6b7ee | ||
|
|
fd8de238bb | ||
|
|
d0c1113417 | ||
|
|
3ca96a852f | ||
|
|
9428c6b565 | ||
|
|
ff00a3242c | ||
|
|
878deb73a0 | ||
|
|
c75bb65609 |
280
docs/benchmarks/llama-index-datasets.py
Normal file
280
docs/benchmarks/llama-index-datasets.py
Normal file
@@ -0,0 +1,280 @@
|
||||
import argparse
|
||||
import os
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from llama_index.core.llama_dataset import LabelledRagDataset
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from lancedb.embeddings.fine_tuner.dataset import QADataset, TextChunk
|
||||
from lancedb.embeddings.fine_tuner.llm import Openai
|
||||
from lancedb.embeddings import get_registry
|
||||
from llama_index.vector_stores.lancedb import LanceDBVectorStore
|
||||
from llama_index.core import ServiceContext, VectorStoreIndex, StorageContext
|
||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||
from llama_index.core.llama_pack import download_llama_pack
|
||||
|
||||
import time
|
||||
import wandb
|
||||
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def get_paths_from_dataset(dataset: str, split=True):
|
||||
"""
|
||||
Returns paths of:
|
||||
- downloaded dataset, lance train dataset, lance test dataset, finetuned model
|
||||
"""
|
||||
if split:
|
||||
return (
|
||||
f"./data/{dataset}",
|
||||
f"./data/{dataset}_lance_train",
|
||||
f"./data/{dataset}_lance_test",
|
||||
f"./data/tuned_{dataset}",
|
||||
)
|
||||
return f"./data/{dataset}", f"./data/{dataset}_lance", f"./data/tuned_{dataset}"
|
||||
|
||||
|
||||
def get_llama_dataset(dataset: str):
|
||||
"""
|
||||
returns:
|
||||
- nodes, documents, rag_dataset
|
||||
"""
|
||||
if not os.path.exists(f"./data/{dataset}"):
|
||||
os.system(
|
||||
f"llamaindex-cli download-llamadataset {dataset} --download-dir ./data/{dataset}" # noqa
|
||||
)
|
||||
rag_dataset = LabelledRagDataset.from_json(f"./data/{dataset}/rag_dataset.json")
|
||||
docs = SimpleDirectoryReader(input_dir=f"./data/{dataset}/source_files").load_data()
|
||||
|
||||
parser = SentenceSplitter()
|
||||
nodes = parser.get_nodes_from_documents(docs)
|
||||
|
||||
return nodes, docs, rag_dataset
|
||||
|
||||
|
||||
def lance_dataset_from_llama_nodes(nodes: list, name: str, split=True):
|
||||
llm = Openai()
|
||||
chunks = [TextChunk.from_llama_index_node(node) for node in nodes]
|
||||
# train test split 75-35
|
||||
if not split:
|
||||
if os.path.exists(f"./data/{name}_lance"):
|
||||
ds = QADataset.load(f"./data/{name}_lance")
|
||||
return ds
|
||||
ds = QADataset.from_llm(chunks, llm)
|
||||
ds.save(f"./data/{name}_lance")
|
||||
return ds
|
||||
|
||||
if os.path.exists(f"./data/{name}_lance_train") and os.path.exists(
|
||||
f"./data/{name}_lance_test"
|
||||
):
|
||||
train_ds = QADataset.load(f"./data/{name}_lance_train")
|
||||
test_ds = QADataset.load(f"./data/{name}_lance_test")
|
||||
return train_ds, test_ds
|
||||
# split chunks random
|
||||
train_size = int(len(chunks) * 0.65)
|
||||
train_chunks = chunks[:train_size]
|
||||
test_chunks = chunks[train_size:]
|
||||
train_ds = QADataset.from_llm(train_chunks, llm)
|
||||
test_ds = QADataset.from_llm(test_chunks, llm)
|
||||
train_ds.save(f"./data/{name}_lance_train")
|
||||
test_ds.save(f"./data/{name}_lance_test")
|
||||
return train_ds, test_ds
|
||||
|
||||
|
||||
def finetune(
|
||||
trainset: str, model: str, epochs: int, path: str, valset: str = None, top_k=5
|
||||
):
|
||||
print(f"Finetuning {model} for {epochs} epochs")
|
||||
print(f"trainset query instances: {len(trainset.queries)}")
|
||||
print(f"valset query instances: {len(valset.queries)}")
|
||||
|
||||
valset = valset if valset is not None else trainset
|
||||
model = get_registry().get("sentence-transformers").create(name=model)
|
||||
base_result = model.evaluate(valset, path="./data/eval/", top_k=top_k)
|
||||
base_hit_rate = pd.DataFrame(base_result)["is_hit"].mean()
|
||||
|
||||
model.finetune(trainset=trainset, valset=valset, path=path, epochs=epochs)
|
||||
tuned = get_registry().get("sentence-transformers").create(name=path)
|
||||
tuned_result = tuned.evaluate(
|
||||
valset, path=f"./data/eval/{str(time.time())}", top_k=top_k
|
||||
)
|
||||
tuned_hit_rate = pd.DataFrame(tuned_result)["is_hit"].mean()
|
||||
|
||||
return base_hit_rate, tuned_hit_rate
|
||||
|
||||
|
||||
def do_eval_rag(dataset: str, model: str):
|
||||
# Requires - pip install llama-index-vector-stores-lancedb
|
||||
# Requires - pip install llama-index-embeddings-huggingface
|
||||
nodes, docs, rag_dataset = get_llama_dataset(dataset)
|
||||
|
||||
embed_model = HuggingFaceEmbedding(model)
|
||||
vector_store = LanceDBVectorStore(uri="/tmp/lancedb")
|
||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||
service_context = ServiceContext.from_defaults(embed_model=embed_model)
|
||||
index = VectorStoreIndex(
|
||||
nodes,
|
||||
service_context=service_context,
|
||||
show_progress=True,
|
||||
storage_context=storage_context,
|
||||
)
|
||||
|
||||
# build basic RAG system
|
||||
index = VectorStoreIndex.from_documents(documents=docs)
|
||||
query_engine = index.as_query_engine()
|
||||
|
||||
# evaluate using the RagEvaluatorPack
|
||||
RagEvaluatorPack = download_llama_pack("RagEvaluatorPack", "./rag_evaluator_pack")
|
||||
rag_evaluator_pack = RagEvaluatorPack(
|
||||
rag_dataset=rag_dataset, query_engine=query_engine
|
||||
)
|
||||
|
||||
metrics = rag_evaluator_pack.run(
|
||||
batch_size=20, # batches the number of openai api calls to make
|
||||
sleep_time_in_seconds=1, # seconds to sleep before making an api call
|
||||
)
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def main(
|
||||
dataset,
|
||||
model,
|
||||
epochs,
|
||||
top_k=5,
|
||||
eval_rag=False,
|
||||
split=True,
|
||||
project: str = "lancedb_finetune",
|
||||
):
|
||||
nodes, _, _ = get_llama_dataset(dataset)
|
||||
trainset = None
|
||||
valset = None
|
||||
if split:
|
||||
trainset, valset = lance_dataset_from_llama_nodes(nodes, dataset, split)
|
||||
data_path, lance_train_path, lance_test_path, tuned_path = (
|
||||
get_paths_from_dataset(dataset, split=split)
|
||||
)
|
||||
else:
|
||||
trainset = lance_dataset_from_llama_nodes(nodes, dataset, split)
|
||||
valset = trainset
|
||||
data_path, lance_path, tuned_path = get_paths_from_dataset(dataset, split=split)
|
||||
|
||||
base_hit_rate, tuned_hit_rate = finetune(
|
||||
trainset, model, epochs, tuned_path, valset, top_k=top_k
|
||||
)
|
||||
|
||||
# Base model model metrics
|
||||
metrics = do_eval_rag(dataset, model) if eval_rag else {}
|
||||
|
||||
# Tuned model metrics
|
||||
metrics_tuned = do_eval_rag(dataset, tuned_path) if eval_rag else {}
|
||||
|
||||
wandb.init(project="lancedb_finetune", name=f"{dataset}_{model}_{epochs}")
|
||||
wandb.log(
|
||||
{
|
||||
"hit_rate": tuned_hit_rate,
|
||||
}
|
||||
)
|
||||
wandb.log(metrics_tuned)
|
||||
wandb.finish()
|
||||
|
||||
wandb.init(project="lancedb_finetune", name=f"{dataset}_{model}_base")
|
||||
wandb.log(
|
||||
{
|
||||
"hit_rate": base_hit_rate,
|
||||
}
|
||||
)
|
||||
wandb.log(metrics)
|
||||
wandb.finish()
|
||||
|
||||
|
||||
def banchmark_all():
|
||||
datasets = [
|
||||
"Uber10KDataset2021",
|
||||
"MiniTruthfulQADataset",
|
||||
"MiniSquadV2Dataset",
|
||||
"MiniEsgBenchDataset",
|
||||
"MiniCovidQaDataset",
|
||||
"Llama2PaperDataset",
|
||||
"HistoryOfAlexnetDataset",
|
||||
"PatronusAIFinanceBenchDataset",
|
||||
]
|
||||
models = ["BAAI/bge-small-en-v1.5"]
|
||||
top_ks = [5]
|
||||
for top_k in top_ks:
|
||||
for model in models:
|
||||
for dataset in datasets:
|
||||
main(dataset, model, 5, top_k=top_k)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
Benchmark the fine-tuning process for a given dataset and model.
|
||||
|
||||
Usage:
|
||||
- For a single dataset
|
||||
python lancedb/docs/benchmarks/llama-index-datasets.py --dataset Uber10KDataset2021 --model BAAI/bge-small-en-v1.5 --epochs 4 --top_k 5 --split 1 --project lancedb_finetune
|
||||
|
||||
- For all datasets and models across all top_ks
|
||||
python lancedb/docs/benchmarks/llama-index-datasets.py --benchmark-all
|
||||
"""
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
default="BraintrustCodaHelpDeskDataset",
|
||||
help="The dataset to use for fine-tuning",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="BAAI/bge-small-en-v1.5",
|
||||
help="The model to use for fine-tuning",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--epochs",
|
||||
type=int,
|
||||
default=4,
|
||||
help="The number of epochs to fine-tune the model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--project",
|
||||
type=str,
|
||||
default="lancedb_finetune",
|
||||
help="The wandb project to log the results",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top_k", type=int, default=5, help="The number of top results to evaluate"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--split",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Whether to split the dataset into train and test(65-35 split), default is 1",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval-rag",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to evaluate the model using RAG",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--benchmark-all",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Benchmark all datasets across all models and top_ks",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.benchmark_all:
|
||||
banchmark_all()
|
||||
else:
|
||||
main(
|
||||
args.dataset,
|
||||
args.model,
|
||||
args.epochs,
|
||||
args.top_k,
|
||||
args.eval_rag,
|
||||
args.split,
|
||||
args.project,
|
||||
)
|
||||
@@ -10,13 +10,18 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
from tqdm import tqdm
|
||||
|
||||
import lancedb
|
||||
|
||||
from .fine_tuner import QADataset
|
||||
from .utils import TEXT, retry_with_exponential_backoff
|
||||
|
||||
|
||||
@@ -126,6 +131,22 @@ class EmbeddingFunction(BaseModel, ABC):
|
||||
def __hash__(self) -> int:
|
||||
return hash(frozenset(vars(self).items()))
|
||||
|
||||
def finetune(self, dataset: QADataset, *args, **kwargs):
|
||||
"""
|
||||
Finetune the embedding function on a dataset
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Finetuning is not supported for this embedding function"
|
||||
)
|
||||
|
||||
def evaluate(self, dataset: QADataset, top_k=5, path=None, *args, **kwargs):
|
||||
"""
|
||||
Evaluate the embedding function on a dataset
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Evaluation is not supported for this embedding function"
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingFunctionConfig(BaseModel):
|
||||
"""
|
||||
@@ -159,3 +180,52 @@ class TextEmbeddingFunction(EmbeddingFunction):
|
||||
Generate the embeddings for the given texts
|
||||
"""
|
||||
pass
|
||||
|
||||
def evaluate(self, dataset: QADataset, top_k=5, path=None, *args, **kwargs):
|
||||
"""
|
||||
Evaluate the embedding function on a dataset. This calculates the hit-rate for
|
||||
the top-k retrieved documents for each query in the dataset. Assumes that the
|
||||
first relevant document is the expected document.
|
||||
Pro - Should work for any embedding model
|
||||
Con - Returns every simple metric.
|
||||
Parameters
|
||||
----------
|
||||
dataset: QADataset
|
||||
The dataset to evaluate on
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
The evaluation results
|
||||
"""
|
||||
corpus = dataset.corpus
|
||||
queries = dataset.queries
|
||||
relevant_docs = dataset.relevant_docs
|
||||
path = path or os.path.join(os.getcwd(), "eval")
|
||||
db = lancedb.connect(path)
|
||||
|
||||
class Schema(lancedb.pydantic.LanceModel):
|
||||
id: str
|
||||
text: str = self.SourceField()
|
||||
vector: lancedb.pydantic.Vector(self.ndims()) = self.VectorField()
|
||||
|
||||
retriever = db.create_table("eval", schema=Schema, mode="overwrite")
|
||||
pylist = [{"id": str(k), "text": v} for k, v in corpus.items()]
|
||||
retriever.add(pylist)
|
||||
|
||||
eval_results = []
|
||||
for query_id, query in tqdm(queries.items()):
|
||||
retrieved_nodes = retriever.search(query).limit(top_k).to_list()
|
||||
retrieved_ids = [node["id"] for node in retrieved_nodes]
|
||||
expected_id = relevant_docs[query_id][0]
|
||||
is_hit = expected_id in retrieved_ids # assume 1 relevant doc
|
||||
|
||||
eval_result = {
|
||||
"is_hit": is_hit,
|
||||
"retrieved": retrieved_ids,
|
||||
"expected": expected_id,
|
||||
"query": query_id,
|
||||
}
|
||||
eval_results.append(eval_result)
|
||||
|
||||
return eval_results
|
||||
|
||||
147
python/python/lancedb/embeddings/fine_tuner/README.md
Normal file
147
python/python/lancedb/embeddings/fine_tuner/README.md
Normal file
@@ -0,0 +1,147 @@
|
||||
### Fine-tuning workflow
|
||||
The fine-tuning workflow is as follows:
|
||||
1. Create a `QADataset` object.
|
||||
2. Initialize any embedding function using LanceDB embedding API
|
||||
3. Call `finetune` method on the embedding object with the `QADataset` object as an argument.
|
||||
4. Evaluate the fine-tuned model using the `evaluate` method in the embedding API.
|
||||
|
||||
# End-to-End Examples
|
||||
The following is an example of how to fine-tune an embedding model using the LanceDB embedding API.
|
||||
|
||||
## Example 1: Fine-tuning from a synthetic dataset
|
||||
|
||||
```python
|
||||
import os
|
||||
import pandas as pd
|
||||
|
||||
from lancedb.embeddings.fine_tuner.llm import Openai
|
||||
from lancedb.embeddings.fine_tuner.dataset import QADataset, TextChunk
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from llama_index.core.schema import MetadataMode
|
||||
from lancedb.embeddings import get_registry
|
||||
|
||||
dataset = "Uber10KDataset2021"
|
||||
lance_dataset_dir = dataset + "_lance"
|
||||
valset_dir = dataset + "_lance_val"
|
||||
finetuned_model_path = "./model_finetuned"
|
||||
|
||||
# 1. Create a QADataset object. See all datasets on llama-index here: https://github.com/run-llama/llama_index/tree/main/llama-datasets
|
||||
|
||||
if not os.path.exists(f"./data/{dataset}"):
|
||||
os.system(
|
||||
f"llamaindex-cli download-llamadataset {dataset} --download-dir ./data/{dataset}"
|
||||
)
|
||||
docs = SimpleDirectoryReader(input_dir=f"./data/{dataset}/source_files").load_data()
|
||||
|
||||
parser = SentenceSplitter()
|
||||
nodes = parser.get_nodes_from_documents(docs)
|
||||
# convert Llama-index TextNode to TextChunk
|
||||
chunks = [TextChunk.from_llama_index_node(node) for node in nodes]
|
||||
llm = Openai()
|
||||
|
||||
if os.path.exists(lance_dataset_dir):
|
||||
trainset = QADataset.load(lance_dataset_dir)
|
||||
else:
|
||||
trainset = QADataset.from_llm(chunks, llm, num_questions_per_chunk=2)
|
||||
trainset.save(lance_dataset_dir)
|
||||
|
||||
# Ideally, we should have a standard dataset for validation, but here we're just generating a synthetic dataset.
|
||||
if os.path.exists(valset_dir):
|
||||
valset = QADataset.load(valset_dir)
|
||||
else:
|
||||
valset = QADataset.from_llm(chunks, llm, num_questions_per_chunk=4)
|
||||
valset.save(valset_dir)
|
||||
|
||||
# 2. Initialize the embedding model
|
||||
model = get_registry().get("sentence-transformers").create(name="sentence-transformers/multi-qa-MiniLM-L6-cos-v1")
|
||||
|
||||
# 3. Fine-tune the model
|
||||
model.finetune(trainset=trainset, path=finetuned_model_path, epochs=4)
|
||||
|
||||
# 4. Evaluate the fine-tuned model
|
||||
base = get_registry().get("sentence-transformers").create(name="sentence-transformers/multi-qa-MiniLM-L6-cos-v1")
|
||||
base_results = base.evaluate(valset, top_k=5)
|
||||
|
||||
tuned = get_registry().get("sentence-transformers").create(name=finetuned_model_path)
|
||||
tuned_results = tuned.evaluate(valset, top_k=5)
|
||||
|
||||
openai = get_registry().get("openai").create(name="text-embedding-3-small")
|
||||
openai_results = openai.evaluate(valset, top_k=5)
|
||||
|
||||
|
||||
print("openai-embedding-v3 hit-rate - ", pd.DataFrame(openai_results)["is_hit"].mean())
|
||||
print("fine-tuned hit-rate - ", pd.DataFrame(tuned_results)["is_hit"].mean())
|
||||
print("Base model hite-rate - ", pd.DataFrame(base_results)["is_hit"].mean())
|
||||
```
|
||||
|
||||
Fine-tuning workflow for embeddings consists for the following parts:
|
||||
|
||||
### QADataset
|
||||
This class is used for managing the data for fine-tuning. It contains the following builder methods:
|
||||
```
|
||||
- from_llm(
|
||||
nodes: 'List[TextChunk]' ,
|
||||
llm: BaseLLM,
|
||||
qa_generate_prompt_tmpl: str = DEFAULT_PROMPT_TMPL,
|
||||
num_questions_per_chunk: int = 2,
|
||||
) -> "QADataset"
|
||||
```
|
||||
Create synthetic data from a language model and text chunks of the original document on which the model is to be fine-tuned.
|
||||
|
||||
```python
|
||||
|
||||
from_responses(docs: List['TextChunk'], queries: Dict[str, str], relevant_docs: Dict[str, List[str]])-> "QADataset"
|
||||
```
|
||||
Create dataset from queries and responses based on a real-world scenario. Designed to be used for knowledge distillation from a larger LLM to a smaller one.
|
||||
|
||||
It also contains the following data attributes:
|
||||
```
|
||||
queries (Dict[str, str]): Dict id -> query.
|
||||
corpus (Dict[str, str]): Dict id -> string.
|
||||
relevant_docs (Dict[str, List[str]]): Dict query id -> list of doc ids.
|
||||
```
|
||||
|
||||
### TextChunk
|
||||
This class is used for managing the data for fine-tuning. It is designed to allow working with and standardize various text splitting/pre-processing tools like llama-index and langchain. It contains the following attributes:
|
||||
```
|
||||
text: str
|
||||
id: str
|
||||
metadata: Dict[str, Any] = {}
|
||||
```
|
||||
|
||||
Builder Methods:
|
||||
|
||||
```python
|
||||
from_llama_index_node(node) -> "TextChunk"
|
||||
```
|
||||
Create a text chunk from a llama index node.
|
||||
|
||||
```python
|
||||
from_langchain_node(node) -> "TextChunk"
|
||||
```
|
||||
Create a text chunk from a langchain index node.
|
||||
|
||||
```python
|
||||
from_chunk(cls, chunk: str, metadata: dict = {}) -> "TextChunk"
|
||||
```
|
||||
Create a text chunk from a string.
|
||||
|
||||
### FineTuner
|
||||
This class is used for fine-tuning embeddings. It is exposed to the user via a high-level function in the base embedding api.
|
||||
```python
|
||||
class BaseEmbeddingTuner(ABC):
|
||||
"""Base Embedding finetuning engine."""
|
||||
|
||||
@abstractmethod
|
||||
def finetune(self) -> None:
|
||||
"""Goes off and does stuff."""
|
||||
|
||||
def helper(self) -> None:
|
||||
"""A helper method."""
|
||||
pass
|
||||
```
|
||||
|
||||
### Embedding API finetuning implementation
|
||||
Each embedding API needs to implement `finetune` method in order to support fine-tuning. A vanilla evaluation technique has been implemented in the `BaseEmbedding` class that calculates hit_rate @ `top_k`.
|
||||
4
python/python/lancedb/embeddings/fine_tuner/__init__.py
Normal file
4
python/python/lancedb/embeddings/fine_tuner/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .dataset import QADataset, TextChunk
|
||||
from .llm import Gemini, Openai
|
||||
|
||||
__all__ = ["QADataset", "TextChunk", "Openai", "Gemini"]
|
||||
19
python/python/lancedb/embeddings/fine_tuner/basetuner.py
Normal file
19
python/python/lancedb/embeddings/fine_tuner/basetuner.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseEmbeddingTuner(ABC):
|
||||
"""Base Embedding finetuning engine."""
|
||||
|
||||
@abstractmethod
|
||||
def finetune(self) -> None:
|
||||
"""
|
||||
Finetune the embedding model.
|
||||
"""
|
||||
pass
|
||||
|
||||
def helper(self) -> None:
|
||||
"""
|
||||
A helper method called after finetuning. This is meant to provide
|
||||
usage instructions or other helpful information.
|
||||
"""
|
||||
pass
|
||||
283
python/python/lancedb/embeddings/fine_tuner/dataset.py
Normal file
283
python/python/lancedb/embeddings/fine_tuner/dataset.py
Normal file
@@ -0,0 +1,283 @@
|
||||
import re
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import lance
|
||||
import pyarrow as pa
|
||||
from pydantic import BaseModel
|
||||
from tqdm import tqdm
|
||||
|
||||
from .llm import BaseLLM
|
||||
|
||||
DEFAULT_PROMPT_TMPL = """\
|
||||
Context information is below.
|
||||
|
||||
---------------------
|
||||
{context_str}
|
||||
---------------------
|
||||
|
||||
Given the context information and no prior knowledge.
|
||||
generate only questions based on the below query.
|
||||
|
||||
You are a Teacher/ Professor. Your task is to setup \
|
||||
{num_questions_per_chunk} questions for an upcoming \
|
||||
quiz/examination. The questions should be diverse in nature \
|
||||
across the document. Restrict the questions to the \
|
||||
context information provided."
|
||||
"""
|
||||
|
||||
|
||||
class QADataset(BaseModel):
|
||||
"""Embedding QA Finetuning Dataset.
|
||||
|
||||
Args:
|
||||
queries (Dict[str, str]): Dict id -> query.
|
||||
corpus (Dict[str, str]): Dict id -> string.
|
||||
relevant_docs (Dict[str, List[str]]): Dict query id -> list of doc ids.
|
||||
|
||||
"""
|
||||
|
||||
queries: Dict[str, str] # id -> query
|
||||
corpus: Dict[str, str] # id -> text
|
||||
relevant_docs: Dict[str, List[str]] # query id -> list of retrieved doc ids
|
||||
mode: str = "text"
|
||||
|
||||
@property
|
||||
def query_docid_pairs(self) -> List[Tuple[str, List[str]]]:
|
||||
"""Get query, relevant doc ids."""
|
||||
return [
|
||||
(query, self.relevant_docs[query_id])
|
||||
for query_id, query in self.queries.items()
|
||||
]
|
||||
|
||||
def save(self, path: str, mode: str = "overwrite") -> None:
|
||||
"""
|
||||
Save the current dataset to a directory as .lance files.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : str
|
||||
The path to save the dataset.
|
||||
mode : str, optional
|
||||
The mode to save the dataset, by default "overwrite". Accepts
|
||||
lance modes.
|
||||
"""
|
||||
save_dir = Path(path)
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# convert to pydict {"id": []}
|
||||
queries = {
|
||||
"id": list(self.queries.keys()),
|
||||
"query": list(self.queries.values()),
|
||||
}
|
||||
corpus = {
|
||||
"id": list(self.corpus.keys()),
|
||||
"text": [
|
||||
val or " " for val in self.corpus.values()
|
||||
], # lance saves empty strings as null
|
||||
}
|
||||
relevant_docs = {
|
||||
"query_id": list(self.relevant_docs.keys()),
|
||||
"doc_id": list(self.relevant_docs.values()),
|
||||
}
|
||||
|
||||
# write to lance
|
||||
lance.write_dataset(
|
||||
pa.Table.from_pydict(queries), save_dir / "queries.lance", mode=mode
|
||||
)
|
||||
lance.write_dataset(
|
||||
pa.Table.from_pydict(corpus), save_dir / "corpus.lance", mode=mode
|
||||
)
|
||||
lance.write_dataset(
|
||||
pa.Table.from_pydict(relevant_docs),
|
||||
save_dir / "relevant_docs.lance",
|
||||
mode=mode,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: str) -> "QADataset":
|
||||
"""
|
||||
Load QADataset from a directory.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : str
|
||||
The path to load the dataset from.
|
||||
|
||||
Returns
|
||||
-------
|
||||
QADataset
|
||||
The loaded QADataset.
|
||||
|
||||
"""
|
||||
load_dir = Path(path)
|
||||
queries = lance.dataset(load_dir / "queries.lance").to_table().to_pydict()
|
||||
corpus = lance.dataset(load_dir / "corpus.lance").to_table().to_pydict()
|
||||
relevant_docs = (
|
||||
lance.dataset(load_dir / "relevant_docs.lance").to_table().to_pydict()
|
||||
)
|
||||
return cls(
|
||||
queries=dict(zip(queries["id"], queries["query"])),
|
||||
corpus=dict(zip(corpus["id"], corpus["text"])),
|
||||
relevant_docs=dict(zip(relevant_docs["query_id"], relevant_docs["doc_id"])),
|
||||
)
|
||||
|
||||
# generate queries as a convenience function
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
nodes: "List[TextChunk]",
|
||||
llm: BaseLLM,
|
||||
qa_generate_prompt_tmpl: str = DEFAULT_PROMPT_TMPL,
|
||||
num_questions_per_chunk: int = 2,
|
||||
) -> "QADataset":
|
||||
"""
|
||||
Generate a QADataset from a list of TextChunks.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
nodes : List[TextChunk]
|
||||
The list of text chunks.
|
||||
llm : BaseLLM
|
||||
The language model to generate questions.
|
||||
qa_generate_prompt_tmpl : str, optional
|
||||
The template for generating questions, by default DEFAULT_PROMPT_TMPL.
|
||||
num_questions_per_chunk : int, optional
|
||||
The number of questions to generate per chunk, by default 2.
|
||||
|
||||
Returns
|
||||
-------
|
||||
QADataset
|
||||
The generated QADataset.
|
||||
"""
|
||||
node_dict = {node.id: node.text for node in nodes}
|
||||
|
||||
queries = {}
|
||||
relevant_docs = {}
|
||||
for node_id, text in tqdm(node_dict.items()):
|
||||
query = qa_generate_prompt_tmpl.format(
|
||||
context_str=text, num_questions_per_chunk=num_questions_per_chunk
|
||||
)
|
||||
response = llm.chat_completion(query)
|
||||
|
||||
result = str(response).strip().split("\n")
|
||||
questions = [
|
||||
re.sub(r"^\d+[\).\s]", "", question).strip() for question in result
|
||||
]
|
||||
questions = [question for question in questions if len(question) > 0]
|
||||
for question in questions:
|
||||
question_id = str(uuid.uuid4())
|
||||
queries[question_id] = question
|
||||
relevant_docs[question_id] = [node_id]
|
||||
|
||||
return cls(queries=queries, corpus=node_dict, relevant_docs=relevant_docs)
|
||||
|
||||
@classmethod
|
||||
def from_responses(
|
||||
cls,
|
||||
nodes: List["TextChunk"],
|
||||
queries: Dict[str, str],
|
||||
relevant_docs: Dict[str, List[str]],
|
||||
) -> "QADataset":
|
||||
"""
|
||||
Create a QADataset from a list of TextChunks and a list of
|
||||
questions, queries, and relevant docs.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
nodes : List[TextChunk]
|
||||
The list of text chunks.
|
||||
queries : Dict[str, str]
|
||||
The queries. query id -> query.
|
||||
relevant_docs : Dict[str, List[str]]
|
||||
The relevant docs. Dict query id -> list of doc ids.
|
||||
|
||||
Returns
|
||||
-------
|
||||
QADataset
|
||||
The QADataset.
|
||||
"""
|
||||
node_dict = {node.id: node.text for node in nodes}
|
||||
return cls(queries=queries, corpus=node_dict, relevant_docs=relevant_docs)
|
||||
|
||||
|
||||
class TextChunk(BaseModel):
|
||||
"""
|
||||
Simple text chunk for storing text nodes. Acts as a wrapper around text.
|
||||
Allow interoperability between different text processing libraries.
|
||||
|
||||
Args:
|
||||
text (str): The text of the chunk.
|
||||
id (str): The id of the chunk.
|
||||
metadata (Dict[str, Any], optional): The metadata of the chunk. Defaults to {}.
|
||||
"""
|
||||
|
||||
text: str
|
||||
id: str
|
||||
metadata: Dict[str, Any] = {}
|
||||
|
||||
@classmethod
|
||||
def from_chunk(cls, chunk: str, metadata: dict = {}) -> "TextChunk":
|
||||
"""
|
||||
Create a SimpleTextChunk from a chunk.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
chunk : str
|
||||
The text chunk.
|
||||
metadata : dict, optional
|
||||
The metadata, by default {}.
|
||||
|
||||
Returns
|
||||
-------
|
||||
TextChunk
|
||||
The text chunk.
|
||||
|
||||
"""
|
||||
# generate a unique id
|
||||
return cls(text=chunk, id=str(uuid.uuid4()), metadata=metadata)
|
||||
|
||||
@classmethod
|
||||
def from_llama_index_node(cls, node):
|
||||
"""
|
||||
Generate a TextChunk from a llama index node.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
node : llama_index.core.TextNode
|
||||
The llama index node.
|
||||
|
||||
"""
|
||||
return cls(text=node.text, id=node.node_id, metadata=node.metadata)
|
||||
|
||||
@classmethod
|
||||
def from_langchain_node(cls, node):
|
||||
"""
|
||||
Generate a TextChunk from a langchain node.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
node : langchain.core.TextNode
|
||||
The langchain node.
|
||||
|
||||
"""
|
||||
raise NotImplementedError("Not implemented yet.")
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert to a dictionary.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dict[str, Any]
|
||||
The dictionary.
|
||||
"""
|
||||
return self.dict()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.text
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"SimpleTextChunk(text={self.text}, id={self.id}, \
|
||||
metadata={self.metadata})"
|
||||
88
python/python/lancedb/embeddings/fine_tuner/llm.py
Normal file
88
python/python/lancedb/embeddings/fine_tuner/llm.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import os
|
||||
import re
|
||||
from functools import cached_property
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ...util import attempt_import_or_raise
|
||||
from ..utils import api_key_not_found_help
|
||||
|
||||
|
||||
class BaseLLM(BaseModel):
|
||||
"""
|
||||
TODO:
|
||||
Base class for Language Model based Embedding Functions. This class is
|
||||
loosely desined rn, and will be updated as the usage gets clearer.
|
||||
"""
|
||||
|
||||
class Config:
|
||||
protected_namespaces = () # Disable protected namespace check
|
||||
|
||||
model_name: str
|
||||
model_kwargs: dict = {}
|
||||
|
||||
@cached_property
|
||||
def _client():
|
||||
"""
|
||||
Get the client for the language model
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def chat_completion(self, prompt: str, **kwargs):
|
||||
"""
|
||||
Get the chat completion for the given prompt
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Openai(BaseLLM):
|
||||
model_name: str = "gpt-3.5-turbo"
|
||||
kwargs: dict = {}
|
||||
api_key: Optional[str] = None
|
||||
|
||||
@cached_property
|
||||
def _client(self):
|
||||
"""
|
||||
Get the client for the language model
|
||||
"""
|
||||
openai = attempt_import_or_raise("openai")
|
||||
|
||||
if not os.environ.get("OPENAI_API_KEY"):
|
||||
api_key_not_found_help("openai")
|
||||
return openai.OpenAI()
|
||||
|
||||
def chat_completion(self, prompt: str) -> str:
|
||||
"""
|
||||
Get the chat completion for the given prompt
|
||||
"""
|
||||
|
||||
# TODO: this is legacy openai api replace with completions
|
||||
completion = self._client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
**self.kwargs,
|
||||
)
|
||||
|
||||
text = completion.choices[0].message.content
|
||||
|
||||
return text
|
||||
|
||||
def get_questions(self, prompt: str) -> str:
|
||||
"""
|
||||
Get the chat completion for the given prompt
|
||||
"""
|
||||
response = self.chat_completion(prompt)
|
||||
result = str(response).strip().split("\n")
|
||||
questions = [
|
||||
re.sub(r"^\d+[\).\s]", "", question).strip() for question in result
|
||||
]
|
||||
questions = [question for question in questions if len(question) > 0]
|
||||
return questions
|
||||
|
||||
|
||||
class Gemini(BaseLLM):
|
||||
pass
|
||||
@@ -10,12 +10,15 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import List, Union
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import logging
|
||||
from lancedb.embeddings.fine_tuner import QADataset
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import TextEmbeddingFunction
|
||||
from .fine_tuner.basetuner import BaseEmbeddingTuner
|
||||
from .registry import register
|
||||
from .utils import weak_lru
|
||||
|
||||
@@ -80,3 +83,151 @@ class SentenceTransformerEmbeddings(TextEmbeddingFunction):
|
||||
"sentence_transformers", "sentence-transformers"
|
||||
)
|
||||
return sentence_transformers.SentenceTransformer(self.name, device=self.device)
|
||||
|
||||
def finetune(self, trainset: QADataset, *args, **kwargs):
|
||||
"""
|
||||
Finetune the Sentence Transformers model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dataset: QADataset
|
||||
The dataset to use for finetuning
|
||||
"""
|
||||
tuner = SentenceTransformersTuner(
|
||||
model=self.embedding_model,
|
||||
trainset=trainset,
|
||||
**kwargs,
|
||||
)
|
||||
tuner.finetune()
|
||||
|
||||
|
||||
class SentenceTransformersTuner(BaseEmbeddingTuner):
|
||||
"""Sentence Transformers Embedding Finetuning Engine."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Any,
|
||||
trainset: QADataset,
|
||||
valset: Optional[QADataset] = None,
|
||||
path: Optional[str] = "~/.lancedb/embeddings/models",
|
||||
batch_size: int = 8,
|
||||
epochs: int = 1,
|
||||
show_progress: bool = True,
|
||||
eval_steps: int = 50,
|
||||
max_input_per_doc: int = -1,
|
||||
loss: Optional[Any] = None,
|
||||
evaluator: Optional[Any] = None,
|
||||
run_name: Optional[str] = None,
|
||||
log_wandb: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
model: str
|
||||
The model to use for finetuning.
|
||||
trainset: QADataset
|
||||
The training dataset.
|
||||
valset: Optional[QADataset]
|
||||
The validation dataset.
|
||||
path: Optional[str]
|
||||
The path to save the model.
|
||||
batch_size: int, default=8
|
||||
The batch size.
|
||||
epochs: int, default=1
|
||||
The number of epochs.
|
||||
show_progress: bool, default=True
|
||||
Whether to show progress.
|
||||
eval_steps: int, default=50
|
||||
The number of steps to evaluate.
|
||||
max_input_per_doc: int, default=-1
|
||||
The number of input per document.
|
||||
if -1, use all documents.
|
||||
"""
|
||||
from sentence_transformers import InputExample, losses
|
||||
from sentence_transformers.evaluation import InformationRetrievalEvaluator
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
self.model = model
|
||||
self.trainset = trainset
|
||||
self.valset = valset
|
||||
self.path = path
|
||||
self.batch_size = batch_size
|
||||
self.epochs = epochs
|
||||
self.show_progress = show_progress
|
||||
self.eval_steps = eval_steps
|
||||
self.max_input_per_doc = max_input_per_doc
|
||||
self.evaluator = None
|
||||
self.epochs = epochs
|
||||
self.show_progress = show_progress
|
||||
self.eval_steps = eval_steps
|
||||
self.run_name = run_name
|
||||
self.log_wandb = log_wandb
|
||||
|
||||
if self.max_input_per_doc < -1:
|
||||
raise ValueError("max_input_per_doc must be -1 or greater than 0.")
|
||||
|
||||
examples: Any = []
|
||||
for query_id, query in self.trainset.queries.items():
|
||||
if max_input_per_doc == -1:
|
||||
for node_id in self.trainset.relevant_docs[query_id]:
|
||||
text = self.trainset.corpus[node_id]
|
||||
example = InputExample(texts=[query, text])
|
||||
examples.append(example)
|
||||
else:
|
||||
node_id = self.trainset.relevant_docs[query_id][
|
||||
min(max_input_per_doc, len(self.trainset.relevant_docs[query_id]))
|
||||
]
|
||||
text = self.trainset.corpus[node_id]
|
||||
example = InputExample(texts=[query, text])
|
||||
examples.append(example)
|
||||
|
||||
self.examples = examples
|
||||
|
||||
self.loader: DataLoader = DataLoader(examples, batch_size=batch_size)
|
||||
|
||||
if self.valset is not None:
|
||||
eval_engine = evaluator or InformationRetrievalEvaluator
|
||||
self.evaluator = eval_engine(
|
||||
valset.queries, valset.corpus, valset.relevant_docs
|
||||
)
|
||||
self.evaluator = evaluator
|
||||
|
||||
# define loss
|
||||
self.loss = loss or losses.MultipleNegativesRankingLoss(self.model)
|
||||
self.warmup_steps = int(len(self.loader) * epochs * 0.1)
|
||||
|
||||
def finetune(self) -> None:
|
||||
"""Finetune the Sentence Transformers model."""
|
||||
self.model.fit(
|
||||
train_objectives=[(self.loader, self.loss)],
|
||||
epochs=self.epochs,
|
||||
warmup_steps=self.warmup_steps,
|
||||
output_path=self.path,
|
||||
show_progress_bar=self.show_progress,
|
||||
evaluator=self.evaluator,
|
||||
evaluation_steps=self.eval_steps,
|
||||
callback=self._wandb_callback if self.log_wandb else None,
|
||||
)
|
||||
|
||||
self.helper()
|
||||
|
||||
def helper(self) -> None:
|
||||
"""A helper method."""
|
||||
logging.info("Finetuning complete.")
|
||||
logging.info(f"Model saved to {self.path}.") # noqa
|
||||
logging.info("You can now use the model as follows:")
|
||||
logging.info(
|
||||
f"model = get_registry().get('sentence-transformers').create(name='./{self.path}')" # noqa
|
||||
)
|
||||
|
||||
def _wandb_callback(self, score, epoch, steps):
|
||||
try:
|
||||
import wandb
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"wandb is not installed. Please install it using `pip install wandb`"
|
||||
)
|
||||
run = wandb.run or wandb.init(
|
||||
project="sbert_lancedb_finetune", name=self.run_name
|
||||
)
|
||||
run.log({"epoch": epoch, "steps": steps, "score": score})
|
||||
|
||||
45
python/python/tests/test_embedding_tuner.py
Normal file
45
python/python/tests/test_embedding_tuner.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from lancedb.embeddings import get_registry
|
||||
from lancedb.embeddings.fine_tuner import QADataset, TextChunk
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_finetuning_sentence_transformers(tmp_path):
|
||||
queries = {}
|
||||
relevant_docs = {}
|
||||
chunks = [
|
||||
"This is a chunk related to legal docs",
|
||||
"This is another chunk related financial docs",
|
||||
"This is a chunk related to sports docs",
|
||||
"This is another chunk related to fashion docs",
|
||||
]
|
||||
text_chunks = [TextChunk.from_chunk(chunk) for chunk in chunks]
|
||||
for chunk in tqdm(text_chunks):
|
||||
questions = [
|
||||
"What is this chunk about?",
|
||||
"What is the main topic of this chunk?",
|
||||
]
|
||||
for question in questions:
|
||||
question_id = str(uuid.uuid4())
|
||||
queries[question_id] = question
|
||||
relevant_docs[question_id] = [chunk.id]
|
||||
ds = QADataset.from_responses(text_chunks, queries, relevant_docs)
|
||||
|
||||
assert len(ds.queries) == 8
|
||||
assert len(ds.corpus) == 4
|
||||
|
||||
model = get_registry().get("sentence-transformers").create()
|
||||
model.finetune(trainset=ds, valset=ds, path=str(tmp_path / "model"), epochs=1)
|
||||
model = (
|
||||
get_registry().get("sentence-transformers").create(name=str(tmp_path / "model"))
|
||||
)
|
||||
res = model.evaluate(ds)
|
||||
assert res is not None
|
||||
|
||||
|
||||
def test_text_chunk():
|
||||
# TODO
|
||||
pass
|
||||
Reference in New Issue
Block a user