add query params to to nodejs client (#87)

This commit is contained in:
gsilvestrin
2023-05-24 15:48:31 -06:00
committed by GitHub
parent bdef634954
commit 06cb7b6458
5 changed files with 155 additions and 51 deletions

View File

@@ -100,16 +100,21 @@ export class Table {
}
/**
* Insert records into this Table
* @param data Records to be inserted into the Table
* Insert records into this Table.
*
* @param mode Append / Overwrite existing records. Default: Append
* @param data Records to be inserted into the Table
* @return The number of rows added to the table
*/
async add (data: Array<Record<string, unknown>>): Promise<number> {
return tableAdd.call(this._tbl, await fromRecordsToBuffer(data), WriteMode.Append.toString())
}
/**
* Insert records into this Table, replacing its contents.
*
* @param data Records to be inserted into the Table
* @return The number of rows added to the table
*/
async overwrite (data: Array<Record<string, unknown>>): Promise<number> {
return tableAdd.call(this._tbl, await fromRecordsToBuffer(data), WriteMode.Overwrite.toString())
}
@@ -120,44 +125,75 @@ export class Table {
*/
export class Query {
private readonly _tbl: any
private readonly _query_vector: number[]
private readonly _queryVector: number[]
private _limit: number
private readonly _refine_factor?: number
private readonly _nprobes: number
private _refineFactor?: number
private _nprobes: number
private readonly _columns?: string[]
private _filter?: string
private readonly _metric = 'L2'
private _metricType?: MetricType
constructor (tbl: any, queryVector: number[]) {
this._tbl = tbl
this._query_vector = queryVector
this._queryVector = queryVector
this._limit = 10
this._nprobes = 20
this._refine_factor = undefined
this._refineFactor = undefined
this._columns = undefined
this._filter = undefined
this._metricType = undefined
}
/***
* Sets the number of results that will be returned
* @param value number of results
*/
limit (value: number): Query {
this._limit = value
return this
}
/**
* Refine the results by reading extra elements and re-ranking them in memory.
* @param value refine factor to use in this query.
*/
refineFactor (value: number): Query {
this._refineFactor = value
return this
}
/**
* The number of probes used. A higher number makes search more accurate but also slower.
* @param value The number of probes used.
*/
nprobes (value: number): Query {
this._nprobes = value
return this
}
/**
* A filter statement to be applied to this query.
* @param value A filter in the same format used by a sql WHERE clause.
*/
filter (value: string): Query {
this._filter = value
return this
}
/**
* Execute the query and return the results as an Array of Objects
*/
* The MetricType used for this Query.
* @param value The metric to the. @see MetricType for the different options
*/
metricType (value: MetricType): Query {
this._metricType = value
return this
}
/**
* Execute the query and return the results as an Array of Objects
*/
async execute<T = Record<string, unknown>> (): Promise<T[]> {
let buffer
if (this._filter != null) {
buffer = await tableSearch.call(this._tbl, this._query_vector, this._limit, this._filter)
} else {
buffer = await tableSearch.call(this._tbl, this._query_vector, this._limit)
}
const buffer = await tableSearch.call(this._tbl, this)
const data = tableFromIPC(buffer)
return data.toArray().map((entry: Record<string, unknown>) => {
const newObject: Record<string, unknown> = {}
@@ -177,3 +213,18 @@ export enum WriteMode {
Overwrite = 'overwrite',
Append = 'append'
}
/**
* Distance metrics type.
*/
export enum MetricType {
/**
* Euclidean distance
*/
L2 = 'l2',
/**
* Cosine distance
*/
Cosine = 'cosine'
}

View File

@@ -17,6 +17,7 @@ import { assert } from 'chai'
import { track } from 'temp'
import * as lancedb from '../index'
import { MetricType, Query } from '../index'
describe('LanceDB client', function () {
describe('when creating a connection to lancedb', function () {
@@ -132,6 +133,20 @@ describe('LanceDB client', function () {
})
})
describe('Query object', function () {
it('sets custom parameters', async function () {
const query = new Query(undefined, [0.1, 0.3])
.limit(1)
.metricType(MetricType.Cosine)
.refineFactor(100)
.nprobes(20) as Record<string, any>
assert.equal(query._limit, 1)
assert.equal(query._metricType, MetricType.Cosine)
assert.equal(query._refineFactor, 100)
assert.equal(query._nprobes, 20)
})
})
async function createTestDB (): Promise<string> {
const dir = await track().mkdir('lancejs')
const con = await lancedb.connect(dir)

View File

@@ -13,6 +13,7 @@
// limitations under the License.
use std::collections::HashMap;
use std::convert::TryFrom;
use std::ops::Deref;
use std::sync::{Arc, Mutex};
@@ -21,6 +22,7 @@ use arrow_ipc::writer::FileWriter;
use futures::{TryFutureExt, TryStreamExt};
use lance::arrow::RecordBatchBuffer;
use lance::dataset::WriteMode;
use lance::index::vector::MetricType;
use neon::prelude::*;
use neon::types::buffer::TypedArray;
use once_cell::sync::OnceCell;
@@ -39,12 +41,12 @@ struct JsDatabase {
database: Arc<Database>,
}
impl Finalize for JsDatabase {}
struct JsTable {
table: Arc<Mutex<Table>>,
}
impl Finalize for JsDatabase {}
impl Finalize for JsTable {}
fn runtime<'a, C: Context<'a>>(cx: &mut C) -> NeonResult<&'static Runtime> {
@@ -87,7 +89,9 @@ fn database_open_table(mut cx: FunctionContext) -> JsResult<JsPromise> {
let table_rst = database.open_table(table_name).await;
deferred.settle_with(&channel, move |mut cx| {
let table = Arc::new(Mutex::new(table_rst.or_else(|err| cx.throw_error(err.to_string()))?));
let table = Arc::new(Mutex::new(
table_rst.or_else(|err| cx.throw_error(err.to_string()))?,
));
Ok(cx.boxed(JsTable { table }))
});
});
@@ -96,15 +100,32 @@ fn database_open_table(mut cx: FunctionContext) -> JsResult<JsPromise> {
fn table_search(mut cx: FunctionContext) -> JsResult<JsPromise> {
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
let query_vector = cx.argument::<JsArray>(0)?; //. .as_value(&mut cx);
let limit = cx.argument::<JsNumber>(1)?.value(&mut cx);
let filter = cx.argument_opt(2).map(|f| f.downcast_or_throw::<JsString, _>(&mut cx).unwrap().value(&mut cx));
let query_obj = cx.argument::<JsObject>(0)?;
let limit = query_obj
.get::<JsNumber, _, _>(&mut cx, "_limit")?
.value(&mut cx);
let filter = query_obj
.get_opt::<JsString, _, _>(&mut cx, "_filter")?
.map(|s| s.value(&mut cx));
let refine_factor = query_obj
.get_opt::<JsNumber, _, _>(&mut cx, "_refineFactor")?
.map(|s| s.value(&mut cx))
.map(|i| i as u32);
let nprobes = query_obj
.get::<JsNumber, _, _>(&mut cx, "_nprobes")?
.value(&mut cx) as usize;
let metric_type = query_obj
.get_opt::<JsString, _, _>(&mut cx, "_metricType")?
.map(|s| s.value(&mut cx))
.map(|s| MetricType::try_from(s.as_str()).unwrap());
let rt = runtime(&mut cx)?;
let channel = cx.channel();
let (deferred, promise) = cx.promise();
let table = js_table.table.clone();
let query_vector = query_obj.get::<JsArray, _, _>(&mut cx, "_queryVector")?;
let query = convert::js_array_to_vec(query_vector.deref(), &mut cx);
rt.spawn(async move {
@@ -113,7 +134,10 @@ fn table_search(mut cx: FunctionContext) -> JsResult<JsPromise> {
.unwrap()
.search(Float32Array::from(query))
.limit(limit as usize)
.filter(filter);
.refine_factor(refine_factor)
.nprobes(nprobes)
.filter(filter)
.metric_type(metric_type);
let record_batch_stream = builder.execute();
let results = record_batch_stream
.and_then(|stream| stream.try_collect::<Vec<_>>().map_err(Error::from))
@@ -164,7 +188,9 @@ fn table_create(mut cx: FunctionContext) -> JsResult<JsPromise> {
let table_rst = database.create_table(table_name, batch_reader).await;
deferred.settle_with(&channel, move |mut cx| {
let table = Arc::new(Mutex::new(table_rst.or_else(|err| cx.throw_error(err.to_string()))?));
let table = Arc::new(Mutex::new(
table_rst.or_else(|err| cx.throw_error(err.to_string()))?,
));
Ok(cx.boxed(JsTable { table }))
});
});
@@ -178,9 +204,7 @@ fn table_add(mut cx: FunctionContext) -> JsResult<JsPromise> {
("overwrite", WriteMode::Overwrite),
]);
let js_table = cx
.this()
.downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
let buffer = cx.argument::<JsBuffer>(0)?;
let write_mode = cx.argument::<JsString>(1)?.value(&mut cx);
let batches = arrow_buffer_to_record_batch(buffer.as_slice(&mut cx));
@@ -204,7 +228,6 @@ fn table_add(mut cx: FunctionContext) -> JsResult<JsPromise> {
Ok(promise)
}
#[neon::main]
fn main(mut cx: ModuleContext) -> NeonResult<()> {
cx.export_function("databaseNew", database_new)?;

View File

@@ -29,7 +29,7 @@ pub struct Query {
pub filter: Option<String>,
pub nprobes: usize,
pub refine_factor: Option<u32>,
pub metric_type: MetricType,
pub metric_type: Option<MetricType>,
pub use_index: bool,
}
@@ -51,9 +51,9 @@ impl Query {
limit: 10,
nprobes: 20,
refine_factor: None,
metric_type: MetricType::L2,
metric_type: None,
use_index: false,
filter: None
filter: None,
}
}
@@ -71,10 +71,10 @@ impl Query {
self.limit,
)?;
scanner.nprobs(self.nprobes);
scanner.distance_metric(self.metric_type);
scanner.use_index(self.use_index);
self.filter.as_ref().map(|f| scanner.filter(f));
self.refine_factor.map(|rf| scanner.refine(rf));
self.metric_type.map(|mt| scanner.distance_metric(mt));
Ok(scanner.try_into_stream().await?)
}
@@ -123,7 +123,7 @@ impl Query {
/// # Arguments
///
/// * `metric_type` - The distance metric to use. By default [MetricType::L2] is used.
pub fn metric_type(mut self, metric_type: MetricType) -> Query {
pub fn metric_type(mut self, metric_type: Option<MetricType>) -> Query {
self.metric_type = metric_type;
self
}
@@ -174,14 +174,14 @@ mod tests {
.limit(100)
.nprobes(1000)
.use_index(true)
.metric_type(MetricType::Cosine)
.metric_type(Some(MetricType::Cosine))
.refine_factor(Some(999));
assert_eq!(query.query_vector, new_vector);
assert_eq!(query.limit, 100);
assert_eq!(query.nprobes, 1000);
assert_eq!(query.use_index, true);
assert_eq!(query.metric_type, MetricType::Cosine);
assert_eq!(query.metric_type, Some(MetricType::Cosine));
assert_eq!(query.refine_factor, Some(999));
}

View File

@@ -80,7 +80,11 @@ impl Table {
let dataset =
Arc::new(Dataset::write(&mut batches, path, Some(WriteParams::default())).await?);
Ok(Table { name, path: path.to_string(), dataset })
Ok(Table {
name,
path: path.to_string(),
dataset,
})
}
/// Insert records into this Table
@@ -95,12 +99,13 @@ impl Table {
pub async fn add(
&mut self,
mut batches: Box<dyn RecordBatchReader>,
write_mode: Option<WriteMode>
write_mode: Option<WriteMode>,
) -> Result<usize> {
let mut params = WriteParams::default();
params.mode = write_mode.unwrap_or(WriteMode::Append);
self.dataset = Arc::new(Dataset::write(&mut batches, self.path.as_str(), Some(params)).await?);
self.dataset =
Arc::new(Dataset::write(&mut batches, self.path.as_str(), Some(params)).await?);
Ok(batches.count())
}
@@ -171,14 +176,17 @@ mod tests {
let batches: Box<dyn RecordBatchReader> = Box::new(make_test_batches());
let schema = batches.schema().clone();
let mut table = Table::create(Arc::new(path_buf), "test".to_string(), batches).await.unwrap();
let mut table = Table::create(Arc::new(path_buf), "test".to_string(), batches)
.await
.unwrap();
assert_eq!(table.count_rows().await.unwrap(), 10);
let new_batches: Box<dyn RecordBatchReader> = Box::new(RecordBatchBuffer::new(vec![RecordBatch::try_new(
schema,
vec![Arc::new(Int32Array::from_iter_values(100..110))],
)
.unwrap()]));
let new_batches: Box<dyn RecordBatchReader> =
Box::new(RecordBatchBuffer::new(vec![RecordBatch::try_new(
schema,
vec![Arc::new(Int32Array::from_iter_values(100..110))],
)
.unwrap()]));
table.add(new_batches, None).await.unwrap();
assert_eq!(table.count_rows().await.unwrap(), 20);
@@ -192,15 +200,22 @@ mod tests {
let batches: Box<dyn RecordBatchReader> = Box::new(make_test_batches());
let schema = batches.schema().clone();
let mut table = Table::create(Arc::new(path_buf), "test".to_string(), batches).await.unwrap();
let mut table = Table::create(Arc::new(path_buf), "test".to_string(), batches)
.await
.unwrap();
assert_eq!(table.count_rows().await.unwrap(), 10);
let new_batches: Box<dyn RecordBatchReader> = Box::new(RecordBatchBuffer::new(vec![RecordBatch::try_new(
schema,
vec![Arc::new(Int32Array::from_iter_values(100..110))],
).unwrap()]));
let new_batches: Box<dyn RecordBatchReader> =
Box::new(RecordBatchBuffer::new(vec![RecordBatch::try_new(
schema,
vec![Arc::new(Int32Array::from_iter_values(100..110))],
)
.unwrap()]));
table.add(new_batches, Some(WriteMode::Overwrite)).await.unwrap();
table
.add(new_batches, Some(WriteMode::Overwrite))
.await
.unwrap();
assert_eq!(table.count_rows().await.unwrap(), 10);
assert_eq!(table.name, "test");
}