diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index abc577e0..1bae974d 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -16,7 +16,16 @@ from __future__ import annotations from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor from pathlib import Path -from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Type, Union +from typing import ( + TYPE_CHECKING, + Dict, + List, + Literal, + Optional, + Tuple, + Type, + Union, +) import deprecation import numpy as np @@ -515,6 +524,21 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): and also the "_distance" column which is the distance between the query vector and the returned vectors. """ + return self.to_batches().read_all() + + def to_batches(self, /, batch_size: Optional[int] = None) -> pa.RecordBatchReader: + """ + Execute the query and return the result as a RecordBatchReader object. + + Parameters + ---------- + batch_size: int + The maximum number of selected records in a RecordBatch object. + + Returns + ------- + pa.RecordBatchReader + """ vector = self._query if isinstance(self._query, list) else self._query.tolist() if isinstance(vector[0], np.ndarray): vector = [v.tolist() for v in vector] @@ -530,9 +554,14 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): vector_column=self._vector_column, with_row_id=self._with_row_id, ) - result_set = self._table._execute_query(query) + result_set = self._table._execute_query(query, batch_size) if self._reranker is not None: - result_set = self._reranker.rerank_vector(self._str_query, result_set) + rs_table = result_set.read_all() + result_set = self._reranker.rerank_vector(self._str_query, rs_table) + # convert result_set back to RecordBatchReader + result_set = pa.RecordBatchReader.from_batches( + result_set.schema, result_set.to_batches() + ) return result_set diff --git a/python/python/lancedb/remote/table.py b/python/python/lancedb/remote/table.py index cb055d39..65bc450d 100644 --- a/python/python/lancedb/remote/table.py +++ b/python/python/lancedb/remote/table.py @@ -295,7 +295,9 @@ class RemoteTable(Table): vector_column_name = inf_vector_column_query(self.schema) return LanceVectorQueryBuilder(self, query, vector_column_name) - def _execute_query(self, query: Query) -> pa.Table: + def _execute_query( + self, query: Query, batch_size: Optional[int] = None + ) -> pa.RecordBatchReader: if ( query.vector is not None and len(query.vector) > 0 @@ -321,13 +323,12 @@ class RemoteTable(Table): q = query.copy() q.vector = v results.append(submit(self._name, q)) - return pa.concat_tables( [add_index(r.result().to_arrow(), i) for i, r in enumerate(results)] - ) + ).to_reader() else: result = self._conn._client.query(self._name, query) - return result.to_arrow() + return result.to_arrow().to_reader() def _do_merge( self, diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 75d4a2b0..485d3ef8 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -567,7 +567,9 @@ class Table(ABC): raise NotImplementedError @abstractmethod - def _execute_query(self, query: Query) -> pa.Table: + def _execute_query( + self, query: Query, batch_size: Optional[int] = None + ) -> pa.RecordBatchReader: pass @abstractmethod @@ -1588,10 +1590,11 @@ class LanceTable(Table): self._dataset_mut.update(values_sql, where) - def _execute_query(self, query: Query) -> pa.Table: + def _execute_query( + self, query: Query, batch_size: Optional[int] = None + ) -> pa.RecordBatchReader: ds = self.to_lance() - - return ds.to_table( + return ds.scanner( columns=query.columns, filter=query.filter, prefilter=query.prefilter, @@ -1604,7 +1607,8 @@ class LanceTable(Table): "refine_factor": query.refine_factor, }, with_row_id=query.with_row_id, - ) + batch_size=batch_size, + ).to_reader() def _do_merge( self, @@ -2153,13 +2157,17 @@ class AsyncTable: ) -> AsyncVectorQuery: """ Search the table with a given query vector. - This is a convenience method for preparing a vector query and is the same thing as calling `nearestTo` on the builder returned by `query`. Seer [nearest_to][AsyncQuery.nearest_to] for more details. """ return self.query().nearest_to(query_vector) + async def _execute_query( + self, query: Query, batch_size: Optional[int] = None + ) -> pa.RecordBatchReader: + pass + async def _do_merge( self, merge: LanceMergeInsertBuilder, diff --git a/python/python/tests/test_query.py b/python/python/tests/test_query.py index 49d207b8..6f047bd3 100644 --- a/python/python/tests/test_query.py +++ b/python/python/tests/test_query.py @@ -13,6 +13,7 @@ import unittest.mock as mock from datetime import timedelta +from typing import Optional import lance import lancedb @@ -35,9 +36,9 @@ class MockTable: def to_lance(self): return lance.dataset(self.uri) - def _execute_query(self, query): + def _execute_query(self, query, batch_size: Optional[int] = None): ds = self.to_lance() - return ds.to_table( + return ds.scanner( columns=query.columns, filter=query.filter, prefilter=query.prefilter, @@ -49,7 +50,8 @@ class MockTable: "nprobes": query.nprobes, "refine_factor": query.refine_factor, }, - ) + batch_size=batch_size, + ).to_reader() @pytest.fixture @@ -115,6 +117,25 @@ def test_query_builder(table): assert all(np.array(rs[0]["vector"]) == [1, 2]) +def test_query_builder_batches(table): + rs = ( + LanceVectorQueryBuilder(table, [0, 0], "vector") + .limit(2) + .select(["id", "vector"]) + .to_batches(1) + ) + rs_list = [] + for item in rs: + rs_list.append(item) + assert isinstance(item, pa.RecordBatch) + assert len(rs_list) == 1 + assert len(rs_list[0]["id"]) == 2 + assert all(rs_list[0].to_pandas()["vector"][0] == [1.0, 2.0]) + assert rs_list[0].to_pandas()["id"][0] == 1 + assert all(rs_list[0].to_pandas()["vector"][1] == [3.0, 4.0]) + assert rs_list[0].to_pandas()["id"][1] == 2 + + def test_dynamic_projection(table): rs = ( LanceVectorQueryBuilder(table, [0, 0], "vector") @@ -199,7 +220,8 @@ def test_query_builder_with_different_vector_column(): nprobes=20, refine_factor=None, vector_column="foo_vector", - ) + ), + None, ) diff --git a/python/python/tests/test_rerankers.py b/python/python/tests/test_rerankers.py index 6465c51a..7775d598 100644 --- a/python/python/tests/test_rerankers.py +++ b/python/python/tests/test_rerankers.py @@ -163,7 +163,7 @@ def test_cohere_reranker(tmp_path): assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err result_explicit = ( table.search(query_vector) - .rerank(reranker=reranker, query=query) + .rerank(reranker=reranker, query_string=query) .limit(30) .to_arrow() ) @@ -225,7 +225,7 @@ def test_cross_encoder_reranker(tmp_path): result_explicit = ( table.search(query_vector) - .rerank(reranker=reranker, query=query) + .rerank(reranker=reranker, query_string=query) .limit(30) .to_arrow() ) @@ -286,7 +286,7 @@ def test_colbert_reranker(tmp_path): assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err result_explicit = ( table.search(query_vector) - .rerank(reranker=reranker, query=query) + .rerank(reranker=reranker, query_string=query) .limit(30) .to_arrow() ) @@ -351,7 +351,7 @@ def test_openai_reranker(tmp_path): assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err result_explicit = ( table.search(query_vector) - .rerank(reranker=reranker, query=query) + .rerank(reranker=reranker, query_string=query) .limit(30) .to_arrow() )