diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 8b5f44c9..b11b6d9f 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -245,7 +245,6 @@ def _sanitize_data( target_schema = target_schema.with_metadata(new_metadata) _validate_schema(target_schema) - reader = _cast_to_target_schema(reader, target_schema, allow_subschema) return reader @@ -263,12 +262,7 @@ def _cast_to_target_schema( # Fast path when the schemas are already the same return reader - fields = [] - for field in reader.schema: - target_field = target_schema.field(field.name) - if target_field is None: - raise ValueError(f"Field {field.name} not found in target schema") - fields.append(target_field) + fields = _align_field_types(list(iter(reader.schema)), list(iter(target_schema))) reordered_schema = pa.schema(fields, metadata=target_schema.metadata) if not allow_subschema and len(reordered_schema) != len(target_schema): raise ValueError( @@ -289,6 +283,53 @@ def _cast_to_target_schema( return pa.RecordBatchReader.from_batches(reordered_schema, gen()) +def _align_field_types( + fields: List[pa.Field], + target_fields: List[pa.Field], +) -> List[pa.Field]: + """ + Apply the data types from the target_fields to the fields. + """ + new_fields = [] + for field in fields: + target_field = next((f for f in target_fields if f.name == field.name), None) + if target_field is None: + raise ValueError(f"Field '{field.name}' not found in target schema") + if pa.types.is_struct(target_field.type): + new_type = pa.struct( + _align_field_types( + field.type.fields, + target_field.type.fields, + ) + ) + elif pa.types.is_list(target_field.type): + new_type = pa.list_( + _align_field_types( + [field.type.value_field], + [target_field.type.value_field], + )[0] + ) + elif pa.types.is_large_list(target_field.type): + new_type = pa.large_list( + _align_field_types( + [field.type.value_field], + [target_field.type.value_field], + )[0] + ) + elif pa.types.is_fixed_size_list(target_field.type): + new_type = pa.list_( + _align_field_types( + [field.type.value_field], + [target_field.type.value_field], + )[0], + target_field.type.list_size, + ) + else: + new_type = target_field.type + new_fields.append(pa.field(field.name, new_type, field.nullable)) + return new_fields + + def _infer_subschema( schema: List[pa.Field], reference_fields: List[pa.Field], diff --git a/python/python/tests/test_table.py b/python/python/tests/test_table.py index d1da31ca..ad5c6fe3 100644 --- a/python/python/tests/test_table.py +++ b/python/python/tests/test_table.py @@ -231,6 +231,59 @@ def test_add(mem_db: DBConnection): _add(table, schema) +def test_add_struct(mem_db: DBConnection): + # https://github.com/lancedb/lancedb/issues/2114 + schema = pa.schema( + [ + ( + "stuff", + pa.struct( + [ + ("b", pa.int64()), + ("a", pa.int64()), + # TODO: also test subset of nested. + ] + ), + ) + ] + ) + + # Create test data with fields in same order + data = [{"stuff": {"b": 1, "a": 2}}] + # pa.Table.from_pylist() will reorder the fields. We need to make sure + # we fix the field order later, before casting. + table = mem_db.create_table("test", schema=schema) + table.add(data) + + data = [{"stuff": {"b": 4}}] + table.add(data) + + expected = pa.table( + { + "stuff": [{"b": 1, "a": 2}, {"b": 4, "a": None}], + }, + schema=schema, + ) + assert table.to_arrow() == expected + + # Also check struct in list + schema = pa.schema( + { + "s_list": pa.list_( + pa.struct( + [ + ("b", pa.int64()), + ("a", pa.int64()), + ] + ) + ) + } + ) + data = [{"s_list": [{"b": 1, "a": 2}, {"b": 4}]}] + table = mem_db.create_table("test", schema=schema) + table.add(data) + + def test_add_subschema(mem_db: DBConnection): schema = pa.schema( [ @@ -324,7 +377,10 @@ def test_add_nullability(mem_db: DBConnection): # We can't add nullable schema if it contains nulls with pytest.raises( Exception, - match="Casting field 'vector' with null values to non-nullable", + match=( + "The field `vector` contained null values even though " + "the field is marked non-null in the schema" + ), ): table.add(data)