Compare commits

..

1 Commits

Author SHA1 Message Date
prrao87
3446d02f4e fix(python): fill bad vector values element-wise 2026-07-02 15:05:33 -04:00
6 changed files with 102 additions and 70 deletions

80
Cargo.lock generated
View File

@@ -3423,8 +3423,8 @@ checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
[[package]]
name = "fsst"
version = "9.0.0-beta.11"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.11#9644d163c6e1ed309eea9ec2960c000a17663ed6"
version = "9.0.0-beta.10"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.10#e25b71e74b89d10c57b412d111bde087117383f3"
dependencies = [
"arrow-array",
"rand 0.9.4",
@@ -4726,8 +4726,8 @@ checksum = "e037a2e1d8d5fdbd49b16a4ea09d5d6401c1f29eca5ff29d03d3824dba16256a"
[[package]]
name = "lance"
version = "9.0.0-beta.11"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.11#9644d163c6e1ed309eea9ec2960c000a17663ed6"
version = "9.0.0-beta.10"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.10#e25b71e74b89d10c57b412d111bde087117383f3"
dependencies = [
"arc-swap",
"arrow",
@@ -4801,8 +4801,8 @@ dependencies = [
[[package]]
name = "lance-arrow"
version = "9.0.0-beta.11"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.11#9644d163c6e1ed309eea9ec2960c000a17663ed6"
version = "9.0.0-beta.10"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.10#e25b71e74b89d10c57b412d111bde087117383f3"
dependencies = [
"arrow-array",
"arrow-buffer",
@@ -4823,7 +4823,7 @@ dependencies = [
[[package]]
name = "lance-arrow-scalar"
version = "58.0.0"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.11#9644d163c6e1ed309eea9ec2960c000a17663ed6"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.10#e25b71e74b89d10c57b412d111bde087117383f3"
dependencies = [
"arrow-array",
"arrow-buffer",
@@ -4837,7 +4837,7 @@ dependencies = [
[[package]]
name = "lance-arrow-stats"
version = "58.0.0"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.11#9644d163c6e1ed309eea9ec2960c000a17663ed6"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.10#e25b71e74b89d10c57b412d111bde087117383f3"
dependencies = [
"arrow-array",
"arrow-schema",
@@ -4846,8 +4846,8 @@ dependencies = [
[[package]]
name = "lance-bitpacking"
version = "9.0.0-beta.11"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.11#9644d163c6e1ed309eea9ec2960c000a17663ed6"
version = "9.0.0-beta.10"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.10#e25b71e74b89d10c57b412d111bde087117383f3"
dependencies = [
"arrayref",
"crunchy",
@@ -4857,8 +4857,8 @@ dependencies = [
[[package]]
name = "lance-core"
version = "9.0.0-beta.11"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.11#9644d163c6e1ed309eea9ec2960c000a17663ed6"
version = "9.0.0-beta.10"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.10#e25b71e74b89d10c57b412d111bde087117383f3"
dependencies = [
"arrow-array",
"arrow-buffer",
@@ -4896,8 +4896,8 @@ dependencies = [
[[package]]
name = "lance-datafusion"
version = "9.0.0-beta.11"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.11#9644d163c6e1ed309eea9ec2960c000a17663ed6"
version = "9.0.0-beta.10"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.10#e25b71e74b89d10c57b412d111bde087117383f3"
dependencies = [
"arrow",
"arrow-array",
@@ -4927,8 +4927,8 @@ dependencies = [
[[package]]
name = "lance-datagen"
version = "9.0.0-beta.11"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.11#9644d163c6e1ed309eea9ec2960c000a17663ed6"
version = "9.0.0-beta.10"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.10#e25b71e74b89d10c57b412d111bde087117383f3"
dependencies = [
"arrow",
"arrow-array",
@@ -4945,8 +4945,8 @@ dependencies = [
[[package]]
name = "lance-derive"
version = "9.0.0-beta.11"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.11#9644d163c6e1ed309eea9ec2960c000a17663ed6"
version = "9.0.0-beta.10"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.10#e25b71e74b89d10c57b412d111bde087117383f3"
dependencies = [
"proc-macro2",
"quote",
@@ -4955,8 +4955,8 @@ dependencies = [
[[package]]
name = "lance-encoding"
version = "9.0.0-beta.11"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.11#9644d163c6e1ed309eea9ec2960c000a17663ed6"
version = "9.0.0-beta.10"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.10#e25b71e74b89d10c57b412d111bde087117383f3"
dependencies = [
"arrow-arith",
"arrow-array",
@@ -4991,8 +4991,8 @@ dependencies = [
[[package]]
name = "lance-file"
version = "9.0.0-beta.11"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.11#9644d163c6e1ed309eea9ec2960c000a17663ed6"
version = "9.0.0-beta.10"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.10#e25b71e74b89d10c57b412d111bde087117383f3"
dependencies = [
"arrow-arith",
"arrow-array",
@@ -5022,8 +5022,8 @@ dependencies = [
[[package]]
name = "lance-index"
version = "9.0.0-beta.11"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.11#9644d163c6e1ed309eea9ec2960c000a17663ed6"
version = "9.0.0-beta.10"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.10#e25b71e74b89d10c57b412d111bde087117383f3"
dependencies = [
"arc-swap",
"arrow",
@@ -5088,8 +5088,8 @@ dependencies = [
[[package]]
name = "lance-io"
version = "9.0.0-beta.11"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.11#9644d163c6e1ed309eea9ec2960c000a17663ed6"
version = "9.0.0-beta.10"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.10#e25b71e74b89d10c57b412d111bde087117383f3"
dependencies = [
"arrow",
"arrow-arith",
@@ -5130,8 +5130,8 @@ dependencies = [
[[package]]
name = "lance-linalg"
version = "9.0.0-beta.11"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.11#9644d163c6e1ed309eea9ec2960c000a17663ed6"
version = "9.0.0-beta.10"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.10#e25b71e74b89d10c57b412d111bde087117383f3"
dependencies = [
"arrow-array",
"arrow-buffer",
@@ -5147,8 +5147,8 @@ dependencies = [
[[package]]
name = "lance-namespace"
version = "9.0.0-beta.11"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.11#9644d163c6e1ed309eea9ec2960c000a17663ed6"
version = "9.0.0-beta.10"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.10#e25b71e74b89d10c57b412d111bde087117383f3"
dependencies = [
"arrow",
"async-trait",
@@ -5160,8 +5160,8 @@ dependencies = [
[[package]]
name = "lance-namespace-impls"
version = "9.0.0-beta.11"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.11#9644d163c6e1ed309eea9ec2960c000a17663ed6"
version = "9.0.0-beta.10"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.10#e25b71e74b89d10c57b412d111bde087117383f3"
dependencies = [
"arrow",
"arrow-ipc",
@@ -5215,8 +5215,8 @@ dependencies = [
[[package]]
name = "lance-select"
version = "9.0.0-beta.11"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.11#9644d163c6e1ed309eea9ec2960c000a17663ed6"
version = "9.0.0-beta.10"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.10#e25b71e74b89d10c57b412d111bde087117383f3"
dependencies = [
"arrow-array",
"arrow-buffer",
@@ -5231,8 +5231,8 @@ dependencies = [
[[package]]
name = "lance-table"
version = "9.0.0-beta.11"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.11#9644d163c6e1ed309eea9ec2960c000a17663ed6"
version = "9.0.0-beta.10"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.10#e25b71e74b89d10c57b412d111bde087117383f3"
dependencies = [
"arrow",
"arrow-array",
@@ -5271,8 +5271,8 @@ dependencies = [
[[package]]
name = "lance-testing"
version = "9.0.0-beta.11"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.11#9644d163c6e1ed309eea9ec2960c000a17663ed6"
version = "9.0.0-beta.10"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.10#e25b71e74b89d10c57b412d111bde087117383f3"
dependencies = [
"arrow-array",
"arrow-schema",
@@ -5285,8 +5285,8 @@ dependencies = [
[[package]]
name = "lance-tokenizer"
version = "9.0.0-beta.11"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.11#9644d163c6e1ed309eea9ec2960c000a17663ed6"
version = "9.0.0-beta.10"
source = "git+https://github.com/lance-format/lance.git?tag=v9.0.0-beta.10#e25b71e74b89d10c57b412d111bde087117383f3"
dependencies = [
"icu_segmenter",
"jieba-rs",

View File

@@ -13,20 +13,20 @@ categories = ["database-implementations"]
rust-version = "1.91.0"
[workspace.dependencies]
lance = { "version" = "=9.0.0-beta.11", default-features = false, "tag" = "v9.0.0-beta.11", "git" = "https://github.com/lance-format/lance.git" }
lance-core = { "version" = "=9.0.0-beta.11", "tag" = "v9.0.0-beta.11", "git" = "https://github.com/lance-format/lance.git" }
lance-datagen = { "version" = "=9.0.0-beta.11", "tag" = "v9.0.0-beta.11", "git" = "https://github.com/lance-format/lance.git" }
lance-file = { "version" = "=9.0.0-beta.11", "tag" = "v9.0.0-beta.11", "git" = "https://github.com/lance-format/lance.git" }
lance-io = { "version" = "=9.0.0-beta.11", default-features = false, "tag" = "v9.0.0-beta.11", "git" = "https://github.com/lance-format/lance.git" }
lance-index = { "version" = "=9.0.0-beta.11", "tag" = "v9.0.0-beta.11", "git" = "https://github.com/lance-format/lance.git" }
lance-linalg = { "version" = "=9.0.0-beta.11", "tag" = "v9.0.0-beta.11", "git" = "https://github.com/lance-format/lance.git" }
lance-namespace = { "version" = "=9.0.0-beta.11", "tag" = "v9.0.0-beta.11", "git" = "https://github.com/lance-format/lance.git" }
lance-namespace-impls = { "version" = "=9.0.0-beta.11", default-features = false, "tag" = "v9.0.0-beta.11", "git" = "https://github.com/lance-format/lance.git" }
lance-table = { "version" = "=9.0.0-beta.11", "tag" = "v9.0.0-beta.11", "git" = "https://github.com/lance-format/lance.git" }
lance-testing = { "version" = "=9.0.0-beta.11", "tag" = "v9.0.0-beta.11", "git" = "https://github.com/lance-format/lance.git" }
lance-datafusion = { "version" = "=9.0.0-beta.11", "tag" = "v9.0.0-beta.11", "git" = "https://github.com/lance-format/lance.git" }
lance-encoding = { "version" = "=9.0.0-beta.11", "tag" = "v9.0.0-beta.11", "git" = "https://github.com/lance-format/lance.git" }
lance-arrow = { "version" = "=9.0.0-beta.11", "tag" = "v9.0.0-beta.11", "git" = "https://github.com/lance-format/lance.git" }
lance = { "version" = "=9.0.0-beta.10", default-features = false, "tag" = "v9.0.0-beta.10", "git" = "https://github.com/lance-format/lance.git" }
lance-core = { "version" = "=9.0.0-beta.10", "tag" = "v9.0.0-beta.10", "git" = "https://github.com/lance-format/lance.git" }
lance-datagen = { "version" = "=9.0.0-beta.10", "tag" = "v9.0.0-beta.10", "git" = "https://github.com/lance-format/lance.git" }
lance-file = { "version" = "=9.0.0-beta.10", "tag" = "v9.0.0-beta.10", "git" = "https://github.com/lance-format/lance.git" }
lance-io = { "version" = "=9.0.0-beta.10", default-features = false, "tag" = "v9.0.0-beta.10", "git" = "https://github.com/lance-format/lance.git" }
lance-index = { "version" = "=9.0.0-beta.10", "tag" = "v9.0.0-beta.10", "git" = "https://github.com/lance-format/lance.git" }
lance-linalg = { "version" = "=9.0.0-beta.10", "tag" = "v9.0.0-beta.10", "git" = "https://github.com/lance-format/lance.git" }
lance-namespace = { "version" = "=9.0.0-beta.10", "tag" = "v9.0.0-beta.10", "git" = "https://github.com/lance-format/lance.git" }
lance-namespace-impls = { "version" = "=9.0.0-beta.10", default-features = false, "tag" = "v9.0.0-beta.10", "git" = "https://github.com/lance-format/lance.git" }
lance-table = { "version" = "=9.0.0-beta.10", "tag" = "v9.0.0-beta.10", "git" = "https://github.com/lance-format/lance.git" }
lance-testing = { "version" = "=9.0.0-beta.10", "tag" = "v9.0.0-beta.10", "git" = "https://github.com/lance-format/lance.git" }
lance-datafusion = { "version" = "=9.0.0-beta.10", "tag" = "v9.0.0-beta.10", "git" = "https://github.com/lance-format/lance.git" }
lance-encoding = { "version" = "=9.0.0-beta.10", "tag" = "v9.0.0-beta.10", "git" = "https://github.com/lance-format/lance.git" }
lance-arrow = { "version" = "=9.0.0-beta.10", "tag" = "v9.0.0-beta.10", "git" = "https://github.com/lance-format/lance.git" }
ahash = "0.8"
# Note that this one does not include pyarrow
arrow = { version = "58.0.0", optional = false }

View File

@@ -28,7 +28,7 @@
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<arrow.version>15.0.0</arrow.version>
<lance-core.version>9.0.0-beta.11</lance-core.version>
<lance-core.version>9.0.0-beta.10</lance-core.version>
<spotless.skip>false</spotless.skip>
<spotless.version>2.30.0</spotless.version>
<spotless.java.googlejavaformat.version>1.7</spotless.java.googlejavaformat.version>

View File

@@ -6,6 +6,7 @@ from __future__ import annotations
import asyncio
import inspect
import deprecation
import math
import warnings
from abc import ABC, abstractmethod
from dataclasses import dataclass
@@ -4055,17 +4056,35 @@ def _handle_bad_vector_column(
raise ValueError(
"`fill_value` must not be None if `on_bad_vectors` is 'fill'"
)
vec_arr = pc.if_else(
is_bad,
pa.scalar([fill_value] * dim, type=vec_arr.type),
vec_arr,
)
vec_arr = _fill_bad_vector_values(vec_arr, is_bad, dim, fill_value)
else:
raise ValueError(f"Invalid value for on_bad_vectors: {on_bad_vectors}")
return data.set_column(position, vector_column_name, vec_arr)
def _fill_bad_vector_values(
arr: Union[pa.ListArray, pa.FixedSizeListArray],
is_bad: pa.BooleanArray,
dim: int,
fill_value: float,
) -> Union[pa.ListArray, pa.FixedSizeListArray]:
values = arr.to_pylist()
for idx, bad in enumerate(is_bad.to_pylist()):
if not bad:
continue
vector = [] if values[idx] is None else values[idx]
filled = [
fill_value if isinstance(value, float) and math.isnan(value) else value
for value in vector[:dim]
]
filled.extend([fill_value] * (dim - len(filled)))
values[idx] = filled
return pa.array(values, type=arr.type)
def has_nan_values(arr: Union[pa.ListArray, pa.ChunkedArray]) -> pa.BooleanArray:
if isinstance(arr, pa.ChunkedArray):
values = pa.chunked_array([chunk.flatten() for chunk in arr.chunks])

View File

@@ -1568,16 +1568,23 @@ def test_create_with_nans(mem_db: DBConnection):
"fill_test",
data=[
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [2.1, 4.1], "item": "foo", "price": 9.0},
{"vector": [np.nan], "item": "bar", "price": 20.0},
{"vector": [np.nan, np.nan], "item": "bar", "price": 20.0},
{"vector": [np.nan, 5.0], "item": "bar", "price": 21.0},
{"vector": [5], "item": "bar", "price": 22.0},
],
on_bad_vectors="fill",
fill_value=0.0,
)
assert len(table) == 3
assert len(table) == 5
arrow_tbl = table.search().where("item == 'bar'").to_arrow()
v = arrow_tbl["vector"].to_pylist()[0]
assert np.allclose(v, np.array([0.0, 0.0]))
filled_vectors = {
row["price"]: row["vector"]
for row in arrow_tbl.select(["price", "vector"]).to_pylist()
}
assert np.allclose(filled_vectors[20.0], np.array([0.0, 0.0]))
assert np.allclose(filled_vectors[21.0], np.array([0.0, 5.0]))
assert np.allclose(filled_vectors[22.0], np.array([5.0, 0.0]))
def test_add_with_nans(mem_db: DBConnection):
@@ -1620,15 +1627,21 @@ def test_add_with_nans(mem_db: DBConnection):
data=[
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [np.nan], "item": "bar", "price": 20.0},
{"vector": [np.nan, np.nan], "item": "bar", "price": 20.0},
{"vector": [np.nan, 5.0], "item": "bar", "price": 21.0},
{"vector": [5], "item": "bar", "price": 22.0},
],
on_bad_vectors="fill",
fill_value=0.0,
)
assert len(table) == 3
assert len(table) == 4
arrow_tbl = table.search().where("item == 'bar'").to_arrow()
v = arrow_tbl["vector"].to_pylist()[0]
assert np.allclose(v, np.array([0.0, 0.0]))
filled_vectors = {
row["price"]: row["vector"]
for row in arrow_tbl.select(["price", "vector"]).to_pylist()
}
assert np.allclose(filled_vectors[20.0], np.array([0.0, 0.0]))
assert np.allclose(filled_vectors[21.0], np.array([0.0, 5.0]))
assert np.allclose(filled_vectors[22.0], np.array([5.0, 0.0]))
def test_add_with_empty_fixed_size_list_drops_bad_rows(mem_db: DBConnection):

View File

@@ -283,7 +283,7 @@ def test_handle_bad_vectors_jagged(on_bad_vectors):
if on_bad_vectors == "drop":
expected = pa.array([[1.0, 2.0], [4.0, 5.0]])
elif on_bad_vectors == "fill":
expected = pa.array([[1.0, 2.0], [42.0, 42.0], [4.0, 5.0]])
expected = pa.array([[1.0, 2.0], [3.0, 42.0], [4.0, 5.0]])
elif on_bad_vectors == "null":
expected = pa.array([[1.0, 2.0], None, [4.0, 5.0]])
@@ -319,7 +319,7 @@ def test_handle_bad_vectors_nan(on_bad_vectors):
if on_bad_vectors == "drop":
expected = pa.array([[3.0, 4.0]])
elif on_bad_vectors == "fill":
expected = pa.array([[42.0, 42.0], [3.0, 4.0]])
expected = pa.array([[1.0, 42.0], [3.0, 4.0]])
elif on_bad_vectors == "null":
expected = pa.array([None, [3.0, 4.0]])