mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-26 22:59:57 +00:00
update
This commit is contained in:
@@ -10,13 +10,18 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
from tqdm import tqdm
|
||||
|
||||
import lancedb
|
||||
|
||||
from .fine_tuner import QADataset
|
||||
from .utils import TEXT, retry_with_exponential_backoff
|
||||
|
||||
|
||||
@@ -126,6 +131,22 @@ class EmbeddingFunction(BaseModel, ABC):
|
||||
def __hash__(self) -> int:
|
||||
return hash(frozenset(vars(self).items()))
|
||||
|
||||
def finetune(self, dataset: QADataset, *args, **kwargs):
|
||||
"""
|
||||
Finetune the embedding function on a dataset
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Finetuning is not supported for this embedding function"
|
||||
)
|
||||
|
||||
def evaluate(self, dataset: QADataset, top_k=5, path=None, *args, **kwargs):
|
||||
"""
|
||||
Evaluate the embedding function on a dataset
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Evaluation is not supported for this embedding function"
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingFunctionConfig(BaseModel):
|
||||
"""
|
||||
@@ -159,3 +180,52 @@ class TextEmbeddingFunction(EmbeddingFunction):
|
||||
Generate the embeddings for the given texts
|
||||
"""
|
||||
pass
|
||||
|
||||
def evaluate(self, dataset: QADataset, top_k=5, path=None, *args, **kwargs):
|
||||
"""
|
||||
Evaluate the embedding function on a dataset. This calculates the hit-rate for
|
||||
the top-k retrieved documents for each query in the dataset. Assumes that the
|
||||
first relevant document is the expected document.
|
||||
Pro - Should work for any embedding model
|
||||
Con - Returns every simple metric.
|
||||
Parameters
|
||||
----------
|
||||
dataset: QADataset
|
||||
The dataset to evaluate on
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
The evaluation results
|
||||
"""
|
||||
corpus = dataset.corpus
|
||||
queries = dataset.queries
|
||||
relevant_docs = dataset.relevant_docs
|
||||
path = path or os.path.join(os.getcwd(), "eval")
|
||||
db = lancedb.connect(path)
|
||||
|
||||
class Schema(lancedb.pydantic.LanceModel):
|
||||
id: str
|
||||
text: str = self.SourceField()
|
||||
vector: lancedb.pydantic.Vector(self.ndims()) = self.VectorField()
|
||||
|
||||
retriever = db.create_table("eval", schema=Schema, mode="overwrite")
|
||||
pylist = [{"id": str(k), "text": v} for k, v in corpus.items()]
|
||||
retriever.add(pylist)
|
||||
|
||||
eval_results = []
|
||||
for query_id, query in tqdm(queries.items()):
|
||||
retrieved_nodes = retriever.search(query).limit(top_k).to_list()
|
||||
retrieved_ids = [node["id"] for node in retrieved_nodes]
|
||||
expected_id = relevant_docs[query_id][0]
|
||||
is_hit = expected_id in retrieved_ids # assume 1 relevant doc
|
||||
|
||||
eval_result = {
|
||||
"is_hit": is_hit,
|
||||
"retrieved": retrieved_ids,
|
||||
"expected": expected_id,
|
||||
"query": query_id,
|
||||
}
|
||||
eval_results.append(eval_result)
|
||||
|
||||
return eval_results
|
||||
|
||||
Reference in New Issue
Block a user