diff --git a/node/src/test/test.ts b/node/src/test/test.ts index de5ecd70..4fb6fe36 100644 --- a/node/src/test/test.ts +++ b/node/src/test/test.ts @@ -250,6 +250,14 @@ describe('LanceDB client', function () { const createIndex = table.createIndex({ type: 'ivf_pq', column: 'name', num_partitions: 2, max_iters: 2, num_sub_vectors: 2 }) await expect(createIndex).to.be.rejectedWith(/VectorIndex requires the column data type to be fixed size list of float32s/) }) + + it('it should fail when the column is not a vector', async function () { + const uri = await createTestDB(32, 300) + const con = await lancedb.connect(uri) + const table = await con.openTable('vectors') + const createIndex = table.createIndex({ type: 'ivf_pq', column: 'name', num_partitions: -1, max_iters: 2, num_sub_vectors: 2 }) + await expect(createIndex).to.be.rejectedWith('num_partitions: must be > 0') + }) }) describe('when using a custom embedding function', function () { diff --git a/rust/ffi/node/Cargo.toml b/rust/ffi/node/Cargo.toml index da5b9fe1..86b4219b 100644 --- a/rust/ffi/node/Cargo.toml +++ b/rust/ffi/node/Cargo.toml @@ -13,6 +13,7 @@ crate-type = ["cdylib"] arrow-array = { workspace = true } arrow-ipc = { workspace = true } arrow-schema = { workspace = true } +conv = "0.3.3" once_cell = "1" futures = "0.3" half = { workspace = true } diff --git a/rust/ffi/node/src/error.rs b/rust/ffi/node/src/error.rs index 3a164251..b4f5ac3d 100644 --- a/rust/ffi/node/src/error.rs +++ b/rust/ffi/node/src/error.rs @@ -22,8 +22,15 @@ use snafu::Snafu; pub enum Error { #[snafu(display("column '{name}' is missing"))] MissingColumn { name: String }, + #[snafu(display("{name}: {message}"))] + RangeError { name: String, message: String }, + #[snafu(display("{index_type} is not a valid index type"))] + InvalidIndexType { index_type: String }, + #[snafu(display("{message}"))] LanceDB { message: String }, + #[snafu(display("{message}"))] + Neon { message: String }, } pub type Result = std::result::Result; @@ -52,6 +59,14 @@ impl From for Error { } } +impl From for Error { + fn from(value: neon::result::Throw) -> Self { + Self::Neon { + message: value.to_string(), + } + } +} + /// ResultExt is used to transform a [`Result`] into a [`NeonResult`], /// so it can be returned as a JavaScript error /// Copied from [Neon](https://github.com/neon-bindings/neon/blob/4c2e455a9e6814f1ba0178616d63caec7f4df317/crates/neon/src/result/mod.rs#L88) diff --git a/rust/ffi/node/src/index/vector.rs b/rust/ffi/node/src/index/vector.rs index 495441a8..4f710e7b 100644 --- a/rust/ffi/node/src/index/vector.rs +++ b/rust/ffi/node/src/index/vector.rs @@ -22,12 +22,15 @@ use neon::prelude::*; use vectordb::index::vector::{IvfPQIndexBuilder, VectorIndexBuilder}; +use crate::error::Error::InvalidIndexType; +use crate::error::ResultExt; +use crate::neon_ext::js_object_ext::JsObjectExt; use crate::{runtime, JsTable}; pub(crate) fn table_create_vector_index(mut cx: FunctionContext) -> JsResult { let js_table = cx.this().downcast_or_throw::, _>(&mut cx)?; let index_params = cx.argument::(0)?; - let index_params_builder = get_index_params_builder(&mut cx, index_params).unwrap(); + let index_params_builder = get_index_params_builder(&mut cx, index_params).or_throw(&mut cx)?; let rt = runtime(&mut cx)?; let channel = cx.channel(); @@ -54,27 +57,21 @@ pub(crate) fn table_create_vector_index(mut cx: FunctionContext) -> JsResult, -) -> Result { - let idx_type = obj - .get::(cx, "type") - .map_err(|t| t.to_string())? - .value(cx); +) -> crate::error::Result { + let idx_type = obj.get::(cx, "type")?.value(cx); match idx_type.as_str() { "ivf_pq" => { let mut index_builder: IvfPQIndexBuilder = IvfPQIndexBuilder::new(); let mut pq_params = PQBuildParams::default(); - obj.get_opt::(cx, "column") - .map_err(|t| t.to_string())? + obj.get_opt::(cx, "column")? .map(|s| index_builder.column(s.value(cx))); - obj.get_opt::(cx, "index_name") - .map_err(|t| t.to_string())? + obj.get_opt::(cx, "index_name")? .map(|s| index_builder.index_name(s.value(cx))); - obj.get_opt::(cx, "metric_type") - .map_err(|t| t.to_string())? + obj.get_opt::(cx, "metric_type")? .map(|s| MetricType::try_from(s.value(cx).as_str())) .map(|mt| { let metric_type = mt.unwrap(); @@ -82,15 +79,8 @@ fn get_index_params_builder( pq_params.metric_type = metric_type; }); - let num_partitions = obj - .get_opt::(cx, "num_partitions") - .map_err(|t| t.to_string())? - .map(|s| s.value(cx) as usize); - - let max_iters = obj - .get_opt::(cx, "max_iters") - .map_err(|t| t.to_string())? - .map(|s| s.value(cx) as usize); + let num_partitions = obj.get_opt_usize(cx, "num_partitions")?; + let max_iters = obj.get_opt_usize(cx, "max_iters")?; num_partitions.map(|np| { let max_iters = max_iters.unwrap_or(50); @@ -102,32 +92,28 @@ fn get_index_params_builder( index_builder.ivf_params(ivf_params) }); - obj.get_opt::(cx, "use_opq") - .map_err(|t| t.to_string())? + obj.get_opt::(cx, "use_opq")? .map(|s| pq_params.use_opq = s.value(cx)); - obj.get_opt::(cx, "num_sub_vectors") - .map_err(|t| t.to_string())? - .map(|s| pq_params.num_sub_vectors = s.value(cx) as usize); + obj.get_opt_usize(cx, "num_sub_vectors")? + .map(|s| pq_params.num_sub_vectors = s); - obj.get_opt::(cx, "num_bits") - .map_err(|t| t.to_string())? - .map(|s| pq_params.num_bits = s.value(cx) as usize); + obj.get_opt_usize(cx, "num_bits")? + .map(|s| pq_params.num_bits = s); - obj.get_opt::(cx, "max_iters") - .map_err(|t| t.to_string())? - .map(|s| pq_params.max_iters = s.value(cx) as usize); + obj.get_opt_usize(cx, "max_iters")? + .map(|s| pq_params.max_iters = s); - obj.get_opt::(cx, "max_opq_iters") - .map_err(|t| t.to_string())? - .map(|s| pq_params.max_opq_iters = s.value(cx) as usize); + obj.get_opt_usize(cx, "max_opq_iters")? + .map(|s| pq_params.max_opq_iters = s); - obj.get_opt::(cx, "replace") - .map_err(|t| t.to_string())? + obj.get_opt::(cx, "replace")? .map(|s| index_builder.replace(s.value(cx))); Ok(index_builder) } - t => Err(format!("{} is not a valid index type", t).to_string()), + index_type => Err(InvalidIndexType { + index_type: index_type.into(), + }), } } diff --git a/rust/ffi/node/src/lib.rs b/rust/ffi/node/src/lib.rs index d40448af..75cd072d 100644 --- a/rust/ffi/node/src/lib.rs +++ b/rust/ffi/node/src/lib.rs @@ -31,16 +31,17 @@ use once_cell::sync::OnceCell; use tokio::runtime::Runtime; use vectordb::database::Database; -use vectordb::error::Error; use vectordb::table::{ReadParams, Table}; use crate::arrow::{arrow_buffer_to_record_batch, record_batch_to_buffer}; use crate::error::ResultExt; +use crate::neon_ext::js_object_ext::JsObjectExt; mod arrow; mod convert; mod error; mod index; +mod neon_ext; struct JsDatabase { database: Arc, @@ -245,12 +246,9 @@ fn table_search(mut cx: FunctionContext) -> JsResult { .get_opt::(&mut cx, "_filter")? .map(|s| s.value(&mut cx)); let refine_factor = query_obj - .get_opt::(&mut cx, "_refineFactor")? - .map(|s| s.value(&mut cx)) - .map(|i| i as u32); - let nprobes = query_obj - .get::(&mut cx, "_nprobes")? - .value(&mut cx) as usize; + .get_opt_u32(&mut cx, "_refineFactor") + .or_throw(&mut cx)?; + let nprobes = query_obj.get_usize(&mut cx, "_nprobes").or_throw(&mut cx)?; let metric_type = query_obj .get_opt::(&mut cx, "_metricType")? .map(|s| s.value(&mut cx)) @@ -277,7 +275,11 @@ fn table_search(mut cx: FunctionContext) -> JsResult { .select(select); let record_batch_stream = builder.execute(); let results = record_batch_stream - .and_then(|stream| stream.try_collect::>().map_err(Error::from)) + .and_then(|stream| { + stream + .try_collect::>() + .map_err(vectordb::error::Error::from) + }) .await; deferred.settle_with(&channel, move |mut cx| { diff --git a/rust/ffi/node/src/neon_ext.rs b/rust/ffi/node/src/neon_ext.rs new file mode 100644 index 00000000..e2e4b2d2 --- /dev/null +++ b/rust/ffi/node/src/neon_ext.rs @@ -0,0 +1,15 @@ +// Copyright 2023 Lance Developers. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +pub mod js_object_ext; diff --git a/rust/ffi/node/src/neon_ext/js_object_ext.rs b/rust/ffi/node/src/neon_ext/js_object_ext.rs new file mode 100644 index 00000000..6ed0fa56 --- /dev/null +++ b/rust/ffi/node/src/neon_ext/js_object_ext.rs @@ -0,0 +1,82 @@ +// Copyright 2023 Lance Developers. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::error::{Error, Result}; +use neon::prelude::*; + +// extends neon's [JsObject] with helper functions to extract properties +pub trait JsObjectExt { + fn get_opt_u32(&self, cx: &mut FunctionContext, key: &str) -> Result>; + fn get_usize(&self, cx: &mut FunctionContext, key: &str) -> Result; + fn get_opt_usize(&self, cx: &mut FunctionContext, key: &str) -> Result>; +} + +impl JsObjectExt for JsObject { + fn get_opt_u32(&self, cx: &mut FunctionContext, key: &str) -> Result> { + let val_opt = self + .get_opt::(cx, key)? + .map(|s| f64_to_u32_safe(s.value(cx), key)); + val_opt.transpose() + } + + fn get_usize(&self, cx: &mut FunctionContext, key: &str) -> Result { + let val = self.get::(cx, key)?.value(cx); + f64_to_usize_safe(val, key) + } + + fn get_opt_usize(&self, cx: &mut FunctionContext, key: &str) -> Result> { + let val_opt = self + .get_opt::(cx, key)? + .map(|s| f64_to_usize_safe(s.value(cx), key)); + val_opt.transpose() + } +} + +fn f64_to_u32_safe(n: f64, key: &str) -> Result { + use conv::*; + + n.approx_as::().map_err(|e| match e { + FloatError::NegOverflow(_) => Error::RangeError { + name: key.into(), + message: "must be > 0".to_string(), + }, + FloatError::PosOverflow(_) => Error::RangeError { + name: key.into(), + message: format!("must be < {}", u32::MAX), + }, + FloatError::NotANumber(_) => Error::RangeError { + name: key.into(), + message: "not a valid number".to_string(), + }, + }) +} + +fn f64_to_usize_safe(n: f64, key: &str) -> Result { + use conv::*; + + n.approx_as::().map_err(|e| match e { + FloatError::NegOverflow(_) => Error::RangeError { + name: key.into(), + message: "must be > 0".to_string(), + }, + FloatError::PosOverflow(_) => Error::RangeError { + name: key.into(), + message: format!("must be < {}", usize::MAX), + }, + FloatError::NotANumber(_) => Error::RangeError { + name: key.into(), + message: "not a valid number".to_string(), + }, + }) +}