add openai embedding function to nodejs client (#107)

- openai is an optional dependency for lancedb
- added an example to show how to use it
This commit is contained in:
gsilvestrin
2023-06-01 10:25:00 -07:00
committed by GitHub
parent 99cbda8b07
commit 3e14b357e7
12 changed files with 683 additions and 38 deletions

View File

@@ -19,10 +19,14 @@ import {
Vector
} from 'apache-arrow'
import { fromRecordsToBuffer } from './arrow'
import type { EmbeddingFunction } from './embedding/embedding_function'
// eslint-disable-next-line @typescript-eslint/no-var-requires
const { databaseNew, databaseTableNames, databaseOpenTable, tableCreate, tableSearch, tableAdd, tableCreateVectorIndex } = require('../native.js')
export type { EmbeddingFunction }
export { OpenAIEmbeddingFunction } from './embedding/openai'
/**
* Connect to a LanceDB instance at the given URI
* @param uri The uri of the database.
@@ -135,14 +139,8 @@ export class Table<T = number[]> {
* Creates a search query to find the nearest neighbors of the given search term
* @param query The query search term
*/
search (query: T): Query {
let queryVector: number[]
if (this._embeddings !== undefined) {
queryVector = this._embeddings.embed([query])[0]
} else {
queryVector = query as number[]
}
return new Query(this._tbl, queryVector)
search (query: T): Query<T> {
return new Query(this._tbl, query, this._embeddings)
}
/**
@@ -228,32 +226,35 @@ export type VectorIndexParams = IvfPQIndexConfig
/**
* A builder for nearest neighbor queries for LanceDB.
*/
export class Query {
export class Query<T = number[]> {
private readonly _tbl: any
private readonly _queryVector: number[]
private readonly _query: T
private _queryVector?: number[]
private _limit: number
private _refineFactor?: number
private _nprobes: number
private readonly _columns?: string[]
private _filter?: string
private _metricType?: MetricType
private readonly _embeddings?: EmbeddingFunction<T>
constructor (tbl: any, queryVector: number[]) {
constructor (tbl: any, query: T, embeddings?: EmbeddingFunction<T>) {
this._tbl = tbl
this._queryVector = queryVector
this._query = query
this._limit = 10
this._nprobes = 20
this._refineFactor = undefined
this._columns = undefined
this._filter = undefined
this._metricType = undefined
this._embeddings = embeddings
}
/***
* Sets the number of results that will be returned
* @param value number of results
*/
limit (value: number): Query {
limit (value: number): Query<T> {
this._limit = value
return this
}
@@ -262,7 +263,7 @@ export class Query {
* Refine the results by reading extra elements and re-ranking them in memory.
* @param value refine factor to use in this query.
*/
refineFactor (value: number): Query {
refineFactor (value: number): Query<T> {
this._refineFactor = value
return this
}
@@ -271,7 +272,7 @@ export class Query {
* The number of probes used. A higher number makes search more accurate but also slower.
* @param value The number of probes used.
*/
nprobes (value: number): Query {
nprobes (value: number): Query<T> {
this._nprobes = value
return this
}
@@ -280,7 +281,7 @@ export class Query {
* A filter statement to be applied to this query.
* @param value A filter in the same format used by a sql WHERE clause.
*/
filter (value: string): Query {
filter (value: string): Query<T> {
this._filter = value
return this
}
@@ -289,7 +290,7 @@ export class Query {
* The MetricType used for this Query.
* @param value The metric to the. @see MetricType for the different options
*/
metricType (value: MetricType): Query {
metricType (value: MetricType): Query<T> {
this._metricType = value
return this
}
@@ -298,6 +299,12 @@ export class Query {
* Execute the query and return the results as an Array of Objects
*/
async execute<T = Record<string, unknown>> (): Promise<T[]> {
if (this._embeddings !== undefined) {
this._queryVector = (await this._embeddings.embed([this._query]))[0]
} else {
this._queryVector = this._query as number[]
}
const buffer = await tableSearch.call(this._tbl, this)
const data = tableFromIPC(buffer)
return data.toArray().map((entry: Record<string, unknown>) => {
@@ -319,21 +326,6 @@ export enum WriteMode {
Append = 'append'
}
/**
* An embedding function that automatically creates vector representation for a given column.
*/
export interface EmbeddingFunction<T> {
/**
* The name of the column that will be used as input for the Embedding Function.
*/
sourceColumn: string
/**
* Creates a vector representation for the given values.
*/
embed: (data: T[]) => number[][]
}
/**
* Distance metrics type.
*/