mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-19 21:10:41 +00:00
Handle NaN input data (#241)
Sometimes LangChain would insert a single `[np.nan]` as a placeholder if the embedding function failed. This causes a problem for Lance format because then the array can't be stored as a FixedSizedListArray. Instead: 1. By default we remove rows with embedding lengths less than the maximum length in the batch 2. If `strict=True` kwargs is set to True, then a `ValueError` is raised if the embeddings aren't all the same length --------- Co-authored-by: Chang She <chang@lancedb.com>
This commit is contained in:
@@ -15,6 +15,7 @@ import functools
|
||||
from pathlib import Path
|
||||
from unittest.mock import PropertyMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
@@ -167,7 +168,8 @@ def test_create_index_method():
|
||||
replace=True,
|
||||
)
|
||||
|
||||
# Check that the _dataset.create_index method was called with the right parameters
|
||||
# Check that the _dataset.create_index method was called
|
||||
# with the right parameters
|
||||
mock_dataset.return_value.create_index.assert_called_once_with(
|
||||
column="vector",
|
||||
index_type="IVF_PQ",
|
||||
@@ -176,3 +178,50 @@ def test_create_index_method():
|
||||
num_sub_vectors=96,
|
||||
replace=True,
|
||||
)
|
||||
|
||||
|
||||
def test_add_with_nans(db):
|
||||
# By default we drop bad input vectors
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"drop_test",
|
||||
data=[
|
||||
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||
{"vector": [np.nan], "item": "bar", "price": 20.0},
|
||||
{"vector": [5], "item": "bar", "price": 20.0},
|
||||
{"vector": [np.nan, np.nan], "item": "bar", "price": 20.0},
|
||||
],
|
||||
)
|
||||
assert len(table) == 1
|
||||
|
||||
# We can fill bad input with some value
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"fill_test",
|
||||
data=[
|
||||
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||
{"vector": [np.nan], "item": "bar", "price": 20.0},
|
||||
{"vector": [np.nan, np.nan], "item": "bar", "price": 20.0},
|
||||
],
|
||||
on_bad_vectors="fill",
|
||||
fill_value=0.0,
|
||||
)
|
||||
assert len(table) == 3
|
||||
arrow_tbl = table.to_lance().to_table(filter="item == 'bar'")
|
||||
v = arrow_tbl["vector"].to_pylist()[0]
|
||||
assert np.allclose(v, np.array([0.0, 0.0]))
|
||||
|
||||
bad_data = [
|
||||
{"vector": [np.nan], "item": "bar", "price": 20.0},
|
||||
{"vector": [5], "item": "bar", "price": 20.0},
|
||||
{"vector": [np.nan, np.nan], "item": "bar", "price": 20.0},
|
||||
{"vector": [np.nan, 5.0], "item": "bar", "price": 20.0},
|
||||
]
|
||||
for row in bad_data:
|
||||
with pytest.raises(ValueError):
|
||||
LanceTable.create(
|
||||
db,
|
||||
"raise_test",
|
||||
data=[{"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, row],
|
||||
on_bad_vectors="raise",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user