feat: refactor the query API and add query support to the python async API (#1113)

In addition, there are also a number of changes in nodejs to the
docstrings of existing methods because this PR adds a jsdoc linter.
This commit is contained in:
Weston Pace
2024-03-18 12:36:49 -07:00
parent 2db257ca29
commit 4180b44472
38 changed files with 2609 additions and 754 deletions

View File

@@ -22,6 +22,9 @@ pyo3-asyncio = { version = "0.20", features = ["attributes", "tokio-runtime"] }
# Prevent dynamic linking of lzma, which comes from datafusion
lzma-sys = { version = "*", features = ["static"] }
pin-project = "1.1.5"
futures.workspace = true
tokio = { version = "1.36.0", features = ["sync"] }
[build-dependencies]
pyo3-build-config = { version = "0.20.3", features = [

View File

@@ -1,4 +1,4 @@
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Tuple
import pyarrow as pa
@@ -40,6 +40,8 @@ class Table:
async def checkout_latest(self): ...
async def restore(self): ...
async def list_indices(self) -> List[IndexConfig]: ...
def query(self) -> Query: ...
def vector_search(self) -> VectorQuery: ...
class IndexConfig:
index_type: str
@@ -52,3 +54,27 @@ async def connect(
host_override: Optional[str],
read_consistency_interval: Optional[float],
) -> Connection: ...
class RecordBatchStream:
def schema(self) -> pa.Schema: ...
async def next(self) -> Optional[pa.RecordBatch]: ...
class Query:
def where(self, filter: str): ...
def select(self, columns: Tuple[str, str]): ...
def limit(self, limit: int): ...
def nearest_to(self, query_vec: pa.Array) -> VectorQuery: ...
async def execute(self) -> RecordBatchStream: ...
class VectorQuery:
async def execute(self) -> RecordBatchStream: ...
def where(self, filter: str): ...
def select(self, columns: List[str]): ...
def select_with_projection(self, columns: Tuple[str, str]): ...
def limit(self, limit: int): ...
def column(self, column: str): ...
def distance_type(self, distance_type: str): ...
def postfilter(self): ...
def refine_factor(self, refine_factor: int): ...
def nprobes(self, nprobes: int): ...
def bypass_vector_index(self): ...

View File

@@ -0,0 +1,44 @@
from typing import List
import pyarrow as pa
from ._lancedb import RecordBatchStream
class AsyncRecordBatchReader:
"""
An async iterator over a stream of RecordBatches.
Also allows access to the schema of the stream
"""
def __init__(self, inner: RecordBatchStream):
self.inner_ = inner
@property
def schema(self) -> pa.Schema:
"""
Get the schema of the batches produced by the stream
Accessing the schema does not consume any data from the stream
"""
return self.inner_.schema()
async def read_all(self) -> List[pa.RecordBatch]:
"""
Read all the record batches from the stream
This consumes the entire stream and returns a list of record batches
If there are a lot of results this may consume a lot of memory
"""
return [batch async for batch in self]
def __aiter__(self):
return self
async def __anext__(self) -> pa.RecordBatch:
next = await self.inner_.next()
if next is None:
raise StopAsyncIteration
return next

View File

@@ -24,6 +24,7 @@ import pyarrow as pa
import pydantic
from . import __version__
from .arrow import AsyncRecordBatchReader
from .common import VEC
from .rerankers.base import Reranker
from .rerankers.linear_combination import LinearCombinationReranker
@@ -33,6 +34,8 @@ if TYPE_CHECKING:
import PIL
import polars as pl
from ._lancedb import Query as LanceQuery
from ._lancedb import VectorQuery as LanceVectorQuery
from .pydantic import LanceModel
from .table import Table
@@ -921,3 +924,334 @@ class LanceHybridQueryBuilder(LanceQueryBuilder):
"""
self._vector_query.refine_factor(refine_factor)
return self
class AsyncQueryBase(object):
def __init__(self, inner: Union[LanceQuery | LanceVectorQuery]):
"""
Construct an AsyncQueryBase
This method is not intended to be called directly. Instead, use the
[Table.query][] method to create a query.
"""
self._inner = inner
def where(self, predicate: str) -> AsyncQuery:
"""
Only return rows matching the given predicate
The predicate should be supplied as an SQL query string. For example:
>>> predicate = "x > 10"
>>> predicate = "y > 0 AND y < 100"
>>> predicate = "x > 5 OR y = 'test'"
Filtering performance can often be improved by creating a scalar index
on the filter column(s).
"""
self._inner.where(predicate)
return self
def select(self, columns: Union[List[str], dict[str, str]]) -> AsyncQuery:
"""
Return only the specified columns.
By default a query will return all columns from the table. However, this can
have a very significant impact on latency. LanceDb stores data in a columnar
fashion. This
means we can finely tune our I/O to select exactly the columns we need.
As a best practice you should always limit queries to the columns that you need.
If you pass in a list of column names then only those columns will be
returned.
You can also use this method to create new "dynamic" columns based on your
existing columns. For example, you may not care about "a" or "b" but instead
simply want "a + b". This is often seen in the SELECT clause of an SQL query
(e.g. `SELECT a+b FROM my_table`).
To create dynamic columns you can pass in a dict[str, str]. A column will be
returned for each entry in the map. The key provides the name of the column.
The value is an SQL string used to specify how the column is calculated.
For example, an SQL query might state `SELECT a + b AS combined, c`. The
equivalent input to this method would be `{"combined": "a + b", "c": "c"}`.
Columns will always be returned in the order given, even if that order is
different than the order used when adding the data.
"""
if isinstance(columns, dict):
column_tuples = list(columns.items())
else:
try:
column_tuples = [(c, c) for c in columns]
except TypeError:
raise TypeError("columns must be a list of column names or a dict")
self._inner.select(column_tuples)
return self
def limit(self, limit: int) -> AsyncQuery:
"""
Set the maximum number of results to return.
By default, a plain search has no limit. If this method is not
called then every valid row from the table will be returned.
"""
self._inner.limit(limit)
return self
async def to_batches(self) -> AsyncRecordBatchReader:
"""
Execute the query and return the results as an Apache Arrow RecordBatchReader.
"""
return AsyncRecordBatchReader(await self._inner.execute())
async def to_arrow(self) -> pa.Table:
"""
Execute the query and collect the results into an Apache Arrow Table.
This method will collect all results into memory before returning. If
you expect a large number of results, you may want to use [to_batches][]
"""
batch_iter = await self.to_batches()
return pa.Table.from_batches(
await batch_iter.read_all(), schema=batch_iter.schema
)
async def to_pandas(self) -> "pd.DataFrame":
"""
Execute the query and collect the results into a pandas DataFrame.
This method will collect all results into memory before returning. If
you expect a large number of results, you may want to use [to_batches][]
and convert each batch to pandas separately.
Example
-------
>>> import asyncio
>>> from lancedb import connect_async
>>> async def doctest_example():
... conn = await connect_async("./.lancedb")
... table = await conn.create_table("my_table", data=[{"a": 1, "b": 2}])
... async for batch in await table.query().to_batches():
... batch_df = batch.to_pandas()
>>> asyncio.run(doctest_example())
"""
return (await self.to_arrow()).to_pandas()
class AsyncQuery(AsyncQueryBase):
def __init__(self, inner: LanceQuery):
"""
Construct an AsyncQuery
This method is not intended to be called directly. Instead, use the
[Table.query][] method to create a query.
"""
super().__init__(inner)
self._inner = inner
@classmethod
def _query_vec_to_array(self, vec: Union[VEC, Tuple]):
if isinstance(vec, list):
return pa.array(vec)
if isinstance(vec, np.ndarray):
return pa.array(vec)
if isinstance(vec, pa.Array):
return vec
if isinstance(vec, pa.ChunkedArray):
return vec.combine_chunks()
if isinstance(vec, tuple):
return pa.array(vec)
# We've checked everything we formally support in our typings
# but, as a fallback, let pyarrow try and convert it anyway.
# This can allow for some more exotic things like iterables
return pa.array(vec)
def nearest_to(
self, query_vector: Optional[Union[VEC, Tuple]] = None
) -> AsyncVectorQuery:
"""
Find the nearest vectors to the given query vector.
This converts the query from a plain query to a vector query.
This method will attempt to convert the input to the query vector
expected by the embedding model. If the input cannot be converted
then an error will be thrown.
By default, there is no embedding model, and the input should be
something that can be converted to a pyarrow array of floats. This
includes lists, numpy arrays, and tuples.
If there is only one vector column (a column whose data type is a
fixed size list of floats) then the column does not need to be specified.
If there is more than one vector column you must use
[AsyncVectorQuery::column][] to specify which column you would like to
compare with.
If no index has been created on the vector column then a vector query
will perform a distance comparison between the query vector and every
vector in the database and then sort the results. This is sometimes
called a "flat search"
For small databases, with tens of thousands of vectors or less, this can
be reasonably fast. In larger databases you should create a vector index
on the column. If there is a vector index then an "approximate" nearest
neighbor search (frequently called an ANN search) will be performed. This
search is much faster, but the results will be approximate.
The query can be further parameterized using the returned builder. There
are various ANN search parameters that will let you fine tune your recall
accuracy vs search latency.
Vector searches always have a [limit][]. If `limit` has not been called then
a default `limit` of 10 will be used.
"""
return AsyncVectorQuery(
self._inner.nearest_to(AsyncQuery._query_vec_to_array(query_vector))
)
class AsyncVectorQuery(AsyncQueryBase):
def __init__(self, inner: LanceVectorQuery):
"""
Construct an AsyncVectorQuery
This method is not intended to be called directly. Instead, create
a query first with [Table.query][] and then use [AsyncQuery.nearest_to][]
to convert to a vector query.
"""
super().__init__(inner)
self._inner = inner
def column(self, column: str) -> AsyncVectorQuery:
"""
Set the vector column to query
This controls which column is compared to the query vector supplied in
the call to [Query.nearest_to][].
This parameter must be specified if the table has more than one column
whose data type is a fixed-size-list of floats.
"""
self._inner.column(column)
return self
def nprobes(self, nprobes: int) -> AsyncVectorQuery:
"""
Set the number of partitions to search (probe)
This argument is only used when the vector column has an IVF PQ index.
If there is no index then this value is ignored.
The IVF stage of IVF PQ divides the input into partitions (clusters) of
related values.
The partition whose centroids are closest to the query vector will be
exhaustiely searched to find matches. This parameter controls how many
partitions should be searched.
Increasing this value will increase the recall of your query but will
also increase the latency of your query. The default value is 20. This
default is good for many cases but the best value to use will depend on
your data and the recall that you need to achieve.
For best results we recommend tuning this parameter with a benchmark against
your actual data to find the smallest possible value that will still give
you the desired recall.
"""
self._inner.nprobes(nprobes)
return self
def refine_factor(self, refine_factor: int) -> AsyncVectorQuery:
"""
A multiplier to control how many additional rows are taken during the refine
step
This argument is only used when the vector column has an IVF PQ index.
If there is no index then this value is ignored.
An IVF PQ index stores compressed (quantized) values. They query vector is
compared against these values and, since they are compressed, the comparison is
inaccurate.
This parameter can be used to refine the results. It can improve both improve
recall and correct the ordering of the nearest results.
To refine results LanceDb will first perform an ANN search to find the nearest
`limit` * `refine_factor` results. In other words, if `refine_factor` is 3 and
`limit` is the default (10) then the first 30 results will be selected. LanceDb
then fetches the full, uncompressed, values for these 30 results. The results
are then reordered by the true distance and only the nearest 10 are kept.
Note: there is a difference between calling this method with a value of 1 and
never calling this method at all. Calling this method with any value will have
an impact on your search latency. When you call this method with a
`refine_factor` of 1 then LanceDb still needs to fetch the full, uncompressed,
values so that it can potentially reorder the results.
Note: if this method is NOT called then the distances returned in the _distance
column will be approximate distances based on the comparison of the quantized
query vector and the quantized result vectors. This can be considerably
different than the true distance between the query vector and the actual
uncompressed vector.
"""
self._inner.refine_factor(refine_factor)
return self
def distance_type(self, distance_type: str) -> AsyncVectorQuery:
"""
Set the distance metric to use
When performing a vector search we try and find the "nearest" vectors according
to some kind of distance metric. This parameter controls which distance metric
to use. See @see {@link IvfPqOptions.distanceType} for more details on the
different distance metrics available.
Note: if there is a vector index then the distance type used MUST match the
distance type used to train the vector index. If this is not done then the
results will be invalid.
By default "l2" is used.
"""
self._inner.distance_type(distance_type)
return self
def postfilter(self) -> AsyncVectorQuery:
"""
If this is called then filtering will happen after the vector search instead of
before.
By default filtering will be performed before the vector search. This is how
filtering is typically understood to work. This prefilter step does add some
additional latency. Creating a scalar index on the filter column(s) can
often improve this latency. However, sometimes a filter is too complex or
scalar indices cannot be applied to the column. In these cases postfiltering
can be used instead of prefiltering to improve latency.
Post filtering applies the filter to the results of the vector search. This
means we only run the filter on a much smaller set of data. However, it can
cause the query to return fewer than `limit` results (or even no results) if
none of the nearest results match the filter.
Post filtering happens during the "refine stage" (described in more detail in
@see {@link VectorQuery#refineFactor}). This means that setting a higher refine
factor can often help restore some of the results lost by post filtering.
"""
self._inner.postfilter()
return self
def bypass_vector_index(self) -> AsyncVectorQuery:
"""
If this is called then any vector index is skipped
An exhaustive (flat) search will be performed. The query vector will
be compared to every vector in the table. At high scales this can be
expensive. However, this is often still useful. For example, skipping
the vector index can give you ground truth results which you can use to
calculate your recall to select an appropriate value for nprobes.
"""
self._inner.bypass_vector_index()
return self

View File

@@ -43,7 +43,7 @@ from .common import DATA, VEC, VECTOR_COLUMN_NAME
from .embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
from .merge import LanceMergeInsertBuilder
from .pydantic import LanceModel, model_to_dict
from .query import LanceQueryBuilder, Query
from .query import AsyncQuery, AsyncVectorQuery, LanceQueryBuilder, Query
from .util import (
fs_from_uri,
inf_vector_column_query,
@@ -1899,6 +1899,9 @@ class AsyncTable:
"""
return await self._inner.count_rows(filter)
def query(self) -> AsyncQuery:
return AsyncQuery(self._inner.query())
async def to_pandas(self) -> "pd.DataFrame":
"""Return the table as a pandas DataFrame.
@@ -1906,7 +1909,7 @@ class AsyncTable:
-------
pd.DataFrame
"""
return self.to_arrow().to_pandas()
return (await self.to_arrow()).to_pandas()
async def to_arrow(self) -> pa.Table:
"""Return the table as a pyarrow Table.
@@ -1915,7 +1918,7 @@ class AsyncTable:
-------
pa.Table
"""
raise NotImplementedError
return await self.query().to_arrow()
async def create_index(
self,
@@ -2068,90 +2071,18 @@ class AsyncTable:
return LanceMergeInsertBuilder(self, on)
async def search(
def vector_search(
self,
query: Optional[Union[VEC, str, "PIL.Image.Image", Tuple]] = None,
vector_column_name: Optional[str] = None,
query_type: str = "auto",
) -> LanceQueryBuilder:
"""Create a search query to find the nearest neighbors
of the given query vector. We currently support [vector search][search]
and [full-text search][experimental-full-text-search].
All query options are defined in [Query][lancedb.query.Query].
Examples
--------
>>> import lancedb
>>> db = lancedb.connect("./.lancedb")
>>> data = [
... {"original_width": 100, "caption": "bar", "vector": [0.1, 2.3, 4.5]},
... {"original_width": 2000, "caption": "foo", "vector": [0.5, 3.4, 1.3]},
... {"original_width": 3000, "caption": "test", "vector": [0.3, 6.2, 2.6]}
... ]
>>> table = db.create_table("my_table", data)
>>> query = [0.4, 1.4, 2.4]
>>> (table.search(query)
... .where("original_width > 1000", prefilter=True)
... .select(["caption", "original_width", "vector"])
... .limit(2)
... .to_pandas())
caption original_width vector _distance
0 foo 2000 [0.5, 3.4, 1.3] 5.220000
1 test 3000 [0.3, 6.2, 2.6] 23.089996
Parameters
----------
query: list/np.ndarray/str/PIL.Image.Image, default None
The targetted vector to search for.
- *default None*.
Acceptable types are: list, np.ndarray, PIL.Image.Image
- If None then the select/where/limit clauses are applied to filter
the table
vector_column_name: str, optional
The name of the vector column to search.
The vector column needs to be a pyarrow fixed size list type
- If not specified then the vector column is inferred from
the table schema
- If the table has multiple vector columns then the *vector_column_name*
needs to be specified. Otherwise, an error is raised.
query_type: str
*default "auto"*.
Acceptable types are: "vector", "fts", "hybrid", or "auto"
- If "auto" then the query type is inferred from the query;
- If `query` is a list/np.ndarray then the query type is
"vector";
- If `query` is a PIL.Image.Image then either do vector search,
or raise an error if no corresponding embedding function is found.
- If `query` is a string, then the query type is "vector" if the
table has embedding functions else the query type is "fts"
Returns
-------
LanceQueryBuilder
A query builder object representing the query.
Once executed, the query returns
- selected columns
- the vector
- and also the "_distance" column which is the distance between the query
vector and the returned vector.
query_vector: Optional[Union[VEC, Tuple]] = None,
) -> AsyncVectorQuery:
"""
raise NotImplementedError
Search the table with a given query vector.
async def _execute_query(self, query: Query) -> pa.Table:
pass
This is a convenience method for preparing a vector query and
is the same thing as calling `nearestTo` on the builder returned
by `query`. Seer [nearest_to][AsyncQuery.nearest_to] for more details.
"""
return self.query().nearest_to(query_vector)
async def _do_merge(
self,

View File

@@ -12,16 +12,19 @@
# limitations under the License.
import unittest.mock as mock
from datetime import timedelta
import lance
import lancedb
import numpy as np
import pandas.testing as tm
import pyarrow as pa
import pytest
import pytest_asyncio
from lancedb.db import LanceDBConnection
from lancedb.pydantic import LanceModel, Vector
from lancedb.query import LanceVectorQueryBuilder, Query
from lancedb.table import LanceTable
from lancedb.query import AsyncQueryBase, LanceVectorQueryBuilder, Query
from lancedb.table import AsyncTable, LanceTable
class MockTable:
@@ -65,6 +68,24 @@ def table(tmp_path) -> MockTable:
return MockTable(tmp_path)
@pytest_asyncio.fixture
async def table_async(tmp_path) -> AsyncTable:
conn = await lancedb.connect_async(
tmp_path, read_consistency_interval=timedelta(seconds=0)
)
data = pa.table(
{
"vector": pa.array(
[[1, 2], [3, 4]], type=pa.list_(pa.float32(), list_size=2)
),
"id": pa.array([1, 2]),
"str_field": pa.array(["a", "b"]),
"float_field": pa.array([1.0, 2.0]),
}
)
return await conn.create_table("test", data)
def test_cast(table):
class TestModel(LanceModel):
vector: Vector(2)
@@ -184,3 +205,109 @@ def test_query_builder_with_different_vector_column():
def cosine_distance(vec1, vec2):
return 1 - np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
async def check_query(
query: AsyncQueryBase, *, expected_num_rows=None, expected_columns=None
):
num_rows = 0
results = await query.to_batches()
async for batch in results:
if expected_columns is not None:
assert batch.schema.names == expected_columns
num_rows += batch.num_rows
if expected_num_rows is not None:
assert num_rows == expected_num_rows
@pytest.mark.asyncio
async def test_query_async(table_async: AsyncTable):
await check_query(
table_async.query(),
expected_num_rows=2,
expected_columns=["vector", "id", "str_field", "float_field"],
)
await check_query(table_async.query().where("id = 2"), expected_num_rows=1)
await check_query(
table_async.query().select(["id", "vector"]), expected_columns=["id", "vector"]
)
await check_query(
table_async.query().select({"foo": "id", "bar": "id + 1"}),
expected_columns=["foo", "bar"],
)
await check_query(table_async.query().limit(1), expected_num_rows=1)
await check_query(
table_async.query().nearest_to(pa.array([1, 2])), expected_num_rows=2
)
# Support different types of inputs for the vector query
for vector_query in [
[1, 2],
[1.0, 2.0],
np.array([1, 2]),
(1, 2),
]:
await check_query(
table_async.query().nearest_to(vector_query), expected_num_rows=2
)
# No easy way to check these vector query parameters are doing what they say. We
# just check that they don't raise exceptions and assume this is tested at a lower
# level.
await check_query(
table_async.query().where("id = 2").nearest_to(pa.array([1, 2])).postfilter(),
expected_num_rows=1,
)
await check_query(
table_async.query().nearest_to(pa.array([1, 2])).refine_factor(1),
expected_num_rows=2,
)
await check_query(
table_async.query().nearest_to(pa.array([1, 2])).nprobes(10),
expected_num_rows=2,
)
await check_query(
table_async.query().nearest_to(pa.array([1, 2])).bypass_vector_index(),
expected_num_rows=2,
)
await check_query(
table_async.query().nearest_to(pa.array([1, 2])).distance_type("dot"),
expected_num_rows=2,
)
await check_query(
table_async.query().nearest_to(pa.array([1, 2])).distance_type("DoT"),
expected_num_rows=2,
)
# Make sure we can use a vector query as a base query (e.g. call limit on it)
# Also make sure `vector_search` works
await check_query(table_async.vector_search([1, 2]).limit(1), expected_num_rows=1)
# Also check an empty query
await check_query(table_async.query().where("id < 0"), expected_num_rows=0)
@pytest.mark.asyncio
async def test_query_to_arrow_async(table_async: AsyncTable):
table = await table_async.to_arrow()
assert table.num_rows == 2
assert table.num_columns == 4
table = await table_async.query().to_arrow()
assert table.num_rows == 2
assert table.num_columns == 4
table = await table_async.query().where("id < 0").to_arrow()
assert table.num_rows == 0
assert table.num_columns == 4
@pytest.mark.asyncio
async def test_query_to_pandas_async(table_async: AsyncTable):
df = await table_async.to_pandas()
assert df.shape == (2, 4)
df = await table_async.query().to_pandas()
assert df.shape == (2, 4)
df = await table_async.query().where("id < 0").to_pandas()
assert df.shape == (0, 4)

51
python/src/arrow.rs Normal file
View File

@@ -0,0 +1,51 @@
// use arrow::datatypes::SchemaRef;
// use lancedb::arrow::SendableRecordBatchStream;
use std::sync::Arc;
use arrow::{
datatypes::SchemaRef,
pyarrow::{IntoPyArrow, ToPyArrow},
};
use futures::stream::StreamExt;
use lancedb::arrow::SendableRecordBatchStream;
use pyo3::{pyclass, pymethods, PyAny, PyObject, PyRef, PyResult, Python};
use pyo3_asyncio::tokio::future_into_py;
use crate::error::PythonErrorExt;
#[pyclass]
pub struct RecordBatchStream {
schema: SchemaRef,
inner: Arc<tokio::sync::Mutex<SendableRecordBatchStream>>,
}
impl RecordBatchStream {
pub fn new(inner: SendableRecordBatchStream) -> Self {
let schema = inner.schema().clone();
Self {
schema,
inner: Arc::new(tokio::sync::Mutex::new(inner)),
}
}
}
#[pymethods]
impl RecordBatchStream {
pub fn schema(&self, py: Python) -> PyResult<PyObject> {
(*self.schema).clone().into_pyarrow(py)
}
pub fn next(self_: PyRef<'_, Self>) -> PyResult<&PyAny> {
let inner = self_.inner.clone();
future_into_py(self_.py(), async move {
let inner_next = inner.lock().await.next().await;
inner_next
.map(|item| {
let item = item.infer_error()?;
Python::with_gil(|py| item.to_pyarrow(py))
})
.transpose()
})
}
}

View File

@@ -12,15 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use arrow::RecordBatchStream;
use connection::{connect, Connection};
use env_logger::Env;
use index::{Index, IndexConfig};
use pyo3::{pymodule, types::PyModule, wrap_pyfunction, PyResult, Python};
use query::{Query, VectorQuery};
use table::Table;
pub mod arrow;
pub mod connection;
pub mod error;
pub mod index;
pub mod query;
pub mod table;
pub mod util;
@@ -34,6 +38,9 @@ pub fn _lancedb(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<Table>()?;
m.add_class::<Index>()?;
m.add_class::<IndexConfig>()?;
m.add_class::<Query>()?;
m.add_class::<VectorQuery>()?;
m.add_class::<RecordBatchStream>()?;
m.add_function(wrap_pyfunction!(connect, m)?)?;
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
Ok(())

125
python/src/query.rs Normal file
View File

@@ -0,0 +1,125 @@
// Copyright 2024 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use arrow::array::make_array;
use arrow::array::ArrayData;
use arrow::pyarrow::FromPyArrow;
use lancedb::query::{
ExecutableQuery, Query as LanceDbQuery, QueryBase, Select, VectorQuery as LanceDbVectorQuery,
};
use pyo3::pyclass;
use pyo3::pymethods;
use pyo3::PyAny;
use pyo3::PyRef;
use pyo3::PyResult;
use pyo3_asyncio::tokio::future_into_py;
use crate::arrow::RecordBatchStream;
use crate::error::PythonErrorExt;
use crate::util::parse_distance_type;
#[pyclass]
pub struct Query {
inner: LanceDbQuery,
}
impl Query {
pub fn new(query: LanceDbQuery) -> Self {
Self { inner: query }
}
}
#[pymethods]
impl Query {
pub fn r#where(&mut self, predicate: String) {
self.inner = self.inner.clone().only_if(predicate);
}
pub fn select(&mut self, columns: Vec<(String, String)>) {
self.inner = self.inner.clone().select(Select::dynamic(&columns));
}
pub fn limit(&mut self, limit: u32) {
self.inner = self.inner.clone().limit(limit as usize);
}
pub fn nearest_to(&mut self, vector: &PyAny) -> PyResult<VectorQuery> {
let data: ArrayData = ArrayData::from_pyarrow(vector)?;
let array = make_array(data);
let inner = self.inner.clone().nearest_to(array).infer_error()?;
Ok(VectorQuery { inner })
}
pub fn execute(self_: PyRef<'_, Self>) -> PyResult<&PyAny> {
let inner = self_.inner.clone();
future_into_py(self_.py(), async move {
let inner_stream = inner.execute().await.infer_error()?;
Ok(RecordBatchStream::new(inner_stream))
})
}
}
#[pyclass]
pub struct VectorQuery {
inner: LanceDbVectorQuery,
}
#[pymethods]
impl VectorQuery {
pub fn r#where(&mut self, predicate: String) {
self.inner = self.inner.clone().only_if(predicate);
}
pub fn select(&mut self, columns: Vec<(String, String)>) {
self.inner = self.inner.clone().select(Select::dynamic(&columns));
}
pub fn limit(&mut self, limit: u32) {
self.inner = self.inner.clone().limit(limit as usize);
}
pub fn column(&mut self, column: String) {
self.inner = self.inner.clone().column(&column);
}
pub fn distance_type(&mut self, distance_type: String) -> PyResult<()> {
let distance_type = parse_distance_type(distance_type)?;
self.inner = self.inner.clone().distance_type(distance_type);
Ok(())
}
pub fn postfilter(&mut self) {
self.inner = self.inner.clone().postfilter();
}
pub fn refine_factor(&mut self, refine_factor: u32) {
self.inner = self.inner.clone().refine_factor(refine_factor);
}
pub fn nprobes(&mut self, nprobe: u32) {
self.inner = self.inner.clone().nprobes(nprobe as usize);
}
pub fn bypass_vector_index(&mut self) {
self.inner = self.inner.clone().bypass_vector_index()
}
pub fn execute(self_: PyRef<'_, Self>) -> PyResult<&PyAny> {
let inner = self_.inner.clone();
future_into_py(self_.py(), async move {
let inner_stream = inner.execute().await.infer_error()?;
Ok(RecordBatchStream::new(inner_stream))
})
}
}

View File

@@ -14,6 +14,7 @@ use pyo3_asyncio::tokio::future_into_py;
use crate::{
error::PythonErrorExt,
index::{Index, IndexConfig},
query::Query,
};
#[pyclass]
@@ -179,4 +180,8 @@ impl Table {
async move { inner.restore().await.infer_error() },
)
}
pub fn query(&self) -> Query {
Query::new(self.inner_ref().unwrap().query())
}
}

View File

@@ -1,6 +1,10 @@
use std::sync::Mutex;
use pyo3::{exceptions::PyRuntimeError, PyResult};
use lancedb::DistanceType;
use pyo3::{
exceptions::{PyRuntimeError, PyValueError},
PyResult,
};
/// A wrapper around a rust builder
///
@@ -33,3 +37,15 @@ impl<T> BuilderWrapper<T> {
Ok(result)
}
}
pub fn parse_distance_type(distance_type: impl AsRef<str>) -> PyResult<DistanceType> {
match distance_type.as_ref().to_lowercase().as_str() {
"l2" => Ok(DistanceType::L2),
"cosine" => Ok(DistanceType::Cosine),
"dot" => Ok(DistanceType::Dot),
_ => Err(PyValueError::new_err(format!(
"Invalid distance type '{}'. Must be one of l2, cosine, or dot",
distance_type.as_ref()
))),
}
}