mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-07 12:22:59 +00:00
Specify and Index Column for Vector Search (#217)
This commit is contained in:
@@ -11,6 +11,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest.mock as mock
|
||||
|
||||
import lance
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@@ -20,6 +22,7 @@ import pytest
|
||||
|
||||
from lancedb.db import LanceDBConnection
|
||||
from lancedb.query import LanceQueryBuilder
|
||||
from lancedb.table import LanceTable
|
||||
|
||||
|
||||
class MockTable:
|
||||
@@ -48,24 +51,30 @@ def table(tmp_path) -> MockTable:
|
||||
|
||||
|
||||
def test_query_builder(table):
|
||||
df = LanceQueryBuilder(table, [0, 0]).limit(1).select(["id"]).to_df()
|
||||
df = LanceQueryBuilder(table, [0, 0], "vector").limit(1).select(["id"]).to_df()
|
||||
assert df["id"].values[0] == 1
|
||||
assert all(df["vector"].values[0] == [1, 2])
|
||||
|
||||
|
||||
def test_query_builder_with_filter(table):
|
||||
df = LanceQueryBuilder(table, [0, 0]).where("id = 2").to_df()
|
||||
df = LanceQueryBuilder(table, [0, 0], "vector").where("id = 2").to_df()
|
||||
assert df["id"].values[0] == 2
|
||||
assert all(df["vector"].values[0] == [3, 4])
|
||||
|
||||
|
||||
def test_query_builder_with_metric(table):
|
||||
query = [4, 8]
|
||||
df_default = LanceQueryBuilder(table, query).to_df()
|
||||
df_l2 = LanceQueryBuilder(table, query).metric("L2").to_df()
|
||||
vector_column_name = "vector"
|
||||
df_default = LanceQueryBuilder(table, query, vector_column_name).to_df()
|
||||
df_l2 = LanceQueryBuilder(table, query, vector_column_name).metric("L2").to_df()
|
||||
tm.assert_frame_equal(df_default, df_l2)
|
||||
|
||||
df_cosine = LanceQueryBuilder(table, query).metric("cosine").limit(1).to_df()
|
||||
df_cosine = (
|
||||
LanceQueryBuilder(table, query, vector_column_name)
|
||||
.metric("cosine")
|
||||
.limit(1)
|
||||
.to_df()
|
||||
)
|
||||
assert df_cosine.score[0] == pytest.approx(
|
||||
cosine_distance(query, df_cosine.vector[0]),
|
||||
abs=1e-6,
|
||||
@@ -73,5 +82,35 @@ def test_query_builder_with_metric(table):
|
||||
assert 0 <= df_cosine.score[0] <= 1
|
||||
|
||||
|
||||
def test_query_builder_with_different_vector_column():
|
||||
table = mock.MagicMock(spec=LanceTable)
|
||||
query = [4, 8]
|
||||
vector_column_name = "foo_vector"
|
||||
builder = (
|
||||
LanceQueryBuilder(table, query, vector_column_name)
|
||||
.metric("cosine")
|
||||
.where("b < 10")
|
||||
.select(["b"])
|
||||
.limit(2)
|
||||
)
|
||||
ds = mock.Mock()
|
||||
table.to_lance.return_value = ds
|
||||
table._conn = mock.MagicMock()
|
||||
table._conn.is_managed_remote = False
|
||||
builder.to_arrow()
|
||||
ds.to_table.assert_called_once_with(
|
||||
columns=["b"],
|
||||
filter="b < 10",
|
||||
nearest={
|
||||
"column": vector_column_name,
|
||||
"q": query,
|
||||
"k": 2,
|
||||
"metric": "cosine",
|
||||
"nprobes": 20,
|
||||
"refine_factor": None,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def cosine_distance(vec1, vec2):
|
||||
return 1 - np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
|
||||
|
||||
Reference in New Issue
Block a user