feat: add to_batches API #805 (#1048)

SDK
Python

Description
Exposes pyarrow batch api during query execution - relevant when there
is no vector search query, dataset is large and the filtered result is
larger than memory.

---------

Co-authored-by: Ishani Ghose <isghose@amazon.com>
Co-authored-by: Chang She <759245+changhiskhan@users.noreply.github.com>
This commit is contained in:
Ishani Ghose
2024-03-20 13:38:06 -07:00
committed by Weston Pace
parent 968c62cb8f
commit 0838e12b30
5 changed files with 81 additions and 21 deletions

View File

@@ -16,7 +16,16 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Type, Union
from typing import (
TYPE_CHECKING,
Dict,
List,
Literal,
Optional,
Tuple,
Type,
Union,
)
import deprecation
import numpy as np
@@ -515,6 +524,21 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
and also the "_distance" column which is the distance between the query
vector and the returned vectors.
"""
return self.to_batches().read_all()
def to_batches(self, /, batch_size: Optional[int] = None) -> pa.RecordBatchReader:
"""
Execute the query and return the result as a RecordBatchReader object.
Parameters
----------
batch_size: int
The maximum number of selected records in a RecordBatch object.
Returns
-------
pa.RecordBatchReader
"""
vector = self._query if isinstance(self._query, list) else self._query.tolist()
if isinstance(vector[0], np.ndarray):
vector = [v.tolist() for v in vector]
@@ -530,9 +554,14 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
vector_column=self._vector_column,
with_row_id=self._with_row_id,
)
result_set = self._table._execute_query(query)
result_set = self._table._execute_query(query, batch_size)
if self._reranker is not None:
result_set = self._reranker.rerank_vector(self._str_query, result_set)
rs_table = result_set.read_all()
result_set = self._reranker.rerank_vector(self._str_query, rs_table)
# convert result_set back to RecordBatchReader
result_set = pa.RecordBatchReader.from_batches(
result_set.schema, result_set.to_batches()
)
return result_set

View File

@@ -295,7 +295,9 @@ class RemoteTable(Table):
vector_column_name = inf_vector_column_query(self.schema)
return LanceVectorQueryBuilder(self, query, vector_column_name)
def _execute_query(self, query: Query) -> pa.Table:
def _execute_query(
self, query: Query, batch_size: Optional[int] = None
) -> pa.RecordBatchReader:
if (
query.vector is not None
and len(query.vector) > 0
@@ -321,13 +323,12 @@ class RemoteTable(Table):
q = query.copy()
q.vector = v
results.append(submit(self._name, q))
return pa.concat_tables(
[add_index(r.result().to_arrow(), i) for i, r in enumerate(results)]
)
).to_reader()
else:
result = self._conn._client.query(self._name, query)
return result.to_arrow()
return result.to_arrow().to_reader()
def _do_merge(
self,

View File

@@ -567,7 +567,9 @@ class Table(ABC):
raise NotImplementedError
@abstractmethod
def _execute_query(self, query: Query) -> pa.Table:
def _execute_query(
self, query: Query, batch_size: Optional[int] = None
) -> pa.RecordBatchReader:
pass
@abstractmethod
@@ -1588,10 +1590,11 @@ class LanceTable(Table):
self._dataset_mut.update(values_sql, where)
def _execute_query(self, query: Query) -> pa.Table:
def _execute_query(
self, query: Query, batch_size: Optional[int] = None
) -> pa.RecordBatchReader:
ds = self.to_lance()
return ds.to_table(
return ds.scanner(
columns=query.columns,
filter=query.filter,
prefilter=query.prefilter,
@@ -1604,7 +1607,8 @@ class LanceTable(Table):
"refine_factor": query.refine_factor,
},
with_row_id=query.with_row_id,
)
batch_size=batch_size,
).to_reader()
def _do_merge(
self,
@@ -2153,13 +2157,17 @@ class AsyncTable:
) -> AsyncVectorQuery:
"""
Search the table with a given query vector.
This is a convenience method for preparing a vector query and
is the same thing as calling `nearestTo` on the builder returned
by `query`. Seer [nearest_to][AsyncQuery.nearest_to] for more details.
"""
return self.query().nearest_to(query_vector)
async def _execute_query(
self, query: Query, batch_size: Optional[int] = None
) -> pa.RecordBatchReader:
pass
async def _do_merge(
self,
merge: LanceMergeInsertBuilder,

View File

@@ -13,6 +13,7 @@
import unittest.mock as mock
from datetime import timedelta
from typing import Optional
import lance
import lancedb
@@ -35,9 +36,9 @@ class MockTable:
def to_lance(self):
return lance.dataset(self.uri)
def _execute_query(self, query):
def _execute_query(self, query, batch_size: Optional[int] = None):
ds = self.to_lance()
return ds.to_table(
return ds.scanner(
columns=query.columns,
filter=query.filter,
prefilter=query.prefilter,
@@ -49,7 +50,8 @@ class MockTable:
"nprobes": query.nprobes,
"refine_factor": query.refine_factor,
},
)
batch_size=batch_size,
).to_reader()
@pytest.fixture
@@ -115,6 +117,25 @@ def test_query_builder(table):
assert all(np.array(rs[0]["vector"]) == [1, 2])
def test_query_builder_batches(table):
rs = (
LanceVectorQueryBuilder(table, [0, 0], "vector")
.limit(2)
.select(["id", "vector"])
.to_batches(1)
)
rs_list = []
for item in rs:
rs_list.append(item)
assert isinstance(item, pa.RecordBatch)
assert len(rs_list) == 1
assert len(rs_list[0]["id"]) == 2
assert all(rs_list[0].to_pandas()["vector"][0] == [1.0, 2.0])
assert rs_list[0].to_pandas()["id"][0] == 1
assert all(rs_list[0].to_pandas()["vector"][1] == [3.0, 4.0])
assert rs_list[0].to_pandas()["id"][1] == 2
def test_dynamic_projection(table):
rs = (
LanceVectorQueryBuilder(table, [0, 0], "vector")
@@ -199,7 +220,8 @@ def test_query_builder_with_different_vector_column():
nprobes=20,
refine_factor=None,
vector_column="foo_vector",
)
),
None,
)

View File

@@ -163,7 +163,7 @@ def test_cohere_reranker(tmp_path):
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
result_explicit = (
table.search(query_vector)
.rerank(reranker=reranker, query=query)
.rerank(reranker=reranker, query_string=query)
.limit(30)
.to_arrow()
)
@@ -225,7 +225,7 @@ def test_cross_encoder_reranker(tmp_path):
result_explicit = (
table.search(query_vector)
.rerank(reranker=reranker, query=query)
.rerank(reranker=reranker, query_string=query)
.limit(30)
.to_arrow()
)
@@ -286,7 +286,7 @@ def test_colbert_reranker(tmp_path):
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
result_explicit = (
table.search(query_vector)
.rerank(reranker=reranker, query=query)
.rerank(reranker=reranker, query_string=query)
.limit(30)
.to_arrow()
)
@@ -351,7 +351,7 @@ def test_openai_reranker(tmp_path):
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
result_explicit = (
table.search(query_vector)
.rerank(reranker=reranker, query=query)
.rerank(reranker=reranker, query_string=query)
.limit(30)
.to_arrow()
)