diff --git a/node/src/arrow.ts b/node/src/arrow.ts index 36f278e6..ebd8d652 100644 --- a/node/src/arrow.ts +++ b/node/src/arrow.ts @@ -15,15 +15,16 @@ import { Field, Float32, - List, + List, type ListBuilder, makeBuilder, RecordBatchFileWriter, Table, Utf8, type Vector, vectorFromArray } from 'apache-arrow' +import { type EmbeddingFunction } from './index' -export function convertToTable (data: Array>): Table { +export function convertToTable (data: Array>, embeddings?: EmbeddingFunction): Table { if (data.length === 0) { throw new Error('At least one record needs to be provided') } @@ -33,11 +34,7 @@ export function convertToTable (data: Array>): Table { for (const columnsKey of columns) { if (columnsKey === 'vector') { - const children = new Field('item', new Float32()) - const list = new List(children) - const listBuilder = makeBuilder({ - type: list - }) + const listBuilder = newVectorListBuilder() const vectorSize = (data[0].vector as any[]).length for (const datum of data) { if ((datum[columnsKey] as any[]).length !== vectorSize) { @@ -52,6 +49,14 @@ export function convertToTable (data: Array>): Table { for (const datum of data) { values.push(datum[columnsKey]) } + + if (columnsKey === embeddings?.sourceColumn) { + const vectors = embeddings.embed(values as T[]) + const listBuilder = newVectorListBuilder() + vectors.map(v => listBuilder.append(v)) + records.vector = listBuilder.finish().toVector() + } + if (typeof values[0] === 'string') { // `vectorFromArray` converts strings into dictionary vectors, forcing it back to a string column records[columnsKey] = vectorFromArray(values, new Utf8()) @@ -64,8 +69,17 @@ export function convertToTable (data: Array>): Table { return new Table(records) } -export async function fromRecordsToBuffer (data: Array>): Promise { - const table = convertToTable(data) +// Creates a new Arrow ListBuilder that stores a Vector column +function newVectorListBuilder (): ListBuilder { + const children = new Field('item', new Float32()) + const list = new List(children) + return makeBuilder({ + type: list + }) +} + +export async function fromRecordsToBuffer (data: Array>, embeddings?: EmbeddingFunction): Promise { + const table = convertToTable(data, embeddings) const writer = RecordBatchFileWriter.writeAll(table) return Buffer.from(await writer.toUint8Array()) } diff --git a/node/src/index.ts b/node/src/index.ts index 486690e2..3462ee88 100644 --- a/node/src/index.ts +++ b/node/src/index.ts @@ -55,17 +55,50 @@ export class Connection { } /** - * Open a table in the database. - * @param name The name of the table. - */ - async openTable (name: string): Promise { + * Open a table in the database. + * + * @param name The name of the table. + */ + async openTable (name: string): Promise
+ /** + * Open a table in the database. + * + * @param name The name of the table. + * @param embeddings An embedding function to use on this Table + */ + async openTable (name: string, embeddings: EmbeddingFunction): Promise> + async openTable (name: string, embeddings?: EmbeddingFunction): Promise> { const tbl = await databaseOpenTable.call(this._db, name) - return new Table(tbl, name) + if (embeddings !== undefined) { + return new Table(tbl, name, embeddings) + } else { + return new Table(tbl, name) + } } - async createTable (name: string, data: Array>): Promise
{ - await tableCreate.call(this._db, name, await fromRecordsToBuffer(data)) - return await this.openTable(name) + /** + * Creates a new Table and initialize it with new data. + * + * @param name The name of the table. + * @param data Non-empty Array of Records to be inserted into the Table + */ + + async createTable (name: string, data: Array>): Promise
+ /** + * Creates a new Table and initialize it with new data. + * + * @param name The name of the table. + * @param data Non-empty Array of Records to be inserted into the Table + * @param embeddings An embedding function to use on this Table + */ + async createTable (name: string, data: Array>, embeddings: EmbeddingFunction): Promise> + async createTable (name: string, data: Array>, embeddings?: EmbeddingFunction): Promise> { + const tbl = await tableCreate.call(this._db, name, await fromRecordsToBuffer(data, embeddings)) + if (embeddings !== undefined) { + return new Table(tbl, name, embeddings) + } else { + return new Table(tbl, name) + } } async createTableArrow (name: string, table: ArrowTable): Promise
{ @@ -75,16 +108,22 @@ export class Connection { } } -/** - * A table in a LanceDB database. - */ -export class Table { +export class Table { private readonly _tbl: any private readonly _name: string + private readonly _embeddings?: EmbeddingFunction - constructor (tbl: any, name: string) { + constructor (tbl: any, name: string) + /** + * @param tbl + * @param name + * @param embeddings An embedding function to use when interacting with this table + */ + constructor (tbl: any, name: string, embeddings: EmbeddingFunction) + constructor (tbl: any, name: string, embeddings?: EmbeddingFunction) { this._tbl = tbl this._name = name + this._embeddings = embeddings } get name (): string { @@ -92,10 +131,16 @@ export class Table { } /** - * Create a search query to find the nearest neighbors of the given query vector. - * @param queryVector The query vector. - */ - search (queryVector: number[]): Query { + * Creates a search query to find the nearest neighbors of the given search term + * @param query The query search term + */ + search (query: T): Query { + let queryVector: number[] + if (this._embeddings !== undefined) { + queryVector = this._embeddings.embed([query])[0] + } else { + queryVector = query as number[] + } return new Query(this._tbl, queryVector) } @@ -106,7 +151,7 @@ export class Table { * @return The number of rows added to the table */ async add (data: Array>): Promise { - return tableAdd.call(this._tbl, await fromRecordsToBuffer(data), WriteMode.Append.toString()) + return tableAdd.call(this._tbl, await fromRecordsToBuffer(data, this._embeddings), WriteMode.Append.toString()) } /** @@ -116,9 +161,14 @@ export class Table { * @return The number of rows added to the table */ async overwrite (data: Array>): Promise { - return tableAdd.call(this._tbl, await fromRecordsToBuffer(data), WriteMode.Overwrite.toString()) + return tableAdd.call(this._tbl, await fromRecordsToBuffer(data, this._embeddings), WriteMode.Overwrite.toString()) } + /** + * Create an ANN index on this Table vector index. + * + * @param indexParams The parameters of this Index, @see VectorIndexParams. + */ async create_index (indexParams: VectorIndexParams): Promise { return tableCreateVectorIndex.call(this._tbl, indexParams) } @@ -268,6 +318,21 @@ export enum WriteMode { Append = 'append' } +/** + * An embedding function that automatically creates vector representation for a given column. + */ +export interface EmbeddingFunction { + /** + * The name of the column that will be used as input for the Embedding Function. + */ + sourceColumn: string + + /** + * Creates a vector representation for the given values. + */ + embed: (data: T[]) => number[][] +} + /** * Distance metrics type. */ diff --git a/node/src/test/test.ts b/node/src/test/test.ts index 2a927167..f785b9f9 100644 --- a/node/src/test/test.ts +++ b/node/src/test/test.ts @@ -17,7 +17,7 @@ import { assert } from 'chai' import { track } from 'temp' import * as lancedb from '../index' -import { MetricType, Query } from '../index' +import { type EmbeddingFunction, MetricType, Query } from '../index' describe('LanceDB client', function () { describe('when creating a connection to lancedb', function () { @@ -140,6 +140,39 @@ describe('LanceDB client', function () { await table.create_index({ type: 'ivf_pq', column: 'vector', num_partitions: 2, max_iters: 2 }) }).timeout(10_000) // Timeout is high partially because GH macos runner is pretty slow }) + + describe('when using a custom embedding function', function () { + class TextEmbedding implements EmbeddingFunction { + sourceColumn: string + + constructor (targetColumn: string) { + this.sourceColumn = targetColumn + } + + _embedding_map = new Map([ + ['foo', [2.1, 2.2]], + ['bar', [3.1, 3.2]] + ]) + + embed (data: string[]): number[][] { + return data.map(datum => this._embedding_map.get(datum) ?? [0.0, 0.0]) + } + } + + it('should encode the original data into embeddings', async function () { + const dir = await track().mkdir('lancejs') + const con = await lancedb.connect(dir) + const embeddings = new TextEmbedding('name') + + const data = [ + { price: 10, name: 'foo' }, + { price: 50, name: 'bar' } + ] + const table = await con.createTable('vectors', data, embeddings) + const results = await table.search('foo').execute() + assert.equal(results.length, 2) + }) + }) }) describe('Query object', function () {