mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-04 10:52:56 +00:00
feat: change create table to accept Arrow table (#845)
This commit is contained in:
@@ -13,18 +13,29 @@
|
||||
// limitations under the License.
|
||||
|
||||
import {
|
||||
type EmbeddingFunction, type Table, type VectorIndexParams, type Connection,
|
||||
type ConnectionOptions, type CreateTableOptions, type VectorIndex,
|
||||
type EmbeddingFunction,
|
||||
type Table,
|
||||
type VectorIndexParams,
|
||||
type Connection,
|
||||
type ConnectionOptions,
|
||||
type CreateTableOptions,
|
||||
type VectorIndex,
|
||||
type WriteOptions,
|
||||
type IndexStats,
|
||||
type UpdateArgs, type UpdateSqlArgs
|
||||
type UpdateArgs,
|
||||
type UpdateSqlArgs,
|
||||
makeArrowTable
|
||||
} from '../index'
|
||||
import { Query } from '../query'
|
||||
|
||||
import { Vector, Table as ArrowTable } from 'apache-arrow'
|
||||
import { HttpLancedbClient } from './client'
|
||||
import { isEmbeddingFunction } from '../embedding/embedding_function'
|
||||
import { createEmptyTable, fromRecordsToStreamBuffer, fromTableToStreamBuffer } from '../arrow'
|
||||
import {
|
||||
createEmptyTable,
|
||||
fromRecordsToStreamBuffer,
|
||||
fromTableToStreamBuffer
|
||||
} from '../arrow'
|
||||
import { toSQL } from '../util'
|
||||
|
||||
/**
|
||||
@@ -54,7 +65,11 @@ export class RemoteConnection implements Connection {
|
||||
} else {
|
||||
server = opts.hostOverride
|
||||
}
|
||||
this._client = new HttpLancedbClient(server, opts.apiKey, opts.hostOverride === undefined ? undefined : this._dbName)
|
||||
this._client = new HttpLancedbClient(
|
||||
server,
|
||||
opts.apiKey,
|
||||
opts.hostOverride === undefined ? undefined : this._dbName
|
||||
)
|
||||
}
|
||||
|
||||
get uri (): string {
|
||||
@@ -62,14 +77,26 @@ export class RemoteConnection implements Connection {
|
||||
return 'db://' + this._client.uri
|
||||
}
|
||||
|
||||
async tableNames (pageToken: string = '', limit: number = 10): Promise<string[]> {
|
||||
const response = await this._client.get('/v1/table/', { limit, page_token: pageToken })
|
||||
async tableNames (
|
||||
pageToken: string = '',
|
||||
limit: number = 10
|
||||
): Promise<string[]> {
|
||||
const response = await this._client.get('/v1/table/', {
|
||||
limit,
|
||||
page_token: pageToken
|
||||
})
|
||||
return response.data.tables
|
||||
}
|
||||
|
||||
async openTable (name: string): Promise<Table>
|
||||
async openTable<T> (name: string, embeddings: EmbeddingFunction<T>): Promise<Table<T>>
|
||||
async openTable<T> (name: string, embeddings?: EmbeddingFunction<T>): Promise<Table<T>> {
|
||||
async openTable<T>(
|
||||
name: string,
|
||||
embeddings: EmbeddingFunction<T>
|
||||
): Promise<Table<T>>
|
||||
async openTable<T>(
|
||||
name: string,
|
||||
embeddings?: EmbeddingFunction<T>
|
||||
): Promise<Table<T>> {
|
||||
if (embeddings !== undefined) {
|
||||
return new RemoteTable(this._client, name, embeddings)
|
||||
} else {
|
||||
@@ -77,13 +104,21 @@ export class RemoteConnection implements Connection {
|
||||
}
|
||||
}
|
||||
|
||||
async createTable<T> (nameOrOpts: string | CreateTableOptions<T>, data?: Array<Record<string, unknown>>, optsOrEmbedding?: WriteOptions | EmbeddingFunction<T>, opt?: WriteOptions): Promise<Table<T>> {
|
||||
async createTable<T>(
|
||||
nameOrOpts: string | CreateTableOptions<T>,
|
||||
data?: Array<Record<string, unknown>> | ArrowTable,
|
||||
optsOrEmbedding?: WriteOptions | EmbeddingFunction<T>,
|
||||
opt?: WriteOptions
|
||||
): Promise<Table<T>> {
|
||||
// Logic copied from LocatlConnection, refactor these to a base class + connectionImpl pattern
|
||||
let schema
|
||||
let embeddings: undefined | EmbeddingFunction<T>
|
||||
let tableName: string
|
||||
if (typeof nameOrOpts === 'string') {
|
||||
if (optsOrEmbedding !== undefined && isEmbeddingFunction(optsOrEmbedding)) {
|
||||
if (
|
||||
optsOrEmbedding !== undefined &&
|
||||
isEmbeddingFunction(optsOrEmbedding)
|
||||
) {
|
||||
embeddings = optsOrEmbedding
|
||||
}
|
||||
tableName = nameOrOpts
|
||||
@@ -95,14 +130,16 @@ export class RemoteConnection implements Connection {
|
||||
|
||||
let buffer: Buffer
|
||||
|
||||
function isEmpty (data: Array<Record<string, unknown>> | ArrowTable<any>): boolean {
|
||||
function isEmpty (
|
||||
data: Array<Record<string, unknown>> | ArrowTable<any>
|
||||
): boolean {
|
||||
if (data instanceof ArrowTable) {
|
||||
return data.data.length === 0
|
||||
return data.numRows === 0
|
||||
}
|
||||
return data.length === 0
|
||||
}
|
||||
|
||||
if ((data === undefined) || isEmpty(data)) {
|
||||
if (data === undefined || isEmpty(data)) {
|
||||
if (schema === undefined) {
|
||||
throw new Error('Either data or schema needs to defined')
|
||||
}
|
||||
@@ -121,9 +158,11 @@ export class RemoteConnection implements Connection {
|
||||
'application/vnd.apache.arrow.stream'
|
||||
)
|
||||
if (res.status !== 200) {
|
||||
throw new Error(`Server Error, status: ${res.status}, ` +
|
||||
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
|
||||
`message: ${res.statusText}: ${res.data}`)
|
||||
throw new Error(
|
||||
`Server Error, status: ${res.status}, ` +
|
||||
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
|
||||
`message: ${res.statusText}: ${res.data}`
|
||||
)
|
||||
}
|
||||
|
||||
if (embeddings === undefined) {
|
||||
@@ -139,8 +178,12 @@ export class RemoteConnection implements Connection {
|
||||
}
|
||||
|
||||
export class RemoteQuery<T = number[]> extends Query<T> {
|
||||
constructor (query: T, private readonly _client: HttpLancedbClient,
|
||||
private readonly _name: string, embeddings?: EmbeddingFunction<T>) {
|
||||
constructor (
|
||||
query: T,
|
||||
private readonly _client: HttpLancedbClient,
|
||||
private readonly _name: string,
|
||||
embeddings?: EmbeddingFunction<T>
|
||||
) {
|
||||
super(query, undefined, embeddings)
|
||||
}
|
||||
|
||||
@@ -189,8 +232,16 @@ export class RemoteTable<T = number[]> implements Table<T> {
|
||||
private readonly _name: string
|
||||
|
||||
constructor (client: HttpLancedbClient, name: string)
|
||||
constructor (client: HttpLancedbClient, name: string, embeddings: EmbeddingFunction<T>)
|
||||
constructor (client: HttpLancedbClient, name: string, embeddings?: EmbeddingFunction<T>) {
|
||||
constructor (
|
||||
client: HttpLancedbClient,
|
||||
name: string,
|
||||
embeddings: EmbeddingFunction<T>
|
||||
)
|
||||
constructor (
|
||||
client: HttpLancedbClient,
|
||||
name: string,
|
||||
embeddings?: EmbeddingFunction<T>
|
||||
) {
|
||||
this._client = client
|
||||
this._name = name
|
||||
this._embeddings = embeddings
|
||||
@@ -201,22 +252,33 @@ export class RemoteTable<T = number[]> implements Table<T> {
|
||||
}
|
||||
|
||||
get schema (): Promise<any> {
|
||||
return this._client.post(`/v1/table/${this._name}/describe/`).then(res => {
|
||||
if (res.status !== 200) {
|
||||
throw new Error(`Server Error, status: ${res.status}, ` +
|
||||
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
|
||||
`message: ${res.statusText}: ${res.data}`)
|
||||
}
|
||||
return res.data?.schema
|
||||
})
|
||||
return this._client
|
||||
.post(`/v1/table/${this._name}/describe/`)
|
||||
.then((res) => {
|
||||
if (res.status !== 200) {
|
||||
throw new Error(
|
||||
`Server Error, status: ${res.status}, ` +
|
||||
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
|
||||
`message: ${res.statusText}: ${res.data}`
|
||||
)
|
||||
}
|
||||
return res.data?.schema
|
||||
})
|
||||
}
|
||||
|
||||
search (query: T): Query<T> {
|
||||
return new RemoteQuery(query, this._client, this._name)//, this._embeddings_new)
|
||||
return new RemoteQuery(query, this._client, this._name) //, this._embeddings_new)
|
||||
}
|
||||
|
||||
async add (data: Array<Record<string, unknown>>): Promise<number> {
|
||||
const buffer = await fromRecordsToStreamBuffer(data, this._embeddings)
|
||||
async add (data: Array<Record<string, unknown>> | ArrowTable): Promise<number> {
|
||||
let tbl: ArrowTable
|
||||
if (data instanceof ArrowTable) {
|
||||
tbl = data
|
||||
} else {
|
||||
tbl = makeArrowTable(data, await this.schema)
|
||||
}
|
||||
|
||||
const buffer = await fromTableToStreamBuffer(tbl, this._embeddings)
|
||||
const res = await this._client.post(
|
||||
`/v1/table/${this._name}/insert/`,
|
||||
buffer,
|
||||
@@ -226,15 +288,23 @@ export class RemoteTable<T = number[]> implements Table<T> {
|
||||
'application/vnd.apache.arrow.stream'
|
||||
)
|
||||
if (res.status !== 200) {
|
||||
throw new Error(`Server Error, status: ${res.status}, ` +
|
||||
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
|
||||
`message: ${res.statusText}: ${res.data}`)
|
||||
throw new Error(
|
||||
`Server Error, status: ${res.status}, ` +
|
||||
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
|
||||
`message: ${res.statusText}: ${res.data}`
|
||||
)
|
||||
}
|
||||
return data.length
|
||||
return tbl.numRows
|
||||
}
|
||||
|
||||
async overwrite (data: Array<Record<string, unknown>>): Promise<number> {
|
||||
const buffer = await fromRecordsToStreamBuffer(data, this._embeddings)
|
||||
async overwrite (data: Array<Record<string, unknown>> | ArrowTable): Promise<number> {
|
||||
let tbl: ArrowTable
|
||||
if (data instanceof ArrowTable) {
|
||||
tbl = data
|
||||
} else {
|
||||
tbl = makeArrowTable(data)
|
||||
}
|
||||
const buffer = await fromTableToStreamBuffer(tbl, this._embeddings)
|
||||
const res = await this._client.post(
|
||||
`/v1/table/${this._name}/insert/`,
|
||||
buffer,
|
||||
@@ -244,11 +314,13 @@ export class RemoteTable<T = number[]> implements Table<T> {
|
||||
'application/vnd.apache.arrow.stream'
|
||||
)
|
||||
if (res.status !== 200) {
|
||||
throw new Error(`Server Error, status: ${res.status}, ` +
|
||||
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
|
||||
`message: ${res.statusText}: ${res.data}`)
|
||||
throw new Error(
|
||||
`Server Error, status: ${res.status}, ` +
|
||||
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
|
||||
`message: ${res.statusText}: ${res.data}`
|
||||
)
|
||||
}
|
||||
return data.length
|
||||
return tbl.numRows
|
||||
}
|
||||
|
||||
async createIndex (indexParams: VectorIndexParams): Promise<void> {
|
||||
@@ -280,11 +352,16 @@ export class RemoteTable<T = number[]> implements Table<T> {
|
||||
metric_type: metricType,
|
||||
index_cache_size: indexCacheSize
|
||||
}
|
||||
const res = await this._client.post(`/v1/table/${this._name}/create_index/`, data)
|
||||
const res = await this._client.post(
|
||||
`/v1/table/${this._name}/create_index/`,
|
||||
data
|
||||
)
|
||||
if (res.status !== 200) {
|
||||
throw new Error(`Server Error, status: ${res.status}, ` +
|
||||
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
|
||||
`message: ${res.statusText}: ${res.data}`)
|
||||
throw new Error(
|
||||
`Server Error, status: ${res.status}, ` +
|
||||
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
|
||||
`message: ${res.statusText}: ${res.data}`
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -298,7 +375,9 @@ export class RemoteTable<T = number[]> implements Table<T> {
|
||||
}
|
||||
|
||||
async delete (filter: string): Promise<void> {
|
||||
await this._client.post(`/v1/table/${this._name}/delete/`, { predicate: filter })
|
||||
await this._client.post(`/v1/table/${this._name}/delete/`, {
|
||||
predicate: filter
|
||||
})
|
||||
}
|
||||
|
||||
async update (args: UpdateArgs | UpdateSqlArgs): Promise<void> {
|
||||
@@ -322,7 +401,9 @@ export class RemoteTable<T = number[]> implements Table<T> {
|
||||
}
|
||||
|
||||
async listIndices (): Promise<VectorIndex[]> {
|
||||
const results = await this._client.post(`/v1/table/${this._name}/index/list/`)
|
||||
const results = await this._client.post(
|
||||
`/v1/table/${this._name}/index/list/`
|
||||
)
|
||||
return results.data.indexes?.map((index: any) => ({
|
||||
columns: index.columns,
|
||||
name: index.index_name,
|
||||
@@ -331,7 +412,9 @@ export class RemoteTable<T = number[]> implements Table<T> {
|
||||
}
|
||||
|
||||
async indexStats (indexUuid: string): Promise<IndexStats> {
|
||||
const results = await this._client.post(`/v1/table/${this._name}/index/${indexUuid}/stats/`)
|
||||
const results = await this._client.post(
|
||||
`/v1/table/${this._name}/index/${indexUuid}/stats/`
|
||||
)
|
||||
return {
|
||||
numIndexedRows: results.data.num_indexed_rows,
|
||||
numUnindexedRows: results.data.num_unindexed_rows
|
||||
|
||||
Reference in New Issue
Block a user