From 30f5bc5865c7913ce30d804a39591951520f5e3a Mon Sep 17 00:00:00 2001 From: Rob Meng Date: Tue, 22 Aug 2023 16:00:14 -0400 Subject: [PATCH] expose awsRegion to be configurable (#441) --- Cargo.toml | 2 +- node/src/index.ts | 69 ++++++++++++++++++-------------------- rust/ffi/node/src/lib.rs | 42 ++++++++++++++++------- rust/ffi/node/src/query.rs | 11 ++++-- rust/ffi/node/src/table.rs | 16 ++++----- 5 files changed, 79 insertions(+), 61 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 0ef5d7dd..cfb6a652 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,7 @@ members = [ resolver = "2" [workspace.dependencies] -lance = "=0.6.2" +lance = "=0.6.3" arrow-array = "43.0" arrow-data = "43.0" arrow-schema = "43.0" diff --git a/node/src/index.ts b/node/src/index.ts index fb88725c..30dff810 100644 --- a/node/src/index.ts +++ b/node/src/index.ts @@ -42,6 +42,8 @@ export interface ConnectionOptions { awsCredentials?: AwsCredentials + awsRegion?: string + // API key for the remote connections apiKey?: string // Region to connect @@ -51,6 +53,23 @@ export interface ConnectionOptions { hostOverride?: string } +function getAwsArgs (opts: ConnectionOptions): any[] { + const callArgs = [] + const awsCredentials = opts.awsCredentials + if (awsCredentials !== undefined) { + callArgs.push(awsCredentials.accessKeyId) + callArgs.push(awsCredentials.secretKey) + callArgs.push(awsCredentials.sessionToken) + } else { + callArgs.push(undefined) + callArgs.push(undefined) + callArgs.push(undefined) + } + + callArgs.push(opts.awsRegion) + return callArgs +} + export interface CreateTableOptions { // Name of Table name: string @@ -282,7 +301,7 @@ export class LocalConnection implements Connection { async openTable (name: string, embeddings: EmbeddingFunction): Promise> async openTable (name: string, embeddings?: EmbeddingFunction): Promise> async openTable (name: string, embeddings?: EmbeddingFunction): Promise> { - const tbl = await databaseOpenTable.call(this._db, name, ...this.awsParams()) + const tbl = await databaseOpenTable.call(this._db, name, ...getAwsArgs(this._options())) if (embeddings !== undefined) { return new LocalTable(tbl, name, this._options(), embeddings) } else { @@ -336,7 +355,7 @@ export class LocalConnection implements Connection { buffer = await fromRecordsToBuffer(data, embeddingFunction) } - const tbl = await tableCreate.call(this._db, name, buffer, writeOptions?.writeMode?.toString(), ...this.awsParams()) + const tbl = await tableCreate.call(this._db, name, buffer, writeOptions?.writeMode?.toString(), ...getAwsArgs(this._options())) if (embeddingFunction !== undefined) { return new LocalTable(tbl, name, this._options(), embeddingFunction) } else { @@ -344,20 +363,6 @@ export class LocalConnection implements Connection { } } - private awsParams (): any[] { - // TODO: move this thing into rust - const awsCredentials = this._options().awsCredentials - const params = [] - if (awsCredentials !== undefined) { - params.push(awsCredentials.accessKeyId) - params.push(awsCredentials.secretKey) - if (awsCredentials.sessionToken !== undefined) { - params.push(awsCredentials.sessionToken) - } - } - return params - } - /** * Drop an existing table. * @param name The name of the table to drop. @@ -407,16 +412,12 @@ export class LocalTable implements Table { * @return The number of rows added to the table */ async add (data: Array>): Promise { - const callArgs = [this._tbl, await fromRecordsToBuffer(data, this._embeddings), WriteMode.Append.toString()] - const awsCredentials = this._options().awsCredentials - if (awsCredentials !== undefined) { - callArgs.push(awsCredentials.accessKeyId) - callArgs.push(awsCredentials.secretKey) - if (awsCredentials.sessionToken !== undefined) { - callArgs.push(awsCredentials.sessionToken) - } - } - return tableAdd.call(...callArgs).then((newTable: any) => { this._tbl = newTable }) + return tableAdd.call( + this._tbl, + await fromRecordsToBuffer(data, this._embeddings), + WriteMode.Append.toString(), + ...getAwsArgs(this._options()) + ).then((newTable: any) => { this._tbl = newTable }) } /** @@ -426,16 +427,12 @@ export class LocalTable implements Table { * @return The number of rows added to the table */ async overwrite (data: Array>): Promise { - const callArgs = [this._tbl, await fromRecordsToBuffer(data, this._embeddings), WriteMode.Overwrite.toString()] - const awsCredentials = this._options().awsCredentials - if (awsCredentials !== undefined) { - callArgs.push(awsCredentials.accessKeyId) - callArgs.push(awsCredentials.secretKey) - if (awsCredentials.sessionToken !== undefined) { - callArgs.push(awsCredentials.sessionToken) - } - } - return tableAdd.call(...callArgs).then((newTable: any) => { this._tbl = newTable }) + return tableAdd.call( + this._tbl, + await fromRecordsToBuffer(data, this._embeddings), + WriteMode.Overwrite.toString(), + ...getAwsArgs(this._options()) + ).then((newTable: any) => { this._tbl = newTable }) } /** diff --git a/rust/ffi/node/src/lib.rs b/rust/ffi/node/src/lib.rs index 64fa0f84..6791bc69 100644 --- a/rust/ffi/node/src/lib.rs +++ b/rust/ffi/node/src/lib.rs @@ -121,26 +121,28 @@ fn database_table_names(mut cx: FunctionContext) -> JsResult { Ok(promise) } -fn get_aws_creds( +/// Get AWS creds arguments from the context +/// Consumes 3 arguments +fn get_aws_creds( cx: &mut FunctionContext, arg_starting_location: i32, -) -> Result, NeonResult> { +) -> NeonResult> { let secret_key_id = cx .argument_opt(arg_starting_location) - .map(|arg| arg.downcast_or_throw::(cx).ok()) - .flatten() + .filter(|arg| arg.is_a::(cx)) + .and_then(|arg| arg.downcast_or_throw::(cx).ok()) .map(|v| v.value(cx)); let secret_key = cx .argument_opt(arg_starting_location + 1) - .map(|arg| arg.downcast_or_throw::(cx).ok()) - .flatten() + .filter(|arg| arg.is_a::(cx)) + .and_then(|arg| arg.downcast_or_throw::(cx).ok()) .map(|v| v.value(cx)); let temp_token = cx .argument_opt(arg_starting_location + 2) - .map(|arg| arg.downcast_or_throw::(cx).ok()) - .flatten() + .filter(|arg| arg.is_a::(cx)) + .and_then(|arg| arg.downcast_or_throw::(cx).ok()) .map(|v| v.value(cx)); match (secret_key_id, secret_key, temp_token) { @@ -152,7 +154,21 @@ fn get_aws_creds( }), ))), (None, None, None) => Ok(None), - _ => Err(cx.throw_error("Invalid credentials configuration")), + _ => cx.throw_error("Invalid credentials configuration"), + } +} + +/// Get AWS region arguments from the context +fn get_aws_region(cx: &mut FunctionContext, arg_location: i32) -> NeonResult> { + let region = cx + .argument_opt(arg_location) + .filter(|arg| arg.is_a::(cx)) + .map(|arg| arg.downcast_or_throw::(cx)); + + match region { + Some(Ok(region)) => Ok(Some(region.value(cx))), + None => Ok(None), + Some(Err(e)) => Err(e), } } @@ -162,14 +178,14 @@ fn database_open_table(mut cx: FunctionContext) -> JsResult { .downcast_or_throw::, _>(&mut cx)?; let table_name = cx.argument::(0)?.value(&mut cx); - let aws_creds = match get_aws_creds(&mut cx, 1) { - Ok(creds) => creds, - Err(err) => return err, - }; + let aws_creds = get_aws_creds(&mut cx, 1)?; + + let aws_region = get_aws_region(&mut cx, 4)?; let params = ReadParams { store_options: Some(ObjectStoreParams { aws_credentials: aws_creds, + aws_region, ..ObjectStoreParams::default() }), ..ReadParams::default() diff --git a/rust/ffi/node/src/query.rs b/rust/ffi/node/src/query.rs index 4bab78c9..94923e09 100644 --- a/rust/ffi/node/src/query.rs +++ b/rust/ffi/node/src/query.rs @@ -48,7 +48,10 @@ impl JsQuery { .map(|s| s.value(&mut cx)) .map(|s| MetricType::try_from(s.as_str()).unwrap()); - let is_electron = cx.argument::(1).or_throw(&mut cx)?.value(&mut cx); + let is_electron = cx + .argument::(1) + .or_throw(&mut cx)? + .value(&mut cx); let rt = runtime(&mut cx)?; @@ -86,7 +89,11 @@ impl JsQuery { } // Creates a new JsBuffer from a rust buffer with a special logic for electron - fn new_js_buffer<'a>(buffer: Vec, cx: &mut TaskContext<'a>, is_electron: bool) -> NeonResult> { + fn new_js_buffer<'a>( + buffer: Vec, + cx: &mut TaskContext<'a>, + is_electron: bool, + ) -> NeonResult> { if is_electron { // Electron does not support `external`: https://github.com/neon-bindings/neon/pull/937 let mut js_buffer = JsBuffer::new(cx, buffer.len()).or_throw(cx)?; diff --git a/rust/ffi/node/src/table.rs b/rust/ffi/node/src/table.rs index 7a96bbd8..0d7bf3e3 100644 --- a/rust/ffi/node/src/table.rs +++ b/rust/ffi/node/src/table.rs @@ -22,7 +22,7 @@ use neon::types::buffer::TypedArray; use vectordb::Table; use crate::error::ResultExt; -use crate::{get_aws_creds, runtime, JsDatabase}; +use crate::{get_aws_creds, get_aws_region, runtime, JsDatabase}; pub(crate) struct JsTable { pub table: Table, @@ -61,14 +61,13 @@ impl JsTable { let (deferred, promise) = cx.promise(); let database = db.database.clone(); - let aws_creds = match get_aws_creds(&mut cx, 3) { - Ok(creds) => creds, - Err(err) => return err, - }; + let aws_creds = get_aws_creds(&mut cx, 3)?; + let aws_region = get_aws_region(&mut cx, 6)?; let params = WriteParams { store_params: Some(ObjectStoreParams { aws_credentials: aws_creds, + aws_region, ..ObjectStoreParams::default() }), mode: mode, @@ -105,14 +104,13 @@ impl JsTable { "overwrite" => WriteMode::Overwrite, s => return cx.throw_error(format!("invalid write mode {}", s)), }; - let aws_creds = match get_aws_creds(&mut cx, 2) { - Ok(creds) => creds, - Err(err) => return err, - }; + let aws_creds = get_aws_creds(&mut cx, 2)?; + let aws_region = get_aws_region(&mut cx, 5)?; let params = WriteParams { store_params: Some(ObjectStoreParams { aws_credentials: aws_creds, + aws_region, ..ObjectStoreParams::default() }), mode: write_mode,