diff --git a/node/src/index.ts b/node/src/index.ts index 830bedd4..f8072799 100644 --- a/node/src/index.ts +++ b/node/src/index.ts @@ -27,13 +27,38 @@ const { databaseNew, databaseTableNames, databaseOpenTable, databaseDropTable, t export type { EmbeddingFunction } export { OpenAIEmbeddingFunction } from './embedding/openai' +export interface AwsCredentials { + accessKeyId: string + + secretKey: string + + sessionToken?: string +} + +export interface ConnectionOptions { + uri: string + awsCredentials?: AwsCredentials +} + /** * Connect to a LanceDB instance at the given URI * @param uri The uri of the database. */ -export async function connect (uri: string): Promise { - const db = await databaseNew(uri) - return new LocalConnection(db, uri) +export async function connect (uri: string): Promise +export async function connect (opts: Partial): Promise +export async function connect (arg: string | Partial): Promise { + let opts: ConnectionOptions + if (typeof arg === 'string') { + opts = { uri: arg } + } else { + // opts = { uri: arg.uri, awsCredentials = arg.awsCredentials } + opts = Object.assign({ + uri: '', + awsCredentials: undefined + }, arg) + } + const db = await databaseNew(opts.uri) + return new LocalConnection(db, opts) } /** @@ -122,28 +147,20 @@ export interface Table { delete: (filter: string) => Promise } -export interface AwsCredentials { - accessKeyId: string - - secretKey: string - - sessionToken?: string -} - /** * A connection to a LanceDB database. */ export class LocalConnection implements Connection { - private readonly _uri: string + private readonly _options: ConnectionOptions private readonly _db: any - constructor (db: any, uri: string) { - this._uri = uri + constructor (db: any, options: ConnectionOptions) { + this._options = options this._db = db } get uri (): string { - return this._uri + return this._options.uri } /** @@ -166,10 +183,14 @@ export class LocalConnection implements Connection { * @param embeddings An embedding function to use on this Table */ async openTable (name: string, embeddings: EmbeddingFunction): Promise> - async openTable (name: string, embeddings?: EmbeddingFunction, awsCredentials?: AwsCredentials): Promise> - async openTable (name: string, embeddings?: EmbeddingFunction, awsCredentials?: AwsCredentials): Promise> { + async openTable (name: string, embeddings?: EmbeddingFunction): Promise> + async openTable (name: string, embeddings?: EmbeddingFunction): Promise> { const tbl = await databaseOpenTable.call(this._db, name) - return new LocalTable(tbl, name, embeddings, awsCredentials) + if (embeddings !== undefined) { + return new LocalTable(tbl, name, this._options, embeddings) + } else { + return new LocalTable(tbl, name, this._options) + } } /** @@ -191,24 +212,28 @@ export class LocalConnection implements Connection { * @param embeddings An embedding function to use on this Table */ async createTable (name: string, data: Array>, mode: WriteMode, embeddings: EmbeddingFunction): Promise> - async createTable (name: string, data: Array>, mode: WriteMode, embeddings?: EmbeddingFunction, awsCredentials?: AwsCredentials): Promise> - async createTable (name: string, data: Array>, mode: WriteMode, embeddings?: EmbeddingFunction, awsCredentials?: AwsCredentials): Promise> { + async createTable (name: string, data: Array>, mode: WriteMode, embeddings?: EmbeddingFunction): Promise> + async createTable (name: string, data: Array>, mode: WriteMode, embeddings?: EmbeddingFunction): Promise> { if (mode === undefined) { mode = WriteMode.Create } const createArgs = [this._db, name, await fromRecordsToBuffer(data, embeddings), mode.toLowerCase()] - if (awsCredentials !== undefined) { - createArgs.push(awsCredentials.accessKeyId) - createArgs.push(awsCredentials.secretKey) - if (awsCredentials.sessionToken !== undefined) { - createArgs.push(awsCredentials.sessionToken) + if (this._options.awsCredentials !== undefined) { + createArgs.push(this._options.awsCredentials.accessKeyId) + createArgs.push(this._options.awsCredentials.secretKey) + if (this._options.awsCredentials.sessionToken !== undefined) { + createArgs.push(this._options.awsCredentials.sessionToken) } } const tbl = await tableCreate.call(...createArgs) - return new LocalTable(tbl, name, embeddings, awsCredentials) + if (embeddings !== undefined) { + return new LocalTable(tbl, name, this._options, embeddings) + } else { + return new LocalTable(tbl, name, this._options) + } } async createTableArrow (name: string, table: ArrowTable): Promise { @@ -230,21 +255,21 @@ export class LocalTable implements Table { private readonly _tbl: any private readonly _name: string private readonly _embeddings?: EmbeddingFunction - private readonly _awsCredentials?: AwsCredentials + private readonly _options: ConnectionOptions - constructor (tbl: any, name: string) + constructor (tbl: any, name: string, options: ConnectionOptions) /** * @param tbl * @param name + * @param options * @param embeddings An embedding function to use when interacting with this table */ - constructor (tbl: any, name: string, embeddings: EmbeddingFunction) - constructor (tbl: any, name: string, embeddings?: EmbeddingFunction, awsCredentials?: AwsCredentials) - constructor (tbl: any, name: string, embeddings?: EmbeddingFunction, awsCredentials?: AwsCredentials) { + constructor (tbl: any, name: string, options: ConnectionOptions, embeddings: EmbeddingFunction) + constructor (tbl: any, name: string, options: ConnectionOptions, embeddings?: EmbeddingFunction) { this._tbl = tbl this._name = name this._embeddings = embeddings - this._awsCredentials = awsCredentials + this._options = options } get name (): string { @@ -267,11 +292,11 @@ export class LocalTable implements Table { */ async add (data: Array>): Promise { const callArgs = [this._tbl, await fromRecordsToBuffer(data, this._embeddings), WriteMode.Append.toString()] - if (this._awsCredentials !== undefined) { - callArgs.push(this._awsCredentials.accessKeyId) - callArgs.push(this._awsCredentials.secretKey) - if (this._awsCredentials.sessionToken !== undefined) { - callArgs.push(this._awsCredentials.sessionToken) + if (this._options.awsCredentials !== undefined) { + callArgs.push(this._options.awsCredentials.accessKeyId) + callArgs.push(this._options.awsCredentials.secretKey) + if (this._options.awsCredentials.sessionToken !== undefined) { + callArgs.push(this._options.awsCredentials.sessionToken) } } return tableAdd.call(...callArgs) @@ -285,11 +310,11 @@ export class LocalTable implements Table { */ async overwrite (data: Array>): Promise { const callArgs = [this._tbl, await fromRecordsToBuffer(data, this._embeddings), WriteMode.Overwrite.toString()] - if (this._awsCredentials !== undefined) { - callArgs.push(this._awsCredentials.accessKeyId) - callArgs.push(this._awsCredentials.secretKey) - if (this._awsCredentials.sessionToken !== undefined) { - callArgs.push(this._awsCredentials.sessionToken) + if (this._options.awsCredentials !== undefined) { + callArgs.push(this._options.awsCredentials.accessKeyId) + callArgs.push(this._options.awsCredentials.secretKey) + if (this._options.awsCredentials.sessionToken !== undefined) { + callArgs.push(this._options.awsCredentials.sessionToken) } } return tableAdd.call(this._tbl, await fromRecordsToBuffer(data, this._embeddings), WriteMode.Overwrite.toString()) diff --git a/node/src/test/io.ts b/node/src/test/io.ts index fb667ba9..6f383513 100644 --- a/node/src/test/io.ts +++ b/node/src/test/io.ts @@ -18,26 +18,48 @@ import { describe } from 'mocha' import { assert } from 'chai' import * as lancedb from '../index' +import { type ConnectionOptions } from '../index' describe('LanceDB S3 client', function () { if (process.env.TEST_S3_BASE_URL != null) { const baseUri = process.env.TEST_S3_BASE_URL it('should have a valid url', async function () { - const uri = `${baseUri}/valid_url` - const table = await createTestDB(uri, 2, 20) - const con = await lancedb.connect(uri) - assert.equal(con.uri, uri) + const opts = { uri: `${baseUri}/valid_url` } + const table = await createTestDB(opts, 2, 20) + const con = await lancedb.connect(opts) + assert.equal(con.uri, opts.uri) const results = await table.search([0.1, 0.3]).limit(5).execute() assert.equal(results.length, 5) - }) + }).timeout(10_000) + } else { + describe.skip('Skip S3 test', function () {}) + } + + if (process.env.TEST_S3_BASE_URL != null && process.env.TEST_AWS_ACCESS_KEY_ID != null && process.env.TEST_AWS_SECRET_ACCESS_KEY != null) { + const baseUri = process.env.TEST_S3_BASE_URL + it('use custom credentials', async function () { + const opts: ConnectionOptions = { + uri: `${baseUri}/custom_credentials`, + awsCredentials: { + accessKeyId: process.env.TEST_AWS_ACCESS_KEY_ID as string, + secretKey: process.env.TEST_AWS_SECRET_ACCESS_KEY as string + } + } + const table = await createTestDB(opts, 2, 20) + const con = await lancedb.connect(opts) + assert.equal(con.uri, opts.uri) + + const results = await table.search([0.1, 0.3]).limit(5).execute() + assert.equal(results.length, 5) + }).timeout(10_000) } else { describe.skip('Skip S3 test', function () {}) } }) -async function createTestDB (uri: string, numDimensions: number = 2, numRows: number = 2): Promise { - const con = await lancedb.connect(uri) +async function createTestDB (opts: ConnectionOptions, numDimensions: number = 2, numRows: number = 2): Promise { + const con = await lancedb.connect(opts) const data = [] for (let i = 0; i < numRows; i++) { diff --git a/node/src/test/test.ts b/node/src/test/test.ts index 882055cf..aefb6f18 100644 --- a/node/src/test/test.ts +++ b/node/src/test/test.ts @@ -18,7 +18,7 @@ import * as chai from 'chai' import * as chaiAsPromised from 'chai-as-promised' import * as lancedb from '../index' -import { type EmbeddingFunction, MetricType, Query, WriteMode } from '../index' +import { type AwsCredentials, type EmbeddingFunction, MetricType, Query, WriteMode } from '../index' const expect = chai.expect const assert = chai.assert @@ -32,6 +32,22 @@ describe('LanceDB client', function () { assert.equal(con.uri, uri) }) + it('should accept an options object', async function () { + const uri = await createTestDB() + const con = await lancedb.connect({ uri }) + assert.equal(con.uri, uri) + }) + + it('should accept custom aws credentials', async function () { + const uri = await createTestDB() + const awsCredentials: AwsCredentials = { + accessKeyId: '', + secretKey: '' + } + const con = await lancedb.connect({ uri, awsCredentials }) + assert.equal(con.uri, uri) + }) + it('should return the existing table names', async function () { const uri = await createTestDB() const con = await lancedb.connect(uri)