diff --git a/docs/src/python/python.md b/docs/src/python/python.md index f59362cb..00dcd84b 100644 --- a/docs/src/python/python.md +++ b/docs/src/python/python.md @@ -147,8 +147,19 @@ to return the entire (typically filtered) table. Vector searches return the rows nearest to a query vector and can be created with the [AsyncTable.vector_search][lancedb.table.AsyncTable.vector_search] method. -::: lancedb.query.AsyncQueryBase ::: lancedb.query.AsyncQuery + options: + inherited_members: true ::: lancedb.query.AsyncVectorQuery + options: + inherited_members: true + +::: lancedb.query.AsyncFTSQuery + options: + inherited_members: true + +::: lancedb.query.AsyncHybridQuery + options: + inherited_members: true diff --git a/python/pyproject.toml b/python/pyproject.toml index 8dcc69e5..71eb1b94 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -55,7 +55,7 @@ tests = [ "tantivy", "pyarrow-stubs", ] -dev = ["ruff", "pre-commit", "pyright"] +dev = ["ruff", "pre-commit", "pyright", 'typing-extensions>=4.0.0; python_version < "3.11"'] docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"] clip = ["torch", "pillow", "open-clip"] embeddings = [ diff --git a/python/python/lancedb/_lancedb.pyi b/python/python/lancedb/_lancedb.pyi index b431e3be..fb8c0ac9 100644 --- a/python/python/lancedb/_lancedb.pyi +++ b/python/python/lancedb/_lancedb.pyi @@ -84,11 +84,15 @@ class RecordBatchStream: class Query: def where(self, filter: str): ... def select(self, columns: Tuple[str, str]): ... + def select_columns(self, columns: List[str]): ... def limit(self, limit: int): ... def offset(self, offset: int): ... + def fast_search(self): ... + def with_row_id(self): ... + def postfilter(self): ... def nearest_to(self, query_vec: pa.Array) -> VectorQuery: ... def nearest_to_text(self, query: dict) -> FTSQuery: ... - async def execute(self, max_batch_legnth: Optional[int]) -> RecordBatchStream: ... + async def execute(self, max_batch_length: Optional[int]) -> RecordBatchStream: ... class FTSQuery: def where(self, filter: str): ... @@ -98,6 +102,8 @@ class FTSQuery: def fast_search(self): ... def with_row_id(self): ... def postfilter(self): ... + def get_query(self) -> str: ... + def add_query_vector(self, query_vec: pa.Array) -> None: ... def nearest_to(self, query_vec: pa.Array) -> HybridQuery: ... async def execute(self, max_batch_length: Optional[int]) -> RecordBatchStream: ... async def explain_plan(self) -> str: ... diff --git a/python/python/lancedb/arrow.py b/python/python/lancedb/arrow.py index 06393e66..602d7df7 100644 --- a/python/python/lancedb/arrow.py +++ b/python/python/lancedb/arrow.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional, Union import pyarrow as pa @@ -12,17 +12,27 @@ class AsyncRecordBatchReader: Also allows access to the schema of the stream """ - def __init__(self, inner: RecordBatchStream): - self.inner_ = inner - - @property - def schema(self) -> pa.Schema: + def __init__( + self, + inner: Union[RecordBatchStream, pa.Table], + max_batch_length: Optional[int] = None, + ): """ - Get the schema of the batches produced by the stream - Accessing the schema does not consume any data from the stream + Attributes + ---------- + schema : pa.Schema + The schema of the batches produced by the stream. + Accessing the schema does not consume any data from the stream """ - return self.inner_.schema() + if isinstance(inner, pa.Table): + self._inner = self._async_iter_from_table(inner, max_batch_length) + self.schema: pa.Schema = inner.schema + elif isinstance(inner, RecordBatchStream): + self._inner = inner + self.schema: pa.Schema = inner.schema + else: + raise TypeError("inner must be a RecordBatchStream or a Table") async def read_all(self) -> List[pa.RecordBatch]: """ @@ -38,7 +48,18 @@ class AsyncRecordBatchReader: return self async def __anext__(self) -> pa.RecordBatch: - next = await self.inner_.next() - if next is None: - raise StopAsyncIteration - return next + return await self._inner.__anext__() + + @staticmethod + async def _async_iter_from_table( + table: pa.Table, max_batch_length: Optional[int] = None + ): + """ + Create an AsyncRecordBatchReader from a Table + + This is useful when you have a Table that you want to iterate + over asynchronously + """ + batches = table.to_batches(max_chunksize=max_batch_length) + for batch in batches: + yield batch diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index 76f41453..98a0bbf1 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -31,6 +31,7 @@ from .rerankers.util import check_reranker_result from .util import safe_import_pandas, flatten_columns if TYPE_CHECKING: + import sys import PIL import polars as pl @@ -42,6 +43,11 @@ if TYPE_CHECKING: from .pydantic import LanceModel from .table import Table + if sys.version_info >= (3, 11): + from typing import Self + else: + from typing_extensions import Self + pd = safe_import_pandas() @@ -1418,7 +1424,7 @@ class AsyncQueryBase(object): """ self._inner = inner - def where(self, predicate: str) -> AsyncQuery: + def where(self, predicate: str) -> Self: """ Only return rows matching the given predicate @@ -1437,7 +1443,7 @@ class AsyncQueryBase(object): self._inner.where(predicate) return self - def select(self, columns: Union[List[str], dict[str, str]]) -> AsyncQuery: + def select(self, columns: Union[List[str], dict[str, str]]) -> Self: """ Return only the specified columns. @@ -1475,7 +1481,7 @@ class AsyncQueryBase(object): raise TypeError("columns must be a list of column names or a dict") return self - def limit(self, limit: int) -> AsyncQuery: + def limit(self, limit: int) -> Self: """ Set the maximum number of results to return. @@ -1485,7 +1491,7 @@ class AsyncQueryBase(object): self._inner.limit(limit) return self - def offset(self, offset: int) -> AsyncQuery: + def offset(self, offset: int) -> Self: """ Set the offset for the results. @@ -1497,7 +1503,7 @@ class AsyncQueryBase(object): self._inner.offset(offset) return self - def fast_search(self) -> AsyncQuery: + def fast_search(self) -> Self: """ Skip searching un-indexed data. @@ -1511,14 +1517,14 @@ class AsyncQueryBase(object): self._inner.fast_search() return self - def with_row_id(self) -> AsyncQuery: + def with_row_id(self) -> Self: """ Include the _rowid column in the results. """ self._inner.with_row_id() return self - def postfilter(self) -> AsyncQuery: + def postfilter(self) -> Self: """ If this is called then filtering will happen after the search instead of before. @@ -1807,8 +1813,8 @@ class AsyncFTSQuery(AsyncQueryBase): self._inner = inner self._reranker = None - def get_query(self): - self._inner.get_query() + def get_query(self) -> str: + return self._inner.get_query() def rerank( self, @@ -1891,29 +1897,18 @@ class AsyncFTSQuery(AsyncQueryBase): self._inner.nearest_to(AsyncQuery._query_vec_to_array(query_vector)) ) - async def to_arrow(self) -> pa.Table: - results = await super().to_arrow() + async def to_batches( + self, *, max_batch_length: Optional[int] = None + ) -> AsyncRecordBatchReader: + reader = await super().to_batches() + results = pa.Table.from_batches(await reader.read_all(), reader.schema) if self._reranker: - results = self._reranker.rerank_fts(results) - return results + results = self._reranker.rerank_fts(self.get_query(), results) + return AsyncRecordBatchReader(results, max_batch_length=max_batch_length) -class AsyncVectorQuery(AsyncQueryBase): - def __init__(self, inner: LanceVectorQuery): - """ - Construct an AsyncVectorQuery - - This method is not intended to be called directly. Instead, create - a query first with [AsyncTable.query][lancedb.table.AsyncTable.query] and then - use [AsyncQuery.nearest_to][lancedb.query.AsyncQuery.nearest_to]] to convert to - a vector query. Or you can use - [AsyncTable.vector_search][lancedb.table.AsyncTable.vector_search] - """ - super().__init__(inner) - self._inner = inner - self._reranker = None - - def column(self, column: str) -> AsyncVectorQuery: +class AsyncVectorQueryBase: + def column(self, column: str) -> Self: """ Set the vector column to query @@ -1926,7 +1921,7 @@ class AsyncVectorQuery(AsyncQueryBase): self._inner.column(column) return self - def nprobes(self, nprobes: int) -> AsyncVectorQuery: + def nprobes(self, nprobes: int) -> Self: """ Set the number of partitions to search (probe) @@ -1954,7 +1949,7 @@ class AsyncVectorQuery(AsyncQueryBase): def distance_range( self, lower_bound: Optional[float] = None, upper_bound: Optional[float] = None - ) -> AsyncVectorQuery: + ) -> Self: """Set the distance range to use. Only rows with distances within range [lower_bound, upper_bound) @@ -1975,7 +1970,7 @@ class AsyncVectorQuery(AsyncQueryBase): self._inner.distance_range(lower_bound, upper_bound) return self - def ef(self, ef: int) -> AsyncVectorQuery: + def ef(self, ef: int) -> Self: """ Set the number of candidates to consider during search @@ -1990,7 +1985,7 @@ class AsyncVectorQuery(AsyncQueryBase): self._inner.ef(ef) return self - def refine_factor(self, refine_factor: int) -> AsyncVectorQuery: + def refine_factor(self, refine_factor: int) -> Self: """ A multiplier to control how many additional rows are taken during the refine step @@ -2026,7 +2021,7 @@ class AsyncVectorQuery(AsyncQueryBase): self._inner.refine_factor(refine_factor) return self - def distance_type(self, distance_type: str) -> AsyncVectorQuery: + def distance_type(self, distance_type: str) -> Self: """ Set the distance metric to use @@ -2044,7 +2039,7 @@ class AsyncVectorQuery(AsyncQueryBase): self._inner.distance_type(distance_type) return self - def bypass_vector_index(self) -> AsyncVectorQuery: + def bypass_vector_index(self) -> Self: """ If this is called then any vector index is skipped @@ -2057,6 +2052,23 @@ class AsyncVectorQuery(AsyncQueryBase): self._inner.bypass_vector_index() return self + +class AsyncVectorQuery(AsyncQueryBase, AsyncVectorQueryBase): + def __init__(self, inner: LanceVectorQuery): + """ + Construct an AsyncVectorQuery + + This method is not intended to be called directly. Instead, create + a query first with [AsyncTable.query][lancedb.table.AsyncTable.query] and then + use [AsyncQuery.nearest_to][lancedb.query.AsyncQuery.nearest_to]] to convert to + a vector query. Or you can use + [AsyncTable.vector_search][lancedb.table.AsyncTable.vector_search] + """ + super().__init__(inner) + self._inner = inner + self._reranker = None + self._query_string = None + def rerank( self, reranker: Reranker = RRFReranker(), query_string: Optional[str] = None ) -> AsyncHybridQuery: @@ -2065,6 +2077,11 @@ class AsyncVectorQuery(AsyncQueryBase): self._reranker = reranker + if not self._query_string and not query_string: + raise ValueError("query_string must be provided to rerank the results.") + + self._query_string = query_string + return self def nearest_to_text( @@ -2100,14 +2117,17 @@ class AsyncVectorQuery(AsyncQueryBase): self._inner.nearest_to_text({"query": query, "columns": columns}) ) - async def to_arrow(self) -> pa.Table: - results = await super().to_arrow() + async def to_batches( + self, *, max_batch_length: Optional[int] = None + ) -> AsyncRecordBatchReader: + reader = await super().to_batches() + results = pa.Table.from_batches(await reader.read_all(), reader.schema) if self._reranker: - results = self._reranker.rerank_vector(results) - return results + results = self._reranker.rerank_vector(self._query_string, results) + return AsyncRecordBatchReader(results, max_batch_length=max_batch_length) -class AsyncHybridQuery(AsyncQueryBase): +class AsyncHybridQuery(AsyncQueryBase, AsyncVectorQueryBase): """ A query builder that performs hybrid vector and full text search. Results are combined and reranked based on the specified reranker. @@ -2155,10 +2175,9 @@ class AsyncHybridQuery(AsyncQueryBase): return self - async def to_batches(self): - raise NotImplementedError("to_batches not yet supported on a hybrid query") - - async def to_arrow(self) -> pa.Table: + async def to_batches( + self, *, max_batch_length: Optional[int] = None + ) -> AsyncRecordBatchReader: fts_query = AsyncFTSQuery(self._inner.to_fts_query()) vec_query = AsyncVectorQuery(self._inner.to_vector_query()) @@ -2173,7 +2192,7 @@ class AsyncHybridQuery(AsyncQueryBase): vec_query.to_arrow(), ) - return LanceHybridQueryBuilder._combine_hybrid_results( + result = LanceHybridQueryBuilder._combine_hybrid_results( fts_results=fts_results, vector_results=vector_results, norm=self._norm, @@ -2183,6 +2202,8 @@ class AsyncHybridQuery(AsyncQueryBase): with_row_ids=with_row_ids, ) + return AsyncRecordBatchReader(result, max_batch_length=max_batch_length) + async def explain_plan(self, verbose: Optional[bool] = False): """Return the execution plan for this query. diff --git a/python/python/tests/test_hybrid_query.py b/python/python/tests/test_hybrid_query.py index 4014e2a8..20f8ea2e 100644 --- a/python/python/tests/test_hybrid_query.py +++ b/python/python/tests/test_hybrid_query.py @@ -67,6 +67,7 @@ async def test_async_hybrid_query_filters(table: AsyncTable): .where("text not in ('a', 'dog')") .nearest_to([0.3, 0.3]) .nearest_to_text("*a*") + .distance_type("l2") .limit(2) .to_arrow() ) diff --git a/python/python/tests/test_query.py b/python/python/tests/test_query.py index 82cdca6a..3a2bac9a 100644 --- a/python/python/tests/test_query.py +++ b/python/python/tests/test_query.py @@ -7,6 +7,7 @@ from pathlib import Path import lancedb from lancedb.index import IvfPq, FTS +from lancedb.rerankers.cross_encoder import CrossEncoderReranker import numpy as np import pandas.testing as tm import pyarrow as pa @@ -515,15 +516,24 @@ async def test_query_async(table_async: AsyncTable): expected_columns=["id", "vector", "_rowid"], ) + +@pytest.mark.asyncio +@pytest.mark.slow +async def test_query_reranked_async(table_async: AsyncTable): # FTS with rerank await table_async.create_index("text", config=FTS(with_position=False)) await check_query( - table_async.query().nearest_to_text("dog").rerank(), + table_async.query().nearest_to_text("dog").rerank(CrossEncoderReranker()), expected_num_rows=1, ) # Vector query with rerank - await check_query(table_async.vector_search([1, 2]).rerank(), expected_num_rows=2) + await check_query( + table_async.vector_search([1, 2]).rerank( + CrossEncoderReranker(), query_string="dog" + ), + expected_num_rows=2, + ) @pytest.mark.asyncio diff --git a/python/src/arrow.rs b/python/src/arrow.rs index c5c53b54..cb797345 100644 --- a/python/src/arrow.rs +++ b/python/src/arrow.rs @@ -9,7 +9,10 @@ use arrow::{ }; use futures::stream::StreamExt; use lancedb::arrow::SendableRecordBatchStream; -use pyo3::{pyclass, pymethods, Bound, PyAny, PyObject, PyRef, PyResult, Python}; +use pyo3::{ + exceptions::PyStopAsyncIteration, pyclass, pymethods, Bound, PyAny, PyObject, PyRef, PyResult, + Python, +}; use pyo3_async_runtimes::tokio::future_into_py; use crate::error::PythonErrorExt; @@ -32,20 +35,25 @@ impl RecordBatchStream { #[pymethods] impl RecordBatchStream { + #[getter] pub fn schema(&self, py: Python) -> PyResult { (*self.schema).clone().into_pyarrow(py) } - pub fn next(self_: PyRef<'_, Self>) -> PyResult> { + pub fn __aiter__(self_: PyRef<'_, Self>) -> PyRef<'_, Self> { + self_ + } + + pub fn __anext__(self_: PyRef<'_, Self>) -> PyResult> { let inner = self_.inner.clone(); future_into_py(self_.py(), async move { - let inner_next = inner.lock().await.next().await; - inner_next - .map(|item| { - let item = item.infer_error()?; - Python::with_gil(|py| item.to_pyarrow(py)) - }) - .transpose() + let inner_next = inner + .lock() + .await + .next() + .await + .ok_or_else(|| PyStopAsyncIteration::new_err(""))?; + Python::with_gil(|py| inner_next.infer_error()?.to_pyarrow(py)) }) } } diff --git a/python/src/lib.rs b/python/src/lib.rs index a68e7711..eff066d0 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -21,7 +21,7 @@ use pyo3::{ types::{PyModule, PyModuleMethods}, wrap_pyfunction, Bound, PyResult, Python, }; -use query::{Query, VectorQuery}; +use query::{FTSQuery, HybridQuery, Query, VectorQuery}; use table::Table; pub mod arrow; @@ -42,6 +42,8 @@ pub fn _lancedb(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_function(wrap_pyfunction!(connect, m)?)?;