// Copyright 2023 Lance Developers. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. import { Field, type FixedSizeListBuilder, Float32, makeBuilder, RecordBatchFileWriter, Utf8, type Vector, FixedSizeList, vectorFromArray, type Schema, Table as ArrowTable, RecordBatchStreamWriter, List, Float64, RecordBatch, makeData, Struct } from 'apache-arrow' import { type EmbeddingFunction } from './index' // Converts an Array of records into an Arrow Table, optionally applying an embeddings function to it. export async function convertToTable (data: Array>, embeddings?: EmbeddingFunction): Promise { if (data.length === 0) { throw new Error('At least one record needs to be provided') } const columns = Object.keys(data[0]) const records: Record = {} for (const columnsKey of columns) { if (columnsKey === 'vector') { const vectorSize = (data[0].vector as any[]).length const listBuilder = newVectorBuilder(vectorSize) for (const datum of data) { if ((datum[columnsKey] as any[]).length !== vectorSize) { throw new Error(`Invalid vector size, expected ${vectorSize}`) } listBuilder.append(datum[columnsKey]) } records[columnsKey] = listBuilder.finish().toVector() } else { const values = [] for (const datum of data) { values.push(datum[columnsKey]) } if (columnsKey === embeddings?.sourceColumn) { const vectors = await embeddings.embed(values as T[]) records.vector = vectorFromArray(vectors, newVectorType(vectors[0].length)) } if (typeof values[0] === 'string') { // `vectorFromArray` converts strings into dictionary vectors, forcing it back to a string column records[columnsKey] = vectorFromArray(values, new Utf8()) } else if (Array.isArray(values[0])) { const elementType = getElementType(values[0]) let innerType if (elementType === 'string') { innerType = new Utf8() } else if (elementType === 'number') { innerType = new Float64() } else { // TODO: pass in schema if it exists, else keep going to the next element throw new Error(`Unsupported array element type ${elementType}`) } const listBuilder = makeBuilder({ type: new List(new Field('item', innerType, true)) }) for (const value of values) { listBuilder.append(value) } records[columnsKey] = listBuilder.finish().toVector() } else { // TODO if this is a struct field then recursively align the subfields records[columnsKey] = vectorFromArray(values) } } } return new ArrowTable(records) } function getElementType (arr: any[]): string { if (arr.length === 0) { return 'undefined' } return typeof arr[0] } // Creates a new Arrow ListBuilder that stores a Vector column function newVectorBuilder (dim: number): FixedSizeListBuilder { return makeBuilder({ type: newVectorType(dim) }) } // Creates the Arrow Type for a Vector column with dimension `dim` function newVectorType (dim: number): FixedSizeList { // Somewhere we always default to have the elements nullable, so we need to set it to true // otherwise we often get schema mismatches because the stored data always has schema with nullable elements const children = new Field('item', new Float32(), true) return new FixedSizeList(dim, children) } // Converts an Array of records into Arrow IPC format export async function fromRecordsToBuffer (data: Array>, embeddings?: EmbeddingFunction, schema?: Schema): Promise { let table = await convertToTable(data, embeddings) if (schema !== undefined) { table = alignTable(table, schema) } const writer = RecordBatchFileWriter.writeAll(table) return Buffer.from(await writer.toUint8Array()) } // Converts an Array of records into Arrow IPC stream format export async function fromRecordsToStreamBuffer (data: Array>, embeddings?: EmbeddingFunction, schema?: Schema): Promise { let table = await convertToTable(data, embeddings) if (schema !== undefined) { table = alignTable(table, schema) } const writer = RecordBatchStreamWriter.writeAll(table) return Buffer.from(await writer.toUint8Array()) } // Converts an Arrow Table into Arrow IPC format export async function fromTableToBuffer (table: ArrowTable, embeddings?: EmbeddingFunction, schema?: Schema): Promise { if (embeddings !== undefined) { const source = table.getChild(embeddings.sourceColumn) if (source === null) { throw new Error(`The embedding source column ${embeddings.sourceColumn} was not found in the Arrow Table`) } const vectors = await embeddings.embed(source.toArray() as T[]) const column = vectorFromArray(vectors, newVectorType(vectors[0].length)) table = table.assign(new ArrowTable({ vector: column })) } if (schema !== undefined) { table = alignTable(table, schema) } const writer = RecordBatchFileWriter.writeAll(table) return Buffer.from(await writer.toUint8Array()) } // Converts an Arrow Table into Arrow IPC stream format export async function fromTableToStreamBuffer (table: ArrowTable, embeddings?: EmbeddingFunction, schema?: Schema): Promise { if (embeddings !== undefined) { const source = table.getChild(embeddings.sourceColumn) if (source === null) { throw new Error(`The embedding source column ${embeddings.sourceColumn} was not found in the Arrow Table`) } const vectors = await embeddings.embed(source.toArray() as T[]) const column = vectorFromArray(vectors, newVectorType(vectors[0].length)) table = table.assign(new ArrowTable({ vector: column })) } if (schema !== undefined) { table = alignTable(table, schema) } const writer = RecordBatchStreamWriter.writeAll(table) return Buffer.from(await writer.toUint8Array()) } function alignBatch (batch: RecordBatch, schema: Schema): RecordBatch { const alignedChildren = [] for (const field of schema.fields) { const indexInBatch = batch.schema.fields?.findIndex((f) => f.name === field.name) if (indexInBatch < 0) { throw new Error(`The column ${field.name} was not found in the Arrow Table`) } alignedChildren.push(batch.data.children[indexInBatch]) } const newData = makeData({ type: new Struct(schema.fields), length: batch.numRows, nullCount: batch.nullCount, children: alignedChildren }) return new RecordBatch(schema, newData) } function alignTable (table: ArrowTable, schema: Schema): ArrowTable { const alignedBatches = table.batches.map(batch => alignBatch(batch, schema)) return new ArrowTable(schema, alignedBatches) } // Creates an empty Arrow Table export function createEmptyTable (schema: Schema): ArrowTable { return new ArrowTable(schema) }