diff --git a/node/package.json b/node/package.json index 0769de35..6a382db2 100644 --- a/node/package.json +++ b/node/package.json @@ -8,7 +8,7 @@ "tsc": "tsc -b", "build": "cargo-cp-artifact --artifact cdylib vectordb-node index.node -- cargo build --message-format=json", "build-release": "npm run build -- --release", - "test": "npm run tsc && mocha -recursive dist/test --exit", + "test": "npm run tsc && mocha -recursive dist/test", "lint": "eslint native.js src --ext .js,.ts", "clean": "rm -rf node_modules *.node dist/", "pack-build": "neon pack-build", diff --git a/node/src/index.ts b/node/src/index.ts index 2dd1ed16..fc634d85 100644 --- a/node/src/index.ts +++ b/node/src/index.ts @@ -23,7 +23,7 @@ import { Query } from './query' import { isEmbeddingFunction } from './embedding/embedding_function' // eslint-disable-next-line @typescript-eslint/no-var-requires -const { databaseNew, databaseTableNames, databaseOpenTable, databaseDropTable, tableCreate, tableAdd, tableCreateVectorIndex, tableCountRows, tableDelete, tableClose } = require('../native.js') +const { databaseNew, databaseTableNames, databaseOpenTable, databaseDropTable, tableCreate, tableAdd, tableCreateVectorIndex, tableCountRows, tableDelete } = require('../native.js') export { Query } export type { EmbeddingFunction } @@ -215,12 +215,6 @@ export interface Table { * ``` */ delete: (filter: string) => Promise - - /** - * Immediately closes the connection to this Table. After close is called, - * all operations on this Table will fail. - */ - close: () => Promise } /** @@ -316,7 +310,7 @@ export class LocalConnection implements Connection { } export class LocalTable implements Table { - private readonly _tbl: any + private _tbl: any private readonly _name: string private readonly _embeddings?: EmbeddingFunction private readonly _options: ConnectionOptions @@ -363,7 +357,7 @@ export class LocalTable implements Table { callArgs.push(this._options.awsCredentials.sessionToken) } } - return tableAdd.call(...callArgs) + return tableAdd.call(...callArgs).then((newTable: any) => { this._tbl = newTable }) } /** @@ -381,7 +375,7 @@ export class LocalTable implements Table { callArgs.push(this._options.awsCredentials.sessionToken) } } - return tableAdd.call(this._tbl, await fromRecordsToBuffer(data, this._embeddings), WriteMode.Overwrite.toString()) + return tableAdd.call(...callArgs).then((newTable: any) => { this._tbl = newTable }) } /** @@ -390,7 +384,7 @@ export class LocalTable implements Table { * @param indexParams The parameters of this Index, @see VectorIndexParams. */ async createIndex (indexParams: VectorIndexParams): Promise { - return tableCreateVectorIndex.call(this._tbl, indexParams) + return tableCreateVectorIndex.call(this._tbl, indexParams).then((newTable: any) => { this._tbl = newTable }) } /** @@ -406,15 +400,7 @@ export class LocalTable implements Table { * @param filter A filter in the same format used by a sql WHERE clause. */ async delete (filter: string): Promise { - return tableDelete.call(this._tbl, filter) - } - - /** - * Immediately closes the connection to this Table. After close is called, - * all operations on this Table will fail. - */ - async close (): Promise { - return tableClose.call(this._tbl) + return tableDelete.call(this._tbl, filter).then((newTable: any) => { this._tbl = newTable }) } } diff --git a/node/src/remote/index.ts b/node/src/remote/index.ts index 0eb9f24d..4078a04d 100644 --- a/node/src/remote/index.ts +++ b/node/src/remote/index.ts @@ -165,8 +165,4 @@ export class RemoteTable implements Table { async delete (filter: string): Promise { throw new Error('Not implemented') } - - async close (): Promise { - throw new Error('Not implemented') - } } diff --git a/rust/ffi/node/src/index/vector.rs b/rust/ffi/node/src/index/vector.rs index 9d3014f6..cd431c5d 100644 --- a/rust/ffi/node/src/index/vector.rs +++ b/rust/ffi/node/src/index/vector.rs @@ -12,13 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::convert::TryFrom; - use lance::index::vector::ivf::IvfBuildParams; use lance::index::vector::pq::PQBuildParams; use lance::index::vector::MetricType; use neon::context::FunctionContext; use neon::prelude::*; +use std::convert::TryFrom; use vectordb::index::vector::{IvfPQIndexBuilder, VectorIndexBuilder}; @@ -36,21 +35,17 @@ pub(crate) fn table_create_vector_index(mut cx: FunctionContext) -> JsResult, @@ -184,8 +184,7 @@ fn database_open_table(mut cx: FunctionContext) -> JsResult { let table_rst = database.open_table_with_params(&table_name, ¶ms).await; deferred.settle_with(&channel, move |mut cx| { - let table = table_rst.or_throw(&mut cx)?; - let js_table = JsTable::new(&mut cx, table).or_throw(&mut cx)?; + let js_table = JsTable::from(table_rst.or_throw(&mut cx)?); Ok(cx.boxed(js_table)) }); }); @@ -213,8 +212,6 @@ fn database_drop_table(mut cx: FunctionContext) -> JsResult { Ok(promise) } - - #[neon::main] fn main(mut cx: ModuleContext) -> NeonResult<()> { cx.export_function("databaseNew", database_new)?; @@ -226,7 +223,6 @@ fn main(mut cx: ModuleContext) -> NeonResult<()> { cx.export_function("tableAdd", JsTable::js_add)?; cx.export_function("tableCountRows", JsTable::js_count_rows)?; cx.export_function("tableDelete", JsTable::js_delete)?; - cx.export_function("tableClose", JsTable::js_close)?; cx.export_function( "tableCreateVectorIndex", index::vector::table_create_vector_index, diff --git a/rust/ffi/node/src/query.rs b/rust/ffi/node/src/query.rs index b3f89109..e3f2ec49 100644 --- a/rust/ffi/node/src/query.rs +++ b/rust/ffi/node/src/query.rs @@ -4,19 +4,17 @@ use std::ops::Deref; use arrow_array::Float32Array; use futures::{TryFutureExt, TryStreamExt}; use lance::index::vector::MetricType; -use neon::prelude::*; use neon::context::FunctionContext; use neon::handle::Handle; +use neon::prelude::*; -use crate::{convert, runtime}; use crate::arrow::record_batch_to_buffer; use crate::error::ResultExt; use crate::neon_ext::js_object_ext::JsObjectExt; use crate::table::JsTable; +use crate::{convert, runtime}; -pub(crate) struct JsQuery { - -} +pub(crate) struct JsQuery {} impl JsQuery { pub(crate) fn js_search(mut cx: FunctionContext) -> JsResult { @@ -52,37 +50,35 @@ impl JsQuery { let rt = runtime(&mut cx)?; let (deferred, promise) = cx.promise(); + let channel = cx.channel(); let query_vector = query_obj.get::(&mut cx, "_queryVector")?; let query = convert::js_array_to_vec(query_vector.deref(), &mut cx); + let table = js_table.table.clone(); - js_table - .send(deferred, move |table, channel, deferred| { - rt.block_on(async move { - let builder = table - .search(Float32Array::from(query)) - .limit(limit as usize) - .refine_factor(refine_factor) - .nprobes(nprobes) - .filter(filter) - .metric_type(metric_type) - .select(select); - let record_batch_stream = builder.execute(); - let results = record_batch_stream - .and_then(|stream| { - stream - .try_collect::>() - .map_err(vectordb::error::Error::from) - }) - .await; + rt.spawn(async move { + let builder = table + .search(Float32Array::from(query)) + .limit(limit as usize) + .refine_factor(refine_factor) + .nprobes(nprobes) + .filter(filter) + .metric_type(metric_type) + .select(select); + let record_batch_stream = builder.execute(); + let results = record_batch_stream + .and_then(|stream| { + stream + .try_collect::>() + .map_err(vectordb::error::Error::from) + }) + .await; - deferred.settle_with(&channel, move |mut cx| { - let results = results.or_throw(&mut cx)?; - let buffer = record_batch_to_buffer(results).or_throw(&mut cx)?; - Ok(JsBuffer::external(&mut cx, buffer)) - }); - }); - }) - .or_throw(&mut cx)?; + deferred.settle_with(&channel, move |mut cx| { + let results = results.or_throw(&mut cx)?; + let buffer = record_batch_to_buffer(results).or_throw(&mut cx)?; + Ok(JsBuffer::external(&mut cx, buffer)) + }); + }); Ok(promise) } -} \ No newline at end of file +} diff --git a/rust/ffi/node/src/table.rs b/rust/ffi/node/src/table.rs index e041e18c..f3f517d6 100644 --- a/rust/ffi/node/src/table.rs +++ b/rust/ffi/node/src/table.rs @@ -12,82 +12,31 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::sync::mpsc; -use std::thread; use arrow_array::RecordBatchIterator; use lance::dataset::{WriteMode, WriteParams}; use lance::io::object_store::ObjectStoreParams; -use neon::{prelude::*, types::Deferred}; -use neon::types::buffer::TypedArray; use crate::arrow::arrow_buffer_to_record_batch; +use neon::prelude::*; +use neon::types::buffer::TypedArray; +use vectordb::Table; -use crate::error::{Error, Result, ResultExt}; -use crate::{get_aws_creds, JsDatabase, runtime}; +use crate::error::ResultExt; +use crate::{get_aws_creds, runtime, JsDatabase}; -type TableCallback = Box; - -// Wraps a LanceDB table into a channel, allowing concurrent access pub(crate) struct JsTable { - tx: mpsc::Sender, + pub table: Table, } impl Finalize for JsTable {} -// Messages sent on the table channel -pub(crate) enum JsTableMessage { - // Promise to resolve and callback to be executed - Callback(Deferred, TableCallback), - // Forces to shutdown the thread - Close, -} - -impl JsTable { - pub(crate) fn new<'a, C>(cx: &mut C, mut table: vectordb::Table) -> Result - where - C: Context<'a>, - { - // Creates a mpsc Channel to receive messages / commands from Javascript - let (tx, rx) = mpsc::channel::(); - let channel = cx.channel(); - - // Spawn a new thread to receive messages without blocking the main JS thread - thread::spawn(move || { - // Runs until the channel is closed - while let Ok(message) = rx.recv() { - match message { - JsTableMessage::Callback(deferred, f) => { - f(&mut table, &channel, deferred); - }, - JsTableMessage::Close => break - } - } - }); - - Ok(Self { tx }) - } - - // It is not necessary to call `close` since the database will be closed when the wrapping - // `JsBox` is garbage collected. However, calling `close` allows the process to exit - // immediately instead of waiting on garbage collection. This is useful in tests. - pub(crate) fn close(&self) -> Result<()> { - self.tx.send(JsTableMessage::Close) - .map_err(Error::from) - } - - pub(crate) fn send( - &self, - deferred: Deferred, - callback: impl FnOnce(&mut vectordb::Table, &Channel, Deferred) + Send + 'static, - ) -> Result<()> { - self.tx - .send(JsTableMessage::Callback(deferred, Box::new(callback))) - .map_err(Error::from) +impl From for JsTable { + fn from(table: Table) -> Self { + JsTable { table } } } impl JsTable { - pub(crate) fn js_create(mut cx: FunctionContext) -> JsResult { let db = cx .this() @@ -102,7 +51,9 @@ impl JsTable { "overwrite" => WriteMode::Overwrite, "append" => WriteMode::Append, "create" => WriteMode::Create, - _ => return cx.throw_error("Table::create only supports 'overwrite' and 'create' modes"), + _ => { + return cx.throw_error("Table::create only supports 'overwrite' and 'create' modes") + } }; let rt = runtime(&mut cx)?; @@ -133,8 +84,7 @@ impl JsTable { deferred.settle_with(&channel, move |mut cx| { let table = table_rst.or_throw(&mut cx)?; - let js_table = JsTable::new(&mut cx, table).or_throw(&mut cx)?; - Ok(cx.boxed(js_table)) + Ok(cx.boxed(JsTable::from(table))) }); }); Ok(promise) @@ -149,13 +99,15 @@ impl JsTable { let schema = batches[0].schema(); let rt = runtime(&mut cx)?; + let channel = cx.channel(); + let mut table = js_table.table.clone(); let (deferred, promise) = cx.promise(); let write_mode = match write_mode.as_str() { "create" => WriteMode::Create, "append" => WriteMode::Append, "overwrite" => WriteMode::Overwrite, - s => return cx.throw_error(format!("invalid write mode {}", s)), + s => return cx.throw_error(format!("invalid write mode {}", s)), }; let aws_creds = match get_aws_creds(&mut cx, 2) { Ok(creds) => creds, @@ -171,40 +123,33 @@ impl JsTable { ..WriteParams::default() }; - js_table - .send(deferred, move |table, channel, deferred| { - rt.block_on(async move { - let batch_reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema); - let add_result = table.add(batch_reader, Some(params)).await; + rt.spawn(async move { + let batch_reader = RecordBatchIterator::new(batches.into_iter().map(Ok), schema); + let add_result = table.add(batch_reader, Some(params)).await; - deferred.settle_with(&channel, move |mut cx| { - let _added = add_result.or_throw(&mut cx)?; - Ok(cx.boolean(true)) - }); - }); - }) - .or_throw(&mut cx)?; + deferred.settle_with(&channel, move |mut cx| { + let _added = add_result.or_throw(&mut cx)?; + Ok(cx.boxed(JsTable::from(table))) + }); + }); Ok(promise) } pub(crate) fn js_count_rows(mut cx: FunctionContext) -> JsResult { let js_table = cx.this().downcast_or_throw::, _>(&mut cx)?; let rt = runtime(&mut cx)?; - let (deferred, promise) = cx.promise(); + let channel = cx.channel(); + let table = js_table.table.clone(); - js_table - .send(deferred, move |table, channel, deferred| { - rt.block_on(async move { - let num_rows_result = table.count_rows().await; + rt.spawn(async move { + let num_rows_result = table.count_rows().await; - deferred.settle_with(&channel, move |mut cx| { - let num_rows = num_rows_result.or_throw(&mut cx)?; - Ok(cx.number(num_rows as f64)) - }); - }); - }) - .or_throw(&mut cx)?; + deferred.settle_with(&channel, move |mut cx| { + let num_rows = num_rows_result.or_throw(&mut cx)?; + Ok(cx.number(num_rows as f64)) + }); + }); Ok(promise) } @@ -213,25 +158,17 @@ impl JsTable { let rt = runtime(&mut cx)?; let (deferred, promise) = cx.promise(); let predicate = cx.argument::(0)?.value(&mut cx); + let channel = cx.channel(); + let mut table = js_table.table.clone(); - js_table - .send(deferred, move |table, channel, deferred| { - let delete_result = rt.block_on(async move { table.delete(&predicate).await }); + rt.spawn(async move { + let delete_result = table.delete(&predicate).await; - deferred.settle_with(&channel, move |mut cx| { - delete_result.or_throw(&mut cx)?; - Ok(cx.undefined()) - }); + deferred.settle_with(&channel, move |mut cx| { + delete_result.or_throw(&mut cx)?; + Ok(cx.boxed(JsTable::from(table))) }) - .or_throw(&mut cx)?; + }); Ok(promise) } - - pub(crate) fn js_close(mut cx: FunctionContext) -> JsResult { - cx.this() - .downcast_or_throw::, _>(&mut cx)? - .close() - .or_throw(&mut cx)?; - Ok(cx.undefined()) - } }