mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-31 17:02:58 +00:00
fix: insert structs in non-alphabetical order (#2222)
Closes #2114 Starting in #1965, we no longer pass the table schema into `pa.Table.from_pylist()`. This means PyArrow is choosing the order of the struct subfields, and apparently it does them in alphabetical order. This is fine in theory, since in Lance we support providing fields in any order. However, before we pass it to Lance, we call `pa.Table.cast()` to align column types to the table types. `pa.Table.cast()` is strict about field order, so we need to create a cast target schema that aligns with the input data. We were doing this at the top-level fields, but weren't doing this in nested fields. This PR adds support to do this for nested ones.
This commit is contained in:
@@ -245,7 +245,6 @@ def _sanitize_data(
|
||||
target_schema = target_schema.with_metadata(new_metadata)
|
||||
|
||||
_validate_schema(target_schema)
|
||||
|
||||
reader = _cast_to_target_schema(reader, target_schema, allow_subschema)
|
||||
|
||||
return reader
|
||||
@@ -263,12 +262,7 @@ def _cast_to_target_schema(
|
||||
# Fast path when the schemas are already the same
|
||||
return reader
|
||||
|
||||
fields = []
|
||||
for field in reader.schema:
|
||||
target_field = target_schema.field(field.name)
|
||||
if target_field is None:
|
||||
raise ValueError(f"Field {field.name} not found in target schema")
|
||||
fields.append(target_field)
|
||||
fields = _align_field_types(list(iter(reader.schema)), list(iter(target_schema)))
|
||||
reordered_schema = pa.schema(fields, metadata=target_schema.metadata)
|
||||
if not allow_subschema and len(reordered_schema) != len(target_schema):
|
||||
raise ValueError(
|
||||
@@ -289,6 +283,53 @@ def _cast_to_target_schema(
|
||||
return pa.RecordBatchReader.from_batches(reordered_schema, gen())
|
||||
|
||||
|
||||
def _align_field_types(
|
||||
fields: List[pa.Field],
|
||||
target_fields: List[pa.Field],
|
||||
) -> List[pa.Field]:
|
||||
"""
|
||||
Apply the data types from the target_fields to the fields.
|
||||
"""
|
||||
new_fields = []
|
||||
for field in fields:
|
||||
target_field = next((f for f in target_fields if f.name == field.name), None)
|
||||
if target_field is None:
|
||||
raise ValueError(f"Field '{field.name}' not found in target schema")
|
||||
if pa.types.is_struct(target_field.type):
|
||||
new_type = pa.struct(
|
||||
_align_field_types(
|
||||
field.type.fields,
|
||||
target_field.type.fields,
|
||||
)
|
||||
)
|
||||
elif pa.types.is_list(target_field.type):
|
||||
new_type = pa.list_(
|
||||
_align_field_types(
|
||||
[field.type.value_field],
|
||||
[target_field.type.value_field],
|
||||
)[0]
|
||||
)
|
||||
elif pa.types.is_large_list(target_field.type):
|
||||
new_type = pa.large_list(
|
||||
_align_field_types(
|
||||
[field.type.value_field],
|
||||
[target_field.type.value_field],
|
||||
)[0]
|
||||
)
|
||||
elif pa.types.is_fixed_size_list(target_field.type):
|
||||
new_type = pa.list_(
|
||||
_align_field_types(
|
||||
[field.type.value_field],
|
||||
[target_field.type.value_field],
|
||||
)[0],
|
||||
target_field.type.list_size,
|
||||
)
|
||||
else:
|
||||
new_type = target_field.type
|
||||
new_fields.append(pa.field(field.name, new_type, field.nullable))
|
||||
return new_fields
|
||||
|
||||
|
||||
def _infer_subschema(
|
||||
schema: List[pa.Field],
|
||||
reference_fields: List[pa.Field],
|
||||
|
||||
@@ -231,6 +231,59 @@ def test_add(mem_db: DBConnection):
|
||||
_add(table, schema)
|
||||
|
||||
|
||||
def test_add_struct(mem_db: DBConnection):
|
||||
# https://github.com/lancedb/lancedb/issues/2114
|
||||
schema = pa.schema(
|
||||
[
|
||||
(
|
||||
"stuff",
|
||||
pa.struct(
|
||||
[
|
||||
("b", pa.int64()),
|
||||
("a", pa.int64()),
|
||||
# TODO: also test subset of nested.
|
||||
]
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Create test data with fields in same order
|
||||
data = [{"stuff": {"b": 1, "a": 2}}]
|
||||
# pa.Table.from_pylist() will reorder the fields. We need to make sure
|
||||
# we fix the field order later, before casting.
|
||||
table = mem_db.create_table("test", schema=schema)
|
||||
table.add(data)
|
||||
|
||||
data = [{"stuff": {"b": 4}}]
|
||||
table.add(data)
|
||||
|
||||
expected = pa.table(
|
||||
{
|
||||
"stuff": [{"b": 1, "a": 2}, {"b": 4, "a": None}],
|
||||
},
|
||||
schema=schema,
|
||||
)
|
||||
assert table.to_arrow() == expected
|
||||
|
||||
# Also check struct in list
|
||||
schema = pa.schema(
|
||||
{
|
||||
"s_list": pa.list_(
|
||||
pa.struct(
|
||||
[
|
||||
("b", pa.int64()),
|
||||
("a", pa.int64()),
|
||||
]
|
||||
)
|
||||
)
|
||||
}
|
||||
)
|
||||
data = [{"s_list": [{"b": 1, "a": 2}, {"b": 4}]}]
|
||||
table = mem_db.create_table("test", schema=schema)
|
||||
table.add(data)
|
||||
|
||||
|
||||
def test_add_subschema(mem_db: DBConnection):
|
||||
schema = pa.schema(
|
||||
[
|
||||
@@ -324,7 +377,10 @@ def test_add_nullability(mem_db: DBConnection):
|
||||
# We can't add nullable schema if it contains nulls
|
||||
with pytest.raises(
|
||||
Exception,
|
||||
match="Casting field 'vector' with null values to non-nullable",
|
||||
match=(
|
||||
"The field `vector` contained null values even though "
|
||||
"the field is marked non-null in the schema"
|
||||
),
|
||||
):
|
||||
table.add(data)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user