mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 23:12:58 +00:00
Set default to error instead of drop (#259)
when encountering bad input data, we can default to principle of least surprise and raise an exception. Co-authored-by: Chang She <chang@lancedb.com>
This commit is contained in:
@@ -158,7 +158,7 @@ class LanceDBConnection:
|
||||
data: DATA = None,
|
||||
schema: pa.Schema = None,
|
||||
mode: str = "create",
|
||||
on_bad_vectors: str = "drop",
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
) -> LanceTable:
|
||||
"""Create a table in the database.
|
||||
@@ -175,9 +175,9 @@ class LanceDBConnection:
|
||||
The mode to use when creating the table. Can be either "create" or "overwrite".
|
||||
By default, if the table already exists, an exception is raised.
|
||||
If you want to overwrite the table, use mode="overwrite".
|
||||
on_bad_vectors: str
|
||||
on_bad_vectors: str, default "error"
|
||||
What to do if any of the vectors are not the same size or contains NaNs.
|
||||
One of "raise", "drop", "fill".
|
||||
One of "error", "drop", "fill".
|
||||
fill_value: float
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Awaitable, Literal
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
@@ -15,7 +15,6 @@ import abc
|
||||
from typing import List, Optional
|
||||
|
||||
import attr
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
from functools import cached_property
|
||||
from typing import Any, List, Union
|
||||
from typing import List, Union
|
||||
|
||||
import lance
|
||||
import numpy as np
|
||||
@@ -258,7 +258,7 @@ class LanceTable:
|
||||
self,
|
||||
data: DATA,
|
||||
mode: str = "append",
|
||||
on_bad_vectors: str = "drop",
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
) -> int:
|
||||
"""Add data to the table.
|
||||
@@ -270,9 +270,9 @@ class LanceTable:
|
||||
mode: str
|
||||
The mode to use when writing the data. Valid values are
|
||||
"append" and "overwrite".
|
||||
on_bad_vectors: str
|
||||
on_bad_vectors: str, default "error"
|
||||
What to do if any of the vectors are not the same size or contains NaNs.
|
||||
One of "raise", "drop", "fill".
|
||||
One of "error", "drop", "fill".
|
||||
fill_value: float, default 0.
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
|
||||
@@ -329,7 +329,7 @@ class LanceTable:
|
||||
data,
|
||||
schema=None,
|
||||
mode="create",
|
||||
on_bad_vectors: str = "drop",
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
):
|
||||
"""
|
||||
@@ -361,9 +361,9 @@ class LanceTable:
|
||||
mode: str, default "create"
|
||||
The mode to use when writing the data. Valid values are
|
||||
"create", "overwrite", and "append".
|
||||
on_bad_vectors: str
|
||||
on_bad_vectors: str, default "error"
|
||||
What to do if any of the vectors are not the same size or contains NaNs.
|
||||
One of "raise", "drop", "fill".
|
||||
One of "error", "drop", "fill".
|
||||
fill_value: float, default 0.
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
"""
|
||||
@@ -419,7 +419,7 @@ class LanceTable:
|
||||
def _sanitize_schema(
|
||||
data: pa.Table,
|
||||
schema: pa.Schema = None,
|
||||
on_bad_vectors: str = "drop",
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
) -> pa.Table:
|
||||
"""Ensure that the table has the expected schema.
|
||||
@@ -431,10 +431,10 @@ def _sanitize_schema(
|
||||
schema: pa.Schema; optional
|
||||
The expected schema. If not provided, this just converts the
|
||||
vector column to fixed_size_list(float32) if necessary.
|
||||
on_bad_vectors: str
|
||||
on_bad_vectors: str, default "error"
|
||||
What to do if any of the vectors are not the same size or contains NaNs.
|
||||
One of "raise", "drop", "fill".
|
||||
fill_value: float
|
||||
One of "error", "drop", "fill".
|
||||
fill_value: float, default 0.
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
"""
|
||||
if schema is not None:
|
||||
@@ -463,7 +463,7 @@ def _sanitize_schema(
|
||||
def _sanitize_vector_column(
|
||||
data: pa.Table,
|
||||
vector_column_name: str,
|
||||
on_bad_vectors: str = "drop",
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
) -> pa.Table:
|
||||
"""
|
||||
@@ -475,10 +475,10 @@ def _sanitize_vector_column(
|
||||
The table to sanitize.
|
||||
vector_column_name: str
|
||||
The name of the vector column.
|
||||
on_bad_vectors: str
|
||||
on_bad_vectors: str, default "error"
|
||||
What to do if any of the vectors are not the same size or contains NaNs.
|
||||
One of "raise", "drop", "fill".
|
||||
fill_value: float
|
||||
One of "error", "drop", "fill".
|
||||
fill_value: float, default 0.0
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
"""
|
||||
if vector_column_name not in data.column_names:
|
||||
@@ -524,7 +524,7 @@ def ensure_fixed_size_list_of_f32(vec_arr):
|
||||
|
||||
def _sanitize_jagged(data, fill_value, on_bad_vectors, vec_arr, vector_column_name):
|
||||
"""Sanitize jagged vectors."""
|
||||
if on_bad_vectors == "raise":
|
||||
if on_bad_vectors == "error":
|
||||
raise ValueError(
|
||||
f"Vector column {vector_column_name} has variable length vectors "
|
||||
"Set on_bad_vectors='drop' to remove them, or "
|
||||
@@ -538,7 +538,7 @@ def _sanitize_jagged(data, fill_value, on_bad_vectors, vec_arr, vector_column_na
|
||||
if on_bad_vectors == "fill":
|
||||
if fill_value is None:
|
||||
raise ValueError(
|
||||
f"`fill_value` must not be None if `on_bad_vectors` is 'fill'"
|
||||
"`fill_value` must not be None if `on_bad_vectors` is 'fill'"
|
||||
)
|
||||
fill_arr = pa.scalar([float(fill_value)] * ndims)
|
||||
vec_arr = pc.if_else(correct_ndims, vec_arr, fill_arr)
|
||||
@@ -552,7 +552,7 @@ def _sanitize_jagged(data, fill_value, on_bad_vectors, vec_arr, vector_column_na
|
||||
|
||||
def _sanitize_nans(data, fill_value, on_bad_vectors, vec_arr, vector_column_name):
|
||||
"""Sanitize NaNs in vectors"""
|
||||
if on_bad_vectors == "raise":
|
||||
if on_bad_vectors == "error":
|
||||
raise ValueError(
|
||||
f"Vector column {vector_column_name} has NaNs. "
|
||||
"Set on_bad_vectors='drop' to remove them, or "
|
||||
@@ -561,7 +561,7 @@ def _sanitize_nans(data, fill_value, on_bad_vectors, vec_arr, vector_column_name
|
||||
elif on_bad_vectors == "fill":
|
||||
if fill_value is None:
|
||||
raise ValueError(
|
||||
f"`fill_value` must not be None if `on_bad_vectors` is 'fill'"
|
||||
"`fill_value` must not be None if `on_bad_vectors` is 'fill'"
|
||||
)
|
||||
fill_value = float(fill_value)
|
||||
values = pc.if_else(vec_arr.values.is_nan(), fill_value, vec_arr.values)
|
||||
|
||||
@@ -11,9 +11,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from urllib.parse import ParseResult, urlparse
|
||||
|
||||
from pyarrow import fs
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
def get_uri_scheme(uri: str) -> str:
|
||||
|
||||
@@ -181,7 +181,21 @@ def test_create_index_method():
|
||||
|
||||
|
||||
def test_add_with_nans(db):
|
||||
# By default we drop bad input vectors
|
||||
# by default we raise an error on bad input vectors
|
||||
bad_data = [
|
||||
{"vector": [np.nan], "item": "bar", "price": 20.0},
|
||||
{"vector": [5], "item": "bar", "price": 20.0},
|
||||
{"vector": [np.nan, np.nan], "item": "bar", "price": 20.0},
|
||||
{"vector": [np.nan, 5.0], "item": "bar", "price": 20.0},
|
||||
]
|
||||
for row in bad_data:
|
||||
with pytest.raises(ValueError):
|
||||
LanceTable.create(
|
||||
db,
|
||||
"error_test",
|
||||
data=[{"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, row],
|
||||
)
|
||||
|
||||
table = LanceTable.create(
|
||||
db,
|
||||
"drop_test",
|
||||
@@ -191,6 +205,7 @@ def test_add_with_nans(db):
|
||||
{"vector": [5], "item": "bar", "price": 20.0},
|
||||
{"vector": [np.nan, np.nan], "item": "bar", "price": 20.0},
|
||||
],
|
||||
on_bad_vectors="drop",
|
||||
)
|
||||
assert len(table) == 1
|
||||
|
||||
@@ -210,18 +225,3 @@ def test_add_with_nans(db):
|
||||
arrow_tbl = table.to_lance().to_table(filter="item == 'bar'")
|
||||
v = arrow_tbl["vector"].to_pylist()[0]
|
||||
assert np.allclose(v, np.array([0.0, 0.0]))
|
||||
|
||||
bad_data = [
|
||||
{"vector": [np.nan], "item": "bar", "price": 20.0},
|
||||
{"vector": [5], "item": "bar", "price": 20.0},
|
||||
{"vector": [np.nan, np.nan], "item": "bar", "price": 20.0},
|
||||
{"vector": [np.nan, 5.0], "item": "bar", "price": 20.0},
|
||||
]
|
||||
for row in bad_data:
|
||||
with pytest.raises(ValueError):
|
||||
LanceTable.create(
|
||||
db,
|
||||
"raise_test",
|
||||
data=[{"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, row],
|
||||
on_bad_vectors="raise",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user