feat(node): add option object to connect method (#286)

This commit is contained in:
gsilvestrin
2023-07-13 11:03:48 -07:00
committed by GitHub
parent 08cc483ec9
commit 826dc90151
3 changed files with 113 additions and 50 deletions

View File

@@ -27,13 +27,38 @@ const { databaseNew, databaseTableNames, databaseOpenTable, databaseDropTable, t
export type { EmbeddingFunction }
export { OpenAIEmbeddingFunction } from './embedding/openai'
export interface AwsCredentials {
accessKeyId: string
secretKey: string
sessionToken?: string
}
export interface ConnectionOptions {
uri: string
awsCredentials?: AwsCredentials
}
/**
* Connect to a LanceDB instance at the given URI
* @param uri The uri of the database.
*/
export async function connect (uri: string): Promise<Connection> {
const db = await databaseNew(uri)
return new LocalConnection(db, uri)
export async function connect (uri: string): Promise<Connection>
export async function connect (opts: Partial<ConnectionOptions>): Promise<Connection>
export async function connect (arg: string | Partial<ConnectionOptions>): Promise<Connection> {
let opts: ConnectionOptions
if (typeof arg === 'string') {
opts = { uri: arg }
} else {
// opts = { uri: arg.uri, awsCredentials = arg.awsCredentials }
opts = Object.assign({
uri: '',
awsCredentials: undefined
}, arg)
}
const db = await databaseNew(opts.uri)
return new LocalConnection(db, opts)
}
/**
@@ -122,28 +147,20 @@ export interface Table<T = number[]> {
delete: (filter: string) => Promise<void>
}
export interface AwsCredentials {
accessKeyId: string
secretKey: string
sessionToken?: string
}
/**
* A connection to a LanceDB database.
*/
export class LocalConnection implements Connection {
private readonly _uri: string
private readonly _options: ConnectionOptions
private readonly _db: any
constructor (db: any, uri: string) {
this._uri = uri
constructor (db: any, options: ConnectionOptions) {
this._options = options
this._db = db
}
get uri (): string {
return this._uri
return this._options.uri
}
/**
@@ -166,10 +183,14 @@ export class LocalConnection implements Connection {
* @param embeddings An embedding function to use on this Table
*/
async openTable<T> (name: string, embeddings: EmbeddingFunction<T>): Promise<Table<T>>
async openTable<T> (name: string, embeddings?: EmbeddingFunction<T>, awsCredentials?: AwsCredentials): Promise<Table<T>>
async openTable<T> (name: string, embeddings?: EmbeddingFunction<T>, awsCredentials?: AwsCredentials): Promise<Table<T>> {
async openTable<T> (name: string, embeddings?: EmbeddingFunction<T>): Promise<Table<T>>
async openTable<T> (name: string, embeddings?: EmbeddingFunction<T>): Promise<Table<T>> {
const tbl = await databaseOpenTable.call(this._db, name)
return new LocalTable(tbl, name, embeddings, awsCredentials)
if (embeddings !== undefined) {
return new LocalTable(tbl, name, this._options, embeddings)
} else {
return new LocalTable(tbl, name, this._options)
}
}
/**
@@ -191,24 +212,28 @@ export class LocalConnection implements Connection {
* @param embeddings An embedding function to use on this Table
*/
async createTable<T> (name: string, data: Array<Record<string, unknown>>, mode: WriteMode, embeddings: EmbeddingFunction<T>): Promise<Table<T>>
async createTable<T> (name: string, data: Array<Record<string, unknown>>, mode: WriteMode, embeddings?: EmbeddingFunction<T>, awsCredentials?: AwsCredentials): Promise<Table<T>>
async createTable<T> (name: string, data: Array<Record<string, unknown>>, mode: WriteMode, embeddings?: EmbeddingFunction<T>, awsCredentials?: AwsCredentials): Promise<Table<T>> {
async createTable<T> (name: string, data: Array<Record<string, unknown>>, mode: WriteMode, embeddings?: EmbeddingFunction<T>): Promise<Table<T>>
async createTable<T> (name: string, data: Array<Record<string, unknown>>, mode: WriteMode, embeddings?: EmbeddingFunction<T>): Promise<Table<T>> {
if (mode === undefined) {
mode = WriteMode.Create
}
const createArgs = [this._db, name, await fromRecordsToBuffer(data, embeddings), mode.toLowerCase()]
if (awsCredentials !== undefined) {
createArgs.push(awsCredentials.accessKeyId)
createArgs.push(awsCredentials.secretKey)
if (awsCredentials.sessionToken !== undefined) {
createArgs.push(awsCredentials.sessionToken)
if (this._options.awsCredentials !== undefined) {
createArgs.push(this._options.awsCredentials.accessKeyId)
createArgs.push(this._options.awsCredentials.secretKey)
if (this._options.awsCredentials.sessionToken !== undefined) {
createArgs.push(this._options.awsCredentials.sessionToken)
}
}
const tbl = await tableCreate.call(...createArgs)
return new LocalTable(tbl, name, embeddings, awsCredentials)
if (embeddings !== undefined) {
return new LocalTable(tbl, name, this._options, embeddings)
} else {
return new LocalTable(tbl, name, this._options)
}
}
async createTableArrow (name: string, table: ArrowTable): Promise<Table> {
@@ -230,21 +255,21 @@ export class LocalTable<T = number[]> implements Table<T> {
private readonly _tbl: any
private readonly _name: string
private readonly _embeddings?: EmbeddingFunction<T>
private readonly _awsCredentials?: AwsCredentials
private readonly _options: ConnectionOptions
constructor (tbl: any, name: string)
constructor (tbl: any, name: string, options: ConnectionOptions)
/**
* @param tbl
* @param name
* @param options
* @param embeddings An embedding function to use when interacting with this table
*/
constructor (tbl: any, name: string, embeddings: EmbeddingFunction<T>)
constructor (tbl: any, name: string, embeddings?: EmbeddingFunction<T>, awsCredentials?: AwsCredentials)
constructor (tbl: any, name: string, embeddings?: EmbeddingFunction<T>, awsCredentials?: AwsCredentials) {
constructor (tbl: any, name: string, options: ConnectionOptions, embeddings: EmbeddingFunction<T>)
constructor (tbl: any, name: string, options: ConnectionOptions, embeddings?: EmbeddingFunction<T>) {
this._tbl = tbl
this._name = name
this._embeddings = embeddings
this._awsCredentials = awsCredentials
this._options = options
}
get name (): string {
@@ -267,11 +292,11 @@ export class LocalTable<T = number[]> implements Table<T> {
*/
async add (data: Array<Record<string, unknown>>): Promise<number> {
const callArgs = [this._tbl, await fromRecordsToBuffer(data, this._embeddings), WriteMode.Append.toString()]
if (this._awsCredentials !== undefined) {
callArgs.push(this._awsCredentials.accessKeyId)
callArgs.push(this._awsCredentials.secretKey)
if (this._awsCredentials.sessionToken !== undefined) {
callArgs.push(this._awsCredentials.sessionToken)
if (this._options.awsCredentials !== undefined) {
callArgs.push(this._options.awsCredentials.accessKeyId)
callArgs.push(this._options.awsCredentials.secretKey)
if (this._options.awsCredentials.sessionToken !== undefined) {
callArgs.push(this._options.awsCredentials.sessionToken)
}
}
return tableAdd.call(...callArgs)
@@ -285,11 +310,11 @@ export class LocalTable<T = number[]> implements Table<T> {
*/
async overwrite (data: Array<Record<string, unknown>>): Promise<number> {
const callArgs = [this._tbl, await fromRecordsToBuffer(data, this._embeddings), WriteMode.Overwrite.toString()]
if (this._awsCredentials !== undefined) {
callArgs.push(this._awsCredentials.accessKeyId)
callArgs.push(this._awsCredentials.secretKey)
if (this._awsCredentials.sessionToken !== undefined) {
callArgs.push(this._awsCredentials.sessionToken)
if (this._options.awsCredentials !== undefined) {
callArgs.push(this._options.awsCredentials.accessKeyId)
callArgs.push(this._options.awsCredentials.secretKey)
if (this._options.awsCredentials.sessionToken !== undefined) {
callArgs.push(this._options.awsCredentials.sessionToken)
}
}
return tableAdd.call(this._tbl, await fromRecordsToBuffer(data, this._embeddings), WriteMode.Overwrite.toString())