mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 15:12:53 +00:00
feat: allow distance type (metric) to be specified during hybrid search (#1777)
This commit is contained in:
@@ -26,7 +26,7 @@ registry = EmbeddingFunctionRegistry.get_instance()
|
||||
@registry.register("test")
|
||||
class MockTextEmbeddingFunction(TextEmbeddingFunction):
|
||||
"""
|
||||
Return the hash of the first 10 characters
|
||||
Return the hash of the first 10 characters (normalized)
|
||||
"""
|
||||
|
||||
def generate_embeddings(self, texts):
|
||||
@@ -41,6 +41,23 @@ class MockTextEmbeddingFunction(TextEmbeddingFunction):
|
||||
return 10
|
||||
|
||||
|
||||
@registry.register("nonnorm")
|
||||
class MockNonNormTextEmbeddingFunction(TextEmbeddingFunction):
|
||||
"""
|
||||
Return the ord of the first 10 characters (not normalized)
|
||||
"""
|
||||
|
||||
def generate_embeddings(self, texts):
|
||||
return [self._compute_one_embedding(row) for row in texts]
|
||||
|
||||
def _compute_one_embedding(self, row):
|
||||
emb = np.array([float(ord(c)) for c in row[:10]])
|
||||
return emb if len(emb) == 10 else [0] * 10
|
||||
|
||||
def ndims(self):
|
||||
return 10
|
||||
|
||||
|
||||
class RateLimitedAPI:
|
||||
rate_limit = 0.1 # 1 request per 0.1 second
|
||||
last_request_time = 0
|
||||
|
||||
@@ -983,6 +983,7 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
self._reranker = RRFReranker()
|
||||
self._nprobes = None
|
||||
self._refine_factor = None
|
||||
self._metric = None
|
||||
self._phrase_query = False
|
||||
|
||||
def _validate_query(self, query, vector=None, text=None):
|
||||
@@ -1050,6 +1051,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
self._fts_query.with_row_id(True)
|
||||
if self._phrase_query:
|
||||
self._fts_query.phrase_query(True)
|
||||
if self._metric:
|
||||
self._vector_query.metric(self._metric)
|
||||
if self._nprobes:
|
||||
self._vector_query.nprobes(self._nprobes)
|
||||
if self._refine_factor:
|
||||
@@ -1067,6 +1070,7 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
if self._norm == "rank":
|
||||
vector_results = self._rank(vector_results, "_distance")
|
||||
fts_results = self._rank(fts_results, "_score")
|
||||
|
||||
# normalize the scores to be between 0 and 1, 0 being most relevant
|
||||
vector_results = self._normalize_scores(vector_results, "_distance")
|
||||
|
||||
@@ -1115,7 +1119,9 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
rng = max
|
||||
else:
|
||||
rng = max - min
|
||||
scores = (scores - min) / rng
|
||||
# If rng is 0 then min and max are both 0 and so we can leave the scores as is
|
||||
if rng != 0:
|
||||
scores = (scores - min) / rng
|
||||
if invert:
|
||||
scores = 1 - scores
|
||||
# replace the _score column with the ranks
|
||||
@@ -1177,6 +1183,22 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
self._nprobes = nprobes
|
||||
return self
|
||||
|
||||
def metric(self, metric: Literal["L2", "cosine", "dot"]) -> LanceHybridQueryBuilder:
|
||||
"""Set the distance metric to use.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
metric: "L2" or "cosine" or "dot"
|
||||
The distance metric to use. By default "L2" is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceVectorQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._metric = metric.lower()
|
||||
return self
|
||||
|
||||
def refine_factor(self, refine_factor: int) -> LanceHybridQueryBuilder:
|
||||
"""
|
||||
Refine the vector search results by reading extra elements and
|
||||
|
||||
@@ -991,13 +991,10 @@ def test_count_rows(db):
|
||||
assert table.count_rows(filter="text='bar'") == 1
|
||||
|
||||
|
||||
def test_hybrid_search(db, tmp_path):
|
||||
# This test uses an FTS index
|
||||
pytest.importorskip("lancedb.fts")
|
||||
|
||||
def setup_hybrid_search_table(tmp_path, embedding_func):
|
||||
db = MockDB(str(tmp_path))
|
||||
# Create a LanceDB table schema with a vector and a text column
|
||||
emb = EmbeddingFunctionRegistry.get_instance().get("test")()
|
||||
emb = EmbeddingFunctionRegistry.get_instance().get(embedding_func)()
|
||||
|
||||
class MyTable(LanceModel):
|
||||
text: str = emb.SourceField()
|
||||
@@ -1030,6 +1027,15 @@ def test_hybrid_search(db, tmp_path):
|
||||
# Create a fts index
|
||||
table.create_fts_index("text")
|
||||
|
||||
return table, MyTable, emb
|
||||
|
||||
|
||||
def test_hybrid_search(tmp_path):
|
||||
# This test uses an FTS index
|
||||
pytest.importorskip("lancedb.fts")
|
||||
|
||||
table, MyTable, emb = setup_hybrid_search_table(tmp_path, "test")
|
||||
|
||||
result1 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(normalize="score")
|
||||
@@ -1094,6 +1100,24 @@ def test_hybrid_search(db, tmp_path):
|
||||
table.search(query_type="hybrid").text("Arrrrggghhhhhhh").to_list()
|
||||
|
||||
|
||||
def test_hybrid_search_metric_type(db, tmp_path):
|
||||
# This test uses an FTS index
|
||||
pytest.importorskip("lancedb.fts")
|
||||
|
||||
# Need to use nonnorm as the embedding function so L2 and dot results
|
||||
# are different
|
||||
table, _, _ = setup_hybrid_search_table(tmp_path, "nonnorm")
|
||||
|
||||
# with custom metric
|
||||
result_dot = (
|
||||
table.search("feeling lucky", query_type="hybrid").metric("dot").to_arrow()
|
||||
)
|
||||
result_l2 = table.search("feeling lucky", query_type="hybrid").to_arrow()
|
||||
assert len(result_dot) > 0
|
||||
assert len(result_l2) > 0
|
||||
assert result_dot["_relevance_score"] != result_l2["_relevance_score"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"consistency_interval", [None, timedelta(seconds=0), timedelta(seconds=0.1)]
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user