diff --git a/node/src/index.ts b/node/src/index.ts index bb2af069..ed34013a 100644 --- a/node/src/index.ts +++ b/node/src/index.ts @@ -163,6 +163,7 @@ export async function connect ( { uri: '', awsCredentials: undefined, + awsRegion: defaultAwsRegion, apiKey: undefined, region: defaultAwsRegion }, @@ -174,7 +175,13 @@ export async function connect ( // Remote connection return new RemoteConnection(opts) } - const db = await databaseNew(opts.uri) + const db = await databaseNew( + opts.uri, + opts.awsCredentials?.accessKeyId, + opts.awsCredentials?.secretKey, + opts.awsCredentials?.sessionToken, + opts.awsRegion + ) return new LocalConnection(db, opts) } @@ -443,7 +450,7 @@ export interface Table { */ indexStats: (indexUuid: string) => Promise - filter (value: string): Query + filter(value: string): Query schema: Promise } diff --git a/rust/ffi/node/src/lib.rs b/rust/ffi/node/src/lib.rs index 9547e044..4fed6e28 100644 --- a/rust/ffi/node/src/lib.rs +++ b/rust/ffi/node/src/lib.rs @@ -24,7 +24,7 @@ use tokio::runtime::Runtime; use vectordb::connection::Database; use vectordb::table::ReadParams; -use vectordb::Connection; +use vectordb::{ConnectOptions, Connection}; use crate::error::ResultExt; use crate::query::JsQuery; @@ -82,13 +82,26 @@ fn runtime<'a, C: Context<'a>>(cx: &mut C) -> NeonResult<&'static Runtime> { fn database_new(mut cx: FunctionContext) -> JsResult { let path = cx.argument::(0)?.value(&mut cx); + let aws_creds = get_aws_creds(&mut cx, 1)?; + let region = get_aws_region(&mut cx, 4)?; let rt = runtime(&mut cx)?; let channel = cx.channel(); let (deferred, promise) = cx.promise(); + let mut conn_options = ConnectOptions::new(&path); + if let Some(region) = region { + conn_options = conn_options.region(®ion); + } + if let Some(aws_creds) = aws_creds { + conn_options = conn_options.aws_creds(AwsCredential { + key_id: aws_creds.key_id, + secret_key: aws_creds.secret_key, + token: aws_creds.token, + }); + } rt.spawn(async move { - let database = Database::connect(&path).await; + let database = Database::connect_with_options(&conn_options).await; deferred.settle_with(&channel, move |mut cx| { let db = JsDatabase { @@ -127,7 +140,7 @@ fn database_table_names(mut cx: FunctionContext) -> JsResult { fn get_aws_creds( cx: &mut FunctionContext, arg_starting_location: i32, -) -> NeonResult> { +) -> NeonResult> { let secret_key_id = cx .argument_opt(arg_starting_location) .filter(|arg| arg.is_a::(cx)) @@ -147,18 +160,26 @@ fn get_aws_creds( .map(|v| v.value(cx)); match (secret_key_id, secret_key, temp_token) { - (Some(key_id), Some(key), optional_token) => Ok(Some(Arc::new( - StaticCredentialProvider::new(AwsCredential { - key_id, - secret_key: key, - token: optional_token, - }), - ))), + (Some(key_id), Some(key), optional_token) => Ok(Some(AwsCredential { + key_id, + secret_key: key, + token: optional_token, + })), (None, None, None) => Ok(None), _ => cx.throw_error("Invalid credentials configuration"), } } +fn get_aws_credential_provider( + cx: &mut FunctionContext, + arg_starting_location: i32, +) -> NeonResult> { + Ok(get_aws_creds(cx, arg_starting_location)?.map(|aws_cred| { + Arc::new(StaticCredentialProvider::new(aws_cred)) + as Arc> + })) +} + /// Get AWS region arguments from the context fn get_aws_region(cx: &mut FunctionContext, arg_location: i32) -> NeonResult> { let region = cx @@ -179,7 +200,7 @@ fn database_open_table(mut cx: FunctionContext) -> JsResult { .downcast_or_throw::, _>(&mut cx)?; let table_name = cx.argument::(0)?.value(&mut cx); - let aws_creds = get_aws_creds(&mut cx, 1)?; + let aws_creds = get_aws_credential_provider(&mut cx, 1)?; let aws_region = get_aws_region(&mut cx, 4)?; diff --git a/rust/ffi/node/src/table.rs b/rust/ffi/node/src/table.rs index 63125274..ae01bb75 100644 --- a/rust/ffi/node/src/table.rs +++ b/rust/ffi/node/src/table.rs @@ -24,7 +24,7 @@ use neon::types::buffer::TypedArray; use vectordb::TableRef; use crate::error::ResultExt; -use crate::{convert, get_aws_creds, get_aws_region, runtime, JsDatabase}; +use crate::{convert, get_aws_credential_provider, get_aws_region, runtime, JsDatabase}; pub(crate) struct JsTable { pub table: TableRef, @@ -64,7 +64,7 @@ impl JsTable { let (deferred, promise) = cx.promise(); let database = db.database.clone(); - let aws_creds = get_aws_creds(&mut cx, 3)?; + let aws_creds = get_aws_credential_provider(&mut cx, 3)?; let aws_region = get_aws_region(&mut cx, 6)?; let params = WriteParams { @@ -106,7 +106,7 @@ impl JsTable { "overwrite" => WriteMode::Overwrite, s => return cx.throw_error(format!("invalid write mode {}", s)), }; - let aws_creds = get_aws_creds(&mut cx, 2)?; + let aws_creds = get_aws_credential_provider(&mut cx, 2)?; let aws_region = get_aws_region(&mut cx, 5)?; let params = WriteParams { diff --git a/rust/vectordb/src/connection.rs b/rust/vectordb/src/connection.rs index c8fc7921..670b059e 100644 --- a/rust/vectordb/src/connection.rs +++ b/rust/vectordb/src/connection.rs @@ -21,8 +21,10 @@ use std::sync::Arc; use arrow_array::RecordBatchReader; use lance::dataset::WriteParams; -use lance::io::{ObjectStore, WrappingObjectStore}; -use object_store::local::LocalFileSystem; +use lance::io::{ObjectStore, ObjectStoreParams, WrappingObjectStore}; +use object_store::{ + aws::AwsCredential, local::LocalFileSystem, CredentialProvider, StaticCredentialProvider, +}; use snafu::prelude::*; use crate::error::{CreateDirSnafu, Error, InvalidTableNameSnafu, Result}; @@ -86,6 +88,9 @@ pub struct ConnectOptions { /// Lance Cloud host override pub host_override: Option, + /// User provided AWS credentials + pub aws_creds: Option, + /// The maximum number of indices to cache in memory. Defaults to 256. pub index_cache_size: u32, } @@ -98,6 +103,7 @@ impl ConnectOptions { api_key: None, region: None, host_override: None, + aws_creds: None, index_cache_size: 256, } } @@ -117,6 +123,13 @@ impl ConnectOptions { self } + /// [`AwsCredential`] to use when connecting to S3. + /// + pub fn aws_creds(mut self, aws_creds: AwsCredential) -> Self { + self.aws_creds = Some(aws_creds); + self + } + pub fn index_cache_size(mut self, index_cache_size: u32) -> Self { self.index_cache_size = index_cache_size; self @@ -176,6 +189,12 @@ impl Database { /// /// * A [Database] object. pub async fn connect(uri: &str) -> Result { + let options = ConnectOptions::new(uri); + Self::connect_with_options(&options).await + } + + pub async fn connect_with_options(options: &ConnectOptions) -> Result { + let uri = &options.uri; let parse_res = url::Url::parse(uri); match parse_res { @@ -227,7 +246,23 @@ impl Database { }; let plain_uri = url.to_string(); - let (object_store, base_path) = ObjectStore::from_uri(&plain_uri).await?; + let os_params: ObjectStoreParams = if let Some(aws_creds) = &options.aws_creds { + let credential_provider: Arc< + dyn CredentialProvider, + > = Arc::new(StaticCredentialProvider::new(AwsCredential { + key_id: aws_creds.key_id.clone(), + secret_key: aws_creds.secret_key.clone(), + token: aws_creds.token.clone(), + })); + ObjectStoreParams::with_aws_credentials( + Some(credential_provider), + options.region.clone(), + ) + } else { + ObjectStoreParams::default() + }; + let (object_store, base_path) = + ObjectStore::from_uri_and_params(&plain_uri, &os_params).await?; if object_store.is_local() { Self::try_create_dir(&plain_uri).context(CreateDirSnafu { path: plain_uri })?; }