feat(js): add helper function to create Arrow Table with schema (#838)

Support to make Apache Arrow Table from an array of javascript Records,
with optionally provided Schema.
This commit is contained in:
Lei Xu
2024-01-22 11:49:44 -08:00
committed by Weston Pace
parent b699b5c42b
commit d8befeeea2
4 changed files with 315 additions and 20 deletions

View File

@@ -13,18 +13,168 @@
// limitations under the License.
import {
Field, type FixedSizeListBuilder,
Field,
type FixedSizeListBuilder,
Float32,
makeBuilder,
RecordBatchFileWriter,
Utf8, type Vector,
Utf8,
type Vector,
FixedSizeList,
vectorFromArray, type Schema, Table as ArrowTable, RecordBatchStreamWriter, List, Float64, RecordBatch, makeData, Struct
vectorFromArray,
type Schema,
Table as ArrowTable,
RecordBatchStreamWriter,
List,
Float64,
RecordBatch,
makeData,
Struct,
type Float
} from 'apache-arrow'
import { type EmbeddingFunction } from './index'
export class VectorColumnOptions {
/** Vector column type. */
type: Float = new Float32()
constructor (values?: Partial<VectorColumnOptions>) {
Object.assign(this, values)
}
}
/** Options to control the makeArrowTable call. */
export class MakeArrowTableOptions {
/** Provided schema. */
schema?: Schema
/** Vector columns */
vectorColumns: Record<string, VectorColumnOptions> = {
vector: new VectorColumnOptions()
}
constructor (values?: Partial<MakeArrowTableOptions>) {
Object.assign(this, values)
}
}
/**
* An enhanced version of the {@link makeTable} function from Apache Arrow
* that supports nested fields and embeddings columns.
*
* Note that it currently does not support nulls.
*
* @param data input data
* @param options options to control the makeArrowTable call.
*
* @example
*
* ```ts
*
* import { fromTableToBuffer, makeArrowTable } from "../arrow";
* import { Field, FixedSizeList, Float16, Float32, Int32, Schema } from "apache-arrow";
*
* const schema = new Schema([
* new Field("a", new Int32()),
* new Field("b", new Float32()),
* new Field("c", new FixedSizeList(3, new Field("item", new Float16()))),
* ]);
* const table = makeArrowTable([
* { a: 1, b: 2, c: [1, 2, 3] },
* { a: 4, b: 5, c: [4, 5, 6] },
* { a: 7, b: 8, c: [7, 8, 9] },
* ], { schema });
* ```
*
* It guesses the vector columns if the schema is not provided. For example,
* by default it assumes that the column named `vector` is a vector column.
*
* ```ts
*
* const schema = new Schema([
new Field("a", new Float64()),
new Field("b", new Float64()),
new Field(
"vector",
new FixedSizeList(3, new Field("item", new Float32()))
),
]);
const table = makeArrowTable([
{ a: 1, b: 2, vector: [1, 2, 3] },
{ a: 4, b: 5, vector: [4, 5, 6] },
{ a: 7, b: 8, vector: [7, 8, 9] },
]);
assert.deepEqual(table.schema, schema);
* ```
*
* You can specify the vector column types and names using the options as well
*
* ```typescript
*
* const schema = new Schema([
new Field('a', new Float64()),
new Field('b', new Float64()),
new Field('vec1', new FixedSizeList(3, new Field('item', new Float16()))),
new Field('vec2', new FixedSizeList(3, new Field('item', new Float16())))
]);
* const table = makeArrowTable([
{ a: 1, b: 2, vec1: [1, 2, 3], vec2: [2, 4, 6] },
{ a: 4, b: 5, vec1: [4, 5, 6], vec2: [8, 10, 12] },
{ a: 7, b: 8, vec1: [7, 8, 9], vec2: [14, 16, 18] }
], {
vectorColumns: {
vec1: { type: new Float16() },
vec2: { type: new Float16() }
}
}
* assert.deepEqual(table.schema, schema)
* ```
*/
export function makeArrowTable (
data: Array<Record<string, any>>,
options?: Partial<MakeArrowTableOptions>
): ArrowTable {
if (data.length === 0) {
throw new Error('At least one record needs to be provided')
}
const opt = new MakeArrowTableOptions(options !== undefined ? options : {})
const columns: Record<string, Vector> = {}
// TODO: sample dataset to find missing columns
const columnNames = Object.keys(data[0])
for (const colName of columnNames) {
const values = data.map((datum) => datum[colName])
let vector: Vector
if (opt.schema !== undefined) {
// Explicit schema is provided, highest priority
vector = vectorFromArray(
values,
opt.schema?.fields.filter((f) => f.name === colName)[0]?.type
)
} else {
const vectorColumnOptions = opt.vectorColumns[colName]
if (vectorColumnOptions !== undefined) {
const fslType = new FixedSizeList(
values[0].length,
new Field('item', vectorColumnOptions.type, false)
)
vector = vectorFromArray(values, fslType)
} else {
// Normal case
vector = vectorFromArray(values)
}
}
columns[colName] = vector
}
return new ArrowTable(columns)
}
// Converts an Array of records into an Arrow Table, optionally applying an embeddings function to it.
export async function convertToTable<T> (data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>): Promise<ArrowTable> {
export async function convertToTable<T> (
data: Array<Record<string, unknown>>,
embeddings?: EmbeddingFunction<T>
): Promise<ArrowTable> {
if (data.length === 0) {
throw new Error('At least one record needs to be provided')
}
@@ -52,7 +202,10 @@ export async function convertToTable<T> (data: Array<Record<string, unknown>>, e
if (columnsKey === embeddings?.sourceColumn) {
const vectors = await embeddings.embed(values as T[])
records.vector = vectorFromArray(vectors, newVectorType(vectors[0].length))
records.vector = vectorFromArray(
vectors,
newVectorType(vectors[0].length)
)
}
if (typeof values[0] === 'string') {
@@ -110,7 +263,11 @@ function newVectorType (dim: number): FixedSizeList<Float32> {
}
// Converts an Array of records into Arrow IPC format
export async function fromRecordsToBuffer<T> (data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>, schema?: Schema): Promise<Buffer> {
export async function fromRecordsToBuffer<T> (
data: Array<Record<string, unknown>>,
embeddings?: EmbeddingFunction<T>,
schema?: Schema
): Promise<Buffer> {
let table = await convertToTable(data, embeddings)
if (schema !== undefined) {
table = alignTable(table, schema)
@@ -120,7 +277,11 @@ export async function fromRecordsToBuffer<T> (data: Array<Record<string, unknown
}
// Converts an Array of records into Arrow IPC stream format
export async function fromRecordsToStreamBuffer<T> (data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>, schema?: Schema): Promise<Buffer> {
export async function fromRecordsToStreamBuffer<T> (
data: Array<Record<string, unknown>>,
embeddings?: EmbeddingFunction<T>,
schema?: Schema
): Promise<Buffer> {
let table = await convertToTable(data, embeddings)
if (schema !== undefined) {
table = alignTable(table, schema)
@@ -130,12 +291,18 @@ export async function fromRecordsToStreamBuffer<T> (data: Array<Record<string, u
}
// Converts an Arrow Table into Arrow IPC format
export async function fromTableToBuffer<T> (table: ArrowTable, embeddings?: EmbeddingFunction<T>, schema?: Schema): Promise<Buffer> {
export async function fromTableToBuffer<T> (
table: ArrowTable,
embeddings?: EmbeddingFunction<T>,
schema?: Schema
): Promise<Buffer> {
if (embeddings !== undefined) {
const source = table.getChild(embeddings.sourceColumn)
if (source === null) {
throw new Error(`The embedding source column ${embeddings.sourceColumn} was not found in the Arrow Table`)
throw new Error(
`The embedding source column ${embeddings.sourceColumn} was not found in the Arrow Table`
)
}
const vectors = await embeddings.embed(source.toArray() as T[])
@@ -150,12 +317,18 @@ export async function fromTableToBuffer<T> (table: ArrowTable, embeddings?: Embe
}
// Converts an Arrow Table into Arrow IPC stream format
export async function fromTableToStreamBuffer<T> (table: ArrowTable, embeddings?: EmbeddingFunction<T>, schema?: Schema): Promise<Buffer> {
export async function fromTableToStreamBuffer<T> (
table: ArrowTable,
embeddings?: EmbeddingFunction<T>,
schema?: Schema
): Promise<Buffer> {
if (embeddings !== undefined) {
const source = table.getChild(embeddings.sourceColumn)
if (source === null) {
throw new Error(`The embedding source column ${embeddings.sourceColumn} was not found in the Arrow Table`)
throw new Error(
`The embedding source column ${embeddings.sourceColumn} was not found in the Arrow Table`
)
}
const vectors = await embeddings.embed(source.toArray() as T[])
@@ -172,9 +345,13 @@ export async function fromTableToStreamBuffer<T> (table: ArrowTable, embeddings?
function alignBatch (batch: RecordBatch, schema: Schema): RecordBatch {
const alignedChildren = []
for (const field of schema.fields) {
const indexInBatch = batch.schema.fields?.findIndex((f) => f.name === field.name)
const indexInBatch = batch.schema.fields?.findIndex(
(f) => f.name === field.name
)
if (indexInBatch < 0) {
throw new Error(`The column ${field.name} was not found in the Arrow Table`)
throw new Error(
`The column ${field.name} was not found in the Arrow Table`
)
}
alignedChildren.push(batch.data.children[indexInBatch])
}
@@ -188,7 +365,9 @@ function alignBatch (batch: RecordBatch, schema: Schema): RecordBatch {
}
function alignTable (table: ArrowTable, schema: Schema): ArrowTable {
const alignedBatches = table.batches.map(batch => alignBatch(batch, schema))
const alignedBatches = table.batches.map((batch) =>
alignBatch(batch, schema)
)
return new ArrowTable(schema, alignedBatches)
}