From 495c335831c02b6cb413589366b9b44bb7af53a3 Mon Sep 17 00:00:00 2001 From: Ryan Green Date: Fri, 20 Dec 2024 09:43:39 -0600 Subject: [PATCH] Fix fast_search --- python/python/lancedb/query.py | 2 ++ python/python/lancedb/remote/client.py | 1 + python/python/tests/test_remote_db.py | 10 ++++++++++ 3 files changed, 13 insertions(+) diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index c45346c1..f729902e 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -117,6 +117,8 @@ class Query(pydantic.BaseModel): with_row_id: bool = False + fast_search: bool = False + class LanceQueryBuilder(ABC): """An abstract query builder. Subclasses are defined for vector search, diff --git a/python/python/lancedb/remote/client.py b/python/python/lancedb/remote/client.py index 6e3abde4..217cd795 100644 --- a/python/python/lancedb/remote/client.py +++ b/python/python/lancedb/remote/client.py @@ -172,6 +172,7 @@ class RestfulLanceDBClient: headers["content-type"] = content_type if request_id is not None: headers["x-request-id"] = request_id + with self.session.post( urljoin(self.url, uri), headers=headers, diff --git a/python/python/tests/test_remote_db.py b/python/python/tests/test_remote_db.py index 92b82b85..e321434e 100644 --- a/python/python/tests/test_remote_db.py +++ b/python/python/tests/test_remote_db.py @@ -21,6 +21,7 @@ class FakeLanceDBClient: pass def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult: + print(f"{query=}") assert table_name == "test" t = pa.schema([]).empty_table() return VectorQueryResult(t) @@ -48,3 +49,12 @@ def test_empty_query_with_filter(): table = conn["test"] table.schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2))]) print(table.query().select(["vector"]).where("foo == bar").to_arrow()) + + +def test_fast_search_query_with_filter(): + conn = lancedb.connect("db://client-will-be-injected", api_key="fake") + setattr(conn, "_client", FakeLanceDBClient()) + + table = conn["test"] + table.schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2))]) + print(table.query([0, 0], fast_search=True).select(["vector"]).where("foo == bar").to_arrow())