feat: add a pyarrow dataset adapater for LanceDB tables (#1902)

This currently only works for local tables (remote tables cannot be
queried)
This is also exclusive to the sync interface. However, since the pyarrow
dataset interface is synchronous I am not sure if there is much value in
making an async-wrapping variant.

In addition, I added a `to_batches` method to the base query in the sync
API. This already exists in the async API. In the sync API this PR only
adds support for vector queries and scalar queries and not for hybrid or
FTS queries.
This commit is contained in:
Weston Pace
2024-12-03 15:42:54 -08:00
committed by GitHub
parent d8c758513c
commit c998a47e17
5 changed files with 334 additions and 1 deletions

View File

@@ -0,0 +1,248 @@
import logging
from typing import Any, List, Optional, Tuple, Union, Literal
import pyarrow as pa
from ..table import Table
Filter = Union[str, pa.compute.Expression]
Keys = Union[str, List[str]]
JoinType = Literal[
"left semi",
"right semi",
"left anti",
"right anti",
"inner",
"left outer",
"right outer",
"full outer",
]
class PyarrowScannerAdapter(pa.dataset.Scanner):
def __init__(
self,
table: Table,
columns: Optional[List[str]] = None,
filter: Optional[Filter] = None,
batch_size: Optional[int] = None,
batch_readahead: Optional[int] = None,
fragment_readahead: Optional[int] = None,
fragment_scan_options: Optional[Any] = None,
use_threads: bool = True,
memory_pool: Optional[Any] = None,
):
self.table = table
self.columns = columns
self.filter = filter
self.batch_size = batch_size
if batch_readahead is not None:
logging.debug("ignoring batch_readahead which has no lance equivalent")
if fragment_readahead is not None:
logging.debug("ignoring fragment_readahead which has no lance equivalent")
if fragment_scan_options is not None:
raise NotImplementedError("fragment_scan_options not supported")
if use_threads is False:
raise NotImplementedError("use_threads=False not supported")
if memory_pool is not None:
raise NotImplementedError("memory_pool not supported")
def count_rows(self):
return self.table.count_rows(self.filter)
def from_batches(self, **kwargs):
raise NotImplementedError
def from_dataset(self, **kwargs):
raise NotImplementedError
def from_fragment(self, **kwargs):
raise NotImplementedError
def head(self, num_rows: int):
return self.to_reader(limit=num_rows).read_all()
@property
def projected_schema(self):
return self.head(1).schema
def scan_batches(self):
return self.to_reader()
def take(self, indices: List[int]):
raise NotImplementedError
def to_batches(self):
return self.to_reader()
def to_table(self):
return self.to_reader().read_all()
def to_reader(self, *, limit: Optional[int] = None):
query = self.table.search()
# Disable the builtin limit
if limit is None:
num_rows = self.count_rows()
query.limit(num_rows)
elif limit <= 0:
raise ValueError("limit must be positive")
else:
query.limit(limit)
if self.columns is not None:
query = query.select(self.columns)
if self.filter is not None:
query = query.where(self.filter, prefilter=True)
return query.to_batches(batch_size=self.batch_size)
class PyarrowDatasetAdapter(pa.dataset.Dataset):
def __init__(self, table: Table):
self.table = table
def count_rows(self, filter: Optional[Filter] = None):
return self.table.count_rows(filter)
def get_fragments(self, filter: Optional[Filter] = None):
raise NotImplementedError
def head(
self,
num_rows: int,
columns: Optional[List[str]] = None,
filter: Optional[Filter] = None,
batch_size: Optional[int] = None,
batch_readahead: Optional[int] = None,
fragment_readahead: Optional[int] = None,
fragment_scan_options: Optional[Any] = None,
use_threads: bool = True,
memory_pool: Optional[Any] = None,
):
return self.scanner(
columns,
filter,
batch_size,
batch_readahead,
fragment_readahead,
fragment_scan_options,
use_threads,
memory_pool,
).head(num_rows)
def join(
self,
right_dataset: Any,
keys: Keys,
right_keys: Optional[Keys] = None,
join_type: Optional[JoinType] = None,
left_suffix: Optional[str] = None,
right_suffix: Optional[str] = None,
coalesce_keys: bool = True,
use_threads: bool = True,
):
raise NotImplementedError
def join_asof(
self,
right_dataset: Any,
on: str,
by: Keys,
tolerance: int,
right_on: Optional[str] = None,
right_by: Optional[Keys] = None,
):
raise NotImplementedError
@property
def partition_expression(self):
raise NotImplementedError
def replace_schema(self, schema: pa.Schema):
raise NotImplementedError
def scanner(
self,
columns: Optional[List[str]] = None,
filter: Optional[Filter] = None,
batch_size: Optional[int] = None,
batch_readahead: Optional[int] = None,
fragment_readahead: Optional[int] = None,
fragment_scan_options: Optional[Any] = None,
use_threads: bool = True,
memory_pool: Optional[Any] = None,
):
return PyarrowScannerAdapter(
self.table,
columns,
filter,
batch_size,
batch_readahead,
fragment_readahead,
fragment_scan_options,
use_threads,
memory_pool,
)
@property
def schema(self):
return self.table.schema
def sort_by(self, sorting: Union[str, List[Tuple[str, bool]]]):
raise NotImplementedError
def take(
self,
indices: List[int],
columns: Optional[List[str]] = None,
filter: Optional[Filter] = None,
batch_size: Optional[int] = None,
batch_readahead: Optional[int] = None,
fragment_readahead: Optional[int] = None,
fragment_scan_options: Optional[Any] = None,
use_threads: bool = True,
memory_pool: Optional[Any] = None,
):
raise NotImplementedError
def to_batches(
self,
columns: Optional[List[str]] = None,
filter: Optional[Filter] = None,
batch_size: Optional[int] = None,
batch_readahead: Optional[int] = None,
fragment_readahead: Optional[int] = None,
fragment_scan_options: Optional[Any] = None,
use_threads: bool = True,
memory_pool: Optional[Any] = None,
):
return self.scanner(
columns,
filter,
batch_size,
batch_readahead,
fragment_readahead,
fragment_scan_options,
use_threads,
memory_pool,
).to_batches()
def to_table(
self,
columns: Optional[List[str]] = None,
filter: Optional[Filter] = None,
batch_size: Optional[int] = None,
batch_readahead: Optional[int] = None,
fragment_readahead: Optional[int] = None,
fragment_scan_options: Optional[Any] = None,
use_threads: bool = True,
memory_pool: Optional[Any] = None,
):
return self.scanner(
columns,
filter,
batch_size,
batch_readahead,
fragment_readahead,
fragment_scan_options,
use_threads,
memory_pool,
).to_table()

View File

@@ -325,6 +325,14 @@ class LanceQueryBuilder(ABC):
"""
raise NotImplementedError
@abstractmethod
def to_batches(self, /, batch_size: Optional[int] = None) -> pa.Table:
"""
Execute the query and return the results as a pyarrow
[RecordBatchReader](https://arrow.apache.org/docs/python/generated/pyarrow.RecordBatchReader.html)
"""
raise NotImplementedError
def to_list(self) -> List[dict]:
"""
Execute the query and return the results as a list of dictionaries.
@@ -869,6 +877,9 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
check_reranker_result(results)
return results
def to_batches(self, /, batch_size: Optional[int] = None):
raise NotImplementedError("to_batches on an FTS query")
def tantivy_to_arrow(self) -> pa.Table:
try:
import tantivy
@@ -971,6 +982,9 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
class LanceEmptyQueryBuilder(LanceQueryBuilder):
def to_arrow(self) -> pa.Table:
return self.to_batches().read_all()
def to_batches(self, /, batch_size: Optional[int] = None) -> pa.RecordBatchReader:
query = Query(
columns=self._columns,
filter=self._where,
@@ -980,7 +994,7 @@ class LanceEmptyQueryBuilder(LanceQueryBuilder):
# not actually respected in remote query
offset=self._offset or 0,
)
return self._table._execute_query(query).read_all()
return self._table._execute_query(query)
def rerank(self, reranker: Reranker) -> LanceEmptyQueryBuilder:
"""Rerank the results using the specified reranker.
@@ -1135,6 +1149,9 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
results = results.drop(["_rowid"])
return results
def to_batches(self):
raise NotImplementedError("to_batches not yet supported on a hybrid query")
def _rank(self, results: pa.Table, column: str, ascending: bool = True):
if len(results) == 0:
return results

View File

@@ -0,0 +1,21 @@
import duckdb
import pyarrow as pa
import lancedb
from lancedb.integrations.pyarrow import PyarrowDatasetAdapter
def test_basic_query(tmp_path):
data = pa.table({"x": [1, 2, 3, 4], "y": [5, 6, 7, 8]})
conn = lancedb.connect(tmp_path)
tbl = conn.create_table("test", data)
adapter = PyarrowDatasetAdapter(tbl) # noqa: F841
duck_conn = duckdb.connect()
results = duck_conn.sql("SELECT SUM(x) FROM adapter").fetchall()
assert results[0][0] == 10
results = duck_conn.sql("SELECT SUM(y) FROM adapter").fetchall()
assert results[0][0] == 26

View File

@@ -0,0 +1,47 @@
import pyarrow as pa
import lancedb
from lancedb.integrations.pyarrow import PyarrowDatasetAdapter
def test_dataset_adapter(tmp_path):
data = pa.table({"x": [1, 2, 3, 4], "y": [5, 6, 7, 8]})
conn = lancedb.connect(tmp_path)
tbl = conn.create_table("test", data)
adapter = PyarrowDatasetAdapter(tbl)
assert adapter.count_rows() == 4
assert adapter.count_rows("x > 2") == 2
assert adapter.schema == data.schema
assert adapter.head(2) == data.slice(0, 2)
assert adapter.to_table() == data
assert adapter.to_batches().read_all() == data
assert adapter.scanner().to_table() == data
assert adapter.scanner().to_batches().read_all() == data
assert adapter.scanner().projected_schema == data.schema
assert adapter.scanner(columns=["x"]).projected_schema == pa.schema(
[data.schema.field("x")]
)
assert adapter.scanner(columns=["x"]).to_table() == pa.table({"x": [1, 2, 3, 4]})
# Make sure we bypass the limit
data = pa.table({"x": range(100)})
tbl = conn.create_table("test2", data)
adapter = PyarrowDatasetAdapter(tbl)
assert adapter.count_rows() == 100
assert adapter.to_table().num_rows == 100
assert adapter.head(10).num_rows == 10
# Empty table
tbl = conn.create_table("test3", None, schema=pa.schema({"x": pa.int64()}))
adapter = PyarrowDatasetAdapter(tbl)
assert adapter.count_rows() == 0
assert adapter.to_table().num_rows == 0
assert adapter.head(10).num_rows == 0
assert adapter.scanner().projected_schema == pa.schema({"x": pa.int64()})