mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-26 06:39:57 +00:00
feat: allow distance type (metric) to be specified during hybrid search (#1777)
This commit is contained in:
@@ -991,13 +991,10 @@ def test_count_rows(db):
|
||||
assert table.count_rows(filter="text='bar'") == 1
|
||||
|
||||
|
||||
def test_hybrid_search(db, tmp_path):
|
||||
# This test uses an FTS index
|
||||
pytest.importorskip("lancedb.fts")
|
||||
|
||||
def setup_hybrid_search_table(tmp_path, embedding_func):
|
||||
db = MockDB(str(tmp_path))
|
||||
# Create a LanceDB table schema with a vector and a text column
|
||||
emb = EmbeddingFunctionRegistry.get_instance().get("test")()
|
||||
emb = EmbeddingFunctionRegistry.get_instance().get(embedding_func)()
|
||||
|
||||
class MyTable(LanceModel):
|
||||
text: str = emb.SourceField()
|
||||
@@ -1030,6 +1027,15 @@ def test_hybrid_search(db, tmp_path):
|
||||
# Create a fts index
|
||||
table.create_fts_index("text")
|
||||
|
||||
return table, MyTable, emb
|
||||
|
||||
|
||||
def test_hybrid_search(tmp_path):
|
||||
# This test uses an FTS index
|
||||
pytest.importorskip("lancedb.fts")
|
||||
|
||||
table, MyTable, emb = setup_hybrid_search_table(tmp_path, "test")
|
||||
|
||||
result1 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(normalize="score")
|
||||
@@ -1094,6 +1100,24 @@ def test_hybrid_search(db, tmp_path):
|
||||
table.search(query_type="hybrid").text("Arrrrggghhhhhhh").to_list()
|
||||
|
||||
|
||||
def test_hybrid_search_metric_type(db, tmp_path):
|
||||
# This test uses an FTS index
|
||||
pytest.importorskip("lancedb.fts")
|
||||
|
||||
# Need to use nonnorm as the embedding function so L2 and dot results
|
||||
# are different
|
||||
table, _, _ = setup_hybrid_search_table(tmp_path, "nonnorm")
|
||||
|
||||
# with custom metric
|
||||
result_dot = (
|
||||
table.search("feeling lucky", query_type="hybrid").metric("dot").to_arrow()
|
||||
)
|
||||
result_l2 = table.search("feeling lucky", query_type="hybrid").to_arrow()
|
||||
assert len(result_dot) > 0
|
||||
assert len(result_l2) > 0
|
||||
assert result_dot["_relevance_score"] != result_l2["_relevance_score"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"consistency_interval", [None, timedelta(seconds=0), timedelta(seconds=0.1)]
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user