mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-23 05:19:58 +00:00
Compare commits
17 Commits
lu/python-
...
python-v0.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1e0cc69401 | ||
|
|
f31e0c749d | ||
|
|
7a3ef68306 | ||
|
|
43952e01d7 | ||
|
|
495c335831 | ||
|
|
77707db543 | ||
|
|
d6d7ad3b06 | ||
|
|
e58d64c286 | ||
|
|
76cbd18c46 | ||
|
|
4abb38ac70 | ||
|
|
cc7bc5011d | ||
|
|
8193183304 | ||
|
|
cf28b58b7d | ||
|
|
e3b7ee47b9 | ||
|
|
97c9c906e4 | ||
|
|
358f86b9c6 | ||
|
|
5489e215a3 |
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.9.0"
|
||||
current_version = "0.9.0-beta.5"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb-python"
|
||||
version = "0.9.0"
|
||||
version = "0.9.0-beta.5"
|
||||
edition.workspace = true
|
||||
description = "Python bindings for LanceDB"
|
||||
license.workspace = true
|
||||
|
||||
@@ -13,6 +13,7 @@ dependencies = [
|
||||
"packaging",
|
||||
"cachetools",
|
||||
"overrides>=0.7",
|
||||
"urllib3==1.26.19"
|
||||
]
|
||||
description = "lancedb"
|
||||
authors = [{ name = "LanceDB Devs", email = "dev@lancedb.com" }]
|
||||
|
||||
@@ -35,6 +35,7 @@ def connect(
|
||||
host_override: Optional[str] = None,
|
||||
read_consistency_interval: Optional[timedelta] = None,
|
||||
request_thread_pool: Optional[Union[int, ThreadPoolExecutor]] = None,
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> DBConnection:
|
||||
"""Connect to a LanceDB database.
|
||||
@@ -70,6 +71,9 @@ def connect(
|
||||
executor will be used for making requests. This is for LanceDB Cloud
|
||||
only and is only used when making batch requests (i.e., passing in
|
||||
multiple queries to the search method at once).
|
||||
storage_options: dict, optional
|
||||
Additional options for the storage backend. See available options at
|
||||
https://lancedb.github.io/lancedb/guides/storage/
|
||||
|
||||
Examples
|
||||
--------
|
||||
@@ -105,12 +109,16 @@ def connect(
|
||||
region,
|
||||
host_override,
|
||||
request_thread_pool=request_thread_pool,
|
||||
storage_options=storage_options,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if kwargs:
|
||||
raise ValueError(f"Unknown keyword arguments: {kwargs}")
|
||||
return LanceDBConnection(uri, read_consistency_interval=read_consistency_interval)
|
||||
return LanceDBConnection(
|
||||
uri,
|
||||
read_consistency_interval=read_consistency_interval,
|
||||
)
|
||||
|
||||
|
||||
async def connect_async(
|
||||
|
||||
@@ -117,6 +117,8 @@ class Query(pydantic.BaseModel):
|
||||
|
||||
with_row_id: bool = False
|
||||
|
||||
fast_search: bool = False
|
||||
|
||||
|
||||
class LanceQueryBuilder(ABC):
|
||||
"""An abstract query builder. Subclasses are defined for vector search,
|
||||
@@ -125,12 +127,14 @@ class LanceQueryBuilder(ABC):
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
table: "Table",
|
||||
query: Optional[Union[np.ndarray, str, "PIL.Image.Image", Tuple]],
|
||||
query_type: str,
|
||||
vector_column_name: str,
|
||||
ordering_field_name: str = None,
|
||||
cls,
|
||||
table: "Table",
|
||||
query: Optional[Union[np.ndarray, str, "PIL.Image.Image", Tuple]],
|
||||
query_type: str,
|
||||
vector_column_name: str,
|
||||
ordering_field_name: Optional[str] = None,
|
||||
fts_columns: Union[str, List[str]] = [],
|
||||
fast_search: bool = False,
|
||||
) -> LanceQueryBuilder:
|
||||
"""
|
||||
Create a query builder based on the given query and query type.
|
||||
@@ -147,13 +151,18 @@ class LanceQueryBuilder(ABC):
|
||||
If "auto", the query type is inferred based on the query.
|
||||
vector_column_name: str
|
||||
The name of the vector column to use for vector search.
|
||||
fast_search: bool
|
||||
Skip flat search of unindexed data.
|
||||
"""
|
||||
if query is None:
|
||||
return LanceEmptyQueryBuilder(table)
|
||||
|
||||
# Check hybrid search first as it supports empty query pattern
|
||||
if query_type == "hybrid":
|
||||
# hybrid fts and vector query
|
||||
return LanceHybridQueryBuilder(table, query, vector_column_name)
|
||||
return LanceHybridQueryBuilder(
|
||||
table, query, vector_column_name, fts_columns=fts_columns
|
||||
)
|
||||
|
||||
if query is None:
|
||||
return LanceEmptyQueryBuilder(table)
|
||||
|
||||
# remember the string query for reranking purpose
|
||||
str_query = query if isinstance(query, str) else None
|
||||
@@ -165,12 +174,17 @@ class LanceQueryBuilder(ABC):
|
||||
)
|
||||
|
||||
if query_type == "hybrid":
|
||||
return LanceHybridQueryBuilder(table, query, vector_column_name)
|
||||
return LanceHybridQueryBuilder(
|
||||
table, query, vector_column_name, fts_columns=fts_columns
|
||||
)
|
||||
|
||||
if isinstance(query, str):
|
||||
# fts
|
||||
return LanceFtsQueryBuilder(
|
||||
table, query, ordering_field_name=ordering_field_name
|
||||
table,
|
||||
query,
|
||||
ordering_field_name=ordering_field_name,
|
||||
fts_columns=fts_columns,
|
||||
)
|
||||
|
||||
if isinstance(query, list):
|
||||
@@ -180,7 +194,9 @@ class LanceQueryBuilder(ABC):
|
||||
else:
|
||||
raise TypeError(f"Unsupported query type: {type(query)}")
|
||||
|
||||
return LanceVectorQueryBuilder(table, query, vector_column_name, str_query)
|
||||
return LanceVectorQueryBuilder(
|
||||
table, query, vector_column_name, str_query, fast_search
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _resolve_query(cls, table, query, query_type, vector_column_name):
|
||||
@@ -196,8 +212,6 @@ class LanceQueryBuilder(ABC):
|
||||
elif query_type == "auto":
|
||||
if isinstance(query, (list, np.ndarray)):
|
||||
return query, "vector"
|
||||
if isinstance(query, tuple):
|
||||
return query, "hybrid"
|
||||
else:
|
||||
conf = table.embedding_functions.get(vector_column_name)
|
||||
if conf is not None:
|
||||
@@ -224,9 +238,14 @@ class LanceQueryBuilder(ABC):
|
||||
def __init__(self, table: "Table"):
|
||||
self._table = table
|
||||
self._limit = 10
|
||||
self._offset = 0
|
||||
self._columns = None
|
||||
self._where = None
|
||||
self._prefilter = False
|
||||
self._with_row_id = False
|
||||
self._vector = None
|
||||
self._text = None
|
||||
self._ef = None
|
||||
|
||||
@deprecation.deprecated(
|
||||
deprecated_in="0.3.1",
|
||||
@@ -337,11 +356,13 @@ class LanceQueryBuilder(ABC):
|
||||
----------
|
||||
limit: int
|
||||
The maximum number of results to return.
|
||||
By default the query is limited to the first 10.
|
||||
Call this method and pass 0, a negative value,
|
||||
or None to remove the limit.
|
||||
*WARNING* if you have a large dataset, removing
|
||||
the limit can potentially result in reading a
|
||||
The default query limit is 10 results.
|
||||
For ANN/KNN queries, you must specify a limit.
|
||||
Entering 0, a negative number, or None will reset
|
||||
the limit to the default value of 10.
|
||||
*WARNING* if you have a large dataset, setting
|
||||
the limit to a large number, e.g. the table size,
|
||||
can potentially result in reading a
|
||||
large amount of data into memory and cause
|
||||
out of memory issues.
|
||||
|
||||
@@ -351,11 +372,33 @@ class LanceQueryBuilder(ABC):
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
if limit is None or limit <= 0:
|
||||
self._limit = None
|
||||
if isinstance(self, LanceVectorQueryBuilder):
|
||||
raise ValueError("Limit is required for ANN/KNN queries")
|
||||
else:
|
||||
self._limit = None
|
||||
else:
|
||||
self._limit = limit
|
||||
return self
|
||||
|
||||
def offset(self, offset: int) -> LanceQueryBuilder:
|
||||
"""Set the offset for the results.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
offset: int
|
||||
The offset to start fetching results from.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
if offset is None or offset <= 0:
|
||||
self._offset = 0
|
||||
else:
|
||||
self._offset = offset
|
||||
return self
|
||||
|
||||
def select(self, columns: Union[list[str], dict[str, str]]) -> LanceQueryBuilder:
|
||||
"""Set the columns to return.
|
||||
|
||||
@@ -417,6 +460,80 @@ class LanceQueryBuilder(ABC):
|
||||
self._with_row_id = with_row_id
|
||||
return self
|
||||
|
||||
def explain_plan(self, verbose: Optional[bool] = False) -> str:
|
||||
"""Return the execution plan for this query.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> table = db.create_table("my_table", [{"vector": [99, 99]}])
|
||||
>>> query = [100, 100]
|
||||
>>> plan = table.search(query).explain_plan(True)
|
||||
>>> print(plan) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
|
||||
ProjectionExec: expr=[vector@0 as vector, _distance@2 as _distance]
|
||||
GlobalLimitExec: skip=0, fetch=10
|
||||
FilterExec: _distance@2 IS NOT NULL
|
||||
SortExec: TopK(fetch=10), expr=[_distance@2 ASC NULLS LAST], preserve_partitioning=[false]
|
||||
KNNVectorDistance: metric=l2
|
||||
LanceScan: uri=..., projection=[vector], row_id=true, row_addr=false, ordered=false
|
||||
|
||||
Parameters
|
||||
----------
|
||||
verbose : bool, default False
|
||||
Use a verbose output format.
|
||||
|
||||
Returns
|
||||
-------
|
||||
plan : str
|
||||
""" # noqa: E501
|
||||
ds = self._table.to_lance()
|
||||
return ds.scanner(
|
||||
nearest={
|
||||
"column": self._vector_column,
|
||||
"q": self._query,
|
||||
"k": self._limit,
|
||||
"metric": self._metric,
|
||||
"nprobes": self._nprobes,
|
||||
"refine_factor": self._refine_factor,
|
||||
},
|
||||
prefilter=self._prefilter,
|
||||
filter=self._str_query,
|
||||
limit=self._limit,
|
||||
with_row_id=self._with_row_id,
|
||||
offset=self._offset,
|
||||
).explain_plan(verbose)
|
||||
|
||||
def vector(self, vector: Union[np.ndarray, list]) -> LanceQueryBuilder:
|
||||
"""Set the vector to search for.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
vector: np.ndarray or list
|
||||
The vector to search for.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def text(self, text: str) -> LanceQueryBuilder:
|
||||
"""Set the text to search for.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
text: str
|
||||
The text to search for.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
"""
|
||||
@@ -440,11 +557,12 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
table: "Table",
|
||||
query: Union[np.ndarray, list, "PIL.Image.Image"],
|
||||
vector_column: str,
|
||||
str_query: Optional[str] = None,
|
||||
self,
|
||||
table: "Table",
|
||||
query: Union[np.ndarray, list, "PIL.Image.Image"],
|
||||
vector_column: str,
|
||||
str_query: Optional[str] = None,
|
||||
fast_search: bool = False,
|
||||
):
|
||||
super().__init__(table)
|
||||
self._query = query
|
||||
@@ -455,13 +573,14 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
self._prefilter = False
|
||||
self._reranker = None
|
||||
self._str_query = str_query
|
||||
self._fast_search = fast_search
|
||||
|
||||
def metric(self, metric: Literal["L2", "cosine"]) -> LanceVectorQueryBuilder:
|
||||
def metric(self, metric: Literal["L2", "cosine", "dot"]) -> LanceVectorQueryBuilder:
|
||||
"""Set the distance metric to use.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
metric: "L2" or "cosine"
|
||||
metric: "L2" or "cosine" or "dot"
|
||||
The distance metric to use. By default "L2" is used.
|
||||
|
||||
Returns
|
||||
@@ -469,7 +588,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
LanceVectorQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._metric = metric
|
||||
self._metric = metric.lower()
|
||||
return self
|
||||
|
||||
def nprobes(self, nprobes: int) -> LanceVectorQueryBuilder:
|
||||
@@ -494,6 +613,28 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
self._nprobes = nprobes
|
||||
return self
|
||||
|
||||
def ef(self, ef: int) -> LanceVectorQueryBuilder:
|
||||
"""Set the number of candidates to consider during search.
|
||||
|
||||
Higher values will yield better recall (more likely to find vectors if
|
||||
they exist) at the expense of latency.
|
||||
|
||||
This only applies to the HNSW-related index.
|
||||
The default value is 1.5 * limit.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ef: int
|
||||
The number of candidates to consider during search.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceVectorQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._ef = ef
|
||||
return self
|
||||
|
||||
def refine_factor(self, refine_factor: int) -> LanceVectorQueryBuilder:
|
||||
"""Set the refine factor to use, increasing the number of vectors sampled.
|
||||
|
||||
@@ -554,15 +695,11 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
refine_factor=self._refine_factor,
|
||||
vector_column=self._vector_column,
|
||||
with_row_id=self._with_row_id,
|
||||
offset=self._offset,
|
||||
fast_search=self._fast_search,
|
||||
ef=self._ef,
|
||||
)
|
||||
result_set = self._table._execute_query(query, batch_size)
|
||||
if self._reranker is not None:
|
||||
rs_table = result_set.read_all()
|
||||
result_set = self._reranker.rerank_vector(self._str_query, rs_table)
|
||||
# convert result_set back to RecordBatchReader
|
||||
result_set = pa.RecordBatchReader.from_batches(
|
||||
result_set.schema, result_set.to_batches()
|
||||
)
|
||||
|
||||
return result_set
|
||||
|
||||
@@ -591,7 +728,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
return self
|
||||
|
||||
def rerank(
|
||||
self, reranker: Reranker, query_string: Optional[str] = None
|
||||
self, reranker: Reranker, query_string: Optional[str] = None
|
||||
) -> LanceVectorQueryBuilder:
|
||||
"""Rerank the results using the specified reranker.
|
||||
|
||||
@@ -756,12 +893,34 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
|
||||
class LanceEmptyQueryBuilder(LanceQueryBuilder):
|
||||
def to_arrow(self) -> pa.Table:
|
||||
ds = self._table.to_lance()
|
||||
return ds.to_table(
|
||||
return self.to_batches().read_all()
|
||||
|
||||
def to_batches(self, /, batch_size: Optional[int] = None) -> pa.RecordBatchReader:
|
||||
query = Query(
|
||||
columns=self._columns,
|
||||
filter=self._where,
|
||||
limit=self._limit,
|
||||
k=self._limit or 10,
|
||||
with_row_id=self._with_row_id,
|
||||
vector=[],
|
||||
# not actually respected in remote query
|
||||
offset=self._offset or 0,
|
||||
)
|
||||
return self._table._execute_query(query)
|
||||
|
||||
def rerank(self, reranker: Reranker) -> LanceEmptyQueryBuilder:
|
||||
"""Rerank the results using the specified reranker.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
reranker: Reranker
|
||||
The reranker to use.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceEmptyQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
raise NotImplementedError("Reranking is not yet supported.")
|
||||
|
||||
|
||||
class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
|
||||
@@ -55,11 +55,13 @@ class RestfulLanceDBClient:
|
||||
region: str
|
||||
api_key: Credential
|
||||
host_override: Optional[str] = attrs.field(default=None)
|
||||
db_prefix: Optional[str] = attrs.field(default=None)
|
||||
|
||||
closed: bool = attrs.field(default=False, init=False)
|
||||
|
||||
connection_timeout: float = attrs.field(default=120.0, kw_only=True)
|
||||
read_timeout: float = attrs.field(default=300.0, kw_only=True)
|
||||
storage_options: Optional[Dict[str, str]] = attrs.field(default=None, kw_only=True)
|
||||
|
||||
@functools.cached_property
|
||||
def session(self) -> requests.Session:
|
||||
@@ -92,6 +94,18 @@ class RestfulLanceDBClient:
|
||||
headers["Host"] = f"{self.db_name}.{self.region}.api.lancedb.com"
|
||||
if self.host_override:
|
||||
headers["x-lancedb-database"] = self.db_name
|
||||
if self.storage_options:
|
||||
if self.storage_options.get("account_name") is not None:
|
||||
headers["x-azure-storage-account-name"] = self.storage_options[
|
||||
"account_name"
|
||||
]
|
||||
if self.storage_options.get("azure_storage_account_name") is not None:
|
||||
headers["x-azure-storage-account-name"] = self.storage_options[
|
||||
"azure_storage_account_name"
|
||||
]
|
||||
if self.db_prefix:
|
||||
headers["x-lancedb-database-prefix"] = self.db_prefix
|
||||
|
||||
return headers
|
||||
|
||||
@staticmethod
|
||||
@@ -158,6 +172,7 @@ class RestfulLanceDBClient:
|
||||
headers["content-type"] = content_type
|
||||
if request_id is not None:
|
||||
headers["x-request-id"] = request_id
|
||||
|
||||
with self.session.post(
|
||||
urljoin(self.url, uri),
|
||||
headers=headers,
|
||||
@@ -245,7 +260,6 @@ def retry_adapter(options: Dict[str, Any]) -> HTTPAdapter:
|
||||
connect=connect_retries,
|
||||
read=read_retries,
|
||||
backoff_factor=backoff_factor,
|
||||
backoff_jitter=backoff_jitter,
|
||||
status_forcelist=statuses,
|
||||
allowed_methods=methods,
|
||||
)
|
||||
|
||||
@@ -15,7 +15,7 @@ import inspect
|
||||
import logging
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Iterable, List, Optional, Union
|
||||
from typing import Dict, Iterable, List, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from cachetools import TTLCache
|
||||
@@ -44,20 +44,25 @@ class RemoteDBConnection(DBConnection):
|
||||
request_thread_pool: Optional[ThreadPoolExecutor] = None,
|
||||
connection_timeout: float = 120.0,
|
||||
read_timeout: float = 300.0,
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
"""Connect to a remote LanceDB database."""
|
||||
parsed = urlparse(db_url)
|
||||
if parsed.scheme != "db":
|
||||
raise ValueError(f"Invalid scheme: {parsed.scheme}, only accepts db://")
|
||||
self.db_name = parsed.netloc
|
||||
prefix = parsed.path.lstrip("/")
|
||||
self.db_prefix = None if not prefix else prefix
|
||||
self.api_key = api_key
|
||||
self._client = RestfulLanceDBClient(
|
||||
self.db_name,
|
||||
region,
|
||||
api_key,
|
||||
host_override,
|
||||
self.db_prefix,
|
||||
connection_timeout=connection_timeout,
|
||||
read_timeout=read_timeout,
|
||||
storage_options=storage_options,
|
||||
)
|
||||
self._request_thread_pool = request_thread_pool
|
||||
self._table_cache = TTLCache(maxsize=10000, ttl=300)
|
||||
|
||||
@@ -15,13 +15,14 @@ import logging
|
||||
import uuid
|
||||
from concurrent.futures import Future
|
||||
from functools import cached_property
|
||||
from typing import Dict, Iterable, Optional, Union
|
||||
from typing import Dict, Iterable, Optional, Union, Literal
|
||||
|
||||
import pyarrow as pa
|
||||
from lance import json_to_schema
|
||||
|
||||
from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||
from lancedb.merge import LanceMergeInsertBuilder
|
||||
from lancedb.query import LanceQueryBuilder
|
||||
|
||||
from ..query import LanceVectorQueryBuilder
|
||||
from ..table import Query, Table, _sanitize_data
|
||||
@@ -81,6 +82,7 @@ class RemoteTable(Table):
|
||||
def create_scalar_index(
|
||||
self,
|
||||
column: str,
|
||||
index_type: Literal["BTREE", "BITMAP", "LABEL_LIST", "scalar"] = "scalar",
|
||||
):
|
||||
"""Creates a scalar index
|
||||
Parameters
|
||||
@@ -89,8 +91,6 @@ class RemoteTable(Table):
|
||||
The column to be indexed. Must be a boolean, integer, float,
|
||||
or string column.
|
||||
"""
|
||||
index_type = "scalar"
|
||||
|
||||
data = {
|
||||
"column": column,
|
||||
"index_type": index_type,
|
||||
@@ -228,10 +228,21 @@ class RemoteTable(Table):
|
||||
content_type=ARROW_STREAM_CONTENT_TYPE,
|
||||
)
|
||||
|
||||
def query(
|
||||
self,
|
||||
query: Union[VEC, str] = None,
|
||||
query_type: str = "vector",
|
||||
vector_column_name: Optional[str] = None,
|
||||
fast_search: bool = False,
|
||||
) -> LanceVectorQueryBuilder:
|
||||
return self.search(query, query_type, vector_column_name, fast_search)
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: Union[VEC, str],
|
||||
query: Union[VEC, str] = None,
|
||||
query_type: str = "vector",
|
||||
vector_column_name: Optional[str] = None,
|
||||
fast_search: bool = False,
|
||||
) -> LanceVectorQueryBuilder:
|
||||
"""Create a search query to find the nearest neighbors
|
||||
of the given query vector. We currently support [vector search][search]
|
||||
@@ -278,6 +289,11 @@ class RemoteTable(Table):
|
||||
- If the table has multiple vector columns then the *vector_column_name*
|
||||
needs to be specified. Otherwise, an error is raised.
|
||||
|
||||
fast_search: bool, optional
|
||||
Skip a flat search of unindexed data. This may improve
|
||||
search performance but search results will not include unindexed data.
|
||||
|
||||
- *default False*.
|
||||
Returns
|
||||
-------
|
||||
LanceQueryBuilder
|
||||
@@ -293,7 +309,14 @@ class RemoteTable(Table):
|
||||
"""
|
||||
if vector_column_name is None:
|
||||
vector_column_name = inf_vector_column_query(self.schema)
|
||||
return LanceVectorQueryBuilder(self, query, vector_column_name)
|
||||
|
||||
return LanceQueryBuilder.create(
|
||||
self,
|
||||
query,
|
||||
query_type,
|
||||
vector_column_name=vector_column_name,
|
||||
fast_search=fast_search,
|
||||
)
|
||||
|
||||
def _execute_query(
|
||||
self, query: Query, batch_size: Optional[int] = None
|
||||
|
||||
@@ -21,6 +21,7 @@ class FakeLanceDBClient:
|
||||
pass
|
||||
|
||||
def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
|
||||
print(f"{query=}")
|
||||
assert table_name == "test"
|
||||
t = pa.schema([]).empty_table()
|
||||
return VectorQueryResult(t)
|
||||
@@ -39,3 +40,21 @@ def test_remote_db():
|
||||
table = conn["test"]
|
||||
table.schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2))])
|
||||
table.search([1.0, 2.0]).to_pandas()
|
||||
|
||||
|
||||
def test_empty_query_with_filter():
|
||||
conn = lancedb.connect("db://client-will-be-injected", api_key="fake")
|
||||
setattr(conn, "_client", FakeLanceDBClient())
|
||||
|
||||
table = conn["test"]
|
||||
table.schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2))])
|
||||
print(table.query().select(["vector"]).where("foo == bar").to_arrow())
|
||||
|
||||
|
||||
def test_fast_search_query_with_filter():
|
||||
conn = lancedb.connect("db://client-will-be-injected", api_key="fake")
|
||||
setattr(conn, "_client", FakeLanceDBClient())
|
||||
|
||||
table = conn["test"]
|
||||
table.schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2))])
|
||||
print(table.query([0, 0], fast_search=True).select(["vector"]).where("foo == bar").to_arrow())
|
||||
|
||||
@@ -735,7 +735,7 @@ def test_create_scalar_index(db):
|
||||
indices = table.to_lance().list_indices()
|
||||
assert len(indices) == 1
|
||||
scalar_index = indices[0]
|
||||
assert scalar_index["type"] == "Scalar"
|
||||
assert scalar_index["type"] == "BTree"
|
||||
|
||||
# Confirm that prefiltering still works with the scalar index column
|
||||
results = table.search().where("x = 'c'").to_arrow()
|
||||
|
||||
Reference in New Issue
Block a user