mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-26 22:59:57 +00:00
update
This commit is contained in:
@@ -15,25 +15,33 @@ import time
|
||||
import wandb
|
||||
|
||||
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def get_paths_from_dataset(dataset: str, split=True):
|
||||
"""
|
||||
Returns paths of:
|
||||
- downloaded dataset, lance train dataset, lance test dataset, finetuned model
|
||||
"""
|
||||
if split:
|
||||
return f"./data/{dataset}", f"./data/{dataset}_lance_train", f"./data/{dataset}_lance_test", f"./data/tuned_{dataset}"
|
||||
return (
|
||||
f"./data/{dataset}",
|
||||
f"./data/{dataset}_lance_train",
|
||||
f"./data/{dataset}_lance_test",
|
||||
f"./data/tuned_{dataset}",
|
||||
)
|
||||
return f"./data/{dataset}", f"./data/{dataset}_lance", f"./data/tuned_{dataset}"
|
||||
|
||||
|
||||
def get_llama_dataset(dataset: str):
|
||||
"""
|
||||
returns:
|
||||
- nodes, documents, rag_dataset
|
||||
"""
|
||||
if not os.path.exists(f"./data/{dataset}"):
|
||||
os.system(f"llamaindex-cli download-llamadataset {dataset} --download-dir ./data/{dataset}")
|
||||
os.system(
|
||||
f"llamaindex-cli download-llamadataset {dataset} --download-dir ./data/{dataset}"
|
||||
)
|
||||
rag_dataset = LabelledRagDataset.from_json(f"./data/{dataset}/rag_dataset.json")
|
||||
docs = SimpleDirectoryReader(input_dir=f"./data/{dataset}/source_files").load_data()
|
||||
|
||||
@@ -42,6 +50,7 @@ def get_llama_dataset(dataset: str):
|
||||
|
||||
return nodes, docs, rag_dataset
|
||||
|
||||
|
||||
def lance_dataset_from_llama_nodes(nodes: list, name: str, split=True):
|
||||
llm = Openai()
|
||||
chunks = [TextChunk.from_llama_index_node(node) for node in nodes]
|
||||
@@ -53,12 +62,14 @@ def lance_dataset_from_llama_nodes(nodes: list, name: str, split=True):
|
||||
ds = QADataset.from_llm(chunks, llm)
|
||||
ds.save(f"./data/{name}_lance")
|
||||
return ds
|
||||
|
||||
if os.path.exists(f"./data/{name}_lance_train") and os.path.exists(f"./data/{name}_lance_test"):
|
||||
|
||||
if os.path.exists(f"./data/{name}_lance_train") and os.path.exists(
|
||||
f"./data/{name}_lance_test"
|
||||
):
|
||||
train_ds = QADataset.load(f"./data/{name}_lance_train")
|
||||
test_ds = QADataset.load(f"./data/{name}_lance_test")
|
||||
return train_ds, test_ds
|
||||
# split chunks random
|
||||
# split chunks random
|
||||
train_size = int(len(chunks) * 0.65)
|
||||
train_chunks = chunks[:train_size]
|
||||
test_chunks = chunks[train_size:]
|
||||
@@ -69,29 +80,29 @@ def lance_dataset_from_llama_nodes(nodes: list, name: str, split=True):
|
||||
return train_ds, test_ds
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def finetune(trainset: str, model: str, epochs: int, path: str, valset: str = None, top_k=5):
|
||||
def finetune(
|
||||
trainset: str, model: str, epochs: int, path: str, valset: str = None, top_k=5
|
||||
):
|
||||
print(f"Finetuning {model} for {epochs} epochs")
|
||||
print(f"trainset query instances: {len(trainset.queries)}")
|
||||
print(f"valset query instances: {len(valset.queries)}")
|
||||
|
||||
valset = valset if valset is not None else trainset
|
||||
model = get_registry().get("sentence-transformers").create(name=model)
|
||||
base_result = model.evaluate(valset, path=f"./data/eva/", top_k=top_k)
|
||||
base_result = model.evaluate(valset, path=f"./data/eval/", top_k=top_k)
|
||||
base_hit_rate = pd.DataFrame(base_result)["is_hit"].mean()
|
||||
|
||||
model.finetune(trainset=trainset, valset=valset, path=path, epochs=epochs)
|
||||
tuned = get_registry().get("sentence-transformers").create(name=path)
|
||||
tuned_result = tuned.evaluate(valset, path=f"./data/eva/{str(time.time())}", top_k=top_k)
|
||||
tuned_result = tuned.evaluate(
|
||||
valset, path=f"./data/eval/{str(time.time())}", top_k=top_k
|
||||
)
|
||||
tuned_hit_rate = pd.DataFrame(tuned_result)["is_hit"].mean()
|
||||
|
||||
return base_hit_rate, tuned_hit_rate
|
||||
|
||||
|
||||
def eval_rag(dataset: str, model: str):
|
||||
def do_eval_rag(dataset: str, model: str):
|
||||
# Requires - pip install llama-index-vector-stores-lancedb
|
||||
# Requires - pip install llama-index-embeddings-huggingface
|
||||
nodes, docs, rag_dataset = get_llama_dataset(dataset)
|
||||
@@ -101,8 +112,8 @@ def eval_rag(dataset: str, model: str):
|
||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||
service_context = ServiceContext.from_defaults(embed_model=embed_model)
|
||||
index = VectorStoreIndex(
|
||||
nodes,
|
||||
service_context=service_context,
|
||||
nodes,
|
||||
service_context=service_context,
|
||||
show_progress=True,
|
||||
storage_context=storage_context,
|
||||
)
|
||||
@@ -112,9 +123,7 @@ def eval_rag(dataset: str, model: str):
|
||||
query_engine = index.as_query_engine()
|
||||
|
||||
# evaluate using the RagEvaluatorPack
|
||||
RagEvaluatorPack = download_llama_pack(
|
||||
"RagEvaluatorPack", "./rag_evaluator_pack"
|
||||
)
|
||||
RagEvaluatorPack = download_llama_pack("RagEvaluatorPack", "./rag_evaluator_pack")
|
||||
rag_evaluator_pack = RagEvaluatorPack(
|
||||
rag_dataset=rag_dataset, query_engine=query_engine
|
||||
)
|
||||
@@ -126,36 +135,69 @@ def eval_rag(dataset: str, model: str):
|
||||
|
||||
return metrics
|
||||
|
||||
def main(dataset, model, epochs, top_k=5, eval_rag=False, project: str = "lancedb_finetune"):
|
||||
|
||||
def main(
|
||||
dataset,
|
||||
model,
|
||||
epochs,
|
||||
top_k=5,
|
||||
eval_rag=False,
|
||||
split=True,
|
||||
project: str = "lancedb_finetune",
|
||||
):
|
||||
nodes, _, _ = get_llama_dataset(dataset)
|
||||
trainset, valset = lance_dataset_from_llama_nodes(nodes, dataset)
|
||||
data_path, lance_train_path, lance_test_path, tuned_path = get_paths_from_dataset(dataset, split=True)
|
||||
|
||||
base_hit_rate, tuned_hit_rate = finetune(trainset, model, epochs, tuned_path, valset, top_k=top_k)
|
||||
trainset = None
|
||||
valset = None
|
||||
if split:
|
||||
trainset, valset = lance_dataset_from_llama_nodes(nodes, dataset, split)
|
||||
data_path, lance_train_path, lance_test_path, tuned_path = (
|
||||
get_paths_from_dataset(dataset, split=split)
|
||||
)
|
||||
else:
|
||||
trainset = lance_dataset_from_llama_nodes(nodes, dataset, split)
|
||||
valset = trainset
|
||||
data_path, lance_path, tuned_path = get_paths_from_dataset(dataset, split=split)
|
||||
|
||||
base_hit_rate, tuned_hit_rate = finetune(
|
||||
trainset, model, epochs, tuned_path, valset, top_k=top_k
|
||||
)
|
||||
|
||||
# Base model model metrics
|
||||
metrics = eval_rag(dataset, model)
|
||||
metrics = do_eval_rag(dataset, model) if eval_rag else {}
|
||||
|
||||
# Tuned model metrics
|
||||
metrics_tuned = eval_rag(dataset, tuned_path)
|
||||
|
||||
wandb.init(project="lancedb_finetune", name=f"{dataset}_{model}_{epochs}")
|
||||
wandb.log({
|
||||
"hit_rate": tuned_hit_rate,
|
||||
})
|
||||
metrics_tuned = do_eval_rag(dataset, tuned_path) if eval_rag else {}
|
||||
|
||||
wandb.init(project="lancedb_finetune", name=f"{dataset}_{model}_{epochs}")
|
||||
wandb.log(
|
||||
{
|
||||
"hit_rate": tuned_hit_rate,
|
||||
}
|
||||
)
|
||||
wandb.log(metrics_tuned)
|
||||
wandb.finish()
|
||||
|
||||
wandb.init(project="lancedb_finetune", name=f"{dataset}_{model}_base")
|
||||
wandb.log({
|
||||
"hit_rate": base_hit_rate,
|
||||
})
|
||||
wandb.init(project="lancedb_finetune", name=f"{dataset}_{model}_base")
|
||||
wandb.log(
|
||||
{
|
||||
"hit_rate": base_hit_rate,
|
||||
}
|
||||
)
|
||||
wandb.log(metrics)
|
||||
wandb.finish()
|
||||
|
||||
|
||||
def banchmark_all():
|
||||
datasets = ["Uber10KDataset2021", "MiniTruthfulQADataset", "MiniSquadV2Dataset", "MiniEsgBenchDataset", "MiniCovidQaDataset", "Llama2PaperDataset", "HistoryOfAlexnetDataset", "PatronusAIFinanceBenchDataset"]
|
||||
datasets = [
|
||||
"Uber10KDataset2021",
|
||||
"MiniTruthfulQADataset",
|
||||
"MiniSquadV2Dataset",
|
||||
"MiniEsgBenchDataset",
|
||||
"MiniCovidQaDataset",
|
||||
"Llama2PaperDataset",
|
||||
"HistoryOfAlexnetDataset",
|
||||
"PatronusAIFinanceBenchDataset",
|
||||
]
|
||||
models = ["BAAI/bge-small-en-v1.5"]
|
||||
top_ks = [5]
|
||||
for top_k in top_ks:
|
||||
@@ -163,19 +205,30 @@ def banchmark_all():
|
||||
for dataset in datasets:
|
||||
main(dataset, model, 5, top_k=top_k)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--dataset", type=str, default="BraintrustCodaHelpDeskDataset")
|
||||
parser.add_argument("--model", type=str, default="BAAI/bge-small-en-v1.5")
|
||||
parser.add_argument("--epochs", type=int, default=4)
|
||||
parser.add_argument("--project", type=str, default="lancedb_finetune")
|
||||
parser.add_argument("--top_k", type=int, default=5)
|
||||
parser.add_argument("--eval-rag", action="store_true")
|
||||
parser.add_argument("--split", type=int, default=1)
|
||||
parser.add_argument("--eval-rag", action="store_true", default=False)
|
||||
parser.add_argument("--benchmark-all", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.benchmark_all:
|
||||
banchmark_all()
|
||||
else:
|
||||
main(args.dataset, args.model, args.epochs, args.top_k, args.eval_rag, args.project)
|
||||
main(
|
||||
args.dataset,
|
||||
args.model,
|
||||
args.epochs,
|
||||
args.top_k,
|
||||
args.eval_rag,
|
||||
args.split,
|
||||
args.project,
|
||||
)
|
||||
|
||||
@@ -10,12 +10,15 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import List, Union
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import logging
|
||||
from lancedb.embeddings.fine_tuner import QADataset
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import TextEmbeddingFunction
|
||||
from .fine_tuner.basetuner import BaseEmbeddingTuner
|
||||
from .registry import register
|
||||
from .utils import weak_lru
|
||||
|
||||
@@ -80,3 +83,151 @@ class SentenceTransformerEmbeddings(TextEmbeddingFunction):
|
||||
"sentence_transformers", "sentence-transformers"
|
||||
)
|
||||
return sentence_transformers.SentenceTransformer(self.name, device=self.device)
|
||||
|
||||
def finetune(self, trainset: QADataset, *args, **kwargs):
|
||||
"""
|
||||
Finetune the Sentence Transformers model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dataset: QADataset
|
||||
The dataset to use for finetuning
|
||||
"""
|
||||
tuner = SentenceTransformersTuner(
|
||||
model=self.embedding_model,
|
||||
trainset=trainset,
|
||||
**kwargs,
|
||||
)
|
||||
tuner.finetune()
|
||||
|
||||
|
||||
class SentenceTransformersTuner(BaseEmbeddingTuner):
|
||||
"""Sentence Transformers Embedding Finetuning Engine."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Any,
|
||||
trainset: QADataset,
|
||||
valset: Optional[QADataset] = None,
|
||||
path: Optional[str] = "~/.lancedb/embeddings/models",
|
||||
batch_size: int = 8,
|
||||
epochs: int = 1,
|
||||
show_progress: bool = True,
|
||||
eval_steps: int = 50,
|
||||
max_input_per_doc: int = -1,
|
||||
loss: Optional[Any] = None,
|
||||
evaluator: Optional[Any] = None,
|
||||
run_name: Optional[str] = None,
|
||||
log_wandb: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
model: str
|
||||
The model to use for finetuning.
|
||||
trainset: QADataset
|
||||
The training dataset.
|
||||
valset: Optional[QADataset]
|
||||
The validation dataset.
|
||||
path: Optional[str]
|
||||
The path to save the model.
|
||||
batch_size: int, default=8
|
||||
The batch size.
|
||||
epochs: int, default=1
|
||||
The number of epochs.
|
||||
show_progress: bool, default=True
|
||||
Whether to show progress.
|
||||
eval_steps: int, default=50
|
||||
The number of steps to evaluate.
|
||||
max_input_per_doc: int, default=-1
|
||||
The number of input per document.
|
||||
if -1, use all documents.
|
||||
"""
|
||||
from sentence_transformers import InputExample, losses
|
||||
from sentence_transformers.evaluation import InformationRetrievalEvaluator
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
self.model = model
|
||||
self.trainset = trainset
|
||||
self.valset = valset
|
||||
self.path = path
|
||||
self.batch_size = batch_size
|
||||
self.epochs = epochs
|
||||
self.show_progress = show_progress
|
||||
self.eval_steps = eval_steps
|
||||
self.max_input_per_doc = max_input_per_doc
|
||||
self.evaluator = None
|
||||
self.epochs = epochs
|
||||
self.show_progress = show_progress
|
||||
self.eval_steps = eval_steps
|
||||
self.run_name = run_name
|
||||
self.log_wandb = log_wandb
|
||||
|
||||
if self.max_input_per_doc < -1:
|
||||
raise ValueError("max_input_per_doc must be -1 or greater than 0.")
|
||||
|
||||
examples: Any = []
|
||||
for query_id, query in self.trainset.queries.items():
|
||||
if max_input_per_doc == -1:
|
||||
for node_id in self.trainset.relevant_docs[query_id]:
|
||||
text = self.trainset.corpus[node_id]
|
||||
example = InputExample(texts=[query, text])
|
||||
examples.append(example)
|
||||
else:
|
||||
node_id = self.trainset.relevant_docs[query_id][
|
||||
min(max_input_per_doc, len(self.trainset.relevant_docs[query_id]))
|
||||
]
|
||||
text = self.trainset.corpus[node_id]
|
||||
example = InputExample(texts=[query, text])
|
||||
examples.append(example)
|
||||
|
||||
self.examples = examples
|
||||
|
||||
self.loader: DataLoader = DataLoader(examples, batch_size=batch_size)
|
||||
|
||||
if self.valset is not None:
|
||||
eval_engine = evaluator or InformationRetrievalEvaluator
|
||||
self.evaluator = eval_engine(
|
||||
valset.queries, valset.corpus, valset.relevant_docs
|
||||
)
|
||||
self.evaluator = evaluator
|
||||
|
||||
# define loss
|
||||
self.loss = loss or losses.MultipleNegativesRankingLoss(self.model)
|
||||
self.warmup_steps = int(len(self.loader) * epochs * 0.1)
|
||||
|
||||
def finetune(self) -> None:
|
||||
"""Finetune the Sentence Transformers model."""
|
||||
self.model.fit(
|
||||
train_objectives=[(self.loader, self.loss)],
|
||||
epochs=self.epochs,
|
||||
warmup_steps=self.warmup_steps,
|
||||
output_path=self.path,
|
||||
show_progress_bar=self.show_progress,
|
||||
evaluator=self.evaluator,
|
||||
evaluation_steps=self.eval_steps,
|
||||
callback=self._wandb_callback if self.log_wandb else None,
|
||||
)
|
||||
|
||||
self.helper()
|
||||
|
||||
def helper(self) -> None:
|
||||
"""A helper method."""
|
||||
logging.info("Finetuning complete.")
|
||||
logging.info(f"Model saved to {self.path}.")
|
||||
logging.info("You can now use the model as follows:")
|
||||
logging.info(
|
||||
f"model = get_registry().get('sentence-transformers').create(name='./{self.path}')" # noqa
|
||||
)
|
||||
|
||||
def _wandb_callback(self, score, epoch, steps):
|
||||
try:
|
||||
import wandb
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"wandb is not installed. Please install it using `pip install wandb`"
|
||||
)
|
||||
run = wandb.run or wandb.init(
|
||||
project="sbert_lancedb_finetune", name=self.run_name
|
||||
)
|
||||
run.log({"epoch": epoch, "steps": steps, "score": score})
|
||||
|
||||
Reference in New Issue
Block a user