mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-04 19:02:58 +00:00
feat(node): add option object to connect method (#286)
This commit is contained in:
@@ -27,13 +27,38 @@ const { databaseNew, databaseTableNames, databaseOpenTable, databaseDropTable, t
|
||||
export type { EmbeddingFunction }
|
||||
export { OpenAIEmbeddingFunction } from './embedding/openai'
|
||||
|
||||
export interface AwsCredentials {
|
||||
accessKeyId: string
|
||||
|
||||
secretKey: string
|
||||
|
||||
sessionToken?: string
|
||||
}
|
||||
|
||||
export interface ConnectionOptions {
|
||||
uri: string
|
||||
awsCredentials?: AwsCredentials
|
||||
}
|
||||
|
||||
/**
|
||||
* Connect to a LanceDB instance at the given URI
|
||||
* @param uri The uri of the database.
|
||||
*/
|
||||
export async function connect (uri: string): Promise<Connection> {
|
||||
const db = await databaseNew(uri)
|
||||
return new LocalConnection(db, uri)
|
||||
export async function connect (uri: string): Promise<Connection>
|
||||
export async function connect (opts: Partial<ConnectionOptions>): Promise<Connection>
|
||||
export async function connect (arg: string | Partial<ConnectionOptions>): Promise<Connection> {
|
||||
let opts: ConnectionOptions
|
||||
if (typeof arg === 'string') {
|
||||
opts = { uri: arg }
|
||||
} else {
|
||||
// opts = { uri: arg.uri, awsCredentials = arg.awsCredentials }
|
||||
opts = Object.assign({
|
||||
uri: '',
|
||||
awsCredentials: undefined
|
||||
}, arg)
|
||||
}
|
||||
const db = await databaseNew(opts.uri)
|
||||
return new LocalConnection(db, opts)
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -122,28 +147,20 @@ export interface Table<T = number[]> {
|
||||
delete: (filter: string) => Promise<void>
|
||||
}
|
||||
|
||||
export interface AwsCredentials {
|
||||
accessKeyId: string
|
||||
|
||||
secretKey: string
|
||||
|
||||
sessionToken?: string
|
||||
}
|
||||
|
||||
/**
|
||||
* A connection to a LanceDB database.
|
||||
*/
|
||||
export class LocalConnection implements Connection {
|
||||
private readonly _uri: string
|
||||
private readonly _options: ConnectionOptions
|
||||
private readonly _db: any
|
||||
|
||||
constructor (db: any, uri: string) {
|
||||
this._uri = uri
|
||||
constructor (db: any, options: ConnectionOptions) {
|
||||
this._options = options
|
||||
this._db = db
|
||||
}
|
||||
|
||||
get uri (): string {
|
||||
return this._uri
|
||||
return this._options.uri
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -166,10 +183,14 @@ export class LocalConnection implements Connection {
|
||||
* @param embeddings An embedding function to use on this Table
|
||||
*/
|
||||
async openTable<T> (name: string, embeddings: EmbeddingFunction<T>): Promise<Table<T>>
|
||||
async openTable<T> (name: string, embeddings?: EmbeddingFunction<T>, awsCredentials?: AwsCredentials): Promise<Table<T>>
|
||||
async openTable<T> (name: string, embeddings?: EmbeddingFunction<T>, awsCredentials?: AwsCredentials): Promise<Table<T>> {
|
||||
async openTable<T> (name: string, embeddings?: EmbeddingFunction<T>): Promise<Table<T>>
|
||||
async openTable<T> (name: string, embeddings?: EmbeddingFunction<T>): Promise<Table<T>> {
|
||||
const tbl = await databaseOpenTable.call(this._db, name)
|
||||
return new LocalTable(tbl, name, embeddings, awsCredentials)
|
||||
if (embeddings !== undefined) {
|
||||
return new LocalTable(tbl, name, this._options, embeddings)
|
||||
} else {
|
||||
return new LocalTable(tbl, name, this._options)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -191,24 +212,28 @@ export class LocalConnection implements Connection {
|
||||
* @param embeddings An embedding function to use on this Table
|
||||
*/
|
||||
async createTable<T> (name: string, data: Array<Record<string, unknown>>, mode: WriteMode, embeddings: EmbeddingFunction<T>): Promise<Table<T>>
|
||||
async createTable<T> (name: string, data: Array<Record<string, unknown>>, mode: WriteMode, embeddings?: EmbeddingFunction<T>, awsCredentials?: AwsCredentials): Promise<Table<T>>
|
||||
async createTable<T> (name: string, data: Array<Record<string, unknown>>, mode: WriteMode, embeddings?: EmbeddingFunction<T>, awsCredentials?: AwsCredentials): Promise<Table<T>> {
|
||||
async createTable<T> (name: string, data: Array<Record<string, unknown>>, mode: WriteMode, embeddings?: EmbeddingFunction<T>): Promise<Table<T>>
|
||||
async createTable<T> (name: string, data: Array<Record<string, unknown>>, mode: WriteMode, embeddings?: EmbeddingFunction<T>): Promise<Table<T>> {
|
||||
if (mode === undefined) {
|
||||
mode = WriteMode.Create
|
||||
}
|
||||
|
||||
const createArgs = [this._db, name, await fromRecordsToBuffer(data, embeddings), mode.toLowerCase()]
|
||||
if (awsCredentials !== undefined) {
|
||||
createArgs.push(awsCredentials.accessKeyId)
|
||||
createArgs.push(awsCredentials.secretKey)
|
||||
if (awsCredentials.sessionToken !== undefined) {
|
||||
createArgs.push(awsCredentials.sessionToken)
|
||||
if (this._options.awsCredentials !== undefined) {
|
||||
createArgs.push(this._options.awsCredentials.accessKeyId)
|
||||
createArgs.push(this._options.awsCredentials.secretKey)
|
||||
if (this._options.awsCredentials.sessionToken !== undefined) {
|
||||
createArgs.push(this._options.awsCredentials.sessionToken)
|
||||
}
|
||||
}
|
||||
|
||||
const tbl = await tableCreate.call(...createArgs)
|
||||
|
||||
return new LocalTable(tbl, name, embeddings, awsCredentials)
|
||||
if (embeddings !== undefined) {
|
||||
return new LocalTable(tbl, name, this._options, embeddings)
|
||||
} else {
|
||||
return new LocalTable(tbl, name, this._options)
|
||||
}
|
||||
}
|
||||
|
||||
async createTableArrow (name: string, table: ArrowTable): Promise<Table> {
|
||||
@@ -230,21 +255,21 @@ export class LocalTable<T = number[]> implements Table<T> {
|
||||
private readonly _tbl: any
|
||||
private readonly _name: string
|
||||
private readonly _embeddings?: EmbeddingFunction<T>
|
||||
private readonly _awsCredentials?: AwsCredentials
|
||||
private readonly _options: ConnectionOptions
|
||||
|
||||
constructor (tbl: any, name: string)
|
||||
constructor (tbl: any, name: string, options: ConnectionOptions)
|
||||
/**
|
||||
* @param tbl
|
||||
* @param name
|
||||
* @param options
|
||||
* @param embeddings An embedding function to use when interacting with this table
|
||||
*/
|
||||
constructor (tbl: any, name: string, embeddings: EmbeddingFunction<T>)
|
||||
constructor (tbl: any, name: string, embeddings?: EmbeddingFunction<T>, awsCredentials?: AwsCredentials)
|
||||
constructor (tbl: any, name: string, embeddings?: EmbeddingFunction<T>, awsCredentials?: AwsCredentials) {
|
||||
constructor (tbl: any, name: string, options: ConnectionOptions, embeddings: EmbeddingFunction<T>)
|
||||
constructor (tbl: any, name: string, options: ConnectionOptions, embeddings?: EmbeddingFunction<T>) {
|
||||
this._tbl = tbl
|
||||
this._name = name
|
||||
this._embeddings = embeddings
|
||||
this._awsCredentials = awsCredentials
|
||||
this._options = options
|
||||
}
|
||||
|
||||
get name (): string {
|
||||
@@ -267,11 +292,11 @@ export class LocalTable<T = number[]> implements Table<T> {
|
||||
*/
|
||||
async add (data: Array<Record<string, unknown>>): Promise<number> {
|
||||
const callArgs = [this._tbl, await fromRecordsToBuffer(data, this._embeddings), WriteMode.Append.toString()]
|
||||
if (this._awsCredentials !== undefined) {
|
||||
callArgs.push(this._awsCredentials.accessKeyId)
|
||||
callArgs.push(this._awsCredentials.secretKey)
|
||||
if (this._awsCredentials.sessionToken !== undefined) {
|
||||
callArgs.push(this._awsCredentials.sessionToken)
|
||||
if (this._options.awsCredentials !== undefined) {
|
||||
callArgs.push(this._options.awsCredentials.accessKeyId)
|
||||
callArgs.push(this._options.awsCredentials.secretKey)
|
||||
if (this._options.awsCredentials.sessionToken !== undefined) {
|
||||
callArgs.push(this._options.awsCredentials.sessionToken)
|
||||
}
|
||||
}
|
||||
return tableAdd.call(...callArgs)
|
||||
@@ -285,11 +310,11 @@ export class LocalTable<T = number[]> implements Table<T> {
|
||||
*/
|
||||
async overwrite (data: Array<Record<string, unknown>>): Promise<number> {
|
||||
const callArgs = [this._tbl, await fromRecordsToBuffer(data, this._embeddings), WriteMode.Overwrite.toString()]
|
||||
if (this._awsCredentials !== undefined) {
|
||||
callArgs.push(this._awsCredentials.accessKeyId)
|
||||
callArgs.push(this._awsCredentials.secretKey)
|
||||
if (this._awsCredentials.sessionToken !== undefined) {
|
||||
callArgs.push(this._awsCredentials.sessionToken)
|
||||
if (this._options.awsCredentials !== undefined) {
|
||||
callArgs.push(this._options.awsCredentials.accessKeyId)
|
||||
callArgs.push(this._options.awsCredentials.secretKey)
|
||||
if (this._options.awsCredentials.sessionToken !== undefined) {
|
||||
callArgs.push(this._options.awsCredentials.sessionToken)
|
||||
}
|
||||
}
|
||||
return tableAdd.call(this._tbl, await fromRecordsToBuffer(data, this._embeddings), WriteMode.Overwrite.toString())
|
||||
|
||||
@@ -18,26 +18,48 @@ import { describe } from 'mocha'
|
||||
import { assert } from 'chai'
|
||||
|
||||
import * as lancedb from '../index'
|
||||
import { type ConnectionOptions } from '../index'
|
||||
|
||||
describe('LanceDB S3 client', function () {
|
||||
if (process.env.TEST_S3_BASE_URL != null) {
|
||||
const baseUri = process.env.TEST_S3_BASE_URL
|
||||
it('should have a valid url', async function () {
|
||||
const uri = `${baseUri}/valid_url`
|
||||
const table = await createTestDB(uri, 2, 20)
|
||||
const con = await lancedb.connect(uri)
|
||||
assert.equal(con.uri, uri)
|
||||
const opts = { uri: `${baseUri}/valid_url` }
|
||||
const table = await createTestDB(opts, 2, 20)
|
||||
const con = await lancedb.connect(opts)
|
||||
assert.equal(con.uri, opts.uri)
|
||||
|
||||
const results = await table.search([0.1, 0.3]).limit(5).execute()
|
||||
assert.equal(results.length, 5)
|
||||
})
|
||||
}).timeout(10_000)
|
||||
} else {
|
||||
describe.skip('Skip S3 test', function () {})
|
||||
}
|
||||
|
||||
if (process.env.TEST_S3_BASE_URL != null && process.env.TEST_AWS_ACCESS_KEY_ID != null && process.env.TEST_AWS_SECRET_ACCESS_KEY != null) {
|
||||
const baseUri = process.env.TEST_S3_BASE_URL
|
||||
it('use custom credentials', async function () {
|
||||
const opts: ConnectionOptions = {
|
||||
uri: `${baseUri}/custom_credentials`,
|
||||
awsCredentials: {
|
||||
accessKeyId: process.env.TEST_AWS_ACCESS_KEY_ID as string,
|
||||
secretKey: process.env.TEST_AWS_SECRET_ACCESS_KEY as string
|
||||
}
|
||||
}
|
||||
const table = await createTestDB(opts, 2, 20)
|
||||
const con = await lancedb.connect(opts)
|
||||
assert.equal(con.uri, opts.uri)
|
||||
|
||||
const results = await table.search([0.1, 0.3]).limit(5).execute()
|
||||
assert.equal(results.length, 5)
|
||||
}).timeout(10_000)
|
||||
} else {
|
||||
describe.skip('Skip S3 test', function () {})
|
||||
}
|
||||
})
|
||||
|
||||
async function createTestDB (uri: string, numDimensions: number = 2, numRows: number = 2): Promise<lancedb.Table> {
|
||||
const con = await lancedb.connect(uri)
|
||||
async function createTestDB (opts: ConnectionOptions, numDimensions: number = 2, numRows: number = 2): Promise<lancedb.Table> {
|
||||
const con = await lancedb.connect(opts)
|
||||
|
||||
const data = []
|
||||
for (let i = 0; i < numRows; i++) {
|
||||
|
||||
@@ -18,7 +18,7 @@ import * as chai from 'chai'
|
||||
import * as chaiAsPromised from 'chai-as-promised'
|
||||
|
||||
import * as lancedb from '../index'
|
||||
import { type EmbeddingFunction, MetricType, Query, WriteMode } from '../index'
|
||||
import { type AwsCredentials, type EmbeddingFunction, MetricType, Query, WriteMode } from '../index'
|
||||
|
||||
const expect = chai.expect
|
||||
const assert = chai.assert
|
||||
@@ -32,6 +32,22 @@ describe('LanceDB client', function () {
|
||||
assert.equal(con.uri, uri)
|
||||
})
|
||||
|
||||
it('should accept an options object', async function () {
|
||||
const uri = await createTestDB()
|
||||
const con = await lancedb.connect({ uri })
|
||||
assert.equal(con.uri, uri)
|
||||
})
|
||||
|
||||
it('should accept custom aws credentials', async function () {
|
||||
const uri = await createTestDB()
|
||||
const awsCredentials: AwsCredentials = {
|
||||
accessKeyId: '',
|
||||
secretKey: ''
|
||||
}
|
||||
const con = await lancedb.connect({ uri, awsCredentials })
|
||||
assert.equal(con.uri, uri)
|
||||
})
|
||||
|
||||
it('should return the existing table names', async function () {
|
||||
const uri = await createTestDB()
|
||||
const con = await lancedb.connect(uri)
|
||||
|
||||
Reference in New Issue
Block a user