diff --git a/node/src/arrow.ts b/node/src/arrow.ts index ebbc6b24..5421654a 100644 --- a/node/src/arrow.ts +++ b/node/src/arrow.ts @@ -20,7 +20,7 @@ import { Utf8, type Vector, FixedSizeList, - vectorFromArray, type Schema, Table as ArrowTable + vectorFromArray, type Schema, Table as ArrowTable, RecordBatchStreamWriter } from 'apache-arrow' import { type EmbeddingFunction } from './index' @@ -77,7 +77,9 @@ function newVectorBuilder (dim: number): FixedSizeListBuilder { // Creates the Arrow Type for a Vector column with dimension `dim` function newVectorType (dim: number): FixedSizeList { - const children = new Field('item', new Float32()) + // Somewhere we always default to have the elements nullable, so we need to set it to true + // otherwise we often get schema mismatches because the stored data always has schema with nullable elements + const children = new Field('item', new Float32(), true) return new FixedSizeList(dim, children) } @@ -88,6 +90,13 @@ export async function fromRecordsToBuffer (data: Array (data: Array>, embeddings?: EmbeddingFunction): Promise { + const table = await convertToTable(data, embeddings) + const writer = RecordBatchStreamWriter.writeAll(table) + return Buffer.from(await writer.toUint8Array()) +} + // Converts an Arrow Table into Arrow IPC format export async function fromTableToBuffer (table: ArrowTable, embeddings?: EmbeddingFunction): Promise { if (embeddings !== undefined) { @@ -105,6 +114,23 @@ export async function fromTableToBuffer (table: ArrowTable, embeddings?: Embe return Buffer.from(await writer.toUint8Array()) } +// Converts an Arrow Table into Arrow IPC stream format +export async function fromTableToStreamBuffer (table: ArrowTable, embeddings?: EmbeddingFunction): Promise { + if (embeddings !== undefined) { + const source = table.getChild(embeddings.sourceColumn) + + if (source === null) { + throw new Error(`The embedding source column ${embeddings.sourceColumn} was not found in the Arrow Table`) + } + + const vectors = await embeddings.embed(source.toArray() as T[]) + const column = vectorFromArray(vectors, newVectorType(vectors[0].length)) + table = table.assign(new ArrowTable({ vector: column })) + } + const writer = RecordBatchStreamWriter.writeAll(table) + return Buffer.from(await writer.toUint8Array()) +} + // Creates an empty Arrow Table export function createEmptyTable (schema: Schema): ArrowTable { return new ArrowTable(schema) diff --git a/node/src/remote/client.ts b/node/src/remote/client.ts index 3bcbd421..5911cd48 100644 --- a/node/src/remote/client.ts +++ b/node/src/remote/client.ts @@ -108,13 +108,18 @@ export class HttpLancedbClient { /** * Sent POST request. */ - public async post (path: string, data?: any, params?: Record): Promise { + public async post ( + path: string, + data?: any, + params?: Record, + content?: string | undefined + ): Promise { const response = await axios.post( `${this._url}${path}`, data, { headers: { - 'Content-Type': 'application/json', + 'Content-Type': content ?? 'application/json', 'x-api-key': this._apiKey(), ...(this._dbName !== undefined ? { 'x-lancedb-database': this._dbName } : {}) }, diff --git a/node/src/remote/index.ts b/node/src/remote/index.ts index 92219302..abb3d8f0 100644 --- a/node/src/remote/index.ts +++ b/node/src/remote/index.ts @@ -18,8 +18,10 @@ import { } from '../index' import { Query } from '../query' -import { Vector } from 'apache-arrow' +import { Vector, Table as ArrowTable } from 'apache-arrow' import { HttpLancedbClient } from './client' +import { isEmbeddingFunction } from '../embedding/embedding_function' +import { createEmptyTable, fromRecordsToStreamBuffer, fromTableToStreamBuffer } from '../arrow' /** * Remote connection. @@ -66,8 +68,60 @@ export class RemoteConnection implements Connection { } } - async createTable (name: string | CreateTableOptions, data?: Array>, optsOrEmbedding?: WriteOptions | EmbeddingFunction, opt?: WriteOptions): Promise> { - throw new Error('Not implemented') + async createTable (nameOrOpts: string | CreateTableOptions, data?: Array>, optsOrEmbedding?: WriteOptions | EmbeddingFunction, opt?: WriteOptions): Promise> { + // Logic copied from LocatlConnection, refactor these to a base class + connectionImpl pattern + let schema + let embeddings: undefined | EmbeddingFunction + let tableName: string + if (typeof nameOrOpts === 'string') { + if (optsOrEmbedding !== undefined && isEmbeddingFunction(optsOrEmbedding)) { + embeddings = optsOrEmbedding + } + tableName = nameOrOpts + } else { + schema = nameOrOpts.schema + embeddings = nameOrOpts.embeddingFunction + tableName = nameOrOpts.name + } + + let buffer: Buffer + + function isEmpty (data: Array> | ArrowTable): boolean { + if (data instanceof ArrowTable) { + return data.data.length === 0 + } + return data.length === 0 + } + + if ((data === undefined) || isEmpty(data)) { + if (schema === undefined) { + throw new Error('Either data or schema needs to defined') + } + buffer = await fromTableToStreamBuffer(createEmptyTable(schema)) + } else if (data instanceof ArrowTable) { + buffer = await fromTableToStreamBuffer(data, embeddings) + } else { + // data is Array> + buffer = await fromRecordsToStreamBuffer(data, embeddings) + } + + const res = await this._client.post( + `/v1/table/${tableName}/create/`, + buffer, + undefined, + 'application/vnd.apache.arrow.stream' + ) + if (res.status !== 200) { + throw new Error(`Server Error, status: ${res.status}, ` + + // eslint-disable-next-line @typescript-eslint/restrict-template-expressions + `message: ${res.statusText}: ${res.data}`) + } + + if (embeddings === undefined) { + return new RemoteTable(this._client, tableName) + } else { + return new RemoteTable(this._client, tableName, embeddings) + } } async dropTable (name: string): Promise { @@ -141,11 +195,39 @@ export class RemoteTable implements Table { } async add (data: Array>): Promise { - throw new Error('Not implemented') + const buffer = await fromRecordsToStreamBuffer(data, this._embeddings) + const res = await this._client.post( + `/v1/table/${this._name}/insert/`, + buffer, + { + mode: 'append' + }, + 'application/vnd.apache.arrow.stream' + ) + if (res.status !== 200) { + throw new Error(`Server Error, status: ${res.status}, ` + + // eslint-disable-next-line @typescript-eslint/restrict-template-expressions + `message: ${res.statusText}: ${res.data}`) + } + return data.length } async overwrite (data: Array>): Promise { - throw new Error('Not implemented') + const buffer = await fromRecordsToStreamBuffer(data, this._embeddings) + const res = await this._client.post( + `/v1/table/${this._name}/insert/`, + buffer, + { + mode: 'overwrite' + }, + 'application/vnd.apache.arrow.stream' + ) + if (res.status !== 200) { + throw new Error(`Server Error, status: ${res.status}, ` + + // eslint-disable-next-line @typescript-eslint/restrict-template-expressions + `message: ${res.statusText}: ${res.data}`) + } + return data.length } async createIndex (indexParams: VectorIndexParams): Promise {