Compare commits

...

15 Commits

Author SHA1 Message Date
Lance Release
7a3ef68306 Bump version: 0.9.0-beta.3 → 0.9.0-beta.4 2024-12-20 16:02:53 +00:00
Ryan Green
43952e01d7 bump version 2024-12-20 09:44:46 -06:00
Ryan Green
495c335831 Fix fast_search 2024-12-20 09:43:39 -06:00
Ryan Green
77707db543 Backport fast_search and empty query builder for remote table 2024-12-20 09:21:05 -06:00
Ryan Green
d6d7ad3b06 bump version 2024-12-18 10:21:04 -06:00
Ryan Green
e58d64c286 Remove unsupported Retry params 2024-12-18 10:08:38 -06:00
Ryan Green
76cbd18c46 bump version 2024-12-18 09:38:36 -06:00
Ryan Green
4abb38ac70 bump version 2024-12-18 09:37:58 -06:00
Ryan Green
cc7bc5011d Merge remote-tracking branch 'origin/python-v0.9.0-patch' into python-v0.9.0-patch
# Conflicts:
#	python/pyproject.toml
2024-12-18 08:59:35 -06:00
Ryan Green
8193183304 override urllib3 version 2024-12-18 08:59:24 -06:00
Ryan Green
cf28b58b7d override urllib3 version 2024-12-18 08:58:41 -06:00
Lance Release
e3b7ee47b9 Bump version: 0.9.0 → 0.9.0-final.1 2024-12-13 01:16:24 +00:00
Lu Qiu
97c9c906e4 Fix version test 2024-12-12 17:10:07 -08:00
Lu Qiu
358f86b9c6 fix 2024-12-12 16:44:24 -08:00
Lu Qiu
5489e215a3 Support storage options and folder prefix 2024-12-12 16:17:34 -08:00
10 changed files with 278 additions and 48 deletions

View File

@@ -1,5 +1,5 @@
[tool.bumpversion] [tool.bumpversion]
current_version = "0.9.0" current_version = "0.9.0-beta.4"
parse = """(?x) parse = """(?x)
(?P<major>0|[1-9]\\d*)\\. (?P<major>0|[1-9]\\d*)\\.
(?P<minor>0|[1-9]\\d*)\\. (?P<minor>0|[1-9]\\d*)\\.

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "lancedb-python" name = "lancedb-python"
version = "0.9.0" version = "0.9.0-beta.4"
edition.workspace = true edition.workspace = true
description = "Python bindings for LanceDB" description = "Python bindings for LanceDB"
license.workspace = true license.workspace = true

View File

@@ -13,6 +13,7 @@ dependencies = [
"packaging", "packaging",
"cachetools", "cachetools",
"overrides>=0.7", "overrides>=0.7",
"urllib3==1.26.19"
] ]
description = "lancedb" description = "lancedb"
authors = [{ name = "LanceDB Devs", email = "dev@lancedb.com" }] authors = [{ name = "LanceDB Devs", email = "dev@lancedb.com" }]

View File

@@ -35,6 +35,7 @@ def connect(
host_override: Optional[str] = None, host_override: Optional[str] = None,
read_consistency_interval: Optional[timedelta] = None, read_consistency_interval: Optional[timedelta] = None,
request_thread_pool: Optional[Union[int, ThreadPoolExecutor]] = None, request_thread_pool: Optional[Union[int, ThreadPoolExecutor]] = None,
storage_options: Optional[Dict[str, str]] = None,
**kwargs, **kwargs,
) -> DBConnection: ) -> DBConnection:
"""Connect to a LanceDB database. """Connect to a LanceDB database.
@@ -70,6 +71,9 @@ def connect(
executor will be used for making requests. This is for LanceDB Cloud executor will be used for making requests. This is for LanceDB Cloud
only and is only used when making batch requests (i.e., passing in only and is only used when making batch requests (i.e., passing in
multiple queries to the search method at once). multiple queries to the search method at once).
storage_options: dict, optional
Additional options for the storage backend. See available options at
https://lancedb.github.io/lancedb/guides/storage/
Examples Examples
-------- --------
@@ -105,12 +109,16 @@ def connect(
region, region,
host_override, host_override,
request_thread_pool=request_thread_pool, request_thread_pool=request_thread_pool,
storage_options=storage_options,
**kwargs, **kwargs,
) )
if kwargs: if kwargs:
raise ValueError(f"Unknown keyword arguments: {kwargs}") raise ValueError(f"Unknown keyword arguments: {kwargs}")
return LanceDBConnection(uri, read_consistency_interval=read_consistency_interval) return LanceDBConnection(
uri,
read_consistency_interval=read_consistency_interval,
)
async def connect_async( async def connect_async(

View File

@@ -117,6 +117,8 @@ class Query(pydantic.BaseModel):
with_row_id: bool = False with_row_id: bool = False
fast_search: bool = False
class LanceQueryBuilder(ABC): class LanceQueryBuilder(ABC):
"""An abstract query builder. Subclasses are defined for vector search, """An abstract query builder. Subclasses are defined for vector search,
@@ -125,12 +127,14 @@ class LanceQueryBuilder(ABC):
@classmethod @classmethod
def create( def create(
cls, cls,
table: "Table", table: "Table",
query: Optional[Union[np.ndarray, str, "PIL.Image.Image", Tuple]], query: Optional[Union[np.ndarray, str, "PIL.Image.Image", Tuple]],
query_type: str, query_type: str,
vector_column_name: str, vector_column_name: str,
ordering_field_name: str = None, ordering_field_name: Optional[str] = None,
fts_columns: Union[str, List[str]] = [],
fast_search: bool = False,
) -> LanceQueryBuilder: ) -> LanceQueryBuilder:
""" """
Create a query builder based on the given query and query type. Create a query builder based on the given query and query type.
@@ -147,13 +151,18 @@ class LanceQueryBuilder(ABC):
If "auto", the query type is inferred based on the query. If "auto", the query type is inferred based on the query.
vector_column_name: str vector_column_name: str
The name of the vector column to use for vector search. The name of the vector column to use for vector search.
fast_search: bool
Skip flat search of unindexed data.
""" """
if query is None: # Check hybrid search first as it supports empty query pattern
return LanceEmptyQueryBuilder(table)
if query_type == "hybrid": if query_type == "hybrid":
# hybrid fts and vector query # hybrid fts and vector query
return LanceHybridQueryBuilder(table, query, vector_column_name) return LanceHybridQueryBuilder(
table, query, vector_column_name, fts_columns=fts_columns
)
if query is None:
return LanceEmptyQueryBuilder(table)
# remember the string query for reranking purpose # remember the string query for reranking purpose
str_query = query if isinstance(query, str) else None str_query = query if isinstance(query, str) else None
@@ -165,12 +174,17 @@ class LanceQueryBuilder(ABC):
) )
if query_type == "hybrid": if query_type == "hybrid":
return LanceHybridQueryBuilder(table, query, vector_column_name) return LanceHybridQueryBuilder(
table, query, vector_column_name, fts_columns=fts_columns
)
if isinstance(query, str): if isinstance(query, str):
# fts # fts
return LanceFtsQueryBuilder( return LanceFtsQueryBuilder(
table, query, ordering_field_name=ordering_field_name table,
query,
ordering_field_name=ordering_field_name,
fts_columns=fts_columns,
) )
if isinstance(query, list): if isinstance(query, list):
@@ -180,7 +194,9 @@ class LanceQueryBuilder(ABC):
else: else:
raise TypeError(f"Unsupported query type: {type(query)}") raise TypeError(f"Unsupported query type: {type(query)}")
return LanceVectorQueryBuilder(table, query, vector_column_name, str_query) return LanceVectorQueryBuilder(
table, query, vector_column_name, str_query, fast_search
)
@classmethod @classmethod
def _resolve_query(cls, table, query, query_type, vector_column_name): def _resolve_query(cls, table, query, query_type, vector_column_name):
@@ -196,8 +212,6 @@ class LanceQueryBuilder(ABC):
elif query_type == "auto": elif query_type == "auto":
if isinstance(query, (list, np.ndarray)): if isinstance(query, (list, np.ndarray)):
return query, "vector" return query, "vector"
if isinstance(query, tuple):
return query, "hybrid"
else: else:
conf = table.embedding_functions.get(vector_column_name) conf = table.embedding_functions.get(vector_column_name)
if conf is not None: if conf is not None:
@@ -224,9 +238,14 @@ class LanceQueryBuilder(ABC):
def __init__(self, table: "Table"): def __init__(self, table: "Table"):
self._table = table self._table = table
self._limit = 10 self._limit = 10
self._offset = 0
self._columns = None self._columns = None
self._where = None self._where = None
self._prefilter = False
self._with_row_id = False self._with_row_id = False
self._vector = None
self._text = None
self._ef = None
@deprecation.deprecated( @deprecation.deprecated(
deprecated_in="0.3.1", deprecated_in="0.3.1",
@@ -337,11 +356,13 @@ class LanceQueryBuilder(ABC):
---------- ----------
limit: int limit: int
The maximum number of results to return. The maximum number of results to return.
By default the query is limited to the first 10. The default query limit is 10 results.
Call this method and pass 0, a negative value, For ANN/KNN queries, you must specify a limit.
or None to remove the limit. Entering 0, a negative number, or None will reset
*WARNING* if you have a large dataset, removing the limit to the default value of 10.
the limit can potentially result in reading a *WARNING* if you have a large dataset, setting
the limit to a large number, e.g. the table size,
can potentially result in reading a
large amount of data into memory and cause large amount of data into memory and cause
out of memory issues. out of memory issues.
@@ -351,11 +372,33 @@ class LanceQueryBuilder(ABC):
The LanceQueryBuilder object. The LanceQueryBuilder object.
""" """
if limit is None or limit <= 0: if limit is None or limit <= 0:
self._limit = None if isinstance(self, LanceVectorQueryBuilder):
raise ValueError("Limit is required for ANN/KNN queries")
else:
self._limit = None
else: else:
self._limit = limit self._limit = limit
return self return self
def offset(self, offset: int) -> LanceQueryBuilder:
"""Set the offset for the results.
Parameters
----------
offset: int
The offset to start fetching results from.
Returns
-------
LanceQueryBuilder
The LanceQueryBuilder object.
"""
if offset is None or offset <= 0:
self._offset = 0
else:
self._offset = offset
return self
def select(self, columns: Union[list[str], dict[str, str]]) -> LanceQueryBuilder: def select(self, columns: Union[list[str], dict[str, str]]) -> LanceQueryBuilder:
"""Set the columns to return. """Set the columns to return.
@@ -417,6 +460,80 @@ class LanceQueryBuilder(ABC):
self._with_row_id = with_row_id self._with_row_id = with_row_id
return self return self
def explain_plan(self, verbose: Optional[bool] = False) -> str:
"""Return the execution plan for this query.
Examples
--------
>>> import lancedb
>>> db = lancedb.connect("./.lancedb")
>>> table = db.create_table("my_table", [{"vector": [99, 99]}])
>>> query = [100, 100]
>>> plan = table.search(query).explain_plan(True)
>>> print(plan) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
ProjectionExec: expr=[vector@0 as vector, _distance@2 as _distance]
GlobalLimitExec: skip=0, fetch=10
FilterExec: _distance@2 IS NOT NULL
SortExec: TopK(fetch=10), expr=[_distance@2 ASC NULLS LAST], preserve_partitioning=[false]
KNNVectorDistance: metric=l2
LanceScan: uri=..., projection=[vector], row_id=true, row_addr=false, ordered=false
Parameters
----------
verbose : bool, default False
Use a verbose output format.
Returns
-------
plan : str
""" # noqa: E501
ds = self._table.to_lance()
return ds.scanner(
nearest={
"column": self._vector_column,
"q": self._query,
"k": self._limit,
"metric": self._metric,
"nprobes": self._nprobes,
"refine_factor": self._refine_factor,
},
prefilter=self._prefilter,
filter=self._str_query,
limit=self._limit,
with_row_id=self._with_row_id,
offset=self._offset,
).explain_plan(verbose)
def vector(self, vector: Union[np.ndarray, list]) -> LanceQueryBuilder:
"""Set the vector to search for.
Parameters
----------
vector: np.ndarray or list
The vector to search for.
Returns
-------
LanceQueryBuilder
The LanceQueryBuilder object.
"""
raise NotImplementedError
def text(self, text: str) -> LanceQueryBuilder:
"""Set the text to search for.
Parameters
----------
text: str
The text to search for.
Returns
-------
LanceQueryBuilder
The LanceQueryBuilder object.
"""
raise NotImplementedError
class LanceVectorQueryBuilder(LanceQueryBuilder): class LanceVectorQueryBuilder(LanceQueryBuilder):
""" """
@@ -440,11 +557,12 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
""" """
def __init__( def __init__(
self, self,
table: "Table", table: "Table",
query: Union[np.ndarray, list, "PIL.Image.Image"], query: Union[np.ndarray, list, "PIL.Image.Image"],
vector_column: str, vector_column: str,
str_query: Optional[str] = None, str_query: Optional[str] = None,
fast_search: bool = False,
): ):
super().__init__(table) super().__init__(table)
self._query = query self._query = query
@@ -455,13 +573,14 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
self._prefilter = False self._prefilter = False
self._reranker = None self._reranker = None
self._str_query = str_query self._str_query = str_query
self._fast_search = fast_search
def metric(self, metric: Literal["L2", "cosine"]) -> LanceVectorQueryBuilder: def metric(self, metric: Literal["L2", "cosine", "dot"]) -> LanceVectorQueryBuilder:
"""Set the distance metric to use. """Set the distance metric to use.
Parameters Parameters
---------- ----------
metric: "L2" or "cosine" metric: "L2" or "cosine" or "dot"
The distance metric to use. By default "L2" is used. The distance metric to use. By default "L2" is used.
Returns Returns
@@ -469,7 +588,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
LanceVectorQueryBuilder LanceVectorQueryBuilder
The LanceQueryBuilder object. The LanceQueryBuilder object.
""" """
self._metric = metric self._metric = metric.lower()
return self return self
def nprobes(self, nprobes: int) -> LanceVectorQueryBuilder: def nprobes(self, nprobes: int) -> LanceVectorQueryBuilder:
@@ -494,6 +613,28 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
self._nprobes = nprobes self._nprobes = nprobes
return self return self
def ef(self, ef: int) -> LanceVectorQueryBuilder:
"""Set the number of candidates to consider during search.
Higher values will yield better recall (more likely to find vectors if
they exist) at the expense of latency.
This only applies to the HNSW-related index.
The default value is 1.5 * limit.
Parameters
----------
ef: int
The number of candidates to consider during search.
Returns
-------
LanceVectorQueryBuilder
The LanceQueryBuilder object.
"""
self._ef = ef
return self
def refine_factor(self, refine_factor: int) -> LanceVectorQueryBuilder: def refine_factor(self, refine_factor: int) -> LanceVectorQueryBuilder:
"""Set the refine factor to use, increasing the number of vectors sampled. """Set the refine factor to use, increasing the number of vectors sampled.
@@ -554,15 +695,11 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
refine_factor=self._refine_factor, refine_factor=self._refine_factor,
vector_column=self._vector_column, vector_column=self._vector_column,
with_row_id=self._with_row_id, with_row_id=self._with_row_id,
offset=self._offset,
fast_search=self._fast_search,
ef=self._ef,
) )
result_set = self._table._execute_query(query, batch_size) result_set = self._table._execute_query(query, batch_size)
if self._reranker is not None:
rs_table = result_set.read_all()
result_set = self._reranker.rerank_vector(self._str_query, rs_table)
# convert result_set back to RecordBatchReader
result_set = pa.RecordBatchReader.from_batches(
result_set.schema, result_set.to_batches()
)
return result_set return result_set
@@ -591,7 +728,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
return self return self
def rerank( def rerank(
self, reranker: Reranker, query_string: Optional[str] = None self, reranker: Reranker, query_string: Optional[str] = None
) -> LanceVectorQueryBuilder: ) -> LanceVectorQueryBuilder:
"""Rerank the results using the specified reranker. """Rerank the results using the specified reranker.
@@ -756,12 +893,34 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
class LanceEmptyQueryBuilder(LanceQueryBuilder): class LanceEmptyQueryBuilder(LanceQueryBuilder):
def to_arrow(self) -> pa.Table: def to_arrow(self) -> pa.Table:
ds = self._table.to_lance() return self.to_batches().read_all()
return ds.to_table(
def to_batches(self, /, batch_size: Optional[int] = None) -> pa.RecordBatchReader:
query = Query(
columns=self._columns, columns=self._columns,
filter=self._where, filter=self._where,
limit=self._limit, k=self._limit or 10,
with_row_id=self._with_row_id,
vector=[],
# not actually respected in remote query
offset=self._offset or 0,
) )
return self._table._execute_query(query)
def rerank(self, reranker: Reranker) -> LanceEmptyQueryBuilder:
"""Rerank the results using the specified reranker.
Parameters
----------
reranker: Reranker
The reranker to use.
Returns
-------
LanceEmptyQueryBuilder
The LanceQueryBuilder object.
"""
raise NotImplementedError("Reranking is not yet supported.")
class LanceHybridQueryBuilder(LanceQueryBuilder): class LanceHybridQueryBuilder(LanceQueryBuilder):

View File

@@ -55,11 +55,13 @@ class RestfulLanceDBClient:
region: str region: str
api_key: Credential api_key: Credential
host_override: Optional[str] = attrs.field(default=None) host_override: Optional[str] = attrs.field(default=None)
db_prefix: Optional[str] = attrs.field(default=None)
closed: bool = attrs.field(default=False, init=False) closed: bool = attrs.field(default=False, init=False)
connection_timeout: float = attrs.field(default=120.0, kw_only=True) connection_timeout: float = attrs.field(default=120.0, kw_only=True)
read_timeout: float = attrs.field(default=300.0, kw_only=True) read_timeout: float = attrs.field(default=300.0, kw_only=True)
storage_options: Optional[Dict[str, str]] = attrs.field(default=None, kw_only=True)
@functools.cached_property @functools.cached_property
def session(self) -> requests.Session: def session(self) -> requests.Session:
@@ -92,6 +94,18 @@ class RestfulLanceDBClient:
headers["Host"] = f"{self.db_name}.{self.region}.api.lancedb.com" headers["Host"] = f"{self.db_name}.{self.region}.api.lancedb.com"
if self.host_override: if self.host_override:
headers["x-lancedb-database"] = self.db_name headers["x-lancedb-database"] = self.db_name
if self.storage_options:
if self.storage_options.get("account_name") is not None:
headers["x-azure-storage-account-name"] = self.storage_options[
"account_name"
]
if self.storage_options.get("azure_storage_account_name") is not None:
headers["x-azure-storage-account-name"] = self.storage_options[
"azure_storage_account_name"
]
if self.db_prefix:
headers["x-lancedb-database-prefix"] = self.db_prefix
return headers return headers
@staticmethod @staticmethod
@@ -158,6 +172,7 @@ class RestfulLanceDBClient:
headers["content-type"] = content_type headers["content-type"] = content_type
if request_id is not None: if request_id is not None:
headers["x-request-id"] = request_id headers["x-request-id"] = request_id
with self.session.post( with self.session.post(
urljoin(self.url, uri), urljoin(self.url, uri),
headers=headers, headers=headers,
@@ -245,7 +260,6 @@ def retry_adapter(options: Dict[str, Any]) -> HTTPAdapter:
connect=connect_retries, connect=connect_retries,
read=read_retries, read=read_retries,
backoff_factor=backoff_factor, backoff_factor=backoff_factor,
backoff_jitter=backoff_jitter,
status_forcelist=statuses, status_forcelist=statuses,
allowed_methods=methods, allowed_methods=methods,
) )

View File

@@ -15,7 +15,7 @@ import inspect
import logging import logging
import uuid import uuid
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Iterable, List, Optional, Union from typing import Dict, Iterable, List, Optional, Union
from urllib.parse import urlparse from urllib.parse import urlparse
from cachetools import TTLCache from cachetools import TTLCache
@@ -44,20 +44,25 @@ class RemoteDBConnection(DBConnection):
request_thread_pool: Optional[ThreadPoolExecutor] = None, request_thread_pool: Optional[ThreadPoolExecutor] = None,
connection_timeout: float = 120.0, connection_timeout: float = 120.0,
read_timeout: float = 300.0, read_timeout: float = 300.0,
storage_options: Optional[Dict[str, str]] = None,
): ):
"""Connect to a remote LanceDB database.""" """Connect to a remote LanceDB database."""
parsed = urlparse(db_url) parsed = urlparse(db_url)
if parsed.scheme != "db": if parsed.scheme != "db":
raise ValueError(f"Invalid scheme: {parsed.scheme}, only accepts db://") raise ValueError(f"Invalid scheme: {parsed.scheme}, only accepts db://")
self.db_name = parsed.netloc self.db_name = parsed.netloc
prefix = parsed.path.lstrip("/")
self.db_prefix = None if not prefix else prefix
self.api_key = api_key self.api_key = api_key
self._client = RestfulLanceDBClient( self._client = RestfulLanceDBClient(
self.db_name, self.db_name,
region, region,
api_key, api_key,
host_override, host_override,
self.db_prefix,
connection_timeout=connection_timeout, connection_timeout=connection_timeout,
read_timeout=read_timeout, read_timeout=read_timeout,
storage_options=storage_options,
) )
self._request_thread_pool = request_thread_pool self._request_thread_pool = request_thread_pool
self._table_cache = TTLCache(maxsize=10000, ttl=300) self._table_cache = TTLCache(maxsize=10000, ttl=300)

View File

@@ -22,6 +22,7 @@ from lance import json_to_schema
from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
from lancedb.merge import LanceMergeInsertBuilder from lancedb.merge import LanceMergeInsertBuilder
from lancedb.query import LanceQueryBuilder
from ..query import LanceVectorQueryBuilder from ..query import LanceVectorQueryBuilder
from ..table import Query, Table, _sanitize_data from ..table import Query, Table, _sanitize_data
@@ -228,10 +229,21 @@ class RemoteTable(Table):
content_type=ARROW_STREAM_CONTENT_TYPE, content_type=ARROW_STREAM_CONTENT_TYPE,
) )
def query(
self,
query: Union[VEC, str] = None,
query_type: str = "vector",
vector_column_name: Optional[str] = None,
fast_search: bool = False,
) -> LanceVectorQueryBuilder:
return self.search(query, query_type, vector_column_name, fast_search)
def search( def search(
self, self,
query: Union[VEC, str], query: Union[VEC, str] = None,
query_type: str = "vector",
vector_column_name: Optional[str] = None, vector_column_name: Optional[str] = None,
fast_search: bool = False,
) -> LanceVectorQueryBuilder: ) -> LanceVectorQueryBuilder:
"""Create a search query to find the nearest neighbors """Create a search query to find the nearest neighbors
of the given query vector. We currently support [vector search][search] of the given query vector. We currently support [vector search][search]
@@ -278,6 +290,11 @@ class RemoteTable(Table):
- If the table has multiple vector columns then the *vector_column_name* - If the table has multiple vector columns then the *vector_column_name*
needs to be specified. Otherwise, an error is raised. needs to be specified. Otherwise, an error is raised.
fast_search: bool, optional
Skip a flat search of unindexed data. This may improve
search performance but search results will not include unindexed data.
- *default False*.
Returns Returns
------- -------
LanceQueryBuilder LanceQueryBuilder
@@ -293,7 +310,14 @@ class RemoteTable(Table):
""" """
if vector_column_name is None: if vector_column_name is None:
vector_column_name = inf_vector_column_query(self.schema) vector_column_name = inf_vector_column_query(self.schema)
return LanceVectorQueryBuilder(self, query, vector_column_name)
return LanceQueryBuilder.create(
self,
query,
query_type,
vector_column_name=vector_column_name,
fast_search=fast_search,
)
def _execute_query( def _execute_query(
self, query: Query, batch_size: Optional[int] = None self, query: Query, batch_size: Optional[int] = None

View File

@@ -21,6 +21,7 @@ class FakeLanceDBClient:
pass pass
def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult: def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
print(f"{query=}")
assert table_name == "test" assert table_name == "test"
t = pa.schema([]).empty_table() t = pa.schema([]).empty_table()
return VectorQueryResult(t) return VectorQueryResult(t)
@@ -39,3 +40,21 @@ def test_remote_db():
table = conn["test"] table = conn["test"]
table.schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2))]) table.schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2))])
table.search([1.0, 2.0]).to_pandas() table.search([1.0, 2.0]).to_pandas()
def test_empty_query_with_filter():
conn = lancedb.connect("db://client-will-be-injected", api_key="fake")
setattr(conn, "_client", FakeLanceDBClient())
table = conn["test"]
table.schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2))])
print(table.query().select(["vector"]).where("foo == bar").to_arrow())
def test_fast_search_query_with_filter():
conn = lancedb.connect("db://client-will-be-injected", api_key="fake")
setattr(conn, "_client", FakeLanceDBClient())
table = conn["test"]
table.schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2))])
print(table.query([0, 0], fast_search=True).select(["vector"]).where("foo == bar").to_arrow())

View File

@@ -735,7 +735,7 @@ def test_create_scalar_index(db):
indices = table.to_lance().list_indices() indices = table.to_lance().list_indices()
assert len(indices) == 1 assert len(indices) == 1
scalar_index = indices[0] scalar_index = indices[0]
assert scalar_index["type"] == "Scalar" assert scalar_index["type"] == "BTree"
# Confirm that prefiltering still works with the scalar index column # Confirm that prefiltering still works with the scalar index column
results = table.search().where("x = 'c'").to_arrow() results = table.search().where("x = 'c'").to_arrow()