Handle NaN input data (#241)

Sometimes LangChain would insert a single `[np.nan]` as a placeholder if
the embedding function failed. This causes a problem for Lance format
because then the array can't be stored as a FixedSizedListArray.

Instead:
1. By default we remove rows with embedding lengths less than the
maximum length in the batch
2. If `strict=True` kwargs is set to True, then a `ValueError` is raised
if the embeddings aren't all the same length

---------

Co-authored-by: Chang She <chang@lancedb.com>
This commit is contained in:
Chang She
2023-07-04 20:00:46 -07:00
committed by GitHub
parent 9600a38ff0
commit 3c46d7f268
7 changed files with 273 additions and 24 deletions

View File

@@ -15,6 +15,7 @@ import functools
from pathlib import Path
from unittest.mock import PropertyMock, patch
import numpy as np
import pandas as pd
import pyarrow as pa
import pytest
@@ -167,7 +168,8 @@ def test_create_index_method():
replace=True,
)
# Check that the _dataset.create_index method was called with the right parameters
# Check that the _dataset.create_index method was called
# with the right parameters
mock_dataset.return_value.create_index.assert_called_once_with(
column="vector",
index_type="IVF_PQ",
@@ -176,3 +178,50 @@ def test_create_index_method():
num_sub_vectors=96,
replace=True,
)
def test_add_with_nans(db):
# By default we drop bad input vectors
table = LanceTable.create(
db,
"drop_test",
data=[
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [np.nan], "item": "bar", "price": 20.0},
{"vector": [5], "item": "bar", "price": 20.0},
{"vector": [np.nan, np.nan], "item": "bar", "price": 20.0},
],
)
assert len(table) == 1
# We can fill bad input with some value
table = LanceTable.create(
db,
"fill_test",
data=[
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [np.nan], "item": "bar", "price": 20.0},
{"vector": [np.nan, np.nan], "item": "bar", "price": 20.0},
],
on_bad_vectors="fill",
fill_value=0.0,
)
assert len(table) == 3
arrow_tbl = table.to_lance().to_table(filter="item == 'bar'")
v = arrow_tbl["vector"].to_pylist()[0]
assert np.allclose(v, np.array([0.0, 0.0]))
bad_data = [
{"vector": [np.nan], "item": "bar", "price": 20.0},
{"vector": [5], "item": "bar", "price": 20.0},
{"vector": [np.nan, np.nan], "item": "bar", "price": 20.0},
{"vector": [np.nan, 5.0], "item": "bar", "price": 20.0},
]
for row in bad_data:
with pytest.raises(ValueError):
LanceTable.create(
db,
"raise_test",
data=[{"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, row],
on_bad_vectors="raise",
)