feat: add create_index to the async python API (#1052)

This also refactors the rust lancedb index builder API (and,
correspondingly, the nodejs API)
This commit is contained in:
Weston Pace
2024-03-12 05:17:05 -07:00
committed by GitHub
parent ae1cf4441d
commit 356e89a800
38 changed files with 1330 additions and 767 deletions

87
python/src/index.rs Normal file
View File

@@ -0,0 +1,87 @@
// Copyright 2024 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Mutex;
use lancedb::{
index::{scalar::BTreeIndexBuilder, vector::IvfPqIndexBuilder, Index as LanceDbIndex},
DistanceType,
};
use pyo3::{
exceptions::{PyRuntimeError, PyValueError},
pyclass, pymethods, PyResult,
};
#[pyclass]
pub struct Index {
inner: Mutex<Option<LanceDbIndex>>,
}
impl Index {
pub fn consume(&self) -> PyResult<LanceDbIndex> {
self.inner
.lock()
.unwrap()
.take()
.ok_or_else(|| PyRuntimeError::new_err("cannot use an Index more than once"))
}
}
#[pymethods]
impl Index {
#[staticmethod]
pub fn ivf_pq(
distance_type: Option<String>,
num_partitions: Option<u32>,
num_sub_vectors: Option<u32>,
max_iterations: Option<u32>,
sample_rate: Option<u32>,
) -> PyResult<Self> {
let mut ivf_pq_builder = IvfPqIndexBuilder::default();
if let Some(distance_type) = distance_type {
let distance_type = match distance_type.as_str() {
"l2" => Ok(DistanceType::L2),
"cosine" => Ok(DistanceType::Cosine),
"dot" => Ok(DistanceType::Dot),
_ => Err(PyValueError::new_err(format!(
"Invalid distance type '{}'. Must be one of l2, cosine, or dot",
distance_type
))),
}?;
ivf_pq_builder = ivf_pq_builder.distance_type(distance_type);
}
if let Some(num_partitions) = num_partitions {
ivf_pq_builder = ivf_pq_builder.num_partitions(num_partitions);
}
if let Some(num_sub_vectors) = num_sub_vectors {
ivf_pq_builder = ivf_pq_builder.num_sub_vectors(num_sub_vectors);
}
if let Some(max_iterations) = max_iterations {
ivf_pq_builder = ivf_pq_builder.max_iterations(max_iterations);
}
if let Some(sample_rate) = sample_rate {
ivf_pq_builder = ivf_pq_builder.sample_rate(sample_rate);
}
Ok(Self {
inner: Mutex::new(Some(LanceDbIndex::IvfPq(ivf_pq_builder))),
})
}
#[staticmethod]
pub fn btree() -> PyResult<Self> {
Ok(Self {
inner: Mutex::new(Some(LanceDbIndex::BTree(BTreeIndexBuilder::default()))),
})
}
}

View File

@@ -14,11 +14,15 @@
use connection::{connect, Connection};
use env_logger::Env;
use index::Index;
use pyo3::{pymodule, types::PyModule, wrap_pyfunction, PyResult, Python};
use table::Table;
pub mod connection;
pub mod error;
pub mod index;
pub mod table;
pub mod util;
#[pymodule]
pub fn _lancedb(_py: Python, m: &PyModule) -> PyResult<()> {
@@ -27,6 +31,8 @@ pub fn _lancedb(_py: Python, m: &PyModule) -> PyResult<()> {
.write_style("LANCEDB_LOG_STYLE");
env_logger::init_from_env(env);
m.add_class::<Connection>()?;
m.add_class::<Table>()?;
m.add_class::<Index>()?;
m.add_function(wrap_pyfunction!(connect, m)?)?;
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
Ok(())

View File

@@ -9,7 +9,7 @@ use pyo3::{
};
use pyo3_asyncio::tokio::future_into_py;
use crate::error::PythonErrorExt;
use crate::{error::PythonErrorExt, index::Index};
#[pyclass]
pub struct Table {
@@ -81,6 +81,28 @@ impl Table {
})
}
pub fn create_index<'a>(
self_: PyRef<'a, Self>,
column: String,
index: Option<&Index>,
replace: Option<bool>,
) -> PyResult<&'a PyAny> {
let index = if let Some(index) = index {
index.consume()?
} else {
lancedb::index::Index::Auto
};
let mut op = self_.inner_ref()?.create_index(&[column], index);
if let Some(replace) = replace {
op = op.replace(replace);
}
future_into_py(self_.py(), async move {
op.execute().await.infer_error()?;
Ok(())
})
}
pub fn __repr__(&self) -> String {
match &self.inner {
None => format!("ClosedTable({})", self.name),

35
python/src/util.rs Normal file
View File

@@ -0,0 +1,35 @@
use std::sync::Mutex;
use pyo3::{exceptions::PyRuntimeError, PyResult};
/// A wrapper around a rust builder
///
/// Rust builders are often implemented so that the builder methods
/// consume the builder and return a new one. This is not compatible
/// with the pyo3, which, being garbage collected, cannot easily obtain
/// ownership of an object.
///
/// This wrapper converts the compile-time safety of rust into runtime
/// errors if any attempt to use the builder happens after it is consumed.
pub struct BuilderWrapper<T> {
name: String,
inner: Mutex<Option<T>>,
}
impl<T> BuilderWrapper<T> {
pub fn new(name: impl AsRef<str>, inner: T) -> Self {
Self {
name: name.as_ref().to_string(),
inner: Mutex::new(Some(inner)),
}
}
pub fn consume<O>(&self, mod_fn: impl FnOnce(T) -> O) -> PyResult<O> {
let mut inner = self.inner.lock().unwrap();
let inner_builder = inner.take().ok_or_else(|| {
PyRuntimeError::new_err(format!("{} has already been consumed", self.name))
})?;
let result = mod_fn(inner_builder);
Ok(result)
}
}