diff --git a/python/python/lancedb/util.py b/python/python/lancedb/util.py index b8e519933..e9d01407b 100644 --- a/python/python/lancedb/util.py +++ b/python/python/lancedb/util.py @@ -385,6 +385,21 @@ def _(value: np.ndarray): return value_to_sql(value.tolist()) +@value_to_sql.register(np.bool_) +def _(value: np.bool_): + return value_to_sql(bool(value)) + + +@value_to_sql.register(np.integer) +def _(value: np.integer): + return value_to_sql(int(value)) + + +@value_to_sql.register(np.floating) +def _(value: np.floating): + return value_to_sql(float(value)) + + def deprecated(func): """This is a decorator which can be used to mark functions as deprecated. It will result in a warning being emitted diff --git a/python/python/tests/test_util.py b/python/python/tests/test_util.py index b5ab159b7..f3051d45a 100644 --- a/python/python/tests/test_util.py +++ b/python/python/tests/test_util.py @@ -149,6 +149,21 @@ def test_value_to_sql_dict(): assert value_to_sql({}) == "named_struct()" +def test_value_to_sql_numpy_scalars(): + # numpy scalars (e.g. pulled from an ndarray or a pandas column) must + # convert the same way as their native Python counterparts. np.float64 + # already worked by virtue of subclassing float, but the integer / bool + # / float32 scalars previously raised NotImplementedError. + import numpy as np + + assert value_to_sql(np.int32(5)) == "5" + assert value_to_sql(np.int64(5)) == "5" + assert value_to_sql(np.float32(1.5)) == "1.5" + assert value_to_sql(np.float64(1.5)) == "1.5" + assert value_to_sql(np.bool_(True)) == "TRUE" + assert value_to_sql(np.bool_(False)) == "FALSE" + + def test_append_vector_columns(): registry = EmbeddingFunctionRegistry.get_instance() registry.register("test")(MockTextEmbeddingFunction)