mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-05 03:12:57 +00:00
feat: refactor the query API and add query support to the python async API (#1113)
In addition, there are also a number of changes in nodejs to the docstrings of existing methods because this PR adds a jsdoc linter.
This commit is contained in:
@@ -22,6 +22,9 @@ pyo3-asyncio = { version = "0.20", features = ["attributes", "tokio-runtime"] }
|
||||
|
||||
# Prevent dynamic linking of lzma, which comes from datafusion
|
||||
lzma-sys = { version = "*", features = ["static"] }
|
||||
pin-project = "1.1.5"
|
||||
futures.workspace = true
|
||||
tokio = { version = "1.36.0", features = ["sync"] }
|
||||
|
||||
[build-dependencies]
|
||||
pyo3-build-config = { version = "0.20.3", features = [
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
@@ -40,6 +40,8 @@ class Table:
|
||||
async def checkout_latest(self): ...
|
||||
async def restore(self): ...
|
||||
async def list_indices(self) -> List[IndexConfig]: ...
|
||||
def query(self) -> Query: ...
|
||||
def vector_search(self) -> VectorQuery: ...
|
||||
|
||||
class IndexConfig:
|
||||
index_type: str
|
||||
@@ -52,3 +54,27 @@ async def connect(
|
||||
host_override: Optional[str],
|
||||
read_consistency_interval: Optional[float],
|
||||
) -> Connection: ...
|
||||
|
||||
class RecordBatchStream:
|
||||
def schema(self) -> pa.Schema: ...
|
||||
async def next(self) -> Optional[pa.RecordBatch]: ...
|
||||
|
||||
class Query:
|
||||
def where(self, filter: str): ...
|
||||
def select(self, columns: Tuple[str, str]): ...
|
||||
def limit(self, limit: int): ...
|
||||
def nearest_to(self, query_vec: pa.Array) -> VectorQuery: ...
|
||||
async def execute(self) -> RecordBatchStream: ...
|
||||
|
||||
class VectorQuery:
|
||||
async def execute(self) -> RecordBatchStream: ...
|
||||
def where(self, filter: str): ...
|
||||
def select(self, columns: List[str]): ...
|
||||
def select_with_projection(self, columns: Tuple[str, str]): ...
|
||||
def limit(self, limit: int): ...
|
||||
def column(self, column: str): ...
|
||||
def distance_type(self, distance_type: str): ...
|
||||
def postfilter(self): ...
|
||||
def refine_factor(self, refine_factor: int): ...
|
||||
def nprobes(self, nprobes: int): ...
|
||||
def bypass_vector_index(self): ...
|
||||
|
||||
44
python/python/lancedb/arrow.py
Normal file
44
python/python/lancedb/arrow.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from typing import List
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
from ._lancedb import RecordBatchStream
|
||||
|
||||
|
||||
class AsyncRecordBatchReader:
|
||||
"""
|
||||
An async iterator over a stream of RecordBatches.
|
||||
|
||||
Also allows access to the schema of the stream
|
||||
"""
|
||||
|
||||
def __init__(self, inner: RecordBatchStream):
|
||||
self.inner_ = inner
|
||||
|
||||
@property
|
||||
def schema(self) -> pa.Schema:
|
||||
"""
|
||||
Get the schema of the batches produced by the stream
|
||||
|
||||
Accessing the schema does not consume any data from the stream
|
||||
"""
|
||||
return self.inner_.schema()
|
||||
|
||||
async def read_all(self) -> List[pa.RecordBatch]:
|
||||
"""
|
||||
Read all the record batches from the stream
|
||||
|
||||
This consumes the entire stream and returns a list of record batches
|
||||
|
||||
If there are a lot of results this may consume a lot of memory
|
||||
"""
|
||||
return [batch async for batch in self]
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> pa.RecordBatch:
|
||||
next = await self.inner_.next()
|
||||
if next is None:
|
||||
raise StopAsyncIteration
|
||||
return next
|
||||
@@ -24,6 +24,7 @@ import pyarrow as pa
|
||||
import pydantic
|
||||
|
||||
from . import __version__
|
||||
from .arrow import AsyncRecordBatchReader
|
||||
from .common import VEC
|
||||
from .rerankers.base import Reranker
|
||||
from .rerankers.linear_combination import LinearCombinationReranker
|
||||
@@ -33,6 +34,8 @@ if TYPE_CHECKING:
|
||||
import PIL
|
||||
import polars as pl
|
||||
|
||||
from ._lancedb import Query as LanceQuery
|
||||
from ._lancedb import VectorQuery as LanceVectorQuery
|
||||
from .pydantic import LanceModel
|
||||
from .table import Table
|
||||
|
||||
@@ -921,3 +924,334 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
"""
|
||||
self._vector_query.refine_factor(refine_factor)
|
||||
return self
|
||||
|
||||
|
||||
class AsyncQueryBase(object):
|
||||
def __init__(self, inner: Union[LanceQuery | LanceVectorQuery]):
|
||||
"""
|
||||
Construct an AsyncQueryBase
|
||||
|
||||
This method is not intended to be called directly. Instead, use the
|
||||
[Table.query][] method to create a query.
|
||||
"""
|
||||
self._inner = inner
|
||||
|
||||
def where(self, predicate: str) -> AsyncQuery:
|
||||
"""
|
||||
Only return rows matching the given predicate
|
||||
|
||||
The predicate should be supplied as an SQL query string. For example:
|
||||
|
||||
>>> predicate = "x > 10"
|
||||
>>> predicate = "y > 0 AND y < 100"
|
||||
>>> predicate = "x > 5 OR y = 'test'"
|
||||
|
||||
Filtering performance can often be improved by creating a scalar index
|
||||
on the filter column(s).
|
||||
"""
|
||||
self._inner.where(predicate)
|
||||
return self
|
||||
|
||||
def select(self, columns: Union[List[str], dict[str, str]]) -> AsyncQuery:
|
||||
"""
|
||||
Return only the specified columns.
|
||||
|
||||
By default a query will return all columns from the table. However, this can
|
||||
have a very significant impact on latency. LanceDb stores data in a columnar
|
||||
fashion. This
|
||||
means we can finely tune our I/O to select exactly the columns we need.
|
||||
|
||||
As a best practice you should always limit queries to the columns that you need.
|
||||
If you pass in a list of column names then only those columns will be
|
||||
returned.
|
||||
|
||||
You can also use this method to create new "dynamic" columns based on your
|
||||
existing columns. For example, you may not care about "a" or "b" but instead
|
||||
simply want "a + b". This is often seen in the SELECT clause of an SQL query
|
||||
(e.g. `SELECT a+b FROM my_table`).
|
||||
|
||||
To create dynamic columns you can pass in a dict[str, str]. A column will be
|
||||
returned for each entry in the map. The key provides the name of the column.
|
||||
The value is an SQL string used to specify how the column is calculated.
|
||||
|
||||
For example, an SQL query might state `SELECT a + b AS combined, c`. The
|
||||
equivalent input to this method would be `{"combined": "a + b", "c": "c"}`.
|
||||
|
||||
Columns will always be returned in the order given, even if that order is
|
||||
different than the order used when adding the data.
|
||||
"""
|
||||
if isinstance(columns, dict):
|
||||
column_tuples = list(columns.items())
|
||||
else:
|
||||
try:
|
||||
column_tuples = [(c, c) for c in columns]
|
||||
except TypeError:
|
||||
raise TypeError("columns must be a list of column names or a dict")
|
||||
self._inner.select(column_tuples)
|
||||
return self
|
||||
|
||||
def limit(self, limit: int) -> AsyncQuery:
|
||||
"""
|
||||
Set the maximum number of results to return.
|
||||
|
||||
By default, a plain search has no limit. If this method is not
|
||||
called then every valid row from the table will be returned.
|
||||
"""
|
||||
self._inner.limit(limit)
|
||||
return self
|
||||
|
||||
async def to_batches(self) -> AsyncRecordBatchReader:
|
||||
"""
|
||||
Execute the query and return the results as an Apache Arrow RecordBatchReader.
|
||||
"""
|
||||
return AsyncRecordBatchReader(await self._inner.execute())
|
||||
|
||||
async def to_arrow(self) -> pa.Table:
|
||||
"""
|
||||
Execute the query and collect the results into an Apache Arrow Table.
|
||||
|
||||
This method will collect all results into memory before returning. If
|
||||
you expect a large number of results, you may want to use [to_batches][]
|
||||
"""
|
||||
batch_iter = await self.to_batches()
|
||||
return pa.Table.from_batches(
|
||||
await batch_iter.read_all(), schema=batch_iter.schema
|
||||
)
|
||||
|
||||
async def to_pandas(self) -> "pd.DataFrame":
|
||||
"""
|
||||
Execute the query and collect the results into a pandas DataFrame.
|
||||
|
||||
This method will collect all results into memory before returning. If
|
||||
you expect a large number of results, you may want to use [to_batches][]
|
||||
and convert each batch to pandas separately.
|
||||
|
||||
Example
|
||||
-------
|
||||
|
||||
>>> import asyncio
|
||||
>>> from lancedb import connect_async
|
||||
>>> async def doctest_example():
|
||||
... conn = await connect_async("./.lancedb")
|
||||
... table = await conn.create_table("my_table", data=[{"a": 1, "b": 2}])
|
||||
... async for batch in await table.query().to_batches():
|
||||
... batch_df = batch.to_pandas()
|
||||
>>> asyncio.run(doctest_example())
|
||||
"""
|
||||
return (await self.to_arrow()).to_pandas()
|
||||
|
||||
|
||||
class AsyncQuery(AsyncQueryBase):
|
||||
def __init__(self, inner: LanceQuery):
|
||||
"""
|
||||
Construct an AsyncQuery
|
||||
|
||||
This method is not intended to be called directly. Instead, use the
|
||||
[Table.query][] method to create a query.
|
||||
"""
|
||||
super().__init__(inner)
|
||||
self._inner = inner
|
||||
|
||||
@classmethod
|
||||
def _query_vec_to_array(self, vec: Union[VEC, Tuple]):
|
||||
if isinstance(vec, list):
|
||||
return pa.array(vec)
|
||||
if isinstance(vec, np.ndarray):
|
||||
return pa.array(vec)
|
||||
if isinstance(vec, pa.Array):
|
||||
return vec
|
||||
if isinstance(vec, pa.ChunkedArray):
|
||||
return vec.combine_chunks()
|
||||
if isinstance(vec, tuple):
|
||||
return pa.array(vec)
|
||||
# We've checked everything we formally support in our typings
|
||||
# but, as a fallback, let pyarrow try and convert it anyway.
|
||||
# This can allow for some more exotic things like iterables
|
||||
return pa.array(vec)
|
||||
|
||||
def nearest_to(
|
||||
self, query_vector: Optional[Union[VEC, Tuple]] = None
|
||||
) -> AsyncVectorQuery:
|
||||
"""
|
||||
Find the nearest vectors to the given query vector.
|
||||
|
||||
This converts the query from a plain query to a vector query.
|
||||
|
||||
This method will attempt to convert the input to the query vector
|
||||
expected by the embedding model. If the input cannot be converted
|
||||
then an error will be thrown.
|
||||
|
||||
By default, there is no embedding model, and the input should be
|
||||
something that can be converted to a pyarrow array of floats. This
|
||||
includes lists, numpy arrays, and tuples.
|
||||
|
||||
If there is only one vector column (a column whose data type is a
|
||||
fixed size list of floats) then the column does not need to be specified.
|
||||
If there is more than one vector column you must use
|
||||
[AsyncVectorQuery::column][] to specify which column you would like to
|
||||
compare with.
|
||||
|
||||
If no index has been created on the vector column then a vector query
|
||||
will perform a distance comparison between the query vector and every
|
||||
vector in the database and then sort the results. This is sometimes
|
||||
called a "flat search"
|
||||
|
||||
For small databases, with tens of thousands of vectors or less, this can
|
||||
be reasonably fast. In larger databases you should create a vector index
|
||||
on the column. If there is a vector index then an "approximate" nearest
|
||||
neighbor search (frequently called an ANN search) will be performed. This
|
||||
search is much faster, but the results will be approximate.
|
||||
|
||||
The query can be further parameterized using the returned builder. There
|
||||
are various ANN search parameters that will let you fine tune your recall
|
||||
accuracy vs search latency.
|
||||
|
||||
Vector searches always have a [limit][]. If `limit` has not been called then
|
||||
a default `limit` of 10 will be used.
|
||||
"""
|
||||
return AsyncVectorQuery(
|
||||
self._inner.nearest_to(AsyncQuery._query_vec_to_array(query_vector))
|
||||
)
|
||||
|
||||
|
||||
class AsyncVectorQuery(AsyncQueryBase):
|
||||
def __init__(self, inner: LanceVectorQuery):
|
||||
"""
|
||||
Construct an AsyncVectorQuery
|
||||
|
||||
This method is not intended to be called directly. Instead, create
|
||||
a query first with [Table.query][] and then use [AsyncQuery.nearest_to][]
|
||||
to convert to a vector query.
|
||||
"""
|
||||
super().__init__(inner)
|
||||
self._inner = inner
|
||||
|
||||
def column(self, column: str) -> AsyncVectorQuery:
|
||||
"""
|
||||
Set the vector column to query
|
||||
|
||||
This controls which column is compared to the query vector supplied in
|
||||
the call to [Query.nearest_to][].
|
||||
|
||||
This parameter must be specified if the table has more than one column
|
||||
whose data type is a fixed-size-list of floats.
|
||||
"""
|
||||
self._inner.column(column)
|
||||
return self
|
||||
|
||||
def nprobes(self, nprobes: int) -> AsyncVectorQuery:
|
||||
"""
|
||||
Set the number of partitions to search (probe)
|
||||
|
||||
This argument is only used when the vector column has an IVF PQ index.
|
||||
If there is no index then this value is ignored.
|
||||
|
||||
The IVF stage of IVF PQ divides the input into partitions (clusters) of
|
||||
related values.
|
||||
|
||||
The partition whose centroids are closest to the query vector will be
|
||||
exhaustiely searched to find matches. This parameter controls how many
|
||||
partitions should be searched.
|
||||
|
||||
Increasing this value will increase the recall of your query but will
|
||||
also increase the latency of your query. The default value is 20. This
|
||||
default is good for many cases but the best value to use will depend on
|
||||
your data and the recall that you need to achieve.
|
||||
|
||||
For best results we recommend tuning this parameter with a benchmark against
|
||||
your actual data to find the smallest possible value that will still give
|
||||
you the desired recall.
|
||||
"""
|
||||
self._inner.nprobes(nprobes)
|
||||
return self
|
||||
|
||||
def refine_factor(self, refine_factor: int) -> AsyncVectorQuery:
|
||||
"""
|
||||
A multiplier to control how many additional rows are taken during the refine
|
||||
step
|
||||
|
||||
This argument is only used when the vector column has an IVF PQ index.
|
||||
If there is no index then this value is ignored.
|
||||
|
||||
An IVF PQ index stores compressed (quantized) values. They query vector is
|
||||
compared against these values and, since they are compressed, the comparison is
|
||||
inaccurate.
|
||||
|
||||
This parameter can be used to refine the results. It can improve both improve
|
||||
recall and correct the ordering of the nearest results.
|
||||
|
||||
To refine results LanceDb will first perform an ANN search to find the nearest
|
||||
`limit` * `refine_factor` results. In other words, if `refine_factor` is 3 and
|
||||
`limit` is the default (10) then the first 30 results will be selected. LanceDb
|
||||
then fetches the full, uncompressed, values for these 30 results. The results
|
||||
are then reordered by the true distance and only the nearest 10 are kept.
|
||||
|
||||
Note: there is a difference between calling this method with a value of 1 and
|
||||
never calling this method at all. Calling this method with any value will have
|
||||
an impact on your search latency. When you call this method with a
|
||||
`refine_factor` of 1 then LanceDb still needs to fetch the full, uncompressed,
|
||||
values so that it can potentially reorder the results.
|
||||
|
||||
Note: if this method is NOT called then the distances returned in the _distance
|
||||
column will be approximate distances based on the comparison of the quantized
|
||||
query vector and the quantized result vectors. This can be considerably
|
||||
different than the true distance between the query vector and the actual
|
||||
uncompressed vector.
|
||||
"""
|
||||
self._inner.refine_factor(refine_factor)
|
||||
return self
|
||||
|
||||
def distance_type(self, distance_type: str) -> AsyncVectorQuery:
|
||||
"""
|
||||
Set the distance metric to use
|
||||
|
||||
When performing a vector search we try and find the "nearest" vectors according
|
||||
to some kind of distance metric. This parameter controls which distance metric
|
||||
to use. See @see {@link IvfPqOptions.distanceType} for more details on the
|
||||
different distance metrics available.
|
||||
|
||||
Note: if there is a vector index then the distance type used MUST match the
|
||||
distance type used to train the vector index. If this is not done then the
|
||||
results will be invalid.
|
||||
|
||||
By default "l2" is used.
|
||||
"""
|
||||
self._inner.distance_type(distance_type)
|
||||
return self
|
||||
|
||||
def postfilter(self) -> AsyncVectorQuery:
|
||||
"""
|
||||
If this is called then filtering will happen after the vector search instead of
|
||||
before.
|
||||
|
||||
By default filtering will be performed before the vector search. This is how
|
||||
filtering is typically understood to work. This prefilter step does add some
|
||||
additional latency. Creating a scalar index on the filter column(s) can
|
||||
often improve this latency. However, sometimes a filter is too complex or
|
||||
scalar indices cannot be applied to the column. In these cases postfiltering
|
||||
can be used instead of prefiltering to improve latency.
|
||||
|
||||
Post filtering applies the filter to the results of the vector search. This
|
||||
means we only run the filter on a much smaller set of data. However, it can
|
||||
cause the query to return fewer than `limit` results (or even no results) if
|
||||
none of the nearest results match the filter.
|
||||
|
||||
Post filtering happens during the "refine stage" (described in more detail in
|
||||
@see {@link VectorQuery#refineFactor}). This means that setting a higher refine
|
||||
factor can often help restore some of the results lost by post filtering.
|
||||
"""
|
||||
self._inner.postfilter()
|
||||
return self
|
||||
|
||||
def bypass_vector_index(self) -> AsyncVectorQuery:
|
||||
"""
|
||||
If this is called then any vector index is skipped
|
||||
|
||||
An exhaustive (flat) search will be performed. The query vector will
|
||||
be compared to every vector in the table. At high scales this can be
|
||||
expensive. However, this is often still useful. For example, skipping
|
||||
the vector index can give you ground truth results which you can use to
|
||||
calculate your recall to select an appropriate value for nprobes.
|
||||
"""
|
||||
self._inner.bypass_vector_index()
|
||||
return self
|
||||
|
||||
@@ -43,7 +43,7 @@ from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
|
||||
from .merge import LanceMergeInsertBuilder
|
||||
from .pydantic import LanceModel, model_to_dict
|
||||
from .query import LanceQueryBuilder, Query
|
||||
from .query import AsyncQuery, AsyncVectorQuery, LanceQueryBuilder, Query
|
||||
from .util import (
|
||||
fs_from_uri,
|
||||
inf_vector_column_query,
|
||||
@@ -1899,6 +1899,9 @@ class AsyncTable:
|
||||
"""
|
||||
return await self._inner.count_rows(filter)
|
||||
|
||||
def query(self) -> AsyncQuery:
|
||||
return AsyncQuery(self._inner.query())
|
||||
|
||||
async def to_pandas(self) -> "pd.DataFrame":
|
||||
"""Return the table as a pandas DataFrame.
|
||||
|
||||
@@ -1906,7 +1909,7 @@ class AsyncTable:
|
||||
-------
|
||||
pd.DataFrame
|
||||
"""
|
||||
return self.to_arrow().to_pandas()
|
||||
return (await self.to_arrow()).to_pandas()
|
||||
|
||||
async def to_arrow(self) -> pa.Table:
|
||||
"""Return the table as a pyarrow Table.
|
||||
@@ -1915,7 +1918,7 @@ class AsyncTable:
|
||||
-------
|
||||
pa.Table
|
||||
"""
|
||||
raise NotImplementedError
|
||||
return await self.query().to_arrow()
|
||||
|
||||
async def create_index(
|
||||
self,
|
||||
@@ -2068,90 +2071,18 @@ class AsyncTable:
|
||||
|
||||
return LanceMergeInsertBuilder(self, on)
|
||||
|
||||
async def search(
|
||||
def vector_search(
|
||||
self,
|
||||
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
|
||||
vector_column_name: Optional[str] = None,
|
||||
query_type: str = "auto",
|
||||
) -> LanceQueryBuilder:
|
||||
"""Create a search query to find the nearest neighbors
|
||||
of the given query vector. We currently support [vector search][search]
|
||||
and [full-text search][experimental-full-text-search].
|
||||
|
||||
All query options are defined in [Query][lancedb.query.Query].
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> data = [
|
||||
... {"original_width": 100, "caption": "bar", "vector": [0.1, 2.3, 4.5]},
|
||||
... {"original_width": 2000, "caption": "foo", "vector": [0.5, 3.4, 1.3]},
|
||||
... {"original_width": 3000, "caption": "test", "vector": [0.3, 6.2, 2.6]}
|
||||
... ]
|
||||
>>> table = db.create_table("my_table", data)
|
||||
>>> query = [0.4, 1.4, 2.4]
|
||||
>>> (table.search(query)
|
||||
... .where("original_width > 1000", prefilter=True)
|
||||
... .select(["caption", "original_width", "vector"])
|
||||
... .limit(2)
|
||||
... .to_pandas())
|
||||
caption original_width vector _distance
|
||||
0 foo 2000 [0.5, 3.4, 1.3] 5.220000
|
||||
1 test 3000 [0.3, 6.2, 2.6] 23.089996
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query: list/np.ndarray/str/PIL.Image.Image, default None
|
||||
The targetted vector to search for.
|
||||
|
||||
- *default None*.
|
||||
Acceptable types are: list, np.ndarray, PIL.Image.Image
|
||||
|
||||
- If None then the select/where/limit clauses are applied to filter
|
||||
the table
|
||||
vector_column_name: str, optional
|
||||
The name of the vector column to search.
|
||||
|
||||
The vector column needs to be a pyarrow fixed size list type
|
||||
|
||||
- If not specified then the vector column is inferred from
|
||||
the table schema
|
||||
|
||||
- If the table has multiple vector columns then the *vector_column_name*
|
||||
needs to be specified. Otherwise, an error is raised.
|
||||
query_type: str
|
||||
*default "auto"*.
|
||||
Acceptable types are: "vector", "fts", "hybrid", or "auto"
|
||||
|
||||
- If "auto" then the query type is inferred from the query;
|
||||
|
||||
- If `query` is a list/np.ndarray then the query type is
|
||||
"vector";
|
||||
|
||||
- If `query` is a PIL.Image.Image then either do vector search,
|
||||
or raise an error if no corresponding embedding function is found.
|
||||
|
||||
- If `query` is a string, then the query type is "vector" if the
|
||||
table has embedding functions else the query type is "fts"
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceQueryBuilder
|
||||
A query builder object representing the query.
|
||||
Once executed, the query returns
|
||||
|
||||
- selected columns
|
||||
|
||||
- the vector
|
||||
|
||||
- and also the "_distance" column which is the distance between the query
|
||||
vector and the returned vector.
|
||||
query_vector: Optional[Union[VEC, Tuple]] = None,
|
||||
) -> AsyncVectorQuery:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
Search the table with a given query vector.
|
||||
|
||||
async def _execute_query(self, query: Query) -> pa.Table:
|
||||
pass
|
||||
This is a convenience method for preparing a vector query and
|
||||
is the same thing as calling `nearestTo` on the builder returned
|
||||
by `query`. Seer [nearest_to][AsyncQuery.nearest_to] for more details.
|
||||
"""
|
||||
return self.query().nearest_to(query_vector)
|
||||
|
||||
async def _do_merge(
|
||||
self,
|
||||
|
||||
@@ -12,16 +12,19 @@
|
||||
# limitations under the License.
|
||||
|
||||
import unittest.mock as mock
|
||||
from datetime import timedelta
|
||||
|
||||
import lance
|
||||
import lancedb
|
||||
import numpy as np
|
||||
import pandas.testing as tm
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from lancedb.db import LanceDBConnection
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.query import LanceVectorQueryBuilder, Query
|
||||
from lancedb.table import LanceTable
|
||||
from lancedb.query import AsyncQueryBase, LanceVectorQueryBuilder, Query
|
||||
from lancedb.table import AsyncTable, LanceTable
|
||||
|
||||
|
||||
class MockTable:
|
||||
@@ -65,6 +68,24 @@ def table(tmp_path) -> MockTable:
|
||||
return MockTable(tmp_path)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def table_async(tmp_path) -> AsyncTable:
|
||||
conn = await lancedb.connect_async(
|
||||
tmp_path, read_consistency_interval=timedelta(seconds=0)
|
||||
)
|
||||
data = pa.table(
|
||||
{
|
||||
"vector": pa.array(
|
||||
[[1, 2], [3, 4]], type=pa.list_(pa.float32(), list_size=2)
|
||||
),
|
||||
"id": pa.array([1, 2]),
|
||||
"str_field": pa.array(["a", "b"]),
|
||||
"float_field": pa.array([1.0, 2.0]),
|
||||
}
|
||||
)
|
||||
return await conn.create_table("test", data)
|
||||
|
||||
|
||||
def test_cast(table):
|
||||
class TestModel(LanceModel):
|
||||
vector: Vector(2)
|
||||
@@ -184,3 +205,109 @@ def test_query_builder_with_different_vector_column():
|
||||
|
||||
def cosine_distance(vec1, vec2):
|
||||
return 1 - np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
|
||||
|
||||
|
||||
async def check_query(
|
||||
query: AsyncQueryBase, *, expected_num_rows=None, expected_columns=None
|
||||
):
|
||||
num_rows = 0
|
||||
results = await query.to_batches()
|
||||
async for batch in results:
|
||||
if expected_columns is not None:
|
||||
assert batch.schema.names == expected_columns
|
||||
num_rows += batch.num_rows
|
||||
if expected_num_rows is not None:
|
||||
assert num_rows == expected_num_rows
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_async(table_async: AsyncTable):
|
||||
await check_query(
|
||||
table_async.query(),
|
||||
expected_num_rows=2,
|
||||
expected_columns=["vector", "id", "str_field", "float_field"],
|
||||
)
|
||||
await check_query(table_async.query().where("id = 2"), expected_num_rows=1)
|
||||
await check_query(
|
||||
table_async.query().select(["id", "vector"]), expected_columns=["id", "vector"]
|
||||
)
|
||||
await check_query(
|
||||
table_async.query().select({"foo": "id", "bar": "id + 1"}),
|
||||
expected_columns=["foo", "bar"],
|
||||
)
|
||||
await check_query(table_async.query().limit(1), expected_num_rows=1)
|
||||
await check_query(
|
||||
table_async.query().nearest_to(pa.array([1, 2])), expected_num_rows=2
|
||||
)
|
||||
# Support different types of inputs for the vector query
|
||||
for vector_query in [
|
||||
[1, 2],
|
||||
[1.0, 2.0],
|
||||
np.array([1, 2]),
|
||||
(1, 2),
|
||||
]:
|
||||
await check_query(
|
||||
table_async.query().nearest_to(vector_query), expected_num_rows=2
|
||||
)
|
||||
|
||||
# No easy way to check these vector query parameters are doing what they say. We
|
||||
# just check that they don't raise exceptions and assume this is tested at a lower
|
||||
# level.
|
||||
await check_query(
|
||||
table_async.query().where("id = 2").nearest_to(pa.array([1, 2])).postfilter(),
|
||||
expected_num_rows=1,
|
||||
)
|
||||
await check_query(
|
||||
table_async.query().nearest_to(pa.array([1, 2])).refine_factor(1),
|
||||
expected_num_rows=2,
|
||||
)
|
||||
await check_query(
|
||||
table_async.query().nearest_to(pa.array([1, 2])).nprobes(10),
|
||||
expected_num_rows=2,
|
||||
)
|
||||
await check_query(
|
||||
table_async.query().nearest_to(pa.array([1, 2])).bypass_vector_index(),
|
||||
expected_num_rows=2,
|
||||
)
|
||||
await check_query(
|
||||
table_async.query().nearest_to(pa.array([1, 2])).distance_type("dot"),
|
||||
expected_num_rows=2,
|
||||
)
|
||||
await check_query(
|
||||
table_async.query().nearest_to(pa.array([1, 2])).distance_type("DoT"),
|
||||
expected_num_rows=2,
|
||||
)
|
||||
|
||||
# Make sure we can use a vector query as a base query (e.g. call limit on it)
|
||||
# Also make sure `vector_search` works
|
||||
await check_query(table_async.vector_search([1, 2]).limit(1), expected_num_rows=1)
|
||||
|
||||
# Also check an empty query
|
||||
await check_query(table_async.query().where("id < 0"), expected_num_rows=0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_to_arrow_async(table_async: AsyncTable):
|
||||
table = await table_async.to_arrow()
|
||||
assert table.num_rows == 2
|
||||
assert table.num_columns == 4
|
||||
|
||||
table = await table_async.query().to_arrow()
|
||||
assert table.num_rows == 2
|
||||
assert table.num_columns == 4
|
||||
|
||||
table = await table_async.query().where("id < 0").to_arrow()
|
||||
assert table.num_rows == 0
|
||||
assert table.num_columns == 4
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_to_pandas_async(table_async: AsyncTable):
|
||||
df = await table_async.to_pandas()
|
||||
assert df.shape == (2, 4)
|
||||
|
||||
df = await table_async.query().to_pandas()
|
||||
assert df.shape == (2, 4)
|
||||
|
||||
df = await table_async.query().where("id < 0").to_pandas()
|
||||
assert df.shape == (0, 4)
|
||||
|
||||
51
python/src/arrow.rs
Normal file
51
python/src/arrow.rs
Normal file
@@ -0,0 +1,51 @@
|
||||
// use arrow::datatypes::SchemaRef;
|
||||
// use lancedb::arrow::SendableRecordBatchStream;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow::{
|
||||
datatypes::SchemaRef,
|
||||
pyarrow::{IntoPyArrow, ToPyArrow},
|
||||
};
|
||||
use futures::stream::StreamExt;
|
||||
use lancedb::arrow::SendableRecordBatchStream;
|
||||
use pyo3::{pyclass, pymethods, PyAny, PyObject, PyRef, PyResult, Python};
|
||||
use pyo3_asyncio::tokio::future_into_py;
|
||||
|
||||
use crate::error::PythonErrorExt;
|
||||
|
||||
#[pyclass]
|
||||
pub struct RecordBatchStream {
|
||||
schema: SchemaRef,
|
||||
inner: Arc<tokio::sync::Mutex<SendableRecordBatchStream>>,
|
||||
}
|
||||
|
||||
impl RecordBatchStream {
|
||||
pub fn new(inner: SendableRecordBatchStream) -> Self {
|
||||
let schema = inner.schema().clone();
|
||||
Self {
|
||||
schema,
|
||||
inner: Arc::new(tokio::sync::Mutex::new(inner)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl RecordBatchStream {
|
||||
pub fn schema(&self, py: Python) -> PyResult<PyObject> {
|
||||
(*self.schema).clone().into_pyarrow(py)
|
||||
}
|
||||
|
||||
pub fn next(self_: PyRef<'_, Self>) -> PyResult<&PyAny> {
|
||||
let inner = self_.inner.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
let inner_next = inner.lock().await.next().await;
|
||||
inner_next
|
||||
.map(|item| {
|
||||
let item = item.infer_error()?;
|
||||
Python::with_gil(|py| item.to_pyarrow(py))
|
||||
})
|
||||
.transpose()
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -12,15 +12,19 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use arrow::RecordBatchStream;
|
||||
use connection::{connect, Connection};
|
||||
use env_logger::Env;
|
||||
use index::{Index, IndexConfig};
|
||||
use pyo3::{pymodule, types::PyModule, wrap_pyfunction, PyResult, Python};
|
||||
use query::{Query, VectorQuery};
|
||||
use table::Table;
|
||||
|
||||
pub mod arrow;
|
||||
pub mod connection;
|
||||
pub mod error;
|
||||
pub mod index;
|
||||
pub mod query;
|
||||
pub mod table;
|
||||
pub mod util;
|
||||
|
||||
@@ -34,6 +38,9 @@ pub fn _lancedb(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||
m.add_class::<Table>()?;
|
||||
m.add_class::<Index>()?;
|
||||
m.add_class::<IndexConfig>()?;
|
||||
m.add_class::<Query>()?;
|
||||
m.add_class::<VectorQuery>()?;
|
||||
m.add_class::<RecordBatchStream>()?;
|
||||
m.add_function(wrap_pyfunction!(connect, m)?)?;
|
||||
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
|
||||
Ok(())
|
||||
|
||||
125
python/src/query.rs
Normal file
125
python/src/query.rs
Normal file
@@ -0,0 +1,125 @@
|
||||
// Copyright 2024 Lance Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use arrow::array::make_array;
|
||||
use arrow::array::ArrayData;
|
||||
use arrow::pyarrow::FromPyArrow;
|
||||
use lancedb::query::{
|
||||
ExecutableQuery, Query as LanceDbQuery, QueryBase, Select, VectorQuery as LanceDbVectorQuery,
|
||||
};
|
||||
use pyo3::pyclass;
|
||||
use pyo3::pymethods;
|
||||
use pyo3::PyAny;
|
||||
use pyo3::PyRef;
|
||||
use pyo3::PyResult;
|
||||
use pyo3_asyncio::tokio::future_into_py;
|
||||
|
||||
use crate::arrow::RecordBatchStream;
|
||||
use crate::error::PythonErrorExt;
|
||||
use crate::util::parse_distance_type;
|
||||
|
||||
#[pyclass]
|
||||
pub struct Query {
|
||||
inner: LanceDbQuery,
|
||||
}
|
||||
|
||||
impl Query {
|
||||
pub fn new(query: LanceDbQuery) -> Self {
|
||||
Self { inner: query }
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl Query {
|
||||
pub fn r#where(&mut self, predicate: String) {
|
||||
self.inner = self.inner.clone().only_if(predicate);
|
||||
}
|
||||
|
||||
pub fn select(&mut self, columns: Vec<(String, String)>) {
|
||||
self.inner = self.inner.clone().select(Select::dynamic(&columns));
|
||||
}
|
||||
|
||||
pub fn limit(&mut self, limit: u32) {
|
||||
self.inner = self.inner.clone().limit(limit as usize);
|
||||
}
|
||||
|
||||
pub fn nearest_to(&mut self, vector: &PyAny) -> PyResult<VectorQuery> {
|
||||
let data: ArrayData = ArrayData::from_pyarrow(vector)?;
|
||||
let array = make_array(data);
|
||||
let inner = self.inner.clone().nearest_to(array).infer_error()?;
|
||||
Ok(VectorQuery { inner })
|
||||
}
|
||||
|
||||
pub fn execute(self_: PyRef<'_, Self>) -> PyResult<&PyAny> {
|
||||
let inner = self_.inner.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
let inner_stream = inner.execute().await.infer_error()?;
|
||||
Ok(RecordBatchStream::new(inner_stream))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
pub struct VectorQuery {
|
||||
inner: LanceDbVectorQuery,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl VectorQuery {
|
||||
pub fn r#where(&mut self, predicate: String) {
|
||||
self.inner = self.inner.clone().only_if(predicate);
|
||||
}
|
||||
|
||||
pub fn select(&mut self, columns: Vec<(String, String)>) {
|
||||
self.inner = self.inner.clone().select(Select::dynamic(&columns));
|
||||
}
|
||||
|
||||
pub fn limit(&mut self, limit: u32) {
|
||||
self.inner = self.inner.clone().limit(limit as usize);
|
||||
}
|
||||
|
||||
pub fn column(&mut self, column: String) {
|
||||
self.inner = self.inner.clone().column(&column);
|
||||
}
|
||||
|
||||
pub fn distance_type(&mut self, distance_type: String) -> PyResult<()> {
|
||||
let distance_type = parse_distance_type(distance_type)?;
|
||||
self.inner = self.inner.clone().distance_type(distance_type);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn postfilter(&mut self) {
|
||||
self.inner = self.inner.clone().postfilter();
|
||||
}
|
||||
|
||||
pub fn refine_factor(&mut self, refine_factor: u32) {
|
||||
self.inner = self.inner.clone().refine_factor(refine_factor);
|
||||
}
|
||||
|
||||
pub fn nprobes(&mut self, nprobe: u32) {
|
||||
self.inner = self.inner.clone().nprobes(nprobe as usize);
|
||||
}
|
||||
|
||||
pub fn bypass_vector_index(&mut self) {
|
||||
self.inner = self.inner.clone().bypass_vector_index()
|
||||
}
|
||||
|
||||
pub fn execute(self_: PyRef<'_, Self>) -> PyResult<&PyAny> {
|
||||
let inner = self_.inner.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
let inner_stream = inner.execute().await.infer_error()?;
|
||||
Ok(RecordBatchStream::new(inner_stream))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -14,6 +14,7 @@ use pyo3_asyncio::tokio::future_into_py;
|
||||
use crate::{
|
||||
error::PythonErrorExt,
|
||||
index::{Index, IndexConfig},
|
||||
query::Query,
|
||||
};
|
||||
|
||||
#[pyclass]
|
||||
@@ -179,4 +180,8 @@ impl Table {
|
||||
async move { inner.restore().await.infer_error() },
|
||||
)
|
||||
}
|
||||
|
||||
pub fn query(&self) -> Query {
|
||||
Query::new(self.inner_ref().unwrap().query())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
use std::sync::Mutex;
|
||||
|
||||
use pyo3::{exceptions::PyRuntimeError, PyResult};
|
||||
use lancedb::DistanceType;
|
||||
use pyo3::{
|
||||
exceptions::{PyRuntimeError, PyValueError},
|
||||
PyResult,
|
||||
};
|
||||
|
||||
/// A wrapper around a rust builder
|
||||
///
|
||||
@@ -33,3 +37,15 @@ impl<T> BuilderWrapper<T> {
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn parse_distance_type(distance_type: impl AsRef<str>) -> PyResult<DistanceType> {
|
||||
match distance_type.as_ref().to_lowercase().as_str() {
|
||||
"l2" => Ok(DistanceType::L2),
|
||||
"cosine" => Ok(DistanceType::Cosine),
|
||||
"dot" => Ok(DistanceType::Dot),
|
||||
_ => Err(PyValueError::new_err(format!(
|
||||
"Invalid distance type '{}'. Must be one of l2, cosine, or dot",
|
||||
distance_type.as_ref()
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user