[Node] initial support of nodejs remote sdk (#333)

This commit is contained in:
Lei Xu
2023-07-18 16:15:27 -07:00
committed by GitHub
parent fb97b03a51
commit 980f910f50
7 changed files with 438 additions and 151 deletions

View File

@@ -14,15 +14,15 @@
import {
RecordBatchFileWriter,
type Table as ArrowTable,
tableFromIPC,
Vector
type Table as ArrowTable
} from 'apache-arrow'
import { fromRecordsToBuffer } from './arrow'
import type { EmbeddingFunction } from './embedding/embedding_function'
import { RemoteConnection } from './remote'
import { Query } from './query'
// eslint-disable-next-line @typescript-eslint/no-var-requires
const { databaseNew, databaseTableNames, databaseOpenTable, databaseDropTable, tableCreate, tableSearch, tableAdd, tableCreateVectorIndex, tableCountRows, tableDelete } = require('../native.js')
const { databaseNew, databaseTableNames, databaseOpenTable, databaseDropTable, tableCreate, tableAdd, tableCreateVectorIndex, tableCountRows, tableDelete } = require('../native.js')
export type { EmbeddingFunction }
export { OpenAIEmbeddingFunction } from './embedding/openai'
@@ -37,7 +37,13 @@ export interface AwsCredentials {
export interface ConnectionOptions {
uri: string
awsCredentials?: AwsCredentials
// API key for the remote connections
apiKey?: string
// Region to connect
region?: string
}
/**
@@ -54,9 +60,16 @@ export async function connect (arg: string | Partial<ConnectionOptions>): Promis
// opts = { uri: arg.uri, awsCredentials = arg.awsCredentials }
opts = Object.assign({
uri: '',
awsCredentials: undefined
awsCredentials: undefined,
apiKey: undefined,
region: 'us-west-2'
}, arg)
}
if (opts.uri.startsWith('db://')) {
// Remote connection
return new RemoteConnection(opts)
}
const db = await databaseNew(opts.uri)
return new LocalConnection(db, opts)
}
@@ -191,8 +204,8 @@ export class LocalConnection implements Connection {
}
/**
* Get the names of all tables in the database.
*/
* Get the names of all tables in the database.
*/
async tableNames (): Promise<string[]> {
return databaseTableNames.call(this._db)
}
@@ -203,6 +216,7 @@ export class LocalConnection implements Connection {
* @param name The name of the table.
*/
async openTable (name: string): Promise<Table>
/**
* Open a table in the database.
*
@@ -308,7 +322,7 @@ export class LocalTable<T = number[]> implements Table<T> {
* @param query The query search term
*/
search (query: T): Query<T> {
return new Query(this._tbl, query, this._embeddings)
return new Query(query, this._tbl, this._embeddings)
}
/**
@@ -430,116 +444,6 @@ export interface IvfPQIndexConfig {
export type VectorIndexParams = IvfPQIndexConfig
/**
* A builder for nearest neighbor queries for LanceDB.
*/
export class Query<T = number[]> {
private readonly _tbl: any
private readonly _query: T
private _queryVector?: number[]
private _limit: number
private _refineFactor?: number
private _nprobes: number
private _select?: string[]
private _filter?: string
private _metricType?: MetricType
private readonly _embeddings?: EmbeddingFunction<T>
constructor (tbl: any, query: T, embeddings?: EmbeddingFunction<T>) {
this._tbl = tbl
this._query = query
this._limit = 10
this._nprobes = 20
this._refineFactor = undefined
this._select = undefined
this._filter = undefined
this._metricType = undefined
this._embeddings = embeddings
}
/***
* Sets the number of results that will be returned
* @param value number of results
*/
limit (value: number): Query<T> {
this._limit = value
return this
}
/**
* Refine the results by reading extra elements and re-ranking them in memory.
* @param value refine factor to use in this query.
*/
refineFactor (value: number): Query<T> {
this._refineFactor = value
return this
}
/**
* The number of probes used. A higher number makes search more accurate but also slower.
* @param value The number of probes used.
*/
nprobes (value: number): Query<T> {
this._nprobes = value
return this
}
/**
* A filter statement to be applied to this query.
* @param value A filter in the same format used by a sql WHERE clause.
*/
filter (value: string): Query<T> {
this._filter = value
return this
}
where = this.filter
/** Return only the specified columns.
*
* @param value Only select the specified columns. If not specified, all columns will be returned.
*/
select (value: string[]): Query<T> {
this._select = value
return this
}
/**
* The MetricType used for this Query.
* @param value The metric to the. @see MetricType for the different options
*/
metricType (value: MetricType): Query<T> {
this._metricType = value
return this
}
/**
* Execute the query and return the results as an Array of Objects
*/
async execute<T = Record<string, unknown>> (): Promise<T[]> {
if (this._embeddings !== undefined) {
this._queryVector = (await this._embeddings.embed([this._query]))[0]
} else {
this._queryVector = this._query as number[]
}
const buffer = await tableSearch.call(this._tbl, this)
const data = tableFromIPC(buffer)
return data.toArray().map((entry: Record<string, unknown>) => {
const newObject: Record<string, unknown> = {}
Object.keys(entry).forEach((key: string) => {
if (entry[key] instanceof Vector) {
newObject[key] = (entry[key] as Vector).toArray()
} else {
newObject[key] = entry[key]
}
})
return newObject as unknown as T
})
}
}
/**
* Write mode for writing a table.
*/