From bcfc93cc889b1f3285f0aadb4921bf0625abeb43 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Mon, 20 Jan 2025 16:14:34 -0800 Subject: [PATCH] fix(python): various fixes for async query builders (#2048) This includes several improvements and fixes to the Python Async query builders: 1. The API reference docs show all the methods for each builder 2. The hybrid query builder now has all the same setter methods as the vector search one, so you can now set things like `.distance_type()` on a hybrid query. 3. Re-rankers are now properly hooked up and tested for FTS and vector search. Previously the re-rankers were accidentally bypassed in unit tests, because the builders overrode `.to_arrow()`, but the unit test called `.to_batches()` which was only defined in the base class. Now all builders implement `.to_batches()` and leave `.to_arrow()` to the base class. 4. The `AsyncQueryBase` and `AsyncVectoryQueryBase` setter methods now return `Self`, which provides the appropriate subclass as the type hint return value. Previously, `AsyncQueryBase` had them all hard-coded to `AsyncQuery`, which was unfortunate. (This required bringing in `typing-extensions` for older Python version, but I think it's worth it.) --- docs/src/python/python.md | 13 ++- python/pyproject.toml | 2 +- python/python/lancedb/_lancedb.pyi | 8 +- python/python/lancedb/arrow.py | 47 +++++++--- python/python/lancedb/query.py | 111 ++++++++++++++--------- python/python/tests/test_hybrid_query.py | 1 + python/python/tests/test_query.py | 14 ++- python/src/arrow.rs | 26 ++++-- python/src/lib.rs | 4 +- 9 files changed, 153 insertions(+), 73 deletions(-) diff --git a/docs/src/python/python.md b/docs/src/python/python.md index f59362cb..00dcd84b 100644 --- a/docs/src/python/python.md +++ b/docs/src/python/python.md @@ -147,8 +147,19 @@ to return the entire (typically filtered) table. Vector searches return the rows nearest to a query vector and can be created with the [AsyncTable.vector_search][lancedb.table.AsyncTable.vector_search] method. -::: lancedb.query.AsyncQueryBase ::: lancedb.query.AsyncQuery + options: + inherited_members: true ::: lancedb.query.AsyncVectorQuery + options: + inherited_members: true + +::: lancedb.query.AsyncFTSQuery + options: + inherited_members: true + +::: lancedb.query.AsyncHybridQuery + options: + inherited_members: true diff --git a/python/pyproject.toml b/python/pyproject.toml index 8dcc69e5..71eb1b94 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -55,7 +55,7 @@ tests = [ "tantivy", "pyarrow-stubs", ] -dev = ["ruff", "pre-commit", "pyright"] +dev = ["ruff", "pre-commit", "pyright", 'typing-extensions>=4.0.0; python_version < "3.11"'] docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"] clip = ["torch", "pillow", "open-clip"] embeddings = [ diff --git a/python/python/lancedb/_lancedb.pyi b/python/python/lancedb/_lancedb.pyi index b431e3be..fb8c0ac9 100644 --- a/python/python/lancedb/_lancedb.pyi +++ b/python/python/lancedb/_lancedb.pyi @@ -84,11 +84,15 @@ class RecordBatchStream: class Query: def where(self, filter: str): ... def select(self, columns: Tuple[str, str]): ... + def select_columns(self, columns: List[str]): ... def limit(self, limit: int): ... def offset(self, offset: int): ... + def fast_search(self): ... + def with_row_id(self): ... + def postfilter(self): ... def nearest_to(self, query_vec: pa.Array) -> VectorQuery: ... def nearest_to_text(self, query: dict) -> FTSQuery: ... - async def execute(self, max_batch_legnth: Optional[int]) -> RecordBatchStream: ... + async def execute(self, max_batch_length: Optional[int]) -> RecordBatchStream: ... class FTSQuery: def where(self, filter: str): ... @@ -98,6 +102,8 @@ class FTSQuery: def fast_search(self): ... def with_row_id(self): ... def postfilter(self): ... + def get_query(self) -> str: ... + def add_query_vector(self, query_vec: pa.Array) -> None: ... def nearest_to(self, query_vec: pa.Array) -> HybridQuery: ... async def execute(self, max_batch_length: Optional[int]) -> RecordBatchStream: ... async def explain_plan(self) -> str: ... diff --git a/python/python/lancedb/arrow.py b/python/python/lancedb/arrow.py index 06393e66..602d7df7 100644 --- a/python/python/lancedb/arrow.py +++ b/python/python/lancedb/arrow.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional, Union import pyarrow as pa @@ -12,17 +12,27 @@ class AsyncRecordBatchReader: Also allows access to the schema of the stream """ - def __init__(self, inner: RecordBatchStream): - self.inner_ = inner - - @property - def schema(self) -> pa.Schema: + def __init__( + self, + inner: Union[RecordBatchStream, pa.Table], + max_batch_length: Optional[int] = None, + ): """ - Get the schema of the batches produced by the stream - Accessing the schema does not consume any data from the stream + Attributes + ---------- + schema : pa.Schema + The schema of the batches produced by the stream. + Accessing the schema does not consume any data from the stream """ - return self.inner_.schema() + if isinstance(inner, pa.Table): + self._inner = self._async_iter_from_table(inner, max_batch_length) + self.schema: pa.Schema = inner.schema + elif isinstance(inner, RecordBatchStream): + self._inner = inner + self.schema: pa.Schema = inner.schema + else: + raise TypeError("inner must be a RecordBatchStream or a Table") async def read_all(self) -> List[pa.RecordBatch]: """ @@ -38,7 +48,18 @@ class AsyncRecordBatchReader: return self async def __anext__(self) -> pa.RecordBatch: - next = await self.inner_.next() - if next is None: - raise StopAsyncIteration - return next + return await self._inner.__anext__() + + @staticmethod + async def _async_iter_from_table( + table: pa.Table, max_batch_length: Optional[int] = None + ): + """ + Create an AsyncRecordBatchReader from a Table + + This is useful when you have a Table that you want to iterate + over asynchronously + """ + batches = table.to_batches(max_chunksize=max_batch_length) + for batch in batches: + yield batch diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index 76f41453..98a0bbf1 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -31,6 +31,7 @@ from .rerankers.util import check_reranker_result from .util import safe_import_pandas, flatten_columns if TYPE_CHECKING: + import sys import PIL import polars as pl @@ -42,6 +43,11 @@ if TYPE_CHECKING: from .pydantic import LanceModel from .table import Table + if sys.version_info >= (3, 11): + from typing import Self + else: + from typing_extensions import Self + pd = safe_import_pandas() @@ -1418,7 +1424,7 @@ class AsyncQueryBase(object): """ self._inner = inner - def where(self, predicate: str) -> AsyncQuery: + def where(self, predicate: str) -> Self: """ Only return rows matching the given predicate @@ -1437,7 +1443,7 @@ class AsyncQueryBase(object): self._inner.where(predicate) return self - def select(self, columns: Union[List[str], dict[str, str]]) -> AsyncQuery: + def select(self, columns: Union[List[str], dict[str, str]]) -> Self: """ Return only the specified columns. @@ -1475,7 +1481,7 @@ class AsyncQueryBase(object): raise TypeError("columns must be a list of column names or a dict") return self - def limit(self, limit: int) -> AsyncQuery: + def limit(self, limit: int) -> Self: """ Set the maximum number of results to return. @@ -1485,7 +1491,7 @@ class AsyncQueryBase(object): self._inner.limit(limit) return self - def offset(self, offset: int) -> AsyncQuery: + def offset(self, offset: int) -> Self: """ Set the offset for the results. @@ -1497,7 +1503,7 @@ class AsyncQueryBase(object): self._inner.offset(offset) return self - def fast_search(self) -> AsyncQuery: + def fast_search(self) -> Self: """ Skip searching un-indexed data. @@ -1511,14 +1517,14 @@ class AsyncQueryBase(object): self._inner.fast_search() return self - def with_row_id(self) -> AsyncQuery: + def with_row_id(self) -> Self: """ Include the _rowid column in the results. """ self._inner.with_row_id() return self - def postfilter(self) -> AsyncQuery: + def postfilter(self) -> Self: """ If this is called then filtering will happen after the search instead of before. @@ -1807,8 +1813,8 @@ class AsyncFTSQuery(AsyncQueryBase): self._inner = inner self._reranker = None - def get_query(self): - self._inner.get_query() + def get_query(self) -> str: + return self._inner.get_query() def rerank( self, @@ -1891,29 +1897,18 @@ class AsyncFTSQuery(AsyncQueryBase): self._inner.nearest_to(AsyncQuery._query_vec_to_array(query_vector)) ) - async def to_arrow(self) -> pa.Table: - results = await super().to_arrow() + async def to_batches( + self, *, max_batch_length: Optional[int] = None + ) -> AsyncRecordBatchReader: + reader = await super().to_batches() + results = pa.Table.from_batches(await reader.read_all(), reader.schema) if self._reranker: - results = self._reranker.rerank_fts(results) - return results + results = self._reranker.rerank_fts(self.get_query(), results) + return AsyncRecordBatchReader(results, max_batch_length=max_batch_length) -class AsyncVectorQuery(AsyncQueryBase): - def __init__(self, inner: LanceVectorQuery): - """ - Construct an AsyncVectorQuery - - This method is not intended to be called directly. Instead, create - a query first with [AsyncTable.query][lancedb.table.AsyncTable.query] and then - use [AsyncQuery.nearest_to][lancedb.query.AsyncQuery.nearest_to]] to convert to - a vector query. Or you can use - [AsyncTable.vector_search][lancedb.table.AsyncTable.vector_search] - """ - super().__init__(inner) - self._inner = inner - self._reranker = None - - def column(self, column: str) -> AsyncVectorQuery: +class AsyncVectorQueryBase: + def column(self, column: str) -> Self: """ Set the vector column to query @@ -1926,7 +1921,7 @@ class AsyncVectorQuery(AsyncQueryBase): self._inner.column(column) return self - def nprobes(self, nprobes: int) -> AsyncVectorQuery: + def nprobes(self, nprobes: int) -> Self: """ Set the number of partitions to search (probe) @@ -1954,7 +1949,7 @@ class AsyncVectorQuery(AsyncQueryBase): def distance_range( self, lower_bound: Optional[float] = None, upper_bound: Optional[float] = None - ) -> AsyncVectorQuery: + ) -> Self: """Set the distance range to use. Only rows with distances within range [lower_bound, upper_bound) @@ -1975,7 +1970,7 @@ class AsyncVectorQuery(AsyncQueryBase): self._inner.distance_range(lower_bound, upper_bound) return self - def ef(self, ef: int) -> AsyncVectorQuery: + def ef(self, ef: int) -> Self: """ Set the number of candidates to consider during search @@ -1990,7 +1985,7 @@ class AsyncVectorQuery(AsyncQueryBase): self._inner.ef(ef) return self - def refine_factor(self, refine_factor: int) -> AsyncVectorQuery: + def refine_factor(self, refine_factor: int) -> Self: """ A multiplier to control how many additional rows are taken during the refine step @@ -2026,7 +2021,7 @@ class AsyncVectorQuery(AsyncQueryBase): self._inner.refine_factor(refine_factor) return self - def distance_type(self, distance_type: str) -> AsyncVectorQuery: + def distance_type(self, distance_type: str) -> Self: """ Set the distance metric to use @@ -2044,7 +2039,7 @@ class AsyncVectorQuery(AsyncQueryBase): self._inner.distance_type(distance_type) return self - def bypass_vector_index(self) -> AsyncVectorQuery: + def bypass_vector_index(self) -> Self: """ If this is called then any vector index is skipped @@ -2057,6 +2052,23 @@ class AsyncVectorQuery(AsyncQueryBase): self._inner.bypass_vector_index() return self + +class AsyncVectorQuery(AsyncQueryBase, AsyncVectorQueryBase): + def __init__(self, inner: LanceVectorQuery): + """ + Construct an AsyncVectorQuery + + This method is not intended to be called directly. Instead, create + a query first with [AsyncTable.query][lancedb.table.AsyncTable.query] and then + use [AsyncQuery.nearest_to][lancedb.query.AsyncQuery.nearest_to]] to convert to + a vector query. Or you can use + [AsyncTable.vector_search][lancedb.table.AsyncTable.vector_search] + """ + super().__init__(inner) + self._inner = inner + self._reranker = None + self._query_string = None + def rerank( self, reranker: Reranker = RRFReranker(), query_string: Optional[str] = None ) -> AsyncHybridQuery: @@ -2065,6 +2077,11 @@ class AsyncVectorQuery(AsyncQueryBase): self._reranker = reranker + if not self._query_string and not query_string: + raise ValueError("query_string must be provided to rerank the results.") + + self._query_string = query_string + return self def nearest_to_text( @@ -2100,14 +2117,17 @@ class AsyncVectorQuery(AsyncQueryBase): self._inner.nearest_to_text({"query": query, "columns": columns}) ) - async def to_arrow(self) -> pa.Table: - results = await super().to_arrow() + async def to_batches( + self, *, max_batch_length: Optional[int] = None + ) -> AsyncRecordBatchReader: + reader = await super().to_batches() + results = pa.Table.from_batches(await reader.read_all(), reader.schema) if self._reranker: - results = self._reranker.rerank_vector(results) - return results + results = self._reranker.rerank_vector(self._query_string, results) + return AsyncRecordBatchReader(results, max_batch_length=max_batch_length) -class AsyncHybridQuery(AsyncQueryBase): +class AsyncHybridQuery(AsyncQueryBase, AsyncVectorQueryBase): """ A query builder that performs hybrid vector and full text search. Results are combined and reranked based on the specified reranker. @@ -2155,10 +2175,9 @@ class AsyncHybridQuery(AsyncQueryBase): return self - async def to_batches(self): - raise NotImplementedError("to_batches not yet supported on a hybrid query") - - async def to_arrow(self) -> pa.Table: + async def to_batches( + self, *, max_batch_length: Optional[int] = None + ) -> AsyncRecordBatchReader: fts_query = AsyncFTSQuery(self._inner.to_fts_query()) vec_query = AsyncVectorQuery(self._inner.to_vector_query()) @@ -2173,7 +2192,7 @@ class AsyncHybridQuery(AsyncQueryBase): vec_query.to_arrow(), ) - return LanceHybridQueryBuilder._combine_hybrid_results( + result = LanceHybridQueryBuilder._combine_hybrid_results( fts_results=fts_results, vector_results=vector_results, norm=self._norm, @@ -2183,6 +2202,8 @@ class AsyncHybridQuery(AsyncQueryBase): with_row_ids=with_row_ids, ) + return AsyncRecordBatchReader(result, max_batch_length=max_batch_length) + async def explain_plan(self, verbose: Optional[bool] = False): """Return the execution plan for this query. diff --git a/python/python/tests/test_hybrid_query.py b/python/python/tests/test_hybrid_query.py index 4014e2a8..20f8ea2e 100644 --- a/python/python/tests/test_hybrid_query.py +++ b/python/python/tests/test_hybrid_query.py @@ -67,6 +67,7 @@ async def test_async_hybrid_query_filters(table: AsyncTable): .where("text not in ('a', 'dog')") .nearest_to([0.3, 0.3]) .nearest_to_text("*a*") + .distance_type("l2") .limit(2) .to_arrow() ) diff --git a/python/python/tests/test_query.py b/python/python/tests/test_query.py index 82cdca6a..3a2bac9a 100644 --- a/python/python/tests/test_query.py +++ b/python/python/tests/test_query.py @@ -7,6 +7,7 @@ from pathlib import Path import lancedb from lancedb.index import IvfPq, FTS +from lancedb.rerankers.cross_encoder import CrossEncoderReranker import numpy as np import pandas.testing as tm import pyarrow as pa @@ -515,15 +516,24 @@ async def test_query_async(table_async: AsyncTable): expected_columns=["id", "vector", "_rowid"], ) + +@pytest.mark.asyncio +@pytest.mark.slow +async def test_query_reranked_async(table_async: AsyncTable): # FTS with rerank await table_async.create_index("text", config=FTS(with_position=False)) await check_query( - table_async.query().nearest_to_text("dog").rerank(), + table_async.query().nearest_to_text("dog").rerank(CrossEncoderReranker()), expected_num_rows=1, ) # Vector query with rerank - await check_query(table_async.vector_search([1, 2]).rerank(), expected_num_rows=2) + await check_query( + table_async.vector_search([1, 2]).rerank( + CrossEncoderReranker(), query_string="dog" + ), + expected_num_rows=2, + ) @pytest.mark.asyncio diff --git a/python/src/arrow.rs b/python/src/arrow.rs index c5c53b54..cb797345 100644 --- a/python/src/arrow.rs +++ b/python/src/arrow.rs @@ -9,7 +9,10 @@ use arrow::{ }; use futures::stream::StreamExt; use lancedb::arrow::SendableRecordBatchStream; -use pyo3::{pyclass, pymethods, Bound, PyAny, PyObject, PyRef, PyResult, Python}; +use pyo3::{ + exceptions::PyStopAsyncIteration, pyclass, pymethods, Bound, PyAny, PyObject, PyRef, PyResult, + Python, +}; use pyo3_async_runtimes::tokio::future_into_py; use crate::error::PythonErrorExt; @@ -32,20 +35,25 @@ impl RecordBatchStream { #[pymethods] impl RecordBatchStream { + #[getter] pub fn schema(&self, py: Python) -> PyResult { (*self.schema).clone().into_pyarrow(py) } - pub fn next(self_: PyRef<'_, Self>) -> PyResult> { + pub fn __aiter__(self_: PyRef<'_, Self>) -> PyRef<'_, Self> { + self_ + } + + pub fn __anext__(self_: PyRef<'_, Self>) -> PyResult> { let inner = self_.inner.clone(); future_into_py(self_.py(), async move { - let inner_next = inner.lock().await.next().await; - inner_next - .map(|item| { - let item = item.infer_error()?; - Python::with_gil(|py| item.to_pyarrow(py)) - }) - .transpose() + let inner_next = inner + .lock() + .await + .next() + .await + .ok_or_else(|| PyStopAsyncIteration::new_err(""))?; + Python::with_gil(|py| inner_next.infer_error()?.to_pyarrow(py)) }) } } diff --git a/python/src/lib.rs b/python/src/lib.rs index a68e7711..eff066d0 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -21,7 +21,7 @@ use pyo3::{ types::{PyModule, PyModuleMethods}, wrap_pyfunction, Bound, PyResult, Python, }; -use query::{Query, VectorQuery}; +use query::{FTSQuery, HybridQuery, Query, VectorQuery}; use table::Table; pub mod arrow; @@ -42,6 +42,8 @@ pub fn _lancedb(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_function(wrap_pyfunction!(connect, m)?)?;