mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-24 15:30:38 +00:00
expose awsRegion to be configurable (#441)
This commit is contained in:
@@ -6,7 +6,7 @@ members = [
|
||||
resolver = "2"
|
||||
|
||||
[workspace.dependencies]
|
||||
lance = "=0.6.2"
|
||||
lance = "=0.6.3"
|
||||
arrow-array = "43.0"
|
||||
arrow-data = "43.0"
|
||||
arrow-schema = "43.0"
|
||||
|
||||
@@ -42,6 +42,8 @@ export interface ConnectionOptions {
|
||||
|
||||
awsCredentials?: AwsCredentials
|
||||
|
||||
awsRegion?: string
|
||||
|
||||
// API key for the remote connections
|
||||
apiKey?: string
|
||||
// Region to connect
|
||||
@@ -51,6 +53,23 @@ export interface ConnectionOptions {
|
||||
hostOverride?: string
|
||||
}
|
||||
|
||||
function getAwsArgs (opts: ConnectionOptions): any[] {
|
||||
const callArgs = []
|
||||
const awsCredentials = opts.awsCredentials
|
||||
if (awsCredentials !== undefined) {
|
||||
callArgs.push(awsCredentials.accessKeyId)
|
||||
callArgs.push(awsCredentials.secretKey)
|
||||
callArgs.push(awsCredentials.sessionToken)
|
||||
} else {
|
||||
callArgs.push(undefined)
|
||||
callArgs.push(undefined)
|
||||
callArgs.push(undefined)
|
||||
}
|
||||
|
||||
callArgs.push(opts.awsRegion)
|
||||
return callArgs
|
||||
}
|
||||
|
||||
export interface CreateTableOptions<T> {
|
||||
// Name of Table
|
||||
name: string
|
||||
@@ -282,7 +301,7 @@ export class LocalConnection implements Connection {
|
||||
async openTable<T> (name: string, embeddings: EmbeddingFunction<T>): Promise<Table<T>>
|
||||
async openTable<T> (name: string, embeddings?: EmbeddingFunction<T>): Promise<Table<T>>
|
||||
async openTable<T> (name: string, embeddings?: EmbeddingFunction<T>): Promise<Table<T>> {
|
||||
const tbl = await databaseOpenTable.call(this._db, name, ...this.awsParams())
|
||||
const tbl = await databaseOpenTable.call(this._db, name, ...getAwsArgs(this._options()))
|
||||
if (embeddings !== undefined) {
|
||||
return new LocalTable(tbl, name, this._options(), embeddings)
|
||||
} else {
|
||||
@@ -336,7 +355,7 @@ export class LocalConnection implements Connection {
|
||||
buffer = await fromRecordsToBuffer(data, embeddingFunction)
|
||||
}
|
||||
|
||||
const tbl = await tableCreate.call(this._db, name, buffer, writeOptions?.writeMode?.toString(), ...this.awsParams())
|
||||
const tbl = await tableCreate.call(this._db, name, buffer, writeOptions?.writeMode?.toString(), ...getAwsArgs(this._options()))
|
||||
if (embeddingFunction !== undefined) {
|
||||
return new LocalTable(tbl, name, this._options(), embeddingFunction)
|
||||
} else {
|
||||
@@ -344,20 +363,6 @@ export class LocalConnection implements Connection {
|
||||
}
|
||||
}
|
||||
|
||||
private awsParams (): any[] {
|
||||
// TODO: move this thing into rust
|
||||
const awsCredentials = this._options().awsCredentials
|
||||
const params = []
|
||||
if (awsCredentials !== undefined) {
|
||||
params.push(awsCredentials.accessKeyId)
|
||||
params.push(awsCredentials.secretKey)
|
||||
if (awsCredentials.sessionToken !== undefined) {
|
||||
params.push(awsCredentials.sessionToken)
|
||||
}
|
||||
}
|
||||
return params
|
||||
}
|
||||
|
||||
/**
|
||||
* Drop an existing table.
|
||||
* @param name The name of the table to drop.
|
||||
@@ -407,16 +412,12 @@ export class LocalTable<T = number[]> implements Table<T> {
|
||||
* @return The number of rows added to the table
|
||||
*/
|
||||
async add (data: Array<Record<string, unknown>>): Promise<number> {
|
||||
const callArgs = [this._tbl, await fromRecordsToBuffer(data, this._embeddings), WriteMode.Append.toString()]
|
||||
const awsCredentials = this._options().awsCredentials
|
||||
if (awsCredentials !== undefined) {
|
||||
callArgs.push(awsCredentials.accessKeyId)
|
||||
callArgs.push(awsCredentials.secretKey)
|
||||
if (awsCredentials.sessionToken !== undefined) {
|
||||
callArgs.push(awsCredentials.sessionToken)
|
||||
}
|
||||
}
|
||||
return tableAdd.call(...callArgs).then((newTable: any) => { this._tbl = newTable })
|
||||
return tableAdd.call(
|
||||
this._tbl,
|
||||
await fromRecordsToBuffer(data, this._embeddings),
|
||||
WriteMode.Append.toString(),
|
||||
...getAwsArgs(this._options())
|
||||
).then((newTable: any) => { this._tbl = newTable })
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -426,16 +427,12 @@ export class LocalTable<T = number[]> implements Table<T> {
|
||||
* @return The number of rows added to the table
|
||||
*/
|
||||
async overwrite (data: Array<Record<string, unknown>>): Promise<number> {
|
||||
const callArgs = [this._tbl, await fromRecordsToBuffer(data, this._embeddings), WriteMode.Overwrite.toString()]
|
||||
const awsCredentials = this._options().awsCredentials
|
||||
if (awsCredentials !== undefined) {
|
||||
callArgs.push(awsCredentials.accessKeyId)
|
||||
callArgs.push(awsCredentials.secretKey)
|
||||
if (awsCredentials.sessionToken !== undefined) {
|
||||
callArgs.push(awsCredentials.sessionToken)
|
||||
}
|
||||
}
|
||||
return tableAdd.call(...callArgs).then((newTable: any) => { this._tbl = newTable })
|
||||
return tableAdd.call(
|
||||
this._tbl,
|
||||
await fromRecordsToBuffer(data, this._embeddings),
|
||||
WriteMode.Overwrite.toString(),
|
||||
...getAwsArgs(this._options())
|
||||
).then((newTable: any) => { this._tbl = newTable })
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -121,26 +121,28 @@ fn database_table_names(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
Ok(promise)
|
||||
}
|
||||
|
||||
fn get_aws_creds<T>(
|
||||
/// Get AWS creds arguments from the context
|
||||
/// Consumes 3 arguments
|
||||
fn get_aws_creds(
|
||||
cx: &mut FunctionContext,
|
||||
arg_starting_location: i32,
|
||||
) -> Result<Option<AwsCredentialProvider>, NeonResult<T>> {
|
||||
) -> NeonResult<Option<AwsCredentialProvider>> {
|
||||
let secret_key_id = cx
|
||||
.argument_opt(arg_starting_location)
|
||||
.map(|arg| arg.downcast_or_throw::<JsString, FunctionContext>(cx).ok())
|
||||
.flatten()
|
||||
.filter(|arg| arg.is_a::<JsString, _>(cx))
|
||||
.and_then(|arg| arg.downcast_or_throw::<JsString, FunctionContext>(cx).ok())
|
||||
.map(|v| v.value(cx));
|
||||
|
||||
let secret_key = cx
|
||||
.argument_opt(arg_starting_location + 1)
|
||||
.map(|arg| arg.downcast_or_throw::<JsString, FunctionContext>(cx).ok())
|
||||
.flatten()
|
||||
.filter(|arg| arg.is_a::<JsString, _>(cx))
|
||||
.and_then(|arg| arg.downcast_or_throw::<JsString, FunctionContext>(cx).ok())
|
||||
.map(|v| v.value(cx));
|
||||
|
||||
let temp_token = cx
|
||||
.argument_opt(arg_starting_location + 2)
|
||||
.map(|arg| arg.downcast_or_throw::<JsString, FunctionContext>(cx).ok())
|
||||
.flatten()
|
||||
.filter(|arg| arg.is_a::<JsString, _>(cx))
|
||||
.and_then(|arg| arg.downcast_or_throw::<JsString, FunctionContext>(cx).ok())
|
||||
.map(|v| v.value(cx));
|
||||
|
||||
match (secret_key_id, secret_key, temp_token) {
|
||||
@@ -152,7 +154,21 @@ fn get_aws_creds<T>(
|
||||
}),
|
||||
))),
|
||||
(None, None, None) => Ok(None),
|
||||
_ => Err(cx.throw_error("Invalid credentials configuration")),
|
||||
_ => cx.throw_error("Invalid credentials configuration"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get AWS region arguments from the context
|
||||
fn get_aws_region(cx: &mut FunctionContext, arg_location: i32) -> NeonResult<Option<String>> {
|
||||
let region = cx
|
||||
.argument_opt(arg_location)
|
||||
.filter(|arg| arg.is_a::<JsString, _>(cx))
|
||||
.map(|arg| arg.downcast_or_throw::<JsString, FunctionContext>(cx));
|
||||
|
||||
match region {
|
||||
Some(Ok(region)) => Ok(Some(region.value(cx))),
|
||||
None => Ok(None),
|
||||
Some(Err(e)) => Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -162,14 +178,14 @@ fn database_open_table(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
.downcast_or_throw::<JsBox<JsDatabase>, _>(&mut cx)?;
|
||||
let table_name = cx.argument::<JsString>(0)?.value(&mut cx);
|
||||
|
||||
let aws_creds = match get_aws_creds(&mut cx, 1) {
|
||||
Ok(creds) => creds,
|
||||
Err(err) => return err,
|
||||
};
|
||||
let aws_creds = get_aws_creds(&mut cx, 1)?;
|
||||
|
||||
let aws_region = get_aws_region(&mut cx, 4)?;
|
||||
|
||||
let params = ReadParams {
|
||||
store_options: Some(ObjectStoreParams {
|
||||
aws_credentials: aws_creds,
|
||||
aws_region,
|
||||
..ObjectStoreParams::default()
|
||||
}),
|
||||
..ReadParams::default()
|
||||
|
||||
@@ -48,7 +48,10 @@ impl JsQuery {
|
||||
.map(|s| s.value(&mut cx))
|
||||
.map(|s| MetricType::try_from(s.as_str()).unwrap());
|
||||
|
||||
let is_electron = cx.argument::<JsBoolean>(1).or_throw(&mut cx)?.value(&mut cx);
|
||||
let is_electron = cx
|
||||
.argument::<JsBoolean>(1)
|
||||
.or_throw(&mut cx)?
|
||||
.value(&mut cx);
|
||||
|
||||
let rt = runtime(&mut cx)?;
|
||||
|
||||
@@ -86,7 +89,11 @@ impl JsQuery {
|
||||
}
|
||||
|
||||
// Creates a new JsBuffer from a rust buffer with a special logic for electron
|
||||
fn new_js_buffer<'a>(buffer: Vec<u8>, cx: &mut TaskContext<'a>, is_electron: bool) -> NeonResult<Handle<'a, JsBuffer>> {
|
||||
fn new_js_buffer<'a>(
|
||||
buffer: Vec<u8>,
|
||||
cx: &mut TaskContext<'a>,
|
||||
is_electron: bool,
|
||||
) -> NeonResult<Handle<'a, JsBuffer>> {
|
||||
if is_electron {
|
||||
// Electron does not support `external`: https://github.com/neon-bindings/neon/pull/937
|
||||
let mut js_buffer = JsBuffer::new(cx, buffer.len()).or_throw(cx)?;
|
||||
|
||||
@@ -22,7 +22,7 @@ use neon::types::buffer::TypedArray;
|
||||
use vectordb::Table;
|
||||
|
||||
use crate::error::ResultExt;
|
||||
use crate::{get_aws_creds, runtime, JsDatabase};
|
||||
use crate::{get_aws_creds, get_aws_region, runtime, JsDatabase};
|
||||
|
||||
pub(crate) struct JsTable {
|
||||
pub table: Table,
|
||||
@@ -61,14 +61,13 @@ impl JsTable {
|
||||
let (deferred, promise) = cx.promise();
|
||||
let database = db.database.clone();
|
||||
|
||||
let aws_creds = match get_aws_creds(&mut cx, 3) {
|
||||
Ok(creds) => creds,
|
||||
Err(err) => return err,
|
||||
};
|
||||
let aws_creds = get_aws_creds(&mut cx, 3)?;
|
||||
let aws_region = get_aws_region(&mut cx, 6)?;
|
||||
|
||||
let params = WriteParams {
|
||||
store_params: Some(ObjectStoreParams {
|
||||
aws_credentials: aws_creds,
|
||||
aws_region,
|
||||
..ObjectStoreParams::default()
|
||||
}),
|
||||
mode: mode,
|
||||
@@ -105,14 +104,13 @@ impl JsTable {
|
||||
"overwrite" => WriteMode::Overwrite,
|
||||
s => return cx.throw_error(format!("invalid write mode {}", s)),
|
||||
};
|
||||
let aws_creds = match get_aws_creds(&mut cx, 2) {
|
||||
Ok(creds) => creds,
|
||||
Err(err) => return err,
|
||||
};
|
||||
let aws_creds = get_aws_creds(&mut cx, 2)?;
|
||||
let aws_region = get_aws_region(&mut cx, 5)?;
|
||||
|
||||
let params = WriteParams {
|
||||
store_params: Some(ObjectStoreParams {
|
||||
aws_credentials: aws_creds,
|
||||
aws_region,
|
||||
..ObjectStoreParams::default()
|
||||
}),
|
||||
mode: write_mode,
|
||||
|
||||
Reference in New Issue
Block a user