From 6036cf48a74a374306d051f06a9f0951e0642df6 Mon Sep 17 00:00:00 2001 From: gsilvestrin Date: Wed, 26 Jul 2023 13:44:58 -0700 Subject: [PATCH] fix(node) Replace panic errors with friendlier ones (#366) - Implement Result/Error in the node FFI - Implement a trait (ResultExt) to make error handling less verbose - Refactor some parts of the code that touch arrow into arrow.rs --- Cargo.toml | 1 + node/src/test/test.ts | 20 +++++++++++ rust/ffi/node/Cargo.toml | 1 + rust/ffi/node/src/arrow.rs | 59 ++++++++++++++++++++---------- rust/ffi/node/src/error.rs | 73 ++++++++++++++++++++++++++++++++++++++ rust/ffi/node/src/lib.rs | 55 ++++++++++------------------ rust/vectordb/Cargo.toml | 2 +- 7 files changed, 155 insertions(+), 56 deletions(-) create mode 100644 rust/ffi/node/src/error.rs diff --git a/Cargo.toml b/Cargo.toml index 715a20a3..19440f2c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,4 +13,5 @@ arrow-schema = "42.0" arrow-ipc = "42.0" half = { "version" = "=2.2.1", default-features = false } object_store = "0.6.1" +snafu = "0.7.4" diff --git a/node/src/test/test.ts b/node/src/test/test.ts index e0e76114..de5ecd70 100644 --- a/node/src/test/test.ts +++ b/node/src/test/test.ts @@ -134,6 +134,18 @@ describe('LanceDB client', function () { assert.equal(await table.countRows(), 2) }) + it('fails to create a new table when the vector column is missing', async function () { + const dir = await track().mkdir('lancejs') + const con = await lancedb.connect(dir) + + const data = [ + { id: 1, price: 10 } + ] + + const create = con.createTable('missing_vector', data) + await expect(create).to.be.rejectedWith(Error, 'column \'vector\' is missing') + }) + it('use overwrite flag to overwrite existing table', async function () { const dir = await track().mkdir('lancejs') const con = await lancedb.connect(dir) @@ -230,6 +242,14 @@ describe('LanceDB client', function () { // Default replace = true await table.createIndex({ type: 'ivf_pq', column: 'vector', num_partitions: 2, max_iters: 2, num_sub_vectors: 2 }) }).timeout(50_000) + + it('it should fail when the column is not a vector', async function () { + const uri = await createTestDB(32, 300) + const con = await lancedb.connect(uri) + const table = await con.openTable('vectors') + const createIndex = table.createIndex({ type: 'ivf_pq', column: 'name', num_partitions: 2, max_iters: 2, num_sub_vectors: 2 }) + await expect(createIndex).to.be.rejectedWith(/VectorIndex requires the column data type to be fixed size list of float32s/) + }) }) describe('when using a custom embedding function', function () { diff --git a/rust/ffi/node/Cargo.toml b/rust/ffi/node/Cargo.toml index f507b6c7..679ee3f0 100644 --- a/rust/ffi/node/Cargo.toml +++ b/rust/ffi/node/Cargo.toml @@ -21,5 +21,6 @@ vectordb = { path = "../../vectordb" } tokio = { version = "1.23", features = ["rt-multi-thread"] } neon = {version = "0.10.1", default-features = false, features = ["channel-api", "napi-6", "promise-api", "task-api"] } object_store = { workspace = true, features = ["aws"] } +snafu = { workspace = true } async-trait = "0" env_logger = "0" diff --git a/rust/ffi/node/src/arrow.rs b/rust/ffi/node/src/arrow.rs index 494cc1b4..64dbe1e9 100644 --- a/rust/ffi/node/src/arrow.rs +++ b/rust/ffi/node/src/arrow.rs @@ -13,27 +13,30 @@ // limitations under the License. use std::io::Cursor; +use std::ops::Deref; use std::sync::Arc; use arrow_array::cast::as_list_array; -use arrow_array::{Array, FixedSizeListArray, RecordBatch}; +use arrow_array::{Array, ArrayRef, FixedSizeListArray, RecordBatch}; use arrow_ipc::reader::FileReader; +use arrow_ipc::writer::FileWriter; use arrow_schema::{DataType, Field, Schema}; use lance::arrow::{FixedSizeListArrayExt, RecordBatchExt}; +use vectordb::table::VECTOR_COLUMN_NAME; + +use crate::error::{MissingColumnSnafu, Result}; +use snafu::prelude::*; + +pub(crate) fn convert_record_batch(record_batch: RecordBatch) -> Result { + let column = get_column(VECTOR_COLUMN_NAME, &record_batch)?; -pub(crate) fn convert_record_batch(record_batch: RecordBatch) -> RecordBatch { - let column = record_batch - .column_by_name("vector") - .cloned() - .expect("vector column is missing"); // TODO: we should just consume the underlying js buffer in the future instead of this arrow around a bunch of times let arr = as_list_array(column.as_ref()); let list_size = arr.values().len() / record_batch.num_rows(); - let r = - FixedSizeListArray::try_new_from_values(arr.values().to_owned(), list_size as i32).unwrap(); + let r = FixedSizeListArray::try_new_from_values(arr.values().to_owned(), list_size as i32)?; let schema = Arc::new(Schema::new(vec![Field::new( - "vector", + VECTOR_COLUMN_NAME, DataType::FixedSizeList( Arc::new(Field::new("item", DataType::Float32, true)), list_size as i32, @@ -41,22 +44,42 @@ pub(crate) fn convert_record_batch(record_batch: RecordBatch) -> RecordBatch { true, )])); - let mut new_batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(r)]).unwrap(); + let mut new_batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(r)])?; if record_batch.num_columns() > 1 { - let rb = record_batch.drop_column("vector").unwrap(); - new_batch = new_batch.merge(&rb).unwrap(); + let rb = record_batch.drop_column(VECTOR_COLUMN_NAME)?; + new_batch = new_batch.merge(&rb)?; } - new_batch + Ok(new_batch) } -pub(crate) fn arrow_buffer_to_record_batch(slice: &[u8]) -> Vec { +fn get_column(column_name: &str, record_batch: &RecordBatch) -> Result { + record_batch + .column_by_name(column_name) + .cloned() + .context(MissingColumnSnafu { name: column_name }) +} + +pub(crate) fn arrow_buffer_to_record_batch(slice: &[u8]) -> Result> { let mut batches: Vec = Vec::new(); - let fr = FileReader::try_new(Cursor::new(slice), None); - let file_reader = fr.unwrap(); + let file_reader = FileReader::try_new(Cursor::new(slice), None)?; for b in file_reader { - let record_batch = convert_record_batch(b.unwrap()); + let record_batch = convert_record_batch(b?)?; batches.push(record_batch); } - batches + Ok(batches) +} + +pub(crate) fn record_batch_to_buffer(batches: Vec) -> Result> { + if batches.is_empty() { + return Ok(Vec::new()); + } + + let schema = batches.get(0).unwrap().schema(); + let mut fr = FileWriter::try_new(Vec::new(), schema.deref())?; + for batch in batches.iter() { + fr.write(batch)? + } + fr.finish()?; + Ok(fr.into_inner()?) } diff --git a/rust/ffi/node/src/error.rs b/rust/ffi/node/src/error.rs new file mode 100644 index 00000000..3a164251 --- /dev/null +++ b/rust/ffi/node/src/error.rs @@ -0,0 +1,73 @@ +// Copyright 2023 Lance Developers. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use arrow_schema::ArrowError; +use neon::context::Context; +use neon::prelude::NeonResult; +use snafu::Snafu; + +#[derive(Debug, Snafu)] +#[snafu(visibility(pub(crate)))] +pub enum Error { + #[snafu(display("column '{name}' is missing"))] + MissingColumn { name: String }, + #[snafu(display("{message}"))] + LanceDB { message: String }, +} + +pub type Result = std::result::Result; + +impl From for Error { + fn from(e: vectordb::error::Error) -> Self { + Self::LanceDB { + message: e.to_string(), + } + } +} + +impl From for Error { + fn from(e: lance::Error) -> Self { + Self::LanceDB { + message: e.to_string(), + } + } +} + +impl From for Error { + fn from(value: ArrowError) -> Self { + Self::LanceDB { + message: value.to_string(), + } + } +} + +/// ResultExt is used to transform a [`Result`] into a [`NeonResult`], +/// so it can be returned as a JavaScript error +/// Copied from [Neon](https://github.com/neon-bindings/neon/blob/4c2e455a9e6814f1ba0178616d63caec7f4df317/crates/neon/src/result/mod.rs#L88) +pub trait ResultExt { + fn or_throw<'a, C: Context<'a>>(self, cx: &mut C) -> NeonResult; +} + +/// Implement ResultExt for the std Result so it can be used any Result type +impl ResultExt for std::result::Result +where + E: std::fmt::Display, +{ + fn or_throw<'a, C: Context<'a>>(self, cx: &mut C) -> NeonResult { + match self { + Ok(value) => Ok(value), + Err(error) => cx.throw_error(error.to_string()), + } + } +} diff --git a/rust/ffi/node/src/lib.rs b/rust/ffi/node/src/lib.rs index dbc3a02f..d40448af 100644 --- a/rust/ffi/node/src/lib.rs +++ b/rust/ffi/node/src/lib.rs @@ -18,7 +18,6 @@ use std::ops::Deref; use std::sync::{Arc, Mutex}; use arrow_array::{Float32Array, RecordBatchIterator}; -use arrow_ipc::writer::FileWriter; use async_trait::async_trait; use futures::{TryFutureExt, TryStreamExt}; use lance::dataset::{WriteMode, WriteParams}; @@ -35,10 +34,12 @@ use vectordb::database::Database; use vectordb::error::Error; use vectordb::table::{ReadParams, Table}; -use crate::arrow::arrow_buffer_to_record_batch; +use crate::arrow::{arrow_buffer_to_record_batch, record_batch_to_buffer}; +use crate::error::ResultExt; mod arrow; mod convert; +mod error; mod index; struct JsDatabase { @@ -86,7 +87,7 @@ fn runtime<'a, C: Context<'a>>(cx: &mut C) -> NeonResult<&'static Runtime> { LOG.get_or_init(|| env_logger::init()); - RUNTIME.get_or_try_init(|| Runtime::new().or_else(|err| cx.throw_error(err.to_string()))) + RUNTIME.get_or_try_init(|| Runtime::new().or_throw(cx)) } fn database_new(mut cx: FunctionContext) -> JsResult { @@ -101,7 +102,7 @@ fn database_new(mut cx: FunctionContext) -> JsResult { deferred.settle_with(&channel, move |mut cx| { let db = JsDatabase { - database: Arc::new(database.or_else(|err| cx.throw_error(err.to_string()))?), + database: Arc::new(database.or_throw(&mut cx)?), }; Ok(cx.boxed(db)) }); @@ -123,7 +124,7 @@ fn database_table_names(mut cx: FunctionContext) -> JsResult { let tables_rst = database.table_names().await; deferred.settle_with(&channel, move |mut cx| { - let tables = tables_rst.or_else(|err| cx.throw_error(err.to_string()))?; + let tables = tables_rst.or_throw(&mut cx)?; let table_names = convert::vec_str_to_array(&tables, &mut cx); table_names }); @@ -194,9 +195,7 @@ fn database_open_table(mut cx: FunctionContext) -> JsResult { let table_rst = database.open_table_with_params(&table_name, ¶ms).await; deferred.settle_with(&channel, move |mut cx| { - let table = Arc::new(Mutex::new( - table_rst.or_else(|err| cx.throw_error(err.to_string()))?, - )); + let table = Arc::new(Mutex::new(table_rst.or_throw(&mut cx)?)); Ok(cx.boxed(JsTable { table })) }); }); @@ -217,7 +216,7 @@ fn database_drop_table(mut cx: FunctionContext) -> JsResult { rt.spawn(async move { let result = database.drop_table(&table_name).await; deferred.settle_with(&channel, move |mut cx| { - result.or_else(|err| cx.throw_error(err.to_string()))?; + result.or_throw(&mut cx)?; Ok(cx.null()) }); }); @@ -282,26 +281,9 @@ fn table_search(mut cx: FunctionContext) -> JsResult { .await; deferred.settle_with(&channel, move |mut cx| { - let results = results.or_else(|err| cx.throw_error(err.to_string()))?; - let vector: Vec = Vec::new(); - - if results.is_empty() { - return cx.buffer(0); - } - - let schema = results.get(0).unwrap().schema(); - let mut fr = FileWriter::try_new(vector, schema.deref()) - .or_else(|err| cx.throw_error(err.to_string()))?; - - for batch in results.iter() { - fr.write(batch) - .or_else(|err| cx.throw_error(err.to_string()))?; - } - fr.finish().or_else(|err| cx.throw_error(err.to_string()))?; - let buf = fr - .into_inner() - .or_else(|err| cx.throw_error(err.to_string()))?; - Ok(JsBuffer::external(&mut cx, buf)) + let results = results.or_throw(&mut cx)?; + let buffer = record_batch_to_buffer(results).or_throw(&mut cx)?; + Ok(JsBuffer::external(&mut cx, buffer)) }); }); Ok(promise) @@ -313,7 +295,7 @@ fn table_create(mut cx: FunctionContext) -> JsResult { .downcast_or_throw::, _>(&mut cx)?; let table_name = cx.argument::(0)?.value(&mut cx); let buffer = cx.argument::(1)?; - let batches = arrow_buffer_to_record_batch(buffer.as_slice(&mut cx)); + let batches = arrow_buffer_to_record_batch(buffer.as_slice(&mut cx)).or_throw(&mut cx)?; let schema = batches[0].schema(); // Write mode @@ -351,9 +333,7 @@ fn table_create(mut cx: FunctionContext) -> JsResult { .await; deferred.settle_with(&channel, move |mut cx| { - let table = Arc::new(Mutex::new( - table_rst.or_else(|err| cx.throw_error(err.to_string()))?, - )); + let table = Arc::new(Mutex::new(table_rst.or_throw(&mut cx)?)); Ok(cx.boxed(JsTable { table })) }); }); @@ -370,7 +350,8 @@ fn table_add(mut cx: FunctionContext) -> JsResult { let js_table = cx.this().downcast_or_throw::, _>(&mut cx)?; let buffer = cx.argument::(0)?; let write_mode = cx.argument::(1)?.value(&mut cx); - let batches = arrow_buffer_to_record_batch(buffer.as_slice(&mut cx)); + + let batches = arrow_buffer_to_record_batch(buffer.as_slice(&mut cx)).or_throw(&mut cx)?; let schema = batches[0].schema(); let rt = runtime(&mut cx)?; @@ -399,7 +380,7 @@ fn table_add(mut cx: FunctionContext) -> JsResult { let add_result = table.lock().unwrap().add(batch_reader, Some(params)).await; deferred.settle_with(&channel, move |mut cx| { - let _added = add_result.or_else(|err| cx.throw_error(err.to_string()))?; + let _added = add_result.or_throw(&mut cx)?; Ok(cx.boolean(true)) }); }); @@ -418,7 +399,7 @@ fn table_count_rows(mut cx: FunctionContext) -> JsResult { let num_rows_result = table.lock().unwrap().count_rows().await; deferred.settle_with(&channel, move |mut cx| { - let num_rows = num_rows_result.or_else(|err| cx.throw_error(err.to_string()))?; + let num_rows = num_rows_result.or_throw(&mut cx)?; Ok(cx.number(num_rows as f64)) }); }); @@ -438,7 +419,7 @@ fn table_delete(mut cx: FunctionContext) -> JsResult { let delete_result = rt.block_on(async move { table.lock().unwrap().delete(&predicate).await }); deferred.settle_with(&channel, move |mut cx| { - delete_result.or_else(|err| cx.throw_error(err.to_string()))?; + delete_result.or_throw(&mut cx)?; Ok(cx.undefined()) }); diff --git a/rust/vectordb/Cargo.toml b/rust/vectordb/Cargo.toml index 66a24f63..1c7939c2 100644 --- a/rust/vectordb/Cargo.toml +++ b/rust/vectordb/Cargo.toml @@ -12,7 +12,7 @@ arrow-array = { workspace = true } arrow-data = { workspace = true } arrow-schema = { workspace = true } object_store = { workspace = true } -snafu = "0.7.4" +snafu = { workspace = true } half = { workspace = true } lance = { workspace = true } tokio = { version = "1.23", features = ["rt-multi-thread"] }