Specify and Index Column for Vector Search (#217)

This commit is contained in:
Philip Kung
2023-06-26 16:11:08 -07:00
committed by GitHub
parent e850df56f1
commit 313e66c4c5
4 changed files with 97 additions and 12 deletions

View File

@@ -45,7 +45,12 @@ class LanceQueryBuilder:
0 6 [0.4, 0.4] 0.0
"""
def __init__(self, table: "lancedb.table.LanceTable", query: np.ndarray):
def __init__(
self,
table: "lancedb.table.LanceTable",
query: np.ndarray,
vector_column_name: str = VECTOR_COLUMN_NAME,
):
self._metric = "L2"
self._nprobes = 20
self._refine_factor = None
@@ -54,6 +59,7 @@ class LanceQueryBuilder:
self._limit = 10
self._columns = None
self._where = None
self._vector_column_name = vector_column_name
def limit(self, limit: int) -> LanceQueryBuilder:
"""Set the maximum number of results to return.
@@ -195,7 +201,7 @@ class LanceQueryBuilder:
columns=self._columns,
filter=self._where,
nearest={
"column": VECTOR_COLUMN_NAME,
"column": self._vector_column_name,
"q": self._query,
"k": self._limit,
"metric": self._metric,

View File

@@ -182,7 +182,13 @@ class LanceTable:
def _dataset_uri(self) -> str:
return os.path.join(self._conn.uri, f"{self.name}.lance")
def create_index(self, metric="L2", num_partitions=256, num_sub_vectors=96):
def create_index(
self,
metric="L2",
num_partitions=256,
num_sub_vectors=96,
vector_column_name=VECTOR_COLUMN_NAME,
):
"""Create an index on the table.
Parameters
@@ -198,7 +204,7 @@ class LanceTable:
Default is 96.
"""
self._dataset.create_index(
column=VECTOR_COLUMN_NAME,
column=vector_column_name,
index_type="IVF_PQ",
metric=metric,
num_partitions=num_partitions,
@@ -256,7 +262,9 @@ class LanceTable:
self._reset_dataset()
return len(self)
def search(self, query: Union[VEC, str]) -> LanceQueryBuilder:
def search(
self, query: Union[VEC, str], vector_column_name=VECTOR_COLUMN_NAME
) -> LanceQueryBuilder:
"""Create a search query to find the nearest neighbors
of the given query vector.
@@ -275,7 +283,7 @@ class LanceTable:
"""
if isinstance(query, str):
# fts
return LanceFtsQueryBuilder(self, query)
return LanceFtsQueryBuilder(self, query, vector_column_name)
if isinstance(query, list):
query = np.array(query)
@@ -283,7 +291,7 @@ class LanceTable:
query = query.astype(np.float32)
else:
raise TypeError(f"Unsupported query type: {type(query)}")
return LanceQueryBuilder(self, query)
return LanceQueryBuilder(self, query, vector_column_name)
@classmethod
def create(cls, db, name, data, schema=None, mode="create"):

View File

@@ -11,6 +11,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest.mock as mock
import lance
import numpy as np
import pandas as pd
@@ -20,6 +22,7 @@ import pytest
from lancedb.db import LanceDBConnection
from lancedb.query import LanceQueryBuilder
from lancedb.table import LanceTable
class MockTable:
@@ -48,24 +51,30 @@ def table(tmp_path) -> MockTable:
def test_query_builder(table):
df = LanceQueryBuilder(table, [0, 0]).limit(1).select(["id"]).to_df()
df = LanceQueryBuilder(table, [0, 0], "vector").limit(1).select(["id"]).to_df()
assert df["id"].values[0] == 1
assert all(df["vector"].values[0] == [1, 2])
def test_query_builder_with_filter(table):
df = LanceQueryBuilder(table, [0, 0]).where("id = 2").to_df()
df = LanceQueryBuilder(table, [0, 0], "vector").where("id = 2").to_df()
assert df["id"].values[0] == 2
assert all(df["vector"].values[0] == [3, 4])
def test_query_builder_with_metric(table):
query = [4, 8]
df_default = LanceQueryBuilder(table, query).to_df()
df_l2 = LanceQueryBuilder(table, query).metric("L2").to_df()
vector_column_name = "vector"
df_default = LanceQueryBuilder(table, query, vector_column_name).to_df()
df_l2 = LanceQueryBuilder(table, query, vector_column_name).metric("L2").to_df()
tm.assert_frame_equal(df_default, df_l2)
df_cosine = LanceQueryBuilder(table, query).metric("cosine").limit(1).to_df()
df_cosine = (
LanceQueryBuilder(table, query, vector_column_name)
.metric("cosine")
.limit(1)
.to_df()
)
assert df_cosine.score[0] == pytest.approx(
cosine_distance(query, df_cosine.vector[0]),
abs=1e-6,
@@ -73,5 +82,35 @@ def test_query_builder_with_metric(table):
assert 0 <= df_cosine.score[0] <= 1
def test_query_builder_with_different_vector_column():
table = mock.MagicMock(spec=LanceTable)
query = [4, 8]
vector_column_name = "foo_vector"
builder = (
LanceQueryBuilder(table, query, vector_column_name)
.metric("cosine")
.where("b < 10")
.select(["b"])
.limit(2)
)
ds = mock.Mock()
table.to_lance.return_value = ds
table._conn = mock.MagicMock()
table._conn.is_managed_remote = False
builder.to_arrow()
ds.to_table.assert_called_once_with(
columns=["b"],
filter="b < 10",
nearest={
"column": vector_column_name,
"q": query,
"k": 2,
"metric": "cosine",
"nprobes": 20,
"refine_factor": None,
},
)
def cosine_distance(vec1, vec2):
return 1 - np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))

View File

@@ -13,11 +13,13 @@
import functools
from pathlib import Path
from unittest.mock import PropertyMock, patch
import pandas as pd
import pyarrow as pa
import pytest
from lancedb.db import LanceDBConnection
from lancedb.table import LanceTable
@@ -142,3 +144,33 @@ def test_versioning(db):
table.checkout(1)
assert table.version == 1
assert len(table) == 2
def test_create_index_method():
with patch.object(LanceTable, "_reset_dataset", return_value=None):
with patch.object(
LanceTable, "_dataset", new_callable=PropertyMock
) as mock_dataset:
# Setup mock responses
mock_dataset.return_value.create_index.return_value = None
# Create a LanceTable object
connection = LanceDBConnection(uri="mock.uri")
table = LanceTable(connection, "test_table")
# Call the create_index method
table.create_index(
metric="L2",
num_partitions=256,
num_sub_vectors=96,
vector_column_name="vector",
)
# Check that the _dataset.create_index method was called with the right parameters
mock_dataset.return_value.create_index.assert_called_once_with(
column="vector",
index_type="IVF_PQ",
metric="L2",
num_partitions=256,
num_sub_vectors=96,
)