Set default to error instead of drop (#259)

when encountering bad input data, we can default to principle of least
surprise and raise an exception.

Co-authored-by: Chang She <chang@lancedb.com>
This commit is contained in:
Chang She
2023-07-05 22:44:18 -07:00
committed by GitHub
parent bb3df62dce
commit 507eeae9c8
6 changed files with 40 additions and 43 deletions

View File

@@ -158,7 +158,7 @@ class LanceDBConnection:
data: DATA = None,
schema: pa.Schema = None,
mode: str = "create",
on_bad_vectors: str = "drop",
on_bad_vectors: str = "error",
fill_value: float = 0.0,
) -> LanceTable:
"""Create a table in the database.
@@ -175,9 +175,9 @@ class LanceDBConnection:
The mode to use when creating the table. Can be either "create" or "overwrite".
By default, if the table already exists, an exception is raised.
If you want to overwrite the table, use mode="overwrite".
on_bad_vectors: str
on_bad_vectors: str, default "error"
What to do if any of the vectors are not the same size or contains NaNs.
One of "raise", "drop", "fill".
One of "error", "drop", "fill".
fill_value: float
The value to use when filling vectors. Only used if on_bad_vectors="fill".

View File

@@ -13,7 +13,7 @@
from __future__ import annotations
import asyncio
from typing import Awaitable, Literal
from typing import Literal
import numpy as np
import pandas as pd

View File

@@ -15,7 +15,6 @@ import abc
from typing import List, Optional
import attr
import pandas as pd
import pyarrow as pa
from pydantic import BaseModel

View File

@@ -15,7 +15,7 @@ from __future__ import annotations
import os
from functools import cached_property
from typing import Any, List, Union
from typing import List, Union
import lance
import numpy as np
@@ -258,7 +258,7 @@ class LanceTable:
self,
data: DATA,
mode: str = "append",
on_bad_vectors: str = "drop",
on_bad_vectors: str = "error",
fill_value: float = 0.0,
) -> int:
"""Add data to the table.
@@ -270,9 +270,9 @@ class LanceTable:
mode: str
The mode to use when writing the data. Valid values are
"append" and "overwrite".
on_bad_vectors: str
on_bad_vectors: str, default "error"
What to do if any of the vectors are not the same size or contains NaNs.
One of "raise", "drop", "fill".
One of "error", "drop", "fill".
fill_value: float, default 0.
The value to use when filling vectors. Only used if on_bad_vectors="fill".
@@ -329,7 +329,7 @@ class LanceTable:
data,
schema=None,
mode="create",
on_bad_vectors: str = "drop",
on_bad_vectors: str = "error",
fill_value: float = 0.0,
):
"""
@@ -361,9 +361,9 @@ class LanceTable:
mode: str, default "create"
The mode to use when writing the data. Valid values are
"create", "overwrite", and "append".
on_bad_vectors: str
on_bad_vectors: str, default "error"
What to do if any of the vectors are not the same size or contains NaNs.
One of "raise", "drop", "fill".
One of "error", "drop", "fill".
fill_value: float, default 0.
The value to use when filling vectors. Only used if on_bad_vectors="fill".
"""
@@ -419,7 +419,7 @@ class LanceTable:
def _sanitize_schema(
data: pa.Table,
schema: pa.Schema = None,
on_bad_vectors: str = "drop",
on_bad_vectors: str = "error",
fill_value: float = 0.0,
) -> pa.Table:
"""Ensure that the table has the expected schema.
@@ -431,10 +431,10 @@ def _sanitize_schema(
schema: pa.Schema; optional
The expected schema. If not provided, this just converts the
vector column to fixed_size_list(float32) if necessary.
on_bad_vectors: str
on_bad_vectors: str, default "error"
What to do if any of the vectors are not the same size or contains NaNs.
One of "raise", "drop", "fill".
fill_value: float
One of "error", "drop", "fill".
fill_value: float, default 0.
The value to use when filling vectors. Only used if on_bad_vectors="fill".
"""
if schema is not None:
@@ -463,7 +463,7 @@ def _sanitize_schema(
def _sanitize_vector_column(
data: pa.Table,
vector_column_name: str,
on_bad_vectors: str = "drop",
on_bad_vectors: str = "error",
fill_value: float = 0.0,
) -> pa.Table:
"""
@@ -475,10 +475,10 @@ def _sanitize_vector_column(
The table to sanitize.
vector_column_name: str
The name of the vector column.
on_bad_vectors: str
on_bad_vectors: str, default "error"
What to do if any of the vectors are not the same size or contains NaNs.
One of "raise", "drop", "fill".
fill_value: float
One of "error", "drop", "fill".
fill_value: float, default 0.0
The value to use when filling vectors. Only used if on_bad_vectors="fill".
"""
if vector_column_name not in data.column_names:
@@ -524,7 +524,7 @@ def ensure_fixed_size_list_of_f32(vec_arr):
def _sanitize_jagged(data, fill_value, on_bad_vectors, vec_arr, vector_column_name):
"""Sanitize jagged vectors."""
if on_bad_vectors == "raise":
if on_bad_vectors == "error":
raise ValueError(
f"Vector column {vector_column_name} has variable length vectors "
"Set on_bad_vectors='drop' to remove them, or "
@@ -538,7 +538,7 @@ def _sanitize_jagged(data, fill_value, on_bad_vectors, vec_arr, vector_column_na
if on_bad_vectors == "fill":
if fill_value is None:
raise ValueError(
f"`fill_value` must not be None if `on_bad_vectors` is 'fill'"
"`fill_value` must not be None if `on_bad_vectors` is 'fill'"
)
fill_arr = pa.scalar([float(fill_value)] * ndims)
vec_arr = pc.if_else(correct_ndims, vec_arr, fill_arr)
@@ -552,7 +552,7 @@ def _sanitize_jagged(data, fill_value, on_bad_vectors, vec_arr, vector_column_na
def _sanitize_nans(data, fill_value, on_bad_vectors, vec_arr, vector_column_name):
"""Sanitize NaNs in vectors"""
if on_bad_vectors == "raise":
if on_bad_vectors == "error":
raise ValueError(
f"Vector column {vector_column_name} has NaNs. "
"Set on_bad_vectors='drop' to remove them, or "
@@ -561,7 +561,7 @@ def _sanitize_nans(data, fill_value, on_bad_vectors, vec_arr, vector_column_name
elif on_bad_vectors == "fill":
if fill_value is None:
raise ValueError(
f"`fill_value` must not be None if `on_bad_vectors` is 'fill'"
"`fill_value` must not be None if `on_bad_vectors` is 'fill'"
)
fill_value = float(fill_value)
values = pc.if_else(vec_arr.values.is_nan(), fill_value, vec_arr.values)

View File

@@ -11,9 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from urllib.parse import ParseResult, urlparse
from pyarrow import fs
from urllib.parse import urlparse
def get_uri_scheme(uri: str) -> str:

View File

@@ -181,7 +181,21 @@ def test_create_index_method():
def test_add_with_nans(db):
# By default we drop bad input vectors
# by default we raise an error on bad input vectors
bad_data = [
{"vector": [np.nan], "item": "bar", "price": 20.0},
{"vector": [5], "item": "bar", "price": 20.0},
{"vector": [np.nan, np.nan], "item": "bar", "price": 20.0},
{"vector": [np.nan, 5.0], "item": "bar", "price": 20.0},
]
for row in bad_data:
with pytest.raises(ValueError):
LanceTable.create(
db,
"error_test",
data=[{"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, row],
)
table = LanceTable.create(
db,
"drop_test",
@@ -191,6 +205,7 @@ def test_add_with_nans(db):
{"vector": [5], "item": "bar", "price": 20.0},
{"vector": [np.nan, np.nan], "item": "bar", "price": 20.0},
],
on_bad_vectors="drop",
)
assert len(table) == 1
@@ -210,18 +225,3 @@ def test_add_with_nans(db):
arrow_tbl = table.to_lance().to_table(filter="item == 'bar'")
v = arrow_tbl["vector"].to_pylist()[0]
assert np.allclose(v, np.array([0.0, 0.0]))
bad_data = [
{"vector": [np.nan], "item": "bar", "price": 20.0},
{"vector": [5], "item": "bar", "price": 20.0},
{"vector": [np.nan, np.nan], "item": "bar", "price": 20.0},
{"vector": [np.nan, 5.0], "item": "bar", "price": 20.0},
]
for row in bad_data:
with pytest.raises(ValueError):
LanceTable.create(
db,
"raise_test",
data=[{"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, row],
on_bad_vectors="raise",
)