mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-04 02:42:57 +00:00
add test
This commit is contained in:
45
python/python/tests/test_embedding_tuner.py
Normal file
45
python/python/tests/test_embedding_tuner.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from lancedb.embeddings import get_registry
|
||||
from lancedb.embeddings.fine_tuner import QADataset, TextChunk
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_finetuning_sentence_transformers(tmp_path):
|
||||
queries = {}
|
||||
relevant_docs = {}
|
||||
chunks = [
|
||||
"This is a chunk related to legal docs",
|
||||
"This is another chunk related financial docs",
|
||||
"This is a chunk related to sports docs",
|
||||
"This is another chunk related to fashion docs",
|
||||
]
|
||||
text_chunks = [TextChunk.from_chunk(chunk) for chunk in chunks]
|
||||
for chunk in tqdm(text_chunks):
|
||||
questions = [
|
||||
"What is this chunk about?",
|
||||
"What is the main topic of this chunk?",
|
||||
]
|
||||
for question in questions:
|
||||
question_id = str(uuid.uuid4())
|
||||
queries[question_id] = question
|
||||
relevant_docs[question_id] = [chunk.id]
|
||||
ds = QADataset.from_responses(text_chunks, queries, relevant_docs)
|
||||
|
||||
assert len(ds.queries) == 8
|
||||
assert len(ds.corpus) == 4
|
||||
|
||||
model = get_registry().get("sentence-transformers").create()
|
||||
model.finetune(trainset=ds, valset=ds, path=str(tmp_path / "model"), epochs=1)
|
||||
model = (
|
||||
get_registry().get("sentence-transformers").create(name=str(tmp_path / "model"))
|
||||
)
|
||||
res = model.evaluate(ds)
|
||||
assert res is not None
|
||||
|
||||
|
||||
def test_text_chunk():
|
||||
# TODO
|
||||
pass
|
||||
Reference in New Issue
Block a user