diff --git a/Cargo.toml b/Cargo.toml index a4340531..9d9d939b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,9 +6,9 @@ members = [ resolver = "2" [workspace.dependencies] -lance = "=0.5.3" -arrow-array = "40.0" -arrow-data = "40.0" -arrow-schema = "40.0" -arrow-ipc = "40.0" +lance = "=0.5.5" +arrow-array = "42.0" +arrow-data = "42.0" +arrow-schema = "42.0" +arrow-ipc = "42.0" object_store = "0.6.1" diff --git a/node/package-lock.json b/node/package-lock.json index 5eee35b1..8a58eccc 100644 --- a/node/package-lock.json +++ b/node/package-lock.json @@ -1,12 +1,12 @@ { "name": "vectordb", - "version": "0.1.9", + "version": "0.1.10", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "vectordb", - "version": "0.1.9", + "version": "0.1.10", "license": "Apache-2.0", "dependencies": { "@apache-arrow/ts": "^12.0.0", diff --git a/node/src/index.ts b/node/src/index.ts index 4721a53e..4f1cc400 100644 --- a/node/src/index.ts +++ b/node/src/index.ts @@ -122,6 +122,14 @@ export interface Table { delete: (filter: string) => Promise } +export interface AwsCredentials { + accessKeyId: string + + secretKey: string + + sessionToken?: string +} + /** * A connection to a LanceDB database. */ @@ -186,16 +194,23 @@ export class LocalConnection implements Connection { * @param embeddings An embedding function to use on this Table */ async createTable (name: string, data: Array>, mode: WriteMode, embeddings: EmbeddingFunction): Promise> - async createTable (name: string, data: Array>, mode: WriteMode, embeddings?: EmbeddingFunction): Promise> { + async createTable (name: string, data: Array>, mode: WriteMode, embeddings?: EmbeddingFunction, awsCredentials?: AwsCredentials): Promise> { if (mode === undefined) { mode = WriteMode.Create } - const tbl = await tableCreate.call(this._db, name, await fromRecordsToBuffer(data, embeddings), mode.toLowerCase()) - if (embeddings !== undefined) { - return new LocalTable(tbl, name, embeddings) - } else { - return new LocalTable(tbl, name) + + const createArgs = [this._db, name, await fromRecordsToBuffer(data, embeddings), mode.toLowerCase()] + if (awsCredentials !== undefined) { + createArgs.push(awsCredentials.accessKeyId) + createArgs.push(awsCredentials.secretKey) + if (awsCredentials.sessionToken !== undefined) { + createArgs.push(awsCredentials.sessionToken) + } } + + const tbl = await tableCreate.call(...createArgs) + + return new LocalTable(tbl, name, embeddings, awsCredentials) } async createTableArrow (name: string, table: ArrowTable): Promise { @@ -217,6 +232,7 @@ export class LocalTable implements Table { private readonly _tbl: any private readonly _name: string private readonly _embeddings?: EmbeddingFunction + private readonly _awsCredentials?: AwsCredentials constructor (tbl: any, name: string) /** @@ -225,10 +241,12 @@ export class LocalTable implements Table { * @param embeddings An embedding function to use when interacting with this table */ constructor (tbl: any, name: string, embeddings: EmbeddingFunction) - constructor (tbl: any, name: string, embeddings?: EmbeddingFunction) { + constructor (tbl: any, name: string, embeddings?: EmbeddingFunction, awsCredentials?: AwsCredentials) + constructor (tbl: any, name: string, embeddings?: EmbeddingFunction, awsCredentials?: AwsCredentials) { this._tbl = tbl this._name = name this._embeddings = embeddings + this._awsCredentials = awsCredentials } get name (): string { @@ -250,7 +268,15 @@ export class LocalTable implements Table { * @return The number of rows added to the table */ async add (data: Array>): Promise { - return tableAdd.call(this._tbl, await fromRecordsToBuffer(data, this._embeddings), WriteMode.Append.toString()) + const callArgs = [this._tbl, await fromRecordsToBuffer(data, this._embeddings), WriteMode.Append.toString()] + if (this._awsCredentials !== undefined) { + callArgs.push(this._awsCredentials.accessKeyId) + callArgs.push(this._awsCredentials.secretKey) + if (this._awsCredentials.sessionToken !== undefined) { + callArgs.push(this._awsCredentials.sessionToken) + } + } + return tableAdd.call(...callArgs) } /** @@ -260,6 +286,14 @@ export class LocalTable implements Table { * @return The number of rows added to the table */ async overwrite (data: Array>): Promise { + const callArgs = [this._tbl, await fromRecordsToBuffer(data, this._embeddings), WriteMode.Overwrite.toString()] + if (this._awsCredentials !== undefined) { + callArgs.push(this._awsCredentials.accessKeyId) + callArgs.push(this._awsCredentials.secretKey) + if (this._awsCredentials.sessionToken !== undefined) { + callArgs.push(this._awsCredentials.sessionToken) + } + } return tableAdd.call(this._tbl, await fromRecordsToBuffer(data, this._embeddings), WriteMode.Overwrite.toString()) } diff --git a/rust/ffi/node/Cargo.toml b/rust/ffi/node/Cargo.toml index 14d4dd31..815b28d0 100644 --- a/rust/ffi/node/Cargo.toml +++ b/rust/ffi/node/Cargo.toml @@ -19,3 +19,5 @@ lance = { workspace = true } vectordb = { path = "../../vectordb" } tokio = { version = "1.23", features = ["rt-multi-thread"] } neon = {version = "0.10.1", default-features = false, features = ["channel-api", "napi-6", "promise-api", "task-api"] } +object_store = { workspace = true, features = ["aws"] } +async-trait = "0" diff --git a/rust/ffi/node/src/arrow.rs b/rust/ffi/node/src/arrow.rs index f16ea60c..c9c4ac81 100644 --- a/rust/ffi/node/src/arrow.rs +++ b/rust/ffi/node/src/arrow.rs @@ -13,7 +13,6 @@ // limitations under the License. use std::io::Cursor; -use std::ops::Deref; use std::sync::Arc; use arrow_array::cast::as_list_array; @@ -25,10 +24,13 @@ use lance::arrow::{FixedSizeListArrayExt, RecordBatchExt}; pub(crate) fn convert_record_batch(record_batch: RecordBatch) -> RecordBatch { let column = record_batch .column_by_name("vector") + .cloned() .expect("vector column is missing"); - let arr = as_list_array(column.deref()); + // TODO: we should just consume the underlaying js buffer in the future instead of this arrow around a bunch of times + let arr = as_list_array(column.as_ref()); let list_size = arr.values().len() / record_batch.num_rows(); - let r = FixedSizeListArray::try_new(arr.values(), list_size as i32).unwrap(); + let r = + FixedSizeListArray::try_new_from_values(arr.values().to_owned(), list_size as i32).unwrap(); let schema = Arc::new(Schema::new(vec![Field::new( "vector", diff --git a/rust/ffi/node/src/lib.rs b/rust/ffi/node/src/lib.rs index 40cf3743..232cbaa3 100644 --- a/rust/ffi/node/src/lib.rs +++ b/rust/ffi/node/src/lib.rs @@ -17,19 +17,23 @@ use std::convert::TryFrom; use std::ops::Deref; use std::sync::{Arc, Mutex}; -use arrow_array::{Float32Array, RecordBatchIterator, RecordBatchReader}; +use arrow_array::{Float32Array, RecordBatchIterator}; use arrow_ipc::writer::FileWriter; +use async_trait::async_trait; use futures::{TryFutureExt, TryStreamExt}; -use lance::dataset::{WriteMode, WriteParams}; +use lance::dataset::{ReadParams, WriteMode, WriteParams}; use lance::index::vector::MetricType; +use lance::io::object_store::ObjectStoreParams; use neon::prelude::*; use neon::types::buffer::TypedArray; +use object_store::aws::{AwsCredential, AwsCredentialProvider}; +use object_store::CredentialProvider; use once_cell::sync::OnceCell; use tokio::runtime::Runtime; use vectordb::database::Database; use vectordb::error::Error; -use vectordb::table::Table; +use vectordb::table::{OpenTableParams, Table}; use crate::arrow::arrow_buffer_to_record_batch; @@ -49,6 +53,33 @@ struct JsTable { impl Finalize for JsTable {} +// TODO: object_store didn't export this type so I copied it. +// Make a requiest to object_store to export this type +#[derive(Debug)] +pub struct StaticCredentialProvider { + credential: Arc, +} + +impl StaticCredentialProvider { + pub fn new(credential: T) -> Self { + Self { + credential: Arc::new(credential), + } + } +} + +#[async_trait] +impl CredentialProvider for StaticCredentialProvider +where + T: std::fmt::Debug + Send + Sync, +{ + type Credential = T; + + async fn get_credential(&self) -> object_store::Result> { + Ok(Arc::clone(&self.credential)) + } +} + fn runtime<'a, C: Context<'a>>(cx: &mut C) -> NeonResult<&'static Runtime> { static RUNTIME: OnceCell = OnceCell::new(); @@ -97,19 +128,74 @@ fn database_table_names(mut cx: FunctionContext) -> JsResult { Ok(promise) } +fn get_aws_creds( + cx: &mut FunctionContext, + arg_starting_location: i32, +) -> Result, NeonResult> { + let secret_key_id = cx + .argument_opt(arg_starting_location) + .map(|arg| arg.downcast_or_throw::(cx).ok()) + .flatten() + .map(|v| v.value(cx)); + + let secret_key = cx + .argument_opt(arg_starting_location + 1) + .map(|arg| arg.downcast_or_throw::(cx).ok()) + .flatten() + .map(|v| v.value(cx)); + + let temp_token = cx + .argument_opt(arg_starting_location + 2) + .map(|arg| arg.downcast_or_throw::(cx).ok()) + .flatten() + .map(|v| v.value(cx)); + + match (secret_key_id, secret_key, temp_token) { + (Some(key_id), Some(key), optional_token) => Ok(Some(Arc::new( + StaticCredentialProvider::new(AwsCredential { + key_id: key_id, + secret_key: key, + token: optional_token, + }), + ))), + (None, None, None) => Ok(None), + _ => Err(cx.throw_error("Invalid credentials configuration")), + } +} + fn database_open_table(mut cx: FunctionContext) -> JsResult { let db = cx .this() .downcast_or_throw::, _>(&mut cx)?; let table_name = cx.argument::(0)?.value(&mut cx); + let aws_creds = match get_aws_creds(&mut cx, 1) { + Ok(creds) => creds, + Err(err) => return err, + }; + + let param = ReadParams { + store_options: Some(ObjectStoreParams { + aws_credentials: aws_creds, + ..ObjectStoreParams::default() + }), + ..ReadParams::default() + }; + let rt = runtime(&mut cx)?; let channel = cx.channel(); let database = db.database.clone(); let (deferred, promise) = cx.promise(); rt.spawn(async move { - let table_rst = database.open_table(&table_name).await; + let table_rst = database + .open_table_with_params( + &table_name, + OpenTableParams { + open_table_params: param, + }, + ) + .await; deferred.settle_with(&channel, move |mut cx| { let table = Arc::new(Mutex::new( @@ -241,8 +327,6 @@ fn table_create(mut cx: FunctionContext) -> JsResult { "create" => WriteMode::Create, _ => return cx.throw_error("Table::create only supports 'overwrite' and 'create' modes"), }; - let mut params = WriteParams::default(); - params.mode = mode; let rt = runtime(&mut cx)?; let channel = cx.channel(); @@ -250,11 +334,22 @@ fn table_create(mut cx: FunctionContext) -> JsResult { let (deferred, promise) = cx.promise(); let database = db.database.clone(); + let aws_creds = match get_aws_creds(&mut cx, 3) { + Ok(creds) => creds, + Err(err) => return err, + }; + + let params = WriteParams { + store_params: Some(ObjectStoreParams { + aws_credentials: aws_creds, + ..ObjectStoreParams::default() + }), + mode: mode, + ..WriteParams::default() + }; + rt.block_on(async move { - let batch_reader: Box = Box::new(RecordBatchIterator::new( - batches.into_iter().map(Ok), - schema, - )); + let batch_reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema); let table_rst = database .create_table(&table_name, batch_reader, Some(params)) .await; @@ -289,16 +384,27 @@ fn table_add(mut cx: FunctionContext) -> JsResult { let table = js_table.table.clone(); let write_mode = write_mode_map.get(write_mode.as_str()).cloned(); + let aws_creds = match get_aws_creds(&mut cx, 2) { + Ok(creds) => creds, + Err(err) => return err, + }; + + let params = WriteParams { + store_params: Some(ObjectStoreParams { + aws_credentials: aws_creds, + ..ObjectStoreParams::default() + }), + mode: write_mode.unwrap_or(WriteMode::Append), + ..WriteParams::default() + }; + rt.block_on(async move { - let batch_reader: Box = Box::new(RecordBatchIterator::new( - batches.into_iter().map(Ok), - schema, - )); - let add_result = table.lock().unwrap().add(batch_reader, write_mode).await; + let batch_reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema); + let add_result = table.lock().unwrap().add(batch_reader, Some(params)).await; deferred.settle_with(&channel, move |mut cx| { - let added = add_result.or_else(|err| cx.throw_error(err.to_string()))?; - Ok(cx.number(added as f64)) + let _added = add_result.or_else(|err| cx.throw_error(err.to_string()))?; + Ok(cx.boolean(true)) }); }); Ok(promise) diff --git a/rust/vectordb/src/database.rs b/rust/vectordb/src/database.rs index c759d9ab..4b0f161c 100644 --- a/rust/vectordb/src/database.rs +++ b/rust/vectordb/src/database.rs @@ -100,7 +100,7 @@ impl Database { pub async fn create_table( &self, name: &str, - batches: Box, + batches: impl RecordBatchReader + Send + 'static, params: Option, ) -> Result
{ Table::create(&self.uri, name, batches, params).await diff --git a/rust/vectordb/src/query.rs b/rust/vectordb/src/query.rs index 41cf8174..372ec7e9 100644 --- a/rust/vectordb/src/query.rs +++ b/rust/vectordb/src/query.rs @@ -173,10 +173,8 @@ mod tests { #[tokio::test] async fn test_setters_getters() { - let mut batches: Box = make_test_batches(); - let ds = Dataset::write(&mut batches, "memory://foo", None) - .await - .unwrap(); + let batches = make_test_batches(); + let ds = Dataset::write(batches, "memory://foo", None).await.unwrap(); let vector = Float32Array::from_iter_values([0.1, 0.2]); let query = Query::new(Arc::new(ds), vector.clone()); @@ -202,10 +200,8 @@ mod tests { #[tokio::test] async fn test_execute() { - let mut batches: Box = make_test_batches(); - let ds = Dataset::write(&mut batches, "memory://foo", None) - .await - .unwrap(); + let batches = make_test_batches(); + let ds = Dataset::write(batches, "memory://foo", None).await.unwrap(); let vector = Float32Array::from_iter_values([0.1; 128]); let query = Query::new(Arc::new(ds), vector.clone()); @@ -213,7 +209,7 @@ mod tests { assert_eq!(result.is_ok(), true); } - fn make_test_batches() -> Box { + fn make_test_batches() -> impl RecordBatchReader + Send + 'static { let dim: usize = 128; let schema = Arc::new(ArrowSchema::new(vec![ ArrowField::new("key", DataType::Int32, false), @@ -227,11 +223,11 @@ mod tests { ), ArrowField::new("uri", DataType::Utf8, true), ])); - Box::new(RecordBatchIterator::new( + RecordBatchIterator::new( vec![RecordBatch::new_empty(schema.clone())] .into_iter() .map(Ok), schema, - )) + ) } } diff --git a/rust/vectordb/src/table.rs b/rust/vectordb/src/table.rs index ac9544a6..d01ceacd 100644 --- a/rust/vectordb/src/table.rs +++ b/rust/vectordb/src/table.rs @@ -22,8 +22,8 @@ use snafu::prelude::*; use crate::error::{Error, InvalidTableNameSnafu, Result}; use crate::index::vector::VectorIndexBuilder; -use crate::WriteMode; use crate::query::Query; +use crate::WriteMode; pub const VECTOR_COLUMN_NAME: &str = "vector"; pub const LANCE_FILE_EXTENSION: &str = "lance"; @@ -117,7 +117,7 @@ impl Table { pub async fn create( base_uri: &str, name: &str, - mut batches: Box, + batches: impl RecordBatchReader + Send + 'static, params: Option, ) -> Result { let base_path = Path::new(base_uri); @@ -127,7 +127,7 @@ impl Table { .to_str() .context(InvalidTableNameSnafu { name })? .to_string(); - let dataset = Dataset::write(&mut batches, &uri, params) + let dataset = Dataset::write(batches, &uri, params) .await .map_err(|e| match e { lance::Error::DatasetAlreadyExists { .. } => Error::TableAlreadyExists { @@ -176,14 +176,16 @@ impl Table { /// * The number of rows added pub async fn add( &mut self, - mut batches: Box, - write_mode: Option, - ) -> Result { - let mut params = WriteParams::default(); - params.mode = write_mode.unwrap_or(WriteMode::Append); + batches: impl RecordBatchReader + Send + 'static, + params: Option, + ) -> Result<()> { + let params = params.unwrap_or(WriteParams { + mode: WriteMode::Append, + ..WriteParams::default() + }); - self.dataset = Arc::new(Dataset::write(&mut batches, &self.uri, Some(params)).await?); - Ok(batches.count()) + self.dataset = Arc::new(Dataset::write(batches, &self.uri, Some(params)).await?); + Ok(()) } /// Creates a new Query object that can be executed. @@ -207,12 +209,12 @@ impl Table { /// Merge new data into this table. pub async fn merge( &mut self, - mut batches: Box, + batches: impl RecordBatchReader + Send + 'static, left_on: &str, right_on: &str, ) -> Result<()> { let mut dataset = self.dataset.as_ref().clone(); - dataset.merge(&mut batches, left_on, right_on).await?; + dataset.merge(batches, left_on, right_on).await?; self.dataset = Arc::new(dataset); Ok(()) } @@ -253,8 +255,8 @@ mod tests { let dataset_path = tmp_dir.path().join("test.lance"); let uri = tmp_dir.path().to_str().unwrap(); - let mut batches: Box = make_test_batches(); - Dataset::write(&mut batches, dataset_path.to_str().unwrap(), None) + let batches = make_test_batches(); + Dataset::write(batches, dataset_path.to_str().unwrap(), None) .await .unwrap(); @@ -284,11 +286,11 @@ mod tests { let tmp_dir = tempdir().unwrap(); let uri = tmp_dir.path().to_str().unwrap(); - let batches: Box = make_test_batches(); + let batches = make_test_batches(); let _ = batches.schema().clone(); Table::create(&uri, "test", batches, None).await.unwrap(); - let batches: Box = make_test_batches(); + let batches = make_test_batches(); let result = Table::create(&uri, "test", batches, None).await; assert!(matches!( result.unwrap_err(), @@ -301,12 +303,12 @@ mod tests { let tmp_dir = tempdir().unwrap(); let uri = tmp_dir.path().to_str().unwrap(); - let batches: Box = make_test_batches(); + let batches = make_test_batches(); let schema = batches.schema().clone(); let mut table = Table::create(&uri, "test", batches, None).await.unwrap(); assert_eq!(table.count_rows().await.unwrap(), 10); - let new_batches: Box = Box::new(RecordBatchIterator::new( + let new_batches = RecordBatchIterator::new( vec![RecordBatch::try_new( schema.clone(), vec![Arc::new(Int32Array::from_iter_values(100..110))], @@ -315,7 +317,7 @@ mod tests { .into_iter() .map(Ok), schema.clone(), - )); + ); table.add(new_batches, None).await.unwrap(); assert_eq!(table.count_rows().await.unwrap(), 20); @@ -327,12 +329,12 @@ mod tests { let tmp_dir = tempdir().unwrap(); let uri = tmp_dir.path().to_str().unwrap(); - let batches: Box = make_test_batches(); + let batches = make_test_batches(); let schema = batches.schema().clone(); let mut table = Table::create(uri, "test", batches, None).await.unwrap(); assert_eq!(table.count_rows().await.unwrap(), 10); - let new_batches: Box = Box::new(RecordBatchIterator::new( + let new_batches = RecordBatchIterator::new( vec![RecordBatch::try_new( schema.clone(), vec![Arc::new(Int32Array::from_iter_values(100..110))], @@ -341,10 +343,15 @@ mod tests { .into_iter() .map(Ok), schema.clone(), - )); + ); + + let param: WriteParams = WriteParams { + mode: WriteMode::Overwrite, + ..Default::default() + }; table - .add(new_batches, Some(WriteMode::Overwrite)) + .add(new_batches, Some(param)) .await .unwrap(); assert_eq!(table.count_rows().await.unwrap(), 10); @@ -357,8 +364,8 @@ mod tests { let dataset_path = tmp_dir.path().join("test.lance"); let uri = tmp_dir.path().to_str().unwrap(); - let mut batches: Box = make_test_batches(); - Dataset::write(&mut batches, dataset_path.to_str().unwrap(), None) + let batches = make_test_batches(); + Dataset::write(batches, dataset_path.to_str().unwrap(), None) .await .unwrap(); @@ -369,7 +376,7 @@ mod tests { assert_eq!(vector, query.query_vector); } - #[derive(Default)] + #[derive(Default, Debug)] struct NoOpCacheWrapper { called: AtomicBool, } @@ -396,8 +403,8 @@ mod tests { let dataset_path = tmp_dir.path().join("test.lance"); let uri = tmp_dir.path().to_str().unwrap(); - let mut batches: Box = make_test_batches(); - Dataset::write(&mut batches, dataset_path.to_str().unwrap(), None) + let batches = make_test_batches(); + Dataset::write(batches, dataset_path.to_str().unwrap(), None) .await .unwrap(); @@ -417,15 +424,15 @@ mod tests { assert!(wrapper.called()); } - fn make_test_batches() -> Box { + fn make_test_batches() -> impl RecordBatchReader + Send + Sync + 'static { let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)])); - Box::new(RecordBatchIterator::new( + RecordBatchIterator::new( vec![RecordBatch::try_new( schema.clone(), vec![Arc::new(Int32Array::from_iter_values(0..10))], )], schema, - )) + ) } #[tokio::test] @@ -465,9 +472,7 @@ mod tests { schema, ); - let reader: Box = Box::new(batches); - let mut table = Table::create(uri, "test", reader, None).await.unwrap(); - + let mut table = Table::create(uri, "test", batches, None).await.unwrap(); let mut i = IvfPQIndexBuilder::new(); let index_builder = i