feat(node): align incoming data to table schema (#802)

This commit is contained in:
Chang She
2024-01-10 16:44:00 -08:00
committed by GitHub
parent 99adfe065a
commit 81af350d85
5 changed files with 307 additions and 316 deletions

View File

@@ -17,10 +17,9 @@ import {
Float32,
makeBuilder,
RecordBatchFileWriter,
Utf8,
type Vector,
Utf8, type Vector,
FixedSizeList,
vectorFromArray, type Schema, Table as ArrowTable, RecordBatchStreamWriter, List, Float64
vectorFromArray, type Schema, Table as ArrowTable, RecordBatchStreamWriter, List, Float64, RecordBatch, makeData, Struct
} from 'apache-arrow'
import { type EmbeddingFunction } from './index'
@@ -78,6 +77,7 @@ export async function convertToTable<T> (data: Array<Record<string, unknown>>, e
}
records[columnsKey] = listBuilder.finish().toVector()
} else {
// TODO if this is a struct field then recursively align the subfields
records[columnsKey] = vectorFromArray(values)
}
}
@@ -110,21 +110,27 @@ function newVectorType (dim: number): FixedSizeList<Float32> {
}
// Converts an Array of records into Arrow IPC format
export async function fromRecordsToBuffer<T> (data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>): Promise<Buffer> {
const table = await convertToTable(data, embeddings)
export async function fromRecordsToBuffer<T> (data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>, schema?: Schema): Promise<Buffer> {
let table = await convertToTable(data, embeddings)
if (schema !== undefined) {
table = alignTable(table, schema)
}
const writer = RecordBatchFileWriter.writeAll(table)
return Buffer.from(await writer.toUint8Array())
}
// Converts an Array of records into Arrow IPC stream format
export async function fromRecordsToStreamBuffer<T> (data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>): Promise<Buffer> {
const table = await convertToTable(data, embeddings)
export async function fromRecordsToStreamBuffer<T> (data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>, schema?: Schema): Promise<Buffer> {
let table = await convertToTable(data, embeddings)
if (schema !== undefined) {
table = alignTable(table, schema)
}
const writer = RecordBatchStreamWriter.writeAll(table)
return Buffer.from(await writer.toUint8Array())
}
// Converts an Arrow Table into Arrow IPC format
export async function fromTableToBuffer<T> (table: ArrowTable, embeddings?: EmbeddingFunction<T>): Promise<Buffer> {
export async function fromTableToBuffer<T> (table: ArrowTable, embeddings?: EmbeddingFunction<T>, schema?: Schema): Promise<Buffer> {
if (embeddings !== undefined) {
const source = table.getChild(embeddings.sourceColumn)
@@ -136,12 +142,15 @@ export async function fromTableToBuffer<T> (table: ArrowTable, embeddings?: Embe
const column = vectorFromArray(vectors, newVectorType(vectors[0].length))
table = table.assign(new ArrowTable({ vector: column }))
}
if (schema !== undefined) {
table = alignTable(table, schema)
}
const writer = RecordBatchFileWriter.writeAll(table)
return Buffer.from(await writer.toUint8Array())
}
// Converts an Arrow Table into Arrow IPC stream format
export async function fromTableToStreamBuffer<T> (table: ArrowTable, embeddings?: EmbeddingFunction<T>): Promise<Buffer> {
export async function fromTableToStreamBuffer<T> (table: ArrowTable, embeddings?: EmbeddingFunction<T>, schema?: Schema): Promise<Buffer> {
if (embeddings !== undefined) {
const source = table.getChild(embeddings.sourceColumn)
@@ -153,10 +162,36 @@ export async function fromTableToStreamBuffer<T> (table: ArrowTable, embeddings?
const column = vectorFromArray(vectors, newVectorType(vectors[0].length))
table = table.assign(new ArrowTable({ vector: column }))
}
if (schema !== undefined) {
table = alignTable(table, schema)
}
const writer = RecordBatchStreamWriter.writeAll(table)
return Buffer.from(await writer.toUint8Array())
}
function alignBatch (batch: RecordBatch, schema: Schema): RecordBatch {
const alignedChildren = []
for (const field of schema.fields) {
const indexInBatch = batch.schema.fields?.findIndex((f) => f.name === field.name)
if (indexInBatch < 0) {
throw new Error(`The column ${field.name} was not found in the Arrow Table`)
}
alignedChildren.push(batch.data.children[indexInBatch])
}
const newData = makeData({
type: new Struct(schema.fields),
length: batch.numRows,
nullCount: batch.nullCount,
children: alignedChildren
})
return new RecordBatch(schema, newData)
}
function alignTable (table: ArrowTable, schema: Schema): ArrowTable {
const alignedBatches = table.batches.map(batch => alignBatch(batch, schema))
return new ArrowTable(schema, alignedBatches)
}
// Creates an empty Arrow Table
export function createEmptyTable (schema: Schema): ArrowTable {
return new ArrowTable(schema)

View File

@@ -485,10 +485,10 @@ export class LocalConnection implements Connection {
}
buffer = await fromTableToBuffer(createEmptyTable(schema))
} else if (data instanceof ArrowTable) {
buffer = await fromTableToBuffer(data, embeddingFunction)
buffer = await fromTableToBuffer(data, embeddingFunction, schema)
} else {
// data is Array<Record<...>>
buffer = await fromRecordsToBuffer(data, embeddingFunction)
buffer = await fromRecordsToBuffer(data, embeddingFunction, schema)
}
const tbl = await tableCreate.call(this._db, name, buffer, writeOptions?.writeMode?.toString(), ...getAwsArgs(this._options()))
@@ -560,9 +560,10 @@ export class LocalTable<T = number[]> implements Table<T> {
* @return The number of rows added to the table
*/
async add (data: Array<Record<string, unknown>>): Promise<number> {
const schema = await this.schema
return tableAdd.call(
this._tbl,
await fromRecordsToBuffer(data, this._embeddings),
await fromRecordsToBuffer(data, this._embeddings, schema),
WriteMode.Append.toString(),
...getAwsArgs(this._options())
).then((newTable: any) => { this._tbl = newTable })

View File

@@ -176,6 +176,26 @@ describe('LanceDB client', function () {
assert.deepEqual(await con.tableNames(), ['vectors'])
})
it('create a table with a schema and records', async function () {
const dir = await track().mkdir('lancejs')
const con = await lancedb.connect(dir)
const schema = new Schema(
[new Field('id', new Int32()),
new Field('name', new Utf8()),
new Field('vector', new FixedSizeList(2, new Field('item', new Float32(), true)), false)
]
)
const data = [
{ vector: [0.5, 0.2], name: 'foo', id: 0 },
{ vector: [0.3, 0.1], name: 'bar', id: 1 }
]
// even thought the keys in data is out of order it should still work
const table = await con.createTable({ name: 'vectors', data, schema })
assert.equal(table.name, 'vectors')
assert.deepEqual(await con.tableNames(), ['vectors'])
})
it('create a table with a empty data array', async function () {
const dir = await track().mkdir('lancejs')
const con = await lancedb.connect(dir)
@@ -294,6 +314,25 @@ describe('LanceDB client', function () {
assert.equal(await table.countRows(), 4)
})
it('appends records with fields in a different order', async function () {
const dir = await track().mkdir('lancejs')
const con = await lancedb.connect(dir)
const data = [
{ id: 1, vector: [0.1, 0.2], price: 10, name: 'a' },
{ id: 2, vector: [1.1, 1.2], price: 50, name: 'b' }
]
const table = await con.createTable('vectors', data)
const dataAdd = [
{ id: 3, vector: [2.1, 2.2], name: 'c', price: 10 },
{ id: 4, vector: [3.1, 3.2], name: 'd', price: 50 }
]
await table.add(dataAdd)
assert.equal(await table.countRows(), 4)
})
it('overwrite all records in a table', async function () {
const uri = await createTestDB()
const con = await lancedb.connect(uri)