diff --git a/Cargo.toml b/Cargo.toml index 2a0c973d..8448fd86 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,7 @@ resolver = "2" [workspace.dependencies] lance = { "version" = "=0.8.1", "features" = ["dynamodb"] } lance-linalg = { "version" = "=0.8.1" } +lance-testing = { "version" = "=0.8.1" } # Note that this one does not include pyarrow arrow = { version = "43.0.0", optional = false } arrow-array = "43.0" diff --git a/node/src/integration_test/test.ts b/node/src/integration_test/test.ts index f86a5070..4828213b 100644 --- a/node/src/integration_test/test.ts +++ b/node/src/integration_test/test.ts @@ -18,6 +18,9 @@ import * as chaiAsPromised from 'chai-as-promised' import { v4 as uuidv4 } from 'uuid' import * as lancedb from '../index' +import { tmpdir } from 'os' +import * as fs from 'fs' +import * as path from 'path' const assert = chai.assert chai.use(chaiAsPromised) @@ -41,3 +44,130 @@ describe('LanceDB AWS Integration test', function () { assert.equal(await table.countRows(), 6) }) }) + +describe('LanceDB Mirrored Store Integration test', function () { + it('s3://...?mirroredStore=... param is processed correctly', async function () { + this.timeout(600000) + + const dir = tmpdir() + console.log(dir) + const conn = await lancedb.connect(`s3://lancedb-integtest?mirroredStore=${dir}`) + const data = Array(200).fill({ vector: Array(128).fill(1.0), id: 0 }) + data.push(...Array(200).fill({ vector: Array(128).fill(1.0), id: 1 })) + data.push(...Array(200).fill({ vector: Array(128).fill(1.0), id: 2 })) + data.push(...Array(200).fill({ vector: Array(128).fill(1.0), id: 3 })) + + const tableName = uuidv4() + + // try create table and check if it's mirrored + const t = await conn.createTable(tableName, data, { writeMode: lancedb.WriteMode.Overwrite }) + + const mirroredPath = path.join(dir, `${tableName}.lance`) + fs.readdir(mirroredPath, { withFileTypes: true }, (err, files) => { + if (err != null) throw err + // there should be two dirs + assert.equal(files.length, 2) + assert.isTrue(files[0].isDirectory()) + assert.isTrue(files[1].isDirectory()) + + fs.readdir(path.join(mirroredPath, '_transactions'), { withFileTypes: true }, (err, files) => { + if (err != null) throw err + assert.equal(files.length, 1) + assert.isTrue(files[0].name.endsWith('.txn')) + }) + + fs.readdir(path.join(mirroredPath, 'data'), { withFileTypes: true }, (err, files) => { + if (err != null) throw err + assert.equal(files.length, 1) + assert.isTrue(files[0].name.endsWith('.lance')) + }) + }) + + // try create index and check if it's mirrored + await t.createIndex({ column: 'vector', type: 'ivf_pq' }) + + fs.readdir(mirroredPath, { withFileTypes: true }, (err, files) => { + if (err != null) throw err + // there should be two dirs + assert.equal(files.length, 3) + assert.isTrue(files[0].isDirectory()) + assert.isTrue(files[1].isDirectory()) + assert.isTrue(files[2].isDirectory()) + + // Two TXs now + fs.readdir(path.join(mirroredPath, '_transactions'), { withFileTypes: true }, (err, files) => { + if (err != null) throw err + assert.equal(files.length, 2) + assert.isTrue(files[0].name.endsWith('.txn')) + assert.isTrue(files[1].name.endsWith('.txn')) + }) + + fs.readdir(path.join(mirroredPath, 'data'), { withFileTypes: true }, (err, files) => { + if (err != null) throw err + assert.equal(files.length, 1) + assert.isTrue(files[0].name.endsWith('.lance')) + }) + + fs.readdir(path.join(mirroredPath, '_indices'), { withFileTypes: true }, (err, files) => { + if (err != null) throw err + assert.equal(files.length, 1) + assert.isTrue(files[0].isDirectory()) + + fs.readdir(path.join(mirroredPath, '_indices', files[0].name), { withFileTypes: true }, (err, files) => { + if (err != null) throw err + + assert.equal(files.length, 1) + assert.isTrue(files[0].isFile()) + assert.isTrue(files[0].name.endsWith('.idx')) + }) + }) + }) + + // try delete and check if it's mirrored + await t.delete('id = 0') + + fs.readdir(mirroredPath, { withFileTypes: true }, (err, files) => { + if (err != null) throw err + // there should be two dirs + assert.equal(files.length, 4) + assert.isTrue(files[0].isDirectory()) + assert.isTrue(files[1].isDirectory()) + assert.isTrue(files[2].isDirectory()) + assert.isTrue(files[3].isDirectory()) + + // Three TXs now + fs.readdir(path.join(mirroredPath, '_transactions'), { withFileTypes: true }, (err, files) => { + if (err != null) throw err + assert.equal(files.length, 3) + assert.isTrue(files[0].name.endsWith('.txn')) + assert.isTrue(files[1].name.endsWith('.txn')) + }) + + fs.readdir(path.join(mirroredPath, 'data'), { withFileTypes: true }, (err, files) => { + if (err != null) throw err + assert.equal(files.length, 1) + assert.isTrue(files[0].name.endsWith('.lance')) + }) + + fs.readdir(path.join(mirroredPath, '_indices'), { withFileTypes: true }, (err, files) => { + if (err != null) throw err + assert.equal(files.length, 1) + assert.isTrue(files[0].isDirectory()) + + fs.readdir(path.join(mirroredPath, '_indices', files[0].name), { withFileTypes: true }, (err, files) => { + if (err != null) throw err + + assert.equal(files.length, 1) + assert.isTrue(files[0].isFile()) + assert.isTrue(files[0].name.endsWith('.idx')) + }) + }) + + fs.readdir(path.join(mirroredPath, '_deletions'), { withFileTypes: true }, (err, files) => { + if (err != null) throw err + assert.equal(files.length, 1) + assert.isTrue(files[0].name.endsWith('.arrow')) + }) + }) + }) +}) diff --git a/rust/ffi/node/src/lib.rs b/rust/ffi/node/src/lib.rs index fcb8fe76..b0139b34 100644 --- a/rust/ffi/node/src/lib.rs +++ b/rust/ffi/node/src/lib.rs @@ -195,7 +195,7 @@ fn database_open_table(mut cx: FunctionContext) -> JsResult { let (deferred, promise) = cx.promise(); rt.spawn(async move { - let table_rst = database.open_table_with_params(&table_name, ¶ms).await; + let table_rst = database.open_table_with_params(&table_name, params).await; deferred.settle_with(&channel, move |mut cx| { let js_table = JsTable::from(table_rst.or_throw(&mut cx)?); diff --git a/rust/vectordb/Cargo.toml b/rust/vectordb/Cargo.toml index 7b95dbee..1415dc7a 100644 --- a/rust/vectordb/Cargo.toml +++ b/rust/vectordb/Cargo.toml @@ -21,11 +21,16 @@ snafu = { workspace = true } half = { workspace = true } lance = { workspace = true } lance-linalg = { workspace = true } +lance-testing = { workspace = true } tokio = { version = "1.23", features = ["rt-multi-thread"] } log = { workspace = true } +async-trait = "0" +bytes = "1" +futures = "0" num-traits = "0" url = { workspace = true } [dev-dependencies] tempfile = "3.5.0" rand = { version = "0.8.3", features = ["small_rng"] } +walkdir = "2" \ No newline at end of file diff --git a/rust/vectordb/src/database.rs b/rust/vectordb/src/database.rs index 3bc57d1d..c53e54cd 100644 --- a/rust/vectordb/src/database.rs +++ b/rust/vectordb/src/database.rs @@ -14,13 +14,16 @@ use std::fs::create_dir_all; use std::path::Path; +use std::sync::Arc; use arrow_array::RecordBatchReader; use lance::dataset::WriteParams; -use lance::io::object_store::ObjectStore; +use lance::io::object_store::{ObjectStore, WrappingObjectStore}; +use object_store::local::LocalFileSystem; use snafu::prelude::*; -use crate::error::{CreateDirSnafu, InvalidTableNameSnafu, Result}; +use crate::error::{CreateDirSnafu, Error, InvalidTableNameSnafu, Result}; +use crate::io::object_store::MirroringObjectStoreWrapper; use crate::table::{ReadParams, Table}; pub const LANCE_FILE_EXTENSION: &str = "lance"; @@ -31,10 +34,14 @@ pub struct Database { pub(crate) uri: String, pub(crate) base_path: object_store::path::Path, + + // the object store wrapper to use on write path + pub(crate) store_wrapper: Option>, } const LANCE_EXTENSION: &str = "lance"; const ENGINE: &str = "engine"; +const MIRRORED_STORE: &str = "mirroredStore"; /// A connection to LanceDB impl Database { @@ -55,6 +62,7 @@ impl Database { Ok(mut url) => { // iter thru the query params and extract the commit store param let mut engine = None; + let mut mirrored_store = None; let mut filtered_querys = vec![]; // WARNING: specifying engine is NOT a publicly supported feature in lancedb yet @@ -62,6 +70,13 @@ impl Database { for (key, value) in url.query_pairs() { if key == ENGINE { engine = Some(value.to_string()); + } else if key == MIRRORED_STORE { + if cfg!(windows) { + return Err(Error::Lance { + message: "mirrored store is not supported on windows".into(), + }); + } + mirrored_store = Some(value.to_string()); } else { // to owned so we can modify the url filtered_querys.push((key.to_string(), value.to_string())); @@ -96,11 +111,21 @@ impl Database { Self::try_create_dir(&plain_uri).context(CreateDirSnafu { path: plain_uri })?; } + let write_store_wrapper = match mirrored_store { + Some(path) => { + let mirrored_store = Arc::new(LocalFileSystem::new_with_prefix(path)?); + let wrapper = MirroringObjectStoreWrapper::new(mirrored_store); + Some(Arc::new(wrapper) as Arc) + } + None => None, + }; + Ok(Database { uri: table_base_uri, query_string, base_path, object_store, + store_wrapper: write_store_wrapper, }) } Err(_) => Self::open_path(uri).await, @@ -117,6 +142,7 @@ impl Database { query_string: None, base_path, object_store, + store_wrapper: None, }) } @@ -166,7 +192,15 @@ impl Database { params: Option, ) -> Result { let table_uri = self.table_uri(name)?; - Table::create(&table_uri, name, batches, params).await + + Table::create( + &table_uri, + name, + batches, + self.store_wrapper.clone(), + params, + ) + .await } /// Open a table in the database. @@ -178,7 +212,7 @@ impl Database { /// /// * A [Table] object. pub async fn open_table(&self, name: &str) -> Result
{ - self.open_table_with_params(name, &ReadParams::default()) + self.open_table_with_params(name, ReadParams::default()) .await } @@ -191,9 +225,9 @@ impl Database { /// # Returns /// /// * A [Table] object. - pub async fn open_table_with_params(&self, name: &str, params: &ReadParams) -> Result
{ + pub async fn open_table_with_params(&self, name: &str, params: ReadParams) -> Result
{ let table_uri = self.table_uri(name)?; - Table::open_with_params(&table_uri, name, params).await + Table::open_with_params(&table_uri, name, self.store_wrapper.clone(), params).await } /// Drop a table in the database. diff --git a/rust/vectordb/src/io.rs b/rust/vectordb/src/io.rs new file mode 100644 index 00000000..ef12e4fd --- /dev/null +++ b/rust/vectordb/src/io.rs @@ -0,0 +1 @@ +pub mod object_store; diff --git a/rust/vectordb/src/io/object_store.rs b/rust/vectordb/src/io/object_store.rs new file mode 100644 index 00000000..6258cc04 --- /dev/null +++ b/rust/vectordb/src/io/object_store.rs @@ -0,0 +1,397 @@ +// Copyright 2023 Lance Developers. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! A mirroring object store that mirror writes to a secondary object store + + +use std::{ + fmt::Formatter, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; + +use bytes::Bytes; +use futures::{stream::BoxStream, FutureExt, StreamExt}; +use lance::io::object_store::WrappingObjectStore; +use object_store::{ + path::Path, GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, Result, +}; + +use async_trait::async_trait; +use tokio::{ + io::{AsyncWrite, AsyncWriteExt}, + task::JoinHandle, +}; + +#[derive(Debug)] +struct MirroringObjectStore { + primary: Arc, + secondary: Arc, +} + +impl std::fmt::Display for MirroringObjectStore { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + writeln!(f, "MirrowingObjectStore")?; + writeln!(f, "primary:")?; + self.primary.fmt(f)?; + writeln!(f, "secondary:")?; + self.secondary.fmt(f)?; + Ok(()) + } +} + +trait PrimaryOnly { + fn primary_only(&self) -> bool; +} + +impl PrimaryOnly for Path { + fn primary_only(&self) -> bool { + self.to_string().contains("manifest") + } +} + +/// An object store that mirrors write to secondsry object store first +/// and than commit to primary object store. +/// +/// This is meant to mirrow writes to a less-durable but lower-latency +/// store. We have primary store that is durable but slow, and a secondary +/// store that is fast but not asdurable +/// +/// Note: this object store does not mirror writes to *.manifest files +#[async_trait] +impl ObjectStore for MirroringObjectStore { + async fn put(&self, location: &Path, bytes: Bytes) -> Result<()> { + if location.primary_only() { + self.primary.put(location, bytes).await + } else { + self.secondary.put(location, bytes.clone()).await?; + self.primary.put(location, bytes).await?; + Ok(()) + } + } + + async fn put_multipart( + &self, + location: &Path, + ) -> Result<(MultipartId, Box)> { + if location.primary_only() { + return self.primary.put_multipart(location).await; + } + + let (id, stream) = self.secondary.put_multipart(location).await?; + + let mirroring_upload = MirroringUpload::new( + Pin::new(stream), + self.primary.clone(), + self.secondary.clone(), + location.clone(), + ); + + Ok((id, Box::new(mirroring_upload))) + } + + async fn abort_multipart(&self, location: &Path, multipart_id: &MultipartId) -> Result<()> { + if location.primary_only() { + return self.primary.abort_multipart(location, multipart_id).await; + } + + self.secondary.abort_multipart(location, multipart_id).await + } + + // Reads are routed to primary only + async fn get_opts(&self, location: &Path, options: GetOptions) -> Result { + self.primary.get_opts(location, options).await + } + + async fn head(&self, location: &Path) -> Result { + self.primary.head(location).await + } + + // garbage collection on secondary will happen async from other means + async fn delete(&self, location: &Path) -> Result<()> { + self.primary.delete(location).await + } + + async fn list(&self, prefix: Option<&Path>) -> Result>> { + self.list(prefix).await + } + + async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result { + self.list_with_delimiter(prefix).await + } + + async fn copy(&self, from: &Path, to: &Path) -> Result<()> { + if from.primary_only() { + self.primary.copy(from, to).await + } else { + self.secondary.copy(from, to).await?; + self.primary.copy(from, to).await?; + Ok(()) + } + } + + async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> Result<()> { + self.primary.copy_if_not_exists(from, to).await + } +} + +struct MirroringUpload { + secondary_stream: Pin>, + + primary_store: Arc, + secondary_store: Arc, + location: Path, + + state: MirroringUploadShutdown, +} + +// The state goes from +// None +// -> (secondary)ShutingDown +// -> (secondary)ShutdownDone +// -> Uploading(to primary) +// -> Done +#[derive(Debug)] +enum MirroringUploadShutdown { + None, + ShutingDown, + ShutdownDone, + Uploading(Pin>>), + Completed, +} + +impl MirroringUpload { + pub fn new( + secondary_stream: Pin>, + primary_store: Arc, + secondary_store: Arc, + location: Path, + ) -> Self { + Self { + secondary_stream, + primary_store, + secondary_store, + location, + state: MirroringUploadShutdown::None, + } + } +} + +impl AsyncWrite for MirroringUpload { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + if !matches!(self.state, MirroringUploadShutdown::None) { + return Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::Other, + "already shutdown", + ))); + } + // Write to secondary first + let mut_self = self.get_mut(); + mut_self.secondary_stream.as_mut().poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if !matches!(self.state, MirroringUploadShutdown::None) { + return Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::Other, + "already shutdown", + ))); + } + + let mut_self = self.get_mut(); + mut_self.secondary_stream.as_mut().poll_flush(cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let mut_self = self.get_mut(); + + loop { + // try to shutdown secondary first + match &mut mut_self.state { + MirroringUploadShutdown::None | MirroringUploadShutdown::ShutingDown => { + match mut_self.secondary_stream.as_mut().poll_shutdown(cx) { + Poll::Ready(Ok(())) => { + mut_self.state = MirroringUploadShutdown::ShutdownDone; + // don't return, no waker is setup + } + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Pending => { + mut_self.state = MirroringUploadShutdown::ShutingDown; + return Poll::Pending; + } + } + } + MirroringUploadShutdown::ShutdownDone => { + let primary_store = mut_self.primary_store.clone(); + let secondary_store = mut_self.secondary_store.clone(); + let location = mut_self.location.clone(); + + let upload_future = + Box::pin(tokio::runtime::Handle::current().spawn(async move { + let mut source = + secondary_store.get(&location).await.unwrap().into_stream(); + let upload_stream = primary_store.put_multipart(&location).await; + let (_, mut stream) = upload_stream.unwrap(); + + while let Some(buf) = source.next().await { + let buf = buf.unwrap(); + stream.write_all(&buf).await.unwrap(); + } + + stream.shutdown().await.unwrap(); + })); + mut_self.state = MirroringUploadShutdown::Uploading(upload_future); + // don't return, no waker is setup + } + MirroringUploadShutdown::Uploading(ref mut join_handle) => { + match join_handle.poll_unpin(cx) { + Poll::Ready(Ok(())) => { + mut_self.state = MirroringUploadShutdown::Completed; + return Poll::Ready(Ok(())); + } + Poll::Ready(Err(e)) => { + mut_self.state = MirroringUploadShutdown::Completed; + return Poll::Ready(Err(e.into())); + } + Poll::Pending => { + return Poll::Pending; + } + } + } + MirroringUploadShutdown::Completed => { + return Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::Other, + "shutdown already completed", + ))) + } + } + } + } +} + +#[derive(Debug)] +pub struct MirroringObjectStoreWrapper { + secondary: Arc, +} + +impl MirroringObjectStoreWrapper { + pub fn new(secondary: Arc) -> Self { + Self { secondary } + } +} + +impl WrappingObjectStore for MirroringObjectStoreWrapper { + fn wrap(&self, primary: Arc) -> Arc { + Arc::new(MirroringObjectStore { + primary, + secondary: self.secondary.clone(), + }) + } +} + +// windows pathing can't be simply concatenated +#[cfg(all(test, not(windows)))] +mod test { + use super::*; + use crate::Database; + use arrow_array::PrimitiveArray; + use futures::TryStreamExt; + use lance::{dataset::WriteParams, io::object_store::ObjectStoreParams}; + use lance_testing::datagen::{BatchGenerator, IncrementingInt32, RandomVector}; + use object_store::local::LocalFileSystem; + use tempfile; + + #[tokio::test] + async fn test_e2e() { + let dir1 = tempfile::tempdir().unwrap().into_path(); + let dir2 = tempfile::tempdir().unwrap().into_path(); + + let secondary_store = LocalFileSystem::new_with_prefix(dir2.to_str().unwrap()).unwrap(); + let object_store_wrapper = Arc::new(MirroringObjectStoreWrapper { + secondary: Arc::new(secondary_store), + }); + + let db = Database::connect(dir1.to_str().unwrap()).await.unwrap(); + + let mut param = WriteParams::default(); + let mut store_params = ObjectStoreParams::default(); + store_params.object_store_wrapper = Some(object_store_wrapper); + param.store_params = Some(store_params); + + let mut datagen = BatchGenerator::new(); + datagen = datagen.col(Box::new(IncrementingInt32::default())); + datagen = datagen.col(Box::new(RandomVector::default().named("vector".into()))); + + let res = db + .create_table("test", datagen.batch(100), Some(param.clone())) + .await; + + // leave this here for easy debugging + let t = res.unwrap(); + + assert_eq!(t.count_rows().await.unwrap(), 100); + + let q = t + .search(PrimitiveArray::from_iter_values(vec![0.1, 0.1, 0.1, 0.1])) + .limit(10) + .execute() + .await + .unwrap(); + + let bateches = q.try_collect::>().await.unwrap(); + assert_eq!(bateches.len(), 1); + assert_eq!(bateches[0].num_rows(), 10); + + use walkdir::WalkDir; + + let primary_location = dir1.join("test.lance").canonicalize().unwrap(); + let secondary_location = dir2.join(primary_location.strip_prefix("/").unwrap()); + + let mut primary_iter = WalkDir::new(&primary_location).into_iter(); + let mut secondary_iter = WalkDir::new(&secondary_location).into_iter(); + + let mut primary_elem = primary_iter.next(); + let mut secondary_elem = secondary_iter.next(); + + loop { + if primary_elem.is_none() && secondary_elem.is_none() { + break; + } + // primary has more data then secondary, should not run out before secondary + let primary_f = primary_elem.unwrap().unwrap(); + // hit manifest, skip, _versions contains all the manifest and should not exist on secondary + let primary_raw_path = primary_f.file_name().to_str().unwrap(); + if primary_raw_path.contains("manifest") || primary_raw_path.contains("_versions") { + primary_elem = primary_iter.next(); + continue; + } + let secondary_f = secondary_elem.unwrap().unwrap(); + assert_eq!( + primary_f.path().strip_prefix(&primary_location), + secondary_f.path().strip_prefix(&secondary_location) + ); + + primary_elem = primary_iter.next(); + secondary_elem = secondary_iter.next(); + } + } +} diff --git a/rust/vectordb/src/lib.rs b/rust/vectordb/src/lib.rs index 46d1716e..3cb4c934 100644 --- a/rust/vectordb/src/lib.rs +++ b/rust/vectordb/src/lib.rs @@ -16,8 +16,10 @@ pub mod data; pub mod database; pub mod error; pub mod index; +pub mod io; pub mod query; pub mod table; +pub mod utils; pub use database::Database; pub use table::Table; diff --git a/rust/vectordb/src/table.rs b/rust/vectordb/src/table.rs index 252fc60a..77341630 100644 --- a/rust/vectordb/src/table.rs +++ b/rust/vectordb/src/table.rs @@ -18,11 +18,13 @@ use arrow_array::{Float32Array, RecordBatchReader}; use arrow_schema::SchemaRef; use lance::dataset::{Dataset, WriteParams}; use lance::index::IndexType; +use lance::io::object_store::WrappingObjectStore; use std::path::Path; use crate::error::{Error, Result}; use crate::index::vector::VectorIndexBuilder; use crate::query::Query; +use crate::utils::{PatchReadParam, PatchWriteParam}; use crate::WriteMode; pub use lance::dataset::ReadParams; @@ -35,6 +37,9 @@ pub struct Table { name: String, uri: String, dataset: Arc, + + // the object store wrapper to use on write path + store_wrapper: Option>, } impl std::fmt::Display for Table { @@ -56,12 +61,12 @@ impl Table { /// * A [Table] object. pub async fn open(uri: &str) -> Result { let name = Self::get_table_name(uri)?; - Self::open_with_params(uri, &name, &ReadParams::default()).await + Self::open_with_params(uri, &name, None, ReadParams::default()).await } /// Open an Table with a given name. pub async fn open_with_name(uri: &str, name: &str) -> Result { - Self::open_with_params(uri, name, &ReadParams::default()).await + Self::open_with_params(uri, name, None, ReadParams::default()).await } /// Opens an existing Table @@ -75,8 +80,18 @@ impl Table { /// # Returns /// /// * A [Table] object. - pub async fn open_with_params(uri: &str, name: &str, params: &ReadParams) -> Result { - let dataset = Dataset::open_with_params(uri, params) + pub async fn open_with_params( + uri: &str, + name: &str, + write_store_wrapper: Option>, + params: ReadParams, + ) -> Result { + // patch the params if we have a write store wrapper + let params = match write_store_wrapper.clone() { + Some(wrapper) => params.patch_with_store_wrapper(wrapper)?, + None => params, + }; + let dataset = Dataset::open_with_params(uri, ¶ms) .await .map_err(|e| match e { lance::Error::DatasetNotFound { .. } => Error::TableNotFound { @@ -90,6 +105,7 @@ impl Table { name: name.to_string(), uri: uri.to_string(), dataset: Arc::new(dataset), + store_wrapper: write_store_wrapper, }) } @@ -97,20 +113,26 @@ impl Table { /// pub async fn checkout(uri: &str, version: u64) -> Result { let name = Self::get_table_name(uri)?; - Self::checkout_with_params(uri, &name, version, &ReadParams::default()).await + Self::checkout_with_params(uri, &name, version, None, ReadParams::default()).await } pub async fn checkout_with_name(uri: &str, name: &str, version: u64) -> Result { - Self::checkout_with_params(uri, name, version, &ReadParams::default()).await + Self::checkout_with_params(uri, name, version, None, ReadParams::default()).await } pub async fn checkout_with_params( uri: &str, name: &str, version: u64, - params: &ReadParams, + write_store_wrapper: Option>, + params: ReadParams, ) -> Result { - let dataset = Dataset::checkout_with_params(uri, version, params) + // patch the params if we have a write store wrapper + let params = match write_store_wrapper.clone() { + Some(wrapper) => params.patch_with_store_wrapper(wrapper)?, + None => params, + }; + let dataset = Dataset::checkout_with_params(uri, version, ¶ms) .await .map_err(|e| match e { lance::Error::DatasetNotFound { .. } => Error::TableNotFound { @@ -124,6 +146,7 @@ impl Table { name: name.to_string(), uri: uri.to_string(), dataset: Arc::new(dataset), + store_wrapper: write_store_wrapper, }) } @@ -157,8 +180,15 @@ impl Table { uri: &str, name: &str, batches: impl RecordBatchReader + Send + 'static, + write_store_wrapper: Option>, params: Option, ) -> Result { + // patch the params if we have a write store wrapper + let params = match write_store_wrapper.clone() { + Some(wrapper) => params.patch_with_store_wrapper(wrapper)?, + None => params, + }; + let dataset = Dataset::write(batches, uri, params) .await .map_err(|e| match e { @@ -173,6 +203,7 @@ impl Table { name: name.to_string(), uri: uri.to_string(), dataset: Arc::new(dataset), + store_wrapper: write_store_wrapper, }) } @@ -191,7 +222,8 @@ impl Table { use lance::index::DatasetIndexExt; let mut dataset = self.dataset.as_ref().clone(); - dataset.create_index( + dataset + .create_index( &[index_builder .get_column() .unwrap_or(VECTOR_COLUMN_NAME.to_string()) @@ -220,12 +252,18 @@ impl Table { batches: impl RecordBatchReader + Send + 'static, params: Option, ) -> Result<()> { - let params = params.unwrap_or(WriteParams { + let params = Some(params.unwrap_or(WriteParams { mode: WriteMode::Append, ..WriteParams::default() - }); + })); - self.dataset = Arc::new(Dataset::write(batches, &self.uri, Some(params)).await?); + // patch the params if we have a write store wrapper + let params = match self.store_wrapper.clone() { + Some(wrapper) => params.patch_with_store_wrapper(wrapper)?, + None => params, + }; + + self.dataset = Arc::new(Dataset::write(batches, &self.uri, params).await?); Ok(()) } @@ -329,10 +367,12 @@ mod tests { let batches = make_test_batches(); let _ = batches.schema().clone(); - Table::create(&uri, "test", batches, None).await.unwrap(); + Table::create(&uri, "test", batches, None, None) + .await + .unwrap(); let batches = make_test_batches(); - let result = Table::create(&uri, "test", batches, None).await; + let result = Table::create(&uri, "test", batches, None, None).await; assert!(matches!( result.unwrap_err(), Error::TableAlreadyExists { .. } @@ -346,7 +386,9 @@ mod tests { let batches = make_test_batches(); let schema = batches.schema().clone(); - let mut table = Table::create(&uri, "test", batches, None).await.unwrap(); + let mut table = Table::create(&uri, "test", batches, None, None) + .await + .unwrap(); assert_eq!(table.count_rows().await.unwrap(), 10); let new_batches = RecordBatchIterator::new( @@ -372,7 +414,9 @@ mod tests { let batches = make_test_batches(); let schema = batches.schema().clone(); - let mut table = Table::create(uri, "test", batches, None).await.unwrap(); + let mut table = Table::create(uri, "test", batches, None, None) + .await + .unwrap(); assert_eq!(table.count_rows().await.unwrap(), 10); let new_batches = RecordBatchIterator::new( @@ -455,7 +499,9 @@ mod tests { ..Default::default() }; assert!(!wrapper.called()); - let _ = Table::open_with_params(uri, "test", ¶m).await.unwrap(); + let _ = Table::open_with_params(uri, "test", None, param) + .await + .unwrap(); assert!(wrapper.called()); } @@ -507,7 +553,9 @@ mod tests { schema, ); - let mut table = Table::create(uri, "test", batches, None).await.unwrap(); + let mut table = Table::create(uri, "test", batches, None, None) + .await + .unwrap(); let mut i = IvfPQIndexBuilder::new(); let index_builder = i diff --git a/rust/vectordb/src/utils.rs b/rust/vectordb/src/utils.rs new file mode 100644 index 00000000..c0afa5d1 --- /dev/null +++ b/rust/vectordb/src/utils.rs @@ -0,0 +1,67 @@ +use std::sync::Arc; + +use lance::{ + dataset::{ReadParams, WriteParams}, + io::object_store::{ObjectStoreParams, WrappingObjectStore}, +}; + +use crate::error::{Error, Result}; + +pub trait PatchStoreParam { + fn patch_with_store_wrapper( + self, + wrapper: Arc, + ) -> Result>; +} + +impl PatchStoreParam for Option { + fn patch_with_store_wrapper( + self, + wrapper: Arc, + ) -> Result> { + let mut params = self.unwrap_or_default(); + if params.object_store_wrapper.is_some() { + return Err(Error::Lance { + message: "can not patch param because object store is already set".into(), + }); + } + params.object_store_wrapper = Some(wrapper); + + Ok(Some(params)) + } +} + +pub trait PatchWriteParam { + fn patch_with_store_wrapper( + self, + wrapper: Arc, + ) -> Result>; +} + +impl PatchWriteParam for Option { + fn patch_with_store_wrapper( + self, + wrapper: Arc, + ) -> Result> { + let mut params = self.unwrap_or_default(); + params.store_params = params.store_params.patch_with_store_wrapper(wrapper)?; + Ok(Some(params)) + } +} + +// NOTE: we have some API inconsistency here. +// WriteParam is found in the form of Option and ReadParam is found in the form of ReadParam + +pub trait PatchReadParam { + fn patch_with_store_wrapper(self, wrapper: Arc) -> Result; +} + +impl PatchReadParam for ReadParams { + fn patch_with_store_wrapper( + mut self, + wrapper: Arc, + ) -> Result { + self.store_options = self.store_options.patch_with_store_wrapper(wrapper)?; + Ok(self) + } +}