Specify and Index Column for Vector Search (#217)

This commit is contained in:
Philip Kung
2023-06-26 16:11:08 -07:00
committed by GitHub
parent e850df56f1
commit 313e66c4c5
4 changed files with 97 additions and 12 deletions

View File

@@ -11,6 +11,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest.mock as mock
import lance
import numpy as np
import pandas as pd
@@ -20,6 +22,7 @@ import pytest
from lancedb.db import LanceDBConnection
from lancedb.query import LanceQueryBuilder
from lancedb.table import LanceTable
class MockTable:
@@ -48,24 +51,30 @@ def table(tmp_path) -> MockTable:
def test_query_builder(table):
df = LanceQueryBuilder(table, [0, 0]).limit(1).select(["id"]).to_df()
df = LanceQueryBuilder(table, [0, 0], "vector").limit(1).select(["id"]).to_df()
assert df["id"].values[0] == 1
assert all(df["vector"].values[0] == [1, 2])
def test_query_builder_with_filter(table):
df = LanceQueryBuilder(table, [0, 0]).where("id = 2").to_df()
df = LanceQueryBuilder(table, [0, 0], "vector").where("id = 2").to_df()
assert df["id"].values[0] == 2
assert all(df["vector"].values[0] == [3, 4])
def test_query_builder_with_metric(table):
query = [4, 8]
df_default = LanceQueryBuilder(table, query).to_df()
df_l2 = LanceQueryBuilder(table, query).metric("L2").to_df()
vector_column_name = "vector"
df_default = LanceQueryBuilder(table, query, vector_column_name).to_df()
df_l2 = LanceQueryBuilder(table, query, vector_column_name).metric("L2").to_df()
tm.assert_frame_equal(df_default, df_l2)
df_cosine = LanceQueryBuilder(table, query).metric("cosine").limit(1).to_df()
df_cosine = (
LanceQueryBuilder(table, query, vector_column_name)
.metric("cosine")
.limit(1)
.to_df()
)
assert df_cosine.score[0] == pytest.approx(
cosine_distance(query, df_cosine.vector[0]),
abs=1e-6,
@@ -73,5 +82,35 @@ def test_query_builder_with_metric(table):
assert 0 <= df_cosine.score[0] <= 1
def test_query_builder_with_different_vector_column():
table = mock.MagicMock(spec=LanceTable)
query = [4, 8]
vector_column_name = "foo_vector"
builder = (
LanceQueryBuilder(table, query, vector_column_name)
.metric("cosine")
.where("b < 10")
.select(["b"])
.limit(2)
)
ds = mock.Mock()
table.to_lance.return_value = ds
table._conn = mock.MagicMock()
table._conn.is_managed_remote = False
builder.to_arrow()
ds.to_table.assert_called_once_with(
columns=["b"],
filter="b < 10",
nearest={
"column": vector_column_name,
"q": query,
"k": 2,
"metric": "cosine",
"nprobes": 20,
"refine_factor": None,
},
)
def cosine_distance(vec1, vec2):
return 1 - np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))