From d8befeeea2d35afd03644f8f71cf8716cffbe73a Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Mon, 22 Jan 2024 11:49:44 -0800 Subject: [PATCH] feat(js): add helper function to create Arrow Table with schema (#838) Support to make Apache Arrow Table from an array of javascript Records, with optionally provided Schema. --- node/src/arrow.ts | 207 +++++++++++++++++++++++++++++++++--- node/src/index.ts | 8 +- node/src/test/arrow.test.ts | 108 +++++++++++++++++++ node/tsconfig.json | 12 ++- 4 files changed, 315 insertions(+), 20 deletions(-) create mode 100644 node/src/test/arrow.test.ts diff --git a/node/src/arrow.ts b/node/src/arrow.ts index 2010c220..09f83116 100644 --- a/node/src/arrow.ts +++ b/node/src/arrow.ts @@ -13,18 +13,168 @@ // limitations under the License. import { - Field, type FixedSizeListBuilder, + Field, + type FixedSizeListBuilder, Float32, makeBuilder, RecordBatchFileWriter, - Utf8, type Vector, + Utf8, + type Vector, FixedSizeList, - vectorFromArray, type Schema, Table as ArrowTable, RecordBatchStreamWriter, List, Float64, RecordBatch, makeData, Struct + vectorFromArray, + type Schema, + Table as ArrowTable, + RecordBatchStreamWriter, + List, + Float64, + RecordBatch, + makeData, + Struct, + type Float } from 'apache-arrow' import { type EmbeddingFunction } from './index' +export class VectorColumnOptions { + /** Vector column type. */ + type: Float = new Float32() + + constructor (values?: Partial) { + Object.assign(this, values) + } +} + +/** Options to control the makeArrowTable call. */ +export class MakeArrowTableOptions { + /** Provided schema. */ + schema?: Schema + + /** Vector columns */ + vectorColumns: Record = { + vector: new VectorColumnOptions() + } + + constructor (values?: Partial) { + Object.assign(this, values) + } +} + +/** + * An enhanced version of the {@link makeTable} function from Apache Arrow + * that supports nested fields and embeddings columns. + * + * Note that it currently does not support nulls. + * + * @param data input data + * @param options options to control the makeArrowTable call. + * + * @example + * + * ```ts + * + * import { fromTableToBuffer, makeArrowTable } from "../arrow"; + * import { Field, FixedSizeList, Float16, Float32, Int32, Schema } from "apache-arrow"; + * + * const schema = new Schema([ + * new Field("a", new Int32()), + * new Field("b", new Float32()), + * new Field("c", new FixedSizeList(3, new Field("item", new Float16()))), + * ]); + * const table = makeArrowTable([ + * { a: 1, b: 2, c: [1, 2, 3] }, + * { a: 4, b: 5, c: [4, 5, 6] }, + * { a: 7, b: 8, c: [7, 8, 9] }, + * ], { schema }); + * ``` + * + * It guesses the vector columns if the schema is not provided. For example, + * by default it assumes that the column named `vector` is a vector column. + * + * ```ts + * + * const schema = new Schema([ + new Field("a", new Float64()), + new Field("b", new Float64()), + new Field( + "vector", + new FixedSizeList(3, new Field("item", new Float32())) + ), + ]); + const table = makeArrowTable([ + { a: 1, b: 2, vector: [1, 2, 3] }, + { a: 4, b: 5, vector: [4, 5, 6] }, + { a: 7, b: 8, vector: [7, 8, 9] }, + ]); + assert.deepEqual(table.schema, schema); + * ``` + * + * You can specify the vector column types and names using the options as well + * + * ```typescript + * + * const schema = new Schema([ + new Field('a', new Float64()), + new Field('b', new Float64()), + new Field('vec1', new FixedSizeList(3, new Field('item', new Float16()))), + new Field('vec2', new FixedSizeList(3, new Field('item', new Float16()))) + ]); + * const table = makeArrowTable([ + { a: 1, b: 2, vec1: [1, 2, 3], vec2: [2, 4, 6] }, + { a: 4, b: 5, vec1: [4, 5, 6], vec2: [8, 10, 12] }, + { a: 7, b: 8, vec1: [7, 8, 9], vec2: [14, 16, 18] } + ], { + vectorColumns: { + vec1: { type: new Float16() }, + vec2: { type: new Float16() } + } + } + * assert.deepEqual(table.schema, schema) + * ``` + */ +export function makeArrowTable ( + data: Array>, + options?: Partial +): ArrowTable { + if (data.length === 0) { + throw new Error('At least one record needs to be provided') + } + const opt = new MakeArrowTableOptions(options !== undefined ? options : {}) + const columns: Record = {} + // TODO: sample dataset to find missing columns + const columnNames = Object.keys(data[0]) + for (const colName of columnNames) { + const values = data.map((datum) => datum[colName]) + let vector: Vector + + if (opt.schema !== undefined) { + // Explicit schema is provided, highest priority + vector = vectorFromArray( + values, + opt.schema?.fields.filter((f) => f.name === colName)[0]?.type + ) + } else { + const vectorColumnOptions = opt.vectorColumns[colName] + if (vectorColumnOptions !== undefined) { + const fslType = new FixedSizeList( + values[0].length, + new Field('item', vectorColumnOptions.type, false) + ) + vector = vectorFromArray(values, fslType) + } else { + // Normal case + vector = vectorFromArray(values) + } + } + columns[colName] = vector + } + + return new ArrowTable(columns) +} + // Converts an Array of records into an Arrow Table, optionally applying an embeddings function to it. -export async function convertToTable (data: Array>, embeddings?: EmbeddingFunction): Promise { +export async function convertToTable ( + data: Array>, + embeddings?: EmbeddingFunction +): Promise { if (data.length === 0) { throw new Error('At least one record needs to be provided') } @@ -52,7 +202,10 @@ export async function convertToTable (data: Array>, e if (columnsKey === embeddings?.sourceColumn) { const vectors = await embeddings.embed(values as T[]) - records.vector = vectorFromArray(vectors, newVectorType(vectors[0].length)) + records.vector = vectorFromArray( + vectors, + newVectorType(vectors[0].length) + ) } if (typeof values[0] === 'string') { @@ -110,7 +263,11 @@ function newVectorType (dim: number): FixedSizeList { } // Converts an Array of records into Arrow IPC format -export async function fromRecordsToBuffer (data: Array>, embeddings?: EmbeddingFunction, schema?: Schema): Promise { +export async function fromRecordsToBuffer ( + data: Array>, + embeddings?: EmbeddingFunction, + schema?: Schema +): Promise { let table = await convertToTable(data, embeddings) if (schema !== undefined) { table = alignTable(table, schema) @@ -120,7 +277,11 @@ export async function fromRecordsToBuffer (data: Array (data: Array>, embeddings?: EmbeddingFunction, schema?: Schema): Promise { +export async function fromRecordsToStreamBuffer ( + data: Array>, + embeddings?: EmbeddingFunction, + schema?: Schema +): Promise { let table = await convertToTable(data, embeddings) if (schema !== undefined) { table = alignTable(table, schema) @@ -130,12 +291,18 @@ export async function fromRecordsToStreamBuffer (data: Array (table: ArrowTable, embeddings?: EmbeddingFunction, schema?: Schema): Promise { +export async function fromTableToBuffer ( + table: ArrowTable, + embeddings?: EmbeddingFunction, + schema?: Schema +): Promise { if (embeddings !== undefined) { const source = table.getChild(embeddings.sourceColumn) if (source === null) { - throw new Error(`The embedding source column ${embeddings.sourceColumn} was not found in the Arrow Table`) + throw new Error( + `The embedding source column ${embeddings.sourceColumn} was not found in the Arrow Table` + ) } const vectors = await embeddings.embed(source.toArray() as T[]) @@ -150,12 +317,18 @@ export async function fromTableToBuffer (table: ArrowTable, embeddings?: Embe } // Converts an Arrow Table into Arrow IPC stream format -export async function fromTableToStreamBuffer (table: ArrowTable, embeddings?: EmbeddingFunction, schema?: Schema): Promise { +export async function fromTableToStreamBuffer ( + table: ArrowTable, + embeddings?: EmbeddingFunction, + schema?: Schema +): Promise { if (embeddings !== undefined) { const source = table.getChild(embeddings.sourceColumn) if (source === null) { - throw new Error(`The embedding source column ${embeddings.sourceColumn} was not found in the Arrow Table`) + throw new Error( + `The embedding source column ${embeddings.sourceColumn} was not found in the Arrow Table` + ) } const vectors = await embeddings.embed(source.toArray() as T[]) @@ -172,9 +345,13 @@ export async function fromTableToStreamBuffer (table: ArrowTable, embeddings? function alignBatch (batch: RecordBatch, schema: Schema): RecordBatch { const alignedChildren = [] for (const field of schema.fields) { - const indexInBatch = batch.schema.fields?.findIndex((f) => f.name === field.name) + const indexInBatch = batch.schema.fields?.findIndex( + (f) => f.name === field.name + ) if (indexInBatch < 0) { - throw new Error(`The column ${field.name} was not found in the Arrow Table`) + throw new Error( + `The column ${field.name} was not found in the Arrow Table` + ) } alignedChildren.push(batch.data.children[indexInBatch]) } @@ -188,7 +365,9 @@ function alignBatch (batch: RecordBatch, schema: Schema): RecordBatch { } function alignTable (table: ArrowTable, schema: Schema): ArrowTable { - const alignedBatches = table.batches.map(batch => alignBatch(batch, schema)) + const alignedBatches = table.batches.map((batch) => + alignBatch(batch, schema) + ) return new ArrowTable(schema, alignedBatches) } diff --git a/node/src/index.ts b/node/src/index.ts index 469e4356..d5fa7056 100644 --- a/node/src/index.ts +++ b/node/src/index.ts @@ -41,12 +41,13 @@ const { tableListIndices, tableIndexStats, tableSchema -// eslint-disable-next-line @typescript-eslint/no-var-requires + // eslint-disable-next-line @typescript-eslint/no-var-requires } = require('../native.js') export { Query } export type { EmbeddingFunction } export { OpenAIEmbeddingFunction } from './embedding/openai' +export { makeArrowTable, type MakeArrowTableOptions } from './arrow' const defaultAwsRegion = 'us-west-2' @@ -859,7 +860,10 @@ export class LocalTable implements Table { private checkElectron (): boolean { try { // eslint-disable-next-line no-prototype-builtins - return (process?.versions?.hasOwnProperty('electron') || navigator?.userAgent?.toLowerCase()?.includes(' electron')) + return ( + Object.prototype.hasOwnProperty.call(process?.versions, 'electron') || + navigator?.userAgent?.toLowerCase()?.includes(' electron') + ) } catch (e) { return false } diff --git a/node/src/test/arrow.test.ts b/node/src/test/arrow.test.ts new file mode 100644 index 00000000..9be44377 --- /dev/null +++ b/node/src/test/arrow.test.ts @@ -0,0 +1,108 @@ +// Copyright 2024 Lance Developers. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import { describe } from 'mocha' +import { assert } from 'chai' + +import { fromTableToBuffer, makeArrowTable } from '../arrow' +import { + Field, + FixedSizeList, + Float16, + Float32, + Int32, + tableFromIPC, + Schema, + Float64 +} from 'apache-arrow' + +describe('Apache Arrow tables', function () { + it('customized schema', async function () { + const schema = new Schema([ + new Field('a', new Int32()), + new Field('b', new Float32()), + new Field('c', new FixedSizeList(3, new Field('item', new Float16()))) + ]) + const table = makeArrowTable( + [ + { a: 1, b: 2, c: [1, 2, 3] }, + { a: 4, b: 5, c: [4, 5, 6] }, + { a: 7, b: 8, c: [7, 8, 9] } + ], + { schema } + ) + + const buf = await fromTableToBuffer(table) + assert.isAbove(buf.byteLength, 0) + + const actual = tableFromIPC(buf) + assert.equal(actual.numRows, 3) + const actualSchema = actual.schema + assert.deepEqual(actualSchema, schema) + }) + + it('default vector column', async function () { + const schema = new Schema([ + new Field('a', new Float64()), + new Field('b', new Float64()), + new Field( + 'vector', + new FixedSizeList(3, new Field('item', new Float32())) + ) + ]) + const table = makeArrowTable([ + { a: 1, b: 2, vector: [1, 2, 3] }, + { a: 4, b: 5, vector: [4, 5, 6] }, + { a: 7, b: 8, vector: [7, 8, 9] } + ]) + + const buf = await fromTableToBuffer(table) + assert.isAbove(buf.byteLength, 0) + + const actual = tableFromIPC(buf) + assert.equal(actual.numRows, 3) + const actualSchema = actual.schema + assert.deepEqual(actualSchema, schema) + }) + + it('2 vector columns', async function () { + const schema = new Schema([ + new Field('a', new Float64()), + new Field('b', new Float64()), + new Field('vec1', new FixedSizeList(3, new Field('item', new Float16()))), + new Field('vec2', new FixedSizeList(3, new Field('item', new Float16()))) + ]) + const table = makeArrowTable( + [ + { a: 1, b: 2, vec1: [1, 2, 3], vec2: [2, 4, 6] }, + { a: 4, b: 5, vec1: [4, 5, 6], vec2: [8, 10, 12] }, + { a: 7, b: 8, vec1: [7, 8, 9], vec2: [14, 16, 18] } + ], + { + vectorColumns: { + vec1: { type: new Float16() }, + vec2: { type: new Float16() } + } + } + ) + + const buf = await fromTableToBuffer(table) + assert.isAbove(buf.byteLength, 0) + + const actual = tableFromIPC(buf) + assert.equal(actual.numRows, 3) + const actualSchema = actual.schema + assert.deepEqual(actualSchema, schema) + }) +}) diff --git a/node/tsconfig.json b/node/tsconfig.json index a3fe259b..abb6b947 100644 --- a/node/tsconfig.json +++ b/node/tsconfig.json @@ -1,10 +1,14 @@ { - "include": ["src/**/*.ts"], + "include": [ + "src/**/*.ts", + "src/*.ts" + ], "compilerOptions": { - "target": "es2016", + "target": "ES2020", "module": "commonjs", "declaration": true, "outDir": "./dist", - "strict": true + "strict": true, + // "esModuleInterop": true, } -} +} \ No newline at end of file