From c998a47e178fe9d840e53688e8e88d8cc14f3867 Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Tue, 3 Dec 2024 15:42:54 -0800 Subject: [PATCH] feat: add a pyarrow dataset adapater for LanceDB tables (#1902) This currently only works for local tables (remote tables cannot be queried) This is also exclusive to the sync interface. However, since the pyarrow dataset interface is synchronous I am not sure if there is much value in making an async-wrapping variant. In addition, I added a `to_batches` method to the base query in the sync API. This already exists in the async API. In the sync API this PR only adds support for vector queries and scalar queries and not for hybrid or FTS queries. --- .../python/lancedb/integrations/__init__.py | 0 python/python/lancedb/integrations/pyarrow.py | 248 ++++++++++++++++++ python/python/lancedb/query.py | 19 +- python/python/tests/test_duckdb.py | 21 ++ python/python/tests/test_pyarrow.py | 47 ++++ 5 files changed, 334 insertions(+), 1 deletion(-) create mode 100644 python/python/lancedb/integrations/__init__.py create mode 100644 python/python/lancedb/integrations/pyarrow.py create mode 100644 python/python/tests/test_duckdb.py create mode 100644 python/python/tests/test_pyarrow.py diff --git a/python/python/lancedb/integrations/__init__.py b/python/python/lancedb/integrations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/python/lancedb/integrations/pyarrow.py b/python/python/lancedb/integrations/pyarrow.py new file mode 100644 index 00000000..2225ef94 --- /dev/null +++ b/python/python/lancedb/integrations/pyarrow.py @@ -0,0 +1,248 @@ +import logging +from typing import Any, List, Optional, Tuple, Union, Literal + +import pyarrow as pa + +from ..table import Table + +Filter = Union[str, pa.compute.Expression] +Keys = Union[str, List[str]] +JoinType = Literal[ + "left semi", + "right semi", + "left anti", + "right anti", + "inner", + "left outer", + "right outer", + "full outer", +] + + +class PyarrowScannerAdapter(pa.dataset.Scanner): + def __init__( + self, + table: Table, + columns: Optional[List[str]] = None, + filter: Optional[Filter] = None, + batch_size: Optional[int] = None, + batch_readahead: Optional[int] = None, + fragment_readahead: Optional[int] = None, + fragment_scan_options: Optional[Any] = None, + use_threads: bool = True, + memory_pool: Optional[Any] = None, + ): + self.table = table + self.columns = columns + self.filter = filter + self.batch_size = batch_size + if batch_readahead is not None: + logging.debug("ignoring batch_readahead which has no lance equivalent") + if fragment_readahead is not None: + logging.debug("ignoring fragment_readahead which has no lance equivalent") + if fragment_scan_options is not None: + raise NotImplementedError("fragment_scan_options not supported") + if use_threads is False: + raise NotImplementedError("use_threads=False not supported") + if memory_pool is not None: + raise NotImplementedError("memory_pool not supported") + + def count_rows(self): + return self.table.count_rows(self.filter) + + def from_batches(self, **kwargs): + raise NotImplementedError + + def from_dataset(self, **kwargs): + raise NotImplementedError + + def from_fragment(self, **kwargs): + raise NotImplementedError + + def head(self, num_rows: int): + return self.to_reader(limit=num_rows).read_all() + + @property + def projected_schema(self): + return self.head(1).schema + + def scan_batches(self): + return self.to_reader() + + def take(self, indices: List[int]): + raise NotImplementedError + + def to_batches(self): + return self.to_reader() + + def to_table(self): + return self.to_reader().read_all() + + def to_reader(self, *, limit: Optional[int] = None): + query = self.table.search() + # Disable the builtin limit + if limit is None: + num_rows = self.count_rows() + query.limit(num_rows) + elif limit <= 0: + raise ValueError("limit must be positive") + else: + query.limit(limit) + if self.columns is not None: + query = query.select(self.columns) + if self.filter is not None: + query = query.where(self.filter, prefilter=True) + return query.to_batches(batch_size=self.batch_size) + + +class PyarrowDatasetAdapter(pa.dataset.Dataset): + def __init__(self, table: Table): + self.table = table + + def count_rows(self, filter: Optional[Filter] = None): + return self.table.count_rows(filter) + + def get_fragments(self, filter: Optional[Filter] = None): + raise NotImplementedError + + def head( + self, + num_rows: int, + columns: Optional[List[str]] = None, + filter: Optional[Filter] = None, + batch_size: Optional[int] = None, + batch_readahead: Optional[int] = None, + fragment_readahead: Optional[int] = None, + fragment_scan_options: Optional[Any] = None, + use_threads: bool = True, + memory_pool: Optional[Any] = None, + ): + return self.scanner( + columns, + filter, + batch_size, + batch_readahead, + fragment_readahead, + fragment_scan_options, + use_threads, + memory_pool, + ).head(num_rows) + + def join( + self, + right_dataset: Any, + keys: Keys, + right_keys: Optional[Keys] = None, + join_type: Optional[JoinType] = None, + left_suffix: Optional[str] = None, + right_suffix: Optional[str] = None, + coalesce_keys: bool = True, + use_threads: bool = True, + ): + raise NotImplementedError + + def join_asof( + self, + right_dataset: Any, + on: str, + by: Keys, + tolerance: int, + right_on: Optional[str] = None, + right_by: Optional[Keys] = None, + ): + raise NotImplementedError + + @property + def partition_expression(self): + raise NotImplementedError + + def replace_schema(self, schema: pa.Schema): + raise NotImplementedError + + def scanner( + self, + columns: Optional[List[str]] = None, + filter: Optional[Filter] = None, + batch_size: Optional[int] = None, + batch_readahead: Optional[int] = None, + fragment_readahead: Optional[int] = None, + fragment_scan_options: Optional[Any] = None, + use_threads: bool = True, + memory_pool: Optional[Any] = None, + ): + return PyarrowScannerAdapter( + self.table, + columns, + filter, + batch_size, + batch_readahead, + fragment_readahead, + fragment_scan_options, + use_threads, + memory_pool, + ) + + @property + def schema(self): + return self.table.schema + + def sort_by(self, sorting: Union[str, List[Tuple[str, bool]]]): + raise NotImplementedError + + def take( + self, + indices: List[int], + columns: Optional[List[str]] = None, + filter: Optional[Filter] = None, + batch_size: Optional[int] = None, + batch_readahead: Optional[int] = None, + fragment_readahead: Optional[int] = None, + fragment_scan_options: Optional[Any] = None, + use_threads: bool = True, + memory_pool: Optional[Any] = None, + ): + raise NotImplementedError + + def to_batches( + self, + columns: Optional[List[str]] = None, + filter: Optional[Filter] = None, + batch_size: Optional[int] = None, + batch_readahead: Optional[int] = None, + fragment_readahead: Optional[int] = None, + fragment_scan_options: Optional[Any] = None, + use_threads: bool = True, + memory_pool: Optional[Any] = None, + ): + return self.scanner( + columns, + filter, + batch_size, + batch_readahead, + fragment_readahead, + fragment_scan_options, + use_threads, + memory_pool, + ).to_batches() + + def to_table( + self, + columns: Optional[List[str]] = None, + filter: Optional[Filter] = None, + batch_size: Optional[int] = None, + batch_readahead: Optional[int] = None, + fragment_readahead: Optional[int] = None, + fragment_scan_options: Optional[Any] = None, + use_threads: bool = True, + memory_pool: Optional[Any] = None, + ): + return self.scanner( + columns, + filter, + batch_size, + batch_readahead, + fragment_readahead, + fragment_scan_options, + use_threads, + memory_pool, + ).to_table() diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index dbd0295c..984be7b2 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -325,6 +325,14 @@ class LanceQueryBuilder(ABC): """ raise NotImplementedError + @abstractmethod + def to_batches(self, /, batch_size: Optional[int] = None) -> pa.Table: + """ + Execute the query and return the results as a pyarrow + [RecordBatchReader](https://arrow.apache.org/docs/python/generated/pyarrow.RecordBatchReader.html) + """ + raise NotImplementedError + def to_list(self) -> List[dict]: """ Execute the query and return the results as a list of dictionaries. @@ -869,6 +877,9 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): check_reranker_result(results) return results + def to_batches(self, /, batch_size: Optional[int] = None): + raise NotImplementedError("to_batches on an FTS query") + def tantivy_to_arrow(self) -> pa.Table: try: import tantivy @@ -971,6 +982,9 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): class LanceEmptyQueryBuilder(LanceQueryBuilder): def to_arrow(self) -> pa.Table: + return self.to_batches().read_all() + + def to_batches(self, /, batch_size: Optional[int] = None) -> pa.RecordBatchReader: query = Query( columns=self._columns, filter=self._where, @@ -980,7 +994,7 @@ class LanceEmptyQueryBuilder(LanceQueryBuilder): # not actually respected in remote query offset=self._offset or 0, ) - return self._table._execute_query(query).read_all() + return self._table._execute_query(query) def rerank(self, reranker: Reranker) -> LanceEmptyQueryBuilder: """Rerank the results using the specified reranker. @@ -1135,6 +1149,9 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): results = results.drop(["_rowid"]) return results + def to_batches(self): + raise NotImplementedError("to_batches not yet supported on a hybrid query") + def _rank(self, results: pa.Table, column: str, ascending: bool = True): if len(results) == 0: return results diff --git a/python/python/tests/test_duckdb.py b/python/python/tests/test_duckdb.py new file mode 100644 index 00000000..e122ba94 --- /dev/null +++ b/python/python/tests/test_duckdb.py @@ -0,0 +1,21 @@ +import duckdb +import pyarrow as pa + +import lancedb +from lancedb.integrations.pyarrow import PyarrowDatasetAdapter + + +def test_basic_query(tmp_path): + data = pa.table({"x": [1, 2, 3, 4], "y": [5, 6, 7, 8]}) + conn = lancedb.connect(tmp_path) + tbl = conn.create_table("test", data) + + adapter = PyarrowDatasetAdapter(tbl) # noqa: F841 + + duck_conn = duckdb.connect() + + results = duck_conn.sql("SELECT SUM(x) FROM adapter").fetchall() + assert results[0][0] == 10 + + results = duck_conn.sql("SELECT SUM(y) FROM adapter").fetchall() + assert results[0][0] == 26 diff --git a/python/python/tests/test_pyarrow.py b/python/python/tests/test_pyarrow.py new file mode 100644 index 00000000..d4a1e07c --- /dev/null +++ b/python/python/tests/test_pyarrow.py @@ -0,0 +1,47 @@ +import pyarrow as pa + +import lancedb +from lancedb.integrations.pyarrow import PyarrowDatasetAdapter + + +def test_dataset_adapter(tmp_path): + data = pa.table({"x": [1, 2, 3, 4], "y": [5, 6, 7, 8]}) + conn = lancedb.connect(tmp_path) + tbl = conn.create_table("test", data) + + adapter = PyarrowDatasetAdapter(tbl) + + assert adapter.count_rows() == 4 + assert adapter.count_rows("x > 2") == 2 + assert adapter.schema == data.schema + assert adapter.head(2) == data.slice(0, 2) + assert adapter.to_table() == data + assert adapter.to_batches().read_all() == data + assert adapter.scanner().to_table() == data + assert adapter.scanner().to_batches().read_all() == data + + assert adapter.scanner().projected_schema == data.schema + assert adapter.scanner(columns=["x"]).projected_schema == pa.schema( + [data.schema.field("x")] + ) + assert adapter.scanner(columns=["x"]).to_table() == pa.table({"x": [1, 2, 3, 4]}) + + # Make sure we bypass the limit + data = pa.table({"x": range(100)}) + tbl = conn.create_table("test2", data) + + adapter = PyarrowDatasetAdapter(tbl) + + assert adapter.count_rows() == 100 + assert adapter.to_table().num_rows == 100 + assert adapter.head(10).num_rows == 10 + + # Empty table + tbl = conn.create_table("test3", None, schema=pa.schema({"x": pa.int64()})) + adapter = PyarrowDatasetAdapter(tbl) + + assert adapter.count_rows() == 0 + assert adapter.to_table().num_rows == 0 + assert adapter.head(10).num_rows == 0 + + assert adapter.scanner().projected_schema == pa.schema({"x": pa.int64()})