mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-26 16:30:41 +00:00
feat(node): align incoming data to table schema (#802)
This commit is contained in:
@@ -17,10 +17,9 @@ import {
|
||||
Float32,
|
||||
makeBuilder,
|
||||
RecordBatchFileWriter,
|
||||
Utf8,
|
||||
type Vector,
|
||||
Utf8, type Vector,
|
||||
FixedSizeList,
|
||||
vectorFromArray, type Schema, Table as ArrowTable, RecordBatchStreamWriter, List, Float64
|
||||
vectorFromArray, type Schema, Table as ArrowTable, RecordBatchStreamWriter, List, Float64, RecordBatch, makeData, Struct
|
||||
} from 'apache-arrow'
|
||||
import { type EmbeddingFunction } from './index'
|
||||
|
||||
@@ -78,6 +77,7 @@ export async function convertToTable<T> (data: Array<Record<string, unknown>>, e
|
||||
}
|
||||
records[columnsKey] = listBuilder.finish().toVector()
|
||||
} else {
|
||||
// TODO if this is a struct field then recursively align the subfields
|
||||
records[columnsKey] = vectorFromArray(values)
|
||||
}
|
||||
}
|
||||
@@ -110,21 +110,27 @@ function newVectorType (dim: number): FixedSizeList<Float32> {
|
||||
}
|
||||
|
||||
// Converts an Array of records into Arrow IPC format
|
||||
export async function fromRecordsToBuffer<T> (data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>): Promise<Buffer> {
|
||||
const table = await convertToTable(data, embeddings)
|
||||
export async function fromRecordsToBuffer<T> (data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>, schema?: Schema): Promise<Buffer> {
|
||||
let table = await convertToTable(data, embeddings)
|
||||
if (schema !== undefined) {
|
||||
table = alignTable(table, schema)
|
||||
}
|
||||
const writer = RecordBatchFileWriter.writeAll(table)
|
||||
return Buffer.from(await writer.toUint8Array())
|
||||
}
|
||||
|
||||
// Converts an Array of records into Arrow IPC stream format
|
||||
export async function fromRecordsToStreamBuffer<T> (data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>): Promise<Buffer> {
|
||||
const table = await convertToTable(data, embeddings)
|
||||
export async function fromRecordsToStreamBuffer<T> (data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>, schema?: Schema): Promise<Buffer> {
|
||||
let table = await convertToTable(data, embeddings)
|
||||
if (schema !== undefined) {
|
||||
table = alignTable(table, schema)
|
||||
}
|
||||
const writer = RecordBatchStreamWriter.writeAll(table)
|
||||
return Buffer.from(await writer.toUint8Array())
|
||||
}
|
||||
|
||||
// Converts an Arrow Table into Arrow IPC format
|
||||
export async function fromTableToBuffer<T> (table: ArrowTable, embeddings?: EmbeddingFunction<T>): Promise<Buffer> {
|
||||
export async function fromTableToBuffer<T> (table: ArrowTable, embeddings?: EmbeddingFunction<T>, schema?: Schema): Promise<Buffer> {
|
||||
if (embeddings !== undefined) {
|
||||
const source = table.getChild(embeddings.sourceColumn)
|
||||
|
||||
@@ -136,12 +142,15 @@ export async function fromTableToBuffer<T> (table: ArrowTable, embeddings?: Embe
|
||||
const column = vectorFromArray(vectors, newVectorType(vectors[0].length))
|
||||
table = table.assign(new ArrowTable({ vector: column }))
|
||||
}
|
||||
if (schema !== undefined) {
|
||||
table = alignTable(table, schema)
|
||||
}
|
||||
const writer = RecordBatchFileWriter.writeAll(table)
|
||||
return Buffer.from(await writer.toUint8Array())
|
||||
}
|
||||
|
||||
// Converts an Arrow Table into Arrow IPC stream format
|
||||
export async function fromTableToStreamBuffer<T> (table: ArrowTable, embeddings?: EmbeddingFunction<T>): Promise<Buffer> {
|
||||
export async function fromTableToStreamBuffer<T> (table: ArrowTable, embeddings?: EmbeddingFunction<T>, schema?: Schema): Promise<Buffer> {
|
||||
if (embeddings !== undefined) {
|
||||
const source = table.getChild(embeddings.sourceColumn)
|
||||
|
||||
@@ -153,10 +162,36 @@ export async function fromTableToStreamBuffer<T> (table: ArrowTable, embeddings?
|
||||
const column = vectorFromArray(vectors, newVectorType(vectors[0].length))
|
||||
table = table.assign(new ArrowTable({ vector: column }))
|
||||
}
|
||||
if (schema !== undefined) {
|
||||
table = alignTable(table, schema)
|
||||
}
|
||||
const writer = RecordBatchStreamWriter.writeAll(table)
|
||||
return Buffer.from(await writer.toUint8Array())
|
||||
}
|
||||
|
||||
function alignBatch (batch: RecordBatch, schema: Schema): RecordBatch {
|
||||
const alignedChildren = []
|
||||
for (const field of schema.fields) {
|
||||
const indexInBatch = batch.schema.fields?.findIndex((f) => f.name === field.name)
|
||||
if (indexInBatch < 0) {
|
||||
throw new Error(`The column ${field.name} was not found in the Arrow Table`)
|
||||
}
|
||||
alignedChildren.push(batch.data.children[indexInBatch])
|
||||
}
|
||||
const newData = makeData({
|
||||
type: new Struct(schema.fields),
|
||||
length: batch.numRows,
|
||||
nullCount: batch.nullCount,
|
||||
children: alignedChildren
|
||||
})
|
||||
return new RecordBatch(schema, newData)
|
||||
}
|
||||
|
||||
function alignTable (table: ArrowTable, schema: Schema): ArrowTable {
|
||||
const alignedBatches = table.batches.map(batch => alignBatch(batch, schema))
|
||||
return new ArrowTable(schema, alignedBatches)
|
||||
}
|
||||
|
||||
// Creates an empty Arrow Table
|
||||
export function createEmptyTable (schema: Schema): ArrowTable {
|
||||
return new ArrowTable(schema)
|
||||
|
||||
Reference in New Issue
Block a user