From 3446d02f4ec092d48b5f919baffbb0c04e0d6945 Mon Sep 17 00:00:00 2001 From: prrao87 <35005448+prrao87@users.noreply.github.com> Date: Thu, 2 Jul 2026 15:05:33 -0400 Subject: [PATCH] fix(python): fill bad vector values element-wise --- python/python/lancedb/table.py | 29 ++++++++++++++++++++++++----- python/python/tests/test_table.py | 29 +++++++++++++++++++++-------- python/python/tests/test_util.py | 4 ++-- 3 files changed, 47 insertions(+), 15 deletions(-) diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 28ee37fbf..cbf06b525 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -6,6 +6,7 @@ from __future__ import annotations import asyncio import inspect import deprecation +import math import warnings from abc import ABC, abstractmethod from dataclasses import dataclass @@ -4055,17 +4056,35 @@ def _handle_bad_vector_column( raise ValueError( "`fill_value` must not be None if `on_bad_vectors` is 'fill'" ) - vec_arr = pc.if_else( - is_bad, - pa.scalar([fill_value] * dim, type=vec_arr.type), - vec_arr, - ) + vec_arr = _fill_bad_vector_values(vec_arr, is_bad, dim, fill_value) else: raise ValueError(f"Invalid value for on_bad_vectors: {on_bad_vectors}") return data.set_column(position, vector_column_name, vec_arr) +def _fill_bad_vector_values( + arr: Union[pa.ListArray, pa.FixedSizeListArray], + is_bad: pa.BooleanArray, + dim: int, + fill_value: float, +) -> Union[pa.ListArray, pa.FixedSizeListArray]: + values = arr.to_pylist() + for idx, bad in enumerate(is_bad.to_pylist()): + if not bad: + continue + + vector = [] if values[idx] is None else values[idx] + filled = [ + fill_value if isinstance(value, float) and math.isnan(value) else value + for value in vector[:dim] + ] + filled.extend([fill_value] * (dim - len(filled))) + values[idx] = filled + + return pa.array(values, type=arr.type) + + def has_nan_values(arr: Union[pa.ListArray, pa.ChunkedArray]) -> pa.BooleanArray: if isinstance(arr, pa.ChunkedArray): values = pa.chunked_array([chunk.flatten() for chunk in arr.chunks]) diff --git a/python/python/tests/test_table.py b/python/python/tests/test_table.py index dae9d5b2e..f39c5cadc 100644 --- a/python/python/tests/test_table.py +++ b/python/python/tests/test_table.py @@ -1568,16 +1568,23 @@ def test_create_with_nans(mem_db: DBConnection): "fill_test", data=[ {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, + {"vector": [2.1, 4.1], "item": "foo", "price": 9.0}, {"vector": [np.nan], "item": "bar", "price": 20.0}, - {"vector": [np.nan, np.nan], "item": "bar", "price": 20.0}, + {"vector": [np.nan, 5.0], "item": "bar", "price": 21.0}, + {"vector": [5], "item": "bar", "price": 22.0}, ], on_bad_vectors="fill", fill_value=0.0, ) - assert len(table) == 3 + assert len(table) == 5 arrow_tbl = table.search().where("item == 'bar'").to_arrow() - v = arrow_tbl["vector"].to_pylist()[0] - assert np.allclose(v, np.array([0.0, 0.0])) + filled_vectors = { + row["price"]: row["vector"] + for row in arrow_tbl.select(["price", "vector"]).to_pylist() + } + assert np.allclose(filled_vectors[20.0], np.array([0.0, 0.0])) + assert np.allclose(filled_vectors[21.0], np.array([0.0, 5.0])) + assert np.allclose(filled_vectors[22.0], np.array([5.0, 0.0])) def test_add_with_nans(mem_db: DBConnection): @@ -1620,15 +1627,21 @@ def test_add_with_nans(mem_db: DBConnection): data=[ {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, {"vector": [np.nan], "item": "bar", "price": 20.0}, - {"vector": [np.nan, np.nan], "item": "bar", "price": 20.0}, + {"vector": [np.nan, 5.0], "item": "bar", "price": 21.0}, + {"vector": [5], "item": "bar", "price": 22.0}, ], on_bad_vectors="fill", fill_value=0.0, ) - assert len(table) == 3 + assert len(table) == 4 arrow_tbl = table.search().where("item == 'bar'").to_arrow() - v = arrow_tbl["vector"].to_pylist()[0] - assert np.allclose(v, np.array([0.0, 0.0])) + filled_vectors = { + row["price"]: row["vector"] + for row in arrow_tbl.select(["price", "vector"]).to_pylist() + } + assert np.allclose(filled_vectors[20.0], np.array([0.0, 0.0])) + assert np.allclose(filled_vectors[21.0], np.array([0.0, 5.0])) + assert np.allclose(filled_vectors[22.0], np.array([5.0, 0.0])) def test_add_with_empty_fixed_size_list_drops_bad_rows(mem_db: DBConnection): diff --git a/python/python/tests/test_util.py b/python/python/tests/test_util.py index c96407779..ed1b794db 100644 --- a/python/python/tests/test_util.py +++ b/python/python/tests/test_util.py @@ -283,7 +283,7 @@ def test_handle_bad_vectors_jagged(on_bad_vectors): if on_bad_vectors == "drop": expected = pa.array([[1.0, 2.0], [4.0, 5.0]]) elif on_bad_vectors == "fill": - expected = pa.array([[1.0, 2.0], [42.0, 42.0], [4.0, 5.0]]) + expected = pa.array([[1.0, 2.0], [3.0, 42.0], [4.0, 5.0]]) elif on_bad_vectors == "null": expected = pa.array([[1.0, 2.0], None, [4.0, 5.0]]) @@ -319,7 +319,7 @@ def test_handle_bad_vectors_nan(on_bad_vectors): if on_bad_vectors == "drop": expected = pa.array([[3.0, 4.0]]) elif on_bad_vectors == "fill": - expected = pa.array([[42.0, 42.0], [3.0, 4.0]]) + expected = pa.array([[1.0, 42.0], [3.0, 4.0]]) elif on_bad_vectors == "null": expected = pa.array([None, [3.0, 4.0]])