diff --git a/python/lancedb/query.py b/python/lancedb/query.py index 21333bec..949e5ef7 100644 --- a/python/lancedb/query.py +++ b/python/lancedb/query.py @@ -24,6 +24,7 @@ class LanceQueryBuilder: """ def __init__(self, table: "lancedb.table.LanceTable", query: np.ndarray): + self._metric = "l2" self._nprobes = 20 self._refine_factor = None self._table = table @@ -77,6 +78,21 @@ class LanceQueryBuilder: self._where = where return self + def metric(self, metric: str) -> LanceQueryBuilder: + """Set the distance metric to use. + + Parameters + ---------- + metric: str + The distance metric to use. By default "l2" is used. + + Returns + ------- + The LanceQueryBuilder object. + """ + self._metric = metric + return self + def nprobes(self, nprobes: int) -> LanceQueryBuilder: """Set the number of probes to use. @@ -118,6 +134,7 @@ class LanceQueryBuilder: "column": VECTOR_COLUMN_NAME, "q": self._query, "k": self._limit, + "metric": self._metric, "nprobes": self._nprobes, "refine_factor": self._refine_factor, }, diff --git a/python/tests/test_query.py b/python/tests/test_query.py index c08cdd8f..ae1bebda 100644 --- a/python/tests/test_query.py +++ b/python/tests/test_query.py @@ -14,7 +14,9 @@ import lance from lancedb.query import LanceQueryBuilder +import numpy as np import pandas as pd +import pandas.testing as tm import pyarrow as pa import pytest @@ -60,3 +62,20 @@ def test_query_builder_with_filter(table): df = LanceQueryBuilder(table, [0, 0]).where("id = 2").to_df() assert df["id"].values[0] == 2 assert all(df["vector"].values[0] == [3, 4]) + + +def test_query_builder_with_metric(table): + query = [4, 8] + df_default = LanceQueryBuilder(table, query).to_df() + df_l2 = LanceQueryBuilder(table, query).metric("l2").to_df() + tm.assert_frame_equal(df_default, df_l2) + + df_cosine = LanceQueryBuilder(table, query).metric("cosine").limit(1).to_df() + assert df_cosine.score[0] == pytest.approx( + cosine_distance(query, df_cosine.vector[0]) + ) + assert 0 <= df_cosine.score[0] <= 1 + + +def cosine_distance(vec1, vec2): + return 1 - np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))