mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 23:12:58 +00:00
add embedding functions to the nodejs client (#95)
This commit is contained in:
@@ -15,15 +15,16 @@
|
||||
import {
|
||||
Field,
|
||||
Float32,
|
||||
List,
|
||||
List, type ListBuilder,
|
||||
makeBuilder,
|
||||
RecordBatchFileWriter,
|
||||
Table, Utf8,
|
||||
type Vector,
|
||||
vectorFromArray
|
||||
} from 'apache-arrow'
|
||||
import { type EmbeddingFunction } from './index'
|
||||
|
||||
export function convertToTable (data: Array<Record<string, unknown>>): Table {
|
||||
export function convertToTable<T> (data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>): Table {
|
||||
if (data.length === 0) {
|
||||
throw new Error('At least one record needs to be provided')
|
||||
}
|
||||
@@ -33,11 +34,7 @@ export function convertToTable (data: Array<Record<string, unknown>>): Table {
|
||||
|
||||
for (const columnsKey of columns) {
|
||||
if (columnsKey === 'vector') {
|
||||
const children = new Field<Float32>('item', new Float32())
|
||||
const list = new List(children)
|
||||
const listBuilder = makeBuilder({
|
||||
type: list
|
||||
})
|
||||
const listBuilder = newVectorListBuilder()
|
||||
const vectorSize = (data[0].vector as any[]).length
|
||||
for (const datum of data) {
|
||||
if ((datum[columnsKey] as any[]).length !== vectorSize) {
|
||||
@@ -52,6 +49,14 @@ export function convertToTable (data: Array<Record<string, unknown>>): Table {
|
||||
for (const datum of data) {
|
||||
values.push(datum[columnsKey])
|
||||
}
|
||||
|
||||
if (columnsKey === embeddings?.sourceColumn) {
|
||||
const vectors = embeddings.embed(values as T[])
|
||||
const listBuilder = newVectorListBuilder()
|
||||
vectors.map(v => listBuilder.append(v))
|
||||
records.vector = listBuilder.finish().toVector()
|
||||
}
|
||||
|
||||
if (typeof values[0] === 'string') {
|
||||
// `vectorFromArray` converts strings into dictionary vectors, forcing it back to a string column
|
||||
records[columnsKey] = vectorFromArray(values, new Utf8())
|
||||
@@ -64,8 +69,17 @@ export function convertToTable (data: Array<Record<string, unknown>>): Table {
|
||||
return new Table(records)
|
||||
}
|
||||
|
||||
export async function fromRecordsToBuffer (data: Array<Record<string, unknown>>): Promise<Buffer> {
|
||||
const table = convertToTable(data)
|
||||
// Creates a new Arrow ListBuilder that stores a Vector column
|
||||
function newVectorListBuilder (): ListBuilder<Float32, any> {
|
||||
const children = new Field<Float32>('item', new Float32())
|
||||
const list = new List(children)
|
||||
return makeBuilder({
|
||||
type: list
|
||||
})
|
||||
}
|
||||
|
||||
export async function fromRecordsToBuffer<T> (data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>): Promise<Buffer> {
|
||||
const table = convertToTable(data, embeddings)
|
||||
const writer = RecordBatchFileWriter.writeAll(table)
|
||||
return Buffer.from(await writer.toUint8Array())
|
||||
}
|
||||
|
||||
@@ -55,17 +55,50 @@ export class Connection {
|
||||
}
|
||||
|
||||
/**
|
||||
* Open a table in the database.
|
||||
* @param name The name of the table.
|
||||
*/
|
||||
async openTable (name: string): Promise<Table> {
|
||||
* Open a table in the database.
|
||||
*
|
||||
* @param name The name of the table.
|
||||
*/
|
||||
async openTable (name: string): Promise<Table>
|
||||
/**
|
||||
* Open a table in the database.
|
||||
*
|
||||
* @param name The name of the table.
|
||||
* @param embeddings An embedding function to use on this Table
|
||||
*/
|
||||
async openTable<T> (name: string, embeddings: EmbeddingFunction<T>): Promise<Table<T>>
|
||||
async openTable<T> (name: string, embeddings?: EmbeddingFunction<T>): Promise<Table<T>> {
|
||||
const tbl = await databaseOpenTable.call(this._db, name)
|
||||
return new Table(tbl, name)
|
||||
if (embeddings !== undefined) {
|
||||
return new Table(tbl, name, embeddings)
|
||||
} else {
|
||||
return new Table(tbl, name)
|
||||
}
|
||||
}
|
||||
|
||||
async createTable (name: string, data: Array<Record<string, unknown>>): Promise<Table> {
|
||||
await tableCreate.call(this._db, name, await fromRecordsToBuffer(data))
|
||||
return await this.openTable(name)
|
||||
/**
|
||||
* Creates a new Table and initialize it with new data.
|
||||
*
|
||||
* @param name The name of the table.
|
||||
* @param data Non-empty Array of Records to be inserted into the Table
|
||||
*/
|
||||
|
||||
async createTable (name: string, data: Array<Record<string, unknown>>): Promise<Table>
|
||||
/**
|
||||
* Creates a new Table and initialize it with new data.
|
||||
*
|
||||
* @param name The name of the table.
|
||||
* @param data Non-empty Array of Records to be inserted into the Table
|
||||
* @param embeddings An embedding function to use on this Table
|
||||
*/
|
||||
async createTable<T> (name: string, data: Array<Record<string, unknown>>, embeddings: EmbeddingFunction<T>): Promise<Table<T>>
|
||||
async createTable<T> (name: string, data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>): Promise<Table<T>> {
|
||||
const tbl = await tableCreate.call(this._db, name, await fromRecordsToBuffer(data, embeddings))
|
||||
if (embeddings !== undefined) {
|
||||
return new Table(tbl, name, embeddings)
|
||||
} else {
|
||||
return new Table(tbl, name)
|
||||
}
|
||||
}
|
||||
|
||||
async createTableArrow (name: string, table: ArrowTable): Promise<Table> {
|
||||
@@ -75,16 +108,22 @@ export class Connection {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A table in a LanceDB database.
|
||||
*/
|
||||
export class Table {
|
||||
export class Table<T = number[]> {
|
||||
private readonly _tbl: any
|
||||
private readonly _name: string
|
||||
private readonly _embeddings?: EmbeddingFunction<T>
|
||||
|
||||
constructor (tbl: any, name: string) {
|
||||
constructor (tbl: any, name: string)
|
||||
/**
|
||||
* @param tbl
|
||||
* @param name
|
||||
* @param embeddings An embedding function to use when interacting with this table
|
||||
*/
|
||||
constructor (tbl: any, name: string, embeddings: EmbeddingFunction<T>)
|
||||
constructor (tbl: any, name: string, embeddings?: EmbeddingFunction<T>) {
|
||||
this._tbl = tbl
|
||||
this._name = name
|
||||
this._embeddings = embeddings
|
||||
}
|
||||
|
||||
get name (): string {
|
||||
@@ -92,10 +131,16 @@ export class Table {
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a search query to find the nearest neighbors of the given query vector.
|
||||
* @param queryVector The query vector.
|
||||
*/
|
||||
search (queryVector: number[]): Query {
|
||||
* Creates a search query to find the nearest neighbors of the given search term
|
||||
* @param query The query search term
|
||||
*/
|
||||
search (query: T): Query {
|
||||
let queryVector: number[]
|
||||
if (this._embeddings !== undefined) {
|
||||
queryVector = this._embeddings.embed([query])[0]
|
||||
} else {
|
||||
queryVector = query as number[]
|
||||
}
|
||||
return new Query(this._tbl, queryVector)
|
||||
}
|
||||
|
||||
@@ -106,7 +151,7 @@ export class Table {
|
||||
* @return The number of rows added to the table
|
||||
*/
|
||||
async add (data: Array<Record<string, unknown>>): Promise<number> {
|
||||
return tableAdd.call(this._tbl, await fromRecordsToBuffer(data), WriteMode.Append.toString())
|
||||
return tableAdd.call(this._tbl, await fromRecordsToBuffer(data, this._embeddings), WriteMode.Append.toString())
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -116,9 +161,14 @@ export class Table {
|
||||
* @return The number of rows added to the table
|
||||
*/
|
||||
async overwrite (data: Array<Record<string, unknown>>): Promise<number> {
|
||||
return tableAdd.call(this._tbl, await fromRecordsToBuffer(data), WriteMode.Overwrite.toString())
|
||||
return tableAdd.call(this._tbl, await fromRecordsToBuffer(data, this._embeddings), WriteMode.Overwrite.toString())
|
||||
}
|
||||
|
||||
/**
|
||||
* Create an ANN index on this Table vector index.
|
||||
*
|
||||
* @param indexParams The parameters of this Index, @see VectorIndexParams.
|
||||
*/
|
||||
async create_index (indexParams: VectorIndexParams): Promise<any> {
|
||||
return tableCreateVectorIndex.call(this._tbl, indexParams)
|
||||
}
|
||||
@@ -268,6 +318,21 @@ export enum WriteMode {
|
||||
Append = 'append'
|
||||
}
|
||||
|
||||
/**
|
||||
* An embedding function that automatically creates vector representation for a given column.
|
||||
*/
|
||||
export interface EmbeddingFunction<T> {
|
||||
/**
|
||||
* The name of the column that will be used as input for the Embedding Function.
|
||||
*/
|
||||
sourceColumn: string
|
||||
|
||||
/**
|
||||
* Creates a vector representation for the given values.
|
||||
*/
|
||||
embed: (data: T[]) => number[][]
|
||||
}
|
||||
|
||||
/**
|
||||
* Distance metrics type.
|
||||
*/
|
||||
|
||||
@@ -17,7 +17,7 @@ import { assert } from 'chai'
|
||||
import { track } from 'temp'
|
||||
|
||||
import * as lancedb from '../index'
|
||||
import { MetricType, Query } from '../index'
|
||||
import { type EmbeddingFunction, MetricType, Query } from '../index'
|
||||
|
||||
describe('LanceDB client', function () {
|
||||
describe('when creating a connection to lancedb', function () {
|
||||
@@ -140,6 +140,39 @@ describe('LanceDB client', function () {
|
||||
await table.create_index({ type: 'ivf_pq', column: 'vector', num_partitions: 2, max_iters: 2 })
|
||||
}).timeout(10_000) // Timeout is high partially because GH macos runner is pretty slow
|
||||
})
|
||||
|
||||
describe('when using a custom embedding function', function () {
|
||||
class TextEmbedding implements EmbeddingFunction<string> {
|
||||
sourceColumn: string
|
||||
|
||||
constructor (targetColumn: string) {
|
||||
this.sourceColumn = targetColumn
|
||||
}
|
||||
|
||||
_embedding_map = new Map<string, number[]>([
|
||||
['foo', [2.1, 2.2]],
|
||||
['bar', [3.1, 3.2]]
|
||||
])
|
||||
|
||||
embed (data: string[]): number[][] {
|
||||
return data.map(datum => this._embedding_map.get(datum) ?? [0.0, 0.0])
|
||||
}
|
||||
}
|
||||
|
||||
it('should encode the original data into embeddings', async function () {
|
||||
const dir = await track().mkdir('lancejs')
|
||||
const con = await lancedb.connect(dir)
|
||||
const embeddings = new TextEmbedding('name')
|
||||
|
||||
const data = [
|
||||
{ price: 10, name: 'foo' },
|
||||
{ price: 50, name: 'bar' }
|
||||
]
|
||||
const table = await con.createTable('vectors', data, embeddings)
|
||||
const results = await table.search('foo').execute()
|
||||
assert.equal(results.length, 2)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Query object', function () {
|
||||
|
||||
Reference in New Issue
Block a user