mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-04 19:02:58 +00:00
feat(node): align incoming data to table schema (#802)
This commit is contained in:
@@ -17,10 +17,9 @@ import {
|
||||
Float32,
|
||||
makeBuilder,
|
||||
RecordBatchFileWriter,
|
||||
Utf8,
|
||||
type Vector,
|
||||
Utf8, type Vector,
|
||||
FixedSizeList,
|
||||
vectorFromArray, type Schema, Table as ArrowTable, RecordBatchStreamWriter, List, Float64
|
||||
vectorFromArray, type Schema, Table as ArrowTable, RecordBatchStreamWriter, List, Float64, RecordBatch, makeData, Struct
|
||||
} from 'apache-arrow'
|
||||
import { type EmbeddingFunction } from './index'
|
||||
|
||||
@@ -78,6 +77,7 @@ export async function convertToTable<T> (data: Array<Record<string, unknown>>, e
|
||||
}
|
||||
records[columnsKey] = listBuilder.finish().toVector()
|
||||
} else {
|
||||
// TODO if this is a struct field then recursively align the subfields
|
||||
records[columnsKey] = vectorFromArray(values)
|
||||
}
|
||||
}
|
||||
@@ -110,21 +110,27 @@ function newVectorType (dim: number): FixedSizeList<Float32> {
|
||||
}
|
||||
|
||||
// Converts an Array of records into Arrow IPC format
|
||||
export async function fromRecordsToBuffer<T> (data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>): Promise<Buffer> {
|
||||
const table = await convertToTable(data, embeddings)
|
||||
export async function fromRecordsToBuffer<T> (data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>, schema?: Schema): Promise<Buffer> {
|
||||
let table = await convertToTable(data, embeddings)
|
||||
if (schema !== undefined) {
|
||||
table = alignTable(table, schema)
|
||||
}
|
||||
const writer = RecordBatchFileWriter.writeAll(table)
|
||||
return Buffer.from(await writer.toUint8Array())
|
||||
}
|
||||
|
||||
// Converts an Array of records into Arrow IPC stream format
|
||||
export async function fromRecordsToStreamBuffer<T> (data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>): Promise<Buffer> {
|
||||
const table = await convertToTable(data, embeddings)
|
||||
export async function fromRecordsToStreamBuffer<T> (data: Array<Record<string, unknown>>, embeddings?: EmbeddingFunction<T>, schema?: Schema): Promise<Buffer> {
|
||||
let table = await convertToTable(data, embeddings)
|
||||
if (schema !== undefined) {
|
||||
table = alignTable(table, schema)
|
||||
}
|
||||
const writer = RecordBatchStreamWriter.writeAll(table)
|
||||
return Buffer.from(await writer.toUint8Array())
|
||||
}
|
||||
|
||||
// Converts an Arrow Table into Arrow IPC format
|
||||
export async function fromTableToBuffer<T> (table: ArrowTable, embeddings?: EmbeddingFunction<T>): Promise<Buffer> {
|
||||
export async function fromTableToBuffer<T> (table: ArrowTable, embeddings?: EmbeddingFunction<T>, schema?: Schema): Promise<Buffer> {
|
||||
if (embeddings !== undefined) {
|
||||
const source = table.getChild(embeddings.sourceColumn)
|
||||
|
||||
@@ -136,12 +142,15 @@ export async function fromTableToBuffer<T> (table: ArrowTable, embeddings?: Embe
|
||||
const column = vectorFromArray(vectors, newVectorType(vectors[0].length))
|
||||
table = table.assign(new ArrowTable({ vector: column }))
|
||||
}
|
||||
if (schema !== undefined) {
|
||||
table = alignTable(table, schema)
|
||||
}
|
||||
const writer = RecordBatchFileWriter.writeAll(table)
|
||||
return Buffer.from(await writer.toUint8Array())
|
||||
}
|
||||
|
||||
// Converts an Arrow Table into Arrow IPC stream format
|
||||
export async function fromTableToStreamBuffer<T> (table: ArrowTable, embeddings?: EmbeddingFunction<T>): Promise<Buffer> {
|
||||
export async function fromTableToStreamBuffer<T> (table: ArrowTable, embeddings?: EmbeddingFunction<T>, schema?: Schema): Promise<Buffer> {
|
||||
if (embeddings !== undefined) {
|
||||
const source = table.getChild(embeddings.sourceColumn)
|
||||
|
||||
@@ -153,10 +162,36 @@ export async function fromTableToStreamBuffer<T> (table: ArrowTable, embeddings?
|
||||
const column = vectorFromArray(vectors, newVectorType(vectors[0].length))
|
||||
table = table.assign(new ArrowTable({ vector: column }))
|
||||
}
|
||||
if (schema !== undefined) {
|
||||
table = alignTable(table, schema)
|
||||
}
|
||||
const writer = RecordBatchStreamWriter.writeAll(table)
|
||||
return Buffer.from(await writer.toUint8Array())
|
||||
}
|
||||
|
||||
function alignBatch (batch: RecordBatch, schema: Schema): RecordBatch {
|
||||
const alignedChildren = []
|
||||
for (const field of schema.fields) {
|
||||
const indexInBatch = batch.schema.fields?.findIndex((f) => f.name === field.name)
|
||||
if (indexInBatch < 0) {
|
||||
throw new Error(`The column ${field.name} was not found in the Arrow Table`)
|
||||
}
|
||||
alignedChildren.push(batch.data.children[indexInBatch])
|
||||
}
|
||||
const newData = makeData({
|
||||
type: new Struct(schema.fields),
|
||||
length: batch.numRows,
|
||||
nullCount: batch.nullCount,
|
||||
children: alignedChildren
|
||||
})
|
||||
return new RecordBatch(schema, newData)
|
||||
}
|
||||
|
||||
function alignTable (table: ArrowTable, schema: Schema): ArrowTable {
|
||||
const alignedBatches = table.batches.map(batch => alignBatch(batch, schema))
|
||||
return new ArrowTable(schema, alignedBatches)
|
||||
}
|
||||
|
||||
// Creates an empty Arrow Table
|
||||
export function createEmptyTable (schema: Schema): ArrowTable {
|
||||
return new ArrowTable(schema)
|
||||
|
||||
@@ -485,10 +485,10 @@ export class LocalConnection implements Connection {
|
||||
}
|
||||
buffer = await fromTableToBuffer(createEmptyTable(schema))
|
||||
} else if (data instanceof ArrowTable) {
|
||||
buffer = await fromTableToBuffer(data, embeddingFunction)
|
||||
buffer = await fromTableToBuffer(data, embeddingFunction, schema)
|
||||
} else {
|
||||
// data is Array<Record<...>>
|
||||
buffer = await fromRecordsToBuffer(data, embeddingFunction)
|
||||
buffer = await fromRecordsToBuffer(data, embeddingFunction, schema)
|
||||
}
|
||||
|
||||
const tbl = await tableCreate.call(this._db, name, buffer, writeOptions?.writeMode?.toString(), ...getAwsArgs(this._options()))
|
||||
@@ -560,9 +560,10 @@ export class LocalTable<T = number[]> implements Table<T> {
|
||||
* @return The number of rows added to the table
|
||||
*/
|
||||
async add (data: Array<Record<string, unknown>>): Promise<number> {
|
||||
const schema = await this.schema
|
||||
return tableAdd.call(
|
||||
this._tbl,
|
||||
await fromRecordsToBuffer(data, this._embeddings),
|
||||
await fromRecordsToBuffer(data, this._embeddings, schema),
|
||||
WriteMode.Append.toString(),
|
||||
...getAwsArgs(this._options())
|
||||
).then((newTable: any) => { this._tbl = newTable })
|
||||
|
||||
@@ -176,6 +176,26 @@ describe('LanceDB client', function () {
|
||||
assert.deepEqual(await con.tableNames(), ['vectors'])
|
||||
})
|
||||
|
||||
it('create a table with a schema and records', async function () {
|
||||
const dir = await track().mkdir('lancejs')
|
||||
const con = await lancedb.connect(dir)
|
||||
|
||||
const schema = new Schema(
|
||||
[new Field('id', new Int32()),
|
||||
new Field('name', new Utf8()),
|
||||
new Field('vector', new FixedSizeList(2, new Field('item', new Float32(), true)), false)
|
||||
]
|
||||
)
|
||||
const data = [
|
||||
{ vector: [0.5, 0.2], name: 'foo', id: 0 },
|
||||
{ vector: [0.3, 0.1], name: 'bar', id: 1 }
|
||||
]
|
||||
// even thought the keys in data is out of order it should still work
|
||||
const table = await con.createTable({ name: 'vectors', data, schema })
|
||||
assert.equal(table.name, 'vectors')
|
||||
assert.deepEqual(await con.tableNames(), ['vectors'])
|
||||
})
|
||||
|
||||
it('create a table with a empty data array', async function () {
|
||||
const dir = await track().mkdir('lancejs')
|
||||
const con = await lancedb.connect(dir)
|
||||
@@ -294,6 +314,25 @@ describe('LanceDB client', function () {
|
||||
assert.equal(await table.countRows(), 4)
|
||||
})
|
||||
|
||||
it('appends records with fields in a different order', async function () {
|
||||
const dir = await track().mkdir('lancejs')
|
||||
const con = await lancedb.connect(dir)
|
||||
|
||||
const data = [
|
||||
{ id: 1, vector: [0.1, 0.2], price: 10, name: 'a' },
|
||||
{ id: 2, vector: [1.1, 1.2], price: 50, name: 'b' }
|
||||
]
|
||||
|
||||
const table = await con.createTable('vectors', data)
|
||||
|
||||
const dataAdd = [
|
||||
{ id: 3, vector: [2.1, 2.2], name: 'c', price: 10 },
|
||||
{ id: 4, vector: [3.1, 3.2], name: 'd', price: 50 }
|
||||
]
|
||||
await table.add(dataAdd)
|
||||
assert.equal(await table.countRows(), 4)
|
||||
})
|
||||
|
||||
it('overwrite all records in a table', async function () {
|
||||
const uri = await createTestDB()
|
||||
const con = await lancedb.connect(uri)
|
||||
|
||||
Reference in New Issue
Block a user