feat: refactor the query API and add query support to the python async API (#1113)

In addition, there are also a number of changes in nodejs to the
docstrings of existing methods because this PR adds a jsdoc linter.
This commit is contained in:
Weston Pace
2024-03-18 12:36:49 -07:00
parent 2db257ca29
commit 4180b44472
38 changed files with 2609 additions and 754 deletions

51
python/src/arrow.rs Normal file
View File

@@ -0,0 +1,51 @@
// use arrow::datatypes::SchemaRef;
// use lancedb::arrow::SendableRecordBatchStream;
use std::sync::Arc;
use arrow::{
datatypes::SchemaRef,
pyarrow::{IntoPyArrow, ToPyArrow},
};
use futures::stream::StreamExt;
use lancedb::arrow::SendableRecordBatchStream;
use pyo3::{pyclass, pymethods, PyAny, PyObject, PyRef, PyResult, Python};
use pyo3_asyncio::tokio::future_into_py;
use crate::error::PythonErrorExt;
#[pyclass]
pub struct RecordBatchStream {
schema: SchemaRef,
inner: Arc<tokio::sync::Mutex<SendableRecordBatchStream>>,
}
impl RecordBatchStream {
pub fn new(inner: SendableRecordBatchStream) -> Self {
let schema = inner.schema().clone();
Self {
schema,
inner: Arc::new(tokio::sync::Mutex::new(inner)),
}
}
}
#[pymethods]
impl RecordBatchStream {
pub fn schema(&self, py: Python) -> PyResult<PyObject> {
(*self.schema).clone().into_pyarrow(py)
}
pub fn next(self_: PyRef<'_, Self>) -> PyResult<&PyAny> {
let inner = self_.inner.clone();
future_into_py(self_.py(), async move {
let inner_next = inner.lock().await.next().await;
inner_next
.map(|item| {
let item = item.infer_error()?;
Python::with_gil(|py| item.to_pyarrow(py))
})
.transpose()
})
}
}

View File

@@ -12,15 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use arrow::RecordBatchStream;
use connection::{connect, Connection};
use env_logger::Env;
use index::{Index, IndexConfig};
use pyo3::{pymodule, types::PyModule, wrap_pyfunction, PyResult, Python};
use query::{Query, VectorQuery};
use table::Table;
pub mod arrow;
pub mod connection;
pub mod error;
pub mod index;
pub mod query;
pub mod table;
pub mod util;
@@ -34,6 +38,9 @@ pub fn _lancedb(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<Table>()?;
m.add_class::<Index>()?;
m.add_class::<IndexConfig>()?;
m.add_class::<Query>()?;
m.add_class::<VectorQuery>()?;
m.add_class::<RecordBatchStream>()?;
m.add_function(wrap_pyfunction!(connect, m)?)?;
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
Ok(())

125
python/src/query.rs Normal file
View File

@@ -0,0 +1,125 @@
// Copyright 2024 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use arrow::array::make_array;
use arrow::array::ArrayData;
use arrow::pyarrow::FromPyArrow;
use lancedb::query::{
ExecutableQuery, Query as LanceDbQuery, QueryBase, Select, VectorQuery as LanceDbVectorQuery,
};
use pyo3::pyclass;
use pyo3::pymethods;
use pyo3::PyAny;
use pyo3::PyRef;
use pyo3::PyResult;
use pyo3_asyncio::tokio::future_into_py;
use crate::arrow::RecordBatchStream;
use crate::error::PythonErrorExt;
use crate::util::parse_distance_type;
#[pyclass]
pub struct Query {
inner: LanceDbQuery,
}
impl Query {
pub fn new(query: LanceDbQuery) -> Self {
Self { inner: query }
}
}
#[pymethods]
impl Query {
pub fn r#where(&mut self, predicate: String) {
self.inner = self.inner.clone().only_if(predicate);
}
pub fn select(&mut self, columns: Vec<(String, String)>) {
self.inner = self.inner.clone().select(Select::dynamic(&columns));
}
pub fn limit(&mut self, limit: u32) {
self.inner = self.inner.clone().limit(limit as usize);
}
pub fn nearest_to(&mut self, vector: &PyAny) -> PyResult<VectorQuery> {
let data: ArrayData = ArrayData::from_pyarrow(vector)?;
let array = make_array(data);
let inner = self.inner.clone().nearest_to(array).infer_error()?;
Ok(VectorQuery { inner })
}
pub fn execute(self_: PyRef<'_, Self>) -> PyResult<&PyAny> {
let inner = self_.inner.clone();
future_into_py(self_.py(), async move {
let inner_stream = inner.execute().await.infer_error()?;
Ok(RecordBatchStream::new(inner_stream))
})
}
}
#[pyclass]
pub struct VectorQuery {
inner: LanceDbVectorQuery,
}
#[pymethods]
impl VectorQuery {
pub fn r#where(&mut self, predicate: String) {
self.inner = self.inner.clone().only_if(predicate);
}
pub fn select(&mut self, columns: Vec<(String, String)>) {
self.inner = self.inner.clone().select(Select::dynamic(&columns));
}
pub fn limit(&mut self, limit: u32) {
self.inner = self.inner.clone().limit(limit as usize);
}
pub fn column(&mut self, column: String) {
self.inner = self.inner.clone().column(&column);
}
pub fn distance_type(&mut self, distance_type: String) -> PyResult<()> {
let distance_type = parse_distance_type(distance_type)?;
self.inner = self.inner.clone().distance_type(distance_type);
Ok(())
}
pub fn postfilter(&mut self) {
self.inner = self.inner.clone().postfilter();
}
pub fn refine_factor(&mut self, refine_factor: u32) {
self.inner = self.inner.clone().refine_factor(refine_factor);
}
pub fn nprobes(&mut self, nprobe: u32) {
self.inner = self.inner.clone().nprobes(nprobe as usize);
}
pub fn bypass_vector_index(&mut self) {
self.inner = self.inner.clone().bypass_vector_index()
}
pub fn execute(self_: PyRef<'_, Self>) -> PyResult<&PyAny> {
let inner = self_.inner.clone();
future_into_py(self_.py(), async move {
let inner_stream = inner.execute().await.infer_error()?;
Ok(RecordBatchStream::new(inner_stream))
})
}
}

View File

@@ -14,6 +14,7 @@ use pyo3_asyncio::tokio::future_into_py;
use crate::{
error::PythonErrorExt,
index::{Index, IndexConfig},
query::Query,
};
#[pyclass]
@@ -179,4 +180,8 @@ impl Table {
async move { inner.restore().await.infer_error() },
)
}
pub fn query(&self) -> Query {
Query::new(self.inner_ref().unwrap().query())
}
}

View File

@@ -1,6 +1,10 @@
use std::sync::Mutex;
use pyo3::{exceptions::PyRuntimeError, PyResult};
use lancedb::DistanceType;
use pyo3::{
exceptions::{PyRuntimeError, PyValueError},
PyResult,
};
/// A wrapper around a rust builder
///
@@ -33,3 +37,15 @@ impl<T> BuilderWrapper<T> {
Ok(result)
}
}
pub fn parse_distance_type(distance_type: impl AsRef<str>) -> PyResult<DistanceType> {
match distance_type.as_ref().to_lowercase().as_str() {
"l2" => Ok(DistanceType::L2),
"cosine" => Ok(DistanceType::Cosine),
"dot" => Ok(DistanceType::Dot),
_ => Err(PyValueError::new_err(format!(
"Invalid distance type '{}'. Must be one of l2, cosine, or dot",
distance_type.as_ref()
))),
}
}