mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-05 19:32:56 +00:00
feat: make it easier to create empty tables (#942)
This PR also reworks the table creation utilities significantly so that they are more consistent, built on top of each other, and thoroughly documented.
This commit is contained in:
@@ -13,10 +13,7 @@
|
||||
// limitations under the License.
|
||||
|
||||
import {
|
||||
Int64,
|
||||
Field,
|
||||
type FixedSizeListBuilder,
|
||||
Float32,
|
||||
makeBuilder,
|
||||
RecordBatchFileWriter,
|
||||
Utf8,
|
||||
@@ -27,15 +24,19 @@ import {
|
||||
Table as ArrowTable,
|
||||
RecordBatchStreamWriter,
|
||||
List,
|
||||
Float64,
|
||||
RecordBatch,
|
||||
makeData,
|
||||
Struct,
|
||||
type Float,
|
||||
type DataType
|
||||
DataType,
|
||||
Binary,
|
||||
Float32
|
||||
} from 'apache-arrow'
|
||||
import { type EmbeddingFunction } from './index'
|
||||
|
||||
/*
|
||||
* Options to control how a column should be converted to a vector array
|
||||
*/
|
||||
export class VectorColumnOptions {
|
||||
/** Vector column type. */
|
||||
type: Float = new Float32()
|
||||
@@ -47,14 +48,50 @@ export class VectorColumnOptions {
|
||||
|
||||
/** Options to control the makeArrowTable call. */
|
||||
export class MakeArrowTableOptions {
|
||||
/** Provided schema. */
|
||||
/*
|
||||
* Schema of the data.
|
||||
*
|
||||
* If this is not provided then the data type will be inferred from the
|
||||
* JS type. Integer numbers will become int64, floating point numbers
|
||||
* will become float64 and arrays will become variable sized lists with
|
||||
* the data type inferred from the first element in the array.
|
||||
*
|
||||
* The schema must be specified if there are no records (e.g. to make
|
||||
* an empty table)
|
||||
*/
|
||||
schema?: Schema
|
||||
|
||||
/** Vector columns */
|
||||
/*
|
||||
* Mapping from vector column name to expected type
|
||||
*
|
||||
* Lance expects vector columns to be fixed size list arrays (i.e. tensors)
|
||||
* However, `makeArrowTable` will not infer this by default (it creates
|
||||
* variable size list arrays). This field can be used to indicate that a column
|
||||
* should be treated as a vector column and converted to a fixed size list.
|
||||
*
|
||||
* The keys should be the names of the vector columns. The value specifies the
|
||||
* expected data type of the vector columns.
|
||||
*
|
||||
* If `schema` is provided then this field is ignored.
|
||||
*
|
||||
* By default, the column named "vector" will be assumed to be a float32
|
||||
* vector column.
|
||||
*/
|
||||
vectorColumns: Record<string, VectorColumnOptions> = {
|
||||
vector: new VectorColumnOptions()
|
||||
}
|
||||
|
||||
/**
|
||||
* If true then string columns will be encoded with dictionary encoding
|
||||
*
|
||||
* Set this to true if your string columns tend to repeat the same values
|
||||
* often. For more precise control use the `schema` property to specify the
|
||||
* data type for individual columns.
|
||||
*
|
||||
* If `schema` is provided then this property is ignored.
|
||||
*/
|
||||
dictionaryEncodeStrings: boolean = false
|
||||
|
||||
constructor (values?: Partial<MakeArrowTableOptions>) {
|
||||
Object.assign(this, values)
|
||||
}
|
||||
@@ -64,8 +101,29 @@ export class MakeArrowTableOptions {
|
||||
* An enhanced version of the {@link makeTable} function from Apache Arrow
|
||||
* that supports nested fields and embeddings columns.
|
||||
*
|
||||
* This function converts an array of Record<String, any> (row-major JS objects)
|
||||
* to an Arrow Table (a columnar structure)
|
||||
*
|
||||
* Note that it currently does not support nulls.
|
||||
*
|
||||
* If a schema is provided then it will be used to determine the resulting array
|
||||
* types. Fields will also be reordered to fit the order defined by the schema.
|
||||
*
|
||||
* If a schema is not provided then the types will be inferred and the field order
|
||||
* will be controlled by the order of properties in the first record.
|
||||
*
|
||||
* If the input is empty then a schema must be provided to create an empty table.
|
||||
*
|
||||
* When a schema is not specified then data types will be inferred. The inference
|
||||
* rules are as follows:
|
||||
*
|
||||
* - boolean => Bool
|
||||
* - number => Float64
|
||||
* - String => Utf8
|
||||
* - Buffer => Binary
|
||||
* - Record<String, any> => Struct
|
||||
* - Array<any> => List
|
||||
*
|
||||
* @param data input data
|
||||
* @param options options to control the makeArrowTable call.
|
||||
*
|
||||
@@ -88,8 +146,10 @@ export class MakeArrowTableOptions {
|
||||
* ], { schema });
|
||||
* ```
|
||||
*
|
||||
* It guesses the vector columns if the schema is not provided. For example,
|
||||
* by default it assumes that the column named `vector` is a vector column.
|
||||
* By default it assumes that the column named `vector` is a vector column
|
||||
* and it will be converted into a fixed size list array of type float32.
|
||||
* The `vectorColumns` option can be used to support other vector column
|
||||
* names and data types.
|
||||
*
|
||||
* ```ts
|
||||
*
|
||||
@@ -136,214 +196,304 @@ export function makeArrowTable (
|
||||
data: Array<Record<string, any>>,
|
||||
options?: Partial<MakeArrowTableOptions>
|
||||
): ArrowTable {
|
||||
if (data.length === 0) {
|
||||
throw new Error('At least one record needs to be provided')
|
||||
if (data.length === 0 && (options?.schema === undefined || options?.schema === null)) {
|
||||
throw new Error('At least one record or a schema needs to be provided')
|
||||
}
|
||||
|
||||
const opt = new MakeArrowTableOptions(options !== undefined ? options : {})
|
||||
const columns: Record<string, Vector> = {}
|
||||
// TODO: sample dataset to find missing columns
|
||||
const columnNames = Object.keys(data[0])
|
||||
// Prefer the field ordering of the schema, if present
|
||||
const columnNames = ((options?.schema) != null) ? (options?.schema?.names as string[]) : Object.keys(data[0])
|
||||
for (const colName of columnNames) {
|
||||
if (data.length !== 0 && !Object.prototype.hasOwnProperty.call(data[0], colName)) {
|
||||
// The field is present in the schema, but not in the data, skip it
|
||||
continue
|
||||
}
|
||||
// Extract a single column from the records (transpose from row-major to col-major)
|
||||
let values = data.map((datum) => datum[colName])
|
||||
let vector: Vector
|
||||
|
||||
// By default (type === undefined) arrow will infer the type from the JS type
|
||||
let type
|
||||
if (opt.schema !== undefined) {
|
||||
// Explicit schema is provided, highest priority
|
||||
const fieldType: DataType | undefined = opt.schema.fields.filter((f) => f.name === colName)[0]?.type as DataType
|
||||
if (fieldType instanceof Int64) {
|
||||
// If there is a schema provided, then use that for the type instead
|
||||
type = opt.schema?.fields.filter((f) => f.name === colName)[0]?.type
|
||||
if (DataType.isInt(type) && type.bitWidth === 64) {
|
||||
// wrap in BigInt to avoid bug: https://github.com/apache/arrow/issues/40051
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-argument
|
||||
values = values.map((v) => BigInt(v))
|
||||
values = values.map((v) => {
|
||||
if (v === null) {
|
||||
return v
|
||||
}
|
||||
return BigInt(v)
|
||||
})
|
||||
}
|
||||
vector = vectorFromArray(values, fieldType)
|
||||
} else {
|
||||
// Otherwise, check to see if this column is one of the vector columns
|
||||
// defined by opt.vectorColumns and, if so, use the fixed size list type
|
||||
const vectorColumnOptions = opt.vectorColumns[colName]
|
||||
if (vectorColumnOptions !== undefined) {
|
||||
const fslType = new FixedSizeList(
|
||||
values[0].length,
|
||||
new Field('item', vectorColumnOptions.type, false)
|
||||
)
|
||||
vector = vectorFromArray(values, fslType)
|
||||
} else {
|
||||
// Normal case
|
||||
vector = vectorFromArray(values)
|
||||
type = newVectorType(values[0].length, vectorColumnOptions.type)
|
||||
}
|
||||
}
|
||||
columns[colName] = vector
|
||||
|
||||
try {
|
||||
// Convert an Array of JS values to an arrow vector
|
||||
columns[colName] = makeVector(values, type, opt.dictionaryEncodeStrings)
|
||||
} catch (error: unknown) {
|
||||
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
|
||||
throw Error(`Could not convert column "${colName}" to Arrow: ${error}`)
|
||||
}
|
||||
}
|
||||
|
||||
return new ArrowTable(columns)
|
||||
if (opt.schema != null) {
|
||||
// `new ArrowTable(columns)` infers a schema which may sometimes have
|
||||
// incorrect nullability (it assumes nullable=true if there are 0 rows)
|
||||
//
|
||||
// `new ArrowTable(schema, columns)` will also fail because it will create a
|
||||
// batch with an inferred schema and then complain that the batch schema
|
||||
// does not match the provided schema.
|
||||
//
|
||||
// To work around this we first create a table with the wrong schema and
|
||||
// then patch the schema of the batches so we can use
|
||||
// `new ArrowTable(schema, batches)` which does not do any schema inference
|
||||
const firstTable = new ArrowTable(columns)
|
||||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
||||
const batchesFixed = firstTable.batches.map(batch => new RecordBatch(opt.schema!, batch.data))
|
||||
return new ArrowTable(opt.schema, batchesFixed)
|
||||
} else {
|
||||
return new ArrowTable(columns)
|
||||
}
|
||||
}
|
||||
|
||||
// Converts an Array of records into an Arrow Table, optionally applying an embeddings function to it.
|
||||
/**
|
||||
* Create an empty Arrow table with the provided schema
|
||||
*/
|
||||
export function makeEmptyTable (schema: Schema): ArrowTable {
|
||||
return makeArrowTable([], { schema })
|
||||
}
|
||||
|
||||
// Helper function to convert Array<Array<any>> to a variable sized list array
|
||||
function makeListVector (lists: any[][]): Vector<any> {
|
||||
if (lists.length === 0 || lists[0].length === 0) {
|
||||
throw Error('Cannot infer list vector from empty array or empty list')
|
||||
}
|
||||
const sampleList = lists[0]
|
||||
let inferredType
|
||||
try {
|
||||
const sampleVector = makeVector(sampleList)
|
||||
inferredType = sampleVector.type
|
||||
} catch (error: unknown) {
|
||||
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
|
||||
throw Error(`Cannot infer list vector. Cannot infer inner type: ${error}`)
|
||||
}
|
||||
|
||||
const listBuilder = makeBuilder({
|
||||
type: new List(new Field('item', inferredType, true))
|
||||
})
|
||||
for (const list of lists) {
|
||||
listBuilder.append(list)
|
||||
}
|
||||
return listBuilder.finish().toVector()
|
||||
}
|
||||
|
||||
// Helper function to convert an Array of JS values to an Arrow Vector
|
||||
function makeVector (values: any[], type?: DataType, stringAsDictionary?: boolean): Vector<any> {
|
||||
if (type !== undefined) {
|
||||
// No need for inference, let Arrow create it
|
||||
return vectorFromArray(values, type)
|
||||
}
|
||||
if (values.length === 0) {
|
||||
throw Error('makeVector requires at least one value or the type must be specfied')
|
||||
}
|
||||
const sampleValue = values.find(val => val !== null && val !== undefined)
|
||||
if (sampleValue === undefined) {
|
||||
throw Error('makeVector cannot infer the type if all values are null or undefined')
|
||||
}
|
||||
if (Array.isArray(sampleValue)) {
|
||||
// Default Arrow inference doesn't handle list types
|
||||
return makeListVector(values)
|
||||
} else if (Buffer.isBuffer(sampleValue)) {
|
||||
// Default Arrow inference doesn't handle Buffer
|
||||
return vectorFromArray(values, new Binary())
|
||||
} else if (!(stringAsDictionary ?? false) && (typeof sampleValue === 'string' || sampleValue instanceof String)) {
|
||||
// If the type is string then don't use Arrow's default inference unless dictionaries are requested
|
||||
// because it will always use dictionary encoding for strings
|
||||
return vectorFromArray(values, new Utf8())
|
||||
} else {
|
||||
// Convert a JS array of values to an arrow vector
|
||||
return vectorFromArray(values)
|
||||
}
|
||||
}
|
||||
|
||||
async function applyEmbeddings<T> (table: ArrowTable, embeddings?: EmbeddingFunction<T>, schema?: Schema): Promise<ArrowTable> {
|
||||
if (embeddings == null) {
|
||||
return table
|
||||
}
|
||||
|
||||
// Convert from ArrowTable to Record<String, Vector>
|
||||
const colEntries = [...Array(table.numCols).keys()].map((_, idx) => {
|
||||
const name = table.schema.fields[idx].name
|
||||
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
|
||||
const vec = table.getChildAt(idx)!
|
||||
return [name, vec]
|
||||
})
|
||||
const newColumns = Object.fromEntries(colEntries)
|
||||
|
||||
const sourceColumn = newColumns[embeddings.sourceColumn]
|
||||
const destColumn = embeddings.destColumn ?? 'vector'
|
||||
const innerDestType = embeddings.embeddingDataType ?? new Float32()
|
||||
if (sourceColumn === undefined) {
|
||||
throw new Error(`Cannot apply embedding function because the source column '${embeddings.sourceColumn}' was not present in the data`)
|
||||
}
|
||||
|
||||
if (table.numRows === 0) {
|
||||
if (Object.prototype.hasOwnProperty.call(newColumns, destColumn)) {
|
||||
// We have an empty table and it already has the embedding column so no work needs to be done
|
||||
// Note: we don't return an error like we did below because this is a common occurrence. For example,
|
||||
// if we call convertToTable with 0 records and a schema that includes the embedding
|
||||
return table
|
||||
}
|
||||
if (embeddings.embeddingDimension !== undefined) {
|
||||
const destType = newVectorType(embeddings.embeddingDimension, innerDestType)
|
||||
newColumns[destColumn] = makeVector([], destType)
|
||||
} else if (schema != null) {
|
||||
const destField = schema.fields.find(f => f.name === destColumn)
|
||||
if (destField != null) {
|
||||
newColumns[destColumn] = makeVector([], destField.type)
|
||||
} else {
|
||||
throw new Error(`Attempt to apply embeddings to an empty table failed because schema was missing embedding column '${destColumn}'`)
|
||||
}
|
||||
} else {
|
||||
throw new Error('Attempt to apply embeddings to an empty table when the embeddings function does not specify `embeddingDimension`')
|
||||
}
|
||||
} else {
|
||||
if (Object.prototype.hasOwnProperty.call(newColumns, destColumn)) {
|
||||
throw new Error(`Attempt to apply embeddings to table failed because column ${destColumn} already existed`)
|
||||
}
|
||||
if (table.batches.length > 1) {
|
||||
throw new Error('Internal error: `makeArrowTable` unexpectedly created a table with more than one batch')
|
||||
}
|
||||
const values = sourceColumn.toArray()
|
||||
const vectors = await embeddings.embed(values as T[])
|
||||
if (vectors.length !== values.length) {
|
||||
throw new Error('Embedding function did not return an embedding for each input element')
|
||||
}
|
||||
const destType = newVectorType(vectors[0].length, innerDestType)
|
||||
newColumns[destColumn] = makeVector(vectors, destType)
|
||||
}
|
||||
|
||||
const newTable = new ArrowTable(newColumns)
|
||||
if (schema != null) {
|
||||
if (schema.fields.find(f => f.name === destColumn) === undefined) {
|
||||
throw new Error(`When using embedding functions and specifying a schema the schema should include the embedding column but the column ${destColumn} was missing`)
|
||||
}
|
||||
return alignTable(newTable, schema)
|
||||
}
|
||||
return newTable
|
||||
}
|
||||
|
||||
/*
|
||||
* Convert an Array of records into an Arrow Table, optionally applying an
|
||||
* embeddings function to it.
|
||||
*
|
||||
* This function calls `makeArrowTable` first to create the Arrow Table.
|
||||
* Any provided `makeTableOptions` (e.g. a schema) will be passed on to
|
||||
* that call.
|
||||
*
|
||||
* The embedding function will be passed a column of values (based on the
|
||||
* `sourceColumn` of the embedding function) and expects to receive back
|
||||
* number[][] which will be converted into a fixed size list column. By
|
||||
* default this will be a fixed size list of Float32 but that can be
|
||||
* customized by the `embeddingDataType` property of the embedding function.
|
||||
*
|
||||
* If a schema is provided in `makeTableOptions` then it should include the
|
||||
* embedding columns. If no schema is provded then embedding columns will
|
||||
* be placed at the end of the table, after all of the input columns.
|
||||
*/
|
||||
export async function convertToTable<T> (
|
||||
data: Array<Record<string, unknown>>,
|
||||
embeddings?: EmbeddingFunction<T>
|
||||
embeddings?: EmbeddingFunction<T>,
|
||||
makeTableOptions?: Partial<MakeArrowTableOptions>
|
||||
): Promise<ArrowTable> {
|
||||
if (data.length === 0) {
|
||||
throw new Error('At least one record needs to be provided')
|
||||
}
|
||||
|
||||
const columns = Object.keys(data[0])
|
||||
const records: Record<string, Vector> = {}
|
||||
|
||||
for (const columnsKey of columns) {
|
||||
if (columnsKey === 'vector') {
|
||||
const vectorSize = (data[0].vector as any[]).length
|
||||
const listBuilder = newVectorBuilder(vectorSize)
|
||||
for (const datum of data) {
|
||||
if ((datum[columnsKey] as any[]).length !== vectorSize) {
|
||||
throw new Error(`Invalid vector size, expected ${vectorSize}`)
|
||||
}
|
||||
|
||||
listBuilder.append(datum[columnsKey])
|
||||
}
|
||||
records[columnsKey] = listBuilder.finish().toVector()
|
||||
} else {
|
||||
const values = []
|
||||
for (const datum of data) {
|
||||
values.push(datum[columnsKey])
|
||||
}
|
||||
|
||||
if (columnsKey === embeddings?.sourceColumn) {
|
||||
const vectors = await embeddings.embed(values as T[])
|
||||
records.vector = vectorFromArray(
|
||||
vectors,
|
||||
newVectorType(vectors[0].length)
|
||||
)
|
||||
}
|
||||
|
||||
if (typeof values[0] === 'string') {
|
||||
// `vectorFromArray` converts strings into dictionary vectors, forcing it back to a string column
|
||||
records[columnsKey] = vectorFromArray(values, new Utf8())
|
||||
} else if (Array.isArray(values[0])) {
|
||||
const elementType = getElementType(values[0])
|
||||
let innerType
|
||||
if (elementType === 'string') {
|
||||
innerType = new Utf8()
|
||||
} else if (elementType === 'number') {
|
||||
innerType = new Float64()
|
||||
} else {
|
||||
// TODO: pass in schema if it exists, else keep going to the next element
|
||||
throw new Error(`Unsupported array element type ${elementType}`)
|
||||
}
|
||||
const listBuilder = makeBuilder({
|
||||
type: new List(new Field('item', innerType, true))
|
||||
})
|
||||
for (const value of values) {
|
||||
listBuilder.append(value)
|
||||
}
|
||||
records[columnsKey] = listBuilder.finish().toVector()
|
||||
} else {
|
||||
// TODO if this is a struct field then recursively align the subfields
|
||||
records[columnsKey] = vectorFromArray(values)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return new ArrowTable(records)
|
||||
}
|
||||
|
||||
function getElementType (arr: any[]): string {
|
||||
if (arr.length === 0) {
|
||||
return 'undefined'
|
||||
}
|
||||
|
||||
return typeof arr[0]
|
||||
}
|
||||
|
||||
// Creates a new Arrow ListBuilder that stores a Vector column
|
||||
function newVectorBuilder (dim: number): FixedSizeListBuilder<Float32> {
|
||||
return makeBuilder({
|
||||
type: newVectorType(dim)
|
||||
})
|
||||
const table = makeArrowTable(data, makeTableOptions)
|
||||
return await applyEmbeddings(table, embeddings, makeTableOptions?.schema)
|
||||
}
|
||||
|
||||
// Creates the Arrow Type for a Vector column with dimension `dim`
|
||||
function newVectorType (dim: number): FixedSizeList<Float32> {
|
||||
function newVectorType <T extends Float> (dim: number, innerType: T): FixedSizeList<T> {
|
||||
// Somewhere we always default to have the elements nullable, so we need to set it to true
|
||||
// otherwise we often get schema mismatches because the stored data always has schema with nullable elements
|
||||
const children = new Field<Float32>('item', new Float32(), true)
|
||||
const children = new Field<T>('item', innerType, true)
|
||||
return new FixedSizeList(dim, children)
|
||||
}
|
||||
|
||||
// Converts an Array of records into Arrow IPC format
|
||||
/**
|
||||
* Serialize an Array of records into a buffer using the Arrow IPC File serialization
|
||||
*
|
||||
* This function will call `convertToTable` and pass on `embeddings` and `schema`
|
||||
*
|
||||
* `schema` is required if data is empty
|
||||
*/
|
||||
export async function fromRecordsToBuffer<T> (
|
||||
data: Array<Record<string, unknown>>,
|
||||
embeddings?: EmbeddingFunction<T>,
|
||||
schema?: Schema
|
||||
): Promise<Buffer> {
|
||||
let table = await convertToTable(data, embeddings)
|
||||
if (schema !== undefined) {
|
||||
table = alignTable(table, schema)
|
||||
}
|
||||
const table = await convertToTable(data, embeddings, { schema })
|
||||
const writer = RecordBatchFileWriter.writeAll(table)
|
||||
return Buffer.from(await writer.toUint8Array())
|
||||
}
|
||||
|
||||
// Converts an Array of records into Arrow IPC stream format
|
||||
/**
|
||||
* Serialize an Array of records into a buffer using the Arrow IPC Stream serialization
|
||||
*
|
||||
* This function will call `convertToTable` and pass on `embeddings` and `schema`
|
||||
*
|
||||
* `schema` is required if data is empty
|
||||
*/
|
||||
export async function fromRecordsToStreamBuffer<T> (
|
||||
data: Array<Record<string, unknown>>,
|
||||
embeddings?: EmbeddingFunction<T>,
|
||||
schema?: Schema
|
||||
): Promise<Buffer> {
|
||||
let table = await convertToTable(data, embeddings)
|
||||
if (schema !== undefined) {
|
||||
table = alignTable(table, schema)
|
||||
}
|
||||
const table = await convertToTable(data, embeddings, { schema })
|
||||
const writer = RecordBatchStreamWriter.writeAll(table)
|
||||
return Buffer.from(await writer.toUint8Array())
|
||||
}
|
||||
|
||||
// Converts an Arrow Table into Arrow IPC format
|
||||
/**
|
||||
* Serialize an Arrow Table into a buffer using the Arrow IPC File serialization
|
||||
*
|
||||
* This function will apply `embeddings` to the table in a manner similar to
|
||||
* `convertToTable`.
|
||||
*
|
||||
* `schema` is required if the table is empty
|
||||
*/
|
||||
export async function fromTableToBuffer<T> (
|
||||
table: ArrowTable,
|
||||
embeddings?: EmbeddingFunction<T>,
|
||||
schema?: Schema
|
||||
): Promise<Buffer> {
|
||||
if (embeddings !== undefined) {
|
||||
const source = table.getChild(embeddings.sourceColumn)
|
||||
|
||||
if (source === null) {
|
||||
throw new Error(
|
||||
`The embedding source column ${embeddings.sourceColumn} was not found in the Arrow Table`
|
||||
)
|
||||
}
|
||||
|
||||
const vectors = await embeddings.embed(source.toArray() as T[])
|
||||
const column = vectorFromArray(vectors, newVectorType(vectors[0].length))
|
||||
table = table.assign(new ArrowTable({ vector: column }))
|
||||
}
|
||||
if (schema !== undefined) {
|
||||
table = alignTable(table, schema)
|
||||
}
|
||||
const writer = RecordBatchFileWriter.writeAll(table)
|
||||
const tableWithEmbeddings = await applyEmbeddings(table, embeddings, schema)
|
||||
const writer = RecordBatchFileWriter.writeAll(tableWithEmbeddings)
|
||||
return Buffer.from(await writer.toUint8Array())
|
||||
}
|
||||
|
||||
// Converts an Arrow Table into Arrow IPC stream format
|
||||
/**
|
||||
* Serialize an Arrow Table into a buffer using the Arrow IPC Stream serialization
|
||||
*
|
||||
* This function will apply `embeddings` to the table in a manner similar to
|
||||
* `convertToTable`.
|
||||
*
|
||||
* `schema` is required if the table is empty
|
||||
*/
|
||||
export async function fromTableToStreamBuffer<T> (
|
||||
table: ArrowTable,
|
||||
embeddings?: EmbeddingFunction<T>,
|
||||
schema?: Schema
|
||||
): Promise<Buffer> {
|
||||
if (embeddings !== undefined) {
|
||||
const source = table.getChild(embeddings.sourceColumn)
|
||||
|
||||
if (source === null) {
|
||||
throw new Error(
|
||||
`The embedding source column ${embeddings.sourceColumn} was not found in the Arrow Table`
|
||||
)
|
||||
}
|
||||
|
||||
const vectors = await embeddings.embed(source.toArray() as T[])
|
||||
const column = vectorFromArray(vectors, newVectorType(vectors[0].length))
|
||||
table = table.assign(new ArrowTable({ vector: column }))
|
||||
}
|
||||
if (schema !== undefined) {
|
||||
table = alignTable(table, schema)
|
||||
}
|
||||
const writer = RecordBatchStreamWriter.writeAll(table)
|
||||
const tableWithEmbeddings = await applyEmbeddings(table, embeddings, schema)
|
||||
const writer = RecordBatchStreamWriter.writeAll(tableWithEmbeddings)
|
||||
return Buffer.from(await writer.toUint8Array())
|
||||
}
|
||||
|
||||
|
||||
@@ -12,18 +12,53 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
import { type Float } from 'apache-arrow'
|
||||
|
||||
/**
|
||||
* An embedding function that automatically creates vector representation for a given column.
|
||||
*/
|
||||
export interface EmbeddingFunction<T> {
|
||||
/**
|
||||
* The name of the column that will be used as input for the Embedding Function.
|
||||
*/
|
||||
* The name of the column that will be used as input for the Embedding Function.
|
||||
*/
|
||||
sourceColumn: string
|
||||
|
||||
/**
|
||||
* Creates a vector representation for the given values.
|
||||
*/
|
||||
* The data type of the embedding
|
||||
*
|
||||
* The embedding function should return `number`. This will be converted into
|
||||
* an Arrow float array. By default this will be Float32 but this property can
|
||||
* be used to control the conversion.
|
||||
*/
|
||||
embeddingDataType?: Float
|
||||
|
||||
/**
|
||||
* The dimension of the embedding
|
||||
*
|
||||
* This is optional, normally this can be determined by looking at the results of
|
||||
* `embed`. If this is not specified, and there is an attempt to apply the embedding
|
||||
* to an empty table, then that process will fail.
|
||||
*/
|
||||
embeddingDimension?: number
|
||||
|
||||
/**
|
||||
* The name of the column that will contain the embedding
|
||||
*
|
||||
* By default this is "vector"
|
||||
*/
|
||||
destColumn?: string
|
||||
|
||||
/**
|
||||
* Should the source column be excluded from the resulting table
|
||||
*
|
||||
* By default the source column is included. Set this to true and
|
||||
* only the embedding will be stored.
|
||||
*/
|
||||
excludeSource?: boolean
|
||||
|
||||
/**
|
||||
* Creates a vector representation for the given values.
|
||||
*/
|
||||
embed: (data: T[]) => Promise<number[][]>
|
||||
}
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ const {
|
||||
export { Query }
|
||||
export type { EmbeddingFunction }
|
||||
export { OpenAIEmbeddingFunction } from './embedding/openai'
|
||||
export { makeArrowTable, type MakeArrowTableOptions } from './arrow'
|
||||
export { convertToTable, makeArrowTable, type MakeArrowTableOptions } from './arrow'
|
||||
|
||||
const defaultAwsRegion = 'us-west-2'
|
||||
|
||||
|
||||
@@ -13,9 +13,10 @@
|
||||
// limitations under the License.
|
||||
|
||||
import { describe } from 'mocha'
|
||||
import { assert } from 'chai'
|
||||
import { assert, expect, use as chaiUse } from 'chai'
|
||||
import * as chaiAsPromised from 'chai-as-promised'
|
||||
|
||||
import { fromTableToBuffer, makeArrowTable } from '../arrow'
|
||||
import { convertToTable, fromTableToBuffer, makeArrowTable, makeEmptyTable } from '../arrow'
|
||||
import {
|
||||
Field,
|
||||
FixedSizeList,
|
||||
@@ -24,21 +25,79 @@ import {
|
||||
Int32,
|
||||
tableFromIPC,
|
||||
Schema,
|
||||
Float64
|
||||
Float64,
|
||||
type Table,
|
||||
Binary,
|
||||
Bool,
|
||||
Utf8,
|
||||
Struct,
|
||||
List,
|
||||
DataType,
|
||||
Dictionary,
|
||||
Int64
|
||||
} from 'apache-arrow'
|
||||
import { type EmbeddingFunction } from '../embedding/embedding_function'
|
||||
|
||||
describe('Apache Arrow tables', function () {
|
||||
it('customized schema', async function () {
|
||||
chaiUse(chaiAsPromised)
|
||||
|
||||
function sampleRecords (): Array<Record<string, any>> {
|
||||
return [
|
||||
{
|
||||
binary: Buffer.alloc(5),
|
||||
boolean: false,
|
||||
number: 7,
|
||||
string: 'hello',
|
||||
struct: { x: 0, y: 0 },
|
||||
list: ['anime', 'action', 'comedy']
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
// Helper method to verify various ways to create a table
|
||||
async function checkTableCreation (tableCreationMethod: (records: any, recordsReversed: any, schema: Schema) => Promise<Table>): Promise<void> {
|
||||
const records = sampleRecords()
|
||||
const recordsReversed = [{
|
||||
list: ['anime', 'action', 'comedy'],
|
||||
struct: { x: 0, y: 0 },
|
||||
string: 'hello',
|
||||
number: 7,
|
||||
boolean: false,
|
||||
binary: Buffer.alloc(5)
|
||||
}]
|
||||
const schema = new Schema([
|
||||
new Field('binary', new Binary(), false),
|
||||
new Field('boolean', new Bool(), false),
|
||||
new Field('number', new Float64(), false),
|
||||
new Field('string', new Utf8(), false),
|
||||
new Field('struct', new Struct([
|
||||
new Field('x', new Float64(), false),
|
||||
new Field('y', new Float64(), false)
|
||||
])),
|
||||
new Field('list', new List(new Field('item', new Utf8(), false)), false)
|
||||
])
|
||||
|
||||
const table = await tableCreationMethod(records, recordsReversed, schema)
|
||||
schema.fields.forEach((field, idx) => {
|
||||
const actualField = table.schema.fields[idx]
|
||||
assert.isFalse(actualField.nullable)
|
||||
assert.equal(table.getChild(field.name)?.type.toString(), field.type.toString())
|
||||
assert.equal(table.getChildAt(idx)?.type.toString(), field.type.toString())
|
||||
})
|
||||
}
|
||||
|
||||
describe('The function makeArrowTable', function () {
|
||||
it('will use data types from a provided schema instead of inference', async function () {
|
||||
const schema = new Schema([
|
||||
new Field('a', new Int32()),
|
||||
new Field('b', new Float32()),
|
||||
new Field('c', new FixedSizeList(3, new Field('item', new Float16())))
|
||||
new Field('c', new FixedSizeList(3, new Field('item', new Float16()))),
|
||||
new Field('d', new Int64())
|
||||
])
|
||||
const table = makeArrowTable(
|
||||
[
|
||||
{ a: 1, b: 2, c: [1, 2, 3] },
|
||||
{ a: 4, b: 5, c: [4, 5, 6] },
|
||||
{ a: 7, b: 8, c: [7, 8, 9] }
|
||||
{ a: 1, b: 2, c: [1, 2, 3], d: 9 },
|
||||
{ a: 4, b: 5, c: [4, 5, 6], d: 10 },
|
||||
{ a: 7, b: 8, c: [7, 8, 9], d: null }
|
||||
],
|
||||
{ schema }
|
||||
)
|
||||
@@ -52,13 +111,13 @@ describe('Apache Arrow tables', function () {
|
||||
assert.deepEqual(actualSchema, schema)
|
||||
})
|
||||
|
||||
it('default vector column', async function () {
|
||||
it('will assume the column `vector` is FixedSizeList<Float32> by default', async function () {
|
||||
const schema = new Schema([
|
||||
new Field('a', new Float64()),
|
||||
new Field('b', new Float64()),
|
||||
new Field(
|
||||
'vector',
|
||||
new FixedSizeList(3, new Field('item', new Float32()))
|
||||
new FixedSizeList(3, new Field('item', new Float32(), true))
|
||||
)
|
||||
])
|
||||
const table = makeArrowTable([
|
||||
@@ -76,12 +135,12 @@ describe('Apache Arrow tables', function () {
|
||||
assert.deepEqual(actualSchema, schema)
|
||||
})
|
||||
|
||||
it('2 vector columns', async function () {
|
||||
it('can support multiple vector columns', async function () {
|
||||
const schema = new Schema([
|
||||
new Field('a', new Float64()),
|
||||
new Field('b', new Float64()),
|
||||
new Field('vec1', new FixedSizeList(3, new Field('item', new Float16()))),
|
||||
new Field('vec2', new FixedSizeList(3, new Field('item', new Float16())))
|
||||
new Field('vec1', new FixedSizeList(3, new Field('item', new Float16(), true))),
|
||||
new Field('vec2', new FixedSizeList(3, new Field('item', new Float16(), true)))
|
||||
])
|
||||
const table = makeArrowTable(
|
||||
[
|
||||
@@ -105,4 +164,157 @@ describe('Apache Arrow tables', function () {
|
||||
const actualSchema = actual.schema
|
||||
assert.deepEqual(actualSchema, schema)
|
||||
})
|
||||
|
||||
it('will allow different vector column types', async function () {
|
||||
const table = makeArrowTable(
|
||||
[
|
||||
{ fp16: [1], fp32: [1], fp64: [1] }
|
||||
],
|
||||
{
|
||||
vectorColumns: {
|
||||
fp16: { type: new Float16() },
|
||||
fp32: { type: new Float32() },
|
||||
fp64: { type: new Float64() }
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
assert.equal(table.getChild('fp16')?.type.children[0].type.toString(), new Float16().toString())
|
||||
assert.equal(table.getChild('fp32')?.type.children[0].type.toString(), new Float32().toString())
|
||||
assert.equal(table.getChild('fp64')?.type.children[0].type.toString(), new Float64().toString())
|
||||
})
|
||||
|
||||
it('will use dictionary encoded strings if asked', async function () {
|
||||
const table = makeArrowTable([{ str: 'hello' }])
|
||||
assert.isTrue(DataType.isUtf8(table.getChild('str')?.type))
|
||||
|
||||
const tableWithDict = makeArrowTable([{ str: 'hello' }], { dictionaryEncodeStrings: true })
|
||||
assert.isTrue(DataType.isDictionary(tableWithDict.getChild('str')?.type))
|
||||
|
||||
const schema = new Schema([
|
||||
new Field('str', new Dictionary(new Utf8(), new Int32()))
|
||||
])
|
||||
|
||||
const tableWithDict2 = makeArrowTable([{ str: 'hello' }], { schema })
|
||||
assert.isTrue(DataType.isDictionary(tableWithDict2.getChild('str')?.type))
|
||||
})
|
||||
|
||||
it('will infer data types correctly', async function () {
|
||||
await checkTableCreation(async (records) => makeArrowTable(records))
|
||||
})
|
||||
|
||||
it('will allow a schema to be provided', async function () {
|
||||
await checkTableCreation(async (records, _, schema) => makeArrowTable(records, { schema }))
|
||||
})
|
||||
|
||||
it('will use the field order of any provided schema', async function () {
|
||||
await checkTableCreation(async (_, recordsReversed, schema) => makeArrowTable(recordsReversed, { schema }))
|
||||
})
|
||||
|
||||
it('will make an empty table', async function () {
|
||||
await checkTableCreation(async (_, __, schema) => makeArrowTable([], { schema }))
|
||||
})
|
||||
})
|
||||
|
||||
class DummyEmbedding implements EmbeddingFunction<string> {
|
||||
public readonly sourceColumn = 'string'
|
||||
public readonly embeddingDimension = 2
|
||||
public readonly embeddingDataType = new Float16()
|
||||
|
||||
async embed (data: string[]): Promise<number[][]> {
|
||||
return data.map(
|
||||
() => [0.0, 0.0]
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
class DummyEmbeddingWithNoDimension implements EmbeddingFunction<string> {
|
||||
public readonly sourceColumn = 'string'
|
||||
|
||||
async embed (data: string[]): Promise<number[][]> {
|
||||
return data.map(
|
||||
() => [0.0, 0.0]
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
describe('convertToTable', function () {
|
||||
it('will infer data types correctly', async function () {
|
||||
await checkTableCreation(async (records) => await convertToTable(records))
|
||||
})
|
||||
|
||||
it('will allow a schema to be provided', async function () {
|
||||
await checkTableCreation(async (records, _, schema) => await convertToTable(records, undefined, { schema }))
|
||||
})
|
||||
|
||||
it('will use the field order of any provided schema', async function () {
|
||||
await checkTableCreation(async (_, recordsReversed, schema) => await convertToTable(recordsReversed, undefined, { schema }))
|
||||
})
|
||||
|
||||
it('will make an empty table', async function () {
|
||||
await checkTableCreation(async (_, __, schema) => await convertToTable([], undefined, { schema }))
|
||||
})
|
||||
|
||||
it('will apply embeddings', async function () {
|
||||
const records = sampleRecords()
|
||||
const table = await convertToTable(records, new DummyEmbedding())
|
||||
assert.isTrue(DataType.isFixedSizeList(table.getChild('vector')?.type))
|
||||
assert.equal(table.getChild('vector')?.type.children[0].type.toString(), new Float16().toString())
|
||||
})
|
||||
|
||||
it('will fail if missing the embedding source column', async function () {
|
||||
return await expect(convertToTable([{ id: 1 }], new DummyEmbedding())).to.be.rejectedWith("'string' was not present")
|
||||
})
|
||||
|
||||
it('use embeddingDimension if embedding missing from table', async function () {
|
||||
const schema = new Schema([
|
||||
new Field('string', new Utf8(), false)
|
||||
])
|
||||
// Simulate getting an empty Arrow table (minus embedding) from some other source
|
||||
// In other words, we aren't starting with records
|
||||
const table = makeEmptyTable(schema)
|
||||
|
||||
// If the embedding specifies the dimension we are fine
|
||||
await fromTableToBuffer(table, new DummyEmbedding())
|
||||
|
||||
// We can also supply a schema and should be ok
|
||||
const schemaWithEmbedding = new Schema([
|
||||
new Field('string', new Utf8(), false),
|
||||
new Field('vector', new FixedSizeList(2, new Field('item', new Float16(), false)), false)
|
||||
])
|
||||
await fromTableToBuffer(table, new DummyEmbeddingWithNoDimension(), schemaWithEmbedding)
|
||||
|
||||
// Otherwise we will get an error
|
||||
return await expect(fromTableToBuffer(table, new DummyEmbeddingWithNoDimension())).to.be.rejectedWith('does not specify `embeddingDimension`')
|
||||
})
|
||||
|
||||
it('will apply embeddings to an empty table', async function () {
|
||||
const schema = new Schema([
|
||||
new Field('string', new Utf8(), false),
|
||||
new Field('vector', new FixedSizeList(2, new Field('item', new Float16(), false)), false)
|
||||
])
|
||||
const table = await convertToTable([], new DummyEmbedding(), { schema })
|
||||
assert.isTrue(DataType.isFixedSizeList(table.getChild('vector')?.type))
|
||||
assert.equal(table.getChild('vector')?.type.children[0].type.toString(), new Float16().toString())
|
||||
})
|
||||
|
||||
it('will complain if embeddings present but schema missing embedding column', async function () {
|
||||
const schema = new Schema([
|
||||
new Field('string', new Utf8(), false)
|
||||
])
|
||||
return await expect(convertToTable([], new DummyEmbedding(), { schema })).to.be.rejectedWith('column vector was missing')
|
||||
})
|
||||
|
||||
it('will provide a nice error if run twice', async function () {
|
||||
const records = sampleRecords()
|
||||
const table = await convertToTable(records, new DummyEmbedding())
|
||||
// fromTableToBuffer will try and apply the embeddings again
|
||||
return await expect(fromTableToBuffer(table, new DummyEmbedding())).to.be.rejectedWith('already existed')
|
||||
})
|
||||
})
|
||||
|
||||
describe('makeEmptyTable', function () {
|
||||
it('will make an empty table', async function () {
|
||||
await checkTableCreation(async (_, __, schema) => makeEmptyTable(schema))
|
||||
})
|
||||
})
|
||||
|
||||
@@ -9,6 +9,6 @@
|
||||
"declaration": true,
|
||||
"outDir": "./dist",
|
||||
"strict": true,
|
||||
// "esModuleInterop": true,
|
||||
"sourceMap": true,
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user