feat: allow distance type (metric) to be specified during hybrid search (#1777)

This commit is contained in:
Weston Pace
2024-10-29 13:51:18 -07:00
committed by GitHub
parent d71df4572e
commit 55104c5bae
3 changed files with 70 additions and 7 deletions

View File

@@ -991,13 +991,10 @@ def test_count_rows(db):
assert table.count_rows(filter="text='bar'") == 1
def test_hybrid_search(db, tmp_path):
# This test uses an FTS index
pytest.importorskip("lancedb.fts")
def setup_hybrid_search_table(tmp_path, embedding_func):
db = MockDB(str(tmp_path))
# Create a LanceDB table schema with a vector and a text column
emb = EmbeddingFunctionRegistry.get_instance().get("test")()
emb = EmbeddingFunctionRegistry.get_instance().get(embedding_func)()
class MyTable(LanceModel):
text: str = emb.SourceField()
@@ -1030,6 +1027,15 @@ def test_hybrid_search(db, tmp_path):
# Create a fts index
table.create_fts_index("text")
return table, MyTable, emb
def test_hybrid_search(tmp_path):
# This test uses an FTS index
pytest.importorskip("lancedb.fts")
table, MyTable, emb = setup_hybrid_search_table(tmp_path, "test")
result1 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="score")
@@ -1094,6 +1100,24 @@ def test_hybrid_search(db, tmp_path):
table.search(query_type="hybrid").text("Arrrrggghhhhhhh").to_list()
def test_hybrid_search_metric_type(db, tmp_path):
# This test uses an FTS index
pytest.importorskip("lancedb.fts")
# Need to use nonnorm as the embedding function so L2 and dot results
# are different
table, _, _ = setup_hybrid_search_table(tmp_path, "nonnorm")
# with custom metric
result_dot = (
table.search("feeling lucky", query_type="hybrid").metric("dot").to_arrow()
)
result_l2 = table.search("feeling lucky", query_type="hybrid").to_arrow()
assert len(result_dot) > 0
assert len(result_l2) > 0
assert result_dot["_relevance_score"] != result_l2["_relevance_score"]
@pytest.mark.parametrize(
"consistency_interval", [None, timedelta(seconds=0), timedelta(seconds=0.1)]
)