fix: handle input with missing columns when using embedding functions (#2516)

## Summary

Fixes #2515 by implementing comprehensive support for missing columns in
Arrow table inputs when using embedding functions.

### Problem
Previously, when an Arrow table was passed to `fromDataToBuffer` with
missing columns and a schema containing embedding functions, the system
would fail because `applyEmbeddingsFromMetadata` expected all columns to
be present in the table.

🤖 Generated with [Claude Code](https://claude.ai/code)

---------

Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
Will Jones
2025-07-18 15:54:25 -07:00
committed by GitHub
parent b3a637fdeb
commit 88283110f4
2 changed files with 386 additions and 13 deletions

View File

@@ -1,7 +1,7 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
import { Schema } from "apache-arrow";
import { Bool, Field, Int32, List, Schema, Struct, Utf8 } from "apache-arrow";
import * as arrow15 from "apache-arrow-15";
import * as arrow16 from "apache-arrow-16";
@@ -11,10 +11,12 @@ import * as arrow18 from "apache-arrow-18";
import {
convertToTable,
fromBufferToRecordBatch,
fromDataToBuffer,
fromRecordBatchToBuffer,
fromTableToBuffer,
makeArrowTable,
makeEmptyTable,
tableFromIPC,
} from "../lancedb/arrow";
import {
EmbeddingFunction,
@@ -375,8 +377,221 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
expect(table2.schema).toEqual(schema);
});
it("will handle missing columns in schema alignment when using embeddings", async function () {
const schema = new Schema(
[
new Field("domain", new Utf8(), true),
new Field("name", new Utf8(), true),
new Field("description", new Utf8(), true),
],
new Map([["embedding_functions", JSON.stringify([])]]),
);
const data = [
{ domain: "google.com", name: "Google" },
{ domain: "facebook.com", name: "Facebook" },
];
const table = await convertToTable(data, undefined, { schema });
expect(table.numCols).toBe(3);
expect(table.numRows).toBe(2);
const descriptionColumn = table.getChild("description");
expect(descriptionColumn).toBeDefined();
expect(descriptionColumn?.nullCount).toBe(2);
expect(descriptionColumn?.toArray()).toEqual([null, null]);
expect(table.getChild("domain")?.toArray()).toEqual([
"google.com",
"facebook.com",
]);
expect(table.getChild("name")?.toArray()).toEqual([
"Google",
"Facebook",
]);
});
it("will handle completely missing nested struct columns", async function () {
const schema = new Schema(
[
new Field("id", new Utf8(), true),
new Field("name", new Utf8(), true),
new Field(
"metadata",
new Struct([
new Field("version", new Int32(), true),
new Field("author", new Utf8(), true),
new Field(
"tags",
new List(new Field("item", new Utf8(), true)),
true,
),
]),
true,
),
],
new Map([["embedding_functions", JSON.stringify([])]]),
);
const data = [
{ id: "doc1", name: "Document 1" },
{ id: "doc2", name: "Document 2" },
];
const table = await convertToTable(data, undefined, { schema });
expect(table.numCols).toBe(3);
expect(table.numRows).toBe(2);
const buf = await fromTableToBuffer(table);
const retrievedTable = tableFromIPC(buf);
const rows = [];
for (let i = 0; i < retrievedTable.numRows; i++) {
rows.push(retrievedTable.get(i));
}
expect(rows[0].metadata.version).toBe(null);
expect(rows[0].metadata.author).toBe(null);
expect(rows[0].metadata.tags).toBe(null);
expect(rows[0].id).toBe("doc1");
expect(rows[0].name).toBe("Document 1");
});
it("will handle partially missing nested struct fields", async function () {
const schema = new Schema(
[
new Field("id", new Utf8(), true),
new Field(
"metadata",
new Struct([
new Field("version", new Int32(), true),
new Field("author", new Utf8(), true),
new Field("created_at", new Utf8(), true),
]),
true,
),
],
new Map([["embedding_functions", JSON.stringify([])]]),
);
const data = [
{ id: "doc1", metadata: { version: 1, author: "Alice" } },
{ id: "doc2", metadata: { version: 2 } },
];
const table = await convertToTable(data, undefined, { schema });
expect(table.numCols).toBe(2);
expect(table.numRows).toBe(2);
const metadataColumn = table.getChild("metadata");
expect(metadataColumn).toBeDefined();
expect(metadataColumn?.type.toString()).toBe(
"Struct<{version:Int32, author:Utf8, created_at:Utf8}>",
);
});
it("will handle multiple levels of nested structures", async function () {
const schema = new Schema(
[
new Field("id", new Utf8(), true),
new Field(
"config",
new Struct([
new Field("database", new Utf8(), true),
new Field(
"connection",
new Struct([
new Field("host", new Utf8(), true),
new Field("port", new Int32(), true),
new Field(
"ssl",
new Struct([
new Field("enabled", new Bool(), true),
new Field("cert_path", new Utf8(), true),
]),
true,
),
]),
true,
),
]),
true,
),
],
new Map([["embedding_functions", JSON.stringify([])]]),
);
const data = [
{
id: "config1",
config: {
database: "postgres",
connection: { host: "localhost" },
},
},
{
id: "config2",
config: { database: "mysql" },
},
{
id: "config3",
},
];
const table = await convertToTable(data, undefined, { schema });
expect(table.numCols).toBe(2);
expect(table.numRows).toBe(3);
const configColumn = table.getChild("config");
expect(configColumn).toBeDefined();
expect(configColumn?.type.toString()).toBe(
"Struct<{database:Utf8, connection:Struct<{host:Utf8, port:Int32, ssl:Struct<{enabled:Bool, cert_path:Utf8}>}>}>",
);
});
it("will handle missing columns in Arrow table input when using embeddings", async function () {
const incompleteTable = makeArrowTable([
{ domain: "google.com", name: "Google" },
{ domain: "facebook.com", name: "Facebook" },
]);
const schema = new Schema(
[
new Field("domain", new Utf8(), true),
new Field("name", new Utf8(), true),
new Field("description", new Utf8(), true),
],
new Map([["embedding_functions", JSON.stringify([])]]),
);
const buf = await fromDataToBuffer(incompleteTable, undefined, schema);
expect(buf.byteLength).toBeGreaterThan(0);
const retrievedTable = tableFromIPC(buf);
expect(retrievedTable.numCols).toBe(3);
expect(retrievedTable.numRows).toBe(2);
const descriptionColumn = retrievedTable.getChild("description");
expect(descriptionColumn).toBeDefined();
expect(descriptionColumn?.nullCount).toBe(2);
expect(descriptionColumn?.toArray()).toEqual([null, null]);
expect(retrievedTable.getChild("domain")?.toArray()).toEqual([
"google.com",
"facebook.com",
]);
expect(retrievedTable.getChild("name")?.toArray()).toEqual([
"Google",
"Facebook",
]);
});
it("should correctly retain values in nested struct fields", async function () {
// Define test data with nested struct
const testData = [
{
id: "doc1",
@@ -400,10 +615,8 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
},
];
// Create Arrow table from the data
const table = makeArrowTable(testData);
// Verify schema has the nested struct fields
const metadataField = table.schema.fields.find(
(f) => f.name === "metadata",
);
@@ -417,23 +630,17 @@ describe.each([arrow15, arrow16, arrow17, arrow18])(
"text",
]);
// Convert to buffer and back (simulating storage and retrieval)
const buf = await fromTableToBuffer(table);
const retrievedTable = tableFromIPC(buf);
// Verify the retrieved table has the same structure
const rows = [];
for (let i = 0; i < retrievedTable.numRows; i++) {
rows.push(retrievedTable.get(i));
}
// Check values in the first row
const firstRow = rows[0];
expect(firstRow.id).toBe("doc1");
expect(firstRow.vector.toJSON()).toEqual([1, 2, 3]);
// Verify metadata values are preserved (this is where the bug is)
expect(firstRow.metadata).toBeDefined();
expect(firstRow.metadata.filePath).toBe("/path/to/file1.ts");
expect(firstRow.metadata.startLine).toBe(10);
expect(firstRow.metadata.endLine).toBe(20);

View File

@@ -839,6 +839,15 @@ async function applyEmbeddingsFromMetadata(
const vector = makeVector(vectors, destType);
columns[destColumn] = vector;
}
// Add any missing columns from the schema as null vectors
for (const field of schema.fields) {
if (!(field.name in columns)) {
const nullValues = new Array(table.numRows).fill(null);
columns[field.name] = makeVector(nullValues, field.type);
}
}
const newTable = new ArrowTable(columns);
return alignTable(newTable, schema);
}
@@ -987,7 +996,21 @@ export async function convertToTable(
embeddings?: EmbeddingFunctionConfig,
makeTableOptions?: Partial<MakeArrowTableOptions>,
): Promise<ArrowTable> {
const table = makeArrowTable(data, makeTableOptions);
let processedData = data;
// If we have a schema with embedding metadata, we need to preprocess the data
// to ensure all nested fields are present
if (
makeTableOptions?.schema &&
makeTableOptions.schema.metadata?.has("embedding_functions")
) {
processedData = ensureNestedFieldsExist(
data,
makeTableOptions.schema as Schema,
);
}
const table = makeArrowTable(processedData, makeTableOptions);
return await applyEmbeddings(table, embeddings, makeTableOptions?.schema);
}
@@ -1080,7 +1103,16 @@ export async function fromDataToBuffer(
schema = sanitizeSchema(schema);
}
if (isArrowTable(data)) {
return fromTableToBuffer(sanitizeTable(data), embeddings, schema);
const table = sanitizeTable(data);
// If we have a schema with embedding functions, we need to ensure all columns exist
// before applying embeddings, since applyEmbeddingsFromMetadata expects all columns
// to be present in the table
if (schema && schema.metadata?.has("embedding_functions")) {
const alignedTable = alignTableToSchema(table, schema);
return fromTableToBuffer(alignedTable, embeddings, schema);
} else {
return fromTableToBuffer(table, embeddings, schema);
}
} else {
const table = await convertToTable(data, embeddings, { schema });
return fromTableToBuffer(table);
@@ -1149,7 +1181,7 @@ function alignBatch(batch: RecordBatch, schema: Schema): RecordBatch {
type: new Struct(schema.fields),
length: batch.numRows,
nullCount: batch.nullCount,
children: alignedChildren,
children: alignedChildren as unknown as ArrowData<DataType>[],
});
return new RecordBatch(schema, newData);
}
@@ -1221,6 +1253,79 @@ function validateSchemaEmbeddings(
return new Schema(fields, schema.metadata);
}
/**
* Ensures that all nested fields defined in the schema exist in the data,
* filling missing fields with null values.
*/
export function ensureNestedFieldsExist(
data: Array<Record<string, unknown>>,
schema: Schema,
): Array<Record<string, unknown>> {
return data.map((row) => {
const completeRow: Record<string, unknown> = {};
for (const field of schema.fields) {
if (field.name in row) {
if (
field.type.constructor.name === "Struct" &&
row[field.name] !== null &&
row[field.name] !== undefined
) {
// Handle nested struct
const nestedValue = row[field.name] as Record<string, unknown>;
completeRow[field.name] = ensureStructFieldsExist(
nestedValue,
field.type,
);
} else {
// Non-struct field or null struct value
completeRow[field.name] = row[field.name];
}
} else {
// Field is missing from the data - set to null
completeRow[field.name] = null;
}
}
return completeRow;
});
}
/**
* Recursively ensures that all fields in a struct type exist in the data,
* filling missing fields with null values.
*/
function ensureStructFieldsExist(
data: Record<string, unknown>,
structType: Struct,
): Record<string, unknown> {
const completeStruct: Record<string, unknown> = {};
for (const childField of structType.children) {
if (childField.name in data) {
if (
childField.type.constructor.name === "Struct" &&
data[childField.name] !== null &&
data[childField.name] !== undefined
) {
// Recursively handle nested struct
completeStruct[childField.name] = ensureStructFieldsExist(
data[childField.name] as Record<string, unknown>,
childField.type,
);
} else {
// Non-struct field or null struct value
completeStruct[childField.name] = data[childField.name];
}
} else {
// Field is missing - set to null
completeStruct[childField.name] = null;
}
}
return completeStruct;
}
interface JsonDataType {
type: string;
fields?: JsonField[];
@@ -1354,3 +1459,64 @@ function fieldToJson(field: Field): JsonField {
metadata: field.metadata,
};
}
function alignTableToSchema(
table: ArrowTable,
targetSchema: Schema,
): ArrowTable {
const existingColumns = new Map<string, Vector>();
// Map existing columns
for (const field of table.schema.fields) {
existingColumns.set(field.name, table.getChild(field.name)!);
}
// Create vectors for all fields in target schema
const alignedColumns: Record<string, Vector> = {};
for (const field of targetSchema.fields) {
if (existingColumns.has(field.name)) {
// Column exists, use it
alignedColumns[field.name] = existingColumns.get(field.name)!;
} else {
// Column missing, create null vector
alignedColumns[field.name] = createNullVector(field, table.numRows);
}
}
// Create new table with aligned schema and columns
return new ArrowTable(targetSchema, alignedColumns);
}
function createNullVector(field: Field, numRows: number): Vector {
if (field.type.constructor.name === "Struct") {
// For struct types, create a struct with null fields
const structType = field.type as Struct;
const childVectors = structType.children.map((childField) =>
createNullVector(childField, numRows),
);
// Create struct data
const structData = makeData({
type: structType,
length: numRows,
nullCount: 0,
children: childVectors.map((v) => v.data[0]),
});
return arrowMakeVector(structData);
} else {
// For other types, create a vector of nulls
const nullBitmap = new Uint8Array(Math.ceil(numRows / 8));
// All bits are 0, meaning all values are null
const data = makeData({
type: field.type,
length: numRows,
nullCount: numRows,
nullBitmap,
});
return arrowMakeVector(data);
}
}