From d0c11134175261f27af7436b829bdbc78f701104 Mon Sep 17 00:00:00 2001 From: ayush chaurasia Date: Mon, 15 Apr 2024 17:07:29 +0530 Subject: [PATCH] add test --- python/python/tests/test_embedding_tuner.py | 45 +++++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 python/python/tests/test_embedding_tuner.py diff --git a/python/python/tests/test_embedding_tuner.py b/python/python/tests/test_embedding_tuner.py new file mode 100644 index 00000000..c2733acd --- /dev/null +++ b/python/python/tests/test_embedding_tuner.py @@ -0,0 +1,45 @@ +import uuid + +import pytest +from lancedb.embeddings import get_registry +from lancedb.embeddings.fine_tuner import QADataset, TextChunk +from tqdm import tqdm + + +@pytest.mark.slow +def test_finetuning_sentence_transformers(tmp_path): + queries = {} + relevant_docs = {} + chunks = [ + "This is a chunk related to legal docs", + "This is another chunk related financial docs", + "This is a chunk related to sports docs", + "This is another chunk related to fashion docs", + ] + text_chunks = [TextChunk.from_chunk(chunk) for chunk in chunks] + for chunk in tqdm(text_chunks): + questions = [ + "What is this chunk about?", + "What is the main topic of this chunk?", + ] + for question in questions: + question_id = str(uuid.uuid4()) + queries[question_id] = question + relevant_docs[question_id] = [chunk.id] + ds = QADataset.from_responses(text_chunks, queries, relevant_docs) + + assert len(ds.queries) == 8 + assert len(ds.corpus) == 4 + + model = get_registry().get("sentence-transformers").create() + model.finetune(trainset=ds, valset=ds, path=str(tmp_path / "model"), epochs=1) + model = ( + get_registry().get("sentence-transformers").create(name=str(tmp_path / "model")) + ) + res = model.evaluate(ds) + assert res is not None + + +def test_text_chunk(): + # TODO + pass \ No newline at end of file