From 507eeae9c803cf20d9d71d6161bb67bdd4b9308e Mon Sep 17 00:00:00 2001 From: Chang She <759245+changhiskhan@users.noreply.github.com> Date: Wed, 5 Jul 2023 22:44:18 -0700 Subject: [PATCH] Set default to error instead of drop (#259) when encountering bad input data, we can default to principle of least surprise and raise an exception. Co-authored-by: Chang She --- python/lancedb/db.py | 6 ++--- python/lancedb/query.py | 2 +- python/lancedb/remote/__init__.py | 1 - python/lancedb/table.py | 38 +++++++++++++++---------------- python/lancedb/util.py | 4 +--- python/tests/test_table.py | 32 +++++++++++++------------- 6 files changed, 40 insertions(+), 43 deletions(-) diff --git a/python/lancedb/db.py b/python/lancedb/db.py index 50b2147d..fceb2e26 100644 --- a/python/lancedb/db.py +++ b/python/lancedb/db.py @@ -158,7 +158,7 @@ class LanceDBConnection: data: DATA = None, schema: pa.Schema = None, mode: str = "create", - on_bad_vectors: str = "drop", + on_bad_vectors: str = "error", fill_value: float = 0.0, ) -> LanceTable: """Create a table in the database. @@ -175,9 +175,9 @@ class LanceDBConnection: The mode to use when creating the table. Can be either "create" or "overwrite". By default, if the table already exists, an exception is raised. If you want to overwrite the table, use mode="overwrite". - on_bad_vectors: str + on_bad_vectors: str, default "error" What to do if any of the vectors are not the same size or contains NaNs. - One of "raise", "drop", "fill". + One of "error", "drop", "fill". fill_value: float The value to use when filling vectors. Only used if on_bad_vectors="fill". diff --git a/python/lancedb/query.py b/python/lancedb/query.py index bccc6905..3414dff4 100644 --- a/python/lancedb/query.py +++ b/python/lancedb/query.py @@ -13,7 +13,7 @@ from __future__ import annotations import asyncio -from typing import Awaitable, Literal +from typing import Literal import numpy as np import pandas as pd diff --git a/python/lancedb/remote/__init__.py b/python/lancedb/remote/__init__.py index 090b124a..57cc98fa 100644 --- a/python/lancedb/remote/__init__.py +++ b/python/lancedb/remote/__init__.py @@ -15,7 +15,6 @@ import abc from typing import List, Optional import attr -import pandas as pd import pyarrow as pa from pydantic import BaseModel diff --git a/python/lancedb/table.py b/python/lancedb/table.py index 111f0bae..4541a32e 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -15,7 +15,7 @@ from __future__ import annotations import os from functools import cached_property -from typing import Any, List, Union +from typing import List, Union import lance import numpy as np @@ -258,7 +258,7 @@ class LanceTable: self, data: DATA, mode: str = "append", - on_bad_vectors: str = "drop", + on_bad_vectors: str = "error", fill_value: float = 0.0, ) -> int: """Add data to the table. @@ -270,9 +270,9 @@ class LanceTable: mode: str The mode to use when writing the data. Valid values are "append" and "overwrite". - on_bad_vectors: str + on_bad_vectors: str, default "error" What to do if any of the vectors are not the same size or contains NaNs. - One of "raise", "drop", "fill". + One of "error", "drop", "fill". fill_value: float, default 0. The value to use when filling vectors. Only used if on_bad_vectors="fill". @@ -329,7 +329,7 @@ class LanceTable: data, schema=None, mode="create", - on_bad_vectors: str = "drop", + on_bad_vectors: str = "error", fill_value: float = 0.0, ): """ @@ -361,9 +361,9 @@ class LanceTable: mode: str, default "create" The mode to use when writing the data. Valid values are "create", "overwrite", and "append". - on_bad_vectors: str + on_bad_vectors: str, default "error" What to do if any of the vectors are not the same size or contains NaNs. - One of "raise", "drop", "fill". + One of "error", "drop", "fill". fill_value: float, default 0. The value to use when filling vectors. Only used if on_bad_vectors="fill". """ @@ -419,7 +419,7 @@ class LanceTable: def _sanitize_schema( data: pa.Table, schema: pa.Schema = None, - on_bad_vectors: str = "drop", + on_bad_vectors: str = "error", fill_value: float = 0.0, ) -> pa.Table: """Ensure that the table has the expected schema. @@ -431,10 +431,10 @@ def _sanitize_schema( schema: pa.Schema; optional The expected schema. If not provided, this just converts the vector column to fixed_size_list(float32) if necessary. - on_bad_vectors: str + on_bad_vectors: str, default "error" What to do if any of the vectors are not the same size or contains NaNs. - One of "raise", "drop", "fill". - fill_value: float + One of "error", "drop", "fill". + fill_value: float, default 0. The value to use when filling vectors. Only used if on_bad_vectors="fill". """ if schema is not None: @@ -463,7 +463,7 @@ def _sanitize_schema( def _sanitize_vector_column( data: pa.Table, vector_column_name: str, - on_bad_vectors: str = "drop", + on_bad_vectors: str = "error", fill_value: float = 0.0, ) -> pa.Table: """ @@ -475,10 +475,10 @@ def _sanitize_vector_column( The table to sanitize. vector_column_name: str The name of the vector column. - on_bad_vectors: str + on_bad_vectors: str, default "error" What to do if any of the vectors are not the same size or contains NaNs. - One of "raise", "drop", "fill". - fill_value: float + One of "error", "drop", "fill". + fill_value: float, default 0.0 The value to use when filling vectors. Only used if on_bad_vectors="fill". """ if vector_column_name not in data.column_names: @@ -524,7 +524,7 @@ def ensure_fixed_size_list_of_f32(vec_arr): def _sanitize_jagged(data, fill_value, on_bad_vectors, vec_arr, vector_column_name): """Sanitize jagged vectors.""" - if on_bad_vectors == "raise": + if on_bad_vectors == "error": raise ValueError( f"Vector column {vector_column_name} has variable length vectors " "Set on_bad_vectors='drop' to remove them, or " @@ -538,7 +538,7 @@ def _sanitize_jagged(data, fill_value, on_bad_vectors, vec_arr, vector_column_na if on_bad_vectors == "fill": if fill_value is None: raise ValueError( - f"`fill_value` must not be None if `on_bad_vectors` is 'fill'" + "`fill_value` must not be None if `on_bad_vectors` is 'fill'" ) fill_arr = pa.scalar([float(fill_value)] * ndims) vec_arr = pc.if_else(correct_ndims, vec_arr, fill_arr) @@ -552,7 +552,7 @@ def _sanitize_jagged(data, fill_value, on_bad_vectors, vec_arr, vector_column_na def _sanitize_nans(data, fill_value, on_bad_vectors, vec_arr, vector_column_name): """Sanitize NaNs in vectors""" - if on_bad_vectors == "raise": + if on_bad_vectors == "error": raise ValueError( f"Vector column {vector_column_name} has NaNs. " "Set on_bad_vectors='drop' to remove them, or " @@ -561,7 +561,7 @@ def _sanitize_nans(data, fill_value, on_bad_vectors, vec_arr, vector_column_name elif on_bad_vectors == "fill": if fill_value is None: raise ValueError( - f"`fill_value` must not be None if `on_bad_vectors` is 'fill'" + "`fill_value` must not be None if `on_bad_vectors` is 'fill'" ) fill_value = float(fill_value) values = pc.if_else(vec_arr.values.is_nan(), fill_value, vec_arr.values) diff --git a/python/lancedb/util.py b/python/lancedb/util.py index bc5cc7ba..47865b07 100644 --- a/python/lancedb/util.py +++ b/python/lancedb/util.py @@ -11,9 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from urllib.parse import ParseResult, urlparse - -from pyarrow import fs +from urllib.parse import urlparse def get_uri_scheme(uri: str) -> str: diff --git a/python/tests/test_table.py b/python/tests/test_table.py index 1dc10d65..019bf22b 100644 --- a/python/tests/test_table.py +++ b/python/tests/test_table.py @@ -181,7 +181,21 @@ def test_create_index_method(): def test_add_with_nans(db): - # By default we drop bad input vectors + # by default we raise an error on bad input vectors + bad_data = [ + {"vector": [np.nan], "item": "bar", "price": 20.0}, + {"vector": [5], "item": "bar", "price": 20.0}, + {"vector": [np.nan, np.nan], "item": "bar", "price": 20.0}, + {"vector": [np.nan, 5.0], "item": "bar", "price": 20.0}, + ] + for row in bad_data: + with pytest.raises(ValueError): + LanceTable.create( + db, + "error_test", + data=[{"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, row], + ) + table = LanceTable.create( db, "drop_test", @@ -191,6 +205,7 @@ def test_add_with_nans(db): {"vector": [5], "item": "bar", "price": 20.0}, {"vector": [np.nan, np.nan], "item": "bar", "price": 20.0}, ], + on_bad_vectors="drop", ) assert len(table) == 1 @@ -210,18 +225,3 @@ def test_add_with_nans(db): arrow_tbl = table.to_lance().to_table(filter="item == 'bar'") v = arrow_tbl["vector"].to_pylist()[0] assert np.allclose(v, np.array([0.0, 0.0])) - - bad_data = [ - {"vector": [np.nan], "item": "bar", "price": 20.0}, - {"vector": [5], "item": "bar", "price": 20.0}, - {"vector": [np.nan, np.nan], "item": "bar", "price": 20.0}, - {"vector": [np.nan, 5.0], "item": "bar", "price": 20.0}, - ] - for row in bad_data: - with pytest.raises(ValueError): - LanceTable.create( - db, - "raise_test", - data=[{"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, row], - on_bad_vectors="raise", - )