fix(python): fill bad vector values element-wise

This commit is contained in:
prrao87
2026-07-02 15:05:33 -04:00
parent 37466a0390
commit 3446d02f4e
3 changed files with 47 additions and 15 deletions

View File

@@ -6,6 +6,7 @@ from __future__ import annotations
import asyncio
import inspect
import deprecation
import math
import warnings
from abc import ABC, abstractmethod
from dataclasses import dataclass
@@ -4055,17 +4056,35 @@ def _handle_bad_vector_column(
raise ValueError(
"`fill_value` must not be None if `on_bad_vectors` is 'fill'"
)
vec_arr = pc.if_else(
is_bad,
pa.scalar([fill_value] * dim, type=vec_arr.type),
vec_arr,
)
vec_arr = _fill_bad_vector_values(vec_arr, is_bad, dim, fill_value)
else:
raise ValueError(f"Invalid value for on_bad_vectors: {on_bad_vectors}")
return data.set_column(position, vector_column_name, vec_arr)
def _fill_bad_vector_values(
arr: Union[pa.ListArray, pa.FixedSizeListArray],
is_bad: pa.BooleanArray,
dim: int,
fill_value: float,
) -> Union[pa.ListArray, pa.FixedSizeListArray]:
values = arr.to_pylist()
for idx, bad in enumerate(is_bad.to_pylist()):
if not bad:
continue
vector = [] if values[idx] is None else values[idx]
filled = [
fill_value if isinstance(value, float) and math.isnan(value) else value
for value in vector[:dim]
]
filled.extend([fill_value] * (dim - len(filled)))
values[idx] = filled
return pa.array(values, type=arr.type)
def has_nan_values(arr: Union[pa.ListArray, pa.ChunkedArray]) -> pa.BooleanArray:
if isinstance(arr, pa.ChunkedArray):
values = pa.chunked_array([chunk.flatten() for chunk in arr.chunks])

View File

@@ -1568,16 +1568,23 @@ def test_create_with_nans(mem_db: DBConnection):
"fill_test",
data=[
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [2.1, 4.1], "item": "foo", "price": 9.0},
{"vector": [np.nan], "item": "bar", "price": 20.0},
{"vector": [np.nan, np.nan], "item": "bar", "price": 20.0},
{"vector": [np.nan, 5.0], "item": "bar", "price": 21.0},
{"vector": [5], "item": "bar", "price": 22.0},
],
on_bad_vectors="fill",
fill_value=0.0,
)
assert len(table) == 3
assert len(table) == 5
arrow_tbl = table.search().where("item == 'bar'").to_arrow()
v = arrow_tbl["vector"].to_pylist()[0]
assert np.allclose(v, np.array([0.0, 0.0]))
filled_vectors = {
row["price"]: row["vector"]
for row in arrow_tbl.select(["price", "vector"]).to_pylist()
}
assert np.allclose(filled_vectors[20.0], np.array([0.0, 0.0]))
assert np.allclose(filled_vectors[21.0], np.array([0.0, 5.0]))
assert np.allclose(filled_vectors[22.0], np.array([5.0, 0.0]))
def test_add_with_nans(mem_db: DBConnection):
@@ -1620,15 +1627,21 @@ def test_add_with_nans(mem_db: DBConnection):
data=[
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [np.nan], "item": "bar", "price": 20.0},
{"vector": [np.nan, np.nan], "item": "bar", "price": 20.0},
{"vector": [np.nan, 5.0], "item": "bar", "price": 21.0},
{"vector": [5], "item": "bar", "price": 22.0},
],
on_bad_vectors="fill",
fill_value=0.0,
)
assert len(table) == 3
assert len(table) == 4
arrow_tbl = table.search().where("item == 'bar'").to_arrow()
v = arrow_tbl["vector"].to_pylist()[0]
assert np.allclose(v, np.array([0.0, 0.0]))
filled_vectors = {
row["price"]: row["vector"]
for row in arrow_tbl.select(["price", "vector"]).to_pylist()
}
assert np.allclose(filled_vectors[20.0], np.array([0.0, 0.0]))
assert np.allclose(filled_vectors[21.0], np.array([0.0, 5.0]))
assert np.allclose(filled_vectors[22.0], np.array([5.0, 0.0]))
def test_add_with_empty_fixed_size_list_drops_bad_rows(mem_db: DBConnection):

View File

@@ -283,7 +283,7 @@ def test_handle_bad_vectors_jagged(on_bad_vectors):
if on_bad_vectors == "drop":
expected = pa.array([[1.0, 2.0], [4.0, 5.0]])
elif on_bad_vectors == "fill":
expected = pa.array([[1.0, 2.0], [42.0, 42.0], [4.0, 5.0]])
expected = pa.array([[1.0, 2.0], [3.0, 42.0], [4.0, 5.0]])
elif on_bad_vectors == "null":
expected = pa.array([[1.0, 2.0], None, [4.0, 5.0]])
@@ -319,7 +319,7 @@ def test_handle_bad_vectors_nan(on_bad_vectors):
if on_bad_vectors == "drop":
expected = pa.array([[3.0, 4.0]])
elif on_bad_vectors == "fill":
expected = pa.array([[42.0, 42.0], [3.0, 4.0]])
expected = pa.array([[1.0, 42.0], [3.0, 4.0]])
elif on_bad_vectors == "null":
expected = pa.array([None, [3.0, 4.0]])