feat(node): align incoming data to table schema (#802)

This commit is contained in:
Chang She
2024-01-10 16:44:00 -08:00
committed by Weston Pace
parent 4aa7f58a07
commit 4b243c5ff8
5 changed files with 307 additions and 316 deletions

View File

@@ -17,10 +17,9 @@ import {
Float32,
makeBuilder,
RecordBatchFileWriter,
Utf8,
type Vector,
Utf8, type Vector,
FixedSizeList,
vectorFromArray, type Schema, Table as ArrowTable, RecordBatchStreamWriter, List, Float64
vectorFromArray, type Schema, Table as ArrowTable, RecordBatchStreamWriter, List, Float64, RecordBatch, makeData, Struct
} from 'apache-arrow'
import { type EmbeddingFunction } from './index'
@@ -78,6 +77,7 @@ export async function convertToTable<T> (data: Array<Record<string, unknown>>, e
}
records[columnsKey] = listBuilder.finish().toVector()
} else {
// TODO if this is a struct field then recursively align the subfields
records[columnsKey] = vectorFromArray(values)
}
}
@@ -110,21 +110,27 @@ function newVectorType (dim: number): FixedSizeList<Float32> {
}
// Converts an Array of records into Arrow IPC format
export async function fromRecordsToBuffer<T> (data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>): Promise<Buffer> {
const table = await convertToTable(data, embeddings)
export async function fromRecordsToBuffer<T> (data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>, schema?: Schema): Promise<Buffer> {
let table = await convertToTable(data, embeddings)
if (schema !== undefined) {
table = alignTable(table, schema)
}
const writer = RecordBatchFileWriter.writeAll(table)
return Buffer.from(await writer.toUint8Array())
}
// Converts an Array of records into Arrow IPC stream format
export async function fromRecordsToStreamBuffer<T> (data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>): Promise<Buffer> {
const table = await convertToTable(data, embeddings)
export async function fromRecordsToStreamBuffer<T> (data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>, schema?: Schema): Promise<Buffer> {
let table = await convertToTable(data, embeddings)
if (schema !== undefined) {
table = alignTable(table, schema)
}
const writer = RecordBatchStreamWriter.writeAll(table)
return Buffer.from(await writer.toUint8Array())
}
// Converts an Arrow Table into Arrow IPC format
export async function fromTableToBuffer<T> (table: ArrowTable, embeddings?: EmbeddingFunction<T>): Promise<Buffer> {
export async function fromTableToBuffer<T> (table: ArrowTable, embeddings?: EmbeddingFunction<T>, schema?: Schema): Promise<Buffer> {
if (embeddings !== undefined) {
const source = table.getChild(embeddings.sourceColumn)
@@ -136,12 +142,15 @@ export async function fromTableToBuffer<T> (table: ArrowTable, embeddings?: Embe
const column = vectorFromArray(vectors, newVectorType(vectors[0].length))
table = table.assign(new ArrowTable({ vector: column }))
}
if (schema !== undefined) {
table = alignTable(table, schema)
}
const writer = RecordBatchFileWriter.writeAll(table)
return Buffer.from(await writer.toUint8Array())
}
// Converts an Arrow Table into Arrow IPC stream format
export async function fromTableToStreamBuffer<T> (table: ArrowTable, embeddings?: EmbeddingFunction<T>): Promise<Buffer> {
export async function fromTableToStreamBuffer<T> (table: ArrowTable, embeddings?: EmbeddingFunction<T>, schema?: Schema): Promise<Buffer> {
if (embeddings !== undefined) {
const source = table.getChild(embeddings.sourceColumn)
@@ -153,10 +162,36 @@ export async function fromTableToStreamBuffer<T> (table: ArrowTable, embeddings?
const column = vectorFromArray(vectors, newVectorType(vectors[0].length))
table = table.assign(new ArrowTable({ vector: column }))
}
if (schema !== undefined) {
table = alignTable(table, schema)
}
const writer = RecordBatchStreamWriter.writeAll(table)
return Buffer.from(await writer.toUint8Array())
}
function alignBatch (batch: RecordBatch, schema: Schema): RecordBatch {
const alignedChildren = []
for (const field of schema.fields) {
const indexInBatch = batch.schema.fields?.findIndex((f) => f.name === field.name)
if (indexInBatch < 0) {
throw new Error(`The column ${field.name} was not found in the Arrow Table`)
}
alignedChildren.push(batch.data.children[indexInBatch])
}
const newData = makeData({
type: new Struct(schema.fields),
length: batch.numRows,
nullCount: batch.nullCount,
children: alignedChildren
})
return new RecordBatch(schema, newData)
}
function alignTable (table: ArrowTable, schema: Schema): ArrowTable {
const alignedBatches = table.batches.map(batch => alignBatch(batch, schema))
return new ArrowTable(schema, alignedBatches)
}
// Creates an empty Arrow Table
export function createEmptyTable (schema: Schema): ArrowTable {
return new ArrowTable(schema)