From 313e66c4c58eb9daadc6fc5a1d0b6f863c63c4f5 Mon Sep 17 00:00:00 2001 From: Philip Kung Date: Mon, 26 Jun 2023 16:11:08 -0700 Subject: [PATCH] Specify and Index Column for Vector Search (#217) --- python/lancedb/query.py | 10 ++++++-- python/lancedb/table.py | 18 ++++++++++---- python/tests/test_query.py | 49 ++++++++++++++++++++++++++++++++++---- python/tests/test_table.py | 32 +++++++++++++++++++++++++ 4 files changed, 97 insertions(+), 12 deletions(-) diff --git a/python/lancedb/query.py b/python/lancedb/query.py index cb1e51c7..bccc6905 100644 --- a/python/lancedb/query.py +++ b/python/lancedb/query.py @@ -45,7 +45,12 @@ class LanceQueryBuilder: 0 6 [0.4, 0.4] 0.0 """ - def __init__(self, table: "lancedb.table.LanceTable", query: np.ndarray): + def __init__( + self, + table: "lancedb.table.LanceTable", + query: np.ndarray, + vector_column_name: str = VECTOR_COLUMN_NAME, + ): self._metric = "L2" self._nprobes = 20 self._refine_factor = None @@ -54,6 +59,7 @@ class LanceQueryBuilder: self._limit = 10 self._columns = None self._where = None + self._vector_column_name = vector_column_name def limit(self, limit: int) -> LanceQueryBuilder: """Set the maximum number of results to return. @@ -195,7 +201,7 @@ class LanceQueryBuilder: columns=self._columns, filter=self._where, nearest={ - "column": VECTOR_COLUMN_NAME, + "column": self._vector_column_name, "q": self._query, "k": self._limit, "metric": self._metric, diff --git a/python/lancedb/table.py b/python/lancedb/table.py index 5eacb3a1..b4e4daae 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -182,7 +182,13 @@ class LanceTable: def _dataset_uri(self) -> str: return os.path.join(self._conn.uri, f"{self.name}.lance") - def create_index(self, metric="L2", num_partitions=256, num_sub_vectors=96): + def create_index( + self, + metric="L2", + num_partitions=256, + num_sub_vectors=96, + vector_column_name=VECTOR_COLUMN_NAME, + ): """Create an index on the table. Parameters @@ -198,7 +204,7 @@ class LanceTable: Default is 96. """ self._dataset.create_index( - column=VECTOR_COLUMN_NAME, + column=vector_column_name, index_type="IVF_PQ", metric=metric, num_partitions=num_partitions, @@ -256,7 +262,9 @@ class LanceTable: self._reset_dataset() return len(self) - def search(self, query: Union[VEC, str]) -> LanceQueryBuilder: + def search( + self, query: Union[VEC, str], vector_column_name=VECTOR_COLUMN_NAME + ) -> LanceQueryBuilder: """Create a search query to find the nearest neighbors of the given query vector. @@ -275,7 +283,7 @@ class LanceTable: """ if isinstance(query, str): # fts - return LanceFtsQueryBuilder(self, query) + return LanceFtsQueryBuilder(self, query, vector_column_name) if isinstance(query, list): query = np.array(query) @@ -283,7 +291,7 @@ class LanceTable: query = query.astype(np.float32) else: raise TypeError(f"Unsupported query type: {type(query)}") - return LanceQueryBuilder(self, query) + return LanceQueryBuilder(self, query, vector_column_name) @classmethod def create(cls, db, name, data, schema=None, mode="create"): diff --git a/python/tests/test_query.py b/python/tests/test_query.py index 55713544..1af20f2b 100644 --- a/python/tests/test_query.py +++ b/python/tests/test_query.py @@ -11,6 +11,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import unittest.mock as mock + import lance import numpy as np import pandas as pd @@ -20,6 +22,7 @@ import pytest from lancedb.db import LanceDBConnection from lancedb.query import LanceQueryBuilder +from lancedb.table import LanceTable class MockTable: @@ -48,24 +51,30 @@ def table(tmp_path) -> MockTable: def test_query_builder(table): - df = LanceQueryBuilder(table, [0, 0]).limit(1).select(["id"]).to_df() + df = LanceQueryBuilder(table, [0, 0], "vector").limit(1).select(["id"]).to_df() assert df["id"].values[0] == 1 assert all(df["vector"].values[0] == [1, 2]) def test_query_builder_with_filter(table): - df = LanceQueryBuilder(table, [0, 0]).where("id = 2").to_df() + df = LanceQueryBuilder(table, [0, 0], "vector").where("id = 2").to_df() assert df["id"].values[0] == 2 assert all(df["vector"].values[0] == [3, 4]) def test_query_builder_with_metric(table): query = [4, 8] - df_default = LanceQueryBuilder(table, query).to_df() - df_l2 = LanceQueryBuilder(table, query).metric("L2").to_df() + vector_column_name = "vector" + df_default = LanceQueryBuilder(table, query, vector_column_name).to_df() + df_l2 = LanceQueryBuilder(table, query, vector_column_name).metric("L2").to_df() tm.assert_frame_equal(df_default, df_l2) - df_cosine = LanceQueryBuilder(table, query).metric("cosine").limit(1).to_df() + df_cosine = ( + LanceQueryBuilder(table, query, vector_column_name) + .metric("cosine") + .limit(1) + .to_df() + ) assert df_cosine.score[0] == pytest.approx( cosine_distance(query, df_cosine.vector[0]), abs=1e-6, @@ -73,5 +82,35 @@ def test_query_builder_with_metric(table): assert 0 <= df_cosine.score[0] <= 1 +def test_query_builder_with_different_vector_column(): + table = mock.MagicMock(spec=LanceTable) + query = [4, 8] + vector_column_name = "foo_vector" + builder = ( + LanceQueryBuilder(table, query, vector_column_name) + .metric("cosine") + .where("b < 10") + .select(["b"]) + .limit(2) + ) + ds = mock.Mock() + table.to_lance.return_value = ds + table._conn = mock.MagicMock() + table._conn.is_managed_remote = False + builder.to_arrow() + ds.to_table.assert_called_once_with( + columns=["b"], + filter="b < 10", + nearest={ + "column": vector_column_name, + "q": query, + "k": 2, + "metric": "cosine", + "nprobes": 20, + "refine_factor": None, + }, + ) + + def cosine_distance(vec1, vec2): return 1 - np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2)) diff --git a/python/tests/test_table.py b/python/tests/test_table.py index cf672fa8..6021dc2b 100644 --- a/python/tests/test_table.py +++ b/python/tests/test_table.py @@ -13,11 +13,13 @@ import functools from pathlib import Path +from unittest.mock import PropertyMock, patch import pandas as pd import pyarrow as pa import pytest +from lancedb.db import LanceDBConnection from lancedb.table import LanceTable @@ -142,3 +144,33 @@ def test_versioning(db): table.checkout(1) assert table.version == 1 assert len(table) == 2 + + +def test_create_index_method(): + with patch.object(LanceTable, "_reset_dataset", return_value=None): + with patch.object( + LanceTable, "_dataset", new_callable=PropertyMock + ) as mock_dataset: + # Setup mock responses + mock_dataset.return_value.create_index.return_value = None + + # Create a LanceTable object + connection = LanceDBConnection(uri="mock.uri") + table = LanceTable(connection, "test_table") + + # Call the create_index method + table.create_index( + metric="L2", + num_partitions=256, + num_sub_vectors=96, + vector_column_name="vector", + ) + + # Check that the _dataset.create_index method was called with the right parameters + mock_dataset.return_value.create_index.assert_called_once_with( + column="vector", + index_type="IVF_PQ", + metric="L2", + num_partitions=256, + num_sub_vectors=96, + )