mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-24 15:30:38 +00:00
fix(python): sanitize bad vectors before Arrow cast (#3158)
## Problem
`on_bad_vectors="drop"` is supposed to remove invalid vector rows before
write, but for some schema-defined vector columns it can still fail
later during Arrow cast instead of dropping the bad row.
Repro:
```python
class MySchema(LanceModel):
text: str
embedding: Vector(16)
table = db.create_table("test", schema=MySchema)
table.add(
[
{"text": "hello", "embedding": []},
{"text": "bar", "embedding": [0.1] * 16},
],
on_bad_vectors="drop",
)
```
Before:
```
RuntimeError
Arrow error: C Data interface error: Invalid: ListType can only be casted to FixedSizeListType if the lists are all the expected size.
```
After:
```
rows 1
texts ['bar']
```
## Solution
Make bad-vector sanitization use schema dimensions before cast, while
keeping the handling scoped to vector columns identified by schema
metadata or existing vector-name heuristics.
This also preserves existing integer vector inputs and avoids applying
on_bad_vectors to unrelated fixed-size float columns.
Fixes #1670
Signed-off-by: yaommen <myanstu@163.com>
This commit is contained in:
@@ -1049,6 +1049,231 @@ def test_add_with_nans(mem_db: DBConnection):
|
||||
assert np.allclose(v, np.array([0.0, 0.0]))
|
||||
|
||||
|
||||
def test_add_with_empty_fixed_size_list_drops_bad_rows(mem_db: DBConnection):
|
||||
class Schema(LanceModel):
|
||||
text: str
|
||||
embedding: Vector(16)
|
||||
|
||||
table = mem_db.create_table("test_empty_embeddings", schema=Schema)
|
||||
table.add(
|
||||
[
|
||||
{"text": "hello", "embedding": []},
|
||||
{"text": "bar", "embedding": [0.1] * 16},
|
||||
],
|
||||
on_bad_vectors="drop",
|
||||
)
|
||||
|
||||
data = table.to_arrow()
|
||||
assert data["text"].to_pylist() == ["bar"]
|
||||
assert np.allclose(data["embedding"].to_pylist()[0], np.array([0.1] * 16))
|
||||
|
||||
|
||||
def test_add_with_integer_embeddings_preserves_casting(mem_db: DBConnection):
|
||||
class Schema(LanceModel):
|
||||
text: str
|
||||
embedding: Vector(4)
|
||||
|
||||
table = mem_db.create_table("test_integer_embeddings", schema=Schema)
|
||||
table.add(
|
||||
[{"text": "foo", "embedding": [1, 2, 3, 4]}],
|
||||
on_bad_vectors="drop",
|
||||
)
|
||||
|
||||
assert table.to_arrow()["embedding"].to_pylist() == [[1.0, 2.0, 3.0, 4.0]]
|
||||
|
||||
|
||||
def test_on_bad_vectors_does_not_handle_non_vector_fixed_size_lists(
|
||||
mem_db: DBConnection,
|
||||
):
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("vector", pa.list_(pa.float32(), 4)),
|
||||
pa.field("bbox", pa.list_(pa.float32(), 4)),
|
||||
]
|
||||
)
|
||||
table = mem_db.create_table("test_bbox_schema", schema=schema)
|
||||
|
||||
with pytest.raises(RuntimeError, match="FixedSizeListType"):
|
||||
table.add(
|
||||
[{"vector": [1.0, 2.0, 3.0, 4.0], "bbox": [0.0, 1.0]}],
|
||||
on_bad_vectors="drop",
|
||||
)
|
||||
|
||||
|
||||
def test_on_bad_vectors_does_not_handle_custom_named_fixed_size_lists(
|
||||
mem_db: DBConnection,
|
||||
):
|
||||
schema = pa.schema([pa.field("features", pa.list_(pa.float32(), 16))])
|
||||
table = mem_db.create_table("test_custom_named_fixed_size_vector", schema=schema)
|
||||
|
||||
with pytest.raises(RuntimeError, match="FixedSizeListType"):
|
||||
table.add(
|
||||
[
|
||||
{"features": []},
|
||||
{"features": [0.1] * 16},
|
||||
],
|
||||
on_bad_vectors="drop",
|
||||
)
|
||||
|
||||
|
||||
def test_on_bad_vectors_with_schema_list_vector_still_sanitizes(mem_db: DBConnection):
|
||||
schema = pa.schema([pa.field("vector", pa.list_(pa.float32()))])
|
||||
table = mem_db.create_table("test_schema_list_vector", schema=schema)
|
||||
table.add(
|
||||
[
|
||||
{"vector": [1.0, 2.0]},
|
||||
{"vector": [3.0]},
|
||||
{"vector": [4.0, 5.0]},
|
||||
],
|
||||
on_bad_vectors="drop",
|
||||
)
|
||||
|
||||
assert table.to_arrow()["vector"].to_pylist() == [[1.0, 2.0], [4.0, 5.0]]
|
||||
|
||||
|
||||
def test_on_bad_vectors_handles_typed_custom_fixed_vectors_for_list_schema(
|
||||
mem_db: DBConnection,
|
||||
):
|
||||
schema = pa.schema([pa.field("vec", pa.list_(pa.float32()))])
|
||||
table = mem_db.create_table("test_typed_custom_fixed_vector", schema=schema)
|
||||
data = pa.table(
|
||||
{
|
||||
"vec": pa.array(
|
||||
[[float("nan")] * 16, [1.0] * 16],
|
||||
type=pa.list_(pa.float32(), 16),
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
table.add(data, on_bad_vectors="drop")
|
||||
|
||||
assert table.to_arrow()["vec"].to_pylist() == [[1.0] * 16]
|
||||
|
||||
|
||||
def test_on_bad_vectors_fill_preserves_arrow_nested_vector_type(mem_db: DBConnection):
|
||||
schema = pa.schema([pa.field("vector", pa.list_(pa.float32()))])
|
||||
table = mem_db.create_table("test_fill_arrow_nested_type", schema=schema)
|
||||
data = pa.table(
|
||||
{
|
||||
"vector": pa.array(
|
||||
[[1.0, 2.0], [float("nan"), 3.0]],
|
||||
type=pa.list_(pa.float32(), 2),
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
table.add(
|
||||
data,
|
||||
on_bad_vectors="fill",
|
||||
fill_value=0.0,
|
||||
)
|
||||
|
||||
assert table.to_arrow()["vector"].to_pylist() == [[1.0, 2.0], [0.0, 0.0]]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("table_name", "batch1", "expected"),
|
||||
[
|
||||
(
|
||||
"test_schema_list_vector_empty_prefix",
|
||||
pa.record_batch({"vector": [[], []]}),
|
||||
[[], [], [1.0, 2.0], [3.0, 4.0]],
|
||||
),
|
||||
(
|
||||
"test_schema_list_vector_all_bad_prefix",
|
||||
pa.record_batch({"vector": [[float("nan")] * 3, [float("nan")] * 3]}),
|
||||
[[1.0, 2.0], [3.0, 4.0]],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_on_bad_vectors_with_schema_list_vector_ignores_invalid_prefix_batches(
|
||||
mem_db: DBConnection,
|
||||
table_name: str,
|
||||
batch1: pa.RecordBatch,
|
||||
expected: list,
|
||||
):
|
||||
schema = pa.schema([pa.field("vector", pa.list_(pa.float32()))])
|
||||
table = mem_db.create_table(table_name, schema=schema)
|
||||
batch2 = pa.record_batch({"vector": [[1.0, 2.0], [3.0, 4.0]]})
|
||||
reader = pa.RecordBatchReader.from_batches(batch1.schema, [batch1, batch2])
|
||||
|
||||
table.add(reader, on_bad_vectors="drop")
|
||||
|
||||
assert table.to_arrow()["vector"].to_pylist() == expected
|
||||
|
||||
|
||||
def test_on_bad_vectors_with_multiple_vectors_locks_dim_after_final_drop(
|
||||
mem_db: DBConnection,
|
||||
):
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
func = MockTextEmbeddingFunction.create()
|
||||
metadata = registry.get_table_metadata(
|
||||
[
|
||||
EmbeddingFunctionConfig(
|
||||
source_column="text1", vector_column="vec1", function=func
|
||||
),
|
||||
EmbeddingFunctionConfig(
|
||||
source_column="text2", vector_column="vec2", function=func
|
||||
),
|
||||
]
|
||||
)
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("vec1", pa.list_(pa.float32())),
|
||||
pa.field("vec2", pa.list_(pa.float32())),
|
||||
],
|
||||
metadata=metadata,
|
||||
)
|
||||
table = mem_db.create_table("test_multi_vector_dim_lock", schema=schema)
|
||||
batch1 = pa.record_batch(
|
||||
{
|
||||
"vec1": [[1.0, 2.0, 3.0], [10.0, 11.0]],
|
||||
"vec2": [[float("nan"), 0.0], [5.0, 6.0]],
|
||||
}
|
||||
)
|
||||
batch2 = pa.record_batch(
|
||||
{
|
||||
"vec1": [[20.0, 21.0], [30.0, 31.0]],
|
||||
"vec2": [[7.0, 8.0], [9.0, 10.0]],
|
||||
}
|
||||
)
|
||||
reader = pa.RecordBatchReader.from_batches(batch1.schema, [batch1, batch2])
|
||||
|
||||
table.add(reader, on_bad_vectors="drop")
|
||||
|
||||
data = table.to_arrow()
|
||||
assert data["vec1"].to_pylist() == [[10.0, 11.0], [20.0, 21.0], [30.0, 31.0]]
|
||||
assert data["vec2"].to_pylist() == [[5.0, 6.0], [7.0, 8.0], [9.0, 10.0]]
|
||||
|
||||
|
||||
def test_on_bad_vectors_does_not_handle_non_vector_list_columns(mem_db: DBConnection):
|
||||
schema = pa.schema([pa.field("embedding_history", pa.list_(pa.float32()))])
|
||||
table = mem_db.create_table("test_non_vector_list_schema", schema=schema)
|
||||
table.add(
|
||||
[
|
||||
{"embedding_history": [1.0, 2.0]},
|
||||
{"embedding_history": [3.0]},
|
||||
],
|
||||
on_bad_vectors="drop",
|
||||
)
|
||||
|
||||
assert table.to_arrow()["embedding_history"].to_pylist() == [
|
||||
[1.0, 2.0],
|
||||
[3.0],
|
||||
]
|
||||
|
||||
|
||||
def test_on_bad_vectors_all_null_schema_vector_batches_do_not_crash(
|
||||
mem_db: DBConnection,
|
||||
):
|
||||
schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2), nullable=True)])
|
||||
table = mem_db.create_table("test_all_null_vector_batch", schema=schema)
|
||||
|
||||
table.add([{"vector": None}], on_bad_vectors="drop")
|
||||
|
||||
assert table.to_arrow()["vector"].to_pylist() == [None]
|
||||
|
||||
|
||||
def test_restore(mem_db: DBConnection):
|
||||
table = mem_db.create_table(
|
||||
"my_table",
|
||||
|
||||
Reference in New Issue
Block a user