feat: change create table to accept Arrow table (#845)

This commit is contained in:
Lei Xu
2024-01-23 13:25:15 -08:00
committed by Weston Pace
parent 5ecbf971e2
commit 65c1d8bc4c
5 changed files with 586 additions and 160 deletions

View File

@@ -16,7 +16,8 @@ import { type Schema, Table as ArrowTable, tableFromIPC } from 'apache-arrow'
import {
createEmptyTable,
fromRecordsToBuffer,
fromTableToBuffer
fromTableToBuffer,
makeArrowTable
} from './arrow'
import type { EmbeddingFunction } from './embedding/embedding_function'
import { RemoteConnection } from './remote'
@@ -223,7 +224,7 @@ export interface Connection {
*/
createTable(
name: string,
data: Array<Record<string, unknown>>
data: Array<Record<string, unknown>> | ArrowTable
): Promise<Table>
/**
@@ -235,7 +236,7 @@ export interface Connection {
*/
createTable(
name: string,
data: Array<Record<string, unknown>>,
data: Array<Record<string, unknown>> | ArrowTable,
options: WriteOptions
): Promise<Table>
@@ -248,7 +249,7 @@ export interface Connection {
*/
createTable<T>(
name: string,
data: Array<Record<string, unknown>>,
data: Array<Record<string, unknown>> | ArrowTable,
embeddings: EmbeddingFunction<T>
): Promise<Table<T>>
/**
@@ -261,7 +262,7 @@ export interface Connection {
*/
createTable<T>(
name: string,
data: Array<Record<string, unknown>>,
data: Array<Record<string, unknown>> | ArrowTable,
embeddings: EmbeddingFunction<T>,
options: WriteOptions
): Promise<Table<T>>
@@ -291,7 +292,7 @@ export interface Table<T = number[]> {
* @param data Records to be inserted into the Table
* @return The number of rows added to the table
*/
add: (data: Array<Record<string, unknown>>) => Promise<number>
add: (data: Array<Record<string, unknown>> | ArrowTable) => Promise<number>
/**
* Insert records into this Table, replacing its contents.
@@ -299,7 +300,9 @@ export interface Table<T = number[]> {
* @param data Records to be inserted into the Table
* @return The number of rows added to the table
*/
overwrite: (data: Array<Record<string, unknown>>) => Promise<number>
overwrite: (
data: Array<Record<string, unknown>> | ArrowTable
) => Promise<number>
/**
* Create an ANN index on this Table vector index.
@@ -544,7 +547,7 @@ export class LocalConnection implements Connection {
async createTable<T>(
name: string | CreateTableOptions<T>,
data?: Array<Record<string, unknown>>,
data?: Array<Record<string, unknown>> | ArrowTable,
optsOrEmbedding?: WriteOptions | EmbeddingFunction<T>,
opt?: WriteOptions
): Promise<Table<T>> {
@@ -696,12 +699,20 @@ export class LocalTable<T = number[]> implements Table<T> {
* @param data Records to be inserted into the Table
* @return The number of rows added to the table
*/
async add (data: Array<Record<string, unknown>>): Promise<number> {
async add (
data: Array<Record<string, unknown>> | ArrowTable
): Promise<number> {
const schema = await this.schema
let tbl: ArrowTable
if (data instanceof ArrowTable) {
tbl = data
} else {
tbl = makeArrowTable(data, { schema })
}
return tableAdd
.call(
this._tbl,
await fromRecordsToBuffer(data, this._embeddings, schema),
await fromTableToBuffer(tbl, this._embeddings, schema),
WriteMode.Append.toString(),
...getAwsArgs(this._options())
)
@@ -716,11 +727,19 @@ export class LocalTable<T = number[]> implements Table<T> {
* @param data Records to be inserted into the Table
* @return The number of rows added to the table
*/
async overwrite (data: Array<Record<string, unknown>>): Promise<number> {
async overwrite (
data: Array<Record<string, unknown>> | ArrowTable
): Promise<number> {
let buffer: Buffer
if (data instanceof ArrowTable) {
buffer = await fromTableToBuffer(data, this._embeddings)
} else {
buffer = await fromRecordsToBuffer(data, this._embeddings)
}
return tableAdd
.call(
this._tbl,
await fromRecordsToBuffer(data, this._embeddings),
buffer,
WriteMode.Overwrite.toString(),
...getAwsArgs(this._options())
)