mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-08 12:52:58 +00:00
Specify and Index Column for Vector Search (#217)
This commit is contained in:
@@ -45,7 +45,12 @@ class LanceQueryBuilder:
|
||||
0 6 [0.4, 0.4] 0.0
|
||||
"""
|
||||
|
||||
def __init__(self, table: "lancedb.table.LanceTable", query: np.ndarray):
|
||||
def __init__(
|
||||
self,
|
||||
table: "lancedb.table.LanceTable",
|
||||
query: np.ndarray,
|
||||
vector_column_name: str = VECTOR_COLUMN_NAME,
|
||||
):
|
||||
self._metric = "L2"
|
||||
self._nprobes = 20
|
||||
self._refine_factor = None
|
||||
@@ -54,6 +59,7 @@ class LanceQueryBuilder:
|
||||
self._limit = 10
|
||||
self._columns = None
|
||||
self._where = None
|
||||
self._vector_column_name = vector_column_name
|
||||
|
||||
def limit(self, limit: int) -> LanceQueryBuilder:
|
||||
"""Set the maximum number of results to return.
|
||||
@@ -195,7 +201,7 @@ class LanceQueryBuilder:
|
||||
columns=self._columns,
|
||||
filter=self._where,
|
||||
nearest={
|
||||
"column": VECTOR_COLUMN_NAME,
|
||||
"column": self._vector_column_name,
|
||||
"q": self._query,
|
||||
"k": self._limit,
|
||||
"metric": self._metric,
|
||||
|
||||
@@ -182,7 +182,13 @@ class LanceTable:
|
||||
def _dataset_uri(self) -> str:
|
||||
return os.path.join(self._conn.uri, f"{self.name}.lance")
|
||||
|
||||
def create_index(self, metric="L2", num_partitions=256, num_sub_vectors=96):
|
||||
def create_index(
|
||||
self,
|
||||
metric="L2",
|
||||
num_partitions=256,
|
||||
num_sub_vectors=96,
|
||||
vector_column_name=VECTOR_COLUMN_NAME,
|
||||
):
|
||||
"""Create an index on the table.
|
||||
|
||||
Parameters
|
||||
@@ -198,7 +204,7 @@ class LanceTable:
|
||||
Default is 96.
|
||||
"""
|
||||
self._dataset.create_index(
|
||||
column=VECTOR_COLUMN_NAME,
|
||||
column=vector_column_name,
|
||||
index_type="IVF_PQ",
|
||||
metric=metric,
|
||||
num_partitions=num_partitions,
|
||||
@@ -256,7 +262,9 @@ class LanceTable:
|
||||
self._reset_dataset()
|
||||
return len(self)
|
||||
|
||||
def search(self, query: Union[VEC, str]) -> LanceQueryBuilder:
|
||||
def search(
|
||||
self, query: Union[VEC, str], vector_column_name=VECTOR_COLUMN_NAME
|
||||
) -> LanceQueryBuilder:
|
||||
"""Create a search query to find the nearest neighbors
|
||||
of the given query vector.
|
||||
|
||||
@@ -275,7 +283,7 @@ class LanceTable:
|
||||
"""
|
||||
if isinstance(query, str):
|
||||
# fts
|
||||
return LanceFtsQueryBuilder(self, query)
|
||||
return LanceFtsQueryBuilder(self, query, vector_column_name)
|
||||
|
||||
if isinstance(query, list):
|
||||
query = np.array(query)
|
||||
@@ -283,7 +291,7 @@ class LanceTable:
|
||||
query = query.astype(np.float32)
|
||||
else:
|
||||
raise TypeError(f"Unsupported query type: {type(query)}")
|
||||
return LanceQueryBuilder(self, query)
|
||||
return LanceQueryBuilder(self, query, vector_column_name)
|
||||
|
||||
@classmethod
|
||||
def create(cls, db, name, data, schema=None, mode="create"):
|
||||
|
||||
@@ -11,6 +11,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest.mock as mock
|
||||
|
||||
import lance
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@@ -20,6 +22,7 @@ import pytest
|
||||
|
||||
from lancedb.db import LanceDBConnection
|
||||
from lancedb.query import LanceQueryBuilder
|
||||
from lancedb.table import LanceTable
|
||||
|
||||
|
||||
class MockTable:
|
||||
@@ -48,24 +51,30 @@ def table(tmp_path) -> MockTable:
|
||||
|
||||
|
||||
def test_query_builder(table):
|
||||
df = LanceQueryBuilder(table, [0, 0]).limit(1).select(["id"]).to_df()
|
||||
df = LanceQueryBuilder(table, [0, 0], "vector").limit(1).select(["id"]).to_df()
|
||||
assert df["id"].values[0] == 1
|
||||
assert all(df["vector"].values[0] == [1, 2])
|
||||
|
||||
|
||||
def test_query_builder_with_filter(table):
|
||||
df = LanceQueryBuilder(table, [0, 0]).where("id = 2").to_df()
|
||||
df = LanceQueryBuilder(table, [0, 0], "vector").where("id = 2").to_df()
|
||||
assert df["id"].values[0] == 2
|
||||
assert all(df["vector"].values[0] == [3, 4])
|
||||
|
||||
|
||||
def test_query_builder_with_metric(table):
|
||||
query = [4, 8]
|
||||
df_default = LanceQueryBuilder(table, query).to_df()
|
||||
df_l2 = LanceQueryBuilder(table, query).metric("L2").to_df()
|
||||
vector_column_name = "vector"
|
||||
df_default = LanceQueryBuilder(table, query, vector_column_name).to_df()
|
||||
df_l2 = LanceQueryBuilder(table, query, vector_column_name).metric("L2").to_df()
|
||||
tm.assert_frame_equal(df_default, df_l2)
|
||||
|
||||
df_cosine = LanceQueryBuilder(table, query).metric("cosine").limit(1).to_df()
|
||||
df_cosine = (
|
||||
LanceQueryBuilder(table, query, vector_column_name)
|
||||
.metric("cosine")
|
||||
.limit(1)
|
||||
.to_df()
|
||||
)
|
||||
assert df_cosine.score[0] == pytest.approx(
|
||||
cosine_distance(query, df_cosine.vector[0]),
|
||||
abs=1e-6,
|
||||
@@ -73,5 +82,35 @@ def test_query_builder_with_metric(table):
|
||||
assert 0 <= df_cosine.score[0] <= 1
|
||||
|
||||
|
||||
def test_query_builder_with_different_vector_column():
|
||||
table = mock.MagicMock(spec=LanceTable)
|
||||
query = [4, 8]
|
||||
vector_column_name = "foo_vector"
|
||||
builder = (
|
||||
LanceQueryBuilder(table, query, vector_column_name)
|
||||
.metric("cosine")
|
||||
.where("b < 10")
|
||||
.select(["b"])
|
||||
.limit(2)
|
||||
)
|
||||
ds = mock.Mock()
|
||||
table.to_lance.return_value = ds
|
||||
table._conn = mock.MagicMock()
|
||||
table._conn.is_managed_remote = False
|
||||
builder.to_arrow()
|
||||
ds.to_table.assert_called_once_with(
|
||||
columns=["b"],
|
||||
filter="b < 10",
|
||||
nearest={
|
||||
"column": vector_column_name,
|
||||
"q": query,
|
||||
"k": 2,
|
||||
"metric": "cosine",
|
||||
"nprobes": 20,
|
||||
"refine_factor": None,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def cosine_distance(vec1, vec2):
|
||||
return 1 - np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
|
||||
|
||||
@@ -13,11 +13,13 @@
|
||||
|
||||
import functools
|
||||
from pathlib import Path
|
||||
from unittest.mock import PropertyMock, patch
|
||||
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
|
||||
from lancedb.db import LanceDBConnection
|
||||
from lancedb.table import LanceTable
|
||||
|
||||
|
||||
@@ -142,3 +144,33 @@ def test_versioning(db):
|
||||
table.checkout(1)
|
||||
assert table.version == 1
|
||||
assert len(table) == 2
|
||||
|
||||
|
||||
def test_create_index_method():
|
||||
with patch.object(LanceTable, "_reset_dataset", return_value=None):
|
||||
with patch.object(
|
||||
LanceTable, "_dataset", new_callable=PropertyMock
|
||||
) as mock_dataset:
|
||||
# Setup mock responses
|
||||
mock_dataset.return_value.create_index.return_value = None
|
||||
|
||||
# Create a LanceTable object
|
||||
connection = LanceDBConnection(uri="mock.uri")
|
||||
table = LanceTable(connection, "test_table")
|
||||
|
||||
# Call the create_index method
|
||||
table.create_index(
|
||||
metric="L2",
|
||||
num_partitions=256,
|
||||
num_sub_vectors=96,
|
||||
vector_column_name="vector",
|
||||
)
|
||||
|
||||
# Check that the _dataset.create_index method was called with the right parameters
|
||||
mock_dataset.return_value.create_index.assert_called_once_with(
|
||||
column="vector",
|
||||
index_type="IVF_PQ",
|
||||
metric="L2",
|
||||
num_partitions=256,
|
||||
num_sub_vectors=96,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user