diff --git a/node/src/embedding/embedding_function.ts b/node/src/embedding/embedding_function.ts index 2b662132..e152d41c 100644 --- a/node/src/embedding/embedding_function.ts +++ b/node/src/embedding/embedding_function.ts @@ -26,3 +26,9 @@ export interface EmbeddingFunction { */ embed: (data: T[]) => Promise } + +export function isEmbeddingFunction (value: any): value is EmbeddingFunction { + return Object.keys(value).length === 2 && + typeof value.sourceColumn === 'string' && + typeof value.embed === 'function' +} diff --git a/node/src/index.ts b/node/src/index.ts index e475e21a..b05729d9 100644 --- a/node/src/index.ts +++ b/node/src/index.ts @@ -20,10 +20,12 @@ import { fromRecordsToBuffer } from './arrow' import type { EmbeddingFunction } from './embedding/embedding_function' import { RemoteConnection } from './remote' import { Query } from './query' +import { isEmbeddingFunction } from './embedding/embedding_function' // eslint-disable-next-line @typescript-eslint/no-var-requires const { databaseNew, databaseTableNames, databaseOpenTable, databaseDropTable, tableCreate, tableAdd, tableCreateVectorIndex, tableCountRows, tableDelete } = require('../native.js') +export { Query } export type { EmbeddingFunction } export { OpenAIEmbeddingFunction } from './embedding/openai' @@ -100,10 +102,35 @@ export interface Connection { * * @param {string} name - The name of the table. * @param data - Non-empty Array of Records to be inserted into the table - * @param {WriteMode} mode - The write mode to use when creating the table. + */ + createTable (name: string, data: Array>): Promise + + /** + * Creates a new Table and initialize it with new data. + * + * @param {string} name - The name of the table. + * @param data - Non-empty Array of Records to be inserted into the table + * @param {WriteOptions} options - The write options to use when creating the table. + */ + createTable (name: string, data: Array>, options: WriteOptions): Promise
+ + /** + * Creates a new Table and initialize it with new data. + * + * @param {string} name - The name of the table. + * @param data - Non-empty Array of Records to be inserted into the table * @param {EmbeddingFunction} embeddings - An embedding function to use on this table */ - createTable(name: string, data: Array>, mode?: WriteMode, embeddings?: EmbeddingFunction): Promise> + createTable (name: string, data: Array>, embeddings: EmbeddingFunction): Promise> + /** + * Creates a new Table and initialize it with new data. + * + * @param {string} name - The name of the table. + * @param data - Non-empty Array of Records to be inserted into the table + * @param {EmbeddingFunction} embeddings - An embedding function to use on this table + * @param {WriteOptions} options - The write options to use when creating the table. + */ + createTable (name: string, data: Array>, embeddings: EmbeddingFunction, options: WriteOptions): Promise> createTableArrow(name: string, table: ArrowTable): Promise
@@ -237,32 +264,19 @@ export class LocalConnection implements Connection { } } - /** - * Creates a new Table and initialize it with new data. - * - * @param name The name of the table. - * @param data Non-empty Array of Records to be inserted into the Table - * @param mode The write mode to use when creating the table. - */ - async createTable (name: string, data: Array>, mode?: WriteMode): Promise
- async createTable (name: string, data: Array>, mode: WriteMode): Promise
- - /** - * Creates a new Table and initialize it with new data. - * - * @param name The name of the table. - * @param data Non-empty Array of Records to be inserted into the Table - * @param mode The write mode to use when creating the table. - * @param embeddings An embedding function to use on this Table - */ - async createTable (name: string, data: Array>, mode: WriteMode, embeddings: EmbeddingFunction): Promise> - async createTable (name: string, data: Array>, mode: WriteMode, embeddings?: EmbeddingFunction): Promise> - async createTable (name: string, data: Array>, mode: WriteMode, embeddings?: EmbeddingFunction): Promise> { - if (mode === undefined) { - mode = WriteMode.Create + async createTable (name: string, data: Array>, optsOrEmbedding?: WriteOptions | EmbeddingFunction, opt?: WriteOptions): Promise> { + let writeOptions: WriteOptions = new DefaultWriteOptions() + if (opt !== undefined && isWriteOptions(opt)) { + writeOptions = opt + } else if (optsOrEmbedding !== undefined && isWriteOptions(optsOrEmbedding)) { + writeOptions = optsOrEmbedding } - const createArgs = [this._db, name, await fromRecordsToBuffer(data, embeddings), mode.toLowerCase()] + let embeddings: undefined | EmbeddingFunction + if (optsOrEmbedding !== undefined && isEmbeddingFunction(optsOrEmbedding)) { + embeddings = optsOrEmbedding + } + const createArgs = [this._db, name, await fromRecordsToBuffer(data, embeddings), writeOptions.writeMode?.toString()] if (this._options.awsCredentials !== undefined) { createArgs.push(this._options.awsCredentials.accessKeyId) createArgs.push(this._options.awsCredentials.secretKey) @@ -459,6 +473,23 @@ export enum WriteMode { Append = 'append' } +/** + * Write options when creating a Table. + */ +export interface WriteOptions { + /** A {@link WriteMode} to use on this operation */ + writeMode?: WriteMode +} + +export class DefaultWriteOptions implements WriteOptions { + writeMode = WriteMode.Create +} + +export function isWriteOptions (value: any): value is WriteOptions { + return Object.keys(value).length === 1 && + (value.writeMode === undefined || typeof value.writeMode === 'string') +} + /** * Distance metrics type. */ diff --git a/node/src/test/test.ts b/node/src/test/test.ts index cdfca32e..e0e76114 100644 --- a/node/src/test/test.ts +++ b/node/src/test/test.ts @@ -18,8 +18,7 @@ import * as chai from 'chai' import * as chaiAsPromised from 'chai-as-promised' import * as lancedb from '../index' -import { type AwsCredentials, type EmbeddingFunction, MetricType, WriteMode } from '../index' -import { Query } from '../query' +import { type AwsCredentials, type EmbeddingFunction, MetricType, Query, WriteMode, DefaultWriteOptions, isWriteOptions } from '../index' const expect = chai.expect const assert = chai.assert @@ -145,7 +144,7 @@ describe('LanceDB client', function () { ] const tableName = 'overwrite' - await con.createTable(tableName, data, WriteMode.Create) + await con.createTable(tableName, data, { writeMode: WriteMode.Create }) const newData = [ { id: 1, vector: [0.1, 0.2], price: 10 }, @@ -155,7 +154,7 @@ describe('LanceDB client', function () { await expect(con.createTable(tableName, newData)).to.be.rejectedWith(Error, 'already exists') - const table = await con.createTable(tableName, newData, WriteMode.Overwrite) + const table = await con.createTable(tableName, newData, { writeMode: WriteMode.Overwrite }) assert.equal(table.name, tableName) assert.equal(await table.countRows(), 3) }) @@ -260,7 +259,7 @@ describe('LanceDB client', function () { { price: 10, name: 'foo' }, { price: 50, name: 'bar' } ] - const table = await con.createTable('vectors', data, WriteMode.Create, embeddings) + const table = await con.createTable('vectors', data, embeddings, { writeMode: WriteMode.Create }) const results = await table.search('foo').execute() assert.equal(results.length, 2) }) @@ -318,3 +317,20 @@ describe('Drop table', function () { assert.deepEqual(await con.tableNames(), ['t2']) }) }) + +describe('WriteOptions', function () { + context('#isWriteOptions', function () { + it('should not match empty object', function () { + assert.equal(isWriteOptions({}), false) + }) + it('should match write options', function () { + assert.equal(isWriteOptions({ writeMode: WriteMode.Create }), true) + }) + it('should match undefined write mode', function () { + assert.equal(isWriteOptions({ writeMode: undefined }), true) + }) + it('should match default write options', function () { + assert.equal(isWriteOptions(new DefaultWriteOptions()), true) + }) + }) +})