fix(node): Handle overflows in the node bridge (#372)

- Fixes many numeric conversions that results in hard to reproduce issues
- JsObjectExt extends JsObject with safe methods to extract numericvalues
This commit is contained in:
gsilvestrin
2023-07-28 13:15:21 -07:00
committed by GitHub
parent 1daecac648
commit bcd7f66dc7
7 changed files with 155 additions and 46 deletions

View File

@@ -250,6 +250,14 @@ describe('LanceDB client', function () {
const createIndex = table.createIndex({ type: 'ivf_pq', column: 'name', num_partitions: 2, max_iters: 2, num_sub_vectors: 2 })
await expect(createIndex).to.be.rejectedWith(/VectorIndex requires the column data type to be fixed size list of float32s/)
})
it('it should fail when the column is not a vector', async function () {
const uri = await createTestDB(32, 300)
const con = await lancedb.connect(uri)
const table = await con.openTable('vectors')
const createIndex = table.createIndex({ type: 'ivf_pq', column: 'name', num_partitions: -1, max_iters: 2, num_sub_vectors: 2 })
await expect(createIndex).to.be.rejectedWith('num_partitions: must be > 0')
})
})
describe('when using a custom embedding function', function () {

View File

@@ -13,6 +13,7 @@ crate-type = ["cdylib"]
arrow-array = { workspace = true }
arrow-ipc = { workspace = true }
arrow-schema = { workspace = true }
conv = "0.3.3"
once_cell = "1"
futures = "0.3"
half = { workspace = true }

View File

@@ -22,8 +22,15 @@ use snafu::Snafu;
pub enum Error {
#[snafu(display("column '{name}' is missing"))]
MissingColumn { name: String },
#[snafu(display("{name}: {message}"))]
RangeError { name: String, message: String },
#[snafu(display("{index_type} is not a valid index type"))]
InvalidIndexType { index_type: String },
#[snafu(display("{message}"))]
LanceDB { message: String },
#[snafu(display("{message}"))]
Neon { message: String },
}
pub type Result<T> = std::result::Result<T, Error>;
@@ -52,6 +59,14 @@ impl From<ArrowError> for Error {
}
}
impl From<neon::result::Throw> for Error {
fn from(value: neon::result::Throw) -> Self {
Self::Neon {
message: value.to_string(),
}
}
}
/// ResultExt is used to transform a [`Result`] into a [`NeonResult`],
/// so it can be returned as a JavaScript error
/// Copied from [Neon](https://github.com/neon-bindings/neon/blob/4c2e455a9e6814f1ba0178616d63caec7f4df317/crates/neon/src/result/mod.rs#L88)

View File

@@ -22,12 +22,15 @@ use neon::prelude::*;
use vectordb::index::vector::{IvfPQIndexBuilder, VectorIndexBuilder};
use crate::error::Error::InvalidIndexType;
use crate::error::ResultExt;
use crate::neon_ext::js_object_ext::JsObjectExt;
use crate::{runtime, JsTable};
pub(crate) fn table_create_vector_index(mut cx: FunctionContext) -> JsResult<JsPromise> {
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
let index_params = cx.argument::<JsObject>(0)?;
let index_params_builder = get_index_params_builder(&mut cx, index_params).unwrap();
let index_params_builder = get_index_params_builder(&mut cx, index_params).or_throw(&mut cx)?;
let rt = runtime(&mut cx)?;
let channel = cx.channel();
@@ -54,27 +57,21 @@ pub(crate) fn table_create_vector_index(mut cx: FunctionContext) -> JsResult<JsP
fn get_index_params_builder(
cx: &mut FunctionContext,
obj: Handle<JsObject>,
) -> Result<impl VectorIndexBuilder, String> {
let idx_type = obj
.get::<JsString, _, _>(cx, "type")
.map_err(|t| t.to_string())?
.value(cx);
) -> crate::error::Result<impl VectorIndexBuilder> {
let idx_type = obj.get::<JsString, _, _>(cx, "type")?.value(cx);
match idx_type.as_str() {
"ivf_pq" => {
let mut index_builder: IvfPQIndexBuilder = IvfPQIndexBuilder::new();
let mut pq_params = PQBuildParams::default();
obj.get_opt::<JsString, _, _>(cx, "column")
.map_err(|t| t.to_string())?
obj.get_opt::<JsString, _, _>(cx, "column")?
.map(|s| index_builder.column(s.value(cx)));
obj.get_opt::<JsString, _, _>(cx, "index_name")
.map_err(|t| t.to_string())?
obj.get_opt::<JsString, _, _>(cx, "index_name")?
.map(|s| index_builder.index_name(s.value(cx)));
obj.get_opt::<JsString, _, _>(cx, "metric_type")
.map_err(|t| t.to_string())?
obj.get_opt::<JsString, _, _>(cx, "metric_type")?
.map(|s| MetricType::try_from(s.value(cx).as_str()))
.map(|mt| {
let metric_type = mt.unwrap();
@@ -82,15 +79,8 @@ fn get_index_params_builder(
pq_params.metric_type = metric_type;
});
let num_partitions = obj
.get_opt::<JsNumber, _, _>(cx, "num_partitions")
.map_err(|t| t.to_string())?
.map(|s| s.value(cx) as usize);
let max_iters = obj
.get_opt::<JsNumber, _, _>(cx, "max_iters")
.map_err(|t| t.to_string())?
.map(|s| s.value(cx) as usize);
let num_partitions = obj.get_opt_usize(cx, "num_partitions")?;
let max_iters = obj.get_opt_usize(cx, "max_iters")?;
num_partitions.map(|np| {
let max_iters = max_iters.unwrap_or(50);
@@ -102,32 +92,28 @@ fn get_index_params_builder(
index_builder.ivf_params(ivf_params)
});
obj.get_opt::<JsBoolean, _, _>(cx, "use_opq")
.map_err(|t| t.to_string())?
obj.get_opt::<JsBoolean, _, _>(cx, "use_opq")?
.map(|s| pq_params.use_opq = s.value(cx));
obj.get_opt::<JsNumber, _, _>(cx, "num_sub_vectors")
.map_err(|t| t.to_string())?
.map(|s| pq_params.num_sub_vectors = s.value(cx) as usize);
obj.get_opt_usize(cx, "num_sub_vectors")?
.map(|s| pq_params.num_sub_vectors = s);
obj.get_opt::<JsNumber, _, _>(cx, "num_bits")
.map_err(|t| t.to_string())?
.map(|s| pq_params.num_bits = s.value(cx) as usize);
obj.get_opt_usize(cx, "num_bits")?
.map(|s| pq_params.num_bits = s);
obj.get_opt::<JsNumber, _, _>(cx, "max_iters")
.map_err(|t| t.to_string())?
.map(|s| pq_params.max_iters = s.value(cx) as usize);
obj.get_opt_usize(cx, "max_iters")?
.map(|s| pq_params.max_iters = s);
obj.get_opt::<JsNumber, _, _>(cx, "max_opq_iters")
.map_err(|t| t.to_string())?
.map(|s| pq_params.max_opq_iters = s.value(cx) as usize);
obj.get_opt_usize(cx, "max_opq_iters")?
.map(|s| pq_params.max_opq_iters = s);
obj.get_opt::<JsBoolean, _, _>(cx, "replace")
.map_err(|t| t.to_string())?
obj.get_opt::<JsBoolean, _, _>(cx, "replace")?
.map(|s| index_builder.replace(s.value(cx)));
Ok(index_builder)
}
t => Err(format!("{} is not a valid index type", t).to_string()),
index_type => Err(InvalidIndexType {
index_type: index_type.into(),
}),
}
}

View File

@@ -31,16 +31,17 @@ use once_cell::sync::OnceCell;
use tokio::runtime::Runtime;
use vectordb::database::Database;
use vectordb::error::Error;
use vectordb::table::{ReadParams, Table};
use crate::arrow::{arrow_buffer_to_record_batch, record_batch_to_buffer};
use crate::error::ResultExt;
use crate::neon_ext::js_object_ext::JsObjectExt;
mod arrow;
mod convert;
mod error;
mod index;
mod neon_ext;
struct JsDatabase {
database: Arc<Database>,
@@ -245,12 +246,9 @@ fn table_search(mut cx: FunctionContext) -> JsResult<JsPromise> {
.get_opt::<JsString, _, _>(&mut cx, "_filter")?
.map(|s| s.value(&mut cx));
let refine_factor = query_obj
.get_opt::<JsNumber, _, _>(&mut cx, "_refineFactor")?
.map(|s| s.value(&mut cx))
.map(|i| i as u32);
let nprobes = query_obj
.get::<JsNumber, _, _>(&mut cx, "_nprobes")?
.value(&mut cx) as usize;
.get_opt_u32(&mut cx, "_refineFactor")
.or_throw(&mut cx)?;
let nprobes = query_obj.get_usize(&mut cx, "_nprobes").or_throw(&mut cx)?;
let metric_type = query_obj
.get_opt::<JsString, _, _>(&mut cx, "_metricType")?
.map(|s| s.value(&mut cx))
@@ -277,7 +275,11 @@ fn table_search(mut cx: FunctionContext) -> JsResult<JsPromise> {
.select(select);
let record_batch_stream = builder.execute();
let results = record_batch_stream
.and_then(|stream| stream.try_collect::<Vec<_>>().map_err(Error::from))
.and_then(|stream| {
stream
.try_collect::<Vec<_>>()
.map_err(vectordb::error::Error::from)
})
.await;
deferred.settle_with(&channel, move |mut cx| {

View File

@@ -0,0 +1,15 @@
// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
pub mod js_object_ext;

View File

@@ -0,0 +1,82 @@
// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::error::{Error, Result};
use neon::prelude::*;
// extends neon's [JsObject] with helper functions to extract properties
pub trait JsObjectExt {
fn get_opt_u32(&self, cx: &mut FunctionContext, key: &str) -> Result<Option<u32>>;
fn get_usize(&self, cx: &mut FunctionContext, key: &str) -> Result<usize>;
fn get_opt_usize(&self, cx: &mut FunctionContext, key: &str) -> Result<Option<usize>>;
}
impl JsObjectExt for JsObject {
fn get_opt_u32(&self, cx: &mut FunctionContext, key: &str) -> Result<Option<u32>> {
let val_opt = self
.get_opt::<JsNumber, _, _>(cx, key)?
.map(|s| f64_to_u32_safe(s.value(cx), key));
val_opt.transpose()
}
fn get_usize(&self, cx: &mut FunctionContext, key: &str) -> Result<usize> {
let val = self.get::<JsNumber, _, _>(cx, key)?.value(cx);
f64_to_usize_safe(val, key)
}
fn get_opt_usize(&self, cx: &mut FunctionContext, key: &str) -> Result<Option<usize>> {
let val_opt = self
.get_opt::<JsNumber, _, _>(cx, key)?
.map(|s| f64_to_usize_safe(s.value(cx), key));
val_opt.transpose()
}
}
fn f64_to_u32_safe(n: f64, key: &str) -> Result<u32> {
use conv::*;
n.approx_as::<u32>().map_err(|e| match e {
FloatError::NegOverflow(_) => Error::RangeError {
name: key.into(),
message: "must be > 0".to_string(),
},
FloatError::PosOverflow(_) => Error::RangeError {
name: key.into(),
message: format!("must be < {}", u32::MAX),
},
FloatError::NotANumber(_) => Error::RangeError {
name: key.into(),
message: "not a valid number".to_string(),
},
})
}
fn f64_to_usize_safe(n: f64, key: &str) -> Result<usize> {
use conv::*;
n.approx_as::<usize>().map_err(|e| match e {
FloatError::NegOverflow(_) => Error::RangeError {
name: key.into(),
message: "must be > 0".to_string(),
},
FloatError::PosOverflow(_) => Error::RangeError {
name: key.into(),
message: format!("must be < {}", usize::MAX),
},
FloatError::NotANumber(_) => Error::RangeError {
name: key.into(),
message: "not a valid number".to_string(),
},
})
}