mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-05 19:32:56 +00:00
add query params to to nodejs client (#87)
This commit is contained in:
@@ -100,16 +100,21 @@ export class Table {
|
||||
}
|
||||
|
||||
/**
|
||||
* Insert records into this Table
|
||||
* @param data Records to be inserted into the Table
|
||||
* Insert records into this Table.
|
||||
*
|
||||
* @param mode Append / Overwrite existing records. Default: Append
|
||||
* @param data Records to be inserted into the Table
|
||||
* @return The number of rows added to the table
|
||||
*/
|
||||
async add (data: Array<Record<string, unknown>>): Promise<number> {
|
||||
return tableAdd.call(this._tbl, await fromRecordsToBuffer(data), WriteMode.Append.toString())
|
||||
}
|
||||
|
||||
/**
|
||||
* Insert records into this Table, replacing its contents.
|
||||
*
|
||||
* @param data Records to be inserted into the Table
|
||||
* @return The number of rows added to the table
|
||||
*/
|
||||
async overwrite (data: Array<Record<string, unknown>>): Promise<number> {
|
||||
return tableAdd.call(this._tbl, await fromRecordsToBuffer(data), WriteMode.Overwrite.toString())
|
||||
}
|
||||
@@ -120,44 +125,75 @@ export class Table {
|
||||
*/
|
||||
export class Query {
|
||||
private readonly _tbl: any
|
||||
private readonly _query_vector: number[]
|
||||
private readonly _queryVector: number[]
|
||||
private _limit: number
|
||||
private readonly _refine_factor?: number
|
||||
private readonly _nprobes: number
|
||||
private _refineFactor?: number
|
||||
private _nprobes: number
|
||||
private readonly _columns?: string[]
|
||||
private _filter?: string
|
||||
private readonly _metric = 'L2'
|
||||
private _metricType?: MetricType
|
||||
|
||||
constructor (tbl: any, queryVector: number[]) {
|
||||
this._tbl = tbl
|
||||
this._query_vector = queryVector
|
||||
this._queryVector = queryVector
|
||||
this._limit = 10
|
||||
this._nprobes = 20
|
||||
this._refine_factor = undefined
|
||||
this._refineFactor = undefined
|
||||
this._columns = undefined
|
||||
this._filter = undefined
|
||||
this._metricType = undefined
|
||||
}
|
||||
|
||||
/***
|
||||
* Sets the number of results that will be returned
|
||||
* @param value number of results
|
||||
*/
|
||||
limit (value: number): Query {
|
||||
this._limit = value
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* Refine the results by reading extra elements and re-ranking them in memory.
|
||||
* @param value refine factor to use in this query.
|
||||
*/
|
||||
refineFactor (value: number): Query {
|
||||
this._refineFactor = value
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* The number of probes used. A higher number makes search more accurate but also slower.
|
||||
* @param value The number of probes used.
|
||||
*/
|
||||
nprobes (value: number): Query {
|
||||
this._nprobes = value
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* A filter statement to be applied to this query.
|
||||
* @param value A filter in the same format used by a sql WHERE clause.
|
||||
*/
|
||||
filter (value: string): Query {
|
||||
this._filter = value
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute the query and return the results as an Array of Objects
|
||||
*/
|
||||
* The MetricType used for this Query.
|
||||
* @param value The metric to the. @see MetricType for the different options
|
||||
*/
|
||||
metricType (value: MetricType): Query {
|
||||
this._metricType = value
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute the query and return the results as an Array of Objects
|
||||
*/
|
||||
async execute<T = Record<string, unknown>> (): Promise<T[]> {
|
||||
let buffer
|
||||
if (this._filter != null) {
|
||||
buffer = await tableSearch.call(this._tbl, this._query_vector, this._limit, this._filter)
|
||||
} else {
|
||||
buffer = await tableSearch.call(this._tbl, this._query_vector, this._limit)
|
||||
}
|
||||
const buffer = await tableSearch.call(this._tbl, this)
|
||||
const data = tableFromIPC(buffer)
|
||||
return data.toArray().map((entry: Record<string, unknown>) => {
|
||||
const newObject: Record<string, unknown> = {}
|
||||
@@ -177,3 +213,18 @@ export enum WriteMode {
|
||||
Overwrite = 'overwrite',
|
||||
Append = 'append'
|
||||
}
|
||||
|
||||
/**
|
||||
* Distance metrics type.
|
||||
*/
|
||||
export enum MetricType {
|
||||
/**
|
||||
* Euclidean distance
|
||||
*/
|
||||
L2 = 'l2',
|
||||
|
||||
/**
|
||||
* Cosine distance
|
||||
*/
|
||||
Cosine = 'cosine'
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ import { assert } from 'chai'
|
||||
import { track } from 'temp'
|
||||
|
||||
import * as lancedb from '../index'
|
||||
import { MetricType, Query } from '../index'
|
||||
|
||||
describe('LanceDB client', function () {
|
||||
describe('when creating a connection to lancedb', function () {
|
||||
@@ -132,6 +133,20 @@ describe('LanceDB client', function () {
|
||||
})
|
||||
})
|
||||
|
||||
describe('Query object', function () {
|
||||
it('sets custom parameters', async function () {
|
||||
const query = new Query(undefined, [0.1, 0.3])
|
||||
.limit(1)
|
||||
.metricType(MetricType.Cosine)
|
||||
.refineFactor(100)
|
||||
.nprobes(20) as Record<string, any>
|
||||
assert.equal(query._limit, 1)
|
||||
assert.equal(query._metricType, MetricType.Cosine)
|
||||
assert.equal(query._refineFactor, 100)
|
||||
assert.equal(query._nprobes, 20)
|
||||
})
|
||||
})
|
||||
|
||||
async function createTestDB (): Promise<string> {
|
||||
const dir = await track().mkdir('lancejs')
|
||||
const con = await lancedb.connect(dir)
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
// limitations under the License.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::convert::TryFrom;
|
||||
use std::ops::Deref;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
@@ -21,6 +22,7 @@ use arrow_ipc::writer::FileWriter;
|
||||
use futures::{TryFutureExt, TryStreamExt};
|
||||
use lance::arrow::RecordBatchBuffer;
|
||||
use lance::dataset::WriteMode;
|
||||
use lance::index::vector::MetricType;
|
||||
use neon::prelude::*;
|
||||
use neon::types::buffer::TypedArray;
|
||||
use once_cell::sync::OnceCell;
|
||||
@@ -39,12 +41,12 @@ struct JsDatabase {
|
||||
database: Arc<Database>,
|
||||
}
|
||||
|
||||
impl Finalize for JsDatabase {}
|
||||
|
||||
struct JsTable {
|
||||
table: Arc<Mutex<Table>>,
|
||||
}
|
||||
|
||||
impl Finalize for JsDatabase {}
|
||||
|
||||
impl Finalize for JsTable {}
|
||||
|
||||
fn runtime<'a, C: Context<'a>>(cx: &mut C) -> NeonResult<&'static Runtime> {
|
||||
@@ -87,7 +89,9 @@ fn database_open_table(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let table_rst = database.open_table(table_name).await;
|
||||
|
||||
deferred.settle_with(&channel, move |mut cx| {
|
||||
let table = Arc::new(Mutex::new(table_rst.or_else(|err| cx.throw_error(err.to_string()))?));
|
||||
let table = Arc::new(Mutex::new(
|
||||
table_rst.or_else(|err| cx.throw_error(err.to_string()))?,
|
||||
));
|
||||
Ok(cx.boxed(JsTable { table }))
|
||||
});
|
||||
});
|
||||
@@ -96,15 +100,32 @@ fn database_open_table(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
|
||||
fn table_search(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||
let query_vector = cx.argument::<JsArray>(0)?; //. .as_value(&mut cx);
|
||||
let limit = cx.argument::<JsNumber>(1)?.value(&mut cx);
|
||||
let filter = cx.argument_opt(2).map(|f| f.downcast_or_throw::<JsString, _>(&mut cx).unwrap().value(&mut cx));
|
||||
let query_obj = cx.argument::<JsObject>(0)?;
|
||||
|
||||
let limit = query_obj
|
||||
.get::<JsNumber, _, _>(&mut cx, "_limit")?
|
||||
.value(&mut cx);
|
||||
let filter = query_obj
|
||||
.get_opt::<JsString, _, _>(&mut cx, "_filter")?
|
||||
.map(|s| s.value(&mut cx));
|
||||
let refine_factor = query_obj
|
||||
.get_opt::<JsNumber, _, _>(&mut cx, "_refineFactor")?
|
||||
.map(|s| s.value(&mut cx))
|
||||
.map(|i| i as u32);
|
||||
let nprobes = query_obj
|
||||
.get::<JsNumber, _, _>(&mut cx, "_nprobes")?
|
||||
.value(&mut cx) as usize;
|
||||
let metric_type = query_obj
|
||||
.get_opt::<JsString, _, _>(&mut cx, "_metricType")?
|
||||
.map(|s| s.value(&mut cx))
|
||||
.map(|s| MetricType::try_from(s.as_str()).unwrap());
|
||||
|
||||
let rt = runtime(&mut cx)?;
|
||||
let channel = cx.channel();
|
||||
|
||||
let (deferred, promise) = cx.promise();
|
||||
let table = js_table.table.clone();
|
||||
let query_vector = query_obj.get::<JsArray, _, _>(&mut cx, "_queryVector")?;
|
||||
let query = convert::js_array_to_vec(query_vector.deref(), &mut cx);
|
||||
|
||||
rt.spawn(async move {
|
||||
@@ -113,7 +134,10 @@ fn table_search(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
.unwrap()
|
||||
.search(Float32Array::from(query))
|
||||
.limit(limit as usize)
|
||||
.filter(filter);
|
||||
.refine_factor(refine_factor)
|
||||
.nprobes(nprobes)
|
||||
.filter(filter)
|
||||
.metric_type(metric_type);
|
||||
let record_batch_stream = builder.execute();
|
||||
let results = record_batch_stream
|
||||
.and_then(|stream| stream.try_collect::<Vec<_>>().map_err(Error::from))
|
||||
@@ -164,7 +188,9 @@ fn table_create(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
let table_rst = database.create_table(table_name, batch_reader).await;
|
||||
|
||||
deferred.settle_with(&channel, move |mut cx| {
|
||||
let table = Arc::new(Mutex::new(table_rst.or_else(|err| cx.throw_error(err.to_string()))?));
|
||||
let table = Arc::new(Mutex::new(
|
||||
table_rst.or_else(|err| cx.throw_error(err.to_string()))?,
|
||||
));
|
||||
Ok(cx.boxed(JsTable { table }))
|
||||
});
|
||||
});
|
||||
@@ -178,9 +204,7 @@ fn table_add(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
("overwrite", WriteMode::Overwrite),
|
||||
]);
|
||||
|
||||
let js_table = cx
|
||||
.this()
|
||||
.downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||
let js_table = cx.this().downcast_or_throw::<JsBox<JsTable>, _>(&mut cx)?;
|
||||
let buffer = cx.argument::<JsBuffer>(0)?;
|
||||
let write_mode = cx.argument::<JsString>(1)?.value(&mut cx);
|
||||
let batches = arrow_buffer_to_record_batch(buffer.as_slice(&mut cx));
|
||||
@@ -204,7 +228,6 @@ fn table_add(mut cx: FunctionContext) -> JsResult<JsPromise> {
|
||||
Ok(promise)
|
||||
}
|
||||
|
||||
|
||||
#[neon::main]
|
||||
fn main(mut cx: ModuleContext) -> NeonResult<()> {
|
||||
cx.export_function("databaseNew", database_new)?;
|
||||
|
||||
@@ -29,7 +29,7 @@ pub struct Query {
|
||||
pub filter: Option<String>,
|
||||
pub nprobes: usize,
|
||||
pub refine_factor: Option<u32>,
|
||||
pub metric_type: MetricType,
|
||||
pub metric_type: Option<MetricType>,
|
||||
pub use_index: bool,
|
||||
}
|
||||
|
||||
@@ -51,9 +51,9 @@ impl Query {
|
||||
limit: 10,
|
||||
nprobes: 20,
|
||||
refine_factor: None,
|
||||
metric_type: MetricType::L2,
|
||||
metric_type: None,
|
||||
use_index: false,
|
||||
filter: None
|
||||
filter: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -71,10 +71,10 @@ impl Query {
|
||||
self.limit,
|
||||
)?;
|
||||
scanner.nprobs(self.nprobes);
|
||||
scanner.distance_metric(self.metric_type);
|
||||
scanner.use_index(self.use_index);
|
||||
self.filter.as_ref().map(|f| scanner.filter(f));
|
||||
self.refine_factor.map(|rf| scanner.refine(rf));
|
||||
self.metric_type.map(|mt| scanner.distance_metric(mt));
|
||||
Ok(scanner.try_into_stream().await?)
|
||||
}
|
||||
|
||||
@@ -123,7 +123,7 @@ impl Query {
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `metric_type` - The distance metric to use. By default [MetricType::L2] is used.
|
||||
pub fn metric_type(mut self, metric_type: MetricType) -> Query {
|
||||
pub fn metric_type(mut self, metric_type: Option<MetricType>) -> Query {
|
||||
self.metric_type = metric_type;
|
||||
self
|
||||
}
|
||||
@@ -174,14 +174,14 @@ mod tests {
|
||||
.limit(100)
|
||||
.nprobes(1000)
|
||||
.use_index(true)
|
||||
.metric_type(MetricType::Cosine)
|
||||
.metric_type(Some(MetricType::Cosine))
|
||||
.refine_factor(Some(999));
|
||||
|
||||
assert_eq!(query.query_vector, new_vector);
|
||||
assert_eq!(query.limit, 100);
|
||||
assert_eq!(query.nprobes, 1000);
|
||||
assert_eq!(query.use_index, true);
|
||||
assert_eq!(query.metric_type, MetricType::Cosine);
|
||||
assert_eq!(query.metric_type, Some(MetricType::Cosine));
|
||||
assert_eq!(query.refine_factor, Some(999));
|
||||
}
|
||||
|
||||
|
||||
@@ -80,7 +80,11 @@ impl Table {
|
||||
|
||||
let dataset =
|
||||
Arc::new(Dataset::write(&mut batches, path, Some(WriteParams::default())).await?);
|
||||
Ok(Table { name, path: path.to_string(), dataset })
|
||||
Ok(Table {
|
||||
name,
|
||||
path: path.to_string(),
|
||||
dataset,
|
||||
})
|
||||
}
|
||||
|
||||
/// Insert records into this Table
|
||||
@@ -95,12 +99,13 @@ impl Table {
|
||||
pub async fn add(
|
||||
&mut self,
|
||||
mut batches: Box<dyn RecordBatchReader>,
|
||||
write_mode: Option<WriteMode>
|
||||
write_mode: Option<WriteMode>,
|
||||
) -> Result<usize> {
|
||||
let mut params = WriteParams::default();
|
||||
params.mode = write_mode.unwrap_or(WriteMode::Append);
|
||||
|
||||
self.dataset = Arc::new(Dataset::write(&mut batches, self.path.as_str(), Some(params)).await?);
|
||||
self.dataset =
|
||||
Arc::new(Dataset::write(&mut batches, self.path.as_str(), Some(params)).await?);
|
||||
Ok(batches.count())
|
||||
}
|
||||
|
||||
@@ -171,14 +176,17 @@ mod tests {
|
||||
|
||||
let batches: Box<dyn RecordBatchReader> = Box::new(make_test_batches());
|
||||
let schema = batches.schema().clone();
|
||||
let mut table = Table::create(Arc::new(path_buf), "test".to_string(), batches).await.unwrap();
|
||||
let mut table = Table::create(Arc::new(path_buf), "test".to_string(), batches)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(table.count_rows().await.unwrap(), 10);
|
||||
|
||||
let new_batches: Box<dyn RecordBatchReader> = Box::new(RecordBatchBuffer::new(vec![RecordBatch::try_new(
|
||||
schema,
|
||||
vec![Arc::new(Int32Array::from_iter_values(100..110))],
|
||||
)
|
||||
.unwrap()]));
|
||||
let new_batches: Box<dyn RecordBatchReader> =
|
||||
Box::new(RecordBatchBuffer::new(vec![RecordBatch::try_new(
|
||||
schema,
|
||||
vec![Arc::new(Int32Array::from_iter_values(100..110))],
|
||||
)
|
||||
.unwrap()]));
|
||||
|
||||
table.add(new_batches, None).await.unwrap();
|
||||
assert_eq!(table.count_rows().await.unwrap(), 20);
|
||||
@@ -192,15 +200,22 @@ mod tests {
|
||||
|
||||
let batches: Box<dyn RecordBatchReader> = Box::new(make_test_batches());
|
||||
let schema = batches.schema().clone();
|
||||
let mut table = Table::create(Arc::new(path_buf), "test".to_string(), batches).await.unwrap();
|
||||
let mut table = Table::create(Arc::new(path_buf), "test".to_string(), batches)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(table.count_rows().await.unwrap(), 10);
|
||||
|
||||
let new_batches: Box<dyn RecordBatchReader> = Box::new(RecordBatchBuffer::new(vec![RecordBatch::try_new(
|
||||
schema,
|
||||
vec![Arc::new(Int32Array::from_iter_values(100..110))],
|
||||
).unwrap()]));
|
||||
let new_batches: Box<dyn RecordBatchReader> =
|
||||
Box::new(RecordBatchBuffer::new(vec![RecordBatch::try_new(
|
||||
schema,
|
||||
vec![Arc::new(Int32Array::from_iter_values(100..110))],
|
||||
)
|
||||
.unwrap()]));
|
||||
|
||||
table.add(new_batches, Some(WriteMode::Overwrite)).await.unwrap();
|
||||
table
|
||||
.add(new_batches, Some(WriteMode::Overwrite))
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(table.count_rows().await.unwrap(), 10);
|
||||
assert_eq!(table.name, "test");
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user