expose awsRegion to be configurable (#441)

This commit is contained in:
Rob Meng
2023-08-22 16:00:14 -04:00
committed by GitHub
parent 2737315cb2
commit 30f5bc5865
5 changed files with 79 additions and 61 deletions

View File

@@ -6,7 +6,7 @@ members = [
resolver = "2"
[workspace.dependencies]
lance = "=0.6.2"
lance = "=0.6.3"
arrow-array = "43.0"
arrow-data = "43.0"
arrow-schema = "43.0"

View File

@@ -42,6 +42,8 @@ export interface ConnectionOptions {
awsCredentials?: AwsCredentials
awsRegion?: string
// API key for the remote connections
apiKey?: string
// Region to connect
@@ -51,6 +53,23 @@ export interface ConnectionOptions {
hostOverride?: string
}
function getAwsArgs (opts: ConnectionOptions): any[] {
const callArgs = []
const awsCredentials = opts.awsCredentials
if (awsCredentials !== undefined) {
callArgs.push(awsCredentials.accessKeyId)
callArgs.push(awsCredentials.secretKey)
callArgs.push(awsCredentials.sessionToken)
} else {
callArgs.push(undefined)
callArgs.push(undefined)
callArgs.push(undefined)
}
callArgs.push(opts.awsRegion)
return callArgs
}
export interface CreateTableOptions<T> {
// Name of Table
name: string
@@ -282,7 +301,7 @@ export class LocalConnection implements Connection {
async openTable<T> (name: string, embeddings: EmbeddingFunction<T>): Promise<Table<T>>
async openTable<T> (name: string, embeddings?: EmbeddingFunction<T>): Promise<Table<T>>
async openTable<T> (name: string, embeddings?: EmbeddingFunction<T>): Promise<Table<T>> {
const tbl = await databaseOpenTable.call(this._db, name, ...this.awsParams())
const tbl = await databaseOpenTable.call(this._db, name, ...getAwsArgs(this._options()))
if (embeddings !== undefined) {
return new LocalTable(tbl, name, this._options(), embeddings)
} else {
@@ -336,7 +355,7 @@ export class LocalConnection implements Connection {
buffer = await fromRecordsToBuffer(data, embeddingFunction)
}
const tbl = await tableCreate.call(this._db, name, buffer, writeOptions?.writeMode?.toString(), ...this.awsParams())
const tbl = await tableCreate.call(this._db, name, buffer, writeOptions?.writeMode?.toString(), ...getAwsArgs(this._options()))
if (embeddingFunction !== undefined) {
return new LocalTable(tbl, name, this._options(), embeddingFunction)
} else {
@@ -344,20 +363,6 @@ export class LocalConnection implements Connection {
}
}
private awsParams (): any[] {
// TODO: move this thing into rust
const awsCredentials = this._options().awsCredentials
const params = []
if (awsCredentials !== undefined) {
params.push(awsCredentials.accessKeyId)
params.push(awsCredentials.secretKey)
if (awsCredentials.sessionToken !== undefined) {
params.push(awsCredentials.sessionToken)
}
}
return params
}
/**
* Drop an existing table.
* @param name The name of the table to drop.
@@ -407,16 +412,12 @@ export class LocalTable<T = number[]> implements Table<T> {
* @return The number of rows added to the table
*/
async add (data: Array<Record<string, unknown>>): Promise<number> {
const callArgs = [this._tbl, await fromRecordsToBuffer(data, this._embeddings), WriteMode.Append.toString()]
const awsCredentials = this._options().awsCredentials
if (awsCredentials !== undefined) {
callArgs.push(awsCredentials.accessKeyId)
callArgs.push(awsCredentials.secretKey)
if (awsCredentials.sessionToken !== undefined) {
callArgs.push(awsCredentials.sessionToken)
}
}
return tableAdd.call(...callArgs).then((newTable: any) => { this._tbl = newTable })
return tableAdd.call(
this._tbl,
await fromRecordsToBuffer(data, this._embeddings),
WriteMode.Append.toString(),
...getAwsArgs(this._options())
).then((newTable: any) => { this._tbl = newTable })
}
/**
@@ -426,16 +427,12 @@ export class LocalTable<T = number[]> implements Table<T> {
* @return The number of rows added to the table
*/
async overwrite (data: Array<Record<string, unknown>>): Promise<number> {
const callArgs = [this._tbl, await fromRecordsToBuffer(data, this._embeddings), WriteMode.Overwrite.toString()]
const awsCredentials = this._options().awsCredentials
if (awsCredentials !== undefined) {
callArgs.push(awsCredentials.accessKeyId)
callArgs.push(awsCredentials.secretKey)
if (awsCredentials.sessionToken !== undefined) {
callArgs.push(awsCredentials.sessionToken)
}
}
return tableAdd.call(...callArgs).then((newTable: any) => { this._tbl = newTable })
return tableAdd.call(
this._tbl,
await fromRecordsToBuffer(data, this._embeddings),
WriteMode.Overwrite.toString(),
...getAwsArgs(this._options())
).then((newTable: any) => { this._tbl = newTable })
}
/**

View File

@@ -121,26 +121,28 @@ fn database_table_names(mut cx: FunctionContext) -> JsResult<JsPromise> {
Ok(promise)
}
fn get_aws_creds<T>(
/// Get AWS creds arguments from the context
/// Consumes 3 arguments
fn get_aws_creds(
cx: &mut FunctionContext,
arg_starting_location: i32,
) -> Result<Option<AwsCredentialProvider>, NeonResult<T>> {
) -> NeonResult<Option<AwsCredentialProvider>> {
let secret_key_id = cx
.argument_opt(arg_starting_location)
.map(|arg| arg.downcast_or_throw::<JsString, FunctionContext>(cx).ok())
.flatten()
.filter(|arg| arg.is_a::<JsString, _>(cx))
.and_then(|arg| arg.downcast_or_throw::<JsString, FunctionContext>(cx).ok())
.map(|v| v.value(cx));
let secret_key = cx
.argument_opt(arg_starting_location + 1)
.map(|arg| arg.downcast_or_throw::<JsString, FunctionContext>(cx).ok())
.flatten()
.filter(|arg| arg.is_a::<JsString, _>(cx))
.and_then(|arg| arg.downcast_or_throw::<JsString, FunctionContext>(cx).ok())
.map(|v| v.value(cx));
let temp_token = cx
.argument_opt(arg_starting_location + 2)
.map(|arg| arg.downcast_or_throw::<JsString, FunctionContext>(cx).ok())
.flatten()
.filter(|arg| arg.is_a::<JsString, _>(cx))
.and_then(|arg| arg.downcast_or_throw::<JsString, FunctionContext>(cx).ok())
.map(|v| v.value(cx));
match (secret_key_id, secret_key, temp_token) {
@@ -152,7 +154,21 @@ fn get_aws_creds<T>(
}),
))),
(None, None, None) => Ok(None),
_ => Err(cx.throw_error("Invalid credentials configuration")),
_ => cx.throw_error("Invalid credentials configuration"),
}
}
/// Get AWS region arguments from the context
fn get_aws_region(cx: &mut FunctionContext, arg_location: i32) -> NeonResult<Option<String>> {
let region = cx
.argument_opt(arg_location)
.filter(|arg| arg.is_a::<JsString, _>(cx))
.map(|arg| arg.downcast_or_throw::<JsString, FunctionContext>(cx));
match region {
Some(Ok(region)) => Ok(Some(region.value(cx))),
None => Ok(None),
Some(Err(e)) => Err(e),
}
}
@@ -162,14 +178,14 @@ fn database_open_table(mut cx: FunctionContext) -> JsResult<JsPromise> {
.downcast_or_throw::<JsBox<JsDatabase>, _>(&mut cx)?;
let table_name = cx.argument::<JsString>(0)?.value(&mut cx);
let aws_creds = match get_aws_creds(&mut cx, 1) {
Ok(creds) => creds,
Err(err) => return err,
};
let aws_creds = get_aws_creds(&mut cx, 1)?;
let aws_region = get_aws_region(&mut cx, 4)?;
let params = ReadParams {
store_options: Some(ObjectStoreParams {
aws_credentials: aws_creds,
aws_region,
..ObjectStoreParams::default()
}),
..ReadParams::default()

View File

@@ -48,7 +48,10 @@ impl JsQuery {
.map(|s| s.value(&mut cx))
.map(|s| MetricType::try_from(s.as_str()).unwrap());
let is_electron = cx.argument::<JsBoolean>(1).or_throw(&mut cx)?.value(&mut cx);
let is_electron = cx
.argument::<JsBoolean>(1)
.or_throw(&mut cx)?
.value(&mut cx);
let rt = runtime(&mut cx)?;
@@ -86,7 +89,11 @@ impl JsQuery {
}
// Creates a new JsBuffer from a rust buffer with a special logic for electron
fn new_js_buffer<'a>(buffer: Vec<u8>, cx: &mut TaskContext<'a>, is_electron: bool) -> NeonResult<Handle<'a, JsBuffer>> {
fn new_js_buffer<'a>(
buffer: Vec<u8>,
cx: &mut TaskContext<'a>,
is_electron: bool,
) -> NeonResult<Handle<'a, JsBuffer>> {
if is_electron {
// Electron does not support `external`: https://github.com/neon-bindings/neon/pull/937
let mut js_buffer = JsBuffer::new(cx, buffer.len()).or_throw(cx)?;

View File

@@ -22,7 +22,7 @@ use neon::types::buffer::TypedArray;
use vectordb::Table;
use crate::error::ResultExt;
use crate::{get_aws_creds, runtime, JsDatabase};
use crate::{get_aws_creds, get_aws_region, runtime, JsDatabase};
pub(crate) struct JsTable {
pub table: Table,
@@ -61,14 +61,13 @@ impl JsTable {
let (deferred, promise) = cx.promise();
let database = db.database.clone();
let aws_creds = match get_aws_creds(&mut cx, 3) {
Ok(creds) => creds,
Err(err) => return err,
};
let aws_creds = get_aws_creds(&mut cx, 3)?;
let aws_region = get_aws_region(&mut cx, 6)?;
let params = WriteParams {
store_params: Some(ObjectStoreParams {
aws_credentials: aws_creds,
aws_region,
..ObjectStoreParams::default()
}),
mode: mode,
@@ -105,14 +104,13 @@ impl JsTable {
"overwrite" => WriteMode::Overwrite,
s => return cx.throw_error(format!("invalid write mode {}", s)),
};
let aws_creds = match get_aws_creds(&mut cx, 2) {
Ok(creds) => creds,
Err(err) => return err,
};
let aws_creds = get_aws_creds(&mut cx, 2)?;
let aws_region = get_aws_region(&mut cx, 5)?;
let params = WriteParams {
store_params: Some(ObjectStoreParams {
aws_credentials: aws_creds,
aws_region,
..ObjectStoreParams::default()
}),
mode: write_mode,