feat: add create_index to the async python API (#1052)

This also refactors the rust lancedb index builder API (and,
correspondingly, the nodejs API)
This commit is contained in:
Weston Pace
2024-03-12 05:17:05 -07:00
parent 90af5cf028
commit f822255683
38 changed files with 1329 additions and 766 deletions

12
nodejs/src/error.rs Normal file
View File

@@ -0,0 +1,12 @@
pub type Result<T> = napi::Result<T>;
pub trait NapiErrorExt<T> {
/// Convert to a napi error using from_reason(err.to_string())
fn default_error(self) -> Result<T>;
}
impl<T> NapiErrorExt<T> for std::result::Result<T, lancedb::Error> {
fn default_error(self) -> Result<T> {
self.map_err(|err| napi::Error::from_reason(err.to_string()))
}
}

View File

@@ -14,126 +14,73 @@
use std::sync::Mutex;
use lance_linalg::distance::MetricType as LanceMetricType;
use lancedb::index::IndexBuilder as LanceDbIndexBuilder;
use lancedb::Table as LanceDbTable;
use lancedb::index::scalar::BTreeIndexBuilder;
use lancedb::index::vector::IvfPqIndexBuilder;
use lancedb::index::Index as LanceDbIndex;
use lancedb::DistanceType;
use napi_derive::napi;
#[napi]
pub enum IndexType {
Scalar,
IvfPq,
pub struct Index {
inner: Mutex<Option<LanceDbIndex>>,
}
#[napi]
pub enum MetricType {
L2,
Cosine,
Dot,
}
impl From<MetricType> for LanceMetricType {
fn from(metric: MetricType) -> Self {
match metric {
MetricType::L2 => Self::L2,
MetricType::Cosine => Self::Cosine,
MetricType::Dot => Self::Dot,
}
impl Index {
pub fn consume(&self) -> napi::Result<LanceDbIndex> {
self.inner
.lock()
.unwrap()
.take()
.ok_or(napi::Error::from_reason(
"attempt to use an index more than once",
))
}
}
#[napi]
pub struct IndexBuilder {
inner: Mutex<Option<LanceDbIndexBuilder>>,
}
impl IndexBuilder {
fn modify(
&self,
mod_fn: impl Fn(LanceDbIndexBuilder) -> LanceDbIndexBuilder,
) -> napi::Result<()> {
let mut inner = self.inner.lock().unwrap();
let inner_builder = inner.take().ok_or_else(|| {
napi::Error::from_reason("IndexBuilder has already been consumed".to_string())
})?;
let inner_builder = mod_fn(inner_builder);
inner.replace(inner_builder);
Ok(())
}
}
#[napi]
impl IndexBuilder {
pub fn new(tbl: &LanceDbTable) -> Self {
let inner = tbl.create_index(&[]);
Self {
inner: Mutex::new(Some(inner)),
}
}
#[napi]
pub fn replace(&self, v: bool) -> napi::Result<()> {
self.modify(|b| b.replace(v))
}
#[napi]
pub fn column(&self, c: String) -> napi::Result<()> {
self.modify(|b| b.columns(&[c.as_str()]))
}
#[napi]
pub fn name(&self, name: String) -> napi::Result<()> {
self.modify(|b| b.name(name.as_str()))
}
#[napi]
impl Index {
#[napi(factory)]
pub fn ivf_pq(
&self,
metric_type: Option<MetricType>,
distance_type: Option<String>,
num_partitions: Option<u32>,
num_sub_vectors: Option<u32>,
num_bits: Option<u32>,
max_iterations: Option<u32>,
sample_rate: Option<u32>,
) -> napi::Result<()> {
self.modify(|b| {
let mut b = b.ivf_pq();
if let Some(metric_type) = metric_type {
b = b.metric_type(metric_type.into());
}
if let Some(num_partitions) = num_partitions {
b = b.num_partitions(num_partitions);
}
if let Some(num_sub_vectors) = num_sub_vectors {
b = b.num_sub_vectors(num_sub_vectors);
}
if let Some(num_bits) = num_bits {
b = b.num_bits(num_bits);
}
if let Some(max_iterations) = max_iterations {
b = b.max_iterations(max_iterations);
}
if let Some(sample_rate) = sample_rate {
b = b.sample_rate(sample_rate);
}
b
) -> napi::Result<Self> {
let mut ivf_pq_builder = IvfPqIndexBuilder::default();
if let Some(distance_type) = distance_type {
let distance_type = match distance_type.as_str() {
"l2" => Ok(DistanceType::L2),
"cosine" => Ok(DistanceType::Cosine),
"dot" => Ok(DistanceType::Dot),
_ => Err(napi::Error::from_reason(format!(
"Invalid distance type '{}'. Must be one of l2, cosine, or dot",
distance_type
))),
}?;
ivf_pq_builder = ivf_pq_builder.distance_type(distance_type);
}
if let Some(num_partitions) = num_partitions {
ivf_pq_builder = ivf_pq_builder.num_partitions(num_partitions);
}
if let Some(num_sub_vectors) = num_sub_vectors {
ivf_pq_builder = ivf_pq_builder.num_sub_vectors(num_sub_vectors);
}
if let Some(max_iterations) = max_iterations {
ivf_pq_builder = ivf_pq_builder.max_iterations(max_iterations);
}
if let Some(sample_rate) = sample_rate {
ivf_pq_builder = ivf_pq_builder.sample_rate(sample_rate);
}
Ok(Self {
inner: Mutex::new(Some(LanceDbIndex::IvfPq(ivf_pq_builder))),
})
}
#[napi]
pub fn scalar(&self) -> napi::Result<()> {
self.modify(|b| b.scalar())
}
#[napi]
pub async fn build(&self) -> napi::Result<()> {
let inner = self.inner.lock().unwrap().take().ok_or_else(|| {
napi::Error::from_reason("IndexBuilder has already been consumed".to_string())
})?;
inner
.build()
.await
.map_err(|e| napi::Error::from_reason(format!("Failed to build index: {}", e)))?;
Ok(())
#[napi(factory)]
pub fn btree() -> Self {
Self {
inner: Mutex::new(Some(LanceDbIndex::BTree(BTreeIndexBuilder::default()))),
}
}
}

View File

@@ -13,7 +13,7 @@
// limitations under the License.
use futures::StreamExt;
use lance::io::RecordBatchStream;
use lancedb::arrow::SendableRecordBatchStream;
use lancedb::ipc::batches_to_ipc_file;
use napi::bindgen_prelude::*;
use napi_derive::napi;
@@ -21,12 +21,12 @@ use napi_derive::napi;
/** Typescript-style Async Iterator over RecordBatches */
#[napi]
pub struct RecordBatchIterator {
inner: Box<dyn RecordBatchStream + Unpin>,
inner: SendableRecordBatchStream,
}
#[napi]
impl RecordBatchIterator {
pub(crate) fn new(inner: Box<dyn RecordBatchStream + Unpin>) -> Self {
pub(crate) fn new(inner: SendableRecordBatchStream) -> Self {
Self { inner }
}

View File

@@ -16,6 +16,7 @@ use connection::Connection;
use napi_derive::*;
mod connection;
mod error;
mod index;
mod iterator;
mod query;

View File

@@ -74,6 +74,6 @@ impl Query {
let inner_stream = self.inner.execute_stream().await.map_err(|e| {
napi::Error::from_reason(format!("Failed to execute query stream: {}", e))
})?;
Ok(RecordBatchIterator::new(Box::new(inner_stream)))
Ok(RecordBatchIterator::new(inner_stream))
}
}

View File

@@ -13,13 +13,16 @@
// limitations under the License.
use arrow_ipc::writer::FileWriter;
use lance::dataset::ColumnAlteration as LanceColumnAlteration;
use lancedb::ipc::ipc_file_to_batches;
use lancedb::table::{AddDataMode, Table as LanceDbTable};
use lancedb::table::{
AddDataMode, ColumnAlteration as LanceColumnAlteration, NewColumnTransform,
Table as LanceDbTable,
};
use napi::bindgen_prelude::*;
use napi_derive::napi;
use crate::index::IndexBuilder;
use crate::error::NapiErrorExt;
use crate::index::Index;
use crate::query::Query;
#[napi]
@@ -129,8 +132,22 @@ impl Table {
}
#[napi]
pub fn create_index(&self) -> napi::Result<IndexBuilder> {
Ok(IndexBuilder::new(self.inner_ref()?))
pub async fn create_index(
&self,
index: Option<&Index>,
column: String,
replace: Option<bool>,
) -> napi::Result<()> {
let lancedb_index = if let Some(index) = index {
index.consume()?
} else {
lancedb::index::Index::Auto
};
let mut builder = self.inner_ref()?.create_index(&[column], lancedb_index);
if let Some(replace) = replace {
builder = builder.replace(replace);
}
builder.execute().await.default_error()
}
#[napi]
@@ -144,7 +161,7 @@ impl Table {
.into_iter()
.map(|sql| (sql.name, sql.value_sql))
.collect::<Vec<_>>();
let transforms = lance::dataset::NewColumnTransform::SqlExpressions(transforms);
let transforms = NewColumnTransform::SqlExpressions(transforms);
self.inner_ref()?
.add_columns(transforms, None)
.await