mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-04 10:52:56 +00:00
update
This commit is contained in:
@@ -5,7 +5,6 @@ from llama_index.core.node_parser import SentenceSplitter
|
||||
from lancedb.embeddings.fine_tuner.dataset import QADataset, TextChunk
|
||||
from lancedb.embeddings.fine_tuner.llm import Openai
|
||||
from lancedb.embeddings import get_registry
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.vector_stores.lancedb import LanceDBVectorStore
|
||||
from llama_index.core import ServiceContext, VectorStoreIndex, StorageContext
|
||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||
@@ -40,7 +39,7 @@ def get_llama_dataset(dataset: str):
|
||||
"""
|
||||
if not os.path.exists(f"./data/{dataset}"):
|
||||
os.system(
|
||||
f"llamaindex-cli download-llamadataset {dataset} --download-dir ./data/{dataset}"
|
||||
f"llamaindex-cli download-llamadataset {dataset} --download-dir ./data/{dataset}" # noqa
|
||||
)
|
||||
rag_dataset = LabelledRagDataset.from_json(f"./data/{dataset}/rag_dataset.json")
|
||||
docs = SimpleDirectoryReader(input_dir=f"./data/{dataset}/source_files").load_data()
|
||||
@@ -89,7 +88,7 @@ def finetune(
|
||||
|
||||
valset = valset if valset is not None else trainset
|
||||
model = get_registry().get("sentence-transformers").create(name=model)
|
||||
base_result = model.evaluate(valset, path=f"./data/eval/", top_k=top_k)
|
||||
base_result = model.evaluate(valset, path="./data/eval/", top_k=top_k)
|
||||
base_hit_rate = pd.DataFrame(base_result)["is_hit"].mean()
|
||||
|
||||
model.finetune(trainset=trainset, valset=valset, path=path, epochs=epochs)
|
||||
|
||||
Reference in New Issue
Block a user