fix: insert structs in non-alphabetical order (#2222)

Closes #2114

Starting in #1965, we no longer pass the table schema into
`pa.Table.from_pylist()`. This means PyArrow is choosing the order of
the struct subfields, and apparently it does them in alphabetical order.
This is fine in theory, since in Lance we support providing fields in
any order. However, before we pass it to Lance, we call
`pa.Table.cast()` to align column types to the table types.
`pa.Table.cast()` is strict about field order, so we need to create a
cast target schema that aligns with the input data. We were doing this
at the top-level fields, but weren't doing this in nested fields. This
PR adds support to do this for nested ones.
This commit is contained in:
Will Jones
2025-03-13 14:46:05 -07:00
committed by GitHub
parent 6c321c694a
commit a207213358
2 changed files with 105 additions and 8 deletions

View File

@@ -245,7 +245,6 @@ def _sanitize_data(
target_schema = target_schema.with_metadata(new_metadata)
_validate_schema(target_schema)
reader = _cast_to_target_schema(reader, target_schema, allow_subschema)
return reader
@@ -263,12 +262,7 @@ def _cast_to_target_schema(
# Fast path when the schemas are already the same
return reader
fields = []
for field in reader.schema:
target_field = target_schema.field(field.name)
if target_field is None:
raise ValueError(f"Field {field.name} not found in target schema")
fields.append(target_field)
fields = _align_field_types(list(iter(reader.schema)), list(iter(target_schema)))
reordered_schema = pa.schema(fields, metadata=target_schema.metadata)
if not allow_subschema and len(reordered_schema) != len(target_schema):
raise ValueError(
@@ -289,6 +283,53 @@ def _cast_to_target_schema(
return pa.RecordBatchReader.from_batches(reordered_schema, gen())
def _align_field_types(
fields: List[pa.Field],
target_fields: List[pa.Field],
) -> List[pa.Field]:
"""
Apply the data types from the target_fields to the fields.
"""
new_fields = []
for field in fields:
target_field = next((f for f in target_fields if f.name == field.name), None)
if target_field is None:
raise ValueError(f"Field '{field.name}' not found in target schema")
if pa.types.is_struct(target_field.type):
new_type = pa.struct(
_align_field_types(
field.type.fields,
target_field.type.fields,
)
)
elif pa.types.is_list(target_field.type):
new_type = pa.list_(
_align_field_types(
[field.type.value_field],
[target_field.type.value_field],
)[0]
)
elif pa.types.is_large_list(target_field.type):
new_type = pa.large_list(
_align_field_types(
[field.type.value_field],
[target_field.type.value_field],
)[0]
)
elif pa.types.is_fixed_size_list(target_field.type):
new_type = pa.list_(
_align_field_types(
[field.type.value_field],
[target_field.type.value_field],
)[0],
target_field.type.list_size,
)
else:
new_type = target_field.type
new_fields.append(pa.field(field.name, new_type, field.nullable))
return new_fields
def _infer_subschema(
schema: List[pa.Field],
reference_fields: List[pa.Field],

View File

@@ -231,6 +231,59 @@ def test_add(mem_db: DBConnection):
_add(table, schema)
def test_add_struct(mem_db: DBConnection):
# https://github.com/lancedb/lancedb/issues/2114
schema = pa.schema(
[
(
"stuff",
pa.struct(
[
("b", pa.int64()),
("a", pa.int64()),
# TODO: also test subset of nested.
]
),
)
]
)
# Create test data with fields in same order
data = [{"stuff": {"b": 1, "a": 2}}]
# pa.Table.from_pylist() will reorder the fields. We need to make sure
# we fix the field order later, before casting.
table = mem_db.create_table("test", schema=schema)
table.add(data)
data = [{"stuff": {"b": 4}}]
table.add(data)
expected = pa.table(
{
"stuff": [{"b": 1, "a": 2}, {"b": 4, "a": None}],
},
schema=schema,
)
assert table.to_arrow() == expected
# Also check struct in list
schema = pa.schema(
{
"s_list": pa.list_(
pa.struct(
[
("b", pa.int64()),
("a", pa.int64()),
]
)
)
}
)
data = [{"s_list": [{"b": 1, "a": 2}, {"b": 4}]}]
table = mem_db.create_table("test", schema=schema)
table.add(data)
def test_add_subschema(mem_db: DBConnection):
schema = pa.schema(
[
@@ -324,7 +377,10 @@ def test_add_nullability(mem_db: DBConnection):
# We can't add nullable schema if it contains nulls
with pytest.raises(
Exception,
match="Casting field 'vector' with null values to non-nullable",
match=(
"The field `vector` contained null values even though "
"the field is marked non-null in the schema"
),
):
table.add(data)