Specify and Index Column for Vector Search (#217)

This commit is contained in:
Philip Kung
2023-06-26 16:11:08 -07:00
committed by GitHub
parent e850df56f1
commit 313e66c4c5
4 changed files with 97 additions and 12 deletions

View File

@@ -45,7 +45,12 @@ class LanceQueryBuilder:
0 6 [0.4, 0.4] 0.0
"""
def __init__(self, table: "lancedb.table.LanceTable", query: np.ndarray):
def __init__(
self,
table: "lancedb.table.LanceTable",
query: np.ndarray,
vector_column_name: str = VECTOR_COLUMN_NAME,
):
self._metric = "L2"
self._nprobes = 20
self._refine_factor = None
@@ -54,6 +59,7 @@ class LanceQueryBuilder:
self._limit = 10
self._columns = None
self._where = None
self._vector_column_name = vector_column_name
def limit(self, limit: int) -> LanceQueryBuilder:
"""Set the maximum number of results to return.
@@ -195,7 +201,7 @@ class LanceQueryBuilder:
columns=self._columns,
filter=self._where,
nearest={
"column": VECTOR_COLUMN_NAME,
"column": self._vector_column_name,
"q": self._query,
"k": self._limit,
"metric": self._metric,

View File

@@ -182,7 +182,13 @@ class LanceTable:
def _dataset_uri(self) -> str:
return os.path.join(self._conn.uri, f"{self.name}.lance")
def create_index(self, metric="L2", num_partitions=256, num_sub_vectors=96):
def create_index(
self,
metric="L2",
num_partitions=256,
num_sub_vectors=96,
vector_column_name=VECTOR_COLUMN_NAME,
):
"""Create an index on the table.
Parameters
@@ -198,7 +204,7 @@ class LanceTable:
Default is 96.
"""
self._dataset.create_index(
column=VECTOR_COLUMN_NAME,
column=vector_column_name,
index_type="IVF_PQ",
metric=metric,
num_partitions=num_partitions,
@@ -256,7 +262,9 @@ class LanceTable:
self._reset_dataset()
return len(self)
def search(self, query: Union[VEC, str]) -> LanceQueryBuilder:
def search(
self, query: Union[VEC, str], vector_column_name=VECTOR_COLUMN_NAME
) -> LanceQueryBuilder:
"""Create a search query to find the nearest neighbors
of the given query vector.
@@ -275,7 +283,7 @@ class LanceTable:
"""
if isinstance(query, str):
# fts
return LanceFtsQueryBuilder(self, query)
return LanceFtsQueryBuilder(self, query, vector_column_name)
if isinstance(query, list):
query = np.array(query)
@@ -283,7 +291,7 @@ class LanceTable:
query = query.astype(np.float32)
else:
raise TypeError(f"Unsupported query type: {type(query)}")
return LanceQueryBuilder(self, query)
return LanceQueryBuilder(self, query, vector_column_name)
@classmethod
def create(cls, db, name, data, schema=None, mode="create"):