Compare commits
2 Commits
docs_march
...
embedding_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
45f2d9976b | ||
|
|
f23641d703 |
@@ -85,7 +85,7 @@ markdown_extensions:
|
||||
alternate_style: true
|
||||
- md_in_html
|
||||
- attr_list
|
||||
|
||||
|
||||
nav:
|
||||
- Home:
|
||||
- LanceDB: index.md
|
||||
@@ -104,14 +104,6 @@ nav:
|
||||
- Overview: hybrid_search/hybrid_search.md
|
||||
- Comparing Rerankers: hybrid_search/eval.md
|
||||
- Airbnb financial data example: notebooks/hybrid_search.ipynb
|
||||
- Reranking:
|
||||
- Quickstart: reranking/index.md
|
||||
- Cohere Reranker: reranking/cohere.md
|
||||
- Linear Combination Reranker: reranking/linear_combination.md
|
||||
- Cross Encoder Reranker: reranking/cross_encoder.md
|
||||
- ColBERT Reranker: reranking/colbert.md
|
||||
- OpenAI Reranker: reranking/openai.md
|
||||
- Building Custom Rerankers: reranking/custom_reranker.md
|
||||
- Filtering: sql.md
|
||||
- Versioning & Reproducibility: notebooks/reproducibility.ipynb
|
||||
- Configuring Storage: guides/storage.md
|
||||
@@ -178,14 +170,6 @@ nav:
|
||||
- Overview: hybrid_search/hybrid_search.md
|
||||
- Comparing Rerankers: hybrid_search/eval.md
|
||||
- Airbnb financial data example: notebooks/hybrid_search.ipynb
|
||||
- Reranking:
|
||||
- Quickstart: reranking/index.md
|
||||
- Cohere Reranker: reranking/cohere.md
|
||||
- Linear Combination Reranker: reranking/linear_combination.md
|
||||
- Cross Encoder Reranker: reranking/cross_encoder.md
|
||||
- ColBERT Reranker: reranking/colbert.md
|
||||
- OpenAI Reranker: reranking/openai.md
|
||||
- Building Custom Rerankers: reranking/custom_reranker.md
|
||||
- Filtering: sql.md
|
||||
- Versioning & Reproducibility: notebooks/reproducibility.ipynb
|
||||
- Configuring Storage: guides/storage.md
|
||||
@@ -239,4 +223,4 @@ extra_javascript:
|
||||
extra:
|
||||
analytics:
|
||||
provider: google
|
||||
property: G-B7NFM40W74
|
||||
property: G-B7NFM40W74
|
||||
|
||||
|
Before Width: | Height: | Size: 147 KiB After Width: | Height: | Size: 104 KiB |
|
Before Width: | Height: | Size: 98 KiB After Width: | Height: | Size: 83 KiB |
|
Before Width: | Height: | Size: 204 KiB After Width: | Height: | Size: 131 KiB |
|
Before Width: | Height: | Size: 112 KiB After Width: | Height: | Size: 82 KiB |
|
Before Width: | Height: | Size: 217 KiB After Width: | Height: | Size: 113 KiB |
|
Before Width: | Height: | Size: 256 KiB After Width: | Height: | Size: 97 KiB |
|
Before Width: | Height: | Size: 20 KiB After Width: | Height: | Size: 6.7 KiB |
150
docs/src/eval/bench_fine_tuned_hybrid.py
Normal file
@@ -0,0 +1,150 @@
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
import os
|
||||
import requests
|
||||
from llama_index.core import ServiceContext, VectorStoreIndex, StorageContext
|
||||
from llama_index.core.schema import TextNode
|
||||
from llama_index.vector_stores.lancedb import LanceDBVectorStore
|
||||
from lancedb.rerankers import CrossEncoderReranker, ColbertReranker, CohereReranker, LinearCombinationReranker
|
||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.embeddings import get_registry
|
||||
from lancedb.embeddings.fine_tuner.dataset import QADataset, TextChunk, DEFAULT_PROMPT_TMPL
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from lancedb.embeddings.fine_tuner.llm import Openai
|
||||
|
||||
import time
|
||||
import lancedb
|
||||
import wandb
|
||||
from pydantic import BaseModel, root_validator
|
||||
from typing import Optional
|
||||
|
||||
TRAIN_DATASET_FPATH = './data/train_dataset.json'
|
||||
VAL_DATASET_FPATH = './data/val_dataset.json'
|
||||
|
||||
with open(TRAIN_DATASET_FPATH, 'r+') as f:
|
||||
train_dataset = json.load(f)
|
||||
|
||||
with open(VAL_DATASET_FPATH, 'r+') as f:
|
||||
val_dataset = json.load(f)
|
||||
|
||||
def train_embedding_model(epoch):
|
||||
def download_test_files(url):
|
||||
# download to cwd
|
||||
files = []
|
||||
filename = os.path.basename(url)
|
||||
if not os.path.exists(filename):
|
||||
print(f"Downloading {url} to {filename}")
|
||||
r = requests.get(url)
|
||||
with open(filename, 'wb') as f:
|
||||
f.write(r.content)
|
||||
files.append(filename)
|
||||
return files
|
||||
|
||||
def get_dataset(url, name):
|
||||
reader = SimpleDirectoryReader(input_files=download_test_files(url))
|
||||
docs = reader.load_data()
|
||||
|
||||
parser = SentenceSplitter()
|
||||
nodes = parser.get_nodes_from_documents(docs)
|
||||
|
||||
if os.path.exists(name):
|
||||
ds = QADataset.load(name)
|
||||
else:
|
||||
llm = Openai()
|
||||
|
||||
# convert Llama-index TextNode to TextChunk
|
||||
chunks = [TextChunk.from_llama_index_node(node) for node in nodes]
|
||||
|
||||
ds = QADataset.from_llm(chunks, llm, num_questions_per_chunk=2)
|
||||
ds.save(name)
|
||||
return ds
|
||||
train_url = 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/10k/uber_2021.pdf'
|
||||
ds = get_dataset(train_url, "qa_dataset_uber")
|
||||
|
||||
|
||||
model = get_registry().get("sentence-transformers").create(name="BAAI/bge-small-en-v1.5")
|
||||
model.finetune(trainset=ds, valset=None, path="model_airbnb", epochs=epoch, log_wandb=True, run_name="lyft_finetune")
|
||||
|
||||
|
||||
def evaluate(
|
||||
dataset,
|
||||
embed_model,
|
||||
reranker=None,
|
||||
top_k=5,
|
||||
verbose=False,
|
||||
):
|
||||
corpus = dataset['corpus']
|
||||
queries = dataset['queries']
|
||||
relevant_docs = dataset['relevant_docs']
|
||||
|
||||
vector_store = LanceDBVectorStore(uri="/tmp/lancedb")
|
||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||
service_context = ServiceContext.from_defaults(embed_model=embed_model)
|
||||
nodes = [TextNode(id_=id_, text=text) for id_, text in corpus.items()]
|
||||
index = VectorStoreIndex(
|
||||
nodes,
|
||||
service_context=service_context,
|
||||
show_progress=True,
|
||||
storage_context=storage_context,
|
||||
)
|
||||
tbl = vector_store.connection.open_table(vector_store.table_name)
|
||||
tbl.create_fts_index("text", replace=True)
|
||||
|
||||
eval_results = []
|
||||
for query_id, query in tqdm(queries.items()):
|
||||
query_vector = embed_model.get_query_embedding(query)
|
||||
try:
|
||||
if reranker is None:
|
||||
rs = tbl.search(query_vector).limit(top_k).to_pandas()
|
||||
else:
|
||||
rs = tbl.search((query_vector, query)).rerank(reranker=reranker).limit(top_k).to_pandas()
|
||||
except Exception as e:
|
||||
print(f'Error with query: {query_id} {e}')
|
||||
continue
|
||||
retrieved_ids = rs['id'].tolist()[:top_k]
|
||||
expected_id = relevant_docs[query_id][0]
|
||||
is_hit = expected_id in retrieved_ids # assume 1 relevant doc
|
||||
if len(eval_results) == 0:
|
||||
print(f"Query: {query}")
|
||||
print(f"Expected: {expected_id}")
|
||||
print(f"Retrieved: {retrieved_ids}")
|
||||
eval_result = {
|
||||
'is_hit': is_hit,
|
||||
'retrieved': retrieved_ids,
|
||||
'expected': expected_id,
|
||||
'query': query_id,
|
||||
}
|
||||
eval_results.append(eval_result)
|
||||
return eval_results
|
||||
|
||||
if __name__ == '__main__':
|
||||
train_embedding_model(4)
|
||||
#embed_model = OpenAIEmbedding() # model="text-embedding-3-small"
|
||||
rerankers = {
|
||||
"Vector Search": None,
|
||||
"Cohere": CohereReranker(),
|
||||
"Cross Encoder": CrossEncoderReranker(),
|
||||
"Colbert": ColbertReranker(),
|
||||
"linear": LinearCombinationReranker(),
|
||||
}
|
||||
top_ks = [3]
|
||||
for top_k in top_ks:
|
||||
#for epoch in epochs:
|
||||
for name, reranker in rerankers.items():
|
||||
#embed_model = HuggingFaceEmbedding("./model_airbnb")
|
||||
embed_model = OpenAIEmbedding()
|
||||
wandb.init(project=f"Reranker-based", name=name)
|
||||
val_eval_results = evaluate(val_dataset, embed_model, reranker=reranker, top_k=top_k)
|
||||
df = pd.DataFrame(val_eval_results)
|
||||
|
||||
hit_rate = df['is_hit'].mean()
|
||||
print(f'Hit rate: {hit_rate:.2f}')
|
||||
wandb.log({f"openai_base_hit_rate_@{top_k}": hit_rate})
|
||||
wandb.finish()
|
||||
|
||||
|
||||
188
docs/src/eval/llama-finetuning-bench.py
Normal file
@@ -0,0 +1,188 @@
|
||||
import os
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from llama_index.core.llama_dataset import LabelledRagDataset
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from lancedb.embeddings.fine_tuner.dataset import QADataset, TextChunk
|
||||
from lancedb.embeddings.fine_tuner.llm import Openai
|
||||
from lancedb.embeddings import get_registry
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.vector_stores.lancedb import LanceDBVectorStore
|
||||
from llama_index.core import ServiceContext, VectorStoreIndex, StorageContext
|
||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||
from llama_index.core.llama_pack import download_llama_pack
|
||||
|
||||
import time
|
||||
import wandb
|
||||
|
||||
|
||||
|
||||
import pandas as pd
|
||||
|
||||
def get_paths_from_dataset(dataset: str, split=True):
|
||||
"""
|
||||
Returns paths of:
|
||||
- downloaded dataset, lance train dataset, lance test dataset, finetuned model
|
||||
"""
|
||||
if split:
|
||||
return f"./data/{dataset}", f"./data/{dataset}_lance_train", f"./data/{dataset}_lance_test", f"./data/tuned_{dataset}"
|
||||
return f"./data/{dataset}", f"./data/{dataset}_lance", f"./data/tuned_{dataset}"
|
||||
|
||||
def get_llama_dataset(dataset: str):
|
||||
"""
|
||||
returns:
|
||||
- nodes, documents, rag_dataset
|
||||
"""
|
||||
if not os.path.exists(f"./data/{dataset}"):
|
||||
os.system(f"llamaindex-cli download-llamadataset {dataset} --download-dir ./data/{dataset}")
|
||||
rag_dataset = LabelledRagDataset.from_json(f"./data/{dataset}/rag_dataset.json")
|
||||
docs = SimpleDirectoryReader(input_dir=f"./data/{dataset}/source_files").load_data()
|
||||
|
||||
parser = SentenceSplitter()
|
||||
nodes = parser.get_nodes_from_documents(docs)
|
||||
|
||||
return nodes, docs, rag_dataset
|
||||
|
||||
def lance_dataset_from_llama_nodes(nodes: list, name: str, split=True):
|
||||
llm = Openai()
|
||||
chunks = [TextChunk.from_llama_index_node(node) for node in nodes]
|
||||
# train test split 75-35
|
||||
if not split:
|
||||
if os.path.exists(f"./data/{name}_lance"):
|
||||
ds = QADataset.load(f"./data/{name}_lance")
|
||||
return ds
|
||||
ds = QADataset.from_llm(chunks, llm)
|
||||
ds.save(f"./data/{name}_lance")
|
||||
return ds
|
||||
|
||||
if os.path.exists(f"./data/{name}_lance_train") and os.path.exists(f"./data/{name}_lance_test"):
|
||||
train_ds = QADataset.load(f"./data/{name}_lance_train")
|
||||
test_ds = QADataset.load(f"./data/{name}_lance_test")
|
||||
return train_ds, test_ds
|
||||
# split chunks random
|
||||
train_size = int(len(chunks) * 0.65)
|
||||
train_chunks = chunks[:train_size]
|
||||
test_chunks = chunks[train_size:]
|
||||
train_ds = QADataset.from_llm(train_chunks, llm)
|
||||
test_ds = QADataset.from_llm(test_chunks, llm)
|
||||
train_ds.save(f"./data/{name}_lance_train")
|
||||
test_ds.save(f"./data/{name}_lance_test")
|
||||
return train_ds, test_ds
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def finetune(trainset: str, model: str, epochs: int, path: str, valset: str = None):
|
||||
print(f"Finetuning {model} for {epochs} epochs")
|
||||
print(f"trainset query instances: {len(trainset.queries)}")
|
||||
print(f"valset query instances: {len(valset.queries)}")
|
||||
|
||||
valset = valset if valset is not None else trainset
|
||||
model = get_registry().get("sentence-transformers").create(name=model)
|
||||
base_result = model.evaluate(valset, path=f"./data/eva/")
|
||||
base_hit_rate = pd.DataFrame(base_result)["is_hit"].mean()
|
||||
|
||||
model.finetune(trainset=trainset, valset=valset, path=path, epochs=epochs)
|
||||
tuned = get_registry().get("sentence-transformers").create(name=path)
|
||||
tuned_result = tuned.evaluate(valset, path=f"./data/eva/{str(time.time())}")
|
||||
tuned_hit_rate = pd.DataFrame(tuned_result)["is_hit"].mean()
|
||||
|
||||
return base_hit_rate, tuned_hit_rate
|
||||
|
||||
|
||||
def eval_rag(dataset: str, model: str):
|
||||
# Requires - pip install llama-index-vector-stores-lancedb
|
||||
# Requires - pip install llama-index-embeddings-huggingface
|
||||
nodes, docs, rag_dataset = get_llama_dataset(dataset)
|
||||
|
||||
embed_model = HuggingFaceEmbedding(model)
|
||||
vector_store = LanceDBVectorStore(uri="/tmp/lancedb")
|
||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||
service_context = ServiceContext.from_defaults(embed_model=embed_model)
|
||||
index = VectorStoreIndex(
|
||||
nodes,
|
||||
service_context=service_context,
|
||||
show_progress=True,
|
||||
storage_context=storage_context,
|
||||
)
|
||||
|
||||
# build basic RAG system
|
||||
index = VectorStoreIndex.from_documents(documents=docs)
|
||||
query_engine = index.as_query_engine()
|
||||
|
||||
# evaluate using the RagEvaluatorPack
|
||||
RagEvaluatorPack = download_llama_pack(
|
||||
"RagEvaluatorPack", "./rag_evaluator_pack"
|
||||
)
|
||||
rag_evaluator_pack = RagEvaluatorPack(
|
||||
rag_dataset=rag_dataset, query_engine=query_engine
|
||||
)
|
||||
|
||||
metrics = rag_evaluator_pack.run(
|
||||
batch_size=20, # batches the number of openai api calls to make
|
||||
sleep_time_in_seconds=1, # seconds to sleep before making an api call
|
||||
)
|
||||
|
||||
return metrics
|
||||
|
||||
def main(dataset, model, epochs, project: str = "lancedb_finetune"):
|
||||
nodes, _, _ = get_llama_dataset(dataset)
|
||||
trainset, valset = lance_dataset_from_llama_nodes(nodes, dataset)
|
||||
data_path, lance_train_path, lance_test_path, tuned_path = get_paths_from_dataset(dataset, split=True)
|
||||
|
||||
base_hit_rate, tuned_hit_rate = finetune(trainset, model, epochs, tuned_path, valset)
|
||||
""""
|
||||
wandb.init(project="lancedb_finetune", name=f"{dataset}_{model}_base")
|
||||
wandb.log({
|
||||
"hit_rate": base_hit_rate,
|
||||
})
|
||||
wandb.finish()
|
||||
|
||||
wandb.init(project="lancedb_finetune", name=f"{dataset}_{model}_{epochs}")
|
||||
wandb.log({
|
||||
"hit_rate": tuned_hit_rate,
|
||||
})
|
||||
wandb.finish()
|
||||
"""
|
||||
|
||||
# Base model model metrics
|
||||
metrics = eval_rag(dataset, model)
|
||||
|
||||
# Tuned model metrics
|
||||
metrics_tuned = eval_rag(dataset, tuned_path)
|
||||
|
||||
wandb.init(project="lancedb_rageval", name=f"{dataset}_{model}_base")
|
||||
import pdb; pdb.set_trace()
|
||||
wandb.log(
|
||||
metrics
|
||||
)
|
||||
wandb.finish()
|
||||
|
||||
wandb.init(project="lancedb_rageval", name=f"{dataset}_{model}_{epochs}")
|
||||
wandb.log(metrics_tuned)
|
||||
|
||||
|
||||
def banchmark_all():
|
||||
datasets = [#"Uber10KDataset2021", "MiniTruthfulQADataset", "MiniSquadV2Dataset", "MiniEsgBenchDataset", "MiniCovidQaDataset", "Llama2PaperDataset", "HistoryOfAlexnetDataset",
|
||||
"PatronusAIFinanceBenchDataset"]
|
||||
models = ["BAAI/bge-small-en-v1.5"]
|
||||
|
||||
for model in models:
|
||||
for dataset in datasets:
|
||||
main(dataset, model, 5)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--dataset", type=str, default="BraintrustCodaHelpDeskDataset")
|
||||
parser.add_argument("--model", type=str, default="BAAI/bge-small-en-v1.5")
|
||||
parser.add_argument("--epochs", type=int, default=4)
|
||||
parser.add_argument("--project", type=str, default="lancedb_finetune")
|
||||
parser.add_argument("--benchmark-all", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.benchmark_all:
|
||||
banchmark_all()
|
||||
else:
|
||||
main(args.dataset, args.model, args.epochs, args.project)
|
||||
71
docs/src/eval/test_fine_tune_from_llm.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import os
|
||||
import json
|
||||
import lancedb
|
||||
import pandas as pd
|
||||
|
||||
from lancedb.embeddings.fine_tuner.llm import Openai
|
||||
from lancedb.embeddings.fine_tuner.dataset import QADataset, TextChunk
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from llama_index.core.schema import MetadataMode
|
||||
from lancedb.embeddings import get_registry
|
||||
|
||||
|
||||
test_url = 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/10k/lyft_2021.pdf'
|
||||
train_url = 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/10k/uber_2021.pdf'
|
||||
def download_test_files(url):
|
||||
import os
|
||||
import requests
|
||||
|
||||
# download to cwd
|
||||
files = []
|
||||
filename = os.path.basename(url)
|
||||
if not os.path.exists(filename):
|
||||
print(f"Downloading {url} to {filename}")
|
||||
r = requests.get(url)
|
||||
with open(filename, 'wb') as f:
|
||||
f.write(r.content)
|
||||
files.append(filename)
|
||||
return files
|
||||
|
||||
def get_dataset(url, name):
|
||||
reader = SimpleDirectoryReader(input_files=download_test_files(url))
|
||||
docs = reader.load_data()
|
||||
|
||||
parser = SentenceSplitter()
|
||||
nodes = parser.get_nodes_from_documents(docs)
|
||||
|
||||
if os.path.exists(name):
|
||||
ds = QADataset.load(name)
|
||||
else:
|
||||
llm = Openai()
|
||||
|
||||
# convert Llama-index TextNode to TextChunk
|
||||
chunks = [TextChunk.from_llama_index_node(node) for node in nodes]
|
||||
|
||||
ds = QADataset.from_llm(chunks, llm)
|
||||
ds.save(name)
|
||||
return ds
|
||||
|
||||
|
||||
|
||||
trainset = get_dataset(test_url, "qa_dataset_1")
|
||||
valset = get_dataset(train_url, "valset")
|
||||
|
||||
model = get_registry().get("sentence-transformers").create()
|
||||
model.finetune(trainset=trainset, valset=valset, path="model_finetuned_1", epochs=4)
|
||||
|
||||
base = get_registry().get("sentence-transformers").create()
|
||||
tuned = get_registry().get("sentence-transformers").create(name="./model_finetuned_1")
|
||||
openai = get_registry().get("openai").create(name="text-embedding-3-large")
|
||||
|
||||
|
||||
rs1 = base.evaluate(valset, path="val_res")
|
||||
rs2 = tuned.evaluate(valset, path="val_res")
|
||||
rs3 = openai.evaluate(valset)
|
||||
|
||||
print("openai-embedding-v3 hit-rate - ", pd.DataFrame(rs3)["is_hit"].mean())
|
||||
print("fine-tuned hit-rate - ", pd.DataFrame(rs2)["is_hit"].mean())
|
||||
print("Base model hite-rate - ", pd.DataFrame(rs1)["is_hit"].mean())
|
||||
|
||||
119
docs/src/eval/test_fine_tune_from_responses.py
Normal file
@@ -0,0 +1,119 @@
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import uuid
|
||||
import lancedb
|
||||
import pandas as pd
|
||||
|
||||
from tqdm import tqdm
|
||||
from lancedb.embeddings.fine_tuner.llm import Openai
|
||||
from lancedb.embeddings.fine_tuner.dataset import QADataset, TextChunk, DEFAULT_PROMPT_TMPL
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from llama_index.core.schema import MetadataMode
|
||||
from lancedb.embeddings import get_registry
|
||||
|
||||
|
||||
|
||||
test_url = 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/10k/lyft_2021.pdf'
|
||||
train_url = 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/10k/uber_2021.pdf'
|
||||
def download_test_files(url):
|
||||
import os
|
||||
import requests
|
||||
|
||||
|
||||
# download to cwd
|
||||
files = []
|
||||
filename = os.path.basename(url)
|
||||
if not os.path.exists(filename):
|
||||
print(f"Downloading {url} to {filename}")
|
||||
r = requests.get(url)
|
||||
with open(filename, 'wb') as f:
|
||||
f.write(r.content)
|
||||
files.append(filename)
|
||||
return files
|
||||
|
||||
|
||||
def get_node(url):
|
||||
reader = SimpleDirectoryReader(input_files=download_test_files(url))
|
||||
docs = reader.load_data()
|
||||
|
||||
parser = SentenceSplitter()
|
||||
nodes = parser.get_nodes_from_documents(docs)
|
||||
|
||||
return nodes
|
||||
def get_dataset(url, name):
|
||||
reader = SimpleDirectoryReader(input_files=download_test_files(url))
|
||||
docs = reader.load_data()
|
||||
|
||||
parser = SentenceSplitter()
|
||||
nodes = parser.get_nodes_from_documents(docs)
|
||||
|
||||
if os.path.exists(name):
|
||||
ds = QADataset.load(name)
|
||||
else:
|
||||
llm = Openai()
|
||||
|
||||
# convert Llama-index TextNode to TextChunk
|
||||
chunks = [TextChunk.from_llama_index_node(node) for node in nodes]
|
||||
|
||||
ds = QADataset.from_llm(chunks, llm)
|
||||
ds.save(name)
|
||||
return ds
|
||||
|
||||
nodes = get_node(train_url)
|
||||
|
||||
db = lancedb.connect("~/lancedb/fine-tuning")
|
||||
model = get_registry().get("openai").create()
|
||||
class Schema(LanceModel):
|
||||
id: str
|
||||
text: str = model.SourceField()
|
||||
vector: Vector(model.ndims()) = model.VectorField()
|
||||
|
||||
retriever = db.create_table("fine-tuning", schema=Schema, mode="overwrite")
|
||||
pylist = [{"id": str(node.node_id), "text": node.text} for node in nodes]
|
||||
retriever.add(pylist)
|
||||
|
||||
|
||||
|
||||
ds_name = "response_data"
|
||||
if os.path.exists(ds_name):
|
||||
ds = QADataset.load(ds_name)
|
||||
else:
|
||||
# Generate questions
|
||||
llm = Openai()
|
||||
text_chunks = [TextChunk.from_llama_index_node(node) for node in nodes]
|
||||
|
||||
queries = {}
|
||||
relevant_docs = {}
|
||||
for chunk in tqdm(text_chunks):
|
||||
text = chunk.text
|
||||
questions = llm.get_questions(DEFAULT_PROMPT_TMPL.format(context_str=text, num_questions_per_chunk=2))
|
||||
|
||||
for question in questions:
|
||||
question_id = str(uuid.uuid4())
|
||||
queries[question_id] = question
|
||||
relevant_docs[question_id] = [retriever.search(question).to_pandas()["id"].tolist()[0]]
|
||||
ds = QADataset.from_responses(text_chunks, queries, relevant_docs)
|
||||
ds.save(ds_name)
|
||||
|
||||
|
||||
# Fine-tune model
|
||||
valset = get_dataset(train_url, "valset")
|
||||
|
||||
model = get_registry().get("sentence-transformers").create()
|
||||
res_base = model.evaluate(valset)
|
||||
|
||||
model.finetune(trainset=ds, path="model_finetuned", epochs=4, log_wandb=True)
|
||||
tuned = get_registry().get("sentence-transformers").create(name="./model_finetuned")
|
||||
res_tuned = tuned.evaluate(valset)
|
||||
|
||||
openai_model = get_registry().get("openai").create()
|
||||
#res_openai = openai_model.evaluate(valset)
|
||||
|
||||
#print(f"openai model results: {pd.DataFrame(res_openai)['is_hit'].mean()}")
|
||||
print(f"base model results: {pd.DataFrame(res_base)['is_hit'].mean()}")
|
||||
print(f"tuned model results: {pd.DataFrame(res_tuned)['is_hit'].mean()}")
|
||||
|
||||
|
||||
@@ -5,9 +5,6 @@ LanceDB supports both semantic and keyword-based search (also termed full-text s
|
||||
## Hybrid search in LanceDB
|
||||
You can perform hybrid search in LanceDB by combining the results of semantic and full-text search via a reranking algorithm of your choice. LanceDB provides multiple rerankers out of the box. However, you can always write a custom reranker if your use case need more sophisticated logic .
|
||||
|
||||
!!! note
|
||||
You need to create a full-text search index before performing a hybrid search. You can create a full-text search index using the `create_fts_index()` method of the table object.
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
@@ -58,7 +55,188 @@ By default, LanceDB uses `LinearCombinationReranker(weight=0.7)` to combine and
|
||||
|
||||
|
||||
## Available Rerankers
|
||||
LanceDB provides a number of re-rankers out of the box. You can use any of these re-rankers by passing them to the `rerank()` method. Visit the [rerankers](../reranking/) page for more information on each re-ranker.
|
||||
LanceDB provides a number of re-rankers out of the box. You can use any of these re-rankers by passing them to the `rerank()` method. Here's a list of available re-rankers:
|
||||
|
||||
## Custom Rerankers
|
||||
You can also create custom rerankers by extending the base `Reranker` class. The custom reranker should implement the `rerank` method that takes a list of search results and returns a reranked list of search results. Visit the [custom rerankers](../reranking/custom_reranker.md) page for more information on creating custom rerankers.
|
||||
### Linear Combination Reranker
|
||||
This is the default re-ranker used by LanceDB. It combines the results of semantic and full-text search using a linear combination of the scores. The weights for the linear combination can be specified. It defaults to 0.7, i.e, 70% weight for semantic search and 30% weight for full-text search.
|
||||
|
||||
|
||||
```python
|
||||
from lancedb.rerankers import LinearCombinationReranker
|
||||
|
||||
reranker = LinearCombinationReranker(weight=0.3) # Use 0.3 as the weight for vector search
|
||||
|
||||
results = table.search("rebel", query_type="hybrid").rerank(reranker=reranker).to_pandas()
|
||||
```
|
||||
|
||||
### Arguments
|
||||
----------------
|
||||
* `weight`: `float`, default `0.7`:
|
||||
The weight to use for the semantic search score. The weight for the full-text search score is `1 - weights`.
|
||||
* `fill`: `float`, default `1.0`:
|
||||
The score to give to results that are only in one of the two result sets.This is treated as penalty, so a higher value means a lower score.
|
||||
TODO: We should just hardcode this-- its pretty confusing as we invert scores to calculate final score
|
||||
* `return_score` : str, default `"relevance"`
|
||||
options are "relevance" or "all"
|
||||
The type of score to return. If "relevance", will return only the `_relevance_score. If "all", will return all scores from the vector and FTS search along with the relevance score.
|
||||
|
||||
### Cohere Reranker
|
||||
This re-ranker uses the [Cohere](https://cohere.ai/) API to combine the results of semantic and full-text search. You can use this re-ranker by passing `CohereReranker()` to the `rerank()` method. Note that you'll need to set the `COHERE_API_KEY` environment variable to use this re-ranker.
|
||||
|
||||
```python
|
||||
from lancedb.rerankers import CohereReranker
|
||||
|
||||
reranker = CohereReranker()
|
||||
|
||||
results = table.search("vampire weekend", query_type="hybrid").rerank(reranker=reranker).to_pandas()
|
||||
```
|
||||
|
||||
### Arguments
|
||||
----------------
|
||||
* `model_name` : str, default `"rerank-english-v2.0"`
|
||||
The name of the cross encoder model to use. Available cohere models are:
|
||||
- rerank-english-v2.0
|
||||
- rerank-multilingual-v2.0
|
||||
* `column` : str, default `"text"`
|
||||
The name of the column to use as input to the cross encoder model.
|
||||
* `top_n` : str, default `None`
|
||||
The number of results to return. If None, will return all results.
|
||||
|
||||
!!! Note
|
||||
Only returns `_relevance_score`. Does not support `return_score = "all"`.
|
||||
|
||||
### Cross Encoder Reranker
|
||||
This reranker uses the [Sentence Transformers](https://www.sbert.net/) library to combine the results of semantic and full-text search. You can use it by passing `CrossEncoderReranker()` to the `rerank()` method.
|
||||
|
||||
```python
|
||||
from lancedb.rerankers import CrossEncoderReranker
|
||||
|
||||
reranker = CrossEncoderReranker()
|
||||
|
||||
results = table.search("harmony hall", query_type="hybrid").rerank(reranker=reranker).to_pandas()
|
||||
```
|
||||
|
||||
|
||||
### Arguments
|
||||
----------------
|
||||
* `model` : str, default `"cross-encoder/ms-marco-TinyBERT-L-6"`
|
||||
The name of the cross encoder model to use. Available cross encoder models can be found [here](https://www.sbert.net/docs/pretrained_cross-encoders.html)
|
||||
* `column` : str, default `"text"`
|
||||
The name of the column to use as input to the cross encoder model.
|
||||
* `device` : str, default `None`
|
||||
The device to use for the cross encoder model. If None, will use "cuda" if available, otherwise "cpu".
|
||||
|
||||
!!! Note
|
||||
Only returns `_relevance_score`. Does not support `return_score = "all"`.
|
||||
|
||||
|
||||
### ColBERT Reranker
|
||||
This reranker uses the ColBERT model to combine the results of semantic and full-text search. You can use it by passing `ColbertrReranker()` to the `rerank()` method.
|
||||
|
||||
ColBERT reranker model calculates relevance of given docs against the query and don't take existing fts and vector search scores into account, so it currently only supports `return_score="relevance"`. By default, it looks for `text` column to rerank the results. But you can specify the column name to use as input to the cross encoder model as described below.
|
||||
|
||||
```python
|
||||
from lancedb.rerankers import ColbertReranker
|
||||
|
||||
reranker = ColbertReranker()
|
||||
|
||||
results = table.search("harmony hall", query_type="hybrid").rerank(reranker=reranker).to_pandas()
|
||||
```
|
||||
|
||||
### Arguments
|
||||
----------------
|
||||
* `model_name` : `str`, default `"colbert-ir/colbertv2.0"`
|
||||
The name of the cross encoder model to use.
|
||||
* `column` : `str`, default `"text"`
|
||||
The name of the column to use as input to the cross encoder model.
|
||||
* `return_score` : `str`, default `"relevance"`
|
||||
options are `"relevance"` or `"all"`. Only `"relevance"` is supported for now.
|
||||
|
||||
!!! Note
|
||||
Only returns `_relevance_score`. Does not support `return_score = "all"`.
|
||||
|
||||
### OpenAI Reranker
|
||||
This reranker uses the OpenAI API to combine the results of semantic and full-text search. You can use it by passing `OpenaiReranker()` to the `rerank()` method.
|
||||
|
||||
!!! Note
|
||||
This prompts chat model to rerank results which is not a dedicated reranker model. This should be treated as experimental.
|
||||
|
||||
!!! Tip
|
||||
- You might run out of token limit so set the search `limits` based on your token limit.
|
||||
- It is recommended to use gpt-4-turbo-preview, the default model, older models might lead to undesired behaviour
|
||||
|
||||
```python
|
||||
from lancedb.rerankers import OpenaiReranker
|
||||
|
||||
reranker = OpenaiReranker()
|
||||
|
||||
results = table.search("harmony hall", query_type="hybrid").rerank(reranker=reranker).to_pandas()
|
||||
```
|
||||
|
||||
### Arguments
|
||||
----------------
|
||||
* `model_name` : `str`, default `"gpt-4-turbo-preview"`
|
||||
The name of the cross encoder model to use.
|
||||
* `column` : `str`, default `"text"`
|
||||
The name of the column to use as input to the cross encoder model.
|
||||
* `return_score` : `str`, default `"relevance"`
|
||||
options are "relevance" or "all". Only "relevance" is supported for now.
|
||||
* `api_key` : `str`, default `None`
|
||||
The API key to use. If None, will use the OPENAI_API_KEY environment variable.
|
||||
|
||||
|
||||
## Building Custom Rerankers
|
||||
You can build your own custom reranker by subclassing the `Reranker` class and implementing the `rerank_hybrid()` method. Here's an example of a custom reranker that combines the results of semantic and full-text search using a linear combination of the scores.
|
||||
|
||||
The `Reranker` base interface comes with a `merge_results()` method that can be used to combine the results of semantic and full-text search. This is a vanilla merging algorithm that simply concatenates the results and removes the duplicates without taking the scores into consideration. It only keeps the first copy of the row encountered. This works well in cases that don't require the scores of semantic and full-text search to combine the results. If you want to use the scores or want to support `return_score="all"`, you'll need to implement your own merging algorithm.
|
||||
|
||||
```python
|
||||
|
||||
from lancedb.rerankers import Reranker
|
||||
import pyarrow as pa
|
||||
|
||||
class MyReranker(Reranker):
|
||||
def __init__(self, param1, param2, ..., return_score="relevance"):
|
||||
super().__init__(return_score)
|
||||
self.param1 = param1
|
||||
self.param2 = param2
|
||||
|
||||
def rerank_hybrid(self, query: str, vector_results: pa.Table, fts_results: pa.Table):
|
||||
# Use the built-in merging function
|
||||
combined_result = self.merge_results(vector_results, fts_results)
|
||||
|
||||
# Do something with the combined results
|
||||
# ...
|
||||
|
||||
# Return the combined results
|
||||
return combined_result
|
||||
|
||||
```
|
||||
|
||||
### Example of a Custom Reranker
|
||||
For the sake of simplicity let's build custom reranker that just enchances the Cohere Reranker by accepting a filter query, and accept other CohereReranker params as kwags.
|
||||
|
||||
```python
|
||||
|
||||
from typing import List, Union
|
||||
import pandas as pd
|
||||
from lancedb.rerankers import CohereReranker
|
||||
|
||||
class MofidifiedCohereReranker(CohereReranker):
|
||||
def __init__(self, filters: Union[str, List[str]], **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
filters = filters if isinstance(filters, list) else [filters]
|
||||
self.filters = filters
|
||||
|
||||
def rerank_hybrid(self, query: str, vector_results: pa.Table, fts_results: pa.Table)-> pa.Table:
|
||||
combined_result = super().rerank_hybrid(query, vector_results, fts_results)
|
||||
df = combined_result.to_pandas()
|
||||
for filter in self.filters:
|
||||
df = df.query("not text.str.contains(@filter)")
|
||||
|
||||
return pa.Table.from_pandas(df)
|
||||
|
||||
```
|
||||
|
||||
!!! tip
|
||||
The `vector_results` and `fts_results` are pyarrow tables. You can convert them to pandas dataframes using `to_pandas()` method and perform any operations you want. After you are done, you can convert the dataframe back to pyarrow table using `pa.Table.from_pandas()` method and return it.
|
||||
|
||||
@@ -1,75 +0,0 @@
|
||||
# Cohere Reranker
|
||||
|
||||
This re-ranker uses the [Cohere](https://cohere.ai/) API to rerank the search results. You can use this re-ranker by passing `CohereReranker()` to the `rerank()` method. Note that you'll either need to set the `COHERE_API_KEY` environment variable or pass the `api_key` argument to use this re-ranker.
|
||||
|
||||
|
||||
!!! note
|
||||
Supported Query Types: Hybrid, Vector, FTS
|
||||
|
||||
|
||||
```python
|
||||
import numpy
|
||||
import lancedb
|
||||
from lancedb.embeddings import get_registry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.rerankers import CohereReranker
|
||||
|
||||
embedder = get_registry().get("sentence-transformers").create()
|
||||
db = lancedb.connect("~/.lancedb")
|
||||
|
||||
class Schema(LanceModel):
|
||||
text: str = embedder.SourceField()
|
||||
vector: Vector(embedder.ndims()) = embedder.VectorField()
|
||||
|
||||
data = [
|
||||
{"text": "hello world"},
|
||||
{"text": "goodbye world"}
|
||||
]
|
||||
tbl = db.create_table("test", schema=Schema, mode="overwrite")
|
||||
tbl.add(data)
|
||||
reranker = CohereReranker(api_key="key")
|
||||
|
||||
# Run vector search with a reranker
|
||||
result = tbl.search("hello").rerank(reranker=reranker).to_list()
|
||||
|
||||
# Run FTS search with a reranker
|
||||
result = tbl.search("hello", query_type="fts").rerank(reranker=reranker).to_list()
|
||||
|
||||
# Run hybrid search with a reranker
|
||||
tbl.create_fts_index("text", replace=True)
|
||||
result = tbl.search("hello", query_type="hybrid").rerank(reranker=reranker).to_list()
|
||||
|
||||
```
|
||||
|
||||
Accepted Arguments
|
||||
----------------
|
||||
| Argument | Type | Default | Description |
|
||||
| --- | --- | --- | --- |
|
||||
| `model_name` | `str` | `"rerank-english-v2.0"` | The name of the reranker model to use. Available cohere models are: rerank-english-v2.0, rerank-multilingual-v2.0 |
|
||||
| `column` | `str` | `"text"` | The name of the column to use as input to the cross encoder model. |
|
||||
| `top_n` | `str` | `None` | The number of results to return. If None, will return all results. |
|
||||
| `api_key` | `str` | `None` | The API key for the Cohere API. If not provided, the `COHERE_API_KEY` environment variable is used. |
|
||||
| `return_score` | str | `"relevance"` | Options are "relevance" or "all". The type of score to return. If "relevance", will return only the `_relevance_score. If "all" is supported, will return relevance score along with the vector and/or fts scores depending on query type |
|
||||
|
||||
|
||||
|
||||
## Supported Scores for each query type
|
||||
You can specify the type of scores you want the reranker to return. The following are the supported scores for each query type:
|
||||
|
||||
### Hybrid Search
|
||||
|`return_score`| Status | Description |
|
||||
| --- | --- | --- |
|
||||
| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column |
|
||||
| `all` | ❌ Not Supported | Returns have vector(`_distance`) and FTS(`score`) along with Hybrid Search score(`_relevance_score`) |
|
||||
|
||||
### Vector Search
|
||||
|`return_score`| Status | Description |
|
||||
| --- | --- | --- |
|
||||
| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column |
|
||||
| `all` | ✅ Supported | Returns have vector(`_distance`) along with Hybrid Search score(`_relevance_score`) |
|
||||
|
||||
### FTS Search
|
||||
|`return_score`| Status | Description |
|
||||
| --- | --- | --- |
|
||||
| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column |
|
||||
| `all` | ✅ Supported | Returns have FTS(`score`) along with Hybrid Search score(`_relevance_score`) |
|
||||
@@ -1,71 +0,0 @@
|
||||
# ColBERT Reranker
|
||||
|
||||
This re-ranker uses ColBERT model to rerank the search results. You can use this re-ranker by passing `ColbertReranker()` to the `rerank()` method.
|
||||
!!! note
|
||||
Supported Query Types: Hybrid, Vector, FTS
|
||||
|
||||
|
||||
```python
|
||||
import numpy
|
||||
import lancedb
|
||||
from lancedb.embeddings import get_registry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.rerankers import ColbertReranker
|
||||
|
||||
embedder = get_registry().get("sentence-transformers").create()
|
||||
db = lancedb.connect("~/.lancedb")
|
||||
|
||||
class Schema(LanceModel):
|
||||
text: str = embedder.SourceField()
|
||||
vector: Vector(embedder.ndims()) = embedder.VectorField()
|
||||
|
||||
data = [
|
||||
{"text": "hello world"},
|
||||
{"text": "goodbye world"}
|
||||
]
|
||||
tbl = db.create_table("test", schema=Schema, mode="overwrite")
|
||||
tbl.add(data)
|
||||
reranker = ColbertReranker()
|
||||
|
||||
# Run vector search with a reranker
|
||||
result = tbl.search("hello").rerank(reranker=reranker).to_list()
|
||||
|
||||
# Run FTS search with a reranker
|
||||
result = tbl.search("hello", query_type="fts").rerank(reranker=reranker).to_list()
|
||||
|
||||
# Run hybrid search with a reranker
|
||||
tbl.create_fts_index("text", replace=True)
|
||||
result = tbl.search("hello", query_type="hybrid").rerank(reranker=reranker).to_list()
|
||||
|
||||
```
|
||||
|
||||
Accepted Arguments
|
||||
----------------
|
||||
| Argument | Type | Default | Description |
|
||||
| --- | --- | --- | --- |
|
||||
| `model_name` | `str` | `"colbert-ir/colbertv2.0"` | The name of the reranker model to use.|
|
||||
| `column` | `str` | `"text"` | The name of the column to use as input to the cross encoder model. |
|
||||
| `device` | `str` | `None` | The device to use for the cross encoder model. If None, will use "cuda" if available, otherwise "cpu". |
|
||||
| `return_score` | str | `"relevance"` | Options are "relevance" or "all". The type of score to return. If "relevance", will return only the `_relevance_score. If "all" is supported, will return relevance score along with the vector and/or fts scores depending on query type |
|
||||
|
||||
|
||||
## Supported Scores for each query type
|
||||
You can specify the type of scores you want the reranker to return. The following are the supported scores for each query type:
|
||||
|
||||
### Hybrid Search
|
||||
|`return_score`| Status | Description |
|
||||
| --- | --- | --- |
|
||||
| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column |
|
||||
| `all` | ❌ Not Supported | Returns have vector(`_distance`) and FTS(`score`) along with Hybrid Search score(`_relevance_score`) |
|
||||
|
||||
### Vector Search
|
||||
|`return_score`| Status | Description |
|
||||
| --- | --- | --- |
|
||||
| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column |
|
||||
| `all` | ✅ Supported | Returns have vector(`_distance`) along with Hybrid Search score(`_relevance_score`) |
|
||||
|
||||
### FTS Search
|
||||
|`return_score`| Status | Description |
|
||||
| --- | --- | --- |
|
||||
| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column |
|
||||
| `all` | ✅ Supported | Returns have FTS(`score`) along with Hybrid Search score(`_relevance_score`) |
|
||||
@@ -1,70 +0,0 @@
|
||||
# Cross Encoder Reranker
|
||||
|
||||
This re-ranker uses Cross Encoder models from sentence-transformers to rerank the search results. You can use this re-ranker by passing `CrossEncoderReranker()` to the `rerank()` method.
|
||||
!!! note
|
||||
Supported Query Types: Hybrid, Vector, FTS
|
||||
|
||||
|
||||
```python
|
||||
import numpy
|
||||
import lancedb
|
||||
from lancedb.embeddings import get_registry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.rerankers import CrossEncoderReranker
|
||||
|
||||
embedder = get_registry().get("sentence-transformers").create()
|
||||
db = lancedb.connect("~/.lancedb")
|
||||
|
||||
class Schema(LanceModel):
|
||||
text: str = embedder.SourceField()
|
||||
vector: Vector(embedder.ndims()) = embedder.VectorField()
|
||||
|
||||
data = [
|
||||
{"text": "hello world"},
|
||||
{"text": "goodbye world"}
|
||||
]
|
||||
tbl = db.create_table("test", schema=Schema, mode="overwrite")
|
||||
tbl.add(data)
|
||||
reranker = CrossEncoderReranker()
|
||||
|
||||
# Run vector search with a reranker
|
||||
result = tbl.search("hello").rerank(reranker=reranker).to_list()
|
||||
|
||||
# Run FTS search with a reranker
|
||||
result = tbl.search("hello", query_type="fts").rerank(reranker=reranker).to_list()
|
||||
|
||||
# Run hybrid search with a reranker
|
||||
tbl.create_fts_index("text", replace=True)
|
||||
result = tbl.search("hello", query_type="hybrid").rerank(reranker=reranker).to_list()
|
||||
|
||||
```
|
||||
|
||||
Accepted Arguments
|
||||
----------------
|
||||
| Argument | Type | Default | Description |
|
||||
| --- | --- | --- | --- |
|
||||
| `model_name` | `str` | `""cross-encoder/ms-marco-TinyBERT-L-6"` | The name of the reranker model to use.|
|
||||
| `column` | `str` | `"text"` | The name of the column to use as input to the cross encoder model. |
|
||||
| `device` | `str` | `None` | The device to use for the cross encoder model. If None, will use "cuda" if available, otherwise "cpu". |
|
||||
| `return_score` | str | `"relevance"` | Options are "relevance" or "all". The type of score to return. If "relevance", will return only the `_relevance_score. If "all" is supported, will return relevance score along with the vector and/or fts scores depending on query type |
|
||||
|
||||
## Supported Scores for each query type
|
||||
You can specify the type of scores you want the reranker to return. The following are the supported scores for each query type:
|
||||
|
||||
### Hybrid Search
|
||||
|`return_score`| Status | Description |
|
||||
| --- | --- | --- |
|
||||
| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column |
|
||||
| `all` | ❌ Not Supported | Returns have vector(`_distance`) and FTS(`score`) along with Hybrid Search score(`_relevance_score`) |
|
||||
|
||||
### Vector Search
|
||||
|`return_score`| Status | Description |
|
||||
| --- | --- | --- |
|
||||
| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column |
|
||||
| `all` | ✅ Supported | Returns have vector(`_distance`) along with Hybrid Search score(`_relevance_score`) |
|
||||
|
||||
### FTS Search
|
||||
|`return_score`| Status | Description |
|
||||
| --- | --- | --- |
|
||||
| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column |
|
||||
| `all` | ✅ Supported | Returns have FTS(`score`) along with Hybrid Search score(`_relevance_score`) |
|
||||
@@ -1,89 +0,0 @@
|
||||
## Building Custom Rerankers
|
||||
You can build your own custom reranker by subclassing the `Reranker` class and implementing the `rerank_hybrid()` method. Optionally, you can also implement the `rerank_vector()` and `rerank_fts()` methods if you want to support reranking for vector and FTS search separately.
|
||||
Here's an example of a custom reranker that combines the results of semantic and full-text search using a linear combination of the scores.
|
||||
|
||||
The `Reranker` base interface comes with a `merge_results()` method that can be used to combine the results of semantic and full-text search. This is a vanilla merging algorithm that simply concatenates the results and removes the duplicates without taking the scores into consideration. It only keeps the first copy of the row encountered. This works well in cases that don't require the scores of semantic and full-text search to combine the results. If you want to use the scores or want to support `return_score="all"`, you'll need to implement your own merging algorithm.
|
||||
|
||||
```python
|
||||
|
||||
from lancedb.rerankers import Reranker
|
||||
import pyarrow as pa
|
||||
|
||||
class MyReranker(Reranker):
|
||||
def __init__(self, param1, param2, ..., return_score="relevance"):
|
||||
super().__init__(return_score)
|
||||
self.param1 = param1
|
||||
self.param2 = param2
|
||||
|
||||
def rerank_hybrid(self, query: str, vector_results: pa.Table, fts_results: pa.Table):
|
||||
# Use the built-in merging function
|
||||
combined_result = self.merge_results(vector_results, fts_results)
|
||||
|
||||
# Do something with the combined results
|
||||
# ...
|
||||
|
||||
# Return the combined results
|
||||
return combined_result
|
||||
|
||||
def rerank_vector(self, query: str, vector_results: pa.Table):
|
||||
# Do something with the vector results
|
||||
# ...
|
||||
|
||||
# Return the vector results
|
||||
return vector_results
|
||||
|
||||
def rerank_fts(self, query: str, fts_results: pa.Table):
|
||||
# Do something with the FTS results
|
||||
# ...
|
||||
|
||||
# Return the FTS results
|
||||
return fts_results
|
||||
|
||||
```
|
||||
|
||||
### Example of a Custom Reranker
|
||||
For the sake of simplicity let's build custom reranker that just enchances the Cohere Reranker by accepting a filter query, and accept other CohereReranker params as kwags.
|
||||
|
||||
```python
|
||||
|
||||
from typing import List, Union
|
||||
import pandas as pd
|
||||
from lancedb.rerankers import CohereReranker
|
||||
|
||||
class ModifiedCohereReranker(CohereReranker):
|
||||
def __init__(self, filters: Union[str, List[str]], **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
filters = filters if isinstance(filters, list) else [filters]
|
||||
self.filters = filters
|
||||
|
||||
def rerank_hybrid(self, query: str, vector_results: pa.Table, fts_results: pa.Table)-> pa.Table:
|
||||
combined_result = super().rerank_hybrid(query, vector_results, fts_results)
|
||||
df = combined_result.to_pandas()
|
||||
for filter in self.filters:
|
||||
df = df.query("not text.str.contains(@filter)")
|
||||
|
||||
return pa.Table.from_pandas(df)
|
||||
|
||||
def rerank_vector(self, query: str, vector_results: pa.Table)-> pa.Table:
|
||||
vector_results = super().rerank_vector(query, vector_results)
|
||||
df = vector_results.to_pandas()
|
||||
for filter in self.filters:
|
||||
df = df.query("not text.str.contains(@filter)")
|
||||
|
||||
return pa.Table.from_pandas(df)
|
||||
|
||||
def rerank_fts(self, query: str, fts_results: pa.Table)-> pa.Table:
|
||||
fts_results = super().rerank_fts(query, fts_results)
|
||||
df = fts_results.to_pandas()
|
||||
for filter in self.filters:
|
||||
df = df.query("not text.str.contains(@filter)")
|
||||
|
||||
return pa.Table.from_pandas(df)
|
||||
|
||||
```
|
||||
|
||||
!!! tip
|
||||
The `vector_results` and `fts_results` are pyarrow tables. Lean more about pyarrow tables [here](https://arrow.apache.org/docs/python). It can be convered to other data types like pandas dataframe, pydict, pylist etc.
|
||||
|
||||
For example, You can convert them to pandas dataframes using `to_pandas()` method and perform any operations you want. After you are done, you can convert the dataframe back to pyarrow table using `pa.Table.from_pandas()` method and return it.
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
Reranking is the process of reordering a list of items based on some criteria. In the context of search, reranking is used to reorder the search results returned by a search engine based on some criteria. This can be useful when the initial ranking of the search results is not satisfactory or when the user has provided additional information that can be used to improve the ranking of the search results.
|
||||
|
||||
LanceDB comes with some built-in rerankers. Some of the rerankers that are available in LanceDB are:
|
||||
|
||||
| Reranker | Description | Supported Query Types |
|
||||
| --- | --- | --- |
|
||||
| `LinearCombinationReranker` | Reranks search results based on a linear combination of FTS and vector search scores | Hybrid |
|
||||
| `CohereReranker` | Uses cohere Reranker API to rerank results | Vector, FTS, Hybrid |
|
||||
| `CrossEncoderReranker` | Uses a cross-encoder model to rerank search results | Vector, FTS, Hybrid |
|
||||
| `ColbertReranker` | Uses a colbert model to rerank search results | Vector, FTS, Hybrid |
|
||||
| `OpenaiReranker`(Experimental) | Uses OpenAI's chat model to rerank search results | Vector, FTS, Hybrid |
|
||||
|
||||
|
||||
## Using a Reranker
|
||||
Using rerankers is optional for vector and FTS. However, for hybrid search, rerankers are required. To use a reranker, you need to create an instance of the reranker and pass it to the `rerank` method of the query builder.
|
||||
|
||||
```python
|
||||
import numpy
|
||||
import lancedb
|
||||
from lancedb.embeddings import get_registry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.rerankers import CohereReranker
|
||||
|
||||
embedder = get_registry().get("sentence-transformers").create()
|
||||
db = lancedb.connect("~/.lancedb")
|
||||
|
||||
class Schema(LanceModel):
|
||||
text: str = embedder.SourceField()
|
||||
vector: Vector(embedder.ndims()) = embedder.VectorField()
|
||||
|
||||
data = [
|
||||
{"text": "hello world"},
|
||||
{"text": "goodbye world"}
|
||||
]
|
||||
tbl = db.create_table("test", data)
|
||||
reranker = CohereReranker(api_key="your_api_key")
|
||||
|
||||
# Run vector search with a reranker
|
||||
result = tbl.query("hello").rerank(reranker).to_list()
|
||||
|
||||
# Run FTS search with a reranker
|
||||
result = tbl.query("hello", query_type="fts").rerank(reranker).to_list()
|
||||
|
||||
# Run hybrid search with a reranker
|
||||
tbl.create_fts_index("text")
|
||||
result = tbl.query("hello", query_type="hybrid").rerank(reranker).to_list()
|
||||
```
|
||||
|
||||
## Available Rerankers
|
||||
LanceDB comes with some built-in rerankers. Here are some of the rerankers that are available in LanceDB:
|
||||
|
||||
- [Cohere Reranker](./cohere.md)
|
||||
- [Cross Encoder Reranker](./cross_encoder.md)
|
||||
- [ColBERT Reranker](./colbert.md)
|
||||
- [OpenAI Reranker](./openai.md)
|
||||
- [Linear Combination Reranker](./linear_combination.md)
|
||||
|
||||
## Creating Custom Rerankers
|
||||
|
||||
LanceDB also you to create custom rerankers by extending the base `Reranker` class. The custom reranker should implement the `rerank` method that takes a list of search results and returns a reranked list of search results. This is covered in more detail in the [Creating Custom Rerankers](./custom_reranker.md) section.
|
||||
|
||||
@@ -1,52 +0,0 @@
|
||||
# Linear Combination Reranker
|
||||
|
||||
This is the default re-ranker used by LanceDB hybrid search. It combines the results of semantic and full-text search using a linear combination of the scores. The weights for the linear combination can be specified. It defaults to 0.7, i.e, 70% weight for semantic search and 30% weight for full-text search.
|
||||
|
||||
!!! note
|
||||
Supported Query Types: Hybrid
|
||||
|
||||
|
||||
```python
|
||||
import numpy
|
||||
import lancedb
|
||||
from lancedb.embeddings import get_registry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.rerankers import LinearCombinationReranker
|
||||
|
||||
embedder = get_registry().get("sentence-transformers").create()
|
||||
db = lancedb.connect("~/.lancedb")
|
||||
|
||||
class Schema(LanceModel):
|
||||
text: str = embedder.SourceField()
|
||||
vector: Vector(embedder.ndims()) = embedder.VectorField()
|
||||
|
||||
data = [
|
||||
{"text": "hello world"},
|
||||
{"text": "goodbye world"}
|
||||
]
|
||||
tbl = db.create_table("test", schema=Schema, mode="overwrite")
|
||||
tbl.add(data)
|
||||
reranker = LinearCombinationReranker()
|
||||
|
||||
# Run hybrid search with a reranker
|
||||
tbl.create_fts_index("text", replace=True)
|
||||
result = tbl.search("hello", query_type="hybrid").rerank(reranker=reranker).to_list()
|
||||
|
||||
```
|
||||
|
||||
Accepted Arguments
|
||||
----------------
|
||||
| Argument | Type | Default | Description |
|
||||
| --- | --- | --- | --- |
|
||||
| `weight` | `float` | `0.7` | The weight to use for the semantic search score. The weight for the full-text search score is `1 - weights`. |
|
||||
| `return_score` | str | `"relevance"` | Options are "relevance" or "all". The type of score to return. If "relevance", will return only the `_relevance_score. If "all", will return all scores from the vector and FTS search along with the relevance score. |
|
||||
|
||||
|
||||
## Supported Scores for each query type
|
||||
You can specify the type of scores you want the reranker to return. The following are the supported scores for each query type:
|
||||
|
||||
### Hybrid Search
|
||||
|`return_score`| Status | Description |
|
||||
| --- | --- | --- |
|
||||
| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column |
|
||||
| `all` | ✅ Supported | Returns have vector(`_distance`) and FTS(`score`) along with Hybrid Search score(`_distance`) |
|
||||
@@ -1,73 +0,0 @@
|
||||
# OpenAI Reranker (Experimental)
|
||||
|
||||
This re-ranker uses OpenAI chat model to rerank the search results. You can use this re-ranker by passing `OpenAI()` to the `rerank()` method.
|
||||
!!! note
|
||||
Supported Query Types: Hybrid, Vector, FTS
|
||||
|
||||
!!! warning
|
||||
This re-ranker is experimental. OpenAI doesn't have a dedicated reranking model, so we are using the chat model for reranking.
|
||||
|
||||
```python
|
||||
import numpy
|
||||
import lancedb
|
||||
from lancedb.embeddings import get_registry
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.rerankers import OpenaiReranker
|
||||
|
||||
embedder = get_registry().get("sentence-transformers").create()
|
||||
db = lancedb.connect("~/.lancedb")
|
||||
|
||||
class Schema(LanceModel):
|
||||
text: str = embedder.SourceField()
|
||||
vector: Vector(embedder.ndims()) = embedder.VectorField()
|
||||
|
||||
data = [
|
||||
{"text": "hello world"},
|
||||
{"text": "goodbye world"}
|
||||
]
|
||||
tbl = db.create_table("test", schema=Schema, mode="overwrite")
|
||||
tbl.add(data)
|
||||
reranker = OpenaiReranker()
|
||||
|
||||
# Run vector search with a reranker
|
||||
result = tbl.search("hello").rerank(reranker=reranker).to_list()
|
||||
|
||||
# Run FTS search with a reranker
|
||||
result = tbl.search("hello", query_type="fts").rerank(reranker=reranker).to_list()
|
||||
|
||||
# Run hybrid search with a reranker
|
||||
tbl.create_fts_index("text", replace=True)
|
||||
result = tbl.search("hello", query_type="hybrid").rerank(reranker=reranker).to_list()
|
||||
|
||||
```
|
||||
|
||||
Accepted Arguments
|
||||
----------------
|
||||
| Argument | Type | Default | Description |
|
||||
| --- | --- | --- | --- |
|
||||
| `model_name` | `str` | `"gpt-4-turbo-preview"` | The name of the reranker model to use.|
|
||||
| `column` | `str` | `"text"` | The name of the column to use as input to the cross encoder model. |
|
||||
| `return_score` | str | `"relevance"` | Options are "relevance" or "all". The type of score to return. If "relevance", will return only the `_relevance_score. If "all" is supported, will return relevance score along with the vector and/or fts scores depending on query type |
|
||||
| `api_key` | str | `None` | The API key to use. If None, will use the OPENAI_API_KEY environment variable.
|
||||
|
||||
|
||||
## Supported Scores for each query type
|
||||
You can specify the type of scores you want the reranker to return. The following are the supported scores for each query type:
|
||||
|
||||
### Hybrid Search
|
||||
|`return_score`| Status | Description |
|
||||
| --- | --- | --- |
|
||||
| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column |
|
||||
| `all` | ❌ Not Supported | Returns have vector(`_distance`) and FTS(`score`) along with Hybrid Search score(`_relevance_score`) |
|
||||
|
||||
### Vector Search
|
||||
|`return_score`| Status | Description |
|
||||
| --- | --- | --- |
|
||||
| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column |
|
||||
| `all` | ✅ Supported | Returns have vector(`_distance`) along with Hybrid Search score(`_relevance_score`) |
|
||||
|
||||
### FTS Search
|
||||
|`return_score`| Status | Description |
|
||||
| --- | --- | --- |
|
||||
| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column |
|
||||
| `all` | ✅ Supported | Returns have FTS(`score`) along with Hybrid Search score(`_relevance_score`) |
|
||||
@@ -15,7 +15,6 @@ excluded_globs = [
|
||||
"../src/ann_indexes.md",
|
||||
"../src/basic.md",
|
||||
"../src/hybrid_search/hybrid_search.md",
|
||||
"../src/reranking/*.md",
|
||||
]
|
||||
|
||||
python_prefix = "py"
|
||||
|
||||
@@ -10,13 +10,18 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
from tqdm import tqdm
|
||||
|
||||
import lancedb
|
||||
|
||||
from .fine_tuner import QADataset
|
||||
from .utils import TEXT, retry_with_exponential_backoff
|
||||
|
||||
|
||||
@@ -126,6 +131,22 @@ class EmbeddingFunction(BaseModel, ABC):
|
||||
def __hash__(self) -> int:
|
||||
return hash(frozenset(vars(self).items()))
|
||||
|
||||
def finetune(self, dataset: QADataset, *args, **kwargs):
|
||||
"""
|
||||
Finetune the embedding function on a dataset
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Finetuning is not supported for this embedding function"
|
||||
)
|
||||
|
||||
def evaluate(self, dataset: QADataset, top_k=5, path=None, *args, **kwargs):
|
||||
"""
|
||||
Evaluate the embedding function on a dataset
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Evaluation is not supported for this embedding function"
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingFunctionConfig(BaseModel):
|
||||
"""
|
||||
@@ -159,3 +180,52 @@ class TextEmbeddingFunction(EmbeddingFunction):
|
||||
Generate the embeddings for the given texts
|
||||
"""
|
||||
pass
|
||||
|
||||
def evaluate(self, dataset: QADataset, top_k=5, path=None, *args, **kwargs):
|
||||
"""
|
||||
Evaluate the embedding function on a dataset. This calculates the hit-rate for
|
||||
the top-k retrieved documents for each query in the dataset. Assumes that the
|
||||
first relevant document is the expected document.
|
||||
Pro - Should work for any embedding model
|
||||
Con - Returns every simple metric.
|
||||
Parameters
|
||||
----------
|
||||
dataset: QADataset
|
||||
The dataset to evaluate on
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
The evaluation results
|
||||
"""
|
||||
corpus = dataset.corpus
|
||||
queries = dataset.queries
|
||||
relevant_docs = dataset.relevant_docs
|
||||
path = path or os.path.join(os.getcwd(), "eval")
|
||||
db = lancedb.connect(path)
|
||||
|
||||
class Schema(lancedb.pydantic.LanceModel):
|
||||
id: str
|
||||
text: str = self.SourceField()
|
||||
vector: lancedb.pydantic.Vector(self.ndims()) = self.VectorField()
|
||||
|
||||
retriever = db.create_table("eval", schema=Schema, mode="overwrite")
|
||||
pylist = [{"id": str(k), "text": v} for k, v in corpus.items()]
|
||||
retriever.add(pylist)
|
||||
|
||||
eval_results = []
|
||||
for query_id, query in tqdm(queries.items()):
|
||||
retrieved_nodes = retriever.search(query).limit(top_k).to_list()
|
||||
retrieved_ids = [node["id"] for node in retrieved_nodes]
|
||||
expected_id = relevant_docs[query_id][0]
|
||||
is_hit = expected_id in retrieved_ids # assume 1 relevant doc
|
||||
|
||||
eval_result = {
|
||||
"is_hit": is_hit,
|
||||
"retrieved": retrieved_ids,
|
||||
"expected": expected_id,
|
||||
"query": query_id,
|
||||
}
|
||||
eval_results.append(eval_result)
|
||||
|
||||
return eval_results
|
||||
|
||||
133
python/python/lancedb/embeddings/fine_tuner/README.md
Normal file
@@ -0,0 +1,133 @@
|
||||
Fine-tuning workflow for embeddings consists for the following parts:
|
||||
|
||||
### QADataset
|
||||
This class is used for managing the data for fine-tuning. It contains the following builder methods:
|
||||
```
|
||||
- from_llm(
|
||||
nodes: 'List[TextChunk]' ,
|
||||
llm: BaseLLM,
|
||||
qa_generate_prompt_tmpl: str = DEFAULT_PROMPT_TMPL,
|
||||
num_questions_per_chunk: int = 2,
|
||||
) -> "QADataset"
|
||||
```
|
||||
Create synthetic data from a language model and text chunks of the original document on which the model is to be fine-tuned.
|
||||
|
||||
```python
|
||||
|
||||
from_responses(docs: List['TextChunk'], queries: Dict[str, str], relevant_docs: Dict[str, List[str]])-> "QADataset"
|
||||
```
|
||||
Create dataset from queries and responses based on a real-world scenario. Designed to be used for knowledge distillation from a larger LLM to a smaller one.
|
||||
|
||||
It also contains the following data attributes:
|
||||
```
|
||||
queries (Dict[str, str]): Dict id -> query.
|
||||
corpus (Dict[str, str]): Dict id -> string.
|
||||
relevant_docs (Dict[str, List[str]]): Dict query id -> list of doc ids.
|
||||
```
|
||||
|
||||
### TextChunk
|
||||
This class is used for managing the data for fine-tuning. It is designed to allow working with and standardize various text splitting/pre-processing tools like llama-index and langchain. It contains the following attributes:
|
||||
```
|
||||
text: str
|
||||
id: str
|
||||
metadata: Dict[str, Any] = {}
|
||||
```
|
||||
|
||||
Builder Methods:
|
||||
|
||||
```python
|
||||
from_llama_index_node(node) -> "TextChunk"
|
||||
```
|
||||
Create a text chunk from a llama index node.
|
||||
|
||||
```python
|
||||
from_langchain_node(node) -> "TextChunk"
|
||||
```
|
||||
Create a text chunk from a langchain index node.
|
||||
|
||||
```python
|
||||
from_chunk(cls, chunk: str, metadata: dict = {}) -> "TextChunk"
|
||||
```
|
||||
Create a text chunk from a string.
|
||||
|
||||
### FineTuner
|
||||
This class is used for fine-tuning embeddings. It is exposed to the user via a high-level function in the base embedding api.
|
||||
```python
|
||||
class BaseEmbeddingTuner(ABC):
|
||||
"""Base Embedding finetuning engine."""
|
||||
|
||||
@abstractmethod
|
||||
def finetune(self) -> None:
|
||||
"""Goes off and does stuff."""
|
||||
|
||||
def helper(self) -> None:
|
||||
"""A helper method."""
|
||||
pass
|
||||
```
|
||||
|
||||
### Embedding API finetuning implementation
|
||||
Each embedding API needs to implement `finetune` method in order to support fine-tuning. A vanilla evaluation technique has been implemented in the `BaseEmbedding` class that calculates hit_rate @ `top_k`.
|
||||
|
||||
### Fine-tuning workflow
|
||||
The fine-tuning workflow is as follows:
|
||||
1. Create a `QADataset` object.
|
||||
2. Initialize any embedding function using LanceDB embedding API
|
||||
3. Call `finetune` method on the embedding object with the `QADataset` object as an argument.
|
||||
4. Evaluate the fine-tuned model using the `evaluate` method in the embedding API.
|
||||
|
||||
# End-to-End Examples
|
||||
The following is an example of how to fine-tune an embedding model using the LanceDB embedding API.
|
||||
|
||||
## Example 1: Fine-tuning from a synthetic dataset
|
||||
```python
|
||||
import pandas as pd
|
||||
|
||||
from lancedb.embeddings.fine_tuner.llm import Openai
|
||||
from lancedb.embeddings.fine_tuner.dataset import QADataset, TextChunk
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from llama_index.core.schema import MetadataMode
|
||||
from lancedb.embeddings import get_registry
|
||||
|
||||
# 1. Create a QADataset object
|
||||
url = "uber10k.pdf"
|
||||
reader = SimpleDirectoryReader(input_files=url)
|
||||
docs = reader.load_data()
|
||||
|
||||
parser = SentenceSplitter()
|
||||
nodes = parser.get_nodes_from_documents(docs)
|
||||
|
||||
if os.path.exists(name):
|
||||
ds = QADataset.load(name)
|
||||
else:
|
||||
llm = Openai()
|
||||
|
||||
# convert Llama-index TextNode to TextChunk
|
||||
chunks = [TextChunk.from_llama_index_node(node) for node in nodes]
|
||||
|
||||
ds = QADataset.from_llm(chunks, llm)
|
||||
ds.save(name)
|
||||
|
||||
# 2. Initialize the embedding model
|
||||
model = get_registry().get("sentence-transformers").create()
|
||||
|
||||
# 3. Fine-tune the model
|
||||
model.finetune(trainset=ds, path="model_finetuned", epochs=4)
|
||||
|
||||
# 4. Evaluate the fine-tuned model
|
||||
base = get_registry().get("sentence-transformers").create()
|
||||
tuned = get_registry().get("sentence-transformers").create(name="./model_finetuned_1")
|
||||
openai = get_registry().get("openai").create(name="text-embedding-3-large")
|
||||
|
||||
|
||||
rs1 = base.evaluate(trainset, path="val_res")
|
||||
rs2 = tuned.evaluate(trainset, path="val_res")
|
||||
rs3 = openai.evaluate(trainset)
|
||||
|
||||
print("openai-embedding-v3 hit-rate - ", pd.DataFrame(rs3)["is_hit"].mean())
|
||||
print("fine-tuned hit-rate - ", pd.DataFrame(rs2)["is_hit"].mean())
|
||||
print("Base model hite-rate - ", pd.DataFrame(rs1)["is_hit"].mean())
|
||||
```
|
||||
|
||||
|
||||
4
python/python/lancedb/embeddings/fine_tuner/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .dataset import QADataset, TextChunk
|
||||
from .llm import Gemini, Openai
|
||||
|
||||
__all__ = ["QADataset", "TextChunk", "Openai", "Gemini"]
|
||||
13
python/python/lancedb/embeddings/fine_tuner/basetuner.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseEmbeddingTuner(ABC):
|
||||
"""Base Embedding finetuning engine."""
|
||||
|
||||
@abstractmethod
|
||||
def finetune(self) -> None:
|
||||
"""Goes off and does stuff."""
|
||||
|
||||
def helper(self) -> None:
|
||||
"""A helper method."""
|
||||
pass
|
||||
179
python/python/lancedb/embeddings/fine_tuner/dataset.py
Normal file
@@ -0,0 +1,179 @@
|
||||
import re
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import lance
|
||||
import pyarrow as pa
|
||||
from pydantic import BaseModel
|
||||
from tqdm import tqdm
|
||||
|
||||
from .llm import BaseLLM
|
||||
|
||||
DEFAULT_PROMPT_TMPL = """\
|
||||
Context information is below.
|
||||
|
||||
---------------------
|
||||
{context_str}
|
||||
---------------------
|
||||
|
||||
Given the context information and no prior knowledge.
|
||||
generate only questions based on the below query.
|
||||
|
||||
You are a Teacher/ Professor. Your task is to setup \
|
||||
{num_questions_per_chunk} questions for an upcoming \
|
||||
quiz/examination. The questions should be diverse in nature \
|
||||
across the document. Restrict the questions to the \
|
||||
context information provided."
|
||||
"""
|
||||
|
||||
|
||||
class QADataset(BaseModel):
|
||||
"""Embedding QA Finetuning Dataset.
|
||||
|
||||
Args:
|
||||
queries (Dict[str, str]): Dict id -> query.
|
||||
corpus (Dict[str, str]): Dict id -> string.
|
||||
relevant_docs (Dict[str, List[str]]): Dict query id -> list of doc ids.
|
||||
|
||||
"""
|
||||
|
||||
queries: Dict[str, str] # id -> query
|
||||
corpus: Dict[str, str] # id -> text
|
||||
relevant_docs: Dict[str, List[str]] # query id -> list of retrieved doc ids
|
||||
mode: str = "text"
|
||||
|
||||
@property
|
||||
def query_docid_pairs(self) -> List[Tuple[str, List[str]]]:
|
||||
"""Get query, relevant doc ids."""
|
||||
return [
|
||||
(query, self.relevant_docs[query_id])
|
||||
for query_id, query in self.queries.items()
|
||||
]
|
||||
|
||||
def save(self, path: str, mode: str = "overwrite") -> None:
|
||||
"""Save to lance dataset"""
|
||||
save_dir = Path(path)
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# convert to pydict {"id": []}
|
||||
queries = {
|
||||
"id": list(self.queries.keys()),
|
||||
"query": list(self.queries.values()),
|
||||
}
|
||||
corpus = {
|
||||
"id": list(self.corpus.keys()),
|
||||
"text": [
|
||||
val or " " for val in self.corpus.values()
|
||||
], # lance saves empty strings as null
|
||||
}
|
||||
relevant_docs = {
|
||||
"query_id": list(self.relevant_docs.keys()),
|
||||
"doc_id": list(self.relevant_docs.values()),
|
||||
}
|
||||
|
||||
# write to lance
|
||||
lance.write_dataset(
|
||||
pa.Table.from_pydict(queries), save_dir / "queries.lance", mode=mode
|
||||
)
|
||||
lance.write_dataset(
|
||||
pa.Table.from_pydict(corpus), save_dir / "corpus.lance", mode=mode
|
||||
)
|
||||
lance.write_dataset(
|
||||
pa.Table.from_pydict(relevant_docs),
|
||||
save_dir / "relevant_docs.lance",
|
||||
mode=mode,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: str) -> "QADataset":
|
||||
"""Load from .lance data"""
|
||||
load_dir = Path(path)
|
||||
queries = lance.dataset(load_dir / "queries.lance").to_table().to_pydict()
|
||||
corpus = lance.dataset(load_dir / "corpus.lance").to_table().to_pydict()
|
||||
relevant_docs = (
|
||||
lance.dataset(load_dir / "relevant_docs.lance").to_table().to_pydict()
|
||||
)
|
||||
return cls(
|
||||
queries=dict(zip(queries["id"], queries["query"])),
|
||||
corpus=dict(zip(corpus["id"], corpus["text"])),
|
||||
relevant_docs=dict(zip(relevant_docs["query_id"], relevant_docs["doc_id"])),
|
||||
)
|
||||
|
||||
# generate queries as a convenience function
|
||||
@classmethod
|
||||
def from_llm(
|
||||
cls,
|
||||
nodes: "List[TextChunk]",
|
||||
llm: BaseLLM,
|
||||
qa_generate_prompt_tmpl: str = DEFAULT_PROMPT_TMPL,
|
||||
num_questions_per_chunk: int = 2,
|
||||
) -> "QADataset":
|
||||
"""Generate examples given a set of nodes."""
|
||||
node_dict = {node.id: node.text for node in nodes}
|
||||
|
||||
queries = {}
|
||||
relevant_docs = {}
|
||||
for node_id, text in tqdm(node_dict.items()):
|
||||
query = qa_generate_prompt_tmpl.format(
|
||||
context_str=text, num_questions_per_chunk=num_questions_per_chunk
|
||||
)
|
||||
response = llm.chat_completion(query)
|
||||
|
||||
result = str(response).strip().split("\n")
|
||||
questions = [
|
||||
re.sub(r"^\d+[\).\s]", "", question).strip() for question in result
|
||||
]
|
||||
questions = [question for question in questions if len(question) > 0]
|
||||
for question in questions:
|
||||
question_id = str(uuid.uuid4())
|
||||
queries[question_id] = question
|
||||
relevant_docs[question_id] = [node_id]
|
||||
|
||||
return QADataset(queries=queries, corpus=node_dict, relevant_docs=relevant_docs)
|
||||
|
||||
@classmethod
|
||||
def from_responses(
|
||||
cls,
|
||||
docs: List["TextChunk"],
|
||||
queries: Dict[str, str],
|
||||
relevant_docs: Dict[str, List[str]],
|
||||
) -> "QADataset":
|
||||
"""Create a QADataset from a list of TextChunks and a list of questions."""
|
||||
node_dict = {node.id: node.text for node in docs}
|
||||
return cls(queries=queries, corpus=node_dict, relevant_docs=relevant_docs)
|
||||
|
||||
|
||||
class TextChunk(BaseModel):
|
||||
"""Simple text chunk for generating questions."""
|
||||
|
||||
text: str
|
||||
id: str
|
||||
metadata: Dict[str, Any] = {}
|
||||
|
||||
@classmethod
|
||||
def from_chunk(cls, chunk: str, metadata: dict = {}) -> "TextChunk":
|
||||
"""Create a SimpleTextChunk from a chunk."""
|
||||
# generate a unique id
|
||||
return cls(text=chunk, id=str(uuid.uuid4()), metadata=metadata)
|
||||
|
||||
@classmethod
|
||||
def from_llama_index_node(cls, node):
|
||||
"""Convert a llama index node to a text chunk."""
|
||||
return cls(text=node.text, id=node.node_id, metadata=node.metadata)
|
||||
|
||||
@classmethod
|
||||
def from_langchain_node(cls, node):
|
||||
"""Convert a langchaain node to a text chunk."""
|
||||
raise NotImplementedError("Not implemented yet.")
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to a dictionary."""
|
||||
return self.dict()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.text
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"SimpleTextChunk(text={self.text}, id={self.id}, \
|
||||
metadata={self.metadata})"
|
||||
85
python/python/lancedb/embeddings/fine_tuner/llm.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import os
|
||||
import re
|
||||
from functools import cached_property
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ...util import attempt_import_or_raise
|
||||
from ..utils import api_key_not_found_help
|
||||
|
||||
|
||||
class BaseLLM(BaseModel):
|
||||
"""
|
||||
TODO:
|
||||
Base class for Language Model based Embedding Functions. This class is
|
||||
loosely desined rn, and will be updated as the usage gets clearer.
|
||||
"""
|
||||
|
||||
model_name: str
|
||||
model_kwargs: dict = {}
|
||||
|
||||
@cached_property
|
||||
def _client():
|
||||
"""
|
||||
Get the client for the language model
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def chat_completion(self, prompt: str, **kwargs):
|
||||
"""
|
||||
Get the chat completion for the given prompt
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Openai(BaseLLM):
|
||||
model_name: str = "gpt-3.5-turbo"
|
||||
kwargs: dict = {}
|
||||
api_key: Optional[str] = None
|
||||
|
||||
@cached_property
|
||||
def _client(self):
|
||||
"""
|
||||
Get the client for the language model
|
||||
"""
|
||||
openai = attempt_import_or_raise("openai")
|
||||
|
||||
if not os.environ.get("OPENAI_API_KEY"):
|
||||
api_key_not_found_help("openai")
|
||||
return openai.OpenAI()
|
||||
|
||||
def chat_completion(self, prompt: str) -> str:
|
||||
"""
|
||||
Get the chat completion for the given prompt
|
||||
"""
|
||||
|
||||
# TODO: this is legacy openai api replace with completions
|
||||
completion = self._client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
**self.kwargs,
|
||||
)
|
||||
|
||||
text = completion.choices[0].message.content
|
||||
|
||||
return text
|
||||
|
||||
def get_questions(self, prompt: str) -> str:
|
||||
"""
|
||||
Get the chat completion for the given prompt
|
||||
"""
|
||||
response = self.chat_completion(prompt)
|
||||
result = str(response).strip().split("\n")
|
||||
questions = [
|
||||
re.sub(r"^\d+[\).\s]", "", question).strip() for question in result
|
||||
]
|
||||
questions = [question for question in questions if len(question) > 0]
|
||||
return questions
|
||||
|
||||
|
||||
class Gemini(BaseLLM):
|
||||
pass
|
||||
@@ -103,9 +103,9 @@ class InstructorEmbeddingFunction(TextEmbeddingFunction):
|
||||
# convert_to_numpy: bool = True # Hardcoding this as numpy can be ingested directly
|
||||
|
||||
source_instruction: str = "represent the document for retrieval"
|
||||
query_instruction: str = (
|
||||
"represent the document for retrieving the most similar documents"
|
||||
)
|
||||
query_instruction: (
|
||||
str
|
||||
) = "represent the document for retrieving the most similar documents"
|
||||
|
||||
@weak_lru(maxsize=1)
|
||||
def ndims(self):
|
||||
|
||||
@@ -10,12 +10,16 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import List, Union
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lancedb.embeddings.fine_tuner import QADataset
|
||||
from lancedb.utils.general import LOGGER
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import TextEmbeddingFunction
|
||||
from .fine_tuner.basetuner import BaseEmbeddingTuner
|
||||
from .registry import register
|
||||
from .utils import weak_lru
|
||||
|
||||
@@ -80,3 +84,151 @@ class SentenceTransformerEmbeddings(TextEmbeddingFunction):
|
||||
"sentence_transformers", "sentence-transformers"
|
||||
)
|
||||
return sentence_transformers.SentenceTransformer(self.name, device=self.device)
|
||||
|
||||
def finetune(self, trainset: QADataset, *args, **kwargs):
|
||||
"""
|
||||
Finetune the Sentence Transformers model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dataset: QADataset
|
||||
The dataset to use for finetuning
|
||||
"""
|
||||
tuner = SentenceTransformersTuner(
|
||||
model=self.embedding_model,
|
||||
trainset=trainset,
|
||||
**kwargs,
|
||||
)
|
||||
tuner.finetune()
|
||||
|
||||
|
||||
class SentenceTransformersTuner(BaseEmbeddingTuner):
|
||||
"""Sentence Transformers Embedding Finetuning Engine."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Any,
|
||||
trainset: QADataset,
|
||||
valset: Optional[QADataset] = None,
|
||||
path: Optional[str] = "~/.lancedb/embeddings/models",
|
||||
batch_size: int = 8,
|
||||
epochs: int = 1,
|
||||
show_progress: bool = True,
|
||||
eval_steps: int = 50,
|
||||
max_input_per_doc: int = -1,
|
||||
loss: Optional[Any] = None,
|
||||
evaluator: Optional[Any] = None,
|
||||
run_name: Optional[str] = None,
|
||||
log_wandb: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
model: str
|
||||
The model to use for finetuning.
|
||||
trainset: QADataset
|
||||
The training dataset.
|
||||
valset: Optional[QADataset]
|
||||
The validation dataset.
|
||||
path: Optional[str]
|
||||
The path to save the model.
|
||||
batch_size: int, default=8
|
||||
The batch size.
|
||||
epochs: int, default=1
|
||||
The number of epochs.
|
||||
show_progress: bool, default=True
|
||||
Whether to show progress.
|
||||
eval_steps: int, default=50
|
||||
The number of steps to evaluate.
|
||||
max_input_per_doc: int, default=-1
|
||||
The number of input per document.
|
||||
if -1, use all documents.
|
||||
"""
|
||||
from sentence_transformers import InputExample, losses
|
||||
from sentence_transformers.evaluation import InformationRetrievalEvaluator
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
self.model = model
|
||||
self.trainset = trainset
|
||||
self.valset = valset
|
||||
self.path = path
|
||||
self.batch_size = batch_size
|
||||
self.epochs = epochs
|
||||
self.show_progress = show_progress
|
||||
self.eval_steps = eval_steps
|
||||
self.max_input_per_doc = max_input_per_doc
|
||||
self.evaluator = None
|
||||
self.epochs = epochs
|
||||
self.show_progress = show_progress
|
||||
self.eval_steps = eval_steps
|
||||
self.run_name = run_name
|
||||
self.log_wandb = log_wandb
|
||||
|
||||
if self.max_input_per_doc < -1:
|
||||
raise ValueError("max_input_per_doc must be -1 or greater than 0.")
|
||||
|
||||
examples: Any = []
|
||||
for query_id, query in self.trainset.queries.items():
|
||||
if max_input_per_doc == -1:
|
||||
for node_id in self.trainset.relevant_docs[query_id]:
|
||||
text = self.trainset.corpus[node_id]
|
||||
example = InputExample(texts=[query, text])
|
||||
examples.append(example)
|
||||
else:
|
||||
node_id = self.trainset.relevant_docs[query_id][
|
||||
min(max_input_per_doc, len(self.trainset.relevant_docs[query_id]))
|
||||
]
|
||||
text = self.trainset.corpus[node_id]
|
||||
example = InputExample(texts=[query, text])
|
||||
examples.append(example)
|
||||
|
||||
self.examples = examples
|
||||
|
||||
self.loader: DataLoader = DataLoader(examples, batch_size=batch_size)
|
||||
|
||||
if self.valset is not None:
|
||||
eval_engine = evaluator or InformationRetrievalEvaluator
|
||||
self.evaluator = eval_engine(
|
||||
valset.queries, valset.corpus, valset.relevant_docs
|
||||
)
|
||||
self.evaluator = evaluator
|
||||
|
||||
# define loss
|
||||
self.loss = loss or losses.MultipleNegativesRankingLoss(self.model)
|
||||
self.warmup_steps = int(len(self.loader) * epochs * 0.1)
|
||||
|
||||
def finetune(self) -> None:
|
||||
"""Finetune the Sentence Transformers model."""
|
||||
self.model.fit(
|
||||
train_objectives=[(self.loader, self.loss)],
|
||||
epochs=self.epochs,
|
||||
warmup_steps=self.warmup_steps,
|
||||
output_path=self.path,
|
||||
show_progress_bar=self.show_progress,
|
||||
evaluator=self.evaluator,
|
||||
evaluation_steps=self.eval_steps,
|
||||
callback=self._wandb_callback if self.log_wandb else None,
|
||||
)
|
||||
|
||||
self.helper()
|
||||
|
||||
def helper(self) -> None:
|
||||
"""A helper method."""
|
||||
LOGGER.info("Finetuning complete.")
|
||||
LOGGER.info(f"Model saved to {self.path}.")
|
||||
LOGGER.info("You can now use the model as follows:")
|
||||
LOGGER.info(
|
||||
f"model = get_registry().get('sentence-transformers').create(name='./{self.path}')" # noqa
|
||||
)
|
||||
|
||||
def _wandb_callback(self, score, epoch, steps):
|
||||
try:
|
||||
import wandb
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"wandb is not installed. Please install it using `pip install wandb`"
|
||||
)
|
||||
run = wandb.run or wandb.init(
|
||||
project="sbert_lancedb_finetune", name=self.run_name
|
||||
)
|
||||
run.log({"epoch": epoch, "steps": steps, "score": score})
|
||||
|
||||
@@ -118,11 +118,6 @@ class Reranker(ABC):
|
||||
The results from the vector search
|
||||
fts_results : pa.Table
|
||||
The results from the FTS search
|
||||
|
||||
Returns
|
||||
-------
|
||||
pa.Table
|
||||
The merged results
|
||||
"""
|
||||
combined = pa.concat_tables([vector_results, fts_results], promote=True)
|
||||
row_id = combined.column("_rowid")
|
||||
|
||||
45
python/python/tests/test_embedding_fine_tuning.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from lancedb.embeddings import get_registry
|
||||
from lancedb.embeddings.fine_tuner import QADataset, TextChunk
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_finetuning_sentence_transformers(tmp_path):
|
||||
queries = {}
|
||||
relevant_docs = {}
|
||||
chunks = [
|
||||
"This is a chunk related to legal docs",
|
||||
"This is another chunk related financial docs",
|
||||
"This is a chunk related to sports docs",
|
||||
"This is another chunk related to fashion docs",
|
||||
]
|
||||
text_chunks = [TextChunk.from_chunk(chunk) for chunk in chunks]
|
||||
for chunk in tqdm(text_chunks):
|
||||
questions = [
|
||||
"What is this chunk about?",
|
||||
"What is the main topic of this chunk?",
|
||||
]
|
||||
for question in questions:
|
||||
question_id = str(uuid.uuid4())
|
||||
queries[question_id] = question
|
||||
relevant_docs[question_id] = [chunk.id]
|
||||
ds = QADataset.from_responses(text_chunks, queries, relevant_docs)
|
||||
|
||||
assert len(ds.queries) == 8
|
||||
assert len(ds.corpus) == 4
|
||||
|
||||
model = get_registry().get("sentence-transformers").create()
|
||||
model.finetune(trainset=ds, valset=ds, path=str(tmp_path / "model"), epochs=1)
|
||||
model = (
|
||||
get_registry().get("sentence-transformers").create(name=str(tmp_path / "model"))
|
||||
)
|
||||
res = model.evaluate(ds)
|
||||
assert res is not None
|
||||
|
||||
|
||||
def test_text_chunk():
|
||||
# TODO
|
||||
pass
|
||||