This commit is contained in:
ayush chaurasia
2024-04-15 17:07:29 +05:30
parent 3ca96a852f
commit d0c1113417

View File

@@ -0,0 +1,45 @@
import uuid
import pytest
from lancedb.embeddings import get_registry
from lancedb.embeddings.fine_tuner import QADataset, TextChunk
from tqdm import tqdm
@pytest.mark.slow
def test_finetuning_sentence_transformers(tmp_path):
queries = {}
relevant_docs = {}
chunks = [
"This is a chunk related to legal docs",
"This is another chunk related financial docs",
"This is a chunk related to sports docs",
"This is another chunk related to fashion docs",
]
text_chunks = [TextChunk.from_chunk(chunk) for chunk in chunks]
for chunk in tqdm(text_chunks):
questions = [
"What is this chunk about?",
"What is the main topic of this chunk?",
]
for question in questions:
question_id = str(uuid.uuid4())
queries[question_id] = question
relevant_docs[question_id] = [chunk.id]
ds = QADataset.from_responses(text_chunks, queries, relevant_docs)
assert len(ds.queries) == 8
assert len(ds.corpus) == 4
model = get_registry().get("sentence-transformers").create()
model.finetune(trainset=ds, valset=ds, path=str(tmp_path / "model"), epochs=1)
model = (
get_registry().get("sentence-transformers").create(name=str(tmp_path / "model"))
)
res = model.evaluate(ds)
assert res is not None
def test_text_chunk():
# TODO
pass