From 593b5939becbc10f2e1b97f972dec8d5045fdc68 Mon Sep 17 00:00:00 2001 From: gsilvestrin Date: Tue, 1 Aug 2023 14:22:04 -0700 Subject: [PATCH] feat(node): Improve concurrency (#376) - Moved computation out of JS main thread by using a mpsc - Removes the Arc/Mutex since Table is owned by JsTable now - Moved table / query methods to their own files - Fixed js-transformers example --- node/examples/js-transformers/index.js | 2 +- node/examples/js-transformers/package.json | 2 +- rust/ffi/node/src/error.rs | 8 + rust/ffi/node/src/index/vector.rs | 30 +-- rust/ffi/node/src/lib.rs | 241 ++------------------- rust/ffi/node/src/query.rs | 88 ++++++++ rust/ffi/node/src/table.rs | 218 +++++++++++++++++++ 7 files changed, 345 insertions(+), 244 deletions(-) create mode 100644 rust/ffi/node/src/query.rs create mode 100644 rust/ffi/node/src/table.rs diff --git a/node/examples/js-transformers/index.js b/node/examples/js-transformers/index.js index ccf21f63..4400cbc6 100644 --- a/node/examples/js-transformers/index.js +++ b/node/examples/js-transformers/index.js @@ -50,7 +50,7 @@ async function example() { { id: 5, text: 'Banana', type: 'fruit' } ] - const table = await db.createTable('food_table', data, "create", embed_fun) + const table = await db.createTable('food_table', data, embed_fun) // Query the table diff --git a/node/examples/js-transformers/package.json b/node/examples/js-transformers/package.json index 4255e27a..b823f90b 100644 --- a/node/examples/js-transformers/package.json +++ b/node/examples/js-transformers/package.json @@ -10,7 +10,7 @@ "license": "Apache-2.0", "dependencies": { "@xenova/transformers": "^2.4.1", - "vectordb": "^0.1.12" + "vectordb": "file:../.." } } diff --git a/rust/ffi/node/src/error.rs b/rust/ffi/node/src/error.rs index b4f5ac3d..924cf544 100644 --- a/rust/ffi/node/src/error.rs +++ b/rust/ffi/node/src/error.rs @@ -67,6 +67,14 @@ impl From for Error { } } +impl From> for Error { + fn from(value: std::sync::mpsc::SendError) -> Self { + Self::Neon { + message: value.to_string(), + } + } +} + /// ResultExt is used to transform a [`Result`] into a [`NeonResult`], /// so it can be returned as a JavaScript error /// Copied from [Neon](https://github.com/neon-bindings/neon/blob/4c2e455a9e6814f1ba0178616d63caec7f4df317/crates/neon/src/result/mod.rs#L88) diff --git a/rust/ffi/node/src/index/vector.rs b/rust/ffi/node/src/index/vector.rs index 4f710e7b..9d3014f6 100644 --- a/rust/ffi/node/src/index/vector.rs +++ b/rust/ffi/node/src/index/vector.rs @@ -25,7 +25,8 @@ use vectordb::index::vector::{IvfPQIndexBuilder, VectorIndexBuilder}; use crate::error::Error::InvalidIndexType; use crate::error::ResultExt; use crate::neon_ext::js_object_ext::JsObjectExt; -use crate::{runtime, JsTable}; +use crate::runtime; +use crate::table::JsTable; pub(crate) fn table_create_vector_index(mut cx: FunctionContext) -> JsResult { let js_table = cx.this().downcast_or_throw::, _>(&mut cx)?; @@ -33,24 +34,23 @@ pub(crate) fn table_create_vector_index(mut cx: FunctionContext) -> JsResult, @@ -49,12 +43,6 @@ struct JsDatabase { impl Finalize for JsDatabase {} -struct JsTable { - table: Arc>, -} - -impl Finalize for JsTable {} - // TODO: object_store didn't export this type so I copied it. // Make a request to object_store to export this type #[derive(Debug)] @@ -196,8 +184,9 @@ fn database_open_table(mut cx: FunctionContext) -> JsResult { let table_rst = database.open_table_with_params(&table_name, ¶ms).await; deferred.settle_with(&channel, move |mut cx| { - let table = Arc::new(Mutex::new(table_rst.or_throw(&mut cx)?)); - Ok(cx.boxed(JsTable { table })) + let table = table_rst.or_throw(&mut cx)?; + let js_table = JsTable::new(&mut cx, table).or_throw(&mut cx)?; + Ok(cx.boxed(js_table)) }); }); Ok(promise) @@ -224,209 +213,7 @@ fn database_drop_table(mut cx: FunctionContext) -> JsResult { Ok(promise) } -fn table_search(mut cx: FunctionContext) -> JsResult { - let js_table = cx.this().downcast_or_throw::, _>(&mut cx)?; - let query_obj = cx.argument::(0)?; - let limit = query_obj - .get::(&mut cx, "_limit")? - .value(&mut cx); - let select = query_obj - .get_opt::(&mut cx, "_select")? - .map(|arr| { - let js_array = arr.deref(); - let mut projection_vec: Vec = Vec::new(); - for i in 0..js_array.len(&mut cx) { - let entry: Handle = js_array.get(&mut cx, i).unwrap(); - projection_vec.push(entry.value(&mut cx)); - } - projection_vec - }); - let filter = query_obj - .get_opt::(&mut cx, "_filter")? - .map(|s| s.value(&mut cx)); - let refine_factor = query_obj - .get_opt_u32(&mut cx, "_refineFactor") - .or_throw(&mut cx)?; - let nprobes = query_obj.get_usize(&mut cx, "_nprobes").or_throw(&mut cx)?; - let metric_type = query_obj - .get_opt::(&mut cx, "_metricType")? - .map(|s| s.value(&mut cx)) - .map(|s| MetricType::try_from(s.as_str()).unwrap()); - - let rt = runtime(&mut cx)?; - let channel = cx.channel(); - - let (deferred, promise) = cx.promise(); - let table = js_table.table.clone(); - let query_vector = query_obj.get::(&mut cx, "_queryVector")?; - let query = convert::js_array_to_vec(query_vector.deref(), &mut cx); - - rt.spawn(async move { - let builder = table - .lock() - .unwrap() - .search(Float32Array::from(query)) - .limit(limit as usize) - .refine_factor(refine_factor) - .nprobes(nprobes) - .filter(filter) - .metric_type(metric_type) - .select(select); - let record_batch_stream = builder.execute(); - let results = record_batch_stream - .and_then(|stream| { - stream - .try_collect::>() - .map_err(vectordb::error::Error::from) - }) - .await; - - deferred.settle_with(&channel, move |mut cx| { - let results = results.or_throw(&mut cx)?; - let buffer = record_batch_to_buffer(results).or_throw(&mut cx)?; - Ok(JsBuffer::external(&mut cx, buffer)) - }); - }); - Ok(promise) -} - -fn table_create(mut cx: FunctionContext) -> JsResult { - let db = cx - .this() - .downcast_or_throw::, _>(&mut cx)?; - let table_name = cx.argument::(0)?.value(&mut cx); - let buffer = cx.argument::(1)?; - let batches = arrow_buffer_to_record_batch(buffer.as_slice(&mut cx)).or_throw(&mut cx)?; - let schema = batches[0].schema(); - - // Write mode - let mode = match cx.argument::(2)?.value(&mut cx).as_str() { - "overwrite" => WriteMode::Overwrite, - "append" => WriteMode::Append, - "create" => WriteMode::Create, - _ => return cx.throw_error("Table::create only supports 'overwrite' and 'create' modes"), - }; - - let rt = runtime(&mut cx)?; - let channel = cx.channel(); - - let (deferred, promise) = cx.promise(); - let database = db.database.clone(); - - let aws_creds = match get_aws_creds(&mut cx, 3) { - Ok(creds) => creds, - Err(err) => return err, - }; - - let params = WriteParams { - store_params: Some(ObjectStoreParams { - aws_credentials: aws_creds, - ..ObjectStoreParams::default() - }), - mode: mode, - ..WriteParams::default() - }; - - rt.block_on(async move { - let batch_reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema); - let table_rst = database - .create_table(&table_name, batch_reader, Some(params)) - .await; - - deferred.settle_with(&channel, move |mut cx| { - let table = Arc::new(Mutex::new(table_rst.or_throw(&mut cx)?)); - Ok(cx.boxed(JsTable { table })) - }); - }); - Ok(promise) -} - -fn table_add(mut cx: FunctionContext) -> JsResult { - let write_mode_map: HashMap<&str, WriteMode> = HashMap::from([ - ("create", WriteMode::Create), - ("append", WriteMode::Append), - ("overwrite", WriteMode::Overwrite), - ]); - - let js_table = cx.this().downcast_or_throw::, _>(&mut cx)?; - let buffer = cx.argument::(0)?; - let write_mode = cx.argument::(1)?.value(&mut cx); - - let batches = arrow_buffer_to_record_batch(buffer.as_slice(&mut cx)).or_throw(&mut cx)?; - let schema = batches[0].schema(); - - let rt = runtime(&mut cx)?; - let channel = cx.channel(); - - let (deferred, promise) = cx.promise(); - let table = js_table.table.clone(); - let write_mode = write_mode_map.get(write_mode.as_str()).cloned(); - - let aws_creds = match get_aws_creds(&mut cx, 2) { - Ok(creds) => creds, - Err(err) => return err, - }; - - let params = WriteParams { - store_params: Some(ObjectStoreParams { - aws_credentials: aws_creds, - ..ObjectStoreParams::default() - }), - mode: write_mode.unwrap_or(WriteMode::Append), - ..WriteParams::default() - }; - - rt.block_on(async move { - let batch_reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema); - let add_result = table.lock().unwrap().add(batch_reader, Some(params)).await; - - deferred.settle_with(&channel, move |mut cx| { - let _added = add_result.or_throw(&mut cx)?; - Ok(cx.boolean(true)) - }); - }); - Ok(promise) -} - -fn table_count_rows(mut cx: FunctionContext) -> JsResult { - let js_table = cx.this().downcast_or_throw::, _>(&mut cx)?; - let rt = runtime(&mut cx)?; - let channel = cx.channel(); - - let (deferred, promise) = cx.promise(); - let table = js_table.table.clone(); - - rt.block_on(async move { - let num_rows_result = table.lock().unwrap().count_rows().await; - - deferred.settle_with(&channel, move |mut cx| { - let num_rows = num_rows_result.or_throw(&mut cx)?; - Ok(cx.number(num_rows as f64)) - }); - }); - Ok(promise) -} - -fn table_delete(mut cx: FunctionContext) -> JsResult { - let js_table = cx.this().downcast_or_throw::, _>(&mut cx)?; - let rt = runtime(&mut cx)?; - let channel = cx.channel(); - - let (deferred, promise) = cx.promise(); - let table = js_table.table.clone(); - - let predicate = cx.argument::(0)?.value(&mut cx); - - let delete_result = rt.block_on(async move { table.lock().unwrap().delete(&predicate).await }); - - deferred.settle_with(&channel, move |mut cx| { - delete_result.or_throw(&mut cx)?; - Ok(cx.undefined()) - }); - - Ok(promise) -} #[neon::main] fn main(mut cx: ModuleContext) -> NeonResult<()> { @@ -434,11 +221,11 @@ fn main(mut cx: ModuleContext) -> NeonResult<()> { cx.export_function("databaseTableNames", database_table_names)?; cx.export_function("databaseOpenTable", database_open_table)?; cx.export_function("databaseDropTable", database_drop_table)?; - cx.export_function("tableSearch", table_search)?; - cx.export_function("tableCreate", table_create)?; - cx.export_function("tableAdd", table_add)?; - cx.export_function("tableCountRows", table_count_rows)?; - cx.export_function("tableDelete", table_delete)?; + cx.export_function("tableSearch", JsQuery::js_search)?; + cx.export_function("tableCreate", JsTable::js_create)?; + cx.export_function("tableAdd", JsTable::js_add)?; + cx.export_function("tableCountRows", JsTable::js_count_rows)?; + cx.export_function("tableDelete", JsTable::js_delete)?; cx.export_function( "tableCreateVectorIndex", index::vector::table_create_vector_index, diff --git a/rust/ffi/node/src/query.rs b/rust/ffi/node/src/query.rs new file mode 100644 index 00000000..b3f89109 --- /dev/null +++ b/rust/ffi/node/src/query.rs @@ -0,0 +1,88 @@ +use std::convert::TryFrom; +use std::ops::Deref; + +use arrow_array::Float32Array; +use futures::{TryFutureExt, TryStreamExt}; +use lance::index::vector::MetricType; +use neon::prelude::*; +use neon::context::FunctionContext; +use neon::handle::Handle; + +use crate::{convert, runtime}; +use crate::arrow::record_batch_to_buffer; +use crate::error::ResultExt; +use crate::neon_ext::js_object_ext::JsObjectExt; +use crate::table::JsTable; + +pub(crate) struct JsQuery { + +} + +impl JsQuery { + pub(crate) fn js_search(mut cx: FunctionContext) -> JsResult { + let js_table = cx.this().downcast_or_throw::, _>(&mut cx)?; + let query_obj = cx.argument::(0)?; + + let limit = query_obj + .get::(&mut cx, "_limit")? + .value(&mut cx); + let select = query_obj + .get_opt::(&mut cx, "_select")? + .map(|arr| { + let js_array = arr.deref(); + let mut projection_vec: Vec = Vec::new(); + for i in 0..js_array.len(&mut cx) { + let entry: Handle = js_array.get(&mut cx, i).unwrap(); + projection_vec.push(entry.value(&mut cx)); + } + projection_vec + }); + let filter = query_obj + .get_opt::(&mut cx, "_filter")? + .map(|s| s.value(&mut cx)); + let refine_factor = query_obj + .get_opt_u32(&mut cx, "_refineFactor") + .or_throw(&mut cx)?; + let nprobes = query_obj.get_usize(&mut cx, "_nprobes").or_throw(&mut cx)?; + let metric_type = query_obj + .get_opt::(&mut cx, "_metricType")? + .map(|s| s.value(&mut cx)) + .map(|s| MetricType::try_from(s.as_str()).unwrap()); + + let rt = runtime(&mut cx)?; + + let (deferred, promise) = cx.promise(); + let query_vector = query_obj.get::(&mut cx, "_queryVector")?; + let query = convert::js_array_to_vec(query_vector.deref(), &mut cx); + + js_table + .send(deferred, move |table, channel, deferred| { + rt.block_on(async move { + let builder = table + .search(Float32Array::from(query)) + .limit(limit as usize) + .refine_factor(refine_factor) + .nprobes(nprobes) + .filter(filter) + .metric_type(metric_type) + .select(select); + let record_batch_stream = builder.execute(); + let results = record_batch_stream + .and_then(|stream| { + stream + .try_collect::>() + .map_err(vectordb::error::Error::from) + }) + .await; + + deferred.settle_with(&channel, move |mut cx| { + let results = results.or_throw(&mut cx)?; + let buffer = record_batch_to_buffer(results).or_throw(&mut cx)?; + Ok(JsBuffer::external(&mut cx, buffer)) + }); + }); + }) + .or_throw(&mut cx)?; + Ok(promise) + } +} \ No newline at end of file diff --git a/rust/ffi/node/src/table.rs b/rust/ffi/node/src/table.rs new file mode 100644 index 00000000..e956e2db --- /dev/null +++ b/rust/ffi/node/src/table.rs @@ -0,0 +1,218 @@ +// Copyright 2023 Lance Developers. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::mpsc; +use std::thread; +use arrow_array::RecordBatchIterator; +use lance::dataset::{WriteMode, WriteParams}; +use lance::io::object_store::ObjectStoreParams; + +use neon::{prelude::*, types::Deferred}; +use neon::types::buffer::TypedArray; +use crate::arrow::arrow_buffer_to_record_batch; + +use crate::error::{Error, Result, ResultExt}; +use crate::{get_aws_creds, JsDatabase, runtime}; + +type TableCallback = Box; + +// Wraps a LanceDB table into a channel, allowing concurrent access +pub(crate) struct JsTable { + tx: mpsc::Sender, +} + +impl Finalize for JsTable {} + +// Messages sent on the table channel +pub(crate) enum JsTableMessage { + // Promise to resolve and callback to be executed + Callback(Deferred, TableCallback), +} + +impl JsTable { + pub(crate) fn new<'a, C>(cx: &mut C, mut table: vectordb::Table) -> Result + where + C: Context<'a>, + { + // Creates a mpsc Channel to receive messages / commands from Javascript + let (tx, rx) = mpsc::channel::(); + let channel = cx.channel(); + + // Spawn a new thread to receive messages without blocking the main JS thread + thread::spawn(move || { + // Runs until the channel is closed + while let Ok(message) = rx.recv() { + match message { + JsTableMessage::Callback(deferred, f) => { + f(&mut table, &channel, deferred); + } + } + } + }); + + Ok(Self { tx }) + } + + pub(crate) fn send( + &self, + deferred: Deferred, + callback: impl FnOnce(&mut vectordb::Table, &Channel, Deferred) + Send + 'static, + ) -> Result<()> { + self.tx + .send(JsTableMessage::Callback(deferred, Box::new(callback))) + .map_err(Error::from) + } +} + +impl JsTable { + + pub(crate) fn js_create(mut cx: FunctionContext) -> JsResult { + let db = cx + .this() + .downcast_or_throw::, _>(&mut cx)?; + let table_name = cx.argument::(0)?.value(&mut cx); + let buffer = cx.argument::(1)?; + let batches = arrow_buffer_to_record_batch(buffer.as_slice(&mut cx)).or_throw(&mut cx)?; + let schema = batches[0].schema(); + + // Write mode + let mode = match cx.argument::(2)?.value(&mut cx).as_str() { + "overwrite" => WriteMode::Overwrite, + "append" => WriteMode::Append, + "create" => WriteMode::Create, + _ => return cx.throw_error("Table::create only supports 'overwrite' and 'create' modes"), + }; + + let rt = runtime(&mut cx)?; + let channel = cx.channel(); + + let (deferred, promise) = cx.promise(); + let database = db.database.clone(); + + let aws_creds = match get_aws_creds(&mut cx, 3) { + Ok(creds) => creds, + Err(err) => return err, + }; + + let params = WriteParams { + store_params: Some(ObjectStoreParams { + aws_credentials: aws_creds, + ..ObjectStoreParams::default() + }), + mode: mode, + ..WriteParams::default() + }; + + rt.spawn(async move { + let batch_reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema); + let table_rst = database + .create_table(&table_name, batch_reader, Some(params)) + .await; + + deferred.settle_with(&channel, move |mut cx| { + let table = table_rst.or_throw(&mut cx)?; + let js_table = JsTable::new(&mut cx, table).or_throw(&mut cx)?; + Ok(cx.boxed(js_table)) + }); + }); + Ok(promise) + } + + pub(crate) fn js_add(mut cx: FunctionContext) -> JsResult { + let js_table = cx.this().downcast_or_throw::, _>(&mut cx)?; + let buffer = cx.argument::(0)?; + let write_mode = cx.argument::(1)?.value(&mut cx); + + let batches = arrow_buffer_to_record_batch(buffer.as_slice(&mut cx)).or_throw(&mut cx)?; + let schema = batches[0].schema(); + + let rt = runtime(&mut cx)?; + + let (deferred, promise) = cx.promise(); + let write_mode = match write_mode.as_str() { + "create" => WriteMode::Create, + "append" => WriteMode::Append, + "overwrite" => WriteMode::Overwrite, + s => return cx.throw_error(format!("invalid write mode {}", s)), + }; + let aws_creds = match get_aws_creds(&mut cx, 2) { + Ok(creds) => creds, + Err(err) => return err, + }; + + let params = WriteParams { + store_params: Some(ObjectStoreParams { + aws_credentials: aws_creds, + ..ObjectStoreParams::default() + }), + mode: write_mode, + ..WriteParams::default() + }; + + js_table + .send(deferred, move |table, channel, deferred| { + rt.block_on(async move { + let batch_reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema); + let add_result = table.add(batch_reader, Some(params)).await; + + deferred.settle_with(&channel, move |mut cx| { + let _added = add_result.or_throw(&mut cx)?; + Ok(cx.boolean(true)) + }); + }); + }) + .or_throw(&mut cx)?; + Ok(promise) + } + + pub(crate) fn js_count_rows(mut cx: FunctionContext) -> JsResult { + let js_table = cx.this().downcast_or_throw::, _>(&mut cx)?; + let rt = runtime(&mut cx)?; + + let (deferred, promise) = cx.promise(); + + js_table + .send(deferred, move |table, channel, deferred| { + rt.block_on(async move { + let num_rows_result = table.count_rows().await; + + deferred.settle_with(&channel, move |mut cx| { + let num_rows = num_rows_result.or_throw(&mut cx)?; + Ok(cx.number(num_rows as f64)) + }); + }); + }) + .or_throw(&mut cx)?; + Ok(promise) + } + + pub(crate) fn js_delete(mut cx: FunctionContext) -> JsResult { + let js_table = cx.this().downcast_or_throw::, _>(&mut cx)?; + let rt = runtime(&mut cx)?; + let (deferred, promise) = cx.promise(); + let predicate = cx.argument::(0)?.value(&mut cx); + + js_table + .send(deferred, move |table, channel, deferred| { + let delete_result = rt.block_on(async move { table.delete(&predicate).await }); + + deferred.settle_with(&channel, move |mut cx| { + delete_result.or_throw(&mut cx)?; + Ok(cx.undefined()) + }); + }) + .or_throw(&mut cx)?; + Ok(promise) + } +}