From ff00a3242c7b66ad3da9772d260f710f3454cd09 Mon Sep 17 00:00:00 2001 From: ayush chaurasia Date: Mon, 15 Apr 2024 07:52:04 +0530 Subject: [PATCH] update --- python/python/lancedb/embeddings/base.py | 70 ++++++++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/python/python/lancedb/embeddings/base.py b/python/python/lancedb/embeddings/base.py index 3d940810..1dd84b0e 100644 --- a/python/python/lancedb/embeddings/base.py +++ b/python/python/lancedb/embeddings/base.py @@ -10,13 +10,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os from abc import ABC, abstractmethod from typing import List, Union import numpy as np import pyarrow as pa from pydantic import BaseModel, Field, PrivateAttr +from tqdm import tqdm +import lancedb + +from .fine_tuner import QADataset from .utils import TEXT, retry_with_exponential_backoff @@ -126,6 +131,22 @@ class EmbeddingFunction(BaseModel, ABC): def __hash__(self) -> int: return hash(frozenset(vars(self).items())) + def finetune(self, dataset: QADataset, *args, **kwargs): + """ + Finetune the embedding function on a dataset + """ + raise NotImplementedError( + "Finetuning is not supported for this embedding function" + ) + + def evaluate(self, dataset: QADataset, top_k=5, path=None, *args, **kwargs): + """ + Evaluate the embedding function on a dataset + """ + raise NotImplementedError( + "Evaluation is not supported for this embedding function" + ) + class EmbeddingFunctionConfig(BaseModel): """ @@ -159,3 +180,52 @@ class TextEmbeddingFunction(EmbeddingFunction): Generate the embeddings for the given texts """ pass + + def evaluate(self, dataset: QADataset, top_k=5, path=None, *args, **kwargs): + """ + Evaluate the embedding function on a dataset. This calculates the hit-rate for + the top-k retrieved documents for each query in the dataset. Assumes that the + first relevant document is the expected document. + Pro - Should work for any embedding model + Con - Returns every simple metric. + Parameters + ---------- + dataset: QADataset + The dataset to evaluate on + + Returns + ------- + dict + The evaluation results + """ + corpus = dataset.corpus + queries = dataset.queries + relevant_docs = dataset.relevant_docs + path = path or os.path.join(os.getcwd(), "eval") + db = lancedb.connect(path) + + class Schema(lancedb.pydantic.LanceModel): + id: str + text: str = self.SourceField() + vector: lancedb.pydantic.Vector(self.ndims()) = self.VectorField() + + retriever = db.create_table("eval", schema=Schema, mode="overwrite") + pylist = [{"id": str(k), "text": v} for k, v in corpus.items()] + retriever.add(pylist) + + eval_results = [] + for query_id, query in tqdm(queries.items()): + retrieved_nodes = retriever.search(query).limit(top_k).to_list() + retrieved_ids = [node["id"] for node in retrieved_nodes] + expected_id = relevant_docs[query_id][0] + is_hit = expected_id in retrieved_ids # assume 1 relevant doc + + eval_result = { + "is_hit": is_hit, + "retrieved": retrieved_ids, + "expected": expected_id, + "query": query_id, + } + eval_results.append(eval_result) + + return eval_results