This commit is contained in:
ayush chaurasia
2024-04-16 09:24:29 +05:30
parent ea34c0b4c4
commit 6bc488f674

View File

@@ -5,7 +5,6 @@ from llama_index.core.node_parser import SentenceSplitter
from lancedb.embeddings.fine_tuner.dataset import QADataset, TextChunk
from lancedb.embeddings.fine_tuner.llm import Openai
from lancedb.embeddings import get_registry
from llama_index.core import VectorStoreIndex
from llama_index.vector_stores.lancedb import LanceDBVectorStore
from llama_index.core import ServiceContext, VectorStoreIndex, StorageContext
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
@@ -40,7 +39,7 @@ def get_llama_dataset(dataset: str):
"""
if not os.path.exists(f"./data/{dataset}"):
os.system(
f"llamaindex-cli download-llamadataset {dataset} --download-dir ./data/{dataset}"
f"llamaindex-cli download-llamadataset {dataset} --download-dir ./data/{dataset}" # noqa
)
rag_dataset = LabelledRagDataset.from_json(f"./data/{dataset}/rag_dataset.json")
docs = SimpleDirectoryReader(input_dir=f"./data/{dataset}/source_files").load_data()
@@ -89,7 +88,7 @@ def finetune(
valset = valset if valset is not None else trainset
model = get_registry().get("sentence-transformers").create(name=model)
base_result = model.evaluate(valset, path=f"./data/eval/", top_k=top_k)
base_result = model.evaluate(valset, path="./data/eval/", top_k=top_k)
base_hit_rate = pd.DataFrame(base_result)["is_hit"].mean()
model.finetune(trainset=trainset, valset=valset, path=path, epochs=epochs)