diff --git a/node/src/index.ts b/node/src/index.ts index 07b96c23..9fc541f1 100644 --- a/node/src/index.ts +++ b/node/src/index.ts @@ -100,16 +100,21 @@ export class Table { } /** - * Insert records into this Table - * @param data Records to be inserted into the Table + * Insert records into this Table. * - * @param mode Append / Overwrite existing records. Default: Append + * @param data Records to be inserted into the Table * @return The number of rows added to the table */ async add (data: Array>): Promise { return tableAdd.call(this._tbl, await fromRecordsToBuffer(data), WriteMode.Append.toString()) } + /** + * Insert records into this Table, replacing its contents. + * + * @param data Records to be inserted into the Table + * @return The number of rows added to the table + */ async overwrite (data: Array>): Promise { return tableAdd.call(this._tbl, await fromRecordsToBuffer(data), WriteMode.Overwrite.toString()) } @@ -120,44 +125,75 @@ export class Table { */ export class Query { private readonly _tbl: any - private readonly _query_vector: number[] + private readonly _queryVector: number[] private _limit: number - private readonly _refine_factor?: number - private readonly _nprobes: number + private _refineFactor?: number + private _nprobes: number private readonly _columns?: string[] private _filter?: string - private readonly _metric = 'L2' + private _metricType?: MetricType constructor (tbl: any, queryVector: number[]) { this._tbl = tbl - this._query_vector = queryVector + this._queryVector = queryVector this._limit = 10 this._nprobes = 20 - this._refine_factor = undefined + this._refineFactor = undefined this._columns = undefined this._filter = undefined + this._metricType = undefined } + /*** + * Sets the number of results that will be returned + * @param value number of results + */ limit (value: number): Query { this._limit = value return this } + /** + * Refine the results by reading extra elements and re-ranking them in memory. + * @param value refine factor to use in this query. + */ + refineFactor (value: number): Query { + this._refineFactor = value + return this + } + + /** + * The number of probes used. A higher number makes search more accurate but also slower. + * @param value The number of probes used. + */ + nprobes (value: number): Query { + this._nprobes = value + return this + } + + /** + * A filter statement to be applied to this query. + * @param value A filter in the same format used by a sql WHERE clause. + */ filter (value: string): Query { this._filter = value return this } /** - * Execute the query and return the results as an Array of Objects - */ + * The MetricType used for this Query. + * @param value The metric to the. @see MetricType for the different options + */ + metricType (value: MetricType): Query { + this._metricType = value + return this + } + + /** + * Execute the query and return the results as an Array of Objects + */ async execute> (): Promise { - let buffer - if (this._filter != null) { - buffer = await tableSearch.call(this._tbl, this._query_vector, this._limit, this._filter) - } else { - buffer = await tableSearch.call(this._tbl, this._query_vector, this._limit) - } + const buffer = await tableSearch.call(this._tbl, this) const data = tableFromIPC(buffer) return data.toArray().map((entry: Record) => { const newObject: Record = {} @@ -177,3 +213,18 @@ export enum WriteMode { Overwrite = 'overwrite', Append = 'append' } + +/** + * Distance metrics type. + */ +export enum MetricType { + /** + * Euclidean distance + */ + L2 = 'l2', + + /** + * Cosine distance + */ + Cosine = 'cosine' +} diff --git a/node/src/test/test.ts b/node/src/test/test.ts index b185e17d..121923bd 100644 --- a/node/src/test/test.ts +++ b/node/src/test/test.ts @@ -17,6 +17,7 @@ import { assert } from 'chai' import { track } from 'temp' import * as lancedb from '../index' +import { MetricType, Query } from '../index' describe('LanceDB client', function () { describe('when creating a connection to lancedb', function () { @@ -132,6 +133,20 @@ describe('LanceDB client', function () { }) }) +describe('Query object', function () { + it('sets custom parameters', async function () { + const query = new Query(undefined, [0.1, 0.3]) + .limit(1) + .metricType(MetricType.Cosine) + .refineFactor(100) + .nprobes(20) as Record + assert.equal(query._limit, 1) + assert.equal(query._metricType, MetricType.Cosine) + assert.equal(query._refineFactor, 100) + assert.equal(query._nprobes, 20) + }) +}) + async function createTestDB (): Promise { const dir = await track().mkdir('lancejs') const con = await lancedb.connect(dir) diff --git a/rust/ffi/node/src/lib.rs b/rust/ffi/node/src/lib.rs index d4fc814d..48a639c9 100644 --- a/rust/ffi/node/src/lib.rs +++ b/rust/ffi/node/src/lib.rs @@ -13,6 +13,7 @@ // limitations under the License. use std::collections::HashMap; +use std::convert::TryFrom; use std::ops::Deref; use std::sync::{Arc, Mutex}; @@ -21,6 +22,7 @@ use arrow_ipc::writer::FileWriter; use futures::{TryFutureExt, TryStreamExt}; use lance::arrow::RecordBatchBuffer; use lance::dataset::WriteMode; +use lance::index::vector::MetricType; use neon::prelude::*; use neon::types::buffer::TypedArray; use once_cell::sync::OnceCell; @@ -39,12 +41,12 @@ struct JsDatabase { database: Arc, } +impl Finalize for JsDatabase {} + struct JsTable { table: Arc>, } -impl Finalize for JsDatabase {} - impl Finalize for JsTable {} fn runtime<'a, C: Context<'a>>(cx: &mut C) -> NeonResult<&'static Runtime> { @@ -87,7 +89,9 @@ fn database_open_table(mut cx: FunctionContext) -> JsResult { let table_rst = database.open_table(table_name).await; deferred.settle_with(&channel, move |mut cx| { - let table = Arc::new(Mutex::new(table_rst.or_else(|err| cx.throw_error(err.to_string()))?)); + let table = Arc::new(Mutex::new( + table_rst.or_else(|err| cx.throw_error(err.to_string()))?, + )); Ok(cx.boxed(JsTable { table })) }); }); @@ -96,15 +100,32 @@ fn database_open_table(mut cx: FunctionContext) -> JsResult { fn table_search(mut cx: FunctionContext) -> JsResult { let js_table = cx.this().downcast_or_throw::, _>(&mut cx)?; - let query_vector = cx.argument::(0)?; //. .as_value(&mut cx); - let limit = cx.argument::(1)?.value(&mut cx); - let filter = cx.argument_opt(2).map(|f| f.downcast_or_throw::(&mut cx).unwrap().value(&mut cx)); + let query_obj = cx.argument::(0)?; + + let limit = query_obj + .get::(&mut cx, "_limit")? + .value(&mut cx); + let filter = query_obj + .get_opt::(&mut cx, "_filter")? + .map(|s| s.value(&mut cx)); + let refine_factor = query_obj + .get_opt::(&mut cx, "_refineFactor")? + .map(|s| s.value(&mut cx)) + .map(|i| i as u32); + let nprobes = query_obj + .get::(&mut cx, "_nprobes")? + .value(&mut cx) as usize; + let metric_type = query_obj + .get_opt::(&mut cx, "_metricType")? + .map(|s| s.value(&mut cx)) + .map(|s| MetricType::try_from(s.as_str()).unwrap()); let rt = runtime(&mut cx)?; let channel = cx.channel(); let (deferred, promise) = cx.promise(); let table = js_table.table.clone(); + let query_vector = query_obj.get::(&mut cx, "_queryVector")?; let query = convert::js_array_to_vec(query_vector.deref(), &mut cx); rt.spawn(async move { @@ -113,7 +134,10 @@ fn table_search(mut cx: FunctionContext) -> JsResult { .unwrap() .search(Float32Array::from(query)) .limit(limit as usize) - .filter(filter); + .refine_factor(refine_factor) + .nprobes(nprobes) + .filter(filter) + .metric_type(metric_type); let record_batch_stream = builder.execute(); let results = record_batch_stream .and_then(|stream| stream.try_collect::>().map_err(Error::from)) @@ -164,7 +188,9 @@ fn table_create(mut cx: FunctionContext) -> JsResult { let table_rst = database.create_table(table_name, batch_reader).await; deferred.settle_with(&channel, move |mut cx| { - let table = Arc::new(Mutex::new(table_rst.or_else(|err| cx.throw_error(err.to_string()))?)); + let table = Arc::new(Mutex::new( + table_rst.or_else(|err| cx.throw_error(err.to_string()))?, + )); Ok(cx.boxed(JsTable { table })) }); }); @@ -178,9 +204,7 @@ fn table_add(mut cx: FunctionContext) -> JsResult { ("overwrite", WriteMode::Overwrite), ]); - let js_table = cx - .this() - .downcast_or_throw::, _>(&mut cx)?; + let js_table = cx.this().downcast_or_throw::, _>(&mut cx)?; let buffer = cx.argument::(0)?; let write_mode = cx.argument::(1)?.value(&mut cx); let batches = arrow_buffer_to_record_batch(buffer.as_slice(&mut cx)); @@ -204,7 +228,6 @@ fn table_add(mut cx: FunctionContext) -> JsResult { Ok(promise) } - #[neon::main] fn main(mut cx: ModuleContext) -> NeonResult<()> { cx.export_function("databaseNew", database_new)?; diff --git a/rust/vectordb/src/query.rs b/rust/vectordb/src/query.rs index fcbda05f..aac6134d 100644 --- a/rust/vectordb/src/query.rs +++ b/rust/vectordb/src/query.rs @@ -29,7 +29,7 @@ pub struct Query { pub filter: Option, pub nprobes: usize, pub refine_factor: Option, - pub metric_type: MetricType, + pub metric_type: Option, pub use_index: bool, } @@ -51,9 +51,9 @@ impl Query { limit: 10, nprobes: 20, refine_factor: None, - metric_type: MetricType::L2, + metric_type: None, use_index: false, - filter: None + filter: None, } } @@ -71,10 +71,10 @@ impl Query { self.limit, )?; scanner.nprobs(self.nprobes); - scanner.distance_metric(self.metric_type); scanner.use_index(self.use_index); self.filter.as_ref().map(|f| scanner.filter(f)); self.refine_factor.map(|rf| scanner.refine(rf)); + self.metric_type.map(|mt| scanner.distance_metric(mt)); Ok(scanner.try_into_stream().await?) } @@ -123,7 +123,7 @@ impl Query { /// # Arguments /// /// * `metric_type` - The distance metric to use. By default [MetricType::L2] is used. - pub fn metric_type(mut self, metric_type: MetricType) -> Query { + pub fn metric_type(mut self, metric_type: Option) -> Query { self.metric_type = metric_type; self } @@ -174,14 +174,14 @@ mod tests { .limit(100) .nprobes(1000) .use_index(true) - .metric_type(MetricType::Cosine) + .metric_type(Some(MetricType::Cosine)) .refine_factor(Some(999)); assert_eq!(query.query_vector, new_vector); assert_eq!(query.limit, 100); assert_eq!(query.nprobes, 1000); assert_eq!(query.use_index, true); - assert_eq!(query.metric_type, MetricType::Cosine); + assert_eq!(query.metric_type, Some(MetricType::Cosine)); assert_eq!(query.refine_factor, Some(999)); } diff --git a/rust/vectordb/src/table.rs b/rust/vectordb/src/table.rs index bab07d77..d5781954 100644 --- a/rust/vectordb/src/table.rs +++ b/rust/vectordb/src/table.rs @@ -80,7 +80,11 @@ impl Table { let dataset = Arc::new(Dataset::write(&mut batches, path, Some(WriteParams::default())).await?); - Ok(Table { name, path: path.to_string(), dataset }) + Ok(Table { + name, + path: path.to_string(), + dataset, + }) } /// Insert records into this Table @@ -95,12 +99,13 @@ impl Table { pub async fn add( &mut self, mut batches: Box, - write_mode: Option + write_mode: Option, ) -> Result { let mut params = WriteParams::default(); params.mode = write_mode.unwrap_or(WriteMode::Append); - self.dataset = Arc::new(Dataset::write(&mut batches, self.path.as_str(), Some(params)).await?); + self.dataset = + Arc::new(Dataset::write(&mut batches, self.path.as_str(), Some(params)).await?); Ok(batches.count()) } @@ -171,14 +176,17 @@ mod tests { let batches: Box = Box::new(make_test_batches()); let schema = batches.schema().clone(); - let mut table = Table::create(Arc::new(path_buf), "test".to_string(), batches).await.unwrap(); + let mut table = Table::create(Arc::new(path_buf), "test".to_string(), batches) + .await + .unwrap(); assert_eq!(table.count_rows().await.unwrap(), 10); - let new_batches: Box = Box::new(RecordBatchBuffer::new(vec![RecordBatch::try_new( - schema, - vec![Arc::new(Int32Array::from_iter_values(100..110))], - ) - .unwrap()])); + let new_batches: Box = + Box::new(RecordBatchBuffer::new(vec![RecordBatch::try_new( + schema, + vec![Arc::new(Int32Array::from_iter_values(100..110))], + ) + .unwrap()])); table.add(new_batches, None).await.unwrap(); assert_eq!(table.count_rows().await.unwrap(), 20); @@ -192,15 +200,22 @@ mod tests { let batches: Box = Box::new(make_test_batches()); let schema = batches.schema().clone(); - let mut table = Table::create(Arc::new(path_buf), "test".to_string(), batches).await.unwrap(); + let mut table = Table::create(Arc::new(path_buf), "test".to_string(), batches) + .await + .unwrap(); assert_eq!(table.count_rows().await.unwrap(), 10); - let new_batches: Box = Box::new(RecordBatchBuffer::new(vec![RecordBatch::try_new( - schema, - vec![Arc::new(Int32Array::from_iter_values(100..110))], - ).unwrap()])); + let new_batches: Box = + Box::new(RecordBatchBuffer::new(vec![RecordBatch::try_new( + schema, + vec![Arc::new(Int32Array::from_iter_values(100..110))], + ) + .unwrap()])); - table.add(new_batches, Some(WriteMode::Overwrite)).await.unwrap(); + table + .add(new_batches, Some(WriteMode::Overwrite)) + .await + .unwrap(); assert_eq!(table.count_rows().await.unwrap(), 10); assert_eq!(table.name, "test"); }