mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 07:09:57 +00:00
SDK Python Description Exposes pyarrow batch api during query execution - relevant when there is no vector search query, dataset is large and the filtered result is larger than memory. --------- Co-authored-by: Ishani Ghose <isghose@amazon.com> Co-authored-by: Chang She <759245+changhiskhan@users.noreply.github.com>
This commit is contained in:
committed by
Weston Pace
parent
968c62cb8f
commit
0838e12b30
@@ -16,7 +16,16 @@ from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Type, Union
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
import deprecation
|
||||
import numpy as np
|
||||
@@ -515,6 +524,21 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
and also the "_distance" column which is the distance between the query
|
||||
vector and the returned vectors.
|
||||
"""
|
||||
return self.to_batches().read_all()
|
||||
|
||||
def to_batches(self, /, batch_size: Optional[int] = None) -> pa.RecordBatchReader:
|
||||
"""
|
||||
Execute the query and return the result as a RecordBatchReader object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
batch_size: int
|
||||
The maximum number of selected records in a RecordBatch object.
|
||||
|
||||
Returns
|
||||
-------
|
||||
pa.RecordBatchReader
|
||||
"""
|
||||
vector = self._query if isinstance(self._query, list) else self._query.tolist()
|
||||
if isinstance(vector[0], np.ndarray):
|
||||
vector = [v.tolist() for v in vector]
|
||||
@@ -530,9 +554,14 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
vector_column=self._vector_column,
|
||||
with_row_id=self._with_row_id,
|
||||
)
|
||||
result_set = self._table._execute_query(query)
|
||||
result_set = self._table._execute_query(query, batch_size)
|
||||
if self._reranker is not None:
|
||||
result_set = self._reranker.rerank_vector(self._str_query, result_set)
|
||||
rs_table = result_set.read_all()
|
||||
result_set = self._reranker.rerank_vector(self._str_query, rs_table)
|
||||
# convert result_set back to RecordBatchReader
|
||||
result_set = pa.RecordBatchReader.from_batches(
|
||||
result_set.schema, result_set.to_batches()
|
||||
)
|
||||
|
||||
return result_set
|
||||
|
||||
|
||||
@@ -295,7 +295,9 @@ class RemoteTable(Table):
|
||||
vector_column_name = inf_vector_column_query(self.schema)
|
||||
return LanceVectorQueryBuilder(self, query, vector_column_name)
|
||||
|
||||
def _execute_query(self, query: Query) -> pa.Table:
|
||||
def _execute_query(
|
||||
self, query: Query, batch_size: Optional[int] = None
|
||||
) -> pa.RecordBatchReader:
|
||||
if (
|
||||
query.vector is not None
|
||||
and len(query.vector) > 0
|
||||
@@ -321,13 +323,12 @@ class RemoteTable(Table):
|
||||
q = query.copy()
|
||||
q.vector = v
|
||||
results.append(submit(self._name, q))
|
||||
|
||||
return pa.concat_tables(
|
||||
[add_index(r.result().to_arrow(), i) for i, r in enumerate(results)]
|
||||
)
|
||||
).to_reader()
|
||||
else:
|
||||
result = self._conn._client.query(self._name, query)
|
||||
return result.to_arrow()
|
||||
return result.to_arrow().to_reader()
|
||||
|
||||
def _do_merge(
|
||||
self,
|
||||
|
||||
@@ -567,7 +567,9 @@ class Table(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _execute_query(self, query: Query) -> pa.Table:
|
||||
def _execute_query(
|
||||
self, query: Query, batch_size: Optional[int] = None
|
||||
) -> pa.RecordBatchReader:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -1588,10 +1590,11 @@ class LanceTable(Table):
|
||||
|
||||
self._dataset_mut.update(values_sql, where)
|
||||
|
||||
def _execute_query(self, query: Query) -> pa.Table:
|
||||
def _execute_query(
|
||||
self, query: Query, batch_size: Optional[int] = None
|
||||
) -> pa.RecordBatchReader:
|
||||
ds = self.to_lance()
|
||||
|
||||
return ds.to_table(
|
||||
return ds.scanner(
|
||||
columns=query.columns,
|
||||
filter=query.filter,
|
||||
prefilter=query.prefilter,
|
||||
@@ -1604,7 +1607,8 @@ class LanceTable(Table):
|
||||
"refine_factor": query.refine_factor,
|
||||
},
|
||||
with_row_id=query.with_row_id,
|
||||
)
|
||||
batch_size=batch_size,
|
||||
).to_reader()
|
||||
|
||||
def _do_merge(
|
||||
self,
|
||||
@@ -2153,13 +2157,17 @@ class AsyncTable:
|
||||
) -> AsyncVectorQuery:
|
||||
"""
|
||||
Search the table with a given query vector.
|
||||
|
||||
This is a convenience method for preparing a vector query and
|
||||
is the same thing as calling `nearestTo` on the builder returned
|
||||
by `query`. Seer [nearest_to][AsyncQuery.nearest_to] for more details.
|
||||
"""
|
||||
return self.query().nearest_to(query_vector)
|
||||
|
||||
async def _execute_query(
|
||||
self, query: Query, batch_size: Optional[int] = None
|
||||
) -> pa.RecordBatchReader:
|
||||
pass
|
||||
|
||||
async def _do_merge(
|
||||
self,
|
||||
merge: LanceMergeInsertBuilder,
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
|
||||
import unittest.mock as mock
|
||||
from datetime import timedelta
|
||||
from typing import Optional
|
||||
|
||||
import lance
|
||||
import lancedb
|
||||
@@ -35,9 +36,9 @@ class MockTable:
|
||||
def to_lance(self):
|
||||
return lance.dataset(self.uri)
|
||||
|
||||
def _execute_query(self, query):
|
||||
def _execute_query(self, query, batch_size: Optional[int] = None):
|
||||
ds = self.to_lance()
|
||||
return ds.to_table(
|
||||
return ds.scanner(
|
||||
columns=query.columns,
|
||||
filter=query.filter,
|
||||
prefilter=query.prefilter,
|
||||
@@ -49,7 +50,8 @@ class MockTable:
|
||||
"nprobes": query.nprobes,
|
||||
"refine_factor": query.refine_factor,
|
||||
},
|
||||
)
|
||||
batch_size=batch_size,
|
||||
).to_reader()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -115,6 +117,25 @@ def test_query_builder(table):
|
||||
assert all(np.array(rs[0]["vector"]) == [1, 2])
|
||||
|
||||
|
||||
def test_query_builder_batches(table):
|
||||
rs = (
|
||||
LanceVectorQueryBuilder(table, [0, 0], "vector")
|
||||
.limit(2)
|
||||
.select(["id", "vector"])
|
||||
.to_batches(1)
|
||||
)
|
||||
rs_list = []
|
||||
for item in rs:
|
||||
rs_list.append(item)
|
||||
assert isinstance(item, pa.RecordBatch)
|
||||
assert len(rs_list) == 1
|
||||
assert len(rs_list[0]["id"]) == 2
|
||||
assert all(rs_list[0].to_pandas()["vector"][0] == [1.0, 2.0])
|
||||
assert rs_list[0].to_pandas()["id"][0] == 1
|
||||
assert all(rs_list[0].to_pandas()["vector"][1] == [3.0, 4.0])
|
||||
assert rs_list[0].to_pandas()["id"][1] == 2
|
||||
|
||||
|
||||
def test_dynamic_projection(table):
|
||||
rs = (
|
||||
LanceVectorQueryBuilder(table, [0, 0], "vector")
|
||||
@@ -199,7 +220,8 @@ def test_query_builder_with_different_vector_column():
|
||||
nprobes=20,
|
||||
refine_factor=None,
|
||||
vector_column="foo_vector",
|
||||
)
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -163,7 +163,7 @@ def test_cohere_reranker(tmp_path):
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
result_explicit = (
|
||||
table.search(query_vector)
|
||||
.rerank(reranker=reranker, query=query)
|
||||
.rerank(reranker=reranker, query_string=query)
|
||||
.limit(30)
|
||||
.to_arrow()
|
||||
)
|
||||
@@ -225,7 +225,7 @@ def test_cross_encoder_reranker(tmp_path):
|
||||
|
||||
result_explicit = (
|
||||
table.search(query_vector)
|
||||
.rerank(reranker=reranker, query=query)
|
||||
.rerank(reranker=reranker, query_string=query)
|
||||
.limit(30)
|
||||
.to_arrow()
|
||||
)
|
||||
@@ -286,7 +286,7 @@ def test_colbert_reranker(tmp_path):
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
result_explicit = (
|
||||
table.search(query_vector)
|
||||
.rerank(reranker=reranker, query=query)
|
||||
.rerank(reranker=reranker, query_string=query)
|
||||
.limit(30)
|
||||
.to_arrow()
|
||||
)
|
||||
@@ -351,7 +351,7 @@ def test_openai_reranker(tmp_path):
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
result_explicit = (
|
||||
table.search(query_vector)
|
||||
.rerank(reranker=reranker, query=query)
|
||||
.rerank(reranker=reranker, query_string=query)
|
||||
.limit(30)
|
||||
.to_arrow()
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user