From dfabbe9081f04e17296459a39fd06c037e916fb2 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Wed, 24 Jan 2024 10:05:12 -0800 Subject: [PATCH] feat(rust): create index API improvement (#853) * Extract a minimal Table interface in Rust SDK * Make create_index composable in Rust. * Fix compiling issues from ffi --- Cargo.toml | 3 + nodejs/src/table.rs | 8 +- rust/ffi/node/src/index/scalar.rs | 8 +- rust/ffi/node/src/index/vector.rs | 112 +++---- rust/ffi/node/src/query.rs | 17 +- rust/ffi/node/src/table.rs | 38 ++- rust/vectordb/Cargo.toml | 6 +- rust/vectordb/src/connection.rs | 55 ++-- rust/vectordb/src/error.rs | 12 + rust/vectordb/src/index.rs | 261 +++++++++++++++- rust/vectordb/src/index/vector.rs | 173 ----------- rust/vectordb/src/io/object_store.rs | 9 +- rust/vectordb/src/lib.rs | 6 +- rust/vectordb/src/query.rs | 16 +- rust/vectordb/src/table.rs | 438 ++++++++++++++++----------- 15 files changed, 665 insertions(+), 497 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 0dd1a401..d6615516 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,11 +24,14 @@ arrow-ord = "49.0" arrow-schema = "49.0" arrow-arith = "49.0" arrow-cast = "49.0" +async-trait = "0" chrono = "0.4.23" half = { "version" = "=2.3.1", default-features = false, features = [ "num-traits", ] } +futures = "0" log = "0.4" object_store = "0.9.0" snafu = "0.7.4" url = "2" +num-traits = "0.2" diff --git a/nodejs/src/table.rs b/nodejs/src/table.rs index a7bc2b4c..3dbce198 100644 --- a/nodejs/src/table.rs +++ b/nodejs/src/table.rs @@ -16,16 +16,16 @@ use crate::query::Query; use arrow_ipc::writer::FileWriter; use napi::bindgen_prelude::*; use napi_derive::napi; -use vectordb::{ipc::ipc_file_to_batches, table::Table as LanceDBTable}; +use vectordb::{ipc::ipc_file_to_batches, table::TableRef}; #[napi] pub struct Table { - pub(crate) table: LanceDBTable, + pub(crate) table: TableRef, } #[napi] impl Table { - pub(crate) fn new(table: LanceDBTable) -> Self { + pub(crate) fn new(table: TableRef) -> Self { Self { table } } @@ -46,7 +46,7 @@ impl Table { pub async unsafe fn add(&mut self, buf: Buffer) -> napi::Result<()> { let batches = ipc_file_to_batches(buf.to_vec()) .map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?; - self.table.add(batches, None).await.map_err(|e| { + self.table.add(Box::new(batches), None).await.map_err(|e| { napi::Error::from_reason(format!( "Failed to add batches to table {}: {}", self.table, e diff --git a/rust/ffi/node/src/index/scalar.rs b/rust/ffi/node/src/index/scalar.rs index f940b62b..0cd2f86b 100644 --- a/rust/ffi/node/src/index/scalar.rs +++ b/rust/ffi/node/src/index/scalar.rs @@ -29,10 +29,14 @@ pub(crate) fn table_create_scalar_index(mut cx: FunctionContext) -> JsResult JsResult { let js_table = cx.this().downcast_or_throw::, _>(&mut cx)?; let index_params = cx.argument::(0)?; - let index_params_builder = get_index_params_builder(&mut cx, index_params).or_throw(&mut cx)?; let rt = runtime(&mut cx)?; let (deferred, promise) = cx.promise(); let channel = cx.channel(); - let mut table = js_table.table.clone(); + let table = js_table.table.clone(); + + let column_name = index_params + .get_opt::(&mut cx, "column")? + .map(|s| s.value(&mut cx)) + .unwrap_or("vector".to_string()); // Backward compatibility + + let tbl = table.clone(); + let mut index_builder = tbl.create_index(&[&column_name]); + get_index_params_builder(&mut cx, index_params, &mut index_builder).or_throw(&mut cx)?; rt.spawn(async move { - let idx_result = table.create_index(&index_params_builder).await; - + let idx_result = index_builder.build().await; deferred.settle_with(&channel, move |mut cx| { idx_result.or_throw(&mut cx)?; Ok(cx.boxed(JsTable::from(table))) @@ -51,66 +56,39 @@ pub(crate) fn table_create_vector_index(mut cx: FunctionContext) -> JsResult, -) -> crate::error::Result { - let idx_type = obj.get::(cx, "type")?.value(cx); - - match idx_type.as_str() { - "ivf_pq" => { - let mut index_builder: IvfPQIndexBuilder = IvfPQIndexBuilder::new(); - let mut pq_params = PQBuildParams::default(); - - obj.get_opt::(cx, "column")? - .map(|s| index_builder.column(s.value(cx))); - - obj.get_opt::(cx, "index_name")? - .map(|s| index_builder.index_name(s.value(cx))); - - if let Some(metric_type) = obj.get_opt::(cx, "metric_type")? { - let metric_type = MetricType::try_from(metric_type.value(cx).as_str()).unwrap(); - index_builder.metric_type(metric_type); - } - - let num_partitions = obj.get_opt_usize(cx, "num_partitions")?; - let max_iters = obj.get_opt_usize(cx, "max_iters")?; - - num_partitions.map(|np| { - let max_iters = max_iters.unwrap_or(50); - let ivf_params = IvfBuildParams { - num_partitions: np, - max_iters, - ..Default::default() - }; - index_builder.ivf_params(ivf_params) - }); - - if let Some(use_opq) = obj.get_opt::(cx, "use_opq")? { - pq_params.use_opq = use_opq.value(cx); - } - - if let Some(num_sub_vectors) = obj.get_opt_usize(cx, "num_sub_vectors")? { - pq_params.num_sub_vectors = num_sub_vectors; - } - - if let Some(num_bits) = obj.get_opt_usize(cx, "num_bits")? { - pq_params.num_bits = num_bits; - } - - if let Some(max_iters) = obj.get_opt_usize(cx, "max_iters")? { - pq_params.max_iters = max_iters; - } - - if let Some(max_opq_iters) = obj.get_opt_usize(cx, "max_opq_iters")? { - pq_params.max_opq_iters = max_opq_iters; - } - - if let Some(replace) = obj.get_opt::(cx, "replace")? { - index_builder.replace(replace.value(cx)); - } - - Ok(index_builder) + builder: &mut IndexBuilder, +) -> crate::error::Result<()> { + match obj.get::(cx, "type")?.value(cx).as_str() { + "ivf_pq" => builder.ivf_pq(), + _ => { + return Err(InvalidIndexType { + index_type: "".into(), + }) } - index_type => Err(InvalidIndexType { - index_type: index_type.into(), - }), + }; + + obj.get_opt::(cx, "index_name")? + .map(|s| builder.name(s.value(cx).as_str())); + + if let Some(metric_type) = obj.get_opt::(cx, "metric_type")? { + let metric_type = MetricType::try_from(metric_type.value(cx).as_str())?; + builder.metric_type(metric_type); } + + if let Some(np) = obj.get_opt_usize(cx, "num_partitions")? { + builder.num_partitions(np as u64); + } + if let Some(ns) = obj.get_opt_u32(cx, "num_sub_vectors")? { + builder.num_sub_vectors(ns); + } + if let Some(max_iters) = obj.get_opt_u32(cx, "max_iters")? { + builder.max_iterations(max_iters); + } + if let Some(num_bits) = obj.get_opt_u32(cx, "num_bits")? { + builder.num_bits(num_bits); + } + if let Some(replace) = obj.get_opt::(cx, "replace")? { + builder.replace(replace.value(cx)); + } + Ok(()) } diff --git a/rust/ffi/node/src/query.rs b/rust/ffi/node/src/query.rs index dc3a3438..708e007e 100644 --- a/rust/ffi/node/src/query.rs +++ b/rust/ffi/node/src/query.rs @@ -70,14 +70,15 @@ impl JsQuery { let query = query_vector.map(|q| convert::js_array_to_vec(q.deref(), &mut cx)); rt.spawn(async move { - let mut builder = table - .search(query) - .refine_factor(refine_factor) - .nprobes(nprobes) - .filter(filter) - .metric_type(metric_type) - .select(select) - .prefilter(prefilter); + let mut builder = table.query(); + if let Some(query) = query { + builder = builder + .query_vector(&query) + .refine_factor(refine_factor) + .nprobes(nprobes) + .metric_type(metric_type); + }; + builder = builder.filter(filter).select(select).prefilter(prefilter); if let Some(limit) = limit { builder = builder.limit(limit as usize); }; diff --git a/rust/ffi/node/src/table.rs b/rust/ffi/node/src/table.rs index 0364b7e6..734c48cf 100644 --- a/rust/ffi/node/src/table.rs +++ b/rust/ffi/node/src/table.rs @@ -20,19 +20,19 @@ use lance::io::ObjectStoreParams; use crate::arrow::{arrow_buffer_to_record_batch, record_batch_to_buffer}; use neon::prelude::*; use neon::types::buffer::TypedArray; -use vectordb::Table; +use vectordb::TableRef; use crate::error::ResultExt; use crate::{convert, get_aws_creds, get_aws_region, runtime, JsDatabase}; pub(crate) struct JsTable { - pub table: Table, + pub table: TableRef, } impl Finalize for JsTable {} -impl From for JsTable { - fn from(table: Table) -> Self { +impl From for JsTable { + fn from(table: TableRef) -> Self { JsTable { table } } } @@ -96,7 +96,7 @@ impl JsTable { arrow_buffer_to_record_batch(buffer.as_slice(&cx)).or_throw(&mut cx)?; let rt = runtime(&mut cx)?; let channel = cx.channel(); - let mut table = js_table.table.clone(); + let table = js_table.table.clone(); let (deferred, promise) = cx.promise(); let write_mode = match write_mode.as_str() { @@ -118,7 +118,7 @@ impl JsTable { rt.spawn(async move { let batch_reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema); - let add_result = table.add(batch_reader, Some(params)).await; + let add_result = table.add(Box::new(batch_reader), Some(params)).await; deferred.settle_with(&channel, move |mut cx| { add_result.or_throw(&mut cx)?; @@ -152,7 +152,7 @@ impl JsTable { let (deferred, promise) = cx.promise(); let predicate = cx.argument::(0)?.value(&mut cx); let channel = cx.channel(); - let mut table = js_table.table.clone(); + let table = js_table.table.clone(); rt.spawn(async move { let delete_result = table.delete(&predicate).await; @@ -167,7 +167,7 @@ impl JsTable { pub(crate) fn js_update(mut cx: FunctionContext) -> JsResult { let js_table = cx.this().downcast_or_throw::, _>(&mut cx)?; - let mut table = js_table.table.clone(); + let table = js_table.table.clone(); let rt = runtime(&mut cx)?; let (deferred, promise) = cx.promise(); @@ -218,7 +218,11 @@ impl JsTable { let predicate = predicate.as_deref(); - let update_result = table.update(predicate, updates_arg).await; + let update_result = table + .as_native() + .unwrap() + .update(predicate, updates_arg) + .await; deferred.settle_with(&channel, move |mut cx| { update_result.or_throw(&mut cx)?; Ok(cx.boxed(JsTable::from(table))) @@ -249,6 +253,8 @@ impl JsTable { rt.spawn(async move { let stats = table + .as_native() + .unwrap() .cleanup_old_versions(older_than, Some(delete_unverified)) .await; @@ -278,7 +284,7 @@ impl JsTable { let js_table = cx.this().downcast_or_throw::, _>(&mut cx)?; let rt = runtime(&mut cx)?; let (deferred, promise) = cx.promise(); - let mut table = js_table.table.clone(); + let table = js_table.table.clone(); let channel = cx.channel(); let js_options = cx.argument::(0)?; @@ -310,7 +316,11 @@ impl JsTable { } rt.spawn(async move { - let stats = table.compact_files(options, None).await; + let stats = table + .as_native() + .unwrap() + .compact_files(options, None) + .await; deferred.settle_with(&channel, move |mut cx| { let stats = stats.or_throw(&mut cx)?; @@ -349,7 +359,7 @@ impl JsTable { let table = js_table.table.clone(); rt.spawn(async move { - let indices = table.load_indices().await; + let indices = table.as_native().unwrap().load_indices().await; deferred.settle_with(&channel, move |mut cx| { let indices = indices.or_throw(&mut cx)?; @@ -389,8 +399,8 @@ impl JsTable { rt.spawn(async move { let load_stats = futures::try_join!( - table.count_indexed_rows(&index_uuid), - table.count_unindexed_rows(&index_uuid) + table.as_native().unwrap().count_indexed_rows(&index_uuid), + table.as_native().unwrap().count_unindexed_rows(&index_uuid) ); deferred.settle_with(&channel, move |mut cx| { diff --git a/rust/vectordb/Cargo.toml b/rust/vectordb/Cargo.toml index 4725294f..e46dfde2 100644 --- a/rust/vectordb/Cargo.toml +++ b/rust/vectordb/Cargo.toml @@ -26,11 +26,11 @@ lance-index = { workspace = true } lance-linalg = { workspace = true } lance-testing = { workspace = true } tokio = { version = "1.23", features = ["rt-multi-thread"] } -log = { workspace = true } +log.workspace = true async-trait = "0" bytes = "1" -futures = "0" -num-traits = "0" +futures.workspace = true +num-traits.workspace = true url = { workspace = true } serde = { version = "^1" } serde_json = { version = "1" } diff --git a/rust/vectordb/src/connection.rs b/rust/vectordb/src/connection.rs index 9d1816d7..92ad0ee5 100644 --- a/rust/vectordb/src/connection.rs +++ b/rust/vectordb/src/connection.rs @@ -27,7 +27,7 @@ use snafu::prelude::*; use crate::error::{CreateDirSnafu, Error, InvalidTableNameSnafu, Result}; use crate::io::object_store::MirroringObjectStoreWrapper; -use crate::table::{ReadParams, Table}; +use crate::table::{NativeTable, ReadParams, TableRef}; pub const LANCE_FILE_EXTENSION: &str = "lance"; @@ -46,17 +46,20 @@ pub trait Connection: Send + Sync { /// * `params` - Optional [`WriteParams`] to create the table. /// /// # Returns - /// Created [`Table`], or [`Err(Error::TableAlreadyExists)`] if the table already exists. + /// Created [`TableRef`], or [`Err(Error::TableAlreadyExists)`] if the table already exists. async fn create_table( &self, name: &str, batches: Box, params: Option, - ) -> Result
; + ) -> Result; - async fn open_table(&self, name: &str) -> Result
; + async fn open_table(&self, name: &str) -> Result { + self.open_table_with_params(name, ReadParams::default()) + .await + } - async fn open_table_with_params(&self, name: &str, params: ReadParams) -> Result
; + async fn open_table_with_params(&self, name: &str, params: ReadParams) -> Result; /// Drop a table in the database. /// @@ -240,30 +243,19 @@ impl Connection for Database { name: &str, batches: Box, params: Option, - ) -> Result
{ + ) -> Result { let table_uri = self.table_uri(name)?; - Table::create( - &table_uri, - name, - batches, - self.store_wrapper.clone(), - params, - ) - .await - } - - /// Open a table in the database. - /// - /// # Arguments - /// * `name` - The name of the table. - /// - /// # Returns - /// - /// * A [Table] object. - async fn open_table(&self, name: &str) -> Result
{ - self.open_table_with_params(name, ReadParams::default()) - .await + Ok(Arc::new( + NativeTable::create( + &table_uri, + name, + batches, + self.store_wrapper.clone(), + params, + ) + .await?, + )) } /// Open a table in the database. @@ -274,10 +266,13 @@ impl Connection for Database { /// /// # Returns /// - /// * A [Table] object. - async fn open_table_with_params(&self, name: &str, params: ReadParams) -> Result
{ + /// * A [TableRef] object. + async fn open_table_with_params(&self, name: &str, params: ReadParams) -> Result { let table_uri = self.table_uri(name)?; - Table::open_with_params(&table_uri, name, self.store_wrapper.clone(), params).await + Ok(Arc::new( + NativeTable::open_with_params(&table_uri, name, self.store_wrapper.clone(), params) + .await?, + )) } async fn drop_table(&self, name: &str) -> Result<()> { diff --git a/rust/vectordb/src/error.rs b/rust/vectordb/src/error.rs index e5418d3c..2bdf97d6 100644 --- a/rust/vectordb/src/error.rs +++ b/rust/vectordb/src/error.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::sync::PoisonError; + use arrow_schema::ArrowError; use snafu::Snafu; @@ -35,6 +37,8 @@ pub enum Error { Lance { message: String }, #[snafu(display("LanceDB Schema Error: {message}"))] Schema { message: String }, + #[snafu(display("Runtime error: {message}"))] + Runtime { message: String }, } pub type Result = std::result::Result; @@ -70,3 +74,11 @@ impl From for Error { } } } + +impl From> for Error { + fn from(e: PoisonError) -> Self { + Self::Runtime { + message: e.to_string(), + } + } +} diff --git a/rust/vectordb/src/index.rs b/rust/vectordb/src/index.rs index ed07a8d5..1782b4b2 100644 --- a/rust/vectordb/src/index.rs +++ b/rust/vectordb/src/index.rs @@ -1,4 +1,4 @@ -// Copyright 2023 Lance Developers. +// Copyright 2024 Lance Developers. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,4 +12,263 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::{cmp::max, sync::Arc}; + +use lance_index::{DatasetIndexExt, IndexType}; +pub use lance_linalg::distance::MetricType; + pub mod vector; + +use crate::{Error, Result, Table}; + +/// Index Parameters. +pub enum IndexParams { + Scalar { + replace: bool, + }, + IvfPq { + replace: bool, + metric_type: MetricType, + num_partitions: u64, + num_sub_vectors: u32, + num_bits: u32, + sample_rate: u32, + max_iterations: u32, + }, +} + +/// Builder for Index Parameters. + +pub struct IndexBuilder { + table: Arc, + columns: Vec, + // General parameters + /// Index name. + name: Option, + /// Replace the existing index. + replace: bool, + + index_type: IndexType, + + // Scalar index parameters + // Nothing to set here. + + // IVF_PQ parameters + metric_type: MetricType, + num_partitions: Option, + // PQ related + num_sub_vectors: Option, + num_bits: u32, + + /// The rate to find samples to train kmeans. + sample_rate: u32, + /// Max iteration to train kmeans. + max_iterations: u32, +} + +impl IndexBuilder { + pub(crate) fn new(table: Arc, columns: &[&str]) -> Self { + IndexBuilder { + table, + columns: columns.iter().map(|c| c.to_string()).collect(), + name: None, + replace: true, + index_type: IndexType::Scalar, + metric_type: MetricType::L2, + num_partitions: None, + num_sub_vectors: None, + num_bits: 8, + sample_rate: 256, + max_iterations: 50, + } + } + + /// Build a Scalar Index. + /// + /// Accepted parameters: + /// - `replace`: Replace the existing index. + /// - `name`: Index name. Default: `None` + pub fn scalar(&mut self) -> &mut Self { + self.index_type = IndexType::Scalar; + self + } + + /// Build an IVF PQ index. + /// + /// Accepted parameters: + /// - `replace`: Replace the existing index. + /// - `name`: Index name. Default: `None` + /// - `metric_type`: [MetricType] to use to build Vector Index. + /// - `num_partitions`: Number of IVF partitions. + /// - `num_sub_vectors`: Number of sub-vectors of PQ. + /// - `num_bits`: Number of bits used for PQ centroids. + /// - `sample_rate`: The rate to find samples to train kmeans. + /// - `max_iterations`: Max iteration to train kmeans. + pub fn ivf_pq(&mut self) -> &mut Self { + self.index_type = IndexType::Vector; + self + } + + /// Whether to replace the existing index, default is `true`. + pub fn replace(&mut self, v: bool) -> &mut Self { + self.replace = v; + self + } + + /// Set the index name. + pub fn name(&mut self, name: &str) -> &mut Self { + self.name = Some(name.to_string()); + self + } + + /// [MetricType] to use to build Vector Index. + /// + /// Default value is [MetricType::L2]. + pub fn metric_type(&mut self, metric_type: MetricType) -> &mut Self { + self.metric_type = metric_type; + self + } + + /// Number of IVF partitions. + pub fn num_partitions(&mut self, num_partitions: u64) -> &mut Self { + self.num_partitions = Some(num_partitions); + self + } + + /// Number of sub-vectors of PQ. + pub fn num_sub_vectors(&mut self, num_sub_vectors: u32) -> &mut Self { + self.num_sub_vectors = Some(num_sub_vectors); + self + } + + /// Number of bits used for PQ centroids. + pub fn num_bits(&mut self, num_bits: u32) -> &mut Self { + self.num_bits = num_bits; + self + } + + /// The rate to find samples to train kmeans. + pub fn sample_rate(&mut self, sample_rate: u32) -> &mut Self { + self.sample_rate = sample_rate; + self + } + + /// Max iteration to train kmeans. + pub fn max_iterations(&mut self, max_iterations: u32) -> &mut Self { + self.max_iterations = max_iterations; + self + } + + /// Build the parameters. + pub async fn build(&self) -> Result<()> { + if self.columns.len() != 1 { + return Err(Error::Schema { + message: "Only one column is supported for index".to_string(), + }); + } + let column = &self.columns[0]; + let schema = self.table.schema(); + let field = schema.field_with_name(column)?; + + let params = match self.index_type { + IndexType::Scalar => IndexParams::Scalar { + replace: self.replace, + }, + IndexType::Vector => { + let num_partitions = if let Some(n) = self.num_partitions { + n + } else { + suggested_num_partitions(self.table.count_rows().await?) + }; + let num_sub_vectors: u32 = if let Some(n) = self.num_sub_vectors { + n + } else { + match field.data_type() { + arrow_schema::DataType::FixedSizeList(_, n) => { + Ok::(suggested_num_sub_vectors(*n as u32)) + } + _ => Err(Error::Schema { + message: format!( + "Column '{}' is not a FixedSizeList", + &self.columns[0] + ), + }), + }? + }; + IndexParams::IvfPq { + replace: self.replace, + metric_type: self.metric_type, + num_partitions, + num_sub_vectors, + num_bits: self.num_bits, + sample_rate: self.sample_rate, + max_iterations: self.max_iterations, + } + } + }; + + let tbl = self + .table + .as_native() + .expect("Only native table is supported here"); + let mut dataset = tbl.clone_inner_dataset(); + match params { + IndexParams::Scalar { replace } => { + self.table + .as_native() + .unwrap() + .create_scalar_index(column, replace) + .await? + } + IndexParams::IvfPq { + replace, + metric_type, + num_partitions, + num_sub_vectors, + num_bits, + max_iterations, + .. + } => { + let lance_idx_params = lance::index::vector::VectorIndexParams::ivf_pq( + num_partitions as usize, + num_bits as u8, + num_sub_vectors as usize, + false, + metric_type, + max_iterations as usize, + ); + dataset + .create_index( + &[column], + IndexType::Vector, + None, + &lance_idx_params, + replace, + ) + .await?; + } + } + tbl.reset_dataset(dataset); + Ok(()) + } +} + +fn suggested_num_partitions(rows: usize) -> u64 { + let num_partitions = (rows as f64).sqrt() as u64; + max(1, num_partitions) +} + +fn suggested_num_sub_vectors(dim: u32) -> u32 { + if dim % 16 == 0 { + // Should be more aggressive than this default. + dim / 16 + } else if dim % 8 == 0 { + dim / 8 + } else { + log::warn!( + "The dimension of the vector is not divisible by 8 or 16, \ + which may cause performance degradation in PQ" + ); + 1 + } +} diff --git a/rust/vectordb/src/index/vector.rs b/rust/vectordb/src/index/vector.rs index 881612c9..9afcc467 100644 --- a/rust/vectordb/src/index/vector.rs +++ b/rust/vectordb/src/index/vector.rs @@ -14,104 +14,7 @@ use serde::Deserialize; -use lance::index::vector::pq::PQBuildParams; -use lance::index::vector::VectorIndexParams; use lance::table::format::{Index, Manifest}; -use lance_index::vector::ivf::IvfBuildParams; -use lance_linalg::distance::MetricType; - -pub trait VectorIndexBuilder { - fn get_column(&self) -> Option; - fn get_index_name(&self) -> Option; - fn build(&self) -> VectorIndexParams; - - fn get_replace(&self) -> bool; -} - -pub struct IvfPQIndexBuilder { - column: Option, - index_name: Option, - metric_type: Option, - ivf_params: Option, - pq_params: Option, - replace: bool, -} - -impl IvfPQIndexBuilder { - pub fn new() -> IvfPQIndexBuilder { - Default::default() - } -} - -impl Default for IvfPQIndexBuilder { - fn default() -> Self { - IvfPQIndexBuilder { - column: None, - index_name: None, - metric_type: None, - ivf_params: None, - pq_params: None, - replace: true, - } - } -} - -impl IvfPQIndexBuilder { - pub fn column(&mut self, column: String) -> &mut IvfPQIndexBuilder { - self.column = Some(column); - self - } - - pub fn index_name(&mut self, index_name: String) -> &mut IvfPQIndexBuilder { - self.index_name = Some(index_name); - self - } - - pub fn metric_type(&mut self, metric_type: MetricType) -> &mut IvfPQIndexBuilder { - self.metric_type = Some(metric_type); - self - } - - pub fn ivf_params(&mut self, ivf_params: IvfBuildParams) -> &mut IvfPQIndexBuilder { - self.ivf_params = Some(ivf_params); - self - } - - pub fn pq_params(&mut self, pq_params: PQBuildParams) -> &mut IvfPQIndexBuilder { - self.pq_params = Some(pq_params); - self - } - - pub fn replace(&mut self, replace: bool) -> &mut IvfPQIndexBuilder { - self.replace = replace; - self - } -} - -impl VectorIndexBuilder for IvfPQIndexBuilder { - fn get_column(&self) -> Option { - self.column.clone() - } - - fn get_index_name(&self) -> Option { - self.index_name.clone() - } - - fn build(&self) -> VectorIndexParams { - let ivf_params = self.ivf_params.clone().unwrap_or_default(); - let pq_params = self.pq_params.clone().unwrap_or_default(); - - VectorIndexParams::with_ivf_pq_params( - self.metric_type.unwrap_or(MetricType::L2), - ivf_params, - pq_params, - ) - } - - fn get_replace(&self) -> bool { - self.replace - } -} pub struct VectorIndex { pub columns: Vec, @@ -139,79 +42,3 @@ pub struct VectorIndexStatistics { pub num_indexed_rows: usize, pub num_unindexed_rows: usize, } - -#[cfg(test)] -mod tests { - use super::*; - - use lance::index::vector::StageParams; - use lance_index::vector::ivf::IvfBuildParams; - use lance_index::vector::pq::PQBuildParams; - - use crate::index::vector::{IvfPQIndexBuilder, VectorIndexBuilder}; - - #[test] - fn test_builder_no_params() { - let index_builder = IvfPQIndexBuilder::new(); - assert!(index_builder.get_column().is_none()); - assert!(index_builder.get_index_name().is_none()); - - let index_params = index_builder.build(); - assert_eq!(index_params.stages.len(), 2); - if let StageParams::Ivf(ivf_params) = index_params.stages.get(0).unwrap() { - let default = IvfBuildParams::default(); - assert_eq!(ivf_params.num_partitions, default.num_partitions); - assert_eq!(ivf_params.max_iters, default.max_iters); - } else { - panic!("Expected first stage to be ivf") - } - - if let StageParams::PQ(pq_params) = index_params.stages.get(1).unwrap() { - assert_eq!(pq_params.use_opq, false); - } else { - panic!("Expected second stage to be pq") - } - } - - #[test] - fn test_builder_all_params() { - let mut index_builder = IvfPQIndexBuilder::new(); - - index_builder - .column("c".to_owned()) - .metric_type(MetricType::Cosine) - .index_name("index".to_owned()); - - assert_eq!(index_builder.column.clone().unwrap(), "c"); - assert_eq!(index_builder.metric_type.unwrap(), MetricType::Cosine); - assert_eq!(index_builder.index_name.clone().unwrap(), "index"); - - let ivf_params = IvfBuildParams::new(500); - let mut pq_params = PQBuildParams::default(); - pq_params.use_opq = true; - pq_params.max_iters = 1; - pq_params.num_bits = 8; - pq_params.num_sub_vectors = 50; - pq_params.max_opq_iters = 2; - index_builder.ivf_params(ivf_params); - index_builder.pq_params(pq_params); - - let index_params = index_builder.build(); - assert_eq!(index_params.stages.len(), 2); - if let StageParams::Ivf(ivf_params) = index_params.stages.get(0).unwrap() { - assert_eq!(ivf_params.num_partitions, 500); - } else { - assert!(false, "Expected first stage to be ivf") - } - - if let StageParams::PQ(pq_params) = index_params.stages.get(1).unwrap() { - assert_eq!(pq_params.use_opq, true); - assert_eq!(pq_params.max_iters, 1); - assert_eq!(pq_params.num_bits, 8); - assert_eq!(pq_params.num_sub_vectors, 50); - assert_eq!(pq_params.max_opq_iters, 2); - } else { - assert!(false, "Expected second stage to be pq") - } - } -} diff --git a/rust/vectordb/src/io/object_store.rs b/rust/vectordb/src/io/object_store.rs index dc027bc2..df6cbc25 100644 --- a/rust/vectordb/src/io/object_store.rs +++ b/rust/vectordb/src/io/object_store.rs @@ -335,14 +335,15 @@ impl WrappingObjectStore for MirroringObjectStoreWrapper { #[cfg(all(test, not(windows)))] mod test { use super::*; - use crate::connection::{Connection, Database}; - use arrow_array::PrimitiveArray; + use futures::TryStreamExt; use lance::{dataset::WriteParams, io::ObjectStoreParams}; use lance_testing::datagen::{BatchGenerator, IncrementingInt32, RandomVector}; use object_store::local::LocalFileSystem; use tempfile; + use crate::connection::{Connection, Database}; + #[tokio::test] async fn test_e2e() { let dir1 = tempfile::tempdir().unwrap().into_path(); @@ -374,9 +375,7 @@ mod test { assert_eq!(t.count_rows().await.unwrap(), 100); let q = t - .search(Some(PrimitiveArray::from_iter_values(vec![ - 0.1, 0.1, 0.1, 0.1, - ]))) + .search(&[0.1, 0.1, 0.1, 0.1]) .limit(10) .execute() .await diff --git a/rust/vectordb/src/lib.rs b/rust/vectordb/src/lib.rs index 64c5434a..b72a58b5 100644 --- a/rust/vectordb/src/lib.rs +++ b/rust/vectordb/src/lib.rs @@ -46,7 +46,7 @@ //! #### Connect to a database. //! //! ```rust -//! use vectordb::{connection::{Database, Connection}, Table, WriteMode}; +//! use vectordb::{connection::{Database, Connection}, WriteMode}; //! use arrow_schema::{Field, Schema}; //! # tokio::runtime::Runtime::new().unwrap().block_on(async { //! let db = Database::connect("data/sample-lancedb").await.unwrap(); @@ -119,7 +119,7 @@ //! # db.create_table("my_table", Box::new(batches), None).await.unwrap(); //! let table = db.open_table("my_table").await.unwrap(); //! let results = table -//! .search(Some(vec![1.0; 128])) +//! .search(&[1.0; 128]) //! .execute() //! .await //! .unwrap() @@ -143,6 +143,6 @@ pub mod utils; pub use connection::{Connection, Database}; pub use error::{Error, Result}; -pub use table::Table; +pub use table::{Table, TableRef}; pub use lance::dataset::WriteMode; diff --git a/rust/vectordb/src/query.rs b/rust/vectordb/src/query.rs index c6ec42e9..0e72a27f 100644 --- a/rust/vectordb/src/query.rs +++ b/rust/vectordb/src/query.rs @@ -1,4 +1,4 @@ -// Copyright 2023 Lance Developers. +// Copyright 2024 Lance Developers. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -48,10 +48,10 @@ impl Query { /// # Returns /// /// * A [Query] object. - pub(crate) fn new(dataset: Arc, vector: Option) -> Self { + pub(crate) fn new(dataset: Arc) -> Self { Query { dataset, - query_vector: vector, + query_vector: None, column: crate::table::VECTOR_COLUMN_NAME.to_string(), limit: None, nprobes: 20, @@ -206,7 +206,7 @@ mod tests { let ds = Dataset::write(batches, "memory://foo", None).await.unwrap(); let vector = Some(Float32Array::from_iter_values([0.1, 0.2])); - let query = Query::new(Arc::new(ds), vector.clone()); + let query = Query::new(Arc::new(ds)).query_vector(&[0.1, 0.2]); assert_eq!(query.query_vector, vector); let new_vector = Float32Array::from_iter_values([9.8, 8.7]); @@ -232,9 +232,7 @@ mod tests { let batches = make_non_empty_batches(); let ds = Arc::new(Dataset::write(batches, "memory://foo", None).await.unwrap()); - let vector = Some(Float32Array::from_iter_values([0.1; 4])); - - let query = Query::new(ds.clone(), vector.clone()); + let query = Query::new(ds.clone()).query_vector(&[0.1; 4]); let result = query .limit(10) .filter(Some("id % 2 == 0".to_string())) @@ -247,7 +245,7 @@ mod tests { assert!(batch.expect("should be Ok").num_rows() < 10); } - let query = Query::new(ds, vector.clone()); + let query = Query::new(ds).query_vector(&[0.1; 4]); let result = query .limit(10) .filter(Some("id % 2 == 0".to_string())) @@ -268,7 +266,7 @@ mod tests { let batches = make_non_empty_batches(); let ds = Arc::new(Dataset::write(batches, "memory://foo", None).await.unwrap()); - let query = Query::new(ds.clone(), None); + let query = Query::new(ds.clone()); let result = query .filter(Some("id % 2 == 0".to_string())) .execute() diff --git a/rust/vectordb/src/table.rs b/rust/vectordb/src/table.rs index 8901f2bf..a5a66804 100644 --- a/rust/vectordb/src/table.rs +++ b/rust/vectordb/src/table.rs @@ -1,4 +1,4 @@ -// Copyright 2023 LanceDB Developers. +// Copyright 2024 LanceDB Developers. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,72 +12,170 @@ // See the License for the specific language governing permissions and // limitations under the License. +//! LanceDB Table APIs + +use std::path::Path; +use std::sync::{Arc, Mutex}; + +use arrow_array::RecordBatchReader; +use arrow_schema::{Schema, SchemaRef}; use chrono::Duration; use lance::dataset::builder::DatasetBuilder; -use lance::index::scalar::ScalarIndexParams; -use lance_index::optimize::OptimizeOptions; -use lance_index::IndexType; -use std::sync::Arc; - -use arrow_array::{Float32Array, RecordBatchReader}; -use arrow_schema::SchemaRef; use lance::dataset::cleanup::RemovalStats; use lance::dataset::optimize::{ compact_files, CompactionMetrics, CompactionOptions, IndexRemapperOptions, }; +pub use lance::dataset::ReadParams; use lance::dataset::{Dataset, UpdateBuilder, WriteParams}; +use lance::index::scalar::ScalarIndexParams; use lance::io::WrappingObjectStore; -use lance_index::DatasetIndexExt; -use std::path::Path; +use lance_index::{optimize::OptimizeOptions, DatasetIndexExt, IndexType}; use crate::error::{Error, Result}; -use crate::index::vector::{VectorIndex, VectorIndexBuilder, VectorIndexStatistics}; +use crate::index::vector::{VectorIndex, VectorIndexStatistics}; +use crate::index::IndexBuilder; use crate::query::Query; use crate::utils::{PatchReadParam, PatchWriteParam}; use crate::WriteMode; -pub use lance::dataset::ReadParams; - pub const VECTOR_COLUMN_NAME: &str = "vector"; +/// A Table is a collection of strong typed Rows. +/// +/// The type of the each row is defined in Apache Arrow [Schema]. +#[async_trait::async_trait] +pub trait Table: std::fmt::Display + Send + Sync { + fn as_any(&self) -> &dyn std::any::Any; + + /// Cast as [`NativeTable`], or return None it if is not a [`NativeTable`]. + fn as_native(&self) -> Option<&NativeTable>; + + /// Get the name of the table. + fn name(&self) -> &str; + + /// Get the arrow [Schema] of the table. + fn schema(&self) -> SchemaRef; + + /// Count the number of rows in this dataset. + async fn count_rows(&self) -> Result; + + /// Insert new records into this Table + /// + /// # Arguments + /// + /// * `batches` RecordBatch to be saved in the Table + /// * `params` Append / Overwrite existing records. Default: Append + async fn add( + &self, + batches: Box, + params: Option, + ) -> Result<()>; + + /// Delete the rows from table that match the predicate. + /// + /// # Arguments + /// - `predicate` - The SQL predicate string to filter the rows to be deleted. + /// + /// # Example + /// + /// ```no_run + /// # use std::sync::Arc; + /// # use vectordb::connection::{Database, Connection}; + /// # use arrow_array::{FixedSizeListArray, types::Float32Type, RecordBatch, + /// # RecordBatchIterator, Int32Array}; + /// # use arrow_schema::{Schema, Field, DataType}; + /// # tokio::runtime::Runtime::new().unwrap().block_on(async { + /// let tmpdir = tempfile::tempdir().unwrap(); + /// let db = Database::connect(tmpdir.path().to_str().unwrap()).await.unwrap(); + /// # let schema = Arc::new(Schema::new(vec![ + /// # Field::new("id", DataType::Int32, false), + /// # Field::new("vector", DataType::FixedSizeList( + /// # Arc::new(Field::new("item", DataType::Float32, true)), 128), true), + /// # ])); + /// let batches = RecordBatchIterator::new(vec![ + /// RecordBatch::try_new(schema.clone(), + /// vec![ + /// Arc::new(Int32Array::from_iter_values(0..10)), + /// Arc::new(FixedSizeListArray::from_iter_primitive::( + /// (0..10).map(|_| Some(vec![Some(1.0); 128])), 128)), + /// ]).unwrap() + /// ].into_iter().map(Ok), + /// schema.clone()); + /// let tbl = db.create_table("delete_test", Box::new(batches), None).await.unwrap(); + /// tbl.delete("id > 5").await.unwrap(); + /// # }); + /// ``` + async fn delete(&self, predicate: &str) -> Result<()>; + + /// Create an index on the column name. + /// + /// # Examples + /// + /// ```no_run + /// # use std::sync::Arc; + /// # use vectordb::connection::{Database, Connection}; + /// # use arrow_array::{FixedSizeListArray, types::Float32Type, RecordBatch, + /// # RecordBatchIterator, Int32Array}; + /// # use arrow_schema::{Schema, Field, DataType}; + /// # tokio::runtime::Runtime::new().unwrap().block_on(async { + /// let tmpdir = tempfile::tempdir().unwrap(); + /// let db = Database::connect(tmpdir.path().to_str().unwrap()).await.unwrap(); + /// # let tbl = db.open_table("delete_test").await.unwrap(); + /// tbl.create_index(&["vector"]) + /// .ivf_pq() + /// .num_partitions(256) + /// .build() + /// .await + /// .unwrap(); + /// # }); + /// ``` + fn create_index(&self, column: &[&str]) -> IndexBuilder; + + /// Search the table with a given query vector. + fn search(&self, query: &[f32]) -> Query { + self.query().query_vector(query) + } + + /// Create a Query builder. + fn query(&self) -> Query; +} + +/// Reference to a Table pointer. +pub type TableRef = Arc; + /// A table in a LanceDB database. #[derive(Debug, Clone)] -pub struct Table { +pub struct NativeTable { name: String, uri: String, - dataset: Arc, + dataset: Arc>, // the object store wrapper to use on write path store_wrapper: Option>, } -impl std::fmt::Display for Table { +impl std::fmt::Display for NativeTable { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "Table({})", self.name) } } -impl Table { +impl NativeTable { /// Opens an existing Table /// /// # Arguments /// - /// * `uri` - The uri to a [Table] + /// * `uri` - The uri to a [NativeTable] /// * `name` - The table name /// /// # Returns /// - /// * A [Table] object. + /// * A [NativeTable] object. pub async fn open(uri: &str) -> Result { let name = Self::get_table_name(uri)?; Self::open_with_params(uri, &name, None, ReadParams::default()).await } - /// Open an Table with a given name. - pub async fn open_with_name(uri: &str, name: &str) -> Result { - Self::open_with_params(uri, name, None, ReadParams::default()).await - } - /// Opens an existing Table /// /// # Arguments @@ -88,7 +186,7 @@ impl Table { /// /// # Returns /// - /// * A [Table] object. + /// * A [NativeTable] object. pub async fn open_with_params( uri: &str, name: &str, @@ -113,25 +211,26 @@ impl Table { message: e.to_string(), }, })?; - Ok(Table { + Ok(NativeTable { name: name.to_string(), uri: uri.to_string(), - dataset: Arc::new(dataset), + dataset: Arc::new(Mutex::new(dataset)), store_wrapper: write_store_wrapper, }) } - /// Checkout a specific version of this [`Table`] + /// Make a new clone of the internal lance dataset. + pub(crate) fn clone_inner_dataset(&self) -> Dataset { + self.dataset.lock().expect("Lock poison").clone() + } + + /// Checkout a specific version of this [NativeTable] /// pub async fn checkout(uri: &str, version: u64) -> Result { let name = Self::get_table_name(uri)?; Self::checkout_with_params(uri, &name, version, None, ReadParams::default()).await } - pub async fn checkout_with_name(uri: &str, name: &str, version: u64) -> Result { - Self::checkout_with_params(uri, name, version, None, ReadParams::default()).await - } - pub async fn checkout_with_params( uri: &str, name: &str, @@ -154,26 +253,27 @@ impl Table { message: e.to_string(), }, })?; - Ok(Table { + Ok(NativeTable { name: name.to_string(), uri: uri.to_string(), - dataset: Arc::new(dataset), + dataset: Arc::new(Mutex::new(dataset)), store_wrapper: write_store_wrapper, }) } pub async fn checkout_latest(&self) -> Result { - let latest_version_id = self.dataset.latest_version_id().await?; - let dataset = if latest_version_id == self.dataset.version().version { - self.dataset.clone() + let dataset = self.clone_inner_dataset(); + let latest_version_id = dataset.latest_version_id().await?; + let dataset = if latest_version_id == dataset.version().version { + dataset } else { - Arc::new(self.dataset.checkout_version(latest_version_id).await?) + dataset.checkout_version(latest_version_id).await? }; - Ok(Table { + Ok(Self { name: self.name.clone(), uri: self.uri.clone(), - dataset, + dataset: Arc::new(Mutex::new(dataset)), store_wrapper: self.store_wrapper.clone(), }) } @@ -203,7 +303,7 @@ impl Table { /// /// # Returns /// - /// * A [Table] object. + /// * A [TableImpl] object. pub(crate) async fn create( uri: &str, name: &str, @@ -227,46 +327,22 @@ impl Table { message: e.to_string(), }, })?; - Ok(Table { + Ok(NativeTable { name: name.to_string(), uri: uri.to_string(), - dataset: Arc::new(dataset), + dataset: Arc::new(Mutex::new(dataset)), store_wrapper: write_store_wrapper, }) } - /// Schema of this Table. - pub fn schema(&self) -> SchemaRef { - Arc::new(self.dataset.schema().into()) - } - /// Version of this Table pub fn version(&self) -> u64 { - self.dataset.version().version - } - - /// Create index on the table. - pub async fn create_index(&mut self, index_builder: &impl VectorIndexBuilder) -> Result<()> { - let mut dataset = self.dataset.as_ref().clone(); - dataset - .create_index( - &[index_builder - .get_column() - .unwrap_or(VECTOR_COLUMN_NAME.to_string()) - .as_str()], - IndexType::Vector, - index_builder.get_index_name(), - &index_builder.build(), - index_builder.get_replace(), - ) - .await?; - self.dataset = Arc::new(dataset); - Ok(()) + self.dataset.lock().expect("lock poison").version().version } /// Create a scalar index on the table - pub async fn create_scalar_index(&mut self, column: &str, replace: bool) -> Result<()> { - let mut dataset = self.dataset.as_ref().clone(); + pub async fn create_scalar_index(&self, column: &str, replace: bool) -> Result<()> { + let mut dataset = self.clone_inner_dataset(); let params = ScalarIndexParams::default(); dataset .create_index(&[column], IndexType::Scalar, None, ¶ms, replace) @@ -275,65 +351,21 @@ impl Table { } pub async fn optimize_indices(&mut self, options: &OptimizeOptions) -> Result<()> { - let mut dataset = self.dataset.as_ref().clone(); + let mut dataset = self.clone_inner_dataset(); dataset.optimize_indices(options).await?; Ok(()) } - /// Insert records into this Table - /// - /// # Arguments - /// - /// * `batches` RecordBatch to be saved in the Table - /// * `write_mode` Append / Overwrite existing records. Default: Append - /// # Returns - /// - /// * The number of rows added - pub async fn add( - &mut self, - batches: impl RecordBatchReader + Send + 'static, - params: Option, - ) -> Result<()> { - let params = Some(params.unwrap_or(WriteParams { - mode: WriteMode::Append, - ..WriteParams::default() - })); - - // patch the params if we have a write store wrapper - let params = match self.store_wrapper.clone() { - Some(wrapper) => params.patch_with_store_wrapper(wrapper)?, - None => params, - }; - - self.dataset = Arc::new(Dataset::write(batches, &self.uri, params).await?); - Ok(()) - } - pub fn query(&self) -> Query { - Query::new(self.dataset.clone(), None) - } - - /// Creates a new Query object that can be executed. - /// - /// # Arguments - /// - /// * `query_vector` The vector used for this query. - /// - /// # Returns - /// * A [Query] object. - pub fn search>(&self, query_vector: Option) -> Query { - Query::new(self.dataset.clone(), query_vector.map(|q| q.into())) + Query::new(self.clone_inner_dataset().into()) } pub fn filter(&self, expr: String) -> Query { - Query::new(self.dataset.clone(), None).filter(Some(expr)) + Query::new(self.clone_inner_dataset().into()).filter(Some(expr)) } /// Returns the number of rows in this Table - pub async fn count_rows(&self) -> Result { - Ok(self.dataset.count_rows().await?) - } /// Merge new data into this table. pub async fn merge( @@ -342,26 +374,14 @@ impl Table { left_on: &str, right_on: &str, ) -> Result<()> { - let mut dataset = self.dataset.as_ref().clone(); + let mut dataset = self.clone_inner_dataset(); dataset.merge(batches, left_on, right_on).await?; - self.dataset = Arc::new(dataset); + self.dataset = Arc::new(Mutex::new(dataset)); Ok(()) } - /// Delete rows from the table - pub async fn delete(&mut self, predicate: &str) -> Result<()> { - let mut dataset = self.dataset.as_ref().clone(); - dataset.delete(predicate).await?; - self.dataset = Arc::new(dataset); - Ok(()) - } - - pub async fn update( - &mut self, - predicate: Option<&str>, - updates: Vec<(&str, &str)>, - ) -> Result<()> { - let mut builder = UpdateBuilder::new(self.dataset.clone()); + pub async fn update(&self, predicate: Option<&str>, updates: Vec<(&str, &str)>) -> Result<()> { + let mut builder = UpdateBuilder::new(self.clone_inner_dataset().into()); if let Some(predicate) = predicate { builder = builder.update_where(predicate)?; } @@ -371,9 +391,8 @@ impl Table { } let operation = builder.build()?; - let new_ds = operation.execute().await?; - self.dataset = new_ds; - + let ds = operation.execute().await?; + self.reset_dataset(ds.as_ref().clone()); Ok(()) } @@ -393,8 +412,8 @@ impl Table { older_than: Duration, delete_unverified: Option, ) -> Result { - Ok(self - .dataset + let dataset = self.clone_inner_dataset(); + Ok(dataset .cleanup_old_versions(older_than, delete_unverified) .await?) } @@ -406,26 +425,28 @@ impl Table { /// /// This calls into [lance::dataset::optimize::compact_files]. pub async fn compact_files( - &mut self, + &self, options: CompactionOptions, remap_options: Option>, ) -> Result { - let mut dataset = self.dataset.as_ref().clone(); + let mut dataset = self.clone_inner_dataset(); let metrics = compact_files(&mut dataset, options, remap_options).await?; - self.dataset = Arc::new(dataset); + self.reset_dataset(dataset); Ok(metrics) } pub fn count_fragments(&self) -> usize { - self.dataset.count_fragments() + self.dataset.lock().expect("lock poison").count_fragments() } pub async fn count_deleted_rows(&self) -> Result { - Ok(self.dataset.count_deleted_rows().await?) + let dataset = self.clone_inner_dataset(); + Ok(dataset.count_deleted_rows().await?) } pub async fn num_small_files(&self, max_rows_per_group: usize) -> usize { - self.dataset.num_small_files(max_rows_per_group).await + let dataset = self.clone_inner_dataset(); + dataset.num_small_files(max_rows_per_group).await } pub async fn count_indexed_rows(&self, index_uuid: &str) -> Result> { @@ -443,8 +464,8 @@ impl Table { } pub async fn load_indices(&self) -> Result> { - let (indices, mf) = - futures::try_join!(self.dataset.load_indices(), self.dataset.latest_manifest())?; + let dataset = self.clone_inner_dataset(); + let (indices, mf) = futures::try_join!(dataset.load_indices(), dataset.latest_manifest())?; Ok(indices .iter() .map(|i| VectorIndex::new_from_format(&mf, i)) @@ -460,10 +481,8 @@ impl Table { if index.is_none() { return Ok(None); } - let index_stats = self - .dataset - .index_statistics(&index.unwrap().index_name) - .await?; + let dataset = self.clone_inner_dataset(); + let index_stats = dataset.index_statistics(&index.unwrap().index_name).await?; let index_stats: VectorIndexStatistics = serde_json::from_str(&index_stats).map_err(|e| Error::Lance { message: format!( @@ -474,6 +493,71 @@ impl Table { Ok(Some(index_stats)) } + + pub(crate) fn reset_dataset(&self, dataset: Dataset) { + *self.dataset.lock().expect("lock poison") = dataset; + } +} + +#[async_trait::async_trait] +impl Table for NativeTable { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_native(&self) -> Option<&NativeTable> { + Some(self) + } + + fn name(&self) -> &str { + self.name.as_str() + } + + fn schema(&self) -> SchemaRef { + let lance_schema = { self.dataset.lock().expect("lock poison").schema().clone() }; + Arc::new(Schema::from(&lance_schema)) + } + + async fn count_rows(&self) -> Result { + let dataset = { self.dataset.lock().expect("lock poison").clone() }; + Ok(dataset.count_rows().await?) + } + + async fn add( + &self, + batches: Box, + params: Option, + ) -> Result<()> { + let params = Some(params.unwrap_or(WriteParams { + mode: WriteMode::Append, + ..WriteParams::default() + })); + + // patch the params if we have a write store wrapper + let params = match self.store_wrapper.clone() { + Some(wrapper) => params.patch_with_store_wrapper(wrapper)?, + None => params, + }; + + self.reset_dataset(Dataset::write(batches, &self.uri, params).await?); + Ok(()) + } + + fn create_index(&self, columns: &[&str]) -> IndexBuilder { + IndexBuilder::new(Arc::new(self.clone()), columns) + } + + fn query(&self) -> Query { + Query::new(Arc::new(self.dataset.lock().expect("lock poison").clone())) + } + + /// Delete rows from the table + async fn delete(&self, predicate: &str) -> Result<()> { + let mut dataset = self.clone_inner_dataset(); + dataset.delete(predicate).await?; + self.reset_dataset(dataset); + Ok(()) + } } #[cfg(test)] @@ -491,14 +575,11 @@ mod tests { use arrow_schema::{DataType, Field, Schema, TimeUnit}; use futures::TryStreamExt; use lance::dataset::{Dataset, WriteMode}; - use lance::index::vector::pq::PQBuildParams; use lance::io::{ObjectStoreParams, WrappingObjectStore}; - use lance_index::vector::ivf::IvfBuildParams; use rand::Rng; use tempfile::tempdir; use super::*; - use crate::index::vector::IvfPQIndexBuilder; #[tokio::test] async fn test_open() { @@ -510,7 +591,9 @@ mod tests { .await .unwrap(); - let table = Table::open(dataset_path.to_str().unwrap()).await.unwrap(); + let table = NativeTable::open(dataset_path.to_str().unwrap()) + .await + .unwrap(); assert_eq!(table.name, "test") } @@ -519,7 +602,7 @@ mod tests { async fn test_open_not_found() { let tmp_dir = tempdir().unwrap(); let uri = tmp_dir.path().to_str().unwrap(); - let table = Table::open(uri).await; + let table = NativeTable::open(uri).await; assert!(matches!(table.unwrap_err(), Error::TableNotFound { .. })); } @@ -539,12 +622,12 @@ mod tests { let batches = make_test_batches(); let _ = batches.schema().clone(); - Table::create(&uri, "test", batches, None, None) + NativeTable::create(&uri, "test", batches, None, None) .await .unwrap(); let batches = make_test_batches(); - let result = Table::create(&uri, "test", batches, None, None).await; + let result = NativeTable::create(&uri, "test", batches, None, None).await; assert!(matches!( result.unwrap_err(), Error::TableAlreadyExists { .. } @@ -558,7 +641,7 @@ mod tests { let batches = make_test_batches(); let schema = batches.schema().clone(); - let mut table = Table::create(&uri, "test", batches, None, None) + let table = NativeTable::create(&uri, "test", batches, None, None) .await .unwrap(); assert_eq!(table.count_rows().await.unwrap(), 10); @@ -574,7 +657,7 @@ mod tests { schema.clone(), ); - table.add(new_batches, None).await.unwrap(); + table.add(Box::new(new_batches), None).await.unwrap(); assert_eq!(table.count_rows().await.unwrap(), 20); assert_eq!(table.name, "test"); } @@ -586,7 +669,7 @@ mod tests { let batches = make_test_batches(); let schema = batches.schema().clone(); - let mut table = Table::create(uri, "test", batches, None, None) + let table = NativeTable::create(uri, "test", batches, None, None) .await .unwrap(); assert_eq!(table.count_rows().await.unwrap(), 10); @@ -607,7 +690,7 @@ mod tests { ..Default::default() }; - table.add(new_batches, Some(param)).await.unwrap(); + table.add(Box::new(new_batches), Some(param)).await.unwrap(); assert_eq!(table.count_rows().await.unwrap(), 10); assert_eq!(table.name, "test"); } @@ -640,7 +723,7 @@ mod tests { ); Dataset::write(record_batch_iter, uri, None).await.unwrap(); - let mut table = Table::open(uri).await.unwrap(); + let table = NativeTable::open(uri).await.unwrap(); table .update(Some("id > 5"), vec![("name", "'foo'")]) @@ -772,7 +855,7 @@ mod tests { ); Dataset::write(record_batch_iter, uri, None).await.unwrap(); - let mut table = Table::open(uri).await.unwrap(); + let table = NativeTable::open(uri).await.unwrap(); // check it can do update for each type let updates: Vec<(&str, &str)> = vec![ @@ -889,11 +972,10 @@ mod tests { .await .unwrap(); - let table = Table::open(uri).await.unwrap(); + let table = NativeTable::open(uri).await.unwrap(); - let vector = Float32Array::from_iter_values([0.1, 0.2]); - let query = table.search(Some(vector.clone())); - assert_eq!(vector, query.query_vector.unwrap()); + let query = table.search(&[0.1, 0.2]); + assert_eq!(&[0.1, 0.2], query.query_vector.unwrap().values()); } #[derive(Default, Debug)] @@ -937,7 +1019,7 @@ mod tests { ..Default::default() }; assert!(!wrapper.called()); - let _ = Table::open_with_params(uri, "test", None, param) + let _ = NativeTable::open_with_params(uri, "test", None, param) .await .unwrap(); assert!(wrapper.called()); @@ -991,23 +1073,23 @@ mod tests { schema, ); - let mut table = Table::create(uri, "test", batches, None, None) + let table = NativeTable::create(uri, "test", batches, None, None) .await .unwrap(); - let mut i = IvfPQIndexBuilder::new(); assert_eq!(table.count_indexed_rows("my_index").await.unwrap(), None); assert_eq!(table.count_unindexed_rows("my_index").await.unwrap(), None); - let index_builder = i - .column("embeddings".to_string()) - .index_name("my_index".to_string()) - .ivf_params(IvfBuildParams::new(256)) - .pq_params(PQBuildParams::default()); + table + .create_index(&["embeddings"]) + .ivf_pq() + .name("my_index") + .num_partitions(256) + .build() + .await + .unwrap(); - table.create_index(index_builder).await.unwrap(); - - assert_eq!(table.dataset.load_indices().await.unwrap().len(), 1); + assert_eq!(table.load_indices().await.unwrap().len(), 1); assert_eq!(table.count_rows().await.unwrap(), 512); assert_eq!(table.name, "test");