mirror of
https://github.com/lancedb/lancedb.git
synced 2026-06-04 12:50:40 +00:00
fix(python): sanitize bad vectors before Arrow cast (#3158)
## Problem
`on_bad_vectors="drop"` is supposed to remove invalid vector rows before
write, but for some schema-defined vector columns it can still fail
later during Arrow cast instead of dropping the bad row.
Repro:
```python
class MySchema(LanceModel):
text: str
embedding: Vector(16)
table = db.create_table("test", schema=MySchema)
table.add(
[
{"text": "hello", "embedding": []},
{"text": "bar", "embedding": [0.1] * 16},
],
on_bad_vectors="drop",
)
```
Before:
```
RuntimeError
Arrow error: C Data interface error: Invalid: ListType can only be casted to FixedSizeListType if the lists are all the expected size.
```
After:
```
rows 1
texts ['bar']
```
## Solution
Make bad-vector sanitization use schema dimensions before cast, while
keeping the handling scoped to vector columns identified by schema
metadata or existing vector-name heuristics.
This also preserves existing integer vector inputs and avoids applying
on_bad_vectors to unrelated fixed-size float columns.
Fixes #1670
Signed-off-by: yaommen <myanstu@163.com>
This commit is contained in:
@@ -1049,6 +1049,231 @@ def test_add_with_nans(mem_db: DBConnection):
|
||||
assert np.allclose(v, np.array([0.0, 0.0]))
|
||||
|
||||
|
||||
def test_add_with_empty_fixed_size_list_drops_bad_rows(mem_db: DBConnection):
|
||||
class Schema(LanceModel):
|
||||
text: str
|
||||
embedding: Vector(16)
|
||||
|
||||
table = mem_db.create_table("test_empty_embeddings", schema=Schema)
|
||||
table.add(
|
||||
[
|
||||
{"text": "hello", "embedding": []},
|
||||
{"text": "bar", "embedding": [0.1] * 16},
|
||||
],
|
||||
on_bad_vectors="drop",
|
||||
)
|
||||
|
||||
data = table.to_arrow()
|
||||
assert data["text"].to_pylist() == ["bar"]
|
||||
assert np.allclose(data["embedding"].to_pylist()[0], np.array([0.1] * 16))
|
||||
|
||||
|
||||
def test_add_with_integer_embeddings_preserves_casting(mem_db: DBConnection):
|
||||
class Schema(LanceModel):
|
||||
text: str
|
||||
embedding: Vector(4)
|
||||
|
||||
table = mem_db.create_table("test_integer_embeddings", schema=Schema)
|
||||
table.add(
|
||||
[{"text": "foo", "embedding": [1, 2, 3, 4]}],
|
||||
on_bad_vectors="drop",
|
||||
)
|
||||
|
||||
assert table.to_arrow()["embedding"].to_pylist() == [[1.0, 2.0, 3.0, 4.0]]
|
||||
|
||||
|
||||
def test_on_bad_vectors_does_not_handle_non_vector_fixed_size_lists(
|
||||
mem_db: DBConnection,
|
||||
):
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("vector", pa.list_(pa.float32(), 4)),
|
||||
pa.field("bbox", pa.list_(pa.float32(), 4)),
|
||||
]
|
||||
)
|
||||
table = mem_db.create_table("test_bbox_schema", schema=schema)
|
||||
|
||||
with pytest.raises(RuntimeError, match="FixedSizeListType"):
|
||||
table.add(
|
||||
[{"vector": [1.0, 2.0, 3.0, 4.0], "bbox": [0.0, 1.0]}],
|
||||
on_bad_vectors="drop",
|
||||
)
|
||||
|
||||
|
||||
def test_on_bad_vectors_does_not_handle_custom_named_fixed_size_lists(
|
||||
mem_db: DBConnection,
|
||||
):
|
||||
schema = pa.schema([pa.field("features", pa.list_(pa.float32(), 16))])
|
||||
table = mem_db.create_table("test_custom_named_fixed_size_vector", schema=schema)
|
||||
|
||||
with pytest.raises(RuntimeError, match="FixedSizeListType"):
|
||||
table.add(
|
||||
[
|
||||
{"features": []},
|
||||
{"features": [0.1] * 16},
|
||||
],
|
||||
on_bad_vectors="drop",
|
||||
)
|
||||
|
||||
|
||||
def test_on_bad_vectors_with_schema_list_vector_still_sanitizes(mem_db: DBConnection):
|
||||
schema = pa.schema([pa.field("vector", pa.list_(pa.float32()))])
|
||||
table = mem_db.create_table("test_schema_list_vector", schema=schema)
|
||||
table.add(
|
||||
[
|
||||
{"vector": [1.0, 2.0]},
|
||||
{"vector": [3.0]},
|
||||
{"vector": [4.0, 5.0]},
|
||||
],
|
||||
on_bad_vectors="drop",
|
||||
)
|
||||
|
||||
assert table.to_arrow()["vector"].to_pylist() == [[1.0, 2.0], [4.0, 5.0]]
|
||||
|
||||
|
||||
def test_on_bad_vectors_handles_typed_custom_fixed_vectors_for_list_schema(
|
||||
mem_db: DBConnection,
|
||||
):
|
||||
schema = pa.schema([pa.field("vec", pa.list_(pa.float32()))])
|
||||
table = mem_db.create_table("test_typed_custom_fixed_vector", schema=schema)
|
||||
data = pa.table(
|
||||
{
|
||||
"vec": pa.array(
|
||||
[[float("nan")] * 16, [1.0] * 16],
|
||||
type=pa.list_(pa.float32(), 16),
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
table.add(data, on_bad_vectors="drop")
|
||||
|
||||
assert table.to_arrow()["vec"].to_pylist() == [[1.0] * 16]
|
||||
|
||||
|
||||
def test_on_bad_vectors_fill_preserves_arrow_nested_vector_type(mem_db: DBConnection):
|
||||
schema = pa.schema([pa.field("vector", pa.list_(pa.float32()))])
|
||||
table = mem_db.create_table("test_fill_arrow_nested_type", schema=schema)
|
||||
data = pa.table(
|
||||
{
|
||||
"vector": pa.array(
|
||||
[[1.0, 2.0], [float("nan"), 3.0]],
|
||||
type=pa.list_(pa.float32(), 2),
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
table.add(
|
||||
data,
|
||||
on_bad_vectors="fill",
|
||||
fill_value=0.0,
|
||||
)
|
||||
|
||||
assert table.to_arrow()["vector"].to_pylist() == [[1.0, 2.0], [0.0, 0.0]]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("table_name", "batch1", "expected"),
|
||||
[
|
||||
(
|
||||
"test_schema_list_vector_empty_prefix",
|
||||
pa.record_batch({"vector": [[], []]}),
|
||||
[[], [], [1.0, 2.0], [3.0, 4.0]],
|
||||
),
|
||||
(
|
||||
"test_schema_list_vector_all_bad_prefix",
|
||||
pa.record_batch({"vector": [[float("nan")] * 3, [float("nan")] * 3]}),
|
||||
[[1.0, 2.0], [3.0, 4.0]],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_on_bad_vectors_with_schema_list_vector_ignores_invalid_prefix_batches(
|
||||
mem_db: DBConnection,
|
||||
table_name: str,
|
||||
batch1: pa.RecordBatch,
|
||||
expected: list,
|
||||
):
|
||||
schema = pa.schema([pa.field("vector", pa.list_(pa.float32()))])
|
||||
table = mem_db.create_table(table_name, schema=schema)
|
||||
batch2 = pa.record_batch({"vector": [[1.0, 2.0], [3.0, 4.0]]})
|
||||
reader = pa.RecordBatchReader.from_batches(batch1.schema, [batch1, batch2])
|
||||
|
||||
table.add(reader, on_bad_vectors="drop")
|
||||
|
||||
assert table.to_arrow()["vector"].to_pylist() == expected
|
||||
|
||||
|
||||
def test_on_bad_vectors_with_multiple_vectors_locks_dim_after_final_drop(
|
||||
mem_db: DBConnection,
|
||||
):
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
func = MockTextEmbeddingFunction.create()
|
||||
metadata = registry.get_table_metadata(
|
||||
[
|
||||
EmbeddingFunctionConfig(
|
||||
source_column="text1", vector_column="vec1", function=func
|
||||
),
|
||||
EmbeddingFunctionConfig(
|
||||
source_column="text2", vector_column="vec2", function=func
|
||||
),
|
||||
]
|
||||
)
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("vec1", pa.list_(pa.float32())),
|
||||
pa.field("vec2", pa.list_(pa.float32())),
|
||||
],
|
||||
metadata=metadata,
|
||||
)
|
||||
table = mem_db.create_table("test_multi_vector_dim_lock", schema=schema)
|
||||
batch1 = pa.record_batch(
|
||||
{
|
||||
"vec1": [[1.0, 2.0, 3.0], [10.0, 11.0]],
|
||||
"vec2": [[float("nan"), 0.0], [5.0, 6.0]],
|
||||
}
|
||||
)
|
||||
batch2 = pa.record_batch(
|
||||
{
|
||||
"vec1": [[20.0, 21.0], [30.0, 31.0]],
|
||||
"vec2": [[7.0, 8.0], [9.0, 10.0]],
|
||||
}
|
||||
)
|
||||
reader = pa.RecordBatchReader.from_batches(batch1.schema, [batch1, batch2])
|
||||
|
||||
table.add(reader, on_bad_vectors="drop")
|
||||
|
||||
data = table.to_arrow()
|
||||
assert data["vec1"].to_pylist() == [[10.0, 11.0], [20.0, 21.0], [30.0, 31.0]]
|
||||
assert data["vec2"].to_pylist() == [[5.0, 6.0], [7.0, 8.0], [9.0, 10.0]]
|
||||
|
||||
|
||||
def test_on_bad_vectors_does_not_handle_non_vector_list_columns(mem_db: DBConnection):
|
||||
schema = pa.schema([pa.field("embedding_history", pa.list_(pa.float32()))])
|
||||
table = mem_db.create_table("test_non_vector_list_schema", schema=schema)
|
||||
table.add(
|
||||
[
|
||||
{"embedding_history": [1.0, 2.0]},
|
||||
{"embedding_history": [3.0]},
|
||||
],
|
||||
on_bad_vectors="drop",
|
||||
)
|
||||
|
||||
assert table.to_arrow()["embedding_history"].to_pylist() == [
|
||||
[1.0, 2.0],
|
||||
[3.0],
|
||||
]
|
||||
|
||||
|
||||
def test_on_bad_vectors_all_null_schema_vector_batches_do_not_crash(
|
||||
mem_db: DBConnection,
|
||||
):
|
||||
schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2), nullable=True)])
|
||||
table = mem_db.create_table("test_all_null_vector_batch", schema=schema)
|
||||
|
||||
table.add([{"vector": None}], on_bad_vectors="drop")
|
||||
|
||||
assert table.to_arrow()["vector"].to_pylist() == [None]
|
||||
|
||||
|
||||
def test_restore(mem_db: DBConnection):
|
||||
table = mem_db.create_table(
|
||||
"my_table",
|
||||
|
||||
@@ -15,8 +15,10 @@ from lancedb.table import (
|
||||
_cast_to_target_schema,
|
||||
_handle_bad_vectors,
|
||||
_into_pyarrow_reader,
|
||||
_sanitize_data,
|
||||
_infer_target_schema,
|
||||
_merge_metadata,
|
||||
_sanitize_data,
|
||||
sanitize_create_table,
|
||||
)
|
||||
import pyarrow as pa
|
||||
import pandas as pd
|
||||
@@ -304,6 +306,117 @@ def test_handle_bad_vectors_noop():
|
||||
assert output["vector"] == vector
|
||||
|
||||
|
||||
def test_handle_bad_vectors_updates_reader_schema_for_target_schema():
|
||||
data = pa.table({"vector": [[1, 2, 3, 4]]})
|
||||
target_schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 4))])
|
||||
|
||||
output = _handle_bad_vectors(
|
||||
data.to_reader(),
|
||||
on_bad_vectors="drop",
|
||||
target_schema=target_schema,
|
||||
)
|
||||
|
||||
assert output.schema == pa.schema([pa.field("vector", pa.list_(pa.float32()))])
|
||||
assert output.read_all()["vector"].to_pylist() == [[1.0, 2.0, 3.0, 4.0]]
|
||||
|
||||
|
||||
def test_sanitize_data_keeps_target_field_metadata():
|
||||
source_field = pa.field(
|
||||
"vector",
|
||||
pa.list_(pa.float32(), 2),
|
||||
metadata={b"source": b"drop-me"},
|
||||
)
|
||||
target_field = pa.field(
|
||||
"vector",
|
||||
pa.list_(pa.float32(), 2),
|
||||
metadata={b"target": b"keep-me"},
|
||||
)
|
||||
data = pa.table(
|
||||
{"vector": pa.array([[1.0, 2.0]], type=pa.list_(pa.float32(), 2))},
|
||||
schema=pa.schema([source_field]),
|
||||
)
|
||||
|
||||
output = _sanitize_data(
|
||||
data,
|
||||
target_schema=pa.schema([target_field]),
|
||||
on_bad_vectors="drop",
|
||||
).read_all()
|
||||
|
||||
assert output.schema.field("vector").metadata == {b"target": b"keep-me"}
|
||||
|
||||
|
||||
def test_sanitize_data_uses_separate_embedding_metadata_for_bad_vectors():
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
conf = EmbeddingFunctionConfig(
|
||||
source_column="text",
|
||||
vector_column="custom_vector",
|
||||
function=MockTextEmbeddingFunction.create(),
|
||||
)
|
||||
metadata = registry.get_table_metadata([conf])
|
||||
schema = pa.schema(
|
||||
{
|
||||
"text": pa.string(),
|
||||
"custom_vector": pa.list_(pa.float32(), 10),
|
||||
},
|
||||
metadata={b"note": b"keep-me"},
|
||||
)
|
||||
data = pa.table(
|
||||
{
|
||||
"text": ["bad", "good"],
|
||||
"custom_vector": [[1.0] * 9, [2.0] * 10],
|
||||
}
|
||||
)
|
||||
|
||||
output = _sanitize_data(
|
||||
data,
|
||||
target_schema=schema,
|
||||
metadata=metadata,
|
||||
on_bad_vectors="drop",
|
||||
).read_all()
|
||||
|
||||
assert output["text"].to_pylist() == ["good"]
|
||||
assert output.schema.metadata[b"note"] == b"keep-me"
|
||||
assert b"embedding_functions" in output.schema.metadata
|
||||
|
||||
|
||||
def test_sanitize_create_table_merges_and_overrides_embedding_metadata():
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
old_conf = EmbeddingFunctionConfig(
|
||||
source_column="text",
|
||||
vector_column="old_vector",
|
||||
function=MockTextEmbeddingFunction.create(),
|
||||
)
|
||||
new_conf = EmbeddingFunctionConfig(
|
||||
source_column="text",
|
||||
vector_column="custom_vector",
|
||||
function=MockTextEmbeddingFunction.create(),
|
||||
)
|
||||
metadata = registry.get_table_metadata([new_conf])
|
||||
schema = pa.schema(
|
||||
{
|
||||
"text": pa.string(),
|
||||
"custom_vector": pa.list_(pa.float32(), 10),
|
||||
},
|
||||
metadata=_merge_metadata(
|
||||
{b"note": b"keep-me"},
|
||||
registry.get_table_metadata([old_conf]),
|
||||
),
|
||||
)
|
||||
|
||||
data, schema = sanitize_create_table(
|
||||
pa.table({"text": ["good"]}),
|
||||
schema,
|
||||
metadata=metadata,
|
||||
on_bad_vectors="drop",
|
||||
)
|
||||
|
||||
assert schema.metadata[b"note"] == b"keep-me"
|
||||
assert b"embedding_functions" in schema.metadata
|
||||
assert data.schema.metadata[b"note"] == b"keep-me"
|
||||
funcs = EmbeddingFunctionRegistry.get_instance().parse_functions(schema.metadata)
|
||||
assert set(funcs.keys()) == {"custom_vector"}
|
||||
|
||||
|
||||
class TestModel(lancedb.pydantic.LanceModel):
|
||||
a: Optional[int]
|
||||
b: Optional[int]
|
||||
|
||||
Reference in New Issue
Block a user