mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-14 02:20:40 +00:00
feat: refactor the query API and add query support to the python async API (#1113)
In addition, there are also a number of changes in nodejs to the docstrings of existing methods because this PR adds a jsdoc linter.
This commit is contained in:
51
python/src/arrow.rs
Normal file
51
python/src/arrow.rs
Normal file
@@ -0,0 +1,51 @@
|
||||
// use arrow::datatypes::SchemaRef;
|
||||
// use lancedb::arrow::SendableRecordBatchStream;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow::{
|
||||
datatypes::SchemaRef,
|
||||
pyarrow::{IntoPyArrow, ToPyArrow},
|
||||
};
|
||||
use futures::stream::StreamExt;
|
||||
use lancedb::arrow::SendableRecordBatchStream;
|
||||
use pyo3::{pyclass, pymethods, PyAny, PyObject, PyRef, PyResult, Python};
|
||||
use pyo3_asyncio::tokio::future_into_py;
|
||||
|
||||
use crate::error::PythonErrorExt;
|
||||
|
||||
#[pyclass]
|
||||
pub struct RecordBatchStream {
|
||||
schema: SchemaRef,
|
||||
inner: Arc<tokio::sync::Mutex<SendableRecordBatchStream>>,
|
||||
}
|
||||
|
||||
impl RecordBatchStream {
|
||||
pub fn new(inner: SendableRecordBatchStream) -> Self {
|
||||
let schema = inner.schema().clone();
|
||||
Self {
|
||||
schema,
|
||||
inner: Arc::new(tokio::sync::Mutex::new(inner)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl RecordBatchStream {
|
||||
pub fn schema(&self, py: Python) -> PyResult<PyObject> {
|
||||
(*self.schema).clone().into_pyarrow(py)
|
||||
}
|
||||
|
||||
pub fn next(self_: PyRef<'_, Self>) -> PyResult<&PyAny> {
|
||||
let inner = self_.inner.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
let inner_next = inner.lock().await.next().await;
|
||||
inner_next
|
||||
.map(|item| {
|
||||
let item = item.infer_error()?;
|
||||
Python::with_gil(|py| item.to_pyarrow(py))
|
||||
})
|
||||
.transpose()
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -12,15 +12,19 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use arrow::RecordBatchStream;
|
||||
use connection::{connect, Connection};
|
||||
use env_logger::Env;
|
||||
use index::{Index, IndexConfig};
|
||||
use pyo3::{pymodule, types::PyModule, wrap_pyfunction, PyResult, Python};
|
||||
use query::{Query, VectorQuery};
|
||||
use table::Table;
|
||||
|
||||
pub mod arrow;
|
||||
pub mod connection;
|
||||
pub mod error;
|
||||
pub mod index;
|
||||
pub mod query;
|
||||
pub mod table;
|
||||
pub mod util;
|
||||
|
||||
@@ -34,6 +38,9 @@ pub fn _lancedb(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||
m.add_class::<Table>()?;
|
||||
m.add_class::<Index>()?;
|
||||
m.add_class::<IndexConfig>()?;
|
||||
m.add_class::<Query>()?;
|
||||
m.add_class::<VectorQuery>()?;
|
||||
m.add_class::<RecordBatchStream>()?;
|
||||
m.add_function(wrap_pyfunction!(connect, m)?)?;
|
||||
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
|
||||
Ok(())
|
||||
|
||||
125
python/src/query.rs
Normal file
125
python/src/query.rs
Normal file
@@ -0,0 +1,125 @@
|
||||
// Copyright 2024 Lance Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use arrow::array::make_array;
|
||||
use arrow::array::ArrayData;
|
||||
use arrow::pyarrow::FromPyArrow;
|
||||
use lancedb::query::{
|
||||
ExecutableQuery, Query as LanceDbQuery, QueryBase, Select, VectorQuery as LanceDbVectorQuery,
|
||||
};
|
||||
use pyo3::pyclass;
|
||||
use pyo3::pymethods;
|
||||
use pyo3::PyAny;
|
||||
use pyo3::PyRef;
|
||||
use pyo3::PyResult;
|
||||
use pyo3_asyncio::tokio::future_into_py;
|
||||
|
||||
use crate::arrow::RecordBatchStream;
|
||||
use crate::error::PythonErrorExt;
|
||||
use crate::util::parse_distance_type;
|
||||
|
||||
#[pyclass]
|
||||
pub struct Query {
|
||||
inner: LanceDbQuery,
|
||||
}
|
||||
|
||||
impl Query {
|
||||
pub fn new(query: LanceDbQuery) -> Self {
|
||||
Self { inner: query }
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl Query {
|
||||
pub fn r#where(&mut self, predicate: String) {
|
||||
self.inner = self.inner.clone().only_if(predicate);
|
||||
}
|
||||
|
||||
pub fn select(&mut self, columns: Vec<(String, String)>) {
|
||||
self.inner = self.inner.clone().select(Select::dynamic(&columns));
|
||||
}
|
||||
|
||||
pub fn limit(&mut self, limit: u32) {
|
||||
self.inner = self.inner.clone().limit(limit as usize);
|
||||
}
|
||||
|
||||
pub fn nearest_to(&mut self, vector: &PyAny) -> PyResult<VectorQuery> {
|
||||
let data: ArrayData = ArrayData::from_pyarrow(vector)?;
|
||||
let array = make_array(data);
|
||||
let inner = self.inner.clone().nearest_to(array).infer_error()?;
|
||||
Ok(VectorQuery { inner })
|
||||
}
|
||||
|
||||
pub fn execute(self_: PyRef<'_, Self>) -> PyResult<&PyAny> {
|
||||
let inner = self_.inner.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
let inner_stream = inner.execute().await.infer_error()?;
|
||||
Ok(RecordBatchStream::new(inner_stream))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
pub struct VectorQuery {
|
||||
inner: LanceDbVectorQuery,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl VectorQuery {
|
||||
pub fn r#where(&mut self, predicate: String) {
|
||||
self.inner = self.inner.clone().only_if(predicate);
|
||||
}
|
||||
|
||||
pub fn select(&mut self, columns: Vec<(String, String)>) {
|
||||
self.inner = self.inner.clone().select(Select::dynamic(&columns));
|
||||
}
|
||||
|
||||
pub fn limit(&mut self, limit: u32) {
|
||||
self.inner = self.inner.clone().limit(limit as usize);
|
||||
}
|
||||
|
||||
pub fn column(&mut self, column: String) {
|
||||
self.inner = self.inner.clone().column(&column);
|
||||
}
|
||||
|
||||
pub fn distance_type(&mut self, distance_type: String) -> PyResult<()> {
|
||||
let distance_type = parse_distance_type(distance_type)?;
|
||||
self.inner = self.inner.clone().distance_type(distance_type);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn postfilter(&mut self) {
|
||||
self.inner = self.inner.clone().postfilter();
|
||||
}
|
||||
|
||||
pub fn refine_factor(&mut self, refine_factor: u32) {
|
||||
self.inner = self.inner.clone().refine_factor(refine_factor);
|
||||
}
|
||||
|
||||
pub fn nprobes(&mut self, nprobe: u32) {
|
||||
self.inner = self.inner.clone().nprobes(nprobe as usize);
|
||||
}
|
||||
|
||||
pub fn bypass_vector_index(&mut self) {
|
||||
self.inner = self.inner.clone().bypass_vector_index()
|
||||
}
|
||||
|
||||
pub fn execute(self_: PyRef<'_, Self>) -> PyResult<&PyAny> {
|
||||
let inner = self_.inner.clone();
|
||||
future_into_py(self_.py(), async move {
|
||||
let inner_stream = inner.execute().await.infer_error()?;
|
||||
Ok(RecordBatchStream::new(inner_stream))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -14,6 +14,7 @@ use pyo3_asyncio::tokio::future_into_py;
|
||||
use crate::{
|
||||
error::PythonErrorExt,
|
||||
index::{Index, IndexConfig},
|
||||
query::Query,
|
||||
};
|
||||
|
||||
#[pyclass]
|
||||
@@ -179,4 +180,8 @@ impl Table {
|
||||
async move { inner.restore().await.infer_error() },
|
||||
)
|
||||
}
|
||||
|
||||
pub fn query(&self) -> Query {
|
||||
Query::new(self.inner_ref().unwrap().query())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
use std::sync::Mutex;
|
||||
|
||||
use pyo3::{exceptions::PyRuntimeError, PyResult};
|
||||
use lancedb::DistanceType;
|
||||
use pyo3::{
|
||||
exceptions::{PyRuntimeError, PyValueError},
|
||||
PyResult,
|
||||
};
|
||||
|
||||
/// A wrapper around a rust builder
|
||||
///
|
||||
@@ -33,3 +37,15 @@ impl<T> BuilderWrapper<T> {
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn parse_distance_type(distance_type: impl AsRef<str>) -> PyResult<DistanceType> {
|
||||
match distance_type.as_ref().to_lowercase().as_str() {
|
||||
"l2" => Ok(DistanceType::L2),
|
||||
"cosine" => Ok(DistanceType::Cosine),
|
||||
"dot" => Ok(DistanceType::Dot),
|
||||
_ => Err(PyValueError::new_err(format!(
|
||||
"Invalid distance type '{}'. Must be one of l2, cosine, or dot",
|
||||
distance_type.as_ref()
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user