Handle NaN input data (#241)

Sometimes LangChain would insert a single `[np.nan]` as a placeholder if
the embedding function failed. This causes a problem for Lance format
because then the array can't be stored as a FixedSizedListArray.

Instead:
1. By default we remove rows with embedding lengths less than the
maximum length in the batch
2. If `strict=True` kwargs is set to True, then a `ValueError` is raised
if the embeddings aren't all the same length

---------

Co-authored-by: Chang She <chang@lancedb.com>
This commit is contained in:
Chang She
2023-07-04 20:00:46 -07:00
committed by GitHub
parent 9600a38ff0
commit 3c46d7f268
7 changed files with 273 additions and 24 deletions

View File

@@ -158,6 +158,8 @@ class LanceDBConnection:
data: DATA = None,
schema: pa.Schema = None,
mode: str = "create",
on_bad_vectors: str = "drop",
fill_value: float = 0.0,
) -> LanceTable:
"""Create a table in the database.
@@ -173,6 +175,11 @@ class LanceDBConnection:
The mode to use when creating the table. Can be either "create" or "overwrite".
By default, if the table already exists, an exception is raised.
If you want to overwrite the table, use mode="overwrite".
on_bad_vectors: str
What to do if any of the vectors are not the same size or contains NaNs.
One of "raise", "drop", "fill".
fill_value: float
The value to use when filling vectors. Only used if on_bad_vectors="fill".
Note
----
@@ -253,7 +260,15 @@ class LanceDBConnection:
raise ValueError("mode must be either 'create' or 'overwrite'")
if data is not None:
tbl = LanceTable.create(self, name, data, schema, mode=mode)
tbl = LanceTable.create(
self,
name,
data,
schema,
mode=mode,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
else:
tbl = LanceTable.open(self, name)
return tbl

View File

@@ -15,12 +15,13 @@ from __future__ import annotations
import os
from functools import cached_property
from typing import List, Union
from typing import Any, List, Union
import lance
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.compute as pc
from lance import LanceDataset
from lance.vector import vec_to_table
@@ -28,15 +29,19 @@ from .common import DATA, VEC, VECTOR_COLUMN_NAME
from .query import LanceFtsQueryBuilder, LanceQueryBuilder
def _sanitize_data(data, schema):
def _sanitize_data(data, schema, on_bad_vectors, fill_value):
if isinstance(data, list):
data = pa.Table.from_pylist(data)
data = _sanitize_schema(data, schema=schema)
data = _sanitize_schema(
data, schema=schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
)
if isinstance(data, dict):
data = vec_to_table(data)
if isinstance(data, pd.DataFrame):
data = pa.Table.from_pandas(data)
data = _sanitize_schema(data, schema=schema)
data = _sanitize_schema(
data, schema=schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
)
if not isinstance(data, pa.Table):
raise TypeError(f"Unsupported data type: {type(data)}")
return data
@@ -249,7 +254,13 @@ class LanceTable:
"""Return the LanceDataset backing this table."""
return self._dataset
def add(self, data: DATA, mode: str = "append") -> int:
def add(
self,
data: DATA,
mode: str = "append",
on_bad_vectors: str = "drop",
fill_value: float = 0.0,
) -> int:
"""Add data to the table.
Parameters
@@ -259,13 +270,20 @@ class LanceTable:
mode: str
The mode to use when writing the data. Valid values are
"append" and "overwrite".
on_bad_vectors: str
What to do if any of the vectors are not the same size or contains NaNs.
One of "raise", "drop", "fill".
fill_value: float, default 0.
The value to use when filling vectors. Only used if on_bad_vectors="fill".
Returns
-------
int
The number of vectors in the table.
"""
data = _sanitize_data(data, self.schema)
data = _sanitize_data(
data, self.schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
)
lance.write_dataset(data, self._dataset_uri, mode=mode)
self._reset_dataset()
return len(self)
@@ -280,6 +298,8 @@ class LanceTable:
----------
query: list, np.ndarray
The query vector.
vector_column_name: str, default "vector"
The name of the vector column to search.
Returns
-------
@@ -302,9 +322,55 @@ class LanceTable:
return LanceQueryBuilder(self, query, vector_column_name)
@classmethod
def create(cls, db, name, data, schema=None, mode="create"):
def create(
cls,
db,
name,
data,
schema=None,
mode="create",
on_bad_vectors: str = "drop",
fill_value: float = 0.0,
):
"""
Create a new table.
Examples
--------
>>> import lancedb
>>> import pandas as pd
>>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1, 2], [3, 4], [5, 6]]})
>>> db = lancedb.connect("./.lancedb")
>>> table = db.create_table("my_table", data)
>>> table.to_pandas()
x vector
0 1 [1.0, 2.0]
1 2 [3.0, 4.0]
2 3 [5.0, 6.0]
Parameters
----------
db: LanceDB
The LanceDB instance to create the table in.
name: str
The name of the table to create.
data: list-of-dict, dict, pd.DataFrame
The data to insert into the table.
schema: dict, optional
The schema of the table. If not provided, the schema is inferred from the data.
mode: str, default "create"
The mode to use when writing the data. Valid values are
"create", "overwrite", and "append".
on_bad_vectors: str
What to do if any of the vectors are not the same size or contains NaNs.
One of "raise", "drop", "fill".
fill_value: float, default 0.
The value to use when filling vectors. Only used if on_bad_vectors="fill".
"""
tbl = LanceTable(db, name)
data = _sanitize_data(data, schema)
data = _sanitize_data(
data, schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
)
lance.write_dataset(data, tbl._dataset_uri, mode=mode)
return tbl
@@ -350,7 +416,12 @@ class LanceTable:
self._dataset.delete(where)
def _sanitize_schema(data: pa.Table, schema: pa.Schema = None) -> pa.Table:
def _sanitize_schema(
data: pa.Table,
schema: pa.Schema = None,
on_bad_vectors: str = "drop",
fill_value: float = 0.0,
) -> pa.Table:
"""Ensure that the table has the expected schema.
Parameters
@@ -360,21 +431,41 @@ def _sanitize_schema(data: pa.Table, schema: pa.Schema = None) -> pa.Table:
schema: pa.Schema; optional
The expected schema. If not provided, this just converts the
vector column to fixed_size_list(float32) if necessary.
on_bad_vectors: str
What to do if any of the vectors are not the same size or contains NaNs.
One of "raise", "drop", "fill".
fill_value: float
The value to use when filling vectors. Only used if on_bad_vectors="fill".
"""
if schema is not None:
if data.schema == schema:
return data
# cast the columns to the expected types
data = data.combine_chunks()
data = _sanitize_vector_column(data, vector_column_name=VECTOR_COLUMN_NAME)
data = _sanitize_vector_column(
data,
vector_column_name=VECTOR_COLUMN_NAME,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
return pa.Table.from_arrays(
[data[name] for name in schema.names], schema=schema
)
# just check the vector column
return _sanitize_vector_column(data, vector_column_name=VECTOR_COLUMN_NAME)
return _sanitize_vector_column(
data,
vector_column_name=VECTOR_COLUMN_NAME,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
def _sanitize_vector_column(data: pa.Table, vector_column_name: str) -> pa.Table:
def _sanitize_vector_column(
data: pa.Table,
vector_column_name: str,
on_bad_vectors: str = "drop",
fill_value: float = 0.0,
) -> pa.Table:
"""
Ensure that the vector column exists and has type fixed_size_list(float32)
@@ -384,19 +475,103 @@ def _sanitize_vector_column(data: pa.Table, vector_column_name: str) -> pa.Table
The table to sanitize.
vector_column_name: str
The name of the vector column.
on_bad_vectors: str
What to do if any of the vectors are not the same size or contains NaNs.
One of "raise", "drop", "fill".
fill_value: float
The value to use when filling vectors. Only used if on_bad_vectors="fill".
"""
if vector_column_name not in data.column_names:
raise ValueError(f"Missing vector column: {vector_column_name}")
# ChunkedArray is annoying to work with, so we combine chunks here
vec_arr = data[vector_column_name].combine_chunks()
if pa.types.is_fixed_size_list(vec_arr.type):
return data
if not pa.types.is_list(vec_arr.type):
if pa.types.is_list(data[vector_column_name].type):
# if it's a variable size list array we make sure the dimensions are all the same
has_jagged_ndims = len(vec_arr.values) % len(data) != 0
if has_jagged_ndims:
data = _sanitize_jagged(
data, fill_value, on_bad_vectors, vec_arr, vector_column_name
)
vec_arr = data[vector_column_name].combine_chunks()
elif not pa.types.is_fixed_size_list(vec_arr.type):
raise TypeError(f"Unsupported vector column type: {vec_arr.type}")
vec_arr = ensure_fixed_size_list_of_f32(vec_arr)
data = data.set_column(
data.column_names.index(vector_column_name), vector_column_name, vec_arr
)
has_nans = pc.any(vec_arr.values.is_nan()).as_py()
if has_nans:
data = _sanitize_nans(
data, fill_value, on_bad_vectors, vec_arr, vector_column_name
)
return data
def ensure_fixed_size_list_of_f32(vec_arr):
values = vec_arr.values
if not pa.types.is_float32(values.type):
values = values.cast(pa.float32())
list_size = len(values) / len(data)
if pa.types.is_fixed_size_list(vec_arr.type):
list_size = vec_arr.type.list_size
else:
list_size = len(values) / len(vec_arr)
vec_arr = pa.FixedSizeListArray.from_arrays(values, list_size)
return data.set_column(
data.column_names.index(vector_column_name), vector_column_name, vec_arr
)
return vec_arr
def _sanitize_jagged(data, fill_value, on_bad_vectors, vec_arr, vector_column_name):
"""Sanitize jagged vectors."""
if on_bad_vectors == "raise":
raise ValueError(
f"Vector column {vector_column_name} has variable length vectors "
"Set on_bad_vectors='drop' to remove them, or "
"set on_bad_vectors='fill' and fill_value=<value> to replace them."
)
lst_lengths = pc.list_value_length(vec_arr)
ndims = pc.max(lst_lengths).as_py()
correct_ndims = pc.equal(lst_lengths, ndims)
if on_bad_vectors == "fill":
if fill_value is None:
raise ValueError(
f"`fill_value` must not be None if `on_bad_vectors` is 'fill'"
)
fill_arr = pa.scalar([float(fill_value)] * ndims)
vec_arr = pc.if_else(correct_ndims, vec_arr, fill_arr)
data = data.set_column(
data.column_names.index(vector_column_name), vector_column_name, vec_arr
)
elif on_bad_vectors == "drop":
data = data.filter(correct_ndims)
return data
def _sanitize_nans(data, fill_value, on_bad_vectors, vec_arr, vector_column_name):
"""Sanitize NaNs in vectors"""
if on_bad_vectors == "raise":
raise ValueError(
f"Vector column {vector_column_name} has NaNs. "
"Set on_bad_vectors='drop' to remove them, or "
"set on_bad_vectors='fill' and fill_value=<value> to replace them."
)
elif on_bad_vectors == "fill":
if fill_value is None:
raise ValueError(
f"`fill_value` must not be None if `on_bad_vectors` is 'fill'"
)
fill_value = float(fill_value)
values = pc.if_else(vec_arr.values.is_nan(), fill_value, vec_arr.values)
ndims = len(vec_arr[0])
vec_arr = pa.FixedSizeListArray.from_arrays(values, ndims)
data = data.set_column(
data.column_names.index(vector_column_name), vector_column_name, vec_arr
)
elif on_bad_vectors == "drop":
is_value_nan = pc.is_nan(vec_arr.values).to_numpy(zero_copy_only=False)
is_full = np.any(~is_value_nan.reshape(-1, vec_arr.type.list_size), axis=1)
data = data.filter(is_full)
return data

View File

@@ -15,7 +15,6 @@ import unittest.mock as mock
import lance
import numpy as np
import pandas as pd
import pandas.testing as tm
import pyarrow as pa
import pytest

View File

@@ -30,7 +30,7 @@ class MockLanceDBServer:
table_name = request.match_info["table_name"]
assert table_name == "test_table"
request_json = await request.json()
await request.json()
# TODO: do some matching
vecs = pd.Series([np.random.rand(128) for x in range(10)], name="vector")

View File

@@ -15,6 +15,7 @@ import functools
from pathlib import Path
from unittest.mock import PropertyMock, patch
import numpy as np
import pandas as pd
import pyarrow as pa
import pytest
@@ -167,7 +168,8 @@ def test_create_index_method():
replace=True,
)
# Check that the _dataset.create_index method was called with the right parameters
# Check that the _dataset.create_index method was called
# with the right parameters
mock_dataset.return_value.create_index.assert_called_once_with(
column="vector",
index_type="IVF_PQ",
@@ -176,3 +178,50 @@ def test_create_index_method():
num_sub_vectors=96,
replace=True,
)
def test_add_with_nans(db):
# By default we drop bad input vectors
table = LanceTable.create(
db,
"drop_test",
data=[
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [np.nan], "item": "bar", "price": 20.0},
{"vector": [5], "item": "bar", "price": 20.0},
{"vector": [np.nan, np.nan], "item": "bar", "price": 20.0},
],
)
assert len(table) == 1
# We can fill bad input with some value
table = LanceTable.create(
db,
"fill_test",
data=[
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
{"vector": [np.nan], "item": "bar", "price": 20.0},
{"vector": [np.nan, np.nan], "item": "bar", "price": 20.0},
],
on_bad_vectors="fill",
fill_value=0.0,
)
assert len(table) == 3
arrow_tbl = table.to_lance().to_table(filter="item == 'bar'")
v = arrow_tbl["vector"].to_pylist()[0]
assert np.allclose(v, np.array([0.0, 0.0]))
bad_data = [
{"vector": [np.nan], "item": "bar", "price": 20.0},
{"vector": [5], "item": "bar", "price": 20.0},
{"vector": [np.nan, np.nan], "item": "bar", "price": 20.0},
{"vector": [np.nan, 5.0], "item": "bar", "price": 20.0},
]
for row in bad_data:
with pytest.raises(ValueError):
LanceTable.create(
db,
"raise_test",
data=[{"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, row],
on_bad_vectors="raise",
)