mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-17 20:10:39 +00:00
Specify and Index Column for Vector Search (#217)
This commit is contained in:
@@ -45,7 +45,12 @@ class LanceQueryBuilder:
|
||||
0 6 [0.4, 0.4] 0.0
|
||||
"""
|
||||
|
||||
def __init__(self, table: "lancedb.table.LanceTable", query: np.ndarray):
|
||||
def __init__(
|
||||
self,
|
||||
table: "lancedb.table.LanceTable",
|
||||
query: np.ndarray,
|
||||
vector_column_name: str = VECTOR_COLUMN_NAME,
|
||||
):
|
||||
self._metric = "L2"
|
||||
self._nprobes = 20
|
||||
self._refine_factor = None
|
||||
@@ -54,6 +59,7 @@ class LanceQueryBuilder:
|
||||
self._limit = 10
|
||||
self._columns = None
|
||||
self._where = None
|
||||
self._vector_column_name = vector_column_name
|
||||
|
||||
def limit(self, limit: int) -> LanceQueryBuilder:
|
||||
"""Set the maximum number of results to return.
|
||||
@@ -195,7 +201,7 @@ class LanceQueryBuilder:
|
||||
columns=self._columns,
|
||||
filter=self._where,
|
||||
nearest={
|
||||
"column": VECTOR_COLUMN_NAME,
|
||||
"column": self._vector_column_name,
|
||||
"q": self._query,
|
||||
"k": self._limit,
|
||||
"metric": self._metric,
|
||||
|
||||
@@ -182,7 +182,13 @@ class LanceTable:
|
||||
def _dataset_uri(self) -> str:
|
||||
return os.path.join(self._conn.uri, f"{self.name}.lance")
|
||||
|
||||
def create_index(self, metric="L2", num_partitions=256, num_sub_vectors=96):
|
||||
def create_index(
|
||||
self,
|
||||
metric="L2",
|
||||
num_partitions=256,
|
||||
num_sub_vectors=96,
|
||||
vector_column_name=VECTOR_COLUMN_NAME,
|
||||
):
|
||||
"""Create an index on the table.
|
||||
|
||||
Parameters
|
||||
@@ -198,7 +204,7 @@ class LanceTable:
|
||||
Default is 96.
|
||||
"""
|
||||
self._dataset.create_index(
|
||||
column=VECTOR_COLUMN_NAME,
|
||||
column=vector_column_name,
|
||||
index_type="IVF_PQ",
|
||||
metric=metric,
|
||||
num_partitions=num_partitions,
|
||||
@@ -256,7 +262,9 @@ class LanceTable:
|
||||
self._reset_dataset()
|
||||
return len(self)
|
||||
|
||||
def search(self, query: Union[VEC, str]) -> LanceQueryBuilder:
|
||||
def search(
|
||||
self, query: Union[VEC, str], vector_column_name=VECTOR_COLUMN_NAME
|
||||
) -> LanceQueryBuilder:
|
||||
"""Create a search query to find the nearest neighbors
|
||||
of the given query vector.
|
||||
|
||||
@@ -275,7 +283,7 @@ class LanceTable:
|
||||
"""
|
||||
if isinstance(query, str):
|
||||
# fts
|
||||
return LanceFtsQueryBuilder(self, query)
|
||||
return LanceFtsQueryBuilder(self, query, vector_column_name)
|
||||
|
||||
if isinstance(query, list):
|
||||
query = np.array(query)
|
||||
@@ -283,7 +291,7 @@ class LanceTable:
|
||||
query = query.astype(np.float32)
|
||||
else:
|
||||
raise TypeError(f"Unsupported query type: {type(query)}")
|
||||
return LanceQueryBuilder(self, query)
|
||||
return LanceQueryBuilder(self, query, vector_column_name)
|
||||
|
||||
@classmethod
|
||||
def create(cls, db, name, data, schema=None, mode="create"):
|
||||
|
||||
Reference in New Issue
Block a user