feat: change create table to accept Arrow table (#845)

This commit is contained in:
Lei Xu
2024-01-23 13:25:15 -08:00
committed by Weston Pace
parent 5ecbf971e2
commit 65c1d8bc4c
5 changed files with 586 additions and 160 deletions

View File

@@ -13,18 +13,29 @@
// limitations under the License.
import {
type EmbeddingFunction, type Table, type VectorIndexParams, type Connection,
type ConnectionOptions, type CreateTableOptions, type VectorIndex,
type EmbeddingFunction,
type Table,
type VectorIndexParams,
type Connection,
type ConnectionOptions,
type CreateTableOptions,
type VectorIndex,
type WriteOptions,
type IndexStats,
type UpdateArgs, type UpdateSqlArgs
type UpdateArgs,
type UpdateSqlArgs,
makeArrowTable
} from '../index'
import { Query } from '../query'
import { Vector, Table as ArrowTable } from 'apache-arrow'
import { HttpLancedbClient } from './client'
import { isEmbeddingFunction } from '../embedding/embedding_function'
import { createEmptyTable, fromRecordsToStreamBuffer, fromTableToStreamBuffer } from '../arrow'
import {
createEmptyTable,
fromRecordsToStreamBuffer,
fromTableToStreamBuffer
} from '../arrow'
import { toSQL } from '../util'
/**
@@ -54,7 +65,11 @@ export class RemoteConnection implements Connection {
} else {
server = opts.hostOverride
}
this._client = new HttpLancedbClient(server, opts.apiKey, opts.hostOverride === undefined ? undefined : this._dbName)
this._client = new HttpLancedbClient(
server,
opts.apiKey,
opts.hostOverride === undefined ? undefined : this._dbName
)
}
get uri (): string {
@@ -62,14 +77,26 @@ export class RemoteConnection implements Connection {
return 'db://' + this._client.uri
}
async tableNames (pageToken: string = '', limit: number = 10): Promise<string[]> {
const response = await this._client.get('/v1/table/', { limit, page_token: pageToken })
async tableNames (
pageToken: string = '',
limit: number = 10
): Promise<string[]> {
const response = await this._client.get('/v1/table/', {
limit,
page_token: pageToken
})
return response.data.tables
}
async openTable (name: string): Promise<Table>
async openTable<T> (name: string, embeddings: EmbeddingFunction<T>): Promise<Table<T>>
async openTable<T> (name: string, embeddings?: EmbeddingFunction<T>): Promise<Table<T>> {
async openTable<T>(
name: string,
embeddings: EmbeddingFunction<T>
): Promise<Table<T>>
async openTable<T>(
name: string,
embeddings?: EmbeddingFunction<T>
): Promise<Table<T>> {
if (embeddings !== undefined) {
return new RemoteTable(this._client, name, embeddings)
} else {
@@ -77,13 +104,21 @@ export class RemoteConnection implements Connection {
}
}
async createTable<T> (nameOrOpts: string | CreateTableOptions<T>, data?: Array<Record<string, unknown>>, optsOrEmbedding?: WriteOptions | EmbeddingFunction<T>, opt?: WriteOptions): Promise<Table<T>> {
async createTable<T>(
nameOrOpts: string | CreateTableOptions<T>,
data?: Array<Record<string, unknown>> | ArrowTable,
optsOrEmbedding?: WriteOptions | EmbeddingFunction<T>,
opt?: WriteOptions
): Promise<Table<T>> {
// Logic copied from LocatlConnection, refactor these to a base class + connectionImpl pattern
let schema
let embeddings: undefined | EmbeddingFunction<T>
let tableName: string
if (typeof nameOrOpts === 'string') {
if (optsOrEmbedding !== undefined && isEmbeddingFunction(optsOrEmbedding)) {
if (
optsOrEmbedding !== undefined &&
isEmbeddingFunction(optsOrEmbedding)
) {
embeddings = optsOrEmbedding
}
tableName = nameOrOpts
@@ -95,14 +130,16 @@ export class RemoteConnection implements Connection {
let buffer: Buffer
function isEmpty (data: Array<Record<string, unknown>> | ArrowTable<any>): boolean {
function isEmpty (
data: Array<Record<string, unknown>> | ArrowTable<any>
): boolean {
if (data instanceof ArrowTable) {
return data.data.length === 0
return data.numRows === 0
}
return data.length === 0
}
if ((data === undefined) || isEmpty(data)) {
if (data === undefined || isEmpty(data)) {
if (schema === undefined) {
throw new Error('Either data or schema needs to defined')
}
@@ -121,9 +158,11 @@ export class RemoteConnection implements Connection {
'application/vnd.apache.arrow.stream'
)
if (res.status !== 200) {
throw new Error(`Server Error, status: ${res.status}, ` +
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`message: ${res.statusText}: ${res.data}`)
throw new Error(
`Server Error, status: ${res.status}, ` +
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`message: ${res.statusText}: ${res.data}`
)
}
if (embeddings === undefined) {
@@ -139,8 +178,12 @@ export class RemoteConnection implements Connection {
}
export class RemoteQuery<T = number[]> extends Query<T> {
constructor (query: T, private readonly _client: HttpLancedbClient,
private readonly _name: string, embeddings?: EmbeddingFunction<T>) {
constructor (
query: T,
private readonly _client: HttpLancedbClient,
private readonly _name: string,
embeddings?: EmbeddingFunction<T>
) {
super(query, undefined, embeddings)
}
@@ -189,8 +232,16 @@ export class RemoteTable<T = number[]> implements Table<T> {
private readonly _name: string
constructor (client: HttpLancedbClient, name: string)
constructor (client: HttpLancedbClient, name: string, embeddings: EmbeddingFunction<T>)
constructor (client: HttpLancedbClient, name: string, embeddings?: EmbeddingFunction<T>) {
constructor (
client: HttpLancedbClient,
name: string,
embeddings: EmbeddingFunction<T>
)
constructor (
client: HttpLancedbClient,
name: string,
embeddings?: EmbeddingFunction<T>
) {
this._client = client
this._name = name
this._embeddings = embeddings
@@ -201,22 +252,33 @@ export class RemoteTable<T = number[]> implements Table<T> {
}
get schema (): Promise<any> {
return this._client.post(`/v1/table/${this._name}/describe/`).then(res => {
if (res.status !== 200) {
throw new Error(`Server Error, status: ${res.status}, ` +
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`message: ${res.statusText}: ${res.data}`)
}
return res.data?.schema
})
return this._client
.post(`/v1/table/${this._name}/describe/`)
.then((res) => {
if (res.status !== 200) {
throw new Error(
`Server Error, status: ${res.status}, ` +
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`message: ${res.statusText}: ${res.data}`
)
}
return res.data?.schema
})
}
search (query: T): Query<T> {
return new RemoteQuery(query, this._client, this._name)//, this._embeddings_new)
return new RemoteQuery(query, this._client, this._name) //, this._embeddings_new)
}
async add (data: Array<Record<string, unknown>>): Promise<number> {
const buffer = await fromRecordsToStreamBuffer(data, this._embeddings)
async add (data: Array<Record<string, unknown>> | ArrowTable): Promise<number> {
let tbl: ArrowTable
if (data instanceof ArrowTable) {
tbl = data
} else {
tbl = makeArrowTable(data, await this.schema)
}
const buffer = await fromTableToStreamBuffer(tbl, this._embeddings)
const res = await this._client.post(
`/v1/table/${this._name}/insert/`,
buffer,
@@ -226,15 +288,23 @@ export class RemoteTable<T = number[]> implements Table<T> {
'application/vnd.apache.arrow.stream'
)
if (res.status !== 200) {
throw new Error(`Server Error, status: ${res.status}, ` +
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`message: ${res.statusText}: ${res.data}`)
throw new Error(
`Server Error, status: ${res.status}, ` +
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`message: ${res.statusText}: ${res.data}`
)
}
return data.length
return tbl.numRows
}
async overwrite (data: Array<Record<string, unknown>>): Promise<number> {
const buffer = await fromRecordsToStreamBuffer(data, this._embeddings)
async overwrite (data: Array<Record<string, unknown>> | ArrowTable): Promise<number> {
let tbl: ArrowTable
if (data instanceof ArrowTable) {
tbl = data
} else {
tbl = makeArrowTable(data)
}
const buffer = await fromTableToStreamBuffer(tbl, this._embeddings)
const res = await this._client.post(
`/v1/table/${this._name}/insert/`,
buffer,
@@ -244,11 +314,13 @@ export class RemoteTable<T = number[]> implements Table<T> {
'application/vnd.apache.arrow.stream'
)
if (res.status !== 200) {
throw new Error(`Server Error, status: ${res.status}, ` +
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`message: ${res.statusText}: ${res.data}`)
throw new Error(
`Server Error, status: ${res.status}, ` +
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`message: ${res.statusText}: ${res.data}`
)
}
return data.length
return tbl.numRows
}
async createIndex (indexParams: VectorIndexParams): Promise<void> {
@@ -280,11 +352,16 @@ export class RemoteTable<T = number[]> implements Table<T> {
metric_type: metricType,
index_cache_size: indexCacheSize
}
const res = await this._client.post(`/v1/table/${this._name}/create_index/`, data)
const res = await this._client.post(
`/v1/table/${this._name}/create_index/`,
data
)
if (res.status !== 200) {
throw new Error(`Server Error, status: ${res.status}, ` +
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`message: ${res.statusText}: ${res.data}`)
throw new Error(
`Server Error, status: ${res.status}, ` +
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`message: ${res.statusText}: ${res.data}`
)
}
}
@@ -298,7 +375,9 @@ export class RemoteTable<T = number[]> implements Table<T> {
}
async delete (filter: string): Promise<void> {
await this._client.post(`/v1/table/${this._name}/delete/`, { predicate: filter })
await this._client.post(`/v1/table/${this._name}/delete/`, {
predicate: filter
})
}
async update (args: UpdateArgs | UpdateSqlArgs): Promise<void> {
@@ -322,7 +401,9 @@ export class RemoteTable<T = number[]> implements Table<T> {
}
async listIndices (): Promise<VectorIndex[]> {
const results = await this._client.post(`/v1/table/${this._name}/index/list/`)
const results = await this._client.post(
`/v1/table/${this._name}/index/list/`
)
return results.data.indexes?.map((index: any) => ({
columns: index.columns,
name: index.index_name,
@@ -331,7 +412,9 @@ export class RemoteTable<T = number[]> implements Table<T> {
}
async indexStats (indexUuid: string): Promise<IndexStats> {
const results = await this._client.post(`/v1/table/${this._name}/index/${indexUuid}/stats/`)
const results = await this._client.post(
`/v1/table/${this._name}/index/${indexUuid}/stats/`
)
return {
numIndexedRows: results.data.num_indexed_rows,
numUnindexedRows: results.data.num_unindexed_rows