mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 07:09:57 +00:00
feat: add create_index to the async python API (#1052)
This also refactors the rust lancedb index builder API (and, correspondingly, the nodejs API)
This commit is contained in:
87
python/src/index.rs
Normal file
87
python/src/index.rs
Normal file
@@ -0,0 +1,87 @@
|
||||
// Copyright 2024 Lance Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Mutex;
|
||||
|
||||
use lancedb::{
|
||||
index::{scalar::BTreeIndexBuilder, vector::IvfPqIndexBuilder, Index as LanceDbIndex},
|
||||
DistanceType,
|
||||
};
|
||||
use pyo3::{
|
||||
exceptions::{PyRuntimeError, PyValueError},
|
||||
pyclass, pymethods, PyResult,
|
||||
};
|
||||
|
||||
#[pyclass]
|
||||
pub struct Index {
|
||||
inner: Mutex<Option<LanceDbIndex>>,
|
||||
}
|
||||
|
||||
impl Index {
|
||||
pub fn consume(&self) -> PyResult<LanceDbIndex> {
|
||||
self.inner
|
||||
.lock()
|
||||
.unwrap()
|
||||
.take()
|
||||
.ok_or_else(|| PyRuntimeError::new_err("cannot use an Index more than once"))
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl Index {
|
||||
#[staticmethod]
|
||||
pub fn ivf_pq(
|
||||
distance_type: Option<String>,
|
||||
num_partitions: Option<u32>,
|
||||
num_sub_vectors: Option<u32>,
|
||||
max_iterations: Option<u32>,
|
||||
sample_rate: Option<u32>,
|
||||
) -> PyResult<Self> {
|
||||
let mut ivf_pq_builder = IvfPqIndexBuilder::default();
|
||||
if let Some(distance_type) = distance_type {
|
||||
let distance_type = match distance_type.as_str() {
|
||||
"l2" => Ok(DistanceType::L2),
|
||||
"cosine" => Ok(DistanceType::Cosine),
|
||||
"dot" => Ok(DistanceType::Dot),
|
||||
_ => Err(PyValueError::new_err(format!(
|
||||
"Invalid distance type '{}'. Must be one of l2, cosine, or dot",
|
||||
distance_type
|
||||
))),
|
||||
}?;
|
||||
ivf_pq_builder = ivf_pq_builder.distance_type(distance_type);
|
||||
}
|
||||
if let Some(num_partitions) = num_partitions {
|
||||
ivf_pq_builder = ivf_pq_builder.num_partitions(num_partitions);
|
||||
}
|
||||
if let Some(num_sub_vectors) = num_sub_vectors {
|
||||
ivf_pq_builder = ivf_pq_builder.num_sub_vectors(num_sub_vectors);
|
||||
}
|
||||
if let Some(max_iterations) = max_iterations {
|
||||
ivf_pq_builder = ivf_pq_builder.max_iterations(max_iterations);
|
||||
}
|
||||
if let Some(sample_rate) = sample_rate {
|
||||
ivf_pq_builder = ivf_pq_builder.sample_rate(sample_rate);
|
||||
}
|
||||
Ok(Self {
|
||||
inner: Mutex::new(Some(LanceDbIndex::IvfPq(ivf_pq_builder))),
|
||||
})
|
||||
}
|
||||
|
||||
#[staticmethod]
|
||||
pub fn btree() -> PyResult<Self> {
|
||||
Ok(Self {
|
||||
inner: Mutex::new(Some(LanceDbIndex::BTree(BTreeIndexBuilder::default()))),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -14,11 +14,15 @@
|
||||
|
||||
use connection::{connect, Connection};
|
||||
use env_logger::Env;
|
||||
use index::Index;
|
||||
use pyo3::{pymodule, types::PyModule, wrap_pyfunction, PyResult, Python};
|
||||
use table::Table;
|
||||
|
||||
pub mod connection;
|
||||
pub mod error;
|
||||
pub mod index;
|
||||
pub mod table;
|
||||
pub mod util;
|
||||
|
||||
#[pymodule]
|
||||
pub fn _lancedb(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||
@@ -27,6 +31,8 @@ pub fn _lancedb(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||
.write_style("LANCEDB_LOG_STYLE");
|
||||
env_logger::init_from_env(env);
|
||||
m.add_class::<Connection>()?;
|
||||
m.add_class::<Table>()?;
|
||||
m.add_class::<Index>()?;
|
||||
m.add_function(wrap_pyfunction!(connect, m)?)?;
|
||||
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
|
||||
Ok(())
|
||||
|
||||
@@ -9,7 +9,7 @@ use pyo3::{
|
||||
};
|
||||
use pyo3_asyncio::tokio::future_into_py;
|
||||
|
||||
use crate::error::PythonErrorExt;
|
||||
use crate::{error::PythonErrorExt, index::Index};
|
||||
|
||||
#[pyclass]
|
||||
pub struct Table {
|
||||
@@ -81,6 +81,28 @@ impl Table {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn create_index<'a>(
|
||||
self_: PyRef<'a, Self>,
|
||||
column: String,
|
||||
index: Option<&Index>,
|
||||
replace: Option<bool>,
|
||||
) -> PyResult<&'a PyAny> {
|
||||
let index = if let Some(index) = index {
|
||||
index.consume()?
|
||||
} else {
|
||||
lancedb::index::Index::Auto
|
||||
};
|
||||
let mut op = self_.inner_ref()?.create_index(&[column], index);
|
||||
if let Some(replace) = replace {
|
||||
op = op.replace(replace);
|
||||
}
|
||||
|
||||
future_into_py(self_.py(), async move {
|
||||
op.execute().await.infer_error()?;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn __repr__(&self) -> String {
|
||||
match &self.inner {
|
||||
None => format!("ClosedTable({})", self.name),
|
||||
|
||||
35
python/src/util.rs
Normal file
35
python/src/util.rs
Normal file
@@ -0,0 +1,35 @@
|
||||
use std::sync::Mutex;
|
||||
|
||||
use pyo3::{exceptions::PyRuntimeError, PyResult};
|
||||
|
||||
/// A wrapper around a rust builder
|
||||
///
|
||||
/// Rust builders are often implemented so that the builder methods
|
||||
/// consume the builder and return a new one. This is not compatible
|
||||
/// with the pyo3, which, being garbage collected, cannot easily obtain
|
||||
/// ownership of an object.
|
||||
///
|
||||
/// This wrapper converts the compile-time safety of rust into runtime
|
||||
/// errors if any attempt to use the builder happens after it is consumed.
|
||||
pub struct BuilderWrapper<T> {
|
||||
name: String,
|
||||
inner: Mutex<Option<T>>,
|
||||
}
|
||||
|
||||
impl<T> BuilderWrapper<T> {
|
||||
pub fn new(name: impl AsRef<str>, inner: T) -> Self {
|
||||
Self {
|
||||
name: name.as_ref().to_string(),
|
||||
inner: Mutex::new(Some(inner)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn consume<O>(&self, mod_fn: impl FnOnce(T) -> O) -> PyResult<O> {
|
||||
let mut inner = self.inner.lock().unwrap();
|
||||
let inner_builder = inner.take().ok_or_else(|| {
|
||||
PyRuntimeError::new_err(format!("{} has already been consumed", self.name))
|
||||
})?;
|
||||
let result = mod_fn(inner_builder);
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user