mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-11 06:12:58 +00:00
Handle NaN input data (#241)
Sometimes LangChain would insert a single `[np.nan]` as a placeholder if the embedding function failed. This causes a problem for Lance format because then the array can't be stored as a FixedSizedListArray. Instead: 1. By default we remove rows with embedding lengths less than the maximum length in the batch 2. If `strict=True` kwargs is set to True, then a `ValueError` is raised if the embeddings aren't all the same length --------- Co-authored-by: Chang She <chang@lancedb.com>
This commit is contained in:
@@ -158,6 +158,8 @@ class LanceDBConnection:
|
||||
data: DATA = None,
|
||||
schema: pa.Schema = None,
|
||||
mode: str = "create",
|
||||
on_bad_vectors: str = "drop",
|
||||
fill_value: float = 0.0,
|
||||
) -> LanceTable:
|
||||
"""Create a table in the database.
|
||||
|
||||
@@ -173,6 +175,11 @@ class LanceDBConnection:
|
||||
The mode to use when creating the table. Can be either "create" or "overwrite".
|
||||
By default, if the table already exists, an exception is raised.
|
||||
If you want to overwrite the table, use mode="overwrite".
|
||||
on_bad_vectors: str
|
||||
What to do if any of the vectors are not the same size or contains NaNs.
|
||||
One of "raise", "drop", "fill".
|
||||
fill_value: float
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
|
||||
Note
|
||||
----
|
||||
@@ -253,7 +260,15 @@ class LanceDBConnection:
|
||||
raise ValueError("mode must be either 'create' or 'overwrite'")
|
||||
|
||||
if data is not None:
|
||||
tbl = LanceTable.create(self, name, data, schema, mode=mode)
|
||||
tbl = LanceTable.create(
|
||||
self,
|
||||
name,
|
||||
data,
|
||||
schema,
|
||||
mode=mode,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
else:
|
||||
tbl = LanceTable.open(self, name)
|
||||
return tbl
|
||||
|
||||
@@ -15,12 +15,13 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
from functools import cached_property
|
||||
from typing import List, Union
|
||||
from typing import Any, List, Union
|
||||
|
||||
import lance
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
import pyarrow.compute as pc
|
||||
from lance import LanceDataset
|
||||
from lance.vector import vec_to_table
|
||||
|
||||
@@ -28,15 +29,19 @@ from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||
from .query import LanceFtsQueryBuilder, LanceQueryBuilder
|
||||
|
||||
|
||||
def _sanitize_data(data, schema):
|
||||
def _sanitize_data(data, schema, on_bad_vectors, fill_value):
|
||||
if isinstance(data, list):
|
||||
data = pa.Table.from_pylist(data)
|
||||
data = _sanitize_schema(data, schema=schema)
|
||||
data = _sanitize_schema(
|
||||
data, schema=schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
|
||||
)
|
||||
if isinstance(data, dict):
|
||||
data = vec_to_table(data)
|
||||
if isinstance(data, pd.DataFrame):
|
||||
data = pa.Table.from_pandas(data)
|
||||
data = _sanitize_schema(data, schema=schema)
|
||||
data = _sanitize_schema(
|
||||
data, schema=schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
|
||||
)
|
||||
if not isinstance(data, pa.Table):
|
||||
raise TypeError(f"Unsupported data type: {type(data)}")
|
||||
return data
|
||||
@@ -249,7 +254,13 @@ class LanceTable:
|
||||
"""Return the LanceDataset backing this table."""
|
||||
return self._dataset
|
||||
|
||||
def add(self, data: DATA, mode: str = "append") -> int:
|
||||
def add(
|
||||
self,
|
||||
data: DATA,
|
||||
mode: str = "append",
|
||||
on_bad_vectors: str = "drop",
|
||||
fill_value: float = 0.0,
|
||||
) -> int:
|
||||
"""Add data to the table.
|
||||
|
||||
Parameters
|
||||
@@ -259,13 +270,20 @@ class LanceTable:
|
||||
mode: str
|
||||
The mode to use when writing the data. Valid values are
|
||||
"append" and "overwrite".
|
||||
on_bad_vectors: str
|
||||
What to do if any of the vectors are not the same size or contains NaNs.
|
||||
One of "raise", "drop", "fill".
|
||||
fill_value: float, default 0.
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
The number of vectors in the table.
|
||||
"""
|
||||
data = _sanitize_data(data, self.schema)
|
||||
data = _sanitize_data(
|
||||
data, self.schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
|
||||
)
|
||||
lance.write_dataset(data, self._dataset_uri, mode=mode)
|
||||
self._reset_dataset()
|
||||
return len(self)
|
||||
@@ -280,6 +298,8 @@ class LanceTable:
|
||||
----------
|
||||
query: list, np.ndarray
|
||||
The query vector.
|
||||
vector_column_name: str, default "vector"
|
||||
The name of the vector column to search.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -302,9 +322,55 @@ class LanceTable:
|
||||
return LanceQueryBuilder(self, query, vector_column_name)
|
||||
|
||||
@classmethod
|
||||
def create(cls, db, name, data, schema=None, mode="create"):
|
||||
def create(
|
||||
cls,
|
||||
db,
|
||||
name,
|
||||
data,
|
||||
schema=None,
|
||||
mode="create",
|
||||
on_bad_vectors: str = "drop",
|
||||
fill_value: float = 0.0,
|
||||
):
|
||||
"""
|
||||
Create a new table.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> import pandas as pd
|
||||
>>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1, 2], [3, 4], [5, 6]]})
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> table = db.create_table("my_table", data)
|
||||
>>> table.to_pandas()
|
||||
x vector
|
||||
0 1 [1.0, 2.0]
|
||||
1 2 [3.0, 4.0]
|
||||
2 3 [5.0, 6.0]
|
||||
|
||||
Parameters
|
||||
----------
|
||||
db: LanceDB
|
||||
The LanceDB instance to create the table in.
|
||||
name: str
|
||||
The name of the table to create.
|
||||
data: list-of-dict, dict, pd.DataFrame
|
||||
The data to insert into the table.
|
||||
schema: dict, optional
|
||||
The schema of the table. If not provided, the schema is inferred from the data.
|
||||
mode: str, default "create"
|
||||
The mode to use when writing the data. Valid values are
|
||||
"create", "overwrite", and "append".
|
||||
on_bad_vectors: str
|
||||
What to do if any of the vectors are not the same size or contains NaNs.
|
||||
One of "raise", "drop", "fill".
|
||||
fill_value: float, default 0.
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
"""
|
||||
tbl = LanceTable(db, name)
|
||||
data = _sanitize_data(data, schema)
|
||||
data = _sanitize_data(
|
||||
data, schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value
|
||||
)
|
||||
lance.write_dataset(data, tbl._dataset_uri, mode=mode)
|
||||
return tbl
|
||||
|
||||
@@ -350,7 +416,12 @@ class LanceTable:
|
||||
self._dataset.delete(where)
|
||||
|
||||
|
||||
def _sanitize_schema(data: pa.Table, schema: pa.Schema = None) -> pa.Table:
|
||||
def _sanitize_schema(
|
||||
data: pa.Table,
|
||||
schema: pa.Schema = None,
|
||||
on_bad_vectors: str = "drop",
|
||||
fill_value: float = 0.0,
|
||||
) -> pa.Table:
|
||||
"""Ensure that the table has the expected schema.
|
||||
|
||||
Parameters
|
||||
@@ -360,21 +431,41 @@ def _sanitize_schema(data: pa.Table, schema: pa.Schema = None) -> pa.Table:
|
||||
schema: pa.Schema; optional
|
||||
The expected schema. If not provided, this just converts the
|
||||
vector column to fixed_size_list(float32) if necessary.
|
||||
on_bad_vectors: str
|
||||
What to do if any of the vectors are not the same size or contains NaNs.
|
||||
One of "raise", "drop", "fill".
|
||||
fill_value: float
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
"""
|
||||
if schema is not None:
|
||||
if data.schema == schema:
|
||||
return data
|
||||
# cast the columns to the expected types
|
||||
data = data.combine_chunks()
|
||||
data = _sanitize_vector_column(data, vector_column_name=VECTOR_COLUMN_NAME)
|
||||
data = _sanitize_vector_column(
|
||||
data,
|
||||
vector_column_name=VECTOR_COLUMN_NAME,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
return pa.Table.from_arrays(
|
||||
[data[name] for name in schema.names], schema=schema
|
||||
)
|
||||
# just check the vector column
|
||||
return _sanitize_vector_column(data, vector_column_name=VECTOR_COLUMN_NAME)
|
||||
return _sanitize_vector_column(
|
||||
data,
|
||||
vector_column_name=VECTOR_COLUMN_NAME,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
|
||||
|
||||
def _sanitize_vector_column(data: pa.Table, vector_column_name: str) -> pa.Table:
|
||||
def _sanitize_vector_column(
|
||||
data: pa.Table,
|
||||
vector_column_name: str,
|
||||
on_bad_vectors: str = "drop",
|
||||
fill_value: float = 0.0,
|
||||
) -> pa.Table:
|
||||
"""
|
||||
Ensure that the vector column exists and has type fixed_size_list(float32)
|
||||
|
||||
@@ -384,19 +475,103 @@ def _sanitize_vector_column(data: pa.Table, vector_column_name: str) -> pa.Table
|
||||
The table to sanitize.
|
||||
vector_column_name: str
|
||||
The name of the vector column.
|
||||
on_bad_vectors: str
|
||||
What to do if any of the vectors are not the same size or contains NaNs.
|
||||
One of "raise", "drop", "fill".
|
||||
fill_value: float
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
"""
|
||||
if vector_column_name not in data.column_names:
|
||||
raise ValueError(f"Missing vector column: {vector_column_name}")
|
||||
# ChunkedArray is annoying to work with, so we combine chunks here
|
||||
vec_arr = data[vector_column_name].combine_chunks()
|
||||
if pa.types.is_fixed_size_list(vec_arr.type):
|
||||
return data
|
||||
if not pa.types.is_list(vec_arr.type):
|
||||
if pa.types.is_list(data[vector_column_name].type):
|
||||
# if it's a variable size list array we make sure the dimensions are all the same
|
||||
has_jagged_ndims = len(vec_arr.values) % len(data) != 0
|
||||
if has_jagged_ndims:
|
||||
data = _sanitize_jagged(
|
||||
data, fill_value, on_bad_vectors, vec_arr, vector_column_name
|
||||
)
|
||||
vec_arr = data[vector_column_name].combine_chunks()
|
||||
elif not pa.types.is_fixed_size_list(vec_arr.type):
|
||||
raise TypeError(f"Unsupported vector column type: {vec_arr.type}")
|
||||
|
||||
vec_arr = ensure_fixed_size_list_of_f32(vec_arr)
|
||||
data = data.set_column(
|
||||
data.column_names.index(vector_column_name), vector_column_name, vec_arr
|
||||
)
|
||||
|
||||
has_nans = pc.any(vec_arr.values.is_nan()).as_py()
|
||||
if has_nans:
|
||||
data = _sanitize_nans(
|
||||
data, fill_value, on_bad_vectors, vec_arr, vector_column_name
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def ensure_fixed_size_list_of_f32(vec_arr):
|
||||
values = vec_arr.values
|
||||
if not pa.types.is_float32(values.type):
|
||||
values = values.cast(pa.float32())
|
||||
list_size = len(values) / len(data)
|
||||
if pa.types.is_fixed_size_list(vec_arr.type):
|
||||
list_size = vec_arr.type.list_size
|
||||
else:
|
||||
list_size = len(values) / len(vec_arr)
|
||||
vec_arr = pa.FixedSizeListArray.from_arrays(values, list_size)
|
||||
return data.set_column(
|
||||
data.column_names.index(vector_column_name), vector_column_name, vec_arr
|
||||
)
|
||||
return vec_arr
|
||||
|
||||
|
||||
def _sanitize_jagged(data, fill_value, on_bad_vectors, vec_arr, vector_column_name):
|
||||
"""Sanitize jagged vectors."""
|
||||
if on_bad_vectors == "raise":
|
||||
raise ValueError(
|
||||
f"Vector column {vector_column_name} has variable length vectors "
|
||||
"Set on_bad_vectors='drop' to remove them, or "
|
||||
"set on_bad_vectors='fill' and fill_value=<value> to replace them."
|
||||
)
|
||||
|
||||
lst_lengths = pc.list_value_length(vec_arr)
|
||||
ndims = pc.max(lst_lengths).as_py()
|
||||
correct_ndims = pc.equal(lst_lengths, ndims)
|
||||
|
||||
if on_bad_vectors == "fill":
|
||||
if fill_value is None:
|
||||
raise ValueError(
|
||||
f"`fill_value` must not be None if `on_bad_vectors` is 'fill'"
|
||||
)
|
||||
fill_arr = pa.scalar([float(fill_value)] * ndims)
|
||||
vec_arr = pc.if_else(correct_ndims, vec_arr, fill_arr)
|
||||
data = data.set_column(
|
||||
data.column_names.index(vector_column_name), vector_column_name, vec_arr
|
||||
)
|
||||
elif on_bad_vectors == "drop":
|
||||
data = data.filter(correct_ndims)
|
||||
return data
|
||||
|
||||
|
||||
def _sanitize_nans(data, fill_value, on_bad_vectors, vec_arr, vector_column_name):
|
||||
"""Sanitize NaNs in vectors"""
|
||||
if on_bad_vectors == "raise":
|
||||
raise ValueError(
|
||||
f"Vector column {vector_column_name} has NaNs. "
|
||||
"Set on_bad_vectors='drop' to remove them, or "
|
||||
"set on_bad_vectors='fill' and fill_value=<value> to replace them."
|
||||
)
|
||||
elif on_bad_vectors == "fill":
|
||||
if fill_value is None:
|
||||
raise ValueError(
|
||||
f"`fill_value` must not be None if `on_bad_vectors` is 'fill'"
|
||||
)
|
||||
fill_value = float(fill_value)
|
||||
values = pc.if_else(vec_arr.values.is_nan(), fill_value, vec_arr.values)
|
||||
ndims = len(vec_arr[0])
|
||||
vec_arr = pa.FixedSizeListArray.from_arrays(values, ndims)
|
||||
data = data.set_column(
|
||||
data.column_names.index(vector_column_name), vector_column_name, vec_arr
|
||||
)
|
||||
elif on_bad_vectors == "drop":
|
||||
is_value_nan = pc.is_nan(vec_arr.values).to_numpy(zero_copy_only=False)
|
||||
is_full = np.any(~is_value_nan.reshape(-1, vec_arr.type.list_size), axis=1)
|
||||
data = data.filter(is_full)
|
||||
return data
|
||||
|
||||
@@ -15,7 +15,6 @@ import unittest.mock as mock
|
||||
|
||||
import lance
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pandas.testing as tm
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
|
||||
@@ -30,7 +30,7 @@ class MockLanceDBServer:
|
||||
table_name = request.match_info["table_name"]
|
||||
assert table_name == "test_table"
|
||||
|
||||
request_json = await request.json()
|
||||
await request.json()
|
||||
# TODO: do some matching
|
||||
|
||||
vecs = pd.Series([np.random.rand(128) for x in range(10)], name="vector")
|
||||
|
||||
@@ -15,6 +15,7 @@ import functools
|
||||
from pathlib import Path
|
||||
from unittest.mock import PropertyMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
@@ -167,7 +168,8 @@ def test_create_index_method():
|
||||
replace=True,
|
||||
)
|
||||
|
||||
# Check that the _dataset.create_index method was called with the right parameters
|
||||
# Check that the _dataset.create_index method was called
|
||||
# with the right parameters
|
||||
mock_dataset.return_value.create_index.assert_called_once_with(
|
||||
column="vector",
|
||||
index_type="IVF_PQ",
|
||||
@@ -176,3 +178,50 @@ def test_create_index_method():
|
||||
num_sub_vectors=96,
|
||||
replace=True,
|
||||
)
|
||||
|
||||
|
||||
def test_add_with_nans(db):
|
||||
# By default we drop bad input vectors
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"drop_test",
|
||||
data=[
|
||||
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||
{"vector": [np.nan], "item": "bar", "price": 20.0},
|
||||
{"vector": [5], "item": "bar", "price": 20.0},
|
||||
{"vector": [np.nan, np.nan], "item": "bar", "price": 20.0},
|
||||
],
|
||||
)
|
||||
assert len(table) == 1
|
||||
|
||||
# We can fill bad input with some value
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"fill_test",
|
||||
data=[
|
||||
{"vector": [3.1, 4.1], "item": "foo", "price": 10.0},
|
||||
{"vector": [np.nan], "item": "bar", "price": 20.0},
|
||||
{"vector": [np.nan, np.nan], "item": "bar", "price": 20.0},
|
||||
],
|
||||
on_bad_vectors="fill",
|
||||
fill_value=0.0,
|
||||
)
|
||||
assert len(table) == 3
|
||||
arrow_tbl = table.to_lance().to_table(filter="item == 'bar'")
|
||||
v = arrow_tbl["vector"].to_pylist()[0]
|
||||
assert np.allclose(v, np.array([0.0, 0.0]))
|
||||
|
||||
bad_data = [
|
||||
{"vector": [np.nan], "item": "bar", "price": 20.0},
|
||||
{"vector": [5], "item": "bar", "price": 20.0},
|
||||
{"vector": [np.nan, np.nan], "item": "bar", "price": 20.0},
|
||||
{"vector": [np.nan, 5.0], "item": "bar", "price": 20.0},
|
||||
]
|
||||
for row in bad_data:
|
||||
with pytest.raises(ValueError):
|
||||
LanceTable.create(
|
||||
db,
|
||||
"raise_test",
|
||||
data=[{"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, row],
|
||||
on_bad_vectors="raise",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user