update benchmark script

This commit is contained in:
ayush chaurasia
2024-04-16 11:13:52 +05:30
parent df404b726e
commit c7bb919561

View File

@@ -1,3 +1,4 @@
import argparse
import os
from llama_index.core import SimpleDirectoryReader
from llama_index.core.llama_dataset import LabelledRagDataset
@@ -206,17 +207,63 @@ def banchmark_all():
if __name__ == "__main__":
import argparse
"""
Benchmark the fine-tuning process for a given dataset and model.
Usage:
- For a single dataset
python lancedb/docs/benchmarks/llama-index-datasets.py --dataset Uber10KDataset2021 --model BAAI/bge-small-en-v1.5 --epochs 4 --top_k 5 --split 1 --project lancedb_finetune
- For all datasets and models across all top_ks
python lancedb/docs/benchmarks/llama-index-datasets.py --benchmark-all
"""
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default="BraintrustCodaHelpDeskDataset")
parser.add_argument("--model", type=str, default="BAAI/bge-small-en-v1.5")
parser.add_argument("--epochs", type=int, default=4)
parser.add_argument("--project", type=str, default="lancedb_finetune")
parser.add_argument("--top_k", type=int, default=5)
parser.add_argument("--split", type=int, default=1)
parser.add_argument("--eval-rag", action="store_true", default=False)
parser.add_argument("--benchmark-all", action="store_true")
parser.add_argument(
"--dataset",
type=str,
default="BraintrustCodaHelpDeskDataset",
help="The dataset to use for fine-tuning",
)
parser.add_argument(
"--model",
type=str,
default="BAAI/bge-small-en-v1.5",
help="The model to use for fine-tuning",
)
parser.add_argument(
"--epochs",
type=int,
default=4,
help="The number of epochs to fine-tune the model",
)
parser.add_argument(
"--project",
type=str,
default="lancedb_finetune",
help="The wandb project to log the results",
)
parser.add_argument(
"--top_k", type=int, default=5, help="The number of top results to evaluate"
)
parser.add_argument(
"--split",
type=int,
default=1,
help="Whether to split the dataset into train and test(65-35 split), default is 1",
)
parser.add_argument(
"--eval-rag",
action="store_true",
default=False,
help="Whether to evaluate the model using RAG",
)
parser.add_argument(
"--benchmark-all",
action="store_true",
default=False,
help="Benchmark all datasets across all models and top_ks",
)
args = parser.parse_args()
if args.benchmark_all: