add query params to to nodejs client (#87)

This commit is contained in:
gsilvestrin
2023-05-24 15:48:31 -06:00
committed by GitHub
parent bdef634954
commit 06cb7b6458
5 changed files with 155 additions and 51 deletions

View File

@@ -100,16 +100,21 @@ export class Table {
}
/**
* Insert records into this Table
* @param data Records to be inserted into the Table
* Insert records into this Table.
*
* @param mode Append / Overwrite existing records. Default: Append
* @param data Records to be inserted into the Table
* @return The number of rows added to the table
*/
async add (data: Array<Record<string, unknown>>): Promise<number> {
return tableAdd.call(this._tbl, await fromRecordsToBuffer(data), WriteMode.Append.toString())
}
/**
* Insert records into this Table, replacing its contents.
*
* @param data Records to be inserted into the Table
* @return The number of rows added to the table
*/
async overwrite (data: Array<Record<string, unknown>>): Promise<number> {
return tableAdd.call(this._tbl, await fromRecordsToBuffer(data), WriteMode.Overwrite.toString())
}
@@ -120,44 +125,75 @@ export class Table {
*/
export class Query {
private readonly _tbl: any
private readonly _query_vector: number[]
private readonly _queryVector: number[]
private _limit: number
private readonly _refine_factor?: number
private readonly _nprobes: number
private _refineFactor?: number
private _nprobes: number
private readonly _columns?: string[]
private _filter?: string
private readonly _metric = 'L2'
private _metricType?: MetricType
constructor (tbl: any, queryVector: number[]) {
this._tbl = tbl
this._query_vector = queryVector
this._queryVector = queryVector
this._limit = 10
this._nprobes = 20
this._refine_factor = undefined
this._refineFactor = undefined
this._columns = undefined
this._filter = undefined
this._metricType = undefined
}
/***
* Sets the number of results that will be returned
* @param value number of results
*/
limit (value: number): Query {
this._limit = value
return this
}
/**
* Refine the results by reading extra elements and re-ranking them in memory.
* @param value refine factor to use in this query.
*/
refineFactor (value: number): Query {
this._refineFactor = value
return this
}
/**
* The number of probes used. A higher number makes search more accurate but also slower.
* @param value The number of probes used.
*/
nprobes (value: number): Query {
this._nprobes = value
return this
}
/**
* A filter statement to be applied to this query.
* @param value A filter in the same format used by a sql WHERE clause.
*/
filter (value: string): Query {
this._filter = value
return this
}
/**
* Execute the query and return the results as an Array of Objects
*/
* The MetricType used for this Query.
* @param value The metric to the. @see MetricType for the different options
*/
metricType (value: MetricType): Query {
this._metricType = value
return this
}
/**
* Execute the query and return the results as an Array of Objects
*/
async execute<T = Record<string, unknown>> (): Promise<T[]> {
let buffer
if (this._filter != null) {
buffer = await tableSearch.call(this._tbl, this._query_vector, this._limit, this._filter)
} else {
buffer = await tableSearch.call(this._tbl, this._query_vector, this._limit)
}
const buffer = await tableSearch.call(this._tbl, this)
const data = tableFromIPC(buffer)
return data.toArray().map((entry: Record<string, unknown>) => {
const newObject: Record<string, unknown> = {}
@@ -177,3 +213,18 @@ export enum WriteMode {
Overwrite = 'overwrite',
Append = 'append'
}
/**
* Distance metrics type.
*/
export enum MetricType {
/**
* Euclidean distance
*/
L2 = 'l2',
/**
* Cosine distance
*/
Cosine = 'cosine'
}

View File

@@ -17,6 +17,7 @@ import { assert } from 'chai'
import { track } from 'temp'
import * as lancedb from '../index'
import { MetricType, Query } from '../index'
describe('LanceDB client', function () {
describe('when creating a connection to lancedb', function () {
@@ -132,6 +133,20 @@ describe('LanceDB client', function () {
})
})
describe('Query object', function () {
it('sets custom parameters', async function () {
const query = new Query(undefined, [0.1, 0.3])
.limit(1)
.metricType(MetricType.Cosine)
.refineFactor(100)
.nprobes(20) as Record<string, any>
assert.equal(query._limit, 1)
assert.equal(query._metricType, MetricType.Cosine)
assert.equal(query._refineFactor, 100)
assert.equal(query._nprobes, 20)
})
})
async function createTestDB (): Promise<string> {
const dir = await track().mkdir('lancejs')
const con = await lancedb.connect(dir)