fix(node): pass AWS credentials to db level operations (#908)

Passed the following tests

```ts
const keyId = process.env.AWS_ACCESS_KEY_ID;
const secretKey = process.env.AWS_SECRET_ACCESS_KEY;
const sessionToken = process.env.AWS_SESSION_TOKEN;
const region = process.env.AWS_REGION;

const db = await lancedb.connect({
  uri: "s3://bucket/path",
  awsCredentials: {
    accessKeyId: keyId,
    secretKey: secretKey,
    sessionToken: sessionToken,
  },
  awsRegion: region,
} as lancedb.ConnectionOptions);

  console.log(await db.createTable("test", [{ vector: [1, 2, 3] }]));
  console.log(await db.tableNames());
  console.log(await db.dropTable("test"))
```
This commit is contained in:
Lei Xu
2024-01-31 12:05:01 -08:00
committed by GitHub
parent 8d0ea29f89
commit 5f59e51583
4 changed files with 82 additions and 19 deletions

View File

@@ -163,6 +163,7 @@ export async function connect (
{
uri: '',
awsCredentials: undefined,
awsRegion: defaultAwsRegion,
apiKey: undefined,
region: defaultAwsRegion
},
@@ -174,7 +175,13 @@ export async function connect (
// Remote connection
return new RemoteConnection(opts)
}
const db = await databaseNew(opts.uri)
const db = await databaseNew(
opts.uri,
opts.awsCredentials?.accessKeyId,
opts.awsCredentials?.secretKey,
opts.awsCredentials?.sessionToken,
opts.awsRegion
)
return new LocalConnection(db, opts)
}
@@ -443,7 +450,7 @@ export interface Table<T = number[]> {
*/
indexStats: (indexUuid: string) => Promise<IndexStats>
filter (value: string): Query<T>
filter(value: string): Query<T>
schema: Promise<Schema>
}

View File

@@ -24,7 +24,7 @@ use tokio::runtime::Runtime;
use vectordb::connection::Database;
use vectordb::table::ReadParams;
use vectordb::Connection;
use vectordb::{ConnectOptions, Connection};
use crate::error::ResultExt;
use crate::query::JsQuery;
@@ -82,13 +82,26 @@ fn runtime<'a, C: Context<'a>>(cx: &mut C) -> NeonResult<&'static Runtime> {
fn database_new(mut cx: FunctionContext) -> JsResult<JsPromise> {
let path = cx.argument::<JsString>(0)?.value(&mut cx);
let aws_creds = get_aws_creds(&mut cx, 1)?;
let region = get_aws_region(&mut cx, 4)?;
let rt = runtime(&mut cx)?;
let channel = cx.channel();
let (deferred, promise) = cx.promise();
let mut conn_options = ConnectOptions::new(&path);
if let Some(region) = region {
conn_options = conn_options.region(&region);
}
if let Some(aws_creds) = aws_creds {
conn_options = conn_options.aws_creds(AwsCredential {
key_id: aws_creds.key_id,
secret_key: aws_creds.secret_key,
token: aws_creds.token,
});
}
rt.spawn(async move {
let database = Database::connect(&path).await;
let database = Database::connect_with_options(&conn_options).await;
deferred.settle_with(&channel, move |mut cx| {
let db = JsDatabase {
@@ -127,7 +140,7 @@ fn database_table_names(mut cx: FunctionContext) -> JsResult<JsPromise> {
fn get_aws_creds(
cx: &mut FunctionContext,
arg_starting_location: i32,
) -> NeonResult<Option<AwsCredentialProvider>> {
) -> NeonResult<Option<AwsCredential>> {
let secret_key_id = cx
.argument_opt(arg_starting_location)
.filter(|arg| arg.is_a::<JsString, _>(cx))
@@ -147,18 +160,26 @@ fn get_aws_creds(
.map(|v| v.value(cx));
match (secret_key_id, secret_key, temp_token) {
(Some(key_id), Some(key), optional_token) => Ok(Some(Arc::new(
StaticCredentialProvider::new(AwsCredential {
key_id,
secret_key: key,
token: optional_token,
}),
))),
(Some(key_id), Some(key), optional_token) => Ok(Some(AwsCredential {
key_id,
secret_key: key,
token: optional_token,
})),
(None, None, None) => Ok(None),
_ => cx.throw_error("Invalid credentials configuration"),
}
}
fn get_aws_credential_provider(
cx: &mut FunctionContext,
arg_starting_location: i32,
) -> NeonResult<Option<AwsCredentialProvider>> {
Ok(get_aws_creds(cx, arg_starting_location)?.map(|aws_cred| {
Arc::new(StaticCredentialProvider::new(aws_cred))
as Arc<dyn CredentialProvider<Credential = AwsCredential>>
}))
}
/// Get AWS region arguments from the context
fn get_aws_region(cx: &mut FunctionContext, arg_location: i32) -> NeonResult<Option<String>> {
let region = cx
@@ -179,7 +200,7 @@ fn database_open_table(mut cx: FunctionContext) -> JsResult<JsPromise> {
.downcast_or_throw::<JsBox<JsDatabase>, _>(&mut cx)?;
let table_name = cx.argument::<JsString>(0)?.value(&mut cx);
let aws_creds = get_aws_creds(&mut cx, 1)?;
let aws_creds = get_aws_credential_provider(&mut cx, 1)?;
let aws_region = get_aws_region(&mut cx, 4)?;

View File

@@ -24,7 +24,7 @@ use neon::types::buffer::TypedArray;
use vectordb::TableRef;
use crate::error::ResultExt;
use crate::{convert, get_aws_creds, get_aws_region, runtime, JsDatabase};
use crate::{convert, get_aws_credential_provider, get_aws_region, runtime, JsDatabase};
pub(crate) struct JsTable {
pub table: TableRef,
@@ -64,7 +64,7 @@ impl JsTable {
let (deferred, promise) = cx.promise();
let database = db.database.clone();
let aws_creds = get_aws_creds(&mut cx, 3)?;
let aws_creds = get_aws_credential_provider(&mut cx, 3)?;
let aws_region = get_aws_region(&mut cx, 6)?;
let params = WriteParams {
@@ -106,7 +106,7 @@ impl JsTable {
"overwrite" => WriteMode::Overwrite,
s => return cx.throw_error(format!("invalid write mode {}", s)),
};
let aws_creds = get_aws_creds(&mut cx, 2)?;
let aws_creds = get_aws_credential_provider(&mut cx, 2)?;
let aws_region = get_aws_region(&mut cx, 5)?;
let params = WriteParams {

View File

@@ -21,8 +21,10 @@ use std::sync::Arc;
use arrow_array::RecordBatchReader;
use lance::dataset::WriteParams;
use lance::io::{ObjectStore, WrappingObjectStore};
use object_store::local::LocalFileSystem;
use lance::io::{ObjectStore, ObjectStoreParams, WrappingObjectStore};
use object_store::{
aws::AwsCredential, local::LocalFileSystem, CredentialProvider, StaticCredentialProvider,
};
use snafu::prelude::*;
use crate::error::{CreateDirSnafu, Error, InvalidTableNameSnafu, Result};
@@ -86,6 +88,9 @@ pub struct ConnectOptions {
/// Lance Cloud host override
pub host_override: Option<String>,
/// User provided AWS credentials
pub aws_creds: Option<AwsCredential>,
/// The maximum number of indices to cache in memory. Defaults to 256.
pub index_cache_size: u32,
}
@@ -98,6 +103,7 @@ impl ConnectOptions {
api_key: None,
region: None,
host_override: None,
aws_creds: None,
index_cache_size: 256,
}
}
@@ -117,6 +123,13 @@ impl ConnectOptions {
self
}
/// [`AwsCredential`] to use when connecting to S3.
///
pub fn aws_creds(mut self, aws_creds: AwsCredential) -> Self {
self.aws_creds = Some(aws_creds);
self
}
pub fn index_cache_size(mut self, index_cache_size: u32) -> Self {
self.index_cache_size = index_cache_size;
self
@@ -176,6 +189,12 @@ impl Database {
///
/// * A [Database] object.
pub async fn connect(uri: &str) -> Result<Database> {
let options = ConnectOptions::new(uri);
Self::connect_with_options(&options).await
}
pub async fn connect_with_options(options: &ConnectOptions) -> Result<Database> {
let uri = &options.uri;
let parse_res = url::Url::parse(uri);
match parse_res {
@@ -227,7 +246,23 @@ impl Database {
};
let plain_uri = url.to_string();
let (object_store, base_path) = ObjectStore::from_uri(&plain_uri).await?;
let os_params: ObjectStoreParams = if let Some(aws_creds) = &options.aws_creds {
let credential_provider: Arc<
dyn CredentialProvider<Credential = AwsCredential>,
> = Arc::new(StaticCredentialProvider::new(AwsCredential {
key_id: aws_creds.key_id.clone(),
secret_key: aws_creds.secret_key.clone(),
token: aws_creds.token.clone(),
}));
ObjectStoreParams::with_aws_credentials(
Some(credential_provider),
options.region.clone(),
)
} else {
ObjectStoreParams::default()
};
let (object_store, base_path) =
ObjectStore::from_uri_and_params(&plain_uri, &os_params).await?;
if object_store.is_local() {
Self::try_create_dir(&plain_uri).context(CreateDirSnafu { path: plain_uri })?;
}