fix(nodejs): add better error handling when missing embedding functions (#1290)

note: 
running the default lint command `npm run lint -- --fix` seems to have
made a lot of unrelated changes.
This commit is contained in:
Cory Grinstead
2024-05-14 08:43:39 -05:00
committed by GitHub
parent df9c41f342
commit bc582bb702
5 changed files with 1242 additions and 1017 deletions

View File

@@ -27,23 +27,23 @@ import {
RecordBatch,
makeData,
Struct,
Float,
type Float,
DataType,
Binary,
Float32
} from 'apache-arrow'
import { type EmbeddingFunction } from './index'
import { sanitizeSchema } from './sanitize'
} from "apache-arrow";
import { type EmbeddingFunction } from "./index";
import { sanitizeSchema } from "./sanitize";
/*
* Options to control how a column should be converted to a vector array
*/
export class VectorColumnOptions {
/** Vector column type. */
type: Float = new Float32()
type: Float = new Float32();
constructor (values?: Partial<VectorColumnOptions>) {
Object.assign(this, values)
constructor(values?: Partial<VectorColumnOptions>) {
Object.assign(this, values);
}
}
@@ -60,7 +60,7 @@ export class MakeArrowTableOptions {
* The schema must be specified if there are no records (e.g. to make
* an empty table)
*/
schema?: Schema
schema?: Schema;
/*
* Mapping from vector column name to expected type
@@ -80,7 +80,9 @@ export class MakeArrowTableOptions {
*/
vectorColumns: Record<string, VectorColumnOptions> = {
vector: new VectorColumnOptions()
}
};
embeddings?: EmbeddingFunction<any>;
/**
* If true then string columns will be encoded with dictionary encoding
@@ -91,10 +93,10 @@ export class MakeArrowTableOptions {
*
* If `schema` is provided then this property is ignored.
*/
dictionaryEncodeStrings: boolean = false
dictionaryEncodeStrings: boolean = false;
constructor (values?: Partial<MakeArrowTableOptions>) {
Object.assign(this, values)
constructor(values?: Partial<MakeArrowTableOptions>) {
Object.assign(this, values);
}
}
@@ -193,59 +195,68 @@ export class MakeArrowTableOptions {
* assert.deepEqual(table.schema, schema)
* ```
*/
export function makeArrowTable (
export function makeArrowTable(
data: Array<Record<string, any>>,
options?: Partial<MakeArrowTableOptions>
): ArrowTable {
if (data.length === 0 && (options?.schema === undefined || options?.schema === null)) {
throw new Error('At least one record or a schema needs to be provided')
if (
data.length === 0 &&
(options?.schema === undefined || options?.schema === null)
) {
throw new Error("At least one record or a schema needs to be provided");
}
const opt = new MakeArrowTableOptions(options !== undefined ? options : {})
const opt = new MakeArrowTableOptions(options !== undefined ? options : {});
if (opt.schema !== undefined && opt.schema !== null) {
opt.schema = sanitizeSchema(opt.schema)
opt.schema = sanitizeSchema(opt.schema);
opt.schema = validateSchemaEmbeddings(opt.schema, data, opt.embeddings);
}
const columns: Record<string, Vector> = {}
const columns: Record<string, Vector> = {};
// TODO: sample dataset to find missing columns
// Prefer the field ordering of the schema, if present
const columnNames = ((opt.schema) != null) ? (opt.schema.names as string[]) : Object.keys(data[0])
const columnNames =
opt.schema != null ? (opt.schema.names as string[]) : Object.keys(data[0]);
for (const colName of columnNames) {
if (data.length !== 0 && !Object.prototype.hasOwnProperty.call(data[0], colName)) {
if (
data.length !== 0 &&
!Object.prototype.hasOwnProperty.call(data[0], colName)
) {
// The field is present in the schema, but not in the data, skip it
continue
continue;
}
// Extract a single column from the records (transpose from row-major to col-major)
let values = data.map((datum) => datum[colName])
let values = data.map((datum) => datum[colName]);
// By default (type === undefined) arrow will infer the type from the JS type
let type
let type;
if (opt.schema !== undefined) {
// If there is a schema provided, then use that for the type instead
type = opt.schema?.fields.filter((f) => f.name === colName)[0]?.type
type = opt.schema?.fields.filter((f) => f.name === colName)[0]?.type;
if (DataType.isInt(type) && type.bitWidth === 64) {
// wrap in BigInt to avoid bug: https://github.com/apache/arrow/issues/40051
values = values.map((v) => {
if (v === null) {
return v
return v;
}
return BigInt(v)
})
return BigInt(v);
});
}
} else {
// Otherwise, check to see if this column is one of the vector columns
// defined by opt.vectorColumns and, if so, use the fixed size list type
const vectorColumnOptions = opt.vectorColumns[colName]
const vectorColumnOptions = opt.vectorColumns[colName];
if (vectorColumnOptions !== undefined) {
type = newVectorType(values[0].length, vectorColumnOptions.type)
type = newVectorType(values[0].length, vectorColumnOptions.type);
}
}
try {
// Convert an Array of JS values to an arrow vector
columns[colName] = makeVector(values, type, opt.dictionaryEncodeStrings)
columns[colName] = makeVector(values, type, opt.dictionaryEncodeStrings);
} catch (error: unknown) {
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
throw Error(`Could not convert column "${colName}" to Arrow: ${error}`)
throw Error(`Could not convert column "${colName}" to Arrow: ${error}`);
}
}
@@ -260,97 +271,116 @@ export function makeArrowTable (
// To work around this we first create a table with the wrong schema and
// then patch the schema of the batches so we can use
// `new ArrowTable(schema, batches)` which does not do any schema inference
const firstTable = new ArrowTable(columns)
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const batchesFixed = firstTable.batches.map(batch => new RecordBatch(opt.schema!, batch.data))
return new ArrowTable(opt.schema, batchesFixed)
const firstTable = new ArrowTable(columns);
const batchesFixed = firstTable.batches.map(
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
(batch) => new RecordBatch(opt.schema!, batch.data)
);
return new ArrowTable(opt.schema, batchesFixed);
} else {
return new ArrowTable(columns)
return new ArrowTable(columns);
}
}
/**
* Create an empty Arrow table with the provided schema
*/
export function makeEmptyTable (schema: Schema): ArrowTable {
return makeArrowTable([], { schema })
export function makeEmptyTable(schema: Schema): ArrowTable {
return makeArrowTable([], { schema });
}
// Helper function to convert Array<Array<any>> to a variable sized list array
function makeListVector (lists: any[][]): Vector<any> {
function makeListVector(lists: any[][]): Vector<any> {
if (lists.length === 0 || lists[0].length === 0) {
throw Error('Cannot infer list vector from empty array or empty list')
throw Error("Cannot infer list vector from empty array or empty list");
}
const sampleList = lists[0]
let inferredType
const sampleList = lists[0];
let inferredType;
try {
const sampleVector = makeVector(sampleList)
inferredType = sampleVector.type
const sampleVector = makeVector(sampleList);
inferredType = sampleVector.type;
} catch (error: unknown) {
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
throw Error(`Cannot infer list vector. Cannot infer inner type: ${error}`)
throw Error(`Cannot infer list vector. Cannot infer inner type: ${error}`);
}
const listBuilder = makeBuilder({
type: new List(new Field('item', inferredType, true))
})
type: new List(new Field("item", inferredType, true))
});
for (const list of lists) {
listBuilder.append(list)
listBuilder.append(list);
}
return listBuilder.finish().toVector()
return listBuilder.finish().toVector();
}
// Helper function to convert an Array of JS values to an Arrow Vector
function makeVector (values: any[], type?: DataType, stringAsDictionary?: boolean): Vector<any> {
function makeVector(
values: any[],
type?: DataType,
stringAsDictionary?: boolean
): Vector<any> {
if (type !== undefined) {
// No need for inference, let Arrow create it
return vectorFromArray(values, type)
return vectorFromArray(values, type);
}
if (values.length === 0) {
throw Error('makeVector requires at least one value or the type must be specfied')
throw Error(
"makeVector requires at least one value or the type must be specfied"
);
}
const sampleValue = values.find(val => val !== null && val !== undefined)
const sampleValue = values.find((val) => val !== null && val !== undefined);
if (sampleValue === undefined) {
throw Error('makeVector cannot infer the type if all values are null or undefined')
throw Error(
"makeVector cannot infer the type if all values are null or undefined"
);
}
if (Array.isArray(sampleValue)) {
// Default Arrow inference doesn't handle list types
return makeListVector(values)
return makeListVector(values);
} else if (Buffer.isBuffer(sampleValue)) {
// Default Arrow inference doesn't handle Buffer
return vectorFromArray(values, new Binary())
} else if (!(stringAsDictionary ?? false) && (typeof sampleValue === 'string' || sampleValue instanceof String)) {
return vectorFromArray(values, new Binary());
} else if (
!(stringAsDictionary ?? false) &&
(typeof sampleValue === "string" || sampleValue instanceof String)
) {
// If the type is string then don't use Arrow's default inference unless dictionaries are requested
// because it will always use dictionary encoding for strings
return vectorFromArray(values, new Utf8())
return vectorFromArray(values, new Utf8());
} else {
// Convert a JS array of values to an arrow vector
return vectorFromArray(values)
return vectorFromArray(values);
}
}
async function applyEmbeddings<T> (table: ArrowTable, embeddings?: EmbeddingFunction<T>, schema?: Schema): Promise<ArrowTable> {
async function applyEmbeddings<T>(
table: ArrowTable,
embeddings?: EmbeddingFunction<T>,
schema?: Schema
): Promise<ArrowTable> {
if (embeddings == null) {
return table
return table;
}
if (schema !== undefined && schema !== null) {
schema = sanitizeSchema(schema)
schema = sanitizeSchema(schema);
}
// Convert from ArrowTable to Record<String, Vector>
const colEntries = [...Array(table.numCols).keys()].map((_, idx) => {
const name = table.schema.fields[idx].name
const name = table.schema.fields[idx].name;
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const vec = table.getChildAt(idx)!
return [name, vec]
})
const newColumns = Object.fromEntries(colEntries)
const vec = table.getChildAt(idx)!;
return [name, vec];
});
const newColumns = Object.fromEntries(colEntries);
const sourceColumn = newColumns[embeddings.sourceColumn]
const destColumn = embeddings.destColumn ?? 'vector'
const innerDestType = embeddings.embeddingDataType ?? new Float32()
const sourceColumn = newColumns[embeddings.sourceColumn];
const destColumn = embeddings.destColumn ?? "vector";
const innerDestType = embeddings.embeddingDataType ?? new Float32();
if (sourceColumn === undefined) {
throw new Error(`Cannot apply embedding function because the source column '${embeddings.sourceColumn}' was not present in the data`)
throw new Error(
`Cannot apply embedding function because the source column '${embeddings.sourceColumn}' was not present in the data`
);
}
if (table.numRows === 0) {
@@ -358,45 +388,60 @@ async function applyEmbeddings<T> (table: ArrowTable, embeddings?: EmbeddingFunc
// We have an empty table and it already has the embedding column so no work needs to be done
// Note: we don't return an error like we did below because this is a common occurrence. For example,
// if we call convertToTable with 0 records and a schema that includes the embedding
return table
return table;
}
if (embeddings.embeddingDimension !== undefined) {
const destType = newVectorType(embeddings.embeddingDimension, innerDestType)
newColumns[destColumn] = makeVector([], destType)
const destType = newVectorType(
embeddings.embeddingDimension,
innerDestType
);
newColumns[destColumn] = makeVector([], destType);
} else if (schema != null) {
const destField = schema.fields.find(f => f.name === destColumn)
const destField = schema.fields.find((f) => f.name === destColumn);
if (destField != null) {
newColumns[destColumn] = makeVector([], destField.type)
newColumns[destColumn] = makeVector([], destField.type);
} else {
throw new Error(`Attempt to apply embeddings to an empty table failed because schema was missing embedding column '${destColumn}'`)
throw new Error(
`Attempt to apply embeddings to an empty table failed because schema was missing embedding column '${destColumn}'`
);
}
} else {
throw new Error('Attempt to apply embeddings to an empty table when the embeddings function does not specify `embeddingDimension`')
throw new Error(
"Attempt to apply embeddings to an empty table when the embeddings function does not specify `embeddingDimension`"
);
}
} else {
if (Object.prototype.hasOwnProperty.call(newColumns, destColumn)) {
throw new Error(`Attempt to apply embeddings to table failed because column ${destColumn} already existed`)
throw new Error(
`Attempt to apply embeddings to table failed because column ${destColumn} already existed`
);
}
if (table.batches.length > 1) {
throw new Error('Internal error: `makeArrowTable` unexpectedly created a table with more than one batch')
throw new Error(
"Internal error: `makeArrowTable` unexpectedly created a table with more than one batch"
);
}
const values = sourceColumn.toArray()
const vectors = await embeddings.embed(values as T[])
const values = sourceColumn.toArray();
const vectors = await embeddings.embed(values as T[]);
if (vectors.length !== values.length) {
throw new Error('Embedding function did not return an embedding for each input element')
throw new Error(
"Embedding function did not return an embedding for each input element"
);
}
const destType = newVectorType(vectors[0].length, innerDestType)
newColumns[destColumn] = makeVector(vectors, destType)
const destType = newVectorType(vectors[0].length, innerDestType);
newColumns[destColumn] = makeVector(vectors, destType);
}
const newTable = new ArrowTable(newColumns)
const newTable = new ArrowTable(newColumns);
if (schema != null) {
if (schema.fields.find(f => f.name === destColumn) === undefined) {
throw new Error(`When using embedding functions and specifying a schema the schema should include the embedding column but the column ${destColumn} was missing`)
if (schema.fields.find((f) => f.name === destColumn) === undefined) {
throw new Error(
`When using embedding functions and specifying a schema the schema should include the embedding column but the column ${destColumn} was missing`
);
}
return alignTable(newTable, schema)
return alignTable(newTable, schema);
}
return newTable
return newTable;
}
/*
@@ -417,21 +462,24 @@ async function applyEmbeddings<T> (table: ArrowTable, embeddings?: EmbeddingFunc
* embedding columns. If no schema is provded then embedding columns will
* be placed at the end of the table, after all of the input columns.
*/
export async function convertToTable<T> (
export async function convertToTable<T>(
data: Array<Record<string, unknown>>,
embeddings?: EmbeddingFunction<T>,
makeTableOptions?: Partial<MakeArrowTableOptions>
): Promise<ArrowTable> {
const table = makeArrowTable(data, makeTableOptions)
return await applyEmbeddings(table, embeddings, makeTableOptions?.schema)
const table = makeArrowTable(data, makeTableOptions);
return await applyEmbeddings(table, embeddings, makeTableOptions?.schema);
}
// Creates the Arrow Type for a Vector column with dimension `dim`
function newVectorType <T extends Float> (dim: number, innerType: T): FixedSizeList<T> {
function newVectorType<T extends Float>(
dim: number,
innerType: T
): FixedSizeList<T> {
// Somewhere we always default to have the elements nullable, so we need to set it to true
// otherwise we often get schema mismatches because the stored data always has schema with nullable elements
const children = new Field<T>('item', innerType, true)
return new FixedSizeList(dim, children)
const children = new Field<T>("item", innerType, true);
return new FixedSizeList(dim, children);
}
/**
@@ -441,17 +489,17 @@ function newVectorType <T extends Float> (dim: number, innerType: T): FixedSizeL
*
* `schema` is required if data is empty
*/
export async function fromRecordsToBuffer<T> (
export async function fromRecordsToBuffer<T>(
data: Array<Record<string, unknown>>,
embeddings?: EmbeddingFunction<T>,
schema?: Schema
): Promise<Buffer> {
if (schema !== undefined && schema !== null) {
schema = sanitizeSchema(schema)
schema = sanitizeSchema(schema);
}
const table = await convertToTable(data, embeddings, { schema })
const writer = RecordBatchFileWriter.writeAll(table)
return Buffer.from(await writer.toUint8Array())
const table = await convertToTable(data, embeddings, { schema, embeddings });
const writer = RecordBatchFileWriter.writeAll(table);
return Buffer.from(await writer.toUint8Array());
}
/**
@@ -461,17 +509,17 @@ export async function fromRecordsToBuffer<T> (
*
* `schema` is required if data is empty
*/
export async function fromRecordsToStreamBuffer<T> (
export async function fromRecordsToStreamBuffer<T>(
data: Array<Record<string, unknown>>,
embeddings?: EmbeddingFunction<T>,
schema?: Schema
): Promise<Buffer> {
if (schema !== null && schema !== undefined) {
schema = sanitizeSchema(schema)
schema = sanitizeSchema(schema);
}
const table = await convertToTable(data, embeddings, { schema })
const writer = RecordBatchStreamWriter.writeAll(table)
return Buffer.from(await writer.toUint8Array())
const table = await convertToTable(data, embeddings, { schema });
const writer = RecordBatchStreamWriter.writeAll(table);
return Buffer.from(await writer.toUint8Array());
}
/**
@@ -482,17 +530,17 @@ export async function fromRecordsToStreamBuffer<T> (
*
* `schema` is required if the table is empty
*/
export async function fromTableToBuffer<T> (
export async function fromTableToBuffer<T>(
table: ArrowTable,
embeddings?: EmbeddingFunction<T>,
schema?: Schema
): Promise<Buffer> {
if (schema !== null && schema !== undefined) {
schema = sanitizeSchema(schema)
schema = sanitizeSchema(schema);
}
const tableWithEmbeddings = await applyEmbeddings(table, embeddings, schema)
const writer = RecordBatchFileWriter.writeAll(tableWithEmbeddings)
return Buffer.from(await writer.toUint8Array())
const tableWithEmbeddings = await applyEmbeddings(table, embeddings, schema);
const writer = RecordBatchFileWriter.writeAll(tableWithEmbeddings);
return Buffer.from(await writer.toUint8Array());
}
/**
@@ -503,49 +551,87 @@ export async function fromTableToBuffer<T> (
*
* `schema` is required if the table is empty
*/
export async function fromTableToStreamBuffer<T> (
export async function fromTableToStreamBuffer<T>(
table: ArrowTable,
embeddings?: EmbeddingFunction<T>,
schema?: Schema
): Promise<Buffer> {
if (schema !== null && schema !== undefined) {
schema = sanitizeSchema(schema)
schema = sanitizeSchema(schema);
}
const tableWithEmbeddings = await applyEmbeddings(table, embeddings, schema)
const writer = RecordBatchStreamWriter.writeAll(tableWithEmbeddings)
return Buffer.from(await writer.toUint8Array())
const tableWithEmbeddings = await applyEmbeddings(table, embeddings, schema);
const writer = RecordBatchStreamWriter.writeAll(tableWithEmbeddings);
return Buffer.from(await writer.toUint8Array());
}
function alignBatch (batch: RecordBatch, schema: Schema): RecordBatch {
const alignedChildren = []
function alignBatch(batch: RecordBatch, schema: Schema): RecordBatch {
const alignedChildren = [];
for (const field of schema.fields) {
const indexInBatch = batch.schema.fields?.findIndex(
(f) => f.name === field.name
)
);
if (indexInBatch < 0) {
throw new Error(
`The column ${field.name} was not found in the Arrow Table`
)
);
}
alignedChildren.push(batch.data.children[indexInBatch])
alignedChildren.push(batch.data.children[indexInBatch]);
}
const newData = makeData({
type: new Struct(schema.fields),
length: batch.numRows,
nullCount: batch.nullCount,
children: alignedChildren
})
return new RecordBatch(schema, newData)
});
return new RecordBatch(schema, newData);
}
function alignTable (table: ArrowTable, schema: Schema): ArrowTable {
function alignTable(table: ArrowTable, schema: Schema): ArrowTable {
const alignedBatches = table.batches.map((batch) =>
alignBatch(batch, schema)
)
return new ArrowTable(schema, alignedBatches)
);
return new ArrowTable(schema, alignedBatches);
}
// Creates an empty Arrow Table
export function createEmptyTable (schema: Schema): ArrowTable {
return new ArrowTable(sanitizeSchema(schema))
export function createEmptyTable(schema: Schema): ArrowTable {
return new ArrowTable(sanitizeSchema(schema));
}
function validateSchemaEmbeddings(
schema: Schema<any>,
data: Array<Record<string, unknown>>,
embeddings: EmbeddingFunction<any> | undefined
) {
const fields = [];
const missingEmbeddingFields = [];
// First we check if the field is a `FixedSizeList`
// Then we check if the data contains the field
// if it does not, we add it to the list of missing embedding fields
// Finally, we check if those missing embedding fields are `this._embeddings`
// if they are not, we throw an error
for (const field of schema.fields) {
if (field.type instanceof FixedSizeList) {
if (data.length !== 0 && data?.[0]?.[field.name] === undefined) {
missingEmbeddingFields.push(field);
} else {
fields.push(field);
}
} else {
fields.push(field);
}
}
if (missingEmbeddingFields.length > 0 && embeddings === undefined) {
console.log({ missingEmbeddingFields, embeddings });
throw new Error(
`Table has embeddings: "${missingEmbeddingFields
.map((f) => f.name)
.join(",")}", but no embedding function was provided`
);
}
return new Schema(fields);
}