Upgrade lance to 0.5.5, and plumb thru new features from the upgrade (#279)

* upgrade
* fixes for the upgrade
* allow JS users to pass custom AWS credentials
This commit is contained in:
Rob Meng
2023-07-11 16:33:39 -04:00
committed by GitHub
parent 80c25f9896
commit ace6aa883a
9 changed files with 226 additions and 81 deletions

View File

@@ -6,9 +6,9 @@ members = [
resolver = "2"
[workspace.dependencies]
lance = "=0.5.3"
arrow-array = "40.0"
arrow-data = "40.0"
arrow-schema = "40.0"
arrow-ipc = "40.0"
lance = "=0.5.5"
arrow-array = "42.0"
arrow-data = "42.0"
arrow-schema = "42.0"
arrow-ipc = "42.0"
object_store = "0.6.1"

View File

@@ -1,12 +1,12 @@
{
"name": "vectordb",
"version": "0.1.9",
"version": "0.1.10",
"lockfileVersion": 2,
"requires": true,
"packages": {
"": {
"name": "vectordb",
"version": "0.1.9",
"version": "0.1.10",
"license": "Apache-2.0",
"dependencies": {
"@apache-arrow/ts": "^12.0.0",

View File

@@ -122,6 +122,14 @@ export interface Table<T = number[]> {
delete: (filter: string) => Promise<void>
}
export interface AwsCredentials {
accessKeyId: string
secretKey: string
sessionToken?: string
}
/**
* A connection to a LanceDB database.
*/
@@ -186,16 +194,23 @@ export class LocalConnection implements Connection {
* @param embeddings An embedding function to use on this Table
*/
async createTable<T> (name: string, data: Array<Record<string, unknown>>, mode: WriteMode, embeddings: EmbeddingFunction<T>): Promise<Table<T>>
async createTable<T> (name: string, data: Array<Record<string, unknown>>, mode: WriteMode, embeddings?: EmbeddingFunction<T>): Promise<Table<T>> {
async createTable<T> (name: string, data: Array<Record<string, unknown>>, mode: WriteMode, embeddings?: EmbeddingFunction<T>, awsCredentials?: AwsCredentials): Promise<Table<T>> {
if (mode === undefined) {
mode = WriteMode.Create
}
const tbl = await tableCreate.call(this._db, name, await fromRecordsToBuffer(data, embeddings), mode.toLowerCase())
if (embeddings !== undefined) {
return new LocalTable(tbl, name, embeddings)
} else {
return new LocalTable(tbl, name)
const createArgs = [this._db, name, await fromRecordsToBuffer(data, embeddings), mode.toLowerCase()]
if (awsCredentials !== undefined) {
createArgs.push(awsCredentials.accessKeyId)
createArgs.push(awsCredentials.secretKey)
if (awsCredentials.sessionToken !== undefined) {
createArgs.push(awsCredentials.sessionToken)
}
}
const tbl = await tableCreate.call(...createArgs)
return new LocalTable(tbl, name, embeddings, awsCredentials)
}
async createTableArrow (name: string, table: ArrowTable): Promise<Table> {
@@ -217,6 +232,7 @@ export class LocalTable<T = number[]> implements Table<T> {
private readonly _tbl: any
private readonly _name: string
private readonly _embeddings?: EmbeddingFunction<T>
private readonly _awsCredentials?: AwsCredentials
constructor (tbl: any, name: string)
/**
@@ -225,10 +241,12 @@ export class LocalTable<T = number[]> implements Table<T> {
* @param embeddings An embedding function to use when interacting with this table
*/
constructor (tbl: any, name: string, embeddings: EmbeddingFunction<T>)
constructor (tbl: any, name: string, embeddings?: EmbeddingFunction<T>) {
constructor (tbl: any, name: string, embeddings?: EmbeddingFunction<T>, awsCredentials?: AwsCredentials)
constructor (tbl: any, name: string, embeddings?: EmbeddingFunction<T>, awsCredentials?: AwsCredentials) {
this._tbl = tbl
this._name = name
this._embeddings = embeddings
this._awsCredentials = awsCredentials
}
get name (): string {
@@ -250,7 +268,15 @@ export class LocalTable<T = number[]> implements Table<T> {
* @return The number of rows added to the table
*/
async add (data: Array<Record<string, unknown>>): Promise<number> {
return tableAdd.call(this._tbl, await fromRecordsToBuffer(data, this._embeddings), WriteMode.Append.toString())
const callArgs = [this._tbl, await fromRecordsToBuffer(data, this._embeddings), WriteMode.Append.toString()]
if (this._awsCredentials !== undefined) {
callArgs.push(this._awsCredentials.accessKeyId)
callArgs.push(this._awsCredentials.secretKey)
if (this._awsCredentials.sessionToken !== undefined) {
callArgs.push(this._awsCredentials.sessionToken)
}
}
return tableAdd.call(...callArgs)
}
/**
@@ -260,6 +286,14 @@ export class LocalTable<T = number[]> implements Table<T> {
* @return The number of rows added to the table
*/
async overwrite (data: Array<Record<string, unknown>>): Promise<number> {
const callArgs = [this._tbl, await fromRecordsToBuffer(data, this._embeddings), WriteMode.Overwrite.toString()]
if (this._awsCredentials !== undefined) {
callArgs.push(this._awsCredentials.accessKeyId)
callArgs.push(this._awsCredentials.secretKey)
if (this._awsCredentials.sessionToken !== undefined) {
callArgs.push(this._awsCredentials.sessionToken)
}
}
return tableAdd.call(this._tbl, await fromRecordsToBuffer(data, this._embeddings), WriteMode.Overwrite.toString())
}

View File

@@ -19,3 +19,5 @@ lance = { workspace = true }
vectordb = { path = "../../vectordb" }
tokio = { version = "1.23", features = ["rt-multi-thread"] }
neon = {version = "0.10.1", default-features = false, features = ["channel-api", "napi-6", "promise-api", "task-api"] }
object_store = { workspace = true, features = ["aws"] }
async-trait = "0"

View File

@@ -13,7 +13,6 @@
// limitations under the License.
use std::io::Cursor;
use std::ops::Deref;
use std::sync::Arc;
use arrow_array::cast::as_list_array;
@@ -25,10 +24,13 @@ use lance::arrow::{FixedSizeListArrayExt, RecordBatchExt};
pub(crate) fn convert_record_batch(record_batch: RecordBatch) -> RecordBatch {
let column = record_batch
.column_by_name("vector")
.cloned()
.expect("vector column is missing");
let arr = as_list_array(column.deref());
// TODO: we should just consume the underlaying js buffer in the future instead of this arrow around a bunch of times
let arr = as_list_array(column.as_ref());
let list_size = arr.values().len() / record_batch.num_rows();
let r = FixedSizeListArray::try_new(arr.values(), list_size as i32).unwrap();
let r =
FixedSizeListArray::try_new_from_values(arr.values().to_owned(), list_size as i32).unwrap();
let schema = Arc::new(Schema::new(vec![Field::new(
"vector",

View File

@@ -17,19 +17,23 @@ use std::convert::TryFrom;
use std::ops::Deref;
use std::sync::{Arc, Mutex};
use arrow_array::{Float32Array, RecordBatchIterator, RecordBatchReader};
use arrow_array::{Float32Array, RecordBatchIterator};
use arrow_ipc::writer::FileWriter;
use async_trait::async_trait;
use futures::{TryFutureExt, TryStreamExt};
use lance::dataset::{WriteMode, WriteParams};
use lance::dataset::{ReadParams, WriteMode, WriteParams};
use lance::index::vector::MetricType;
use lance::io::object_store::ObjectStoreParams;
use neon::prelude::*;
use neon::types::buffer::TypedArray;
use object_store::aws::{AwsCredential, AwsCredentialProvider};
use object_store::CredentialProvider;
use once_cell::sync::OnceCell;
use tokio::runtime::Runtime;
use vectordb::database::Database;
use vectordb::error::Error;
use vectordb::table::Table;
use vectordb::table::{OpenTableParams, Table};
use crate::arrow::arrow_buffer_to_record_batch;
@@ -49,6 +53,33 @@ struct JsTable {
impl Finalize for JsTable {}
// TODO: object_store didn't export this type so I copied it.
// Make a requiest to object_store to export this type
#[derive(Debug)]
pub struct StaticCredentialProvider<T> {
credential: Arc<T>,
}
impl<T> StaticCredentialProvider<T> {
pub fn new(credential: T) -> Self {
Self {
credential: Arc::new(credential),
}
}
}
#[async_trait]
impl<T> CredentialProvider for StaticCredentialProvider<T>
where
T: std::fmt::Debug + Send + Sync,
{
type Credential = T;
async fn get_credential(&self) -> object_store::Result<Arc<T>> {
Ok(Arc::clone(&self.credential))
}
}
fn runtime<'a, C: Context<'a>>(cx: &mut C) -> NeonResult<&'static Runtime> {
static RUNTIME: OnceCell<Runtime> = OnceCell::new();
@@ -97,19 +128,74 @@ fn database_table_names(mut cx: FunctionContext) -> JsResult<JsPromise> {
Ok(promise)
}
fn get_aws_creds<T>(
cx: &mut FunctionContext,
arg_starting_location: i32,
) -> Result<Option<AwsCredentialProvider>, NeonResult<T>> {
let secret_key_id = cx
.argument_opt(arg_starting_location)
.map(|arg| arg.downcast_or_throw::<JsString, FunctionContext>(cx).ok())
.flatten()
.map(|v| v.value(cx));
let secret_key = cx
.argument_opt(arg_starting_location + 1)
.map(|arg| arg.downcast_or_throw::<JsString, FunctionContext>(cx).ok())
.flatten()
.map(|v| v.value(cx));
let temp_token = cx
.argument_opt(arg_starting_location + 2)
.map(|arg| arg.downcast_or_throw::<JsString, FunctionContext>(cx).ok())
.flatten()
.map(|v| v.value(cx));
match (secret_key_id, secret_key, temp_token) {
(Some(key_id), Some(key), optional_token) => Ok(Some(Arc::new(
StaticCredentialProvider::new(AwsCredential {
key_id: key_id,
secret_key: key,
token: optional_token,
}),
))),
(None, None, None) => Ok(None),
_ => Err(cx.throw_error("Invalid credentials configuration")),
}
}
fn database_open_table(mut cx: FunctionContext) -> JsResult<JsPromise> {
let db = cx
.this()
.downcast_or_throw::<JsBox<JsDatabase>, _>(&mut cx)?;
let table_name = cx.argument::<JsString>(0)?.value(&mut cx);
let aws_creds = match get_aws_creds(&mut cx, 1) {
Ok(creds) => creds,
Err(err) => return err,
};
let param = ReadParams {
store_options: Some(ObjectStoreParams {
aws_credentials: aws_creds,
..ObjectStoreParams::default()
}),
..ReadParams::default()
};
let rt = runtime(&mut cx)?;
let channel = cx.channel();
let database = db.database.clone();
let (deferred, promise) = cx.promise();
rt.spawn(async move {
let table_rst = database.open_table(&table_name).await;
let table_rst = database
.open_table_with_params(
&table_name,
OpenTableParams {
open_table_params: param,
},
)
.await;
deferred.settle_with(&channel, move |mut cx| {
let table = Arc::new(Mutex::new(
@@ -241,8 +327,6 @@ fn table_create(mut cx: FunctionContext) -> JsResult<JsPromise> {
"create" => WriteMode::Create,
_ => return cx.throw_error("Table::create only supports 'overwrite' and 'create' modes"),
};
let mut params = WriteParams::default();
params.mode = mode;
let rt = runtime(&mut cx)?;
let channel = cx.channel();
@@ -250,11 +334,22 @@ fn table_create(mut cx: FunctionContext) -> JsResult<JsPromise> {
let (deferred, promise) = cx.promise();
let database = db.database.clone();
let aws_creds = match get_aws_creds(&mut cx, 3) {
Ok(creds) => creds,
Err(err) => return err,
};
let params = WriteParams {
store_params: Some(ObjectStoreParams {
aws_credentials: aws_creds,
..ObjectStoreParams::default()
}),
mode: mode,
..WriteParams::default()
};
rt.block_on(async move {
let batch_reader: Box<dyn RecordBatchReader> = Box::new(RecordBatchIterator::new(
batches.into_iter().map(Ok),
schema,
));
let batch_reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema);
let table_rst = database
.create_table(&table_name, batch_reader, Some(params))
.await;
@@ -289,16 +384,27 @@ fn table_add(mut cx: FunctionContext) -> JsResult<JsPromise> {
let table = js_table.table.clone();
let write_mode = write_mode_map.get(write_mode.as_str()).cloned();
let aws_creds = match get_aws_creds(&mut cx, 2) {
Ok(creds) => creds,
Err(err) => return err,
};
let params = WriteParams {
store_params: Some(ObjectStoreParams {
aws_credentials: aws_creds,
..ObjectStoreParams::default()
}),
mode: write_mode.unwrap_or(WriteMode::Append),
..WriteParams::default()
};
rt.block_on(async move {
let batch_reader: Box<dyn RecordBatchReader> = Box::new(RecordBatchIterator::new(
batches.into_iter().map(Ok),
schema,
));
let add_result = table.lock().unwrap().add(batch_reader, write_mode).await;
let batch_reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema);
let add_result = table.lock().unwrap().add(batch_reader, Some(params)).await;
deferred.settle_with(&channel, move |mut cx| {
let added = add_result.or_else(|err| cx.throw_error(err.to_string()))?;
Ok(cx.number(added as f64))
let _added = add_result.or_else(|err| cx.throw_error(err.to_string()))?;
Ok(cx.boolean(true))
});
});
Ok(promise)

View File

@@ -100,7 +100,7 @@ impl Database {
pub async fn create_table(
&self,
name: &str,
batches: Box<dyn RecordBatchReader>,
batches: impl RecordBatchReader + Send + 'static,
params: Option<WriteParams>,
) -> Result<Table> {
Table::create(&self.uri, name, batches, params).await

View File

@@ -173,10 +173,8 @@ mod tests {
#[tokio::test]
async fn test_setters_getters() {
let mut batches: Box<dyn RecordBatchReader> = make_test_batches();
let ds = Dataset::write(&mut batches, "memory://foo", None)
.await
.unwrap();
let batches = make_test_batches();
let ds = Dataset::write(batches, "memory://foo", None).await.unwrap();
let vector = Float32Array::from_iter_values([0.1, 0.2]);
let query = Query::new(Arc::new(ds), vector.clone());
@@ -202,10 +200,8 @@ mod tests {
#[tokio::test]
async fn test_execute() {
let mut batches: Box<dyn RecordBatchReader> = make_test_batches();
let ds = Dataset::write(&mut batches, "memory://foo", None)
.await
.unwrap();
let batches = make_test_batches();
let ds = Dataset::write(batches, "memory://foo", None).await.unwrap();
let vector = Float32Array::from_iter_values([0.1; 128]);
let query = Query::new(Arc::new(ds), vector.clone());
@@ -213,7 +209,7 @@ mod tests {
assert_eq!(result.is_ok(), true);
}
fn make_test_batches() -> Box<dyn RecordBatchReader> {
fn make_test_batches() -> impl RecordBatchReader + Send + 'static {
let dim: usize = 128;
let schema = Arc::new(ArrowSchema::new(vec![
ArrowField::new("key", DataType::Int32, false),
@@ -227,11 +223,11 @@ mod tests {
),
ArrowField::new("uri", DataType::Utf8, true),
]));
Box::new(RecordBatchIterator::new(
RecordBatchIterator::new(
vec![RecordBatch::new_empty(schema.clone())]
.into_iter()
.map(Ok),
schema,
))
)
}
}

View File

@@ -22,8 +22,8 @@ use snafu::prelude::*;
use crate::error::{Error, InvalidTableNameSnafu, Result};
use crate::index::vector::VectorIndexBuilder;
use crate::WriteMode;
use crate::query::Query;
use crate::WriteMode;
pub const VECTOR_COLUMN_NAME: &str = "vector";
pub const LANCE_FILE_EXTENSION: &str = "lance";
@@ -117,7 +117,7 @@ impl Table {
pub async fn create(
base_uri: &str,
name: &str,
mut batches: Box<dyn RecordBatchReader>,
batches: impl RecordBatchReader + Send + 'static,
params: Option<WriteParams>,
) -> Result<Self> {
let base_path = Path::new(base_uri);
@@ -127,7 +127,7 @@ impl Table {
.to_str()
.context(InvalidTableNameSnafu { name })?
.to_string();
let dataset = Dataset::write(&mut batches, &uri, params)
let dataset = Dataset::write(batches, &uri, params)
.await
.map_err(|e| match e {
lance::Error::DatasetAlreadyExists { .. } => Error::TableAlreadyExists {
@@ -176,14 +176,16 @@ impl Table {
/// * The number of rows added
pub async fn add(
&mut self,
mut batches: Box<dyn RecordBatchReader>,
write_mode: Option<WriteMode>,
) -> Result<usize> {
let mut params = WriteParams::default();
params.mode = write_mode.unwrap_or(WriteMode::Append);
batches: impl RecordBatchReader + Send + 'static,
params: Option<WriteParams>,
) -> Result<()> {
let params = params.unwrap_or(WriteParams {
mode: WriteMode::Append,
..WriteParams::default()
});
self.dataset = Arc::new(Dataset::write(&mut batches, &self.uri, Some(params)).await?);
Ok(batches.count())
self.dataset = Arc::new(Dataset::write(batches, &self.uri, Some(params)).await?);
Ok(())
}
/// Creates a new Query object that can be executed.
@@ -207,12 +209,12 @@ impl Table {
/// Merge new data into this table.
pub async fn merge(
&mut self,
mut batches: Box<dyn RecordBatchReader>,
batches: impl RecordBatchReader + Send + 'static,
left_on: &str,
right_on: &str,
) -> Result<()> {
let mut dataset = self.dataset.as_ref().clone();
dataset.merge(&mut batches, left_on, right_on).await?;
dataset.merge(batches, left_on, right_on).await?;
self.dataset = Arc::new(dataset);
Ok(())
}
@@ -253,8 +255,8 @@ mod tests {
let dataset_path = tmp_dir.path().join("test.lance");
let uri = tmp_dir.path().to_str().unwrap();
let mut batches: Box<dyn RecordBatchReader> = make_test_batches();
Dataset::write(&mut batches, dataset_path.to_str().unwrap(), None)
let batches = make_test_batches();
Dataset::write(batches, dataset_path.to_str().unwrap(), None)
.await
.unwrap();
@@ -284,11 +286,11 @@ mod tests {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let batches: Box<dyn RecordBatchReader> = make_test_batches();
let batches = make_test_batches();
let _ = batches.schema().clone();
Table::create(&uri, "test", batches, None).await.unwrap();
let batches: Box<dyn RecordBatchReader> = make_test_batches();
let batches = make_test_batches();
let result = Table::create(&uri, "test", batches, None).await;
assert!(matches!(
result.unwrap_err(),
@@ -301,12 +303,12 @@ mod tests {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let batches: Box<dyn RecordBatchReader> = make_test_batches();
let batches = make_test_batches();
let schema = batches.schema().clone();
let mut table = Table::create(&uri, "test", batches, None).await.unwrap();
assert_eq!(table.count_rows().await.unwrap(), 10);
let new_batches: Box<dyn RecordBatchReader> = Box::new(RecordBatchIterator::new(
let new_batches = RecordBatchIterator::new(
vec![RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from_iter_values(100..110))],
@@ -315,7 +317,7 @@ mod tests {
.into_iter()
.map(Ok),
schema.clone(),
));
);
table.add(new_batches, None).await.unwrap();
assert_eq!(table.count_rows().await.unwrap(), 20);
@@ -327,12 +329,12 @@ mod tests {
let tmp_dir = tempdir().unwrap();
let uri = tmp_dir.path().to_str().unwrap();
let batches: Box<dyn RecordBatchReader> = make_test_batches();
let batches = make_test_batches();
let schema = batches.schema().clone();
let mut table = Table::create(uri, "test", batches, None).await.unwrap();
assert_eq!(table.count_rows().await.unwrap(), 10);
let new_batches: Box<dyn RecordBatchReader> = Box::new(RecordBatchIterator::new(
let new_batches = RecordBatchIterator::new(
vec![RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from_iter_values(100..110))],
@@ -341,10 +343,15 @@ mod tests {
.into_iter()
.map(Ok),
schema.clone(),
));
);
let param: WriteParams = WriteParams {
mode: WriteMode::Overwrite,
..Default::default()
};
table
.add(new_batches, Some(WriteMode::Overwrite))
.add(new_batches, Some(param))
.await
.unwrap();
assert_eq!(table.count_rows().await.unwrap(), 10);
@@ -357,8 +364,8 @@ mod tests {
let dataset_path = tmp_dir.path().join("test.lance");
let uri = tmp_dir.path().to_str().unwrap();
let mut batches: Box<dyn RecordBatchReader> = make_test_batches();
Dataset::write(&mut batches, dataset_path.to_str().unwrap(), None)
let batches = make_test_batches();
Dataset::write(batches, dataset_path.to_str().unwrap(), None)
.await
.unwrap();
@@ -369,7 +376,7 @@ mod tests {
assert_eq!(vector, query.query_vector);
}
#[derive(Default)]
#[derive(Default, Debug)]
struct NoOpCacheWrapper {
called: AtomicBool,
}
@@ -396,8 +403,8 @@ mod tests {
let dataset_path = tmp_dir.path().join("test.lance");
let uri = tmp_dir.path().to_str().unwrap();
let mut batches: Box<dyn RecordBatchReader> = make_test_batches();
Dataset::write(&mut batches, dataset_path.to_str().unwrap(), None)
let batches = make_test_batches();
Dataset::write(batches, dataset_path.to_str().unwrap(), None)
.await
.unwrap();
@@ -417,15 +424,15 @@ mod tests {
assert!(wrapper.called());
}
fn make_test_batches() -> Box<dyn RecordBatchReader> {
fn make_test_batches() -> impl RecordBatchReader + Send + Sync + 'static {
let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)]));
Box::new(RecordBatchIterator::new(
RecordBatchIterator::new(
vec![RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from_iter_values(0..10))],
)],
schema,
))
)
}
#[tokio::test]
@@ -465,9 +472,7 @@ mod tests {
schema,
);
let reader: Box<dyn RecordBatchReader + Send> = Box::new(batches);
let mut table = Table::create(uri, "test", reader, None).await.unwrap();
let mut table = Table::create(uri, "test", batches, None).await.unwrap();
let mut i = IvfPQIndexBuilder::new();
let index_builder = i