mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-08 12:52:58 +00:00
fix(node): pass AWS credentials to db level operations (#908)
Passed the following tests
```ts
const keyId = process.env.AWS_ACCESS_KEY_ID;
const secretKey = process.env.AWS_SECRET_ACCESS_KEY;
const sessionToken = process.env.AWS_SESSION_TOKEN;
const region = process.env.AWS_REGION;
const db = await lancedb.connect({
uri: "s3://bucket/path",
awsCredentials: {
accessKeyId: keyId,
secretKey: secretKey,
sessionToken: sessionToken,
},
awsRegion: region,
} as lancedb.ConnectionOptions);
console.log(await db.createTable("test", [{ vector: [1, 2, 3] }]));
console.log(await db.tableNames());
console.log(await db.dropTable("test"))
```
This commit is contained in:
@@ -163,6 +163,7 @@ export async function connect (
|
||||
{
|
||||
uri: '',
|
||||
awsCredentials: undefined,
|
||||
awsRegion: defaultAwsRegion,
|
||||
apiKey: undefined,
|
||||
region: defaultAwsRegion
|
||||
},
|
||||
@@ -174,7 +175,13 @@ export async function connect (
|
||||
// Remote connection
|
||||
return new RemoteConnection(opts)
|
||||
}
|
||||
const db = await databaseNew(opts.uri)
|
||||
const db = await databaseNew(
|
||||
opts.uri,
|
||||
opts.awsCredentials?.accessKeyId,
|
||||
opts.awsCredentials?.secretKey,
|
||||
opts.awsCredentials?.sessionToken,
|
||||
opts.awsRegion
|
||||
)
|
||||
return new LocalConnection(db, opts)
|
||||
}
|
||||
|
||||
@@ -443,7 +450,7 @@ export interface Table<T = number[]> {
|
||||
*/
|
||||
indexStats: (indexUuid: string) => Promise<IndexStats>
|
||||
|
||||
filter (value: string): Query<T>
|
||||
filter(value: string): Query<T>
|
||||
|
||||
schema: Promise<Schema>
|
||||
}
|
||||
|
||||
@@ -24,7 +24,7 @@ use tokio::runtime::Runtime;
|
||||
|
||||
use vectordb::connection::Database;
|
||||
use vectordb::table::ReadParams;
|
||||
use vectordb::Connection;
|
||||
use vectordb::{ConnectOptions, Connection};
|
||||
|
||||
use crate::error::ResultExt;
|
||||
use crate::query::JsQuery;
|
||||
@@ -82,13 +82,26 @@ fn runtime<'a, C: Context<'a>>(cx: &mut C) -> NeonResult<&'static Runtime> {
|
||||
|
||||
fn database_new(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let path = cx.argument::<JsString>(0)?.value(&mut cx);
|
||||
let aws_creds = get_aws_creds(&mut cx, 1)?;
|
||||
let region = get_aws_region(&mut cx, 4)?;
|
||||
|
||||
let rt = runtime(&mut cx)?;
|
||||
let channel = cx.channel();
|
||||
let (deferred, promise) = cx.promise();
|
||||
|
||||
let mut conn_options = ConnectOptions::new(&path);
|
||||
if let Some(region) = region {
|
||||
conn_options = conn_options.region(®ion);
|
||||
}
|
||||
if let Some(aws_creds) = aws_creds {
|
||||
conn_options = conn_options.aws_creds(AwsCredential {
|
||||
key_id: aws_creds.key_id,
|
||||
secret_key: aws_creds.secret_key,
|
||||
token: aws_creds.token,
|
||||
});
|
||||
}
|
||||
rt.spawn(async move {
|
||||
let database = Database::connect(&path).await;
|
||||
let database = Database::connect_with_options(&conn_options).await;
|
||||
|
||||
deferred.settle_with(&channel, move |mut cx| {
|
||||
let db = JsDatabase {
|
||||
@@ -127,7 +140,7 @@ fn database_table_names(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
fn get_aws_creds(
|
||||
cx: &mut FunctionContext,
|
||||
arg_starting_location: i32,
|
||||
) -> NeonResult<Option<AwsCredentialProvider>> {
|
||||
) -> NeonResult<Option<AwsCredential>> {
|
||||
let secret_key_id = cx
|
||||
.argument_opt(arg_starting_location)
|
||||
.filter(|arg| arg.is_a::<JsString, _>(cx))
|
||||
@@ -147,18 +160,26 @@ fn get_aws_creds(
|
||||
.map(|v| v.value(cx));
|
||||
|
||||
match (secret_key_id, secret_key, temp_token) {
|
||||
(Some(key_id), Some(key), optional_token) => Ok(Some(Arc::new(
|
||||
StaticCredentialProvider::new(AwsCredential {
|
||||
key_id,
|
||||
secret_key: key,
|
||||
token: optional_token,
|
||||
}),
|
||||
))),
|
||||
(Some(key_id), Some(key), optional_token) => Ok(Some(AwsCredential {
|
||||
key_id,
|
||||
secret_key: key,
|
||||
token: optional_token,
|
||||
})),
|
||||
(None, None, None) => Ok(None),
|
||||
_ => cx.throw_error("Invalid credentials configuration"),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_aws_credential_provider(
|
||||
cx: &mut FunctionContext,
|
||||
arg_starting_location: i32,
|
||||
) -> NeonResult<Option<AwsCredentialProvider>> {
|
||||
Ok(get_aws_creds(cx, arg_starting_location)?.map(|aws_cred| {
|
||||
Arc::new(StaticCredentialProvider::new(aws_cred))
|
||||
as Arc<dyn CredentialProvider<Credential = AwsCredential>>
|
||||
}))
|
||||
}
|
||||
|
||||
/// Get AWS region arguments from the context
|
||||
fn get_aws_region(cx: &mut FunctionContext, arg_location: i32) -> NeonResult<Option<String>> {
|
||||
let region = cx
|
||||
@@ -179,7 +200,7 @@ fn database_open_table(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
.downcast_or_throw::<JsBox<JsDatabase>, _>(&mut cx)?;
|
||||
let table_name = cx.argument::<JsString>(0)?.value(&mut cx);
|
||||
|
||||
let aws_creds = get_aws_creds(&mut cx, 1)?;
|
||||
let aws_creds = get_aws_credential_provider(&mut cx, 1)?;
|
||||
|
||||
let aws_region = get_aws_region(&mut cx, 4)?;
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ use neon::types::buffer::TypedArray;
|
||||
use vectordb::TableRef;
|
||||
|
||||
use crate::error::ResultExt;
|
||||
use crate::{convert, get_aws_creds, get_aws_region, runtime, JsDatabase};
|
||||
use crate::{convert, get_aws_credential_provider, get_aws_region, runtime, JsDatabase};
|
||||
|
||||
pub(crate) struct JsTable {
|
||||
pub table: TableRef,
|
||||
@@ -64,7 +64,7 @@ impl JsTable {
|
||||
let (deferred, promise) = cx.promise();
|
||||
let database = db.database.clone();
|
||||
|
||||
let aws_creds = get_aws_creds(&mut cx, 3)?;
|
||||
let aws_creds = get_aws_credential_provider(&mut cx, 3)?;
|
||||
let aws_region = get_aws_region(&mut cx, 6)?;
|
||||
|
||||
let params = WriteParams {
|
||||
@@ -106,7 +106,7 @@ impl JsTable {
|
||||
"overwrite" => WriteMode::Overwrite,
|
||||
s => return cx.throw_error(format!("invalid write mode {}", s)),
|
||||
};
|
||||
let aws_creds = get_aws_creds(&mut cx, 2)?;
|
||||
let aws_creds = get_aws_credential_provider(&mut cx, 2)?;
|
||||
let aws_region = get_aws_region(&mut cx, 5)?;
|
||||
|
||||
let params = WriteParams {
|
||||
|
||||
@@ -21,8 +21,10 @@ use std::sync::Arc;
|
||||
|
||||
use arrow_array::RecordBatchReader;
|
||||
use lance::dataset::WriteParams;
|
||||
use lance::io::{ObjectStore, WrappingObjectStore};
|
||||
use object_store::local::LocalFileSystem;
|
||||
use lance::io::{ObjectStore, ObjectStoreParams, WrappingObjectStore};
|
||||
use object_store::{
|
||||
aws::AwsCredential, local::LocalFileSystem, CredentialProvider, StaticCredentialProvider,
|
||||
};
|
||||
use snafu::prelude::*;
|
||||
|
||||
use crate::error::{CreateDirSnafu, Error, InvalidTableNameSnafu, Result};
|
||||
@@ -86,6 +88,9 @@ pub struct ConnectOptions {
|
||||
/// Lance Cloud host override
|
||||
pub host_override: Option<String>,
|
||||
|
||||
/// User provided AWS credentials
|
||||
pub aws_creds: Option<AwsCredential>,
|
||||
|
||||
/// The maximum number of indices to cache in memory. Defaults to 256.
|
||||
pub index_cache_size: u32,
|
||||
}
|
||||
@@ -98,6 +103,7 @@ impl ConnectOptions {
|
||||
api_key: None,
|
||||
region: None,
|
||||
host_override: None,
|
||||
aws_creds: None,
|
||||
index_cache_size: 256,
|
||||
}
|
||||
}
|
||||
@@ -117,6 +123,13 @@ impl ConnectOptions {
|
||||
self
|
||||
}
|
||||
|
||||
/// [`AwsCredential`] to use when connecting to S3.
|
||||
///
|
||||
pub fn aws_creds(mut self, aws_creds: AwsCredential) -> Self {
|
||||
self.aws_creds = Some(aws_creds);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn index_cache_size(mut self, index_cache_size: u32) -> Self {
|
||||
self.index_cache_size = index_cache_size;
|
||||
self
|
||||
@@ -176,6 +189,12 @@ impl Database {
|
||||
///
|
||||
/// * A [Database] object.
|
||||
pub async fn connect(uri: &str) -> Result<Database> {
|
||||
let options = ConnectOptions::new(uri);
|
||||
Self::connect_with_options(&options).await
|
||||
}
|
||||
|
||||
pub async fn connect_with_options(options: &ConnectOptions) -> Result<Database> {
|
||||
let uri = &options.uri;
|
||||
let parse_res = url::Url::parse(uri);
|
||||
|
||||
match parse_res {
|
||||
@@ -227,7 +246,23 @@ impl Database {
|
||||
};
|
||||
|
||||
let plain_uri = url.to_string();
|
||||
let (object_store, base_path) = ObjectStore::from_uri(&plain_uri).await?;
|
||||
let os_params: ObjectStoreParams = if let Some(aws_creds) = &options.aws_creds {
|
||||
let credential_provider: Arc<
|
||||
dyn CredentialProvider<Credential = AwsCredential>,
|
||||
> = Arc::new(StaticCredentialProvider::new(AwsCredential {
|
||||
key_id: aws_creds.key_id.clone(),
|
||||
secret_key: aws_creds.secret_key.clone(),
|
||||
token: aws_creds.token.clone(),
|
||||
}));
|
||||
ObjectStoreParams::with_aws_credentials(
|
||||
Some(credential_provider),
|
||||
options.region.clone(),
|
||||
)
|
||||
} else {
|
||||
ObjectStoreParams::default()
|
||||
};
|
||||
let (object_store, base_path) =
|
||||
ObjectStore::from_uri_and_params(&plain_uri, &os_params).await?;
|
||||
if object_store.is_local() {
|
||||
Self::try_create_dir(&plain_uri).context(CreateDirSnafu { path: plain_uri })?;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user