mirror of
https://github.com/lancedb/lancedb.git
synced 2026-07-03 19:10:41 +00:00
fix(python): fill bad vector values element-wise
This commit is contained in:
@@ -6,6 +6,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import inspect
|
||||
import deprecation
|
||||
import math
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
@@ -4055,17 +4056,35 @@ def _handle_bad_vector_column(
|
||||
raise ValueError(
|
||||
"`fill_value` must not be None if `on_bad_vectors` is 'fill'"
|
||||
)
|
||||
vec_arr = pc.if_else(
|
||||
is_bad,
|
||||
pa.scalar([fill_value] * dim, type=vec_arr.type),
|
||||
vec_arr,
|
||||
)
|
||||
vec_arr = _fill_bad_vector_values(vec_arr, is_bad, dim, fill_value)
|
||||
else:
|
||||
raise ValueError(f"Invalid value for on_bad_vectors: {on_bad_vectors}")
|
||||
|
||||
return data.set_column(position, vector_column_name, vec_arr)
|
||||
|
||||
|
||||
def _fill_bad_vector_values(
|
||||
arr: Union[pa.ListArray, pa.FixedSizeListArray],
|
||||
is_bad: pa.BooleanArray,
|
||||
dim: int,
|
||||
fill_value: float,
|
||||
) -> Union[pa.ListArray, pa.FixedSizeListArray]:
|
||||
values = arr.to_pylist()
|
||||
for idx, bad in enumerate(is_bad.to_pylist()):
|
||||
if not bad:
|
||||
continue
|
||||
|
||||
vector = [] if values[idx] is None else values[idx]
|
||||
filled = [
|
||||
fill_value if isinstance(value, float) and math.isnan(value) else value
|
||||
for value in vector[:dim]
|
||||
]
|
||||
filled.extend([fill_value] * (dim - len(filled)))
|
||||
values[idx] = filled
|
||||
|
||||
return pa.array(values, type=arr.type)
|
||||
|
||||
|
||||
def has_nan_values(arr: Union[pa.ListArray, pa.ChunkedArray]) -> pa.BooleanArray:
|
||||
if isinstance(arr, pa.ChunkedArray):
|
||||
values = pa.chunked_array([chunk.flatten() for chunk in arr.chunks])
|
||||
|
||||
@@ -1568,16 +1568,23 @@ def test_create_with_nans(mem_db: DBConnection):
|
||||
"fill_test",
|
||||
data=[
|
||||
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||
{"vector": [2.1, 4.1], "item": "foo", "price": 9.0},
|
||||
{"vector": [np.nan], "item": "bar", "price": 20.0},
|
||||
{"vector": [np.nan, np.nan], "item": "bar", "price": 20.0},
|
||||
{"vector": [np.nan, 5.0], "item": "bar", "price": 21.0},
|
||||
{"vector": [5], "item": "bar", "price": 22.0},
|
||||
],
|
||||
on_bad_vectors="fill",
|
||||
fill_value=0.0,
|
||||
)
|
||||
assert len(table) == 3
|
||||
assert len(table) == 5
|
||||
arrow_tbl = table.search().where("item == 'bar'").to_arrow()
|
||||
v = arrow_tbl["vector"].to_pylist()[0]
|
||||
assert np.allclose(v, np.array([0.0, 0.0]))
|
||||
filled_vectors = {
|
||||
row["price"]: row["vector"]
|
||||
for row in arrow_tbl.select(["price", "vector"]).to_pylist()
|
||||
}
|
||||
assert np.allclose(filled_vectors[20.0], np.array([0.0, 0.0]))
|
||||
assert np.allclose(filled_vectors[21.0], np.array([0.0, 5.0]))
|
||||
assert np.allclose(filled_vectors[22.0], np.array([5.0, 0.0]))
|
||||
|
||||
|
||||
def test_add_with_nans(mem_db: DBConnection):
|
||||
@@ -1620,15 +1627,21 @@ def test_add_with_nans(mem_db: DBConnection):
|
||||
data=[
|
||||
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||
{"vector": [np.nan], "item": "bar", "price": 20.0},
|
||||
{"vector": [np.nan, np.nan], "item": "bar", "price": 20.0},
|
||||
{"vector": [np.nan, 5.0], "item": "bar", "price": 21.0},
|
||||
{"vector": [5], "item": "bar", "price": 22.0},
|
||||
],
|
||||
on_bad_vectors="fill",
|
||||
fill_value=0.0,
|
||||
)
|
||||
assert len(table) == 3
|
||||
assert len(table) == 4
|
||||
arrow_tbl = table.search().where("item == 'bar'").to_arrow()
|
||||
v = arrow_tbl["vector"].to_pylist()[0]
|
||||
assert np.allclose(v, np.array([0.0, 0.0]))
|
||||
filled_vectors = {
|
||||
row["price"]: row["vector"]
|
||||
for row in arrow_tbl.select(["price", "vector"]).to_pylist()
|
||||
}
|
||||
assert np.allclose(filled_vectors[20.0], np.array([0.0, 0.0]))
|
||||
assert np.allclose(filled_vectors[21.0], np.array([0.0, 5.0]))
|
||||
assert np.allclose(filled_vectors[22.0], np.array([5.0, 0.0]))
|
||||
|
||||
|
||||
def test_add_with_empty_fixed_size_list_drops_bad_rows(mem_db: DBConnection):
|
||||
|
||||
@@ -283,7 +283,7 @@ def test_handle_bad_vectors_jagged(on_bad_vectors):
|
||||
if on_bad_vectors == "drop":
|
||||
expected = pa.array([[1.0, 2.0], [4.0, 5.0]])
|
||||
elif on_bad_vectors == "fill":
|
||||
expected = pa.array([[1.0, 2.0], [42.0, 42.0], [4.0, 5.0]])
|
||||
expected = pa.array([[1.0, 2.0], [3.0, 42.0], [4.0, 5.0]])
|
||||
elif on_bad_vectors == "null":
|
||||
expected = pa.array([[1.0, 2.0], None, [4.0, 5.0]])
|
||||
|
||||
@@ -319,7 +319,7 @@ def test_handle_bad_vectors_nan(on_bad_vectors):
|
||||
if on_bad_vectors == "drop":
|
||||
expected = pa.array([[3.0, 4.0]])
|
||||
elif on_bad_vectors == "fill":
|
||||
expected = pa.array([[42.0, 42.0], [3.0, 4.0]])
|
||||
expected = pa.array([[1.0, 42.0], [3.0, 4.0]])
|
||||
elif on_bad_vectors == "null":
|
||||
expected = pa.array([None, [3.0, 4.0]])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user