add openai embedding function to nodejs client (#107)

- openai is an optional dependency for lancedb
- added an example to show how to use it
This commit is contained in:
gsilvestrin
2023-06-01 10:25:00 -07:00
committed by GitHub
parent 99cbda8b07
commit 3e14b357e7
12 changed files with 683 additions and 38 deletions

View File

@@ -24,7 +24,7 @@ import {
} from 'apache-arrow'
import { type EmbeddingFunction } from './index'
export function convertToTable<T> (data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>): Table {
export async function convertToTable<T> (data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>): Promise<Table> {
if (data.length === 0) {
throw new Error('At least one record needs to be provided')
}
@@ -51,7 +51,7 @@ export function convertToTable<T> (data: Array<Record<string, unknown>>, embeddi
}
if (columnsKey === embeddings?.sourceColumn) {
const vectors = embeddings.embed(values as T[])
const vectors = await embeddings.embed(values as T[])
const listBuilder = newVectorListBuilder()
vectors.map(v => listBuilder.append(v))
records.vector = listBuilder.finish().toVector()
@@ -79,7 +79,7 @@ function newVectorListBuilder (): ListBuilder<Float32, any> {
}
export async function fromRecordsToBuffer<T> (data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>): Promise<Buffer> {
const table = convertToTable(data, embeddings)
const table = await convertToTable(data, embeddings)
const writer = RecordBatchFileWriter.writeAll(table)
return Buffer.from(await writer.toUint8Array())
}

View File

@@ -0,0 +1,28 @@
// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/**
* An embedding function that automatically creates vector representation for a given column.
*/
export interface EmbeddingFunction<T> {
/**
* The name of the column that will be used as input for the Embedding Function.
*/
sourceColumn: string
/**
* Creates a vector representation for the given values.
*/
embed: (data: T[]) => Promise<number[][]>
}

View File

@@ -0,0 +1,51 @@
// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import { type EmbeddingFunction } from '../index'
export class OpenAIEmbeddingFunction implements EmbeddingFunction<string> {
private readonly _openai: any
private readonly _modelName: string
constructor (sourceColumn: string, openAIKey: string, modelName: string = 'text-embedding-ada-002') {
let openai
try {
// eslint-disable-next-line @typescript-eslint/no-var-requires
openai = require('openai')
} catch {
throw new Error('please install openai using npm install openai')
}
this.sourceColumn = sourceColumn
const configuration = new openai.Configuration({
apiKey: openAIKey
})
this._openai = new openai.OpenAIApi(configuration)
this._modelName = modelName
}
async embed (data: string[]): Promise<number[][]> {
const response = await this._openai.createEmbedding({
model: this._modelName,
input: data
})
const embeddings: number[][] = []
for (let i = 0; i < response.data.data.length; i++) {
embeddings.push(response.data.data[i].embedding as number[])
}
return embeddings
}
sourceColumn: string
}

View File

@@ -19,10 +19,14 @@ import {
Vector
} from 'apache-arrow'
import { fromRecordsToBuffer } from './arrow'
import type { EmbeddingFunction } from './embedding/embedding_function'
// eslint-disable-next-line @typescript-eslint/no-var-requires
const { databaseNew, databaseTableNames, databaseOpenTable, tableCreate, tableSearch, tableAdd, tableCreateVectorIndex } = require('../native.js')
export type { EmbeddingFunction }
export { OpenAIEmbeddingFunction } from './embedding/openai'
/**
* Connect to a LanceDB instance at the given URI
* @param uri The uri of the database.
@@ -135,14 +139,8 @@ export class Table<T = number[]> {
* Creates a search query to find the nearest neighbors of the given search term
* @param query The query search term
*/
search (query: T): Query {
let queryVector: number[]
if (this._embeddings !== undefined) {
queryVector = this._embeddings.embed([query])[0]
} else {
queryVector = query as number[]
}
return new Query(this._tbl, queryVector)
search (query: T): Query<T> {
return new Query(this._tbl, query, this._embeddings)
}
/**
@@ -228,32 +226,35 @@ export type VectorIndexParams = IvfPQIndexConfig
/**
* A builder for nearest neighbor queries for LanceDB.
*/
export class Query {
export class Query<T = number[]> {
private readonly _tbl: any
private readonly _queryVector: number[]
private readonly _query: T
private _queryVector?: number[]
private _limit: number
private _refineFactor?: number
private _nprobes: number
private readonly _columns?: string[]
private _filter?: string
private _metricType?: MetricType
private readonly _embeddings?: EmbeddingFunction<T>
constructor (tbl: any, queryVector: number[]) {
constructor (tbl: any, query: T, embeddings?: EmbeddingFunction<T>) {
this._tbl = tbl
this._queryVector = queryVector
this._query = query
this._limit = 10
this._nprobes = 20
this._refineFactor = undefined
this._columns = undefined
this._filter = undefined
this._metricType = undefined
this._embeddings = embeddings
}
/***
* Sets the number of results that will be returned
* @param value number of results
*/
limit (value: number): Query {
limit (value: number): Query<T> {
this._limit = value
return this
}
@@ -262,7 +263,7 @@ export class Query {
* Refine the results by reading extra elements and re-ranking them in memory.
* @param value refine factor to use in this query.
*/
refineFactor (value: number): Query {
refineFactor (value: number): Query<T> {
this._refineFactor = value
return this
}
@@ -271,7 +272,7 @@ export class Query {
* The number of probes used. A higher number makes search more accurate but also slower.
* @param value The number of probes used.
*/
nprobes (value: number): Query {
nprobes (value: number): Query<T> {
this._nprobes = value
return this
}
@@ -280,7 +281,7 @@ export class Query {
* A filter statement to be applied to this query.
* @param value A filter in the same format used by a sql WHERE clause.
*/
filter (value: string): Query {
filter (value: string): Query<T> {
this._filter = value
return this
}
@@ -289,7 +290,7 @@ export class Query {
* The MetricType used for this Query.
* @param value The metric to the. @see MetricType for the different options
*/
metricType (value: MetricType): Query {
metricType (value: MetricType): Query<T> {
this._metricType = value
return this
}
@@ -298,6 +299,12 @@ export class Query {
* Execute the query and return the results as an Array of Objects
*/
async execute<T = Record<string, unknown>> (): Promise<T[]> {
if (this._embeddings !== undefined) {
this._queryVector = (await this._embeddings.embed([this._query]))[0]
} else {
this._queryVector = this._query as number[]
}
const buffer = await tableSearch.call(this._tbl, this)
const data = tableFromIPC(buffer)
return data.toArray().map((entry: Record<string, unknown>) => {
@@ -319,21 +326,6 @@ export enum WriteMode {
Append = 'append'
}
/**
* An embedding function that automatically creates vector representation for a given column.
*/
export interface EmbeddingFunction<T> {
/**
* The name of the column that will be used as input for the Embedding Function.
*/
sourceColumn: string
/**
* Creates a vector representation for the given values.
*/
embed: (data: T[]) => number[][]
}
/**
* Distance metrics type.
*/

View File

@@ -0,0 +1,50 @@
// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import { describe } from 'mocha'
import { assert } from 'chai'
import { OpenAIEmbeddingFunction } from '../../embedding/openai'
// eslint-disable-next-line @typescript-eslint/no-var-requires
const { OpenAIApi } = require('openai')
// eslint-disable-next-line @typescript-eslint/no-var-requires
const { stub } = require('sinon')
describe('OpenAPIEmbeddings', function () {
const stubValue = {
data: {
data: [
{
embedding: Array(1536).fill(1.0)
},
{
embedding: Array(1536).fill(2.0)
}
]
}
}
describe('#embed', function () {
it('should create vector embeddings', async function () {
const openAIStub = stub(OpenAIApi.prototype, 'createEmbedding').returns(stubValue)
const f = new OpenAIEmbeddingFunction('text', 'sk-key')
const vectors = await f.embed(['abc', 'def'])
assert.isTrue(openAIStub.calledOnce)
assert.equal(vectors.length, 2)
assert.deepEqual(vectors[0], stubValue.data.data[0].embedding)
assert.deepEqual(vectors[1], stubValue.data.data[1].embedding)
})
})
})

View File

@@ -154,7 +154,7 @@ describe('LanceDB client', function () {
['bar', [3.1, 3.2]]
])
embed (data: string[]): number[][] {
async embed (data: string[]): Promise<number[][]> {
return data.map(datum => this._embedding_map.get(datum) ?? [0.0, 0.0])
}
}