add embedding functions to the nodejs client (#95)

This commit is contained in:
gsilvestrin
2023-05-26 19:09:20 -06:00
committed by GitHub
parent 04d97347d7
commit d3aa8bfbc5
3 changed files with 141 additions and 29 deletions

View File

@@ -15,15 +15,16 @@
import {
Field,
Float32,
List,
List, type ListBuilder,
makeBuilder,
RecordBatchFileWriter,
Table, Utf8,
type Vector,
vectorFromArray
} from 'apache-arrow'
import { type EmbeddingFunction } from './index'
export function convertToTable (data: Array<Record<string, unknown>>): Table {
export function convertToTable<T> (data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>): Table {
if (data.length === 0) {
throw new Error('At least one record needs to be provided')
}
@@ -33,11 +34,7 @@ export function convertToTable (data: Array<Record<string, unknown>>): Table {
for (const columnsKey of columns) {
if (columnsKey === 'vector') {
const children = new Field<Float32>('item', new Float32())
const list = new List(children)
const listBuilder = makeBuilder({
type: list
})
const listBuilder = newVectorListBuilder()
const vectorSize = (data[0].vector as any[]).length
for (const datum of data) {
if ((datum[columnsKey] as any[]).length !== vectorSize) {
@@ -52,6 +49,14 @@ export function convertToTable (data: Array<Record<string, unknown>>): Table {
for (const datum of data) {
values.push(datum[columnsKey])
}
if (columnsKey === embeddings?.sourceColumn) {
const vectors = embeddings.embed(values as T[])
const listBuilder = newVectorListBuilder()
vectors.map(v => listBuilder.append(v))
records.vector = listBuilder.finish().toVector()
}
if (typeof values[0] === 'string') {
// `vectorFromArray` converts strings into dictionary vectors, forcing it back to a string column
records[columnsKey] = vectorFromArray(values, new Utf8())
@@ -64,8 +69,17 @@ export function convertToTable (data: Array<Record<string, unknown>>): Table {
return new Table(records)
}
export async function fromRecordsToBuffer (data: Array<Record<string, unknown>>): Promise<Buffer> {
const table = convertToTable(data)
// Creates a new Arrow ListBuilder that stores a Vector column
function newVectorListBuilder (): ListBuilder<Float32, any> {
const children = new Field<Float32>('item', new Float32())
const list = new List(children)
return makeBuilder({
type: list
})
}
export async function fromRecordsToBuffer<T> (data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>): Promise<Buffer> {
const table = convertToTable(data, embeddings)
const writer = RecordBatchFileWriter.writeAll(table)
return Buffer.from(await writer.toUint8Array())
}

View File

@@ -55,17 +55,50 @@ export class Connection {
}
/**
* Open a table in the database.
* @param name The name of the table.
*/
async openTable (name: string): Promise<Table> {
* Open a table in the database.
*
* @param name The name of the table.
*/
async openTable (name: string): Promise<Table>
/**
* Open a table in the database.
*
* @param name The name of the table.
* @param embeddings An embedding function to use on this Table
*/
async openTable<T> (name: string, embeddings: EmbeddingFunction<T>): Promise<Table<T>>
async openTable<T> (name: string, embeddings?: EmbeddingFunction<T>): Promise<Table<T>> {
const tbl = await databaseOpenTable.call(this._db, name)
return new Table(tbl, name)
if (embeddings !== undefined) {
return new Table(tbl, name, embeddings)
} else {
return new Table(tbl, name)
}
}
async createTable (name: string, data: Array<Record<string, unknown>>): Promise<Table> {
await tableCreate.call(this._db, name, await fromRecordsToBuffer(data))
return await this.openTable(name)
/**
* Creates a new Table and initialize it with new data.
*
* @param name The name of the table.
* @param data Non-empty Array of Records to be inserted into the Table
*/
async createTable (name: string, data: Array<Record<string, unknown>>): Promise<Table>
/**
* Creates a new Table and initialize it with new data.
*
* @param name The name of the table.
* @param data Non-empty Array of Records to be inserted into the Table
* @param embeddings An embedding function to use on this Table
*/
async createTable<T> (name: string, data: Array<Record<string, unknown>>, embeddings: EmbeddingFunction<T>): Promise<Table<T>>
async createTable<T> (name: string, data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>): Promise<Table<T>> {
const tbl = await tableCreate.call(this._db, name, await fromRecordsToBuffer(data, embeddings))
if (embeddings !== undefined) {
return new Table(tbl, name, embeddings)
} else {
return new Table(tbl, name)
}
}
async createTableArrow (name: string, table: ArrowTable): Promise<Table> {
@@ -75,16 +108,22 @@ export class Connection {
}
}
/**
* A table in a LanceDB database.
*/
export class Table {
export class Table<T = number[]> {
private readonly _tbl: any
private readonly _name: string
private readonly _embeddings?: EmbeddingFunction<T>
constructor (tbl: any, name: string) {
constructor (tbl: any, name: string)
/**
* @param tbl
* @param name
* @param embeddings An embedding function to use when interacting with this table
*/
constructor (tbl: any, name: string, embeddings: EmbeddingFunction<T>)
constructor (tbl: any, name: string, embeddings?: EmbeddingFunction<T>) {
this._tbl = tbl
this._name = name
this._embeddings = embeddings
}
get name (): string {
@@ -92,10 +131,16 @@ export class Table {
}
/**
* Create a search query to find the nearest neighbors of the given query vector.
* @param queryVector The query vector.
*/
search (queryVector: number[]): Query {
* Creates a search query to find the nearest neighbors of the given search term
* @param query The query search term
*/
search (query: T): Query {
let queryVector: number[]
if (this._embeddings !== undefined) {
queryVector = this._embeddings.embed([query])[0]
} else {
queryVector = query as number[]
}
return new Query(this._tbl, queryVector)
}
@@ -106,7 +151,7 @@ export class Table {
* @return The number of rows added to the table
*/
async add (data: Array<Record<string, unknown>>): Promise<number> {
return tableAdd.call(this._tbl, await fromRecordsToBuffer(data), WriteMode.Append.toString())
return tableAdd.call(this._tbl, await fromRecordsToBuffer(data, this._embeddings), WriteMode.Append.toString())
}
/**
@@ -116,9 +161,14 @@ export class Table {
* @return The number of rows added to the table
*/
async overwrite (data: Array<Record<string, unknown>>): Promise<number> {
return tableAdd.call(this._tbl, await fromRecordsToBuffer(data), WriteMode.Overwrite.toString())
return tableAdd.call(this._tbl, await fromRecordsToBuffer(data, this._embeddings), WriteMode.Overwrite.toString())
}
/**
* Create an ANN index on this Table vector index.
*
* @param indexParams The parameters of this Index, @see VectorIndexParams.
*/
async create_index (indexParams: VectorIndexParams): Promise<any> {
return tableCreateVectorIndex.call(this._tbl, indexParams)
}
@@ -268,6 +318,21 @@ export enum WriteMode {
Append = 'append'
}
/**
* An embedding function that automatically creates vector representation for a given column.
*/
export interface EmbeddingFunction<T> {
/**
* The name of the column that will be used as input for the Embedding Function.
*/
sourceColumn: string
/**
* Creates a vector representation for the given values.
*/
embed: (data: T[]) => number[][]
}
/**
* Distance metrics type.
*/

View File

@@ -17,7 +17,7 @@ import { assert } from 'chai'
import { track } from 'temp'
import * as lancedb from '../index'
import { MetricType, Query } from '../index'
import { type EmbeddingFunction, MetricType, Query } from '../index'
describe('LanceDB client', function () {
describe('when creating a connection to lancedb', function () {
@@ -140,6 +140,39 @@ describe('LanceDB client', function () {
await table.create_index({ type: 'ivf_pq', column: 'vector', num_partitions: 2, max_iters: 2 })
}).timeout(10_000) // Timeout is high partially because GH macos runner is pretty slow
})
describe('when using a custom embedding function', function () {
class TextEmbedding implements EmbeddingFunction<string> {
sourceColumn: string
constructor (targetColumn: string) {
this.sourceColumn = targetColumn
}
_embedding_map = new Map<string, number[]>([
['foo', [2.1, 2.2]],
['bar', [3.1, 3.2]]
])
embed (data: string[]): number[][] {
return data.map(datum => this._embedding_map.get(datum) ?? [0.0, 0.0])
}
}
it('should encode the original data into embeddings', async function () {
const dir = await track().mkdir('lancejs')
const con = await lancedb.connect(dir)
const embeddings = new TextEmbedding('name')
const data = [
{ price: 10, name: 'foo' },
{ price: 50, name: 'bar' }
]
const table = await con.createTable('vectors', data, embeddings)
const results = await table.search('foo').execute()
assert.equal(results.length, 2)
})
})
})
describe('Query object', function () {