feat(js): add helper function to create Arrow Table with schema (#838)

Support to make Apache Arrow Table from an array of javascript Records,
with optionally provided Schema.
This commit is contained in:
Lei Xu
2024-01-22 11:49:44 -08:00
committed by Weston Pace
parent b699b5c42b
commit d8befeeea2
4 changed files with 315 additions and 20 deletions

View File

@@ -13,18 +13,168 @@
// limitations under the License.
import {
Field, type FixedSizeListBuilder,
Field,
type FixedSizeListBuilder,
Float32,
makeBuilder,
RecordBatchFileWriter,
Utf8, type Vector,
Utf8,
type Vector,
FixedSizeList,
vectorFromArray, type Schema, Table as ArrowTable, RecordBatchStreamWriter, List, Float64, RecordBatch, makeData, Struct
vectorFromArray,
type Schema,
Table as ArrowTable,
RecordBatchStreamWriter,
List,
Float64,
RecordBatch,
makeData,
Struct,
type Float
} from 'apache-arrow'
import { type EmbeddingFunction } from './index'
export class VectorColumnOptions {
/** Vector column type. */
type: Float = new Float32()
constructor (values?: Partial<VectorColumnOptions>) {
Object.assign(this, values)
}
}
/** Options to control the makeArrowTable call. */
export class MakeArrowTableOptions {
/** Provided schema. */
schema?: Schema
/** Vector columns */
vectorColumns: Record<string, VectorColumnOptions> = {
vector: new VectorColumnOptions()
}
constructor (values?: Partial<MakeArrowTableOptions>) {
Object.assign(this, values)
}
}
/**
* An enhanced version of the {@link makeTable} function from Apache Arrow
* that supports nested fields and embeddings columns.
*
* Note that it currently does not support nulls.
*
* @param data input data
* @param options options to control the makeArrowTable call.
*
* @example
*
* ```ts
*
* import { fromTableToBuffer, makeArrowTable } from "../arrow";
* import { Field, FixedSizeList, Float16, Float32, Int32, Schema } from "apache-arrow";
*
* const schema = new Schema([
* new Field("a", new Int32()),
* new Field("b", new Float32()),
* new Field("c", new FixedSizeList(3, new Field("item", new Float16()))),
* ]);
* const table = makeArrowTable([
* { a: 1, b: 2, c: [1, 2, 3] },
* { a: 4, b: 5, c: [4, 5, 6] },
* { a: 7, b: 8, c: [7, 8, 9] },
* ], { schema });
* ```
*
* It guesses the vector columns if the schema is not provided. For example,
* by default it assumes that the column named `vector` is a vector column.
*
* ```ts
*
* const schema = new Schema([
new Field("a", new Float64()),
new Field("b", new Float64()),
new Field(
"vector",
new FixedSizeList(3, new Field("item", new Float32()))
),
]);
const table = makeArrowTable([
{ a: 1, b: 2, vector: [1, 2, 3] },
{ a: 4, b: 5, vector: [4, 5, 6] },
{ a: 7, b: 8, vector: [7, 8, 9] },
]);
assert.deepEqual(table.schema, schema);
* ```
*
* You can specify the vector column types and names using the options as well
*
* ```typescript
*
* const schema = new Schema([
new Field('a', new Float64()),
new Field('b', new Float64()),
new Field('vec1', new FixedSizeList(3, new Field('item', new Float16()))),
new Field('vec2', new FixedSizeList(3, new Field('item', new Float16())))
]);
* const table = makeArrowTable([
{ a: 1, b: 2, vec1: [1, 2, 3], vec2: [2, 4, 6] },
{ a: 4, b: 5, vec1: [4, 5, 6], vec2: [8, 10, 12] },
{ a: 7, b: 8, vec1: [7, 8, 9], vec2: [14, 16, 18] }
], {
vectorColumns: {
vec1: { type: new Float16() },
vec2: { type: new Float16() }
}
}
* assert.deepEqual(table.schema, schema)
* ```
*/
export function makeArrowTable (
data: Array<Record<string, any>>,
options?: Partial<MakeArrowTableOptions>
): ArrowTable {
if (data.length === 0) {
throw new Error('At least one record needs to be provided')
}
const opt = new MakeArrowTableOptions(options !== undefined ? options : {})
const columns: Record<string, Vector> = {}
// TODO: sample dataset to find missing columns
const columnNames = Object.keys(data[0])
for (const colName of columnNames) {
const values = data.map((datum) => datum[colName])
let vector: Vector
if (opt.schema !== undefined) {
// Explicit schema is provided, highest priority
vector = vectorFromArray(
values,
opt.schema?.fields.filter((f) => f.name === colName)[0]?.type
)
} else {
const vectorColumnOptions = opt.vectorColumns[colName]
if (vectorColumnOptions !== undefined) {
const fslType = new FixedSizeList(
values[0].length,
new Field('item', vectorColumnOptions.type, false)
)
vector = vectorFromArray(values, fslType)
} else {
// Normal case
vector = vectorFromArray(values)
}
}
columns[colName] = vector
}
return new ArrowTable(columns)
}
// Converts an Array of records into an Arrow Table, optionally applying an embeddings function to it.
export async function convertToTable<T> (data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>): Promise<ArrowTable> {
export async function convertToTable<T> (
data: Array<Record<string, unknown>>,
embeddings?: EmbeddingFunction<T>
): Promise<ArrowTable> {
if (data.length === 0) {
throw new Error('At least one record needs to be provided')
}
@@ -52,7 +202,10 @@ export async function convertToTable<T> (data: Array<Record<string, unknown>>, e
if (columnsKey === embeddings?.sourceColumn) {
const vectors = await embeddings.embed(values as T[])
records.vector = vectorFromArray(vectors, newVectorType(vectors[0].length))
records.vector = vectorFromArray(
vectors,
newVectorType(vectors[0].length)
)
}
if (typeof values[0] === 'string') {
@@ -110,7 +263,11 @@ function newVectorType (dim: number): FixedSizeList<Float32> {
}
// Converts an Array of records into Arrow IPC format
export async function fromRecordsToBuffer<T> (data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>, schema?: Schema): Promise<Buffer> {
export async function fromRecordsToBuffer<T> (
data: Array<Record<string, unknown>>,
embeddings?: EmbeddingFunction<T>,
schema?: Schema
): Promise<Buffer> {
let table = await convertToTable(data, embeddings)
if (schema !== undefined) {
table = alignTable(table, schema)
@@ -120,7 +277,11 @@ export async function fromRecordsToBuffer<T> (data: Array<Record<string, unknown
}
// Converts an Array of records into Arrow IPC stream format
export async function fromRecordsToStreamBuffer<T> (data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>, schema?: Schema): Promise<Buffer> {
export async function fromRecordsToStreamBuffer<T> (
data: Array<Record<string, unknown>>,
embeddings?: EmbeddingFunction<T>,
schema?: Schema
): Promise<Buffer> {
let table = await convertToTable(data, embeddings)
if (schema !== undefined) {
table = alignTable(table, schema)
@@ -130,12 +291,18 @@ export async function fromRecordsToStreamBuffer<T> (data: Array<Record<string, u
}
// Converts an Arrow Table into Arrow IPC format
export async function fromTableToBuffer<T> (table: ArrowTable, embeddings?: EmbeddingFunction<T>, schema?: Schema): Promise<Buffer> {
export async function fromTableToBuffer<T> (
table: ArrowTable,
embeddings?: EmbeddingFunction<T>,
schema?: Schema
): Promise<Buffer> {
if (embeddings !== undefined) {
const source = table.getChild(embeddings.sourceColumn)
if (source === null) {
throw new Error(`The embedding source column ${embeddings.sourceColumn} was not found in the Arrow Table`)
throw new Error(
`The embedding source column ${embeddings.sourceColumn} was not found in the Arrow Table`
)
}
const vectors = await embeddings.embed(source.toArray() as T[])
@@ -150,12 +317,18 @@ export async function fromTableToBuffer<T> (table: ArrowTable, embeddings?: Embe
}
// Converts an Arrow Table into Arrow IPC stream format
export async function fromTableToStreamBuffer<T> (table: ArrowTable, embeddings?: EmbeddingFunction<T>, schema?: Schema): Promise<Buffer> {
export async function fromTableToStreamBuffer<T> (
table: ArrowTable,
embeddings?: EmbeddingFunction<T>,
schema?: Schema
): Promise<Buffer> {
if (embeddings !== undefined) {
const source = table.getChild(embeddings.sourceColumn)
if (source === null) {
throw new Error(`The embedding source column ${embeddings.sourceColumn} was not found in the Arrow Table`)
throw new Error(
`The embedding source column ${embeddings.sourceColumn} was not found in the Arrow Table`
)
}
const vectors = await embeddings.embed(source.toArray() as T[])
@@ -172,9 +345,13 @@ export async function fromTableToStreamBuffer<T> (table: ArrowTable, embeddings?
function alignBatch (batch: RecordBatch, schema: Schema): RecordBatch {
const alignedChildren = []
for (const field of schema.fields) {
const indexInBatch = batch.schema.fields?.findIndex((f) => f.name === field.name)
const indexInBatch = batch.schema.fields?.findIndex(
(f) => f.name === field.name
)
if (indexInBatch < 0) {
throw new Error(`The column ${field.name} was not found in the Arrow Table`)
throw new Error(
`The column ${field.name} was not found in the Arrow Table`
)
}
alignedChildren.push(batch.data.children[indexInBatch])
}
@@ -188,7 +365,9 @@ function alignBatch (batch: RecordBatch, schema: Schema): RecordBatch {
}
function alignTable (table: ArrowTable, schema: Schema): ArrowTable {
const alignedBatches = table.batches.map(batch => alignBatch(batch, schema))
const alignedBatches = table.batches.map((batch) =>
alignBatch(batch, schema)
)
return new ArrowTable(schema, alignedBatches)
}

View File

@@ -41,12 +41,13 @@ const {
tableListIndices,
tableIndexStats,
tableSchema
// eslint-disable-next-line @typescript-eslint/no-var-requires
// eslint-disable-next-line @typescript-eslint/no-var-requires
} = require('../native.js')
export { Query }
export type { EmbeddingFunction }
export { OpenAIEmbeddingFunction } from './embedding/openai'
export { makeArrowTable, type MakeArrowTableOptions } from './arrow'
const defaultAwsRegion = 'us-west-2'
@@ -859,7 +860,10 @@ export class LocalTable<T = number[]> implements Table<T> {
private checkElectron (): boolean {
try {
// eslint-disable-next-line no-prototype-builtins
return (process?.versions?.hasOwnProperty('electron') || navigator?.userAgent?.toLowerCase()?.includes(' electron'))
return (
Object.prototype.hasOwnProperty.call(process?.versions, 'electron') ||
navigator?.userAgent?.toLowerCase()?.includes(' electron')
)
} catch (e) {
return false
}

108
node/src/test/arrow.test.ts Normal file
View File

@@ -0,0 +1,108 @@
// Copyright 2024 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import { describe } from 'mocha'
import { assert } from 'chai'
import { fromTableToBuffer, makeArrowTable } from '../arrow'
import {
Field,
FixedSizeList,
Float16,
Float32,
Int32,
tableFromIPC,
Schema,
Float64
} from 'apache-arrow'
describe('Apache Arrow tables', function () {
it('customized schema', async function () {
const schema = new Schema([
new Field('a', new Int32()),
new Field('b', new Float32()),
new Field('c', new FixedSizeList(3, new Field('item', new Float16())))
])
const table = makeArrowTable(
[
{ a: 1, b: 2, c: [1, 2, 3] },
{ a: 4, b: 5, c: [4, 5, 6] },
{ a: 7, b: 8, c: [7, 8, 9] }
],
{ schema }
)
const buf = await fromTableToBuffer(table)
assert.isAbove(buf.byteLength, 0)
const actual = tableFromIPC(buf)
assert.equal(actual.numRows, 3)
const actualSchema = actual.schema
assert.deepEqual(actualSchema, schema)
})
it('default vector column', async function () {
const schema = new Schema([
new Field('a', new Float64()),
new Field('b', new Float64()),
new Field(
'vector',
new FixedSizeList(3, new Field('item', new Float32()))
)
])
const table = makeArrowTable([
{ a: 1, b: 2, vector: [1, 2, 3] },
{ a: 4, b: 5, vector: [4, 5, 6] },
{ a: 7, b: 8, vector: [7, 8, 9] }
])
const buf = await fromTableToBuffer(table)
assert.isAbove(buf.byteLength, 0)
const actual = tableFromIPC(buf)
assert.equal(actual.numRows, 3)
const actualSchema = actual.schema
assert.deepEqual(actualSchema, schema)
})
it('2 vector columns', async function () {
const schema = new Schema([
new Field('a', new Float64()),
new Field('b', new Float64()),
new Field('vec1', new FixedSizeList(3, new Field('item', new Float16()))),
new Field('vec2', new FixedSizeList(3, new Field('item', new Float16())))
])
const table = makeArrowTable(
[
{ a: 1, b: 2, vec1: [1, 2, 3], vec2: [2, 4, 6] },
{ a: 4, b: 5, vec1: [4, 5, 6], vec2: [8, 10, 12] },
{ a: 7, b: 8, vec1: [7, 8, 9], vec2: [14, 16, 18] }
],
{
vectorColumns: {
vec1: { type: new Float16() },
vec2: { type: new Float16() }
}
}
)
const buf = await fromTableToBuffer(table)
assert.isAbove(buf.byteLength, 0)
const actual = tableFromIPC(buf)
assert.equal(actual.numRows, 3)
const actualSchema = actual.schema
assert.deepEqual(actualSchema, schema)
})
})

View File

@@ -1,10 +1,14 @@
{
"include": ["src/**/*.ts"],
"include": [
"src/**/*.ts",
"src/*.ts"
],
"compilerOptions": {
"target": "es2016",
"target": "ES2020",
"module": "commonjs",
"declaration": true,
"outDir": "./dist",
"strict": true
"strict": true,
// "esModuleInterop": true,
}
}
}