mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 07:09:57 +00:00
feat: add take_offsets and take_row_ids (#2584)
These operations have existed in lance for a long while and many users need to drop down to lance for this capability. This PR adds the API and implements it using filters (e.g. `_rowid IN (...)`) so that in doesn't currently add any load to `BaseTable`. I'm not sure that is sustainable as base table implementations may want to specialize how they handle this method. However, I figure it is a good starting point. In addition, unlike Lance, this API does not currently guarantee anything about the order of the take results. This is necessary for the fallback filter approach to work (SQL filters cannot guarantee result order)
This commit is contained in:
@@ -28,6 +28,7 @@ import pyarrow.fs as pa_fs
|
||||
import pydantic
|
||||
|
||||
from lancedb.pydantic import PYDANTIC_VERSION
|
||||
from lancedb.background_loop import LOOP
|
||||
|
||||
from . import __version__
|
||||
from .arrow import AsyncRecordBatchReader
|
||||
@@ -48,6 +49,7 @@ if TYPE_CHECKING:
|
||||
from ._lancedb import FTSQuery as LanceFTSQuery
|
||||
from ._lancedb import HybridQuery as LanceHybridQuery
|
||||
from ._lancedb import VectorQuery as LanceVectorQuery
|
||||
from ._lancedb import TakeQuery as LanceTakeQuery
|
||||
from ._lancedb import PyQueryRequest
|
||||
from .common import VEC
|
||||
from .pydantic import LanceModel
|
||||
@@ -2139,7 +2141,11 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
|
||||
|
||||
class AsyncQueryBase(object):
|
||||
def __init__(self, inner: Union[LanceQuery, LanceVectorQuery]):
|
||||
"""
|
||||
Base class for all async queries (take, scan, vector, fts, hybrid)
|
||||
"""
|
||||
|
||||
def __init__(self, inner: Union[LanceQuery, LanceVectorQuery, LanceTakeQuery]):
|
||||
"""
|
||||
Construct an AsyncQueryBase
|
||||
|
||||
@@ -2149,27 +2155,14 @@ class AsyncQueryBase(object):
|
||||
self._inner = inner
|
||||
|
||||
def to_query_object(self) -> Query:
|
||||
"""
|
||||
Convert the query into a query object
|
||||
|
||||
This is currently experimental but can be useful as the query object is pure
|
||||
python and more easily serializable.
|
||||
"""
|
||||
return Query.from_inner(self._inner.to_query_request())
|
||||
|
||||
def where(self, predicate: str) -> Self:
|
||||
"""
|
||||
Only return rows matching the given predicate
|
||||
|
||||
The predicate should be supplied as an SQL query string.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
>>> predicate = "x > 10"
|
||||
>>> predicate = "y > 0 AND y < 100"
|
||||
>>> predicate = "x > 5 OR y = 'test'"
|
||||
|
||||
Filtering performance can often be improved by creating a scalar index
|
||||
on the filter column(s).
|
||||
"""
|
||||
self._inner.where(predicate)
|
||||
return self
|
||||
|
||||
def select(self, columns: Union[List[str], dict[str, str]]) -> Self:
|
||||
"""
|
||||
Return only the specified columns.
|
||||
@@ -2208,42 +2201,6 @@ class AsyncQueryBase(object):
|
||||
raise TypeError("columns must be a list of column names or a dict")
|
||||
return self
|
||||
|
||||
def limit(self, limit: int) -> Self:
|
||||
"""
|
||||
Set the maximum number of results to return.
|
||||
|
||||
By default, a plain search has no limit. If this method is not
|
||||
called then every valid row from the table will be returned.
|
||||
"""
|
||||
self._inner.limit(limit)
|
||||
return self
|
||||
|
||||
def offset(self, offset: int) -> Self:
|
||||
"""
|
||||
Set the offset for the results.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
offset: int
|
||||
The offset to start fetching results from.
|
||||
"""
|
||||
self._inner.offset(offset)
|
||||
return self
|
||||
|
||||
def fast_search(self) -> Self:
|
||||
"""
|
||||
Skip searching un-indexed data.
|
||||
|
||||
This can make queries faster, but will miss any data that has not been
|
||||
indexed.
|
||||
|
||||
!!! tip
|
||||
You can add new data into an existing index by calling
|
||||
[AsyncTable.optimize][lancedb.table.AsyncTable.optimize].
|
||||
"""
|
||||
self._inner.fast_search()
|
||||
return self
|
||||
|
||||
def with_row_id(self) -> Self:
|
||||
"""
|
||||
Include the _rowid column in the results.
|
||||
@@ -2251,27 +2208,6 @@ class AsyncQueryBase(object):
|
||||
self._inner.with_row_id()
|
||||
return self
|
||||
|
||||
def postfilter(self) -> Self:
|
||||
"""
|
||||
If this is called then filtering will happen after the search instead of
|
||||
before.
|
||||
By default filtering will be performed before the search. This is how
|
||||
filtering is typically understood to work. This prefilter step does add some
|
||||
additional latency. Creating a scalar index on the filter column(s) can
|
||||
often improve this latency. However, sometimes a filter is too complex or
|
||||
scalar indices cannot be applied to the column. In these cases postfiltering
|
||||
can be used instead of prefiltering to improve latency.
|
||||
Post filtering applies the filter to the results of the search. This
|
||||
means we only run the filter on a much smaller set of data. However, it can
|
||||
cause the query to return fewer than `limit` results (or even no results) if
|
||||
none of the nearest results match the filter.
|
||||
Post filtering happens during the "refine stage" (described in more detail in
|
||||
@see {@link VectorQuery#refineFactor}). This means that setting a higher refine
|
||||
factor can often help restore some of the results lost by post filtering.
|
||||
"""
|
||||
self._inner.postfilter()
|
||||
return self
|
||||
|
||||
async def to_batches(
|
||||
self,
|
||||
*,
|
||||
@@ -2295,7 +2231,9 @@ class AsyncQueryBase(object):
|
||||
complete within the specified time, an error will be raised.
|
||||
"""
|
||||
return AsyncRecordBatchReader(
|
||||
await self._inner.execute(max_batch_length, timeout)
|
||||
await self._inner.execute(
|
||||
max_batch_length=max_batch_length, timeout=timeout
|
||||
)
|
||||
)
|
||||
|
||||
async def to_arrow(self, timeout: Optional[timedelta] = None) -> pa.Table:
|
||||
@@ -2454,7 +2392,98 @@ class AsyncQueryBase(object):
|
||||
return await self._inner.analyze_plan()
|
||||
|
||||
|
||||
class AsyncQuery(AsyncQueryBase):
|
||||
class AsyncStandardQuery(AsyncQueryBase):
|
||||
"""
|
||||
Base class for "standard" async queries (all but take currently)
|
||||
"""
|
||||
|
||||
def __init__(self, inner: Union[LanceQuery, LanceVectorQuery]):
|
||||
"""
|
||||
Construct an AsyncStandardQuery
|
||||
|
||||
This method is not intended to be called directly. Instead, use the
|
||||
[AsyncTable.query][lancedb.table.AsyncTable.query] method to create a query.
|
||||
"""
|
||||
super().__init__(inner)
|
||||
|
||||
def where(self, predicate: str) -> Self:
|
||||
"""
|
||||
Only return rows matching the given predicate
|
||||
|
||||
The predicate should be supplied as an SQL query string.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
>>> predicate = "x > 10"
|
||||
>>> predicate = "y > 0 AND y < 100"
|
||||
>>> predicate = "x > 5 OR y = 'test'"
|
||||
|
||||
Filtering performance can often be improved by creating a scalar index
|
||||
on the filter column(s).
|
||||
"""
|
||||
self._inner.where(predicate)
|
||||
return self
|
||||
|
||||
def limit(self, limit: int) -> Self:
|
||||
"""
|
||||
Set the maximum number of results to return.
|
||||
|
||||
By default, a plain search has no limit. If this method is not
|
||||
called then every valid row from the table will be returned.
|
||||
"""
|
||||
self._inner.limit(limit)
|
||||
return self
|
||||
|
||||
def offset(self, offset: int) -> Self:
|
||||
"""
|
||||
Set the offset for the results.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
offset: int
|
||||
The offset to start fetching results from.
|
||||
"""
|
||||
self._inner.offset(offset)
|
||||
return self
|
||||
|
||||
def fast_search(self) -> Self:
|
||||
"""
|
||||
Skip searching un-indexed data.
|
||||
|
||||
This can make queries faster, but will miss any data that has not been
|
||||
indexed.
|
||||
|
||||
!!! tip
|
||||
You can add new data into an existing index by calling
|
||||
[AsyncTable.optimize][lancedb.table.AsyncTable.optimize].
|
||||
"""
|
||||
self._inner.fast_search()
|
||||
return self
|
||||
|
||||
def postfilter(self) -> Self:
|
||||
"""
|
||||
If this is called then filtering will happen after the search instead of
|
||||
before.
|
||||
By default filtering will be performed before the search. This is how
|
||||
filtering is typically understood to work. This prefilter step does add some
|
||||
additional latency. Creating a scalar index on the filter column(s) can
|
||||
often improve this latency. However, sometimes a filter is too complex or
|
||||
scalar indices cannot be applied to the column. In these cases postfiltering
|
||||
can be used instead of prefiltering to improve latency.
|
||||
Post filtering applies the filter to the results of the search. This
|
||||
means we only run the filter on a much smaller set of data. However, it can
|
||||
cause the query to return fewer than `limit` results (or even no results) if
|
||||
none of the nearest results match the filter.
|
||||
Post filtering happens during the "refine stage" (described in more detail in
|
||||
@see {@link VectorQuery#refineFactor}). This means that setting a higher refine
|
||||
factor can often help restore some of the results lost by post filtering.
|
||||
"""
|
||||
self._inner.postfilter()
|
||||
return self
|
||||
|
||||
|
||||
class AsyncQuery(AsyncStandardQuery):
|
||||
def __init__(self, inner: LanceQuery):
|
||||
"""
|
||||
Construct an AsyncQuery
|
||||
@@ -2588,7 +2617,7 @@ class AsyncQuery(AsyncQueryBase):
|
||||
return AsyncFTSQuery(self._inner.nearest_to_text({"query": query}))
|
||||
|
||||
|
||||
class AsyncFTSQuery(AsyncQueryBase):
|
||||
class AsyncFTSQuery(AsyncStandardQuery):
|
||||
"""A query for full text search for LanceDB."""
|
||||
|
||||
def __init__(self, inner: LanceFTSQuery):
|
||||
@@ -2867,7 +2896,7 @@ class AsyncVectorQueryBase:
|
||||
return self
|
||||
|
||||
|
||||
class AsyncVectorQuery(AsyncQueryBase, AsyncVectorQueryBase):
|
||||
class AsyncVectorQuery(AsyncStandardQuery, AsyncVectorQueryBase):
|
||||
def __init__(self, inner: LanceVectorQuery):
|
||||
"""
|
||||
Construct an AsyncVectorQuery
|
||||
@@ -2950,7 +2979,7 @@ class AsyncVectorQuery(AsyncQueryBase, AsyncVectorQueryBase):
|
||||
return AsyncRecordBatchReader(results, max_batch_length=max_batch_length)
|
||||
|
||||
|
||||
class AsyncHybridQuery(AsyncQueryBase, AsyncVectorQueryBase):
|
||||
class AsyncHybridQuery(AsyncStandardQuery, AsyncVectorQueryBase):
|
||||
"""
|
||||
A query builder that performs hybrid vector and full text search.
|
||||
Results are combined and reranked based on the specified reranker.
|
||||
@@ -3102,3 +3131,252 @@ class AsyncHybridQuery(AsyncQueryBase, AsyncVectorQueryBase):
|
||||
results.append(await self._inner.to_fts_query().analyze_plan())
|
||||
|
||||
return "\n".join(results)
|
||||
|
||||
|
||||
class AsyncTakeQuery(AsyncQueryBase):
|
||||
"""
|
||||
Builder for parameterizing and executing take queries.
|
||||
"""
|
||||
|
||||
def __init__(self, inner: LanceTakeQuery):
|
||||
super().__init__(inner)
|
||||
|
||||
|
||||
class BaseQueryBuilder(object):
|
||||
"""
|
||||
Wraps AsyncQueryBase and provides a synchronous interface
|
||||
"""
|
||||
|
||||
def __init__(self, inner: AsyncQueryBase):
|
||||
self._inner = inner
|
||||
|
||||
def to_query_object(self) -> Query:
|
||||
return self._inner.to_query_object()
|
||||
|
||||
def select(self, columns: Union[List[str], dict[str, str]]) -> Self:
|
||||
"""
|
||||
Return only the specified columns.
|
||||
|
||||
By default a query will return all columns from the table. However, this can
|
||||
have a very significant impact on latency. LanceDb stores data in a columnar
|
||||
fashion. This
|
||||
means we can finely tune our I/O to select exactly the columns we need.
|
||||
|
||||
As a best practice you should always limit queries to the columns that you need.
|
||||
If you pass in a list of column names then only those columns will be
|
||||
returned.
|
||||
|
||||
You can also use this method to create new "dynamic" columns based on your
|
||||
existing columns. For example, you may not care about "a" or "b" but instead
|
||||
simply want "a + b". This is often seen in the SELECT clause of an SQL query
|
||||
(e.g. `SELECT a+b FROM my_table`).
|
||||
|
||||
To create dynamic columns you can pass in a dict[str, str]. A column will be
|
||||
returned for each entry in the map. The key provides the name of the column.
|
||||
The value is an SQL string used to specify how the column is calculated.
|
||||
|
||||
For example, an SQL query might state `SELECT a + b AS combined, c`. The
|
||||
equivalent input to this method would be `{"combined": "a + b", "c": "c"}`.
|
||||
|
||||
Columns will always be returned in the order given, even if that order is
|
||||
different than the order used when adding the data.
|
||||
"""
|
||||
self._inner.select(columns)
|
||||
return self
|
||||
|
||||
def with_row_id(self) -> Self:
|
||||
"""
|
||||
Include the _rowid column in the results.
|
||||
"""
|
||||
self._inner.with_row_id()
|
||||
return self
|
||||
|
||||
def to_batches(
|
||||
self,
|
||||
*,
|
||||
max_batch_length: Optional[int] = None,
|
||||
timeout: Optional[timedelta] = None,
|
||||
) -> pa.RecordBatchReader:
|
||||
"""
|
||||
Execute the query and return the results as an Apache Arrow RecordBatchReader.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
max_batch_length: Optional[int]
|
||||
The maximum number of selected records in a single RecordBatch object.
|
||||
If not specified, a default batch length is used.
|
||||
It is possible for batches to be smaller than the provided length if the
|
||||
underlying data is stored in smaller chunks.
|
||||
timeout: Optional[timedelta]
|
||||
The maximum time to wait for the query to complete.
|
||||
If not specified, no timeout is applied. If the query does not
|
||||
complete within the specified time, an error will be raised.
|
||||
"""
|
||||
async_iter = LOOP.run(self._inner.execute(max_batch_length, timeout))
|
||||
|
||||
def iter_sync():
|
||||
try:
|
||||
while True:
|
||||
yield LOOP.run(async_iter.__anext__())
|
||||
except StopAsyncIteration:
|
||||
return
|
||||
|
||||
return pa.RecordBatchReader.from_batches(async_iter.schema, iter_sync())
|
||||
|
||||
def to_arrow(self, timeout: Optional[timedelta] = None) -> pa.Table:
|
||||
"""
|
||||
Execute the query and collect the results into an Apache Arrow Table.
|
||||
|
||||
This method will collect all results into memory before returning. If
|
||||
you expect a large number of results, you may want to use
|
||||
[to_batches][lancedb.query.AsyncQueryBase.to_batches]
|
||||
|
||||
Parameters
|
||||
----------
|
||||
timeout: Optional[timedelta]
|
||||
The maximum time to wait for the query to complete.
|
||||
If not specified, no timeout is applied. If the query does not
|
||||
complete within the specified time, an error will be raised.
|
||||
"""
|
||||
return LOOP.run(self._inner.to_arrow(timeout))
|
||||
|
||||
def to_list(self, timeout: Optional[timedelta] = None) -> List[dict]:
|
||||
"""
|
||||
Execute the query and return the results as a list of dictionaries.
|
||||
|
||||
Each list entry is a dictionary with the selected column names as keys,
|
||||
or all table columns if `select` is not called. The vector and the "_distance"
|
||||
fields are returned whether or not they're explicitly selected.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
timeout: Optional[timedelta]
|
||||
The maximum time to wait for the query to complete.
|
||||
If not specified, no timeout is applied. If the query does not
|
||||
complete within the specified time, an error will be raised.
|
||||
"""
|
||||
return LOOP.run(self._inner.to_list(timeout))
|
||||
|
||||
def to_pandas(
|
||||
self,
|
||||
flatten: Optional[Union[int, bool]] = None,
|
||||
timeout: Optional[timedelta] = None,
|
||||
) -> "pd.DataFrame":
|
||||
"""
|
||||
Execute the query and collect the results into a pandas DataFrame.
|
||||
|
||||
This method will collect all results into memory before returning. If you
|
||||
expect a large number of results, you may want to use
|
||||
[to_batches][lancedb.query.AsyncQueryBase.to_batches] and convert each batch to
|
||||
pandas separately.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
>>> import asyncio
|
||||
>>> from lancedb import connect_async
|
||||
>>> async def doctest_example():
|
||||
... conn = await connect_async("./.lancedb")
|
||||
... table = await conn.create_table("my_table", data=[{"a": 1, "b": 2}])
|
||||
... async for batch in await table.query().to_batches():
|
||||
... batch_df = batch.to_pandas()
|
||||
>>> asyncio.run(doctest_example())
|
||||
|
||||
Parameters
|
||||
----------
|
||||
flatten: Optional[Union[int, bool]]
|
||||
If flatten is True, flatten all nested columns.
|
||||
If flatten is an integer, flatten the nested columns up to the
|
||||
specified depth.
|
||||
If unspecified, do not flatten the nested columns.
|
||||
timeout: Optional[timedelta]
|
||||
The maximum time to wait for the query to complete.
|
||||
If not specified, no timeout is applied. If the query does not
|
||||
complete within the specified time, an error will be raised.
|
||||
"""
|
||||
return LOOP.run(self._inner.to_pandas(flatten, timeout))
|
||||
|
||||
def to_polars(
|
||||
self,
|
||||
timeout: Optional[timedelta] = None,
|
||||
) -> "pl.DataFrame":
|
||||
"""
|
||||
Execute the query and collect the results into a Polars DataFrame.
|
||||
|
||||
This method will collect all results into memory before returning. If you
|
||||
expect a large number of results, you may want to use
|
||||
[to_batches][lancedb.query.AsyncQueryBase.to_batches] and convert each batch to
|
||||
polars separately.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
timeout: Optional[timedelta]
|
||||
The maximum time to wait for the query to complete.
|
||||
If not specified, no timeout is applied. If the query does not
|
||||
complete within the specified time, an error will be raised.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
>>> import asyncio
|
||||
>>> import polars as pl
|
||||
>>> from lancedb import connect_async
|
||||
>>> async def doctest_example():
|
||||
... conn = await connect_async("./.lancedb")
|
||||
... table = await conn.create_table("my_table", data=[{"a": 1, "b": 2}])
|
||||
... async for batch in await table.query().to_batches():
|
||||
... batch_df = pl.from_arrow(batch)
|
||||
>>> asyncio.run(doctest_example())
|
||||
"""
|
||||
return LOOP.run(self._inner.to_polars(timeout))
|
||||
|
||||
def explain_plan(self, verbose: Optional[bool] = False):
|
||||
"""Return the execution plan for this query.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import asyncio
|
||||
>>> from lancedb import connect_async
|
||||
>>> async def doctest_example():
|
||||
... conn = await connect_async("./.lancedb")
|
||||
... table = await conn.create_table("my_table", [{"vector": [99, 99]}])
|
||||
... query = [100, 100]
|
||||
... plan = await table.query().nearest_to([1, 2]).explain_plan(True)
|
||||
... print(plan)
|
||||
>>> asyncio.run(doctest_example()) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
|
||||
ProjectionExec: expr=[vector@0 as vector, _distance@2 as _distance]
|
||||
GlobalLimitExec: skip=0, fetch=10
|
||||
FilterExec: _distance@2 IS NOT NULL
|
||||
SortExec: TopK(fetch=10), expr=[_distance@2 ASC NULLS LAST, _rowid@1 ASC NULLS LAST], preserve_partitioning=[false]
|
||||
KNNVectorDistance: metric=l2
|
||||
LanceRead: uri=..., projection=[vector], ...
|
||||
|
||||
Parameters
|
||||
----------
|
||||
verbose : bool, default False
|
||||
Use a verbose output format.
|
||||
|
||||
Returns
|
||||
-------
|
||||
plan : str
|
||||
""" # noqa: E501
|
||||
return LOOP.run(self._inner.explain_plan(verbose))
|
||||
|
||||
def analyze_plan(self):
|
||||
"""Execute the query and display with runtime metrics.
|
||||
|
||||
Returns
|
||||
-------
|
||||
plan : str
|
||||
"""
|
||||
return LOOP.run(self._inner.analyze_plan())
|
||||
|
||||
|
||||
class LanceTakeQueryBuilder(BaseQueryBuilder):
|
||||
"""
|
||||
Builder for parameterizing and executing take queries.
|
||||
"""
|
||||
|
||||
def __init__(self, inner: AsyncTakeQuery):
|
||||
super().__init__(inner)
|
||||
|
||||
@@ -26,7 +26,7 @@ from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||
from lancedb.merge import LanceMergeInsertBuilder
|
||||
from lancedb.embeddings import EmbeddingFunctionRegistry
|
||||
|
||||
from ..query import LanceVectorQueryBuilder, LanceQueryBuilder
|
||||
from ..query import LanceVectorQueryBuilder, LanceQueryBuilder, LanceTakeQueryBuilder
|
||||
from ..table import AsyncTable, IndexStatistics, Query, Table, Tags
|
||||
|
||||
|
||||
@@ -617,6 +617,12 @@ class RemoteTable(Table):
|
||||
def stats(self):
|
||||
return LOOP.run(self._table.stats())
|
||||
|
||||
def take_offsets(self, offsets: list[int]) -> LanceTakeQueryBuilder:
|
||||
return LanceTakeQueryBuilder(self._table.take_offsets(offsets))
|
||||
|
||||
def take_row_ids(self, row_ids: list[int]) -> LanceTakeQueryBuilder:
|
||||
return LanceTakeQueryBuilder(self._table.take_row_ids(row_ids))
|
||||
|
||||
def uses_v2_manifest_paths(self) -> bool:
|
||||
raise NotImplementedError(
|
||||
"uses_v2_manifest_paths() is not supported on the LanceDB Cloud"
|
||||
|
||||
@@ -51,6 +51,7 @@ from .query import (
|
||||
AsyncFTSQuery,
|
||||
AsyncHybridQuery,
|
||||
AsyncQuery,
|
||||
AsyncTakeQuery,
|
||||
AsyncVectorQuery,
|
||||
FullTextQuery,
|
||||
LanceEmptyQueryBuilder,
|
||||
@@ -58,6 +59,7 @@ from .query import (
|
||||
LanceHybridQueryBuilder,
|
||||
LanceQueryBuilder,
|
||||
LanceVectorQueryBuilder,
|
||||
LanceTakeQueryBuilder,
|
||||
Query,
|
||||
)
|
||||
from .util import (
|
||||
@@ -1103,6 +1105,66 @@ class Table(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def take_offsets(self, offsets: list[int]) -> LanceTakeQueryBuilder:
|
||||
"""
|
||||
Take a list of offsets from the table.
|
||||
|
||||
Offsets are 0-indexed and relative to the current version of the table. Offsets
|
||||
are not stable. A row with an offset of N may have a different offset in a
|
||||
different version of the table (e.g. if an earlier row is deleted).
|
||||
|
||||
Offsets are mostly useful for sampling as the set of all valid offsets is easily
|
||||
known in advance to be [0, len(table)).
|
||||
|
||||
No guarantees are made regarding the order in which results are returned. If
|
||||
you desire an output order that matches the order of the given offsets, you will
|
||||
need to add the row offset column to the output and align it yourself.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
offsets: list[int]
|
||||
The offsets to take.
|
||||
|
||||
Returns
|
||||
-------
|
||||
pa.RecordBatch
|
||||
A record batch containing the rows at the given offsets.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def take_row_ids(self, row_ids: list[int]) -> LanceTakeQueryBuilder:
|
||||
"""
|
||||
Take a list of row ids from the table.
|
||||
|
||||
Row ids are not stable and are relative to the current version of the table.
|
||||
They can change due to compaction and updates.
|
||||
|
||||
No guarantees are made regarding the order in which results are returned. If
|
||||
you desire an output order that matches the order of the given ids, you will
|
||||
need to add the row id column to the output and align it yourself.
|
||||
|
||||
Unlike offsets, row ids are not 0-indexed and no assumptions should be made
|
||||
about the possible range of row ids. In order to use this method you must
|
||||
first obtain the row ids by scanning or searching the table.
|
||||
|
||||
Even so, row ids are more stable than offsets and can be useful in some
|
||||
situations.
|
||||
|
||||
There is an ongoing effort to make row ids stable which is tracked at
|
||||
https://github.com/lancedb/lancedb/issues/1120
|
||||
|
||||
Parameters
|
||||
----------
|
||||
row_ids: list[int]
|
||||
The row ids to take.
|
||||
|
||||
Returns
|
||||
-------
|
||||
AsyncTakeQuery
|
||||
A query object that can be executed to get the rows.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def _execute_query(
|
||||
self,
|
||||
@@ -1648,6 +1710,12 @@ class LanceTable(Table):
|
||||
"""Get the current version of the table"""
|
||||
return LOOP.run(self._table.version())
|
||||
|
||||
def take_offsets(self, offsets: list[int]) -> LanceTakeQueryBuilder:
|
||||
return LanceTakeQueryBuilder(self._table.take_offsets(offsets))
|
||||
|
||||
def take_row_ids(self, row_ids: list[int]) -> LanceTakeQueryBuilder:
|
||||
return LanceTakeQueryBuilder(self._table.take_row_ids(row_ids))
|
||||
|
||||
@property
|
||||
def tags(self) -> Tags:
|
||||
"""Tag management for the table.
|
||||
@@ -4030,6 +4098,58 @@ class AsyncTable:
|
||||
"""
|
||||
await self._inner.restore(version)
|
||||
|
||||
def take_offsets(self, offsets: list[int]) -> AsyncTakeQuery:
|
||||
"""
|
||||
Take a list of offsets from the table.
|
||||
|
||||
Offsets are 0-indexed and relative to the current version of the table. Offsets
|
||||
are not stable. A row with an offset of N may have a different offset in a
|
||||
different version of the table (e.g. if an earlier row is deleted).
|
||||
|
||||
Offsets are mostly useful for sampling as the set of all valid offsets is easily
|
||||
known in advance to be [0, len(table)).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
offsets: list[int]
|
||||
The offsets to take.
|
||||
|
||||
Returns
|
||||
-------
|
||||
pa.RecordBatch
|
||||
A record batch containing the rows at the given offsets.
|
||||
"""
|
||||
return AsyncTakeQuery(self._inner.take_offsets(offsets))
|
||||
|
||||
def take_row_ids(self, row_ids: list[int]) -> AsyncTakeQuery:
|
||||
"""
|
||||
Take a list of row ids from the table.
|
||||
|
||||
Row ids are not stable and are relative to the current version of the table.
|
||||
They can change due to compaction and updates.
|
||||
|
||||
Unlike offsets, row ids are not 0-indexed and no assumptions should be made
|
||||
about the possible range of row ids. In order to use this method you must
|
||||
first obtain the row ids by scanning or searching the table.
|
||||
|
||||
Even so, row ids are more stable than offsets and can be useful in some
|
||||
situations.
|
||||
|
||||
There is an ongoing effort to make row ids stable which is tracked at
|
||||
https://github.com/lancedb/lancedb/issues/1120
|
||||
|
||||
Parameters
|
||||
----------
|
||||
row_ids: list[int]
|
||||
The row ids to take.
|
||||
|
||||
Returns
|
||||
-------
|
||||
AsyncTakeQuery
|
||||
A query object that can be executed to get the rows.
|
||||
"""
|
||||
return AsyncTakeQuery(self._inner.take_row_ids(row_ids))
|
||||
|
||||
@property
|
||||
def tags(self) -> AsyncTags:
|
||||
"""Tag management for the dataset.
|
||||
|
||||
@@ -1327,6 +1327,34 @@ def test_query_timeout(tmp_path):
|
||||
)
|
||||
|
||||
|
||||
def test_take_queries(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
data = pa.table(
|
||||
{
|
||||
"idx": range(100),
|
||||
}
|
||||
)
|
||||
table = db.create_table("test", data)
|
||||
|
||||
# Take by offset
|
||||
assert list(
|
||||
sorted(table.take_offsets([5, 2, 17]).to_pandas()["idx"].to_list())
|
||||
) == [
|
||||
2,
|
||||
5,
|
||||
17,
|
||||
]
|
||||
|
||||
# Take by row id
|
||||
assert list(
|
||||
sorted(table.take_row_ids([5, 2, 17]).to_pandas()["idx"].to_list())
|
||||
) == [
|
||||
2,
|
||||
5,
|
||||
17,
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_timeout_async(tmp_path):
|
||||
db = await lancedb.connect_async(tmp_path)
|
||||
|
||||
@@ -13,10 +13,12 @@ use lancedb::index::scalar::{
|
||||
BooleanQuery, BoostQuery, FtsQuery, FullTextSearchQuery, MatchQuery, MultiMatchQuery, Occur,
|
||||
Operator, PhraseQuery,
|
||||
};
|
||||
use lancedb::query::QueryBase;
|
||||
use lancedb::query::QueryExecutionOptions;
|
||||
use lancedb::query::QueryFilter;
|
||||
use lancedb::query::{
|
||||
ExecutableQuery, Query as LanceDbQuery, QueryBase, Select, VectorQuery as LanceDbVectorQuery,
|
||||
ExecutableQuery, Query as LanceDbQuery, Select, TakeQuery as LanceDbTakeQuery,
|
||||
VectorQuery as LanceDbVectorQuery,
|
||||
};
|
||||
use lancedb::table::AnyQuery;
|
||||
use pyo3::prelude::{PyAnyMethods, PyDictMethods};
|
||||
@@ -488,6 +490,76 @@ impl Query {
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
pub struct TakeQuery {
|
||||
inner: LanceDbTakeQuery,
|
||||
}
|
||||
|
||||
impl TakeQuery {
|
||||
pub fn new(query: LanceDbTakeQuery) -> Self {
|
||||
Self { inner: query }
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl TakeQuery {
|
||||
pub fn select(&mut self, columns: Vec<(String, String)>) {
|
||||
self.inner = self.inner.clone().select(Select::dynamic(&columns));
|
||||
}
|
||||
|
||||
pub fn select_columns(&mut self, columns: Vec<String>) {
|
||||
self.inner = self.inner.clone().select(Select::columns(&columns));
|
||||
}
|
||||
|
||||
pub fn with_row_id(&mut self) {
|
||||
self.inner = self.inner.clone().with_row_id();
|
||||
}
|
||||
|
||||
#[pyo3(signature = (max_batch_length=None, timeout=None))]
|
||||
pub fn execute(
|
||||
self_: PyRef<'_, Self>,
|
||||
max_batch_length: Option<u32>,
|
||||
timeout: Option<Duration>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.inner.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
let mut opts = QueryExecutionOptions::default();
|
||||
if let Some(max_batch_length) = max_batch_length {
|
||||
opts.max_batch_length = max_batch_length;
|
||||
}
|
||||
if let Some(timeout) = timeout {
|
||||
opts.timeout = Some(timeout);
|
||||
}
|
||||
let inner_stream = inner.execute_with_options(opts).await.infer_error()?;
|
||||
Ok(RecordBatchStream::new(inner_stream))
|
||||
})
|
||||
}
|
||||
|
||||
pub fn explain_plan(self_: PyRef<'_, Self>, verbose: bool) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.inner.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
inner
|
||||
.explain_plan(verbose)
|
||||
.await
|
||||
.map_err(|e| PyRuntimeError::new_err(e.to_string()))
|
||||
})
|
||||
}
|
||||
|
||||
pub fn analyze_plan(self_: PyRef<'_, Self>) -> PyResult<Bound<'_, PyAny>> {
|
||||
let inner = self_.inner.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
inner
|
||||
.analyze_plan()
|
||||
.await
|
||||
.map_err(|e| PyRuntimeError::new_err(e.to_string()))
|
||||
})
|
||||
}
|
||||
|
||||
pub fn to_query_request(&self) -> PyQueryRequest {
|
||||
PyQueryRequest::from(AnyQuery::Query(self.inner.clone().into_request()))
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
#[derive(Clone)]
|
||||
pub struct FTSQuery {
|
||||
|
||||
@@ -5,7 +5,7 @@ use std::{collections::HashMap, sync::Arc};
|
||||
use crate::{
|
||||
error::PythonErrorExt,
|
||||
index::{extract_index_params, IndexConfig},
|
||||
query::Query,
|
||||
query::{Query, TakeQuery},
|
||||
};
|
||||
use arrow::{
|
||||
datatypes::{DataType, Schema},
|
||||
@@ -568,6 +568,20 @@ impl Table {
|
||||
Ok(Tags::new(self.inner_ref()?.clone()))
|
||||
}
|
||||
|
||||
#[pyo3(signature = (offsets))]
|
||||
pub fn take_offsets(self_: PyRef<'_, Self>, offsets: Vec<u64>) -> PyResult<TakeQuery> {
|
||||
Ok(TakeQuery::new(
|
||||
self_.inner_ref()?.clone().take_offsets(offsets),
|
||||
))
|
||||
}
|
||||
|
||||
#[pyo3(signature = (row_ids))]
|
||||
pub fn take_row_ids(self_: PyRef<'_, Self>, row_ids: Vec<u64>) -> PyResult<TakeQuery> {
|
||||
Ok(TakeQuery::new(
|
||||
self_.inner_ref()?.clone().take_row_ids(row_ids),
|
||||
))
|
||||
}
|
||||
|
||||
/// Optimize the on-disk data by compacting and pruning old data, for better performance.
|
||||
#[pyo3(signature = (cleanup_since_ms=None, delete_unverified=None, retrain=None))]
|
||||
pub fn optimize(
|
||||
|
||||
Reference in New Issue
Block a user