feat: allow distance type (metric) to be specified during hybrid search (#1777)

This commit is contained in:
Weston Pace
2024-10-29 13:51:18 -07:00
committed by GitHub
parent d71df4572e
commit 55104c5bae
3 changed files with 70 additions and 7 deletions

View File

@@ -26,7 +26,7 @@ registry = EmbeddingFunctionRegistry.get_instance()
@registry.register("test")
class MockTextEmbeddingFunction(TextEmbeddingFunction):
"""
Return the hash of the first 10 characters
Return the hash of the first 10 characters (normalized)
"""
def generate_embeddings(self, texts):
@@ -41,6 +41,23 @@ class MockTextEmbeddingFunction(TextEmbeddingFunction):
return 10
@registry.register("nonnorm")
class MockNonNormTextEmbeddingFunction(TextEmbeddingFunction):
"""
Return the ord of the first 10 characters (not normalized)
"""
def generate_embeddings(self, texts):
return [self._compute_one_embedding(row) for row in texts]
def _compute_one_embedding(self, row):
emb = np.array([float(ord(c)) for c in row[:10]])
return emb if len(emb) == 10 else [0] * 10
def ndims(self):
return 10
class RateLimitedAPI:
rate_limit = 0.1 # 1 request per 0.1 second
last_request_time = 0

View File

@@ -983,6 +983,7 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
self._reranker = RRFReranker()
self._nprobes = None
self._refine_factor = None
self._metric = None
self._phrase_query = False
def _validate_query(self, query, vector=None, text=None):
@@ -1050,6 +1051,8 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
self._fts_query.with_row_id(True)
if self._phrase_query:
self._fts_query.phrase_query(True)
if self._metric:
self._vector_query.metric(self._metric)
if self._nprobes:
self._vector_query.nprobes(self._nprobes)
if self._refine_factor:
@@ -1067,6 +1070,7 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
if self._norm == "rank":
vector_results = self._rank(vector_results, "_distance")
fts_results = self._rank(fts_results, "_score")
# normalize the scores to be between 0 and 1, 0 being most relevant
vector_results = self._normalize_scores(vector_results, "_distance")
@@ -1115,7 +1119,9 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
rng = max
else:
rng = max - min
scores = (scores - min) / rng
# If rng is 0 then min and max are both 0 and so we can leave the scores as is
if rng != 0:
scores = (scores - min) / rng
if invert:
scores = 1 - scores
# replace the _score column with the ranks
@@ -1177,6 +1183,22 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
self._nprobes = nprobes
return self
def metric(self, metric: Literal["L2", "cosine", "dot"]) -> LanceHybridQueryBuilder:
"""Set the distance metric to use.
Parameters
----------
metric: "L2" or "cosine" or "dot"
The distance metric to use. By default "L2" is used.
Returns
-------
LanceVectorQueryBuilder
The LanceQueryBuilder object.
"""
self._metric = metric.lower()
return self
def refine_factor(self, refine_factor: int) -> LanceHybridQueryBuilder:
"""
Refine the vector search results by reading extra elements and

View File

@@ -991,13 +991,10 @@ def test_count_rows(db):
assert table.count_rows(filter="text='bar'") == 1
def test_hybrid_search(db, tmp_path):
# This test uses an FTS index
pytest.importorskip("lancedb.fts")
def setup_hybrid_search_table(tmp_path, embedding_func):
db = MockDB(str(tmp_path))
# Create a LanceDB table schema with a vector and a text column
emb = EmbeddingFunctionRegistry.get_instance().get("test")()
emb = EmbeddingFunctionRegistry.get_instance().get(embedding_func)()
class MyTable(LanceModel):
text: str = emb.SourceField()
@@ -1030,6 +1027,15 @@ def test_hybrid_search(db, tmp_path):
# Create a fts index
table.create_fts_index("text")
return table, MyTable, emb
def test_hybrid_search(tmp_path):
# This test uses an FTS index
pytest.importorskip("lancedb.fts")
table, MyTable, emb = setup_hybrid_search_table(tmp_path, "test")
result1 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="score")
@@ -1094,6 +1100,24 @@ def test_hybrid_search(db, tmp_path):
table.search(query_type="hybrid").text("Arrrrggghhhhhhh").to_list()
def test_hybrid_search_metric_type(db, tmp_path):
# This test uses an FTS index
pytest.importorskip("lancedb.fts")
# Need to use nonnorm as the embedding function so L2 and dot results
# are different
table, _, _ = setup_hybrid_search_table(tmp_path, "nonnorm")
# with custom metric
result_dot = (
table.search("feeling lucky", query_type="hybrid").metric("dot").to_arrow()
)
result_l2 = table.search("feeling lucky", query_type="hybrid").to_arrow()
assert len(result_dot) > 0
assert len(result_l2) > 0
assert result_dot["_relevance_score"] != result_l2["_relevance_score"]
@pytest.mark.parametrize(
"consistency_interval", [None, timedelta(seconds=0), timedelta(seconds=0.1)]
)