diff --git a/python/python/lancedb/integrations/__init__.py b/python/python/lancedb/integrations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/python/lancedb/integrations/pyarrow.py b/python/python/lancedb/integrations/pyarrow.py new file mode 100644 index 00000000..2225ef94 --- /dev/null +++ b/python/python/lancedb/integrations/pyarrow.py @@ -0,0 +1,248 @@ +import logging +from typing import Any, List, Optional, Tuple, Union, Literal + +import pyarrow as pa + +from ..table import Table + +Filter = Union[str, pa.compute.Expression] +Keys = Union[str, List[str]] +JoinType = Literal[ + "left semi", + "right semi", + "left anti", + "right anti", + "inner", + "left outer", + "right outer", + "full outer", +] + + +class PyarrowScannerAdapter(pa.dataset.Scanner): + def __init__( + self, + table: Table, + columns: Optional[List[str]] = None, + filter: Optional[Filter] = None, + batch_size: Optional[int] = None, + batch_readahead: Optional[int] = None, + fragment_readahead: Optional[int] = None, + fragment_scan_options: Optional[Any] = None, + use_threads: bool = True, + memory_pool: Optional[Any] = None, + ): + self.table = table + self.columns = columns + self.filter = filter + self.batch_size = batch_size + if batch_readahead is not None: + logging.debug("ignoring batch_readahead which has no lance equivalent") + if fragment_readahead is not None: + logging.debug("ignoring fragment_readahead which has no lance equivalent") + if fragment_scan_options is not None: + raise NotImplementedError("fragment_scan_options not supported") + if use_threads is False: + raise NotImplementedError("use_threads=False not supported") + if memory_pool is not None: + raise NotImplementedError("memory_pool not supported") + + def count_rows(self): + return self.table.count_rows(self.filter) + + def from_batches(self, **kwargs): + raise NotImplementedError + + def from_dataset(self, **kwargs): + raise NotImplementedError + + def from_fragment(self, **kwargs): + raise NotImplementedError + + def head(self, num_rows: int): + return self.to_reader(limit=num_rows).read_all() + + @property + def projected_schema(self): + return self.head(1).schema + + def scan_batches(self): + return self.to_reader() + + def take(self, indices: List[int]): + raise NotImplementedError + + def to_batches(self): + return self.to_reader() + + def to_table(self): + return self.to_reader().read_all() + + def to_reader(self, *, limit: Optional[int] = None): + query = self.table.search() + # Disable the builtin limit + if limit is None: + num_rows = self.count_rows() + query.limit(num_rows) + elif limit <= 0: + raise ValueError("limit must be positive") + else: + query.limit(limit) + if self.columns is not None: + query = query.select(self.columns) + if self.filter is not None: + query = query.where(self.filter, prefilter=True) + return query.to_batches(batch_size=self.batch_size) + + +class PyarrowDatasetAdapter(pa.dataset.Dataset): + def __init__(self, table: Table): + self.table = table + + def count_rows(self, filter: Optional[Filter] = None): + return self.table.count_rows(filter) + + def get_fragments(self, filter: Optional[Filter] = None): + raise NotImplementedError + + def head( + self, + num_rows: int, + columns: Optional[List[str]] = None, + filter: Optional[Filter] = None, + batch_size: Optional[int] = None, + batch_readahead: Optional[int] = None, + fragment_readahead: Optional[int] = None, + fragment_scan_options: Optional[Any] = None, + use_threads: bool = True, + memory_pool: Optional[Any] = None, + ): + return self.scanner( + columns, + filter, + batch_size, + batch_readahead, + fragment_readahead, + fragment_scan_options, + use_threads, + memory_pool, + ).head(num_rows) + + def join( + self, + right_dataset: Any, + keys: Keys, + right_keys: Optional[Keys] = None, + join_type: Optional[JoinType] = None, + left_suffix: Optional[str] = None, + right_suffix: Optional[str] = None, + coalesce_keys: bool = True, + use_threads: bool = True, + ): + raise NotImplementedError + + def join_asof( + self, + right_dataset: Any, + on: str, + by: Keys, + tolerance: int, + right_on: Optional[str] = None, + right_by: Optional[Keys] = None, + ): + raise NotImplementedError + + @property + def partition_expression(self): + raise NotImplementedError + + def replace_schema(self, schema: pa.Schema): + raise NotImplementedError + + def scanner( + self, + columns: Optional[List[str]] = None, + filter: Optional[Filter] = None, + batch_size: Optional[int] = None, + batch_readahead: Optional[int] = None, + fragment_readahead: Optional[int] = None, + fragment_scan_options: Optional[Any] = None, + use_threads: bool = True, + memory_pool: Optional[Any] = None, + ): + return PyarrowScannerAdapter( + self.table, + columns, + filter, + batch_size, + batch_readahead, + fragment_readahead, + fragment_scan_options, + use_threads, + memory_pool, + ) + + @property + def schema(self): + return self.table.schema + + def sort_by(self, sorting: Union[str, List[Tuple[str, bool]]]): + raise NotImplementedError + + def take( + self, + indices: List[int], + columns: Optional[List[str]] = None, + filter: Optional[Filter] = None, + batch_size: Optional[int] = None, + batch_readahead: Optional[int] = None, + fragment_readahead: Optional[int] = None, + fragment_scan_options: Optional[Any] = None, + use_threads: bool = True, + memory_pool: Optional[Any] = None, + ): + raise NotImplementedError + + def to_batches( + self, + columns: Optional[List[str]] = None, + filter: Optional[Filter] = None, + batch_size: Optional[int] = None, + batch_readahead: Optional[int] = None, + fragment_readahead: Optional[int] = None, + fragment_scan_options: Optional[Any] = None, + use_threads: bool = True, + memory_pool: Optional[Any] = None, + ): + return self.scanner( + columns, + filter, + batch_size, + batch_readahead, + fragment_readahead, + fragment_scan_options, + use_threads, + memory_pool, + ).to_batches() + + def to_table( + self, + columns: Optional[List[str]] = None, + filter: Optional[Filter] = None, + batch_size: Optional[int] = None, + batch_readahead: Optional[int] = None, + fragment_readahead: Optional[int] = None, + fragment_scan_options: Optional[Any] = None, + use_threads: bool = True, + memory_pool: Optional[Any] = None, + ): + return self.scanner( + columns, + filter, + batch_size, + batch_readahead, + fragment_readahead, + fragment_scan_options, + use_threads, + memory_pool, + ).to_table() diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index dbd0295c..984be7b2 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -325,6 +325,14 @@ class LanceQueryBuilder(ABC): """ raise NotImplementedError + @abstractmethod + def to_batches(self, /, batch_size: Optional[int] = None) -> pa.Table: + """ + Execute the query and return the results as a pyarrow + [RecordBatchReader](https://arrow.apache.org/docs/python/generated/pyarrow.RecordBatchReader.html) + """ + raise NotImplementedError + def to_list(self) -> List[dict]: """ Execute the query and return the results as a list of dictionaries. @@ -869,6 +877,9 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): check_reranker_result(results) return results + def to_batches(self, /, batch_size: Optional[int] = None): + raise NotImplementedError("to_batches on an FTS query") + def tantivy_to_arrow(self) -> pa.Table: try: import tantivy @@ -971,6 +982,9 @@ class LanceFtsQueryBuilder(LanceQueryBuilder): class LanceEmptyQueryBuilder(LanceQueryBuilder): def to_arrow(self) -> pa.Table: + return self.to_batches().read_all() + + def to_batches(self, /, batch_size: Optional[int] = None) -> pa.RecordBatchReader: query = Query( columns=self._columns, filter=self._where, @@ -980,7 +994,7 @@ class LanceEmptyQueryBuilder(LanceQueryBuilder): # not actually respected in remote query offset=self._offset or 0, ) - return self._table._execute_query(query).read_all() + return self._table._execute_query(query) def rerank(self, reranker: Reranker) -> LanceEmptyQueryBuilder: """Rerank the results using the specified reranker. @@ -1135,6 +1149,9 @@ class LanceHybridQueryBuilder(LanceQueryBuilder): results = results.drop(["_rowid"]) return results + def to_batches(self): + raise NotImplementedError("to_batches not yet supported on a hybrid query") + def _rank(self, results: pa.Table, column: str, ascending: bool = True): if len(results) == 0: return results diff --git a/python/python/tests/test_duckdb.py b/python/python/tests/test_duckdb.py new file mode 100644 index 00000000..e122ba94 --- /dev/null +++ b/python/python/tests/test_duckdb.py @@ -0,0 +1,21 @@ +import duckdb +import pyarrow as pa + +import lancedb +from lancedb.integrations.pyarrow import PyarrowDatasetAdapter + + +def test_basic_query(tmp_path): + data = pa.table({"x": [1, 2, 3, 4], "y": [5, 6, 7, 8]}) + conn = lancedb.connect(tmp_path) + tbl = conn.create_table("test", data) + + adapter = PyarrowDatasetAdapter(tbl) # noqa: F841 + + duck_conn = duckdb.connect() + + results = duck_conn.sql("SELECT SUM(x) FROM adapter").fetchall() + assert results[0][0] == 10 + + results = duck_conn.sql("SELECT SUM(y) FROM adapter").fetchall() + assert results[0][0] == 26 diff --git a/python/python/tests/test_pyarrow.py b/python/python/tests/test_pyarrow.py new file mode 100644 index 00000000..d4a1e07c --- /dev/null +++ b/python/python/tests/test_pyarrow.py @@ -0,0 +1,47 @@ +import pyarrow as pa + +import lancedb +from lancedb.integrations.pyarrow import PyarrowDatasetAdapter + + +def test_dataset_adapter(tmp_path): + data = pa.table({"x": [1, 2, 3, 4], "y": [5, 6, 7, 8]}) + conn = lancedb.connect(tmp_path) + tbl = conn.create_table("test", data) + + adapter = PyarrowDatasetAdapter(tbl) + + assert adapter.count_rows() == 4 + assert adapter.count_rows("x > 2") == 2 + assert adapter.schema == data.schema + assert adapter.head(2) == data.slice(0, 2) + assert adapter.to_table() == data + assert adapter.to_batches().read_all() == data + assert adapter.scanner().to_table() == data + assert adapter.scanner().to_batches().read_all() == data + + assert adapter.scanner().projected_schema == data.schema + assert adapter.scanner(columns=["x"]).projected_schema == pa.schema( + [data.schema.field("x")] + ) + assert adapter.scanner(columns=["x"]).to_table() == pa.table({"x": [1, 2, 3, 4]}) + + # Make sure we bypass the limit + data = pa.table({"x": range(100)}) + tbl = conn.create_table("test2", data) + + adapter = PyarrowDatasetAdapter(tbl) + + assert adapter.count_rows() == 100 + assert adapter.to_table().num_rows == 100 + assert adapter.head(10).num_rows == 10 + + # Empty table + tbl = conn.create_table("test3", None, schema=pa.schema({"x": pa.int64()})) + adapter = PyarrowDatasetAdapter(tbl) + + assert adapter.count_rows() == 0 + assert adapter.to_table().num_rows == 0 + assert adapter.head(10).num_rows == 0 + + assert adapter.scanner().projected_schema == pa.schema({"x": pa.int64()})