mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-06 20:02:58 +00:00
feat(python): Embedding API fine tuning support (#1125)
# based on https://github.com/lancedb/lancedb/pull/1023 Very WIP. I'm thinking of merging individual pieces in this feature branch instead of main so we can still review code in pieces and avoid polluting main. - Adds support for creating corpus from llama-index text-node object (aim to remove this dependency) - Adds very basic support for LLM api for chat completion, will expand as need arises. - Add basic universal evaluator - Add Sentence transformer finetuning support Known problems: - [ ] W&B experiment tracking is not working for sentence transformers
This commit is contained in:
150
docs/src/eval/bench_fine_tuned_hybrid.py
Normal file
150
docs/src/eval/bench_fine_tuned_hybrid.py
Normal file
@@ -0,0 +1,150 @@
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
import os
|
||||
import requests
|
||||
from llama_index.core import ServiceContext, VectorStoreIndex, StorageContext
|
||||
from llama_index.core.schema import TextNode
|
||||
from llama_index.vector_stores.lancedb import LanceDBVectorStore
|
||||
from lancedb.rerankers import CrossEncoderReranker, ColbertReranker, CohereReranker, LinearCombinationReranker
|
||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.embeddings import get_registry
|
||||
from lancedb.embeddings.fine_tuner.dataset import QADataset, TextChunk, DEFAULT_PROMPT_TMPL
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from lancedb.embeddings.fine_tuner.llm import Openai
|
||||
|
||||
import time
|
||||
import lancedb
|
||||
import wandb
|
||||
from pydantic import BaseModel, root_validator
|
||||
from typing import Optional
|
||||
|
||||
TRAIN_DATASET_FPATH = './data/train_dataset.json'
|
||||
VAL_DATASET_FPATH = './data/val_dataset.json'
|
||||
|
||||
with open(TRAIN_DATASET_FPATH, 'r+') as f:
|
||||
train_dataset = json.load(f)
|
||||
|
||||
with open(VAL_DATASET_FPATH, 'r+') as f:
|
||||
val_dataset = json.load(f)
|
||||
|
||||
def train_embedding_model(epoch):
|
||||
def download_test_files(url):
|
||||
# download to cwd
|
||||
files = []
|
||||
filename = os.path.basename(url)
|
||||
if not os.path.exists(filename):
|
||||
print(f"Downloading {url} to {filename}")
|
||||
r = requests.get(url)
|
||||
with open(filename, 'wb') as f:
|
||||
f.write(r.content)
|
||||
files.append(filename)
|
||||
return files
|
||||
|
||||
def get_dataset(url, name):
|
||||
reader = SimpleDirectoryReader(input_files=download_test_files(url))
|
||||
docs = reader.load_data()
|
||||
|
||||
parser = SentenceSplitter()
|
||||
nodes = parser.get_nodes_from_documents(docs)
|
||||
|
||||
if os.path.exists(name):
|
||||
ds = QADataset.load(name)
|
||||
else:
|
||||
llm = Openai()
|
||||
|
||||
# convert Llama-index TextNode to TextChunk
|
||||
chunks = [TextChunk.from_llama_index_node(node) for node in nodes]
|
||||
|
||||
ds = QADataset.from_llm(chunks, llm, num_questions_per_chunk=2)
|
||||
ds.save(name)
|
||||
return ds
|
||||
train_url = 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/10k/uber_2021.pdf'
|
||||
ds = get_dataset(train_url, "qa_dataset_uber")
|
||||
|
||||
|
||||
model = get_registry().get("sentence-transformers").create(name="BAAI/bge-small-en-v1.5")
|
||||
model.finetune(trainset=ds, valset=None, path="model_airbnb", epochs=epoch, log_wandb=True, run_name="lyft_finetune")
|
||||
|
||||
|
||||
def evaluate(
|
||||
dataset,
|
||||
embed_model,
|
||||
reranker=None,
|
||||
top_k=5,
|
||||
verbose=False,
|
||||
):
|
||||
corpus = dataset['corpus']
|
||||
queries = dataset['queries']
|
||||
relevant_docs = dataset['relevant_docs']
|
||||
|
||||
vector_store = LanceDBVectorStore(uri="/tmp/lancedb")
|
||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||
service_context = ServiceContext.from_defaults(embed_model=embed_model)
|
||||
nodes = [TextNode(id_=id_, text=text) for id_, text in corpus.items()]
|
||||
index = VectorStoreIndex(
|
||||
nodes,
|
||||
service_context=service_context,
|
||||
show_progress=True,
|
||||
storage_context=storage_context,
|
||||
)
|
||||
tbl = vector_store.connection.open_table(vector_store.table_name)
|
||||
tbl.create_fts_index("text", replace=True)
|
||||
|
||||
eval_results = []
|
||||
for query_id, query in tqdm(queries.items()):
|
||||
query_vector = embed_model.get_query_embedding(query)
|
||||
try:
|
||||
if reranker is None:
|
||||
rs = tbl.search(query_vector).limit(top_k).to_pandas()
|
||||
else:
|
||||
rs = tbl.search((query_vector, query)).rerank(reranker=reranker).limit(top_k).to_pandas()
|
||||
except Exception as e:
|
||||
print(f'Error with query: {query_id} {e}')
|
||||
continue
|
||||
retrieved_ids = rs['id'].tolist()[:top_k]
|
||||
expected_id = relevant_docs[query_id][0]
|
||||
is_hit = expected_id in retrieved_ids # assume 1 relevant doc
|
||||
if len(eval_results) == 0:
|
||||
print(f"Query: {query}")
|
||||
print(f"Expected: {expected_id}")
|
||||
print(f"Retrieved: {retrieved_ids}")
|
||||
eval_result = {
|
||||
'is_hit': is_hit,
|
||||
'retrieved': retrieved_ids,
|
||||
'expected': expected_id,
|
||||
'query': query_id,
|
||||
}
|
||||
eval_results.append(eval_result)
|
||||
return eval_results
|
||||
|
||||
if __name__ == '__main__':
|
||||
train_embedding_model(4)
|
||||
#embed_model = OpenAIEmbedding() # model="text-embedding-3-small"
|
||||
rerankers = {
|
||||
"Vector Search": None,
|
||||
"Cohere": CohereReranker(),
|
||||
"Cross Encoder": CrossEncoderReranker(),
|
||||
"Colbert": ColbertReranker(),
|
||||
"linear": LinearCombinationReranker(),
|
||||
}
|
||||
top_ks = [3]
|
||||
for top_k in top_ks:
|
||||
#for epoch in epochs:
|
||||
for name, reranker in rerankers.items():
|
||||
#embed_model = HuggingFaceEmbedding("./model_airbnb")
|
||||
embed_model = OpenAIEmbedding()
|
||||
wandb.init(project=f"Reranker-based", name=name)
|
||||
val_eval_results = evaluate(val_dataset, embed_model, reranker=reranker, top_k=top_k)
|
||||
df = pd.DataFrame(val_eval_results)
|
||||
|
||||
hit_rate = df['is_hit'].mean()
|
||||
print(f'Hit rate: {hit_rate:.2f}')
|
||||
wandb.log({f"openai_base_hit_rate_@{top_k}": hit_rate})
|
||||
wandb.finish()
|
||||
|
||||
|
||||
71
docs/src/eval/test_fine_tune_from_llm.py
Normal file
71
docs/src/eval/test_fine_tune_from_llm.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import os
|
||||
import json
|
||||
import lancedb
|
||||
import pandas as pd
|
||||
|
||||
from lancedb.embeddings.fine_tuner.llm import Openai
|
||||
from lancedb.embeddings.fine_tuner.dataset import QADataset, TextChunk
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from llama_index.core.schema import MetadataMode
|
||||
from lancedb.embeddings import get_registry
|
||||
|
||||
|
||||
test_url = 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/10k/lyft_2021.pdf'
|
||||
train_url = 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/10k/uber_2021.pdf'
|
||||
def download_test_files(url):
|
||||
import os
|
||||
import requests
|
||||
|
||||
# download to cwd
|
||||
files = []
|
||||
filename = os.path.basename(url)
|
||||
if not os.path.exists(filename):
|
||||
print(f"Downloading {url} to {filename}")
|
||||
r = requests.get(url)
|
||||
with open(filename, 'wb') as f:
|
||||
f.write(r.content)
|
||||
files.append(filename)
|
||||
return files
|
||||
|
||||
def get_dataset(url, name):
|
||||
reader = SimpleDirectoryReader(input_files=download_test_files(url))
|
||||
docs = reader.load_data()
|
||||
|
||||
parser = SentenceSplitter()
|
||||
nodes = parser.get_nodes_from_documents(docs)
|
||||
|
||||
if os.path.exists(name):
|
||||
ds = QADataset.load(name)
|
||||
else:
|
||||
llm = Openai()
|
||||
|
||||
# convert Llama-index TextNode to TextChunk
|
||||
chunks = [TextChunk.from_llama_index_node(node) for node in nodes]
|
||||
|
||||
ds = QADataset.from_llm(chunks, llm)
|
||||
ds.save(name)
|
||||
return ds
|
||||
|
||||
|
||||
|
||||
trainset = get_dataset(test_url, "qa_dataset_1")
|
||||
valset = get_dataset(train_url, "valset")
|
||||
|
||||
model = get_registry().get("sentence-transformers").create()
|
||||
model.finetune(trainset=trainset, valset=valset, path="model_finetuned_1", epochs=4)
|
||||
|
||||
base = get_registry().get("sentence-transformers").create()
|
||||
tuned = get_registry().get("sentence-transformers").create(name="./model_finetuned_1")
|
||||
openai = get_registry().get("openai").create(name="text-embedding-3-large")
|
||||
|
||||
|
||||
rs1 = base.evaluate(valset, path="val_res")
|
||||
rs2 = tuned.evaluate(valset, path="val_res")
|
||||
rs3 = openai.evaluate(valset)
|
||||
|
||||
print("openai-embedding-v3 hit-rate - ", pd.DataFrame(rs3)["is_hit"].mean())
|
||||
print("fine-tuned hit-rate - ", pd.DataFrame(rs2)["is_hit"].mean())
|
||||
print("Base model hite-rate - ", pd.DataFrame(rs1)["is_hit"].mean())
|
||||
|
||||
119
docs/src/eval/test_fine_tune_from_responses.py
Normal file
119
docs/src/eval/test_fine_tune_from_responses.py
Normal file
@@ -0,0 +1,119 @@
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import uuid
|
||||
import lancedb
|
||||
import pandas as pd
|
||||
|
||||
from tqdm import tqdm
|
||||
from lancedb.embeddings.fine_tuner.llm import Openai
|
||||
from lancedb.embeddings.fine_tuner.dataset import QADataset, TextChunk, DEFAULT_PROMPT_TMPL
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from llama_index.core.schema import MetadataMode
|
||||
from lancedb.embeddings import get_registry
|
||||
|
||||
|
||||
|
||||
test_url = 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/10k/lyft_2021.pdf'
|
||||
train_url = 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/10k/uber_2021.pdf'
|
||||
def download_test_files(url):
|
||||
import os
|
||||
import requests
|
||||
|
||||
|
||||
# download to cwd
|
||||
files = []
|
||||
filename = os.path.basename(url)
|
||||
if not os.path.exists(filename):
|
||||
print(f"Downloading {url} to {filename}")
|
||||
r = requests.get(url)
|
||||
with open(filename, 'wb') as f:
|
||||
f.write(r.content)
|
||||
files.append(filename)
|
||||
return files
|
||||
|
||||
|
||||
def get_node(url):
|
||||
reader = SimpleDirectoryReader(input_files=download_test_files(url))
|
||||
docs = reader.load_data()
|
||||
|
||||
parser = SentenceSplitter()
|
||||
nodes = parser.get_nodes_from_documents(docs)
|
||||
|
||||
return nodes
|
||||
def get_dataset(url, name):
|
||||
reader = SimpleDirectoryReader(input_files=download_test_files(url))
|
||||
docs = reader.load_data()
|
||||
|
||||
parser = SentenceSplitter()
|
||||
nodes = parser.get_nodes_from_documents(docs)
|
||||
|
||||
if os.path.exists(name):
|
||||
ds = QADataset.load(name)
|
||||
else:
|
||||
llm = Openai()
|
||||
|
||||
# convert Llama-index TextNode to TextChunk
|
||||
chunks = [TextChunk.from_llama_index_node(node) for node in nodes]
|
||||
|
||||
ds = QADataset.from_llm(chunks, llm)
|
||||
ds.save(name)
|
||||
return ds
|
||||
|
||||
nodes = get_node(train_url)
|
||||
|
||||
db = lancedb.connect("~/lancedb/fine-tuning")
|
||||
model = get_registry().get("openai").create()
|
||||
class Schema(LanceModel):
|
||||
id: str
|
||||
text: str = model.SourceField()
|
||||
vector: Vector(model.ndims()) = model.VectorField()
|
||||
|
||||
retriever = db.create_table("fine-tuning", schema=Schema, mode="overwrite")
|
||||
pylist = [{"id": str(node.node_id), "text": node.text} for node in nodes]
|
||||
retriever.add(pylist)
|
||||
|
||||
|
||||
|
||||
ds_name = "response_data"
|
||||
if os.path.exists(ds_name):
|
||||
ds = QADataset.load(ds_name)
|
||||
else:
|
||||
# Generate questions
|
||||
llm = Openai()
|
||||
text_chunks = [TextChunk.from_llama_index_node(node) for node in nodes]
|
||||
|
||||
queries = {}
|
||||
relevant_docs = {}
|
||||
for chunk in tqdm(text_chunks):
|
||||
text = chunk.text
|
||||
questions = llm.get_questions(DEFAULT_PROMPT_TMPL.format(context_str=text, num_questions_per_chunk=2))
|
||||
|
||||
for question in questions:
|
||||
question_id = str(uuid.uuid4())
|
||||
queries[question_id] = question
|
||||
relevant_docs[question_id] = [retriever.search(question).to_pandas()["id"].tolist()[0]]
|
||||
ds = QADataset.from_responses(text_chunks, queries, relevant_docs)
|
||||
ds.save(ds_name)
|
||||
|
||||
|
||||
# Fine-tune model
|
||||
valset = get_dataset(train_url, "valset")
|
||||
|
||||
model = get_registry().get("sentence-transformers").create()
|
||||
res_base = model.evaluate(valset)
|
||||
|
||||
model.finetune(trainset=ds, path="model_finetuned", epochs=4, log_wandb=True)
|
||||
tuned = get_registry().get("sentence-transformers").create(name="./model_finetuned")
|
||||
res_tuned = tuned.evaluate(valset)
|
||||
|
||||
openai_model = get_registry().get("openai").create()
|
||||
#res_openai = openai_model.evaluate(valset)
|
||||
|
||||
#print(f"openai model results: {pd.DataFrame(res_openai)['is_hit'].mean()}")
|
||||
print(f"base model results: {pd.DataFrame(res_base)['is_hit'].mean()}")
|
||||
print(f"tuned model results: {pd.DataFrame(res_tuned)['is_hit'].mean()}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user