feat: flexible null handling and insert subschemas in Python (#1827)

* Test that we can insert subschemas (omit nullable columns) in Python.
* More work is needed to support this in Node. See:
https://github.com/lancedb/lancedb/issues/1832
* Test that we can insert data with nullable schema but no nulls in
non-nullable schema.
* Add `"null"` option for `on_bad_vectors` where we fill with null if
the vector is bad.
* Make null values not considered bad if the field itself is nullable.
This commit is contained in:
Will Jones
2024-11-15 11:33:00 -08:00
committed by GitHub
parent b38a4269d0
commit 587c0824af
7 changed files with 288 additions and 27 deletions

View File

@@ -1567,7 +1567,7 @@ class LanceTable(Table):
"append" and "overwrite".
on_bad_vectors: str, default "error"
What to do if any of the vectors are not the same size or contains NaNs.
One of "error", "drop", "fill".
One of "error", "drop", "fill", "null".
fill_value: float, default 0.
The value to use when filling vectors. Only used if on_bad_vectors="fill".
@@ -1851,7 +1851,7 @@ class LanceTable(Table):
data but will validate against any schema that's specified.
on_bad_vectors: str, default "error"
What to do if any of the vectors are not the same size or contains NaNs.
One of "error", "drop", "fill".
One of "error", "drop", "fill", "null".
fill_value: float, default 0.
The value to use when filling vectors. Only used if on_bad_vectors="fill".
embedding_functions: list of EmbeddingFunctionModel, default None
@@ -2151,13 +2151,11 @@ def _sanitize_schema(
vector column to fixed_size_list(float32) if necessary.
on_bad_vectors: str, default "error"
What to do if any of the vectors are not the same size or contains NaNs.
One of "error", "drop", "fill".
One of "error", "drop", "fill", "null".
fill_value: float, default 0.
The value to use when filling vectors. Only used if on_bad_vectors="fill".
"""
if schema is not None:
if data.schema == schema:
return data
# cast the columns to the expected types
data = data.combine_chunks()
for field in schema:
@@ -2177,6 +2175,7 @@ def _sanitize_schema(
vector_column_name=field.name,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
table_schema=schema,
)
return pa.Table.from_arrays(
[data[name] for name in schema.names], schema=schema
@@ -2197,6 +2196,7 @@ def _sanitize_schema(
def _sanitize_vector_column(
data: pa.Table,
vector_column_name: str,
table_schema: Optional[pa.Schema] = None,
on_bad_vectors: str = "error",
fill_value: float = 0.0,
) -> pa.Table:
@@ -2211,12 +2211,16 @@ def _sanitize_vector_column(
The name of the vector column.
on_bad_vectors: str, default "error"
What to do if any of the vectors are not the same size or contains NaNs.
One of "error", "drop", "fill".
One of "error", "drop", "fill", "null".
fill_value: float, default 0.0
The value to use when filling vectors. Only used if on_bad_vectors="fill".
"""
# ChunkedArray is annoying to work with, so we combine chunks here
vec_arr = data[vector_column_name].combine_chunks()
if table_schema is not None:
field = table_schema.field(vector_column_name)
else:
field = None
typ = data[vector_column_name].type
if pa.types.is_list(typ) or pa.types.is_large_list(typ):
# if it's a variable size list array,
@@ -2243,7 +2247,11 @@ def _sanitize_vector_column(
data, fill_value, on_bad_vectors, vec_arr, vector_column_name
)
else:
if pc.any(pc.is_null(vec_arr.values, nan_is_null=True)).as_py():
if (
field is not None
and not field.nullable
and pc.any(pc.is_null(vec_arr.values)).as_py()
) or (pc.any(pc.is_nan(vec_arr.values)).as_py()):
data = _sanitize_nans(
data, fill_value, on_bad_vectors, vec_arr, vector_column_name
)
@@ -2287,6 +2295,12 @@ def _sanitize_jagged(data, fill_value, on_bad_vectors, vec_arr, vector_column_na
)
elif on_bad_vectors == "drop":
data = data.filter(correct_ndims)
elif on_bad_vectors == "null":
data = data.set_column(
data.column_names.index(vector_column_name),
vector_column_name,
pc.if_else(correct_ndims, vec_arr, pa.scalar(None)),
)
return data
@@ -2303,7 +2317,8 @@ def _sanitize_nans(
raise ValueError(
f"Vector column {vector_column_name} has NaNs. "
"Set on_bad_vectors='drop' to remove them, or "
"set on_bad_vectors='fill' and fill_value=<value> to replace them."
"set on_bad_vectors='fill' and fill_value=<value> to replace them. "
"Or set on_bad_vectors='null' to replace them with null."
)
elif on_bad_vectors == "fill":
if fill_value is None:
@@ -2323,6 +2338,17 @@ def _sanitize_nans(
np_arr = np_arr.reshape(-1, vec_arr.type.list_size)
not_nulls = np.any(np_arr, axis=1)
data = data.filter(~not_nulls)
elif on_bad_vectors == "null":
# null = pa.nulls(len(vec_arr)).cast(vec_arr.type)
# values = pc.if_else(pc.is_nan(vec_arr.values), fill_value, vec_arr.values)
np_arr = np.isnan(vec_arr.values.to_numpy(zero_copy_only=False))
np_arr = np_arr.reshape(-1, vec_arr.type.list_size)
no_nans = np.any(np_arr, axis=1)
data = data.set_column(
data.column_names.index(vector_column_name),
vector_column_name,
pc.if_else(no_nans, vec_arr, pa.scalar(None)),
)
return data
@@ -2588,7 +2614,7 @@ class AsyncTable:
"append" and "overwrite".
on_bad_vectors: str, default "error"
What to do if any of the vectors are not the same size or contains NaNs.
One of "error", "drop", "fill".
One of "error", "drop", "fill", "null".
fill_value: float, default 0.
The value to use when filling vectors. Only used if on_bad_vectors="fill".

View File

@@ -81,14 +81,15 @@ def test_embedding_function(tmp_path):
def test_embedding_with_bad_results(tmp_path):
@register("mock-embedding")
class MockEmbeddingFunction(TextEmbeddingFunction):
@register("null-embedding")
class NullEmbeddingFunction(TextEmbeddingFunction):
def ndims(self):
return 128
def generate_embeddings(
self, texts: Union[List[str], np.ndarray]
) -> list[Union[np.array, None]]:
# Return None, which is bad if field is non-nullable
return [
None if i % 2 == 0 else np.random.randn(self.ndims())
for i in range(len(texts))
@@ -96,13 +97,17 @@ def test_embedding_with_bad_results(tmp_path):
db = lancedb.connect(tmp_path)
registry = EmbeddingFunctionRegistry.get_instance()
model = registry.get("mock-embedding").create()
model = registry.get("null-embedding").create()
class Schema(LanceModel):
text: str = model.SourceField()
vector: Vector(model.ndims()) = model.VectorField()
table = db.create_table("test", schema=Schema, mode="overwrite")
with pytest.raises(ValueError):
# Default on_bad_vectors is "error"
table.add([{"text": "hello world"}])
table.add(
[{"text": "hello world"}, {"text": "bar"}],
on_bad_vectors="drop",
@@ -112,13 +117,33 @@ def test_embedding_with_bad_results(tmp_path):
assert len(table) == 1
assert df.iloc[0]["text"] == "bar"
# table = db.create_table("test2", schema=Schema, mode="overwrite")
# table.add(
# [{"text": "hello world"}, {"text": "bar"}],
# )
# assert len(table) == 2
# tbl = table.to_arrow()
# assert tbl["vector"].null_count == 1
@register("nan-embedding")
class NanEmbeddingFunction(TextEmbeddingFunction):
def ndims(self):
return 128
def generate_embeddings(
self, texts: Union[List[str], np.ndarray]
) -> list[Union[np.array, None]]:
# Return NaN to produce bad vectors
return [
[np.NAN] * 128 if i % 2 == 0 else np.random.randn(self.ndims())
for i in range(len(texts))
]
db = lancedb.connect(tmp_path)
registry = EmbeddingFunctionRegistry.get_instance()
model = registry.get("nan-embedding").create()
table = db.create_table("test2", schema=Schema, mode="overwrite")
table.alter_columns(dict(path="vector", nullable=True))
table.add(
[{"text": "hello world"}, {"text": "bar"}],
on_bad_vectors="null",
)
assert len(table) == 2
tbl = table.to_arrow()
assert tbl["vector"].null_count == 1
def test_with_existing_vectors(tmp_path):

View File

@@ -240,6 +240,121 @@ def test_add(db):
_add(table, schema)
def test_add_subschema(tmp_path):
db = lancedb.connect(tmp_path)
schema = pa.schema(
[
pa.field("vector", pa.list_(pa.float32(), 2), nullable=True),
pa.field("item", pa.string(), nullable=True),
pa.field("price", pa.float64(), nullable=False),
]
)
table = db.create_table("test", schema=schema)
data = {"price": 10.0, "item": "foo"}
table.add([data])
data = {"price": 2.0, "vector": [3.1, 4.1]}
table.add([data])
data = {"price": 3.0, "vector": [5.9, 26.5], "item": "bar"}
table.add([data])
expected = pa.table(
{
"vector": [None, [3.1, 4.1], [5.9, 26.5]],
"item": ["foo", None, "bar"],
"price": [10.0, 2.0, 3.0],
},
schema=schema,
)
assert table.to_arrow() == expected
data = {"item": "foo"}
# We can't omit a column if it's not nullable
with pytest.raises(OSError, match="Invalid user input"):
table.add([data])
# We can add it if we make the column nullable
table.alter_columns(dict(path="price", nullable=True))
table.add([data])
expected_schema = pa.schema(
[
pa.field("vector", pa.list_(pa.float32(), 2), nullable=True),
pa.field("item", pa.string(), nullable=True),
pa.field("price", pa.float64(), nullable=True),
]
)
expected = pa.table(
{
"vector": [None, [3.1, 4.1], [5.9, 26.5], None],
"item": ["foo", None, "bar", "foo"],
"price": [10.0, 2.0, 3.0, None],
},
schema=expected_schema,
)
assert table.to_arrow() == expected
def test_add_nullability(tmp_path):
db = lancedb.connect(tmp_path)
schema = pa.schema(
[
pa.field("vector", pa.list_(pa.float32(), 2), nullable=False),
pa.field("id", pa.string(), nullable=False),
]
)
table = db.create_table("test", schema=schema)
nullable_schema = pa.schema(
[
pa.field("vector", pa.list_(pa.float32(), 2), nullable=True),
pa.field("id", pa.string(), nullable=True),
]
)
data = pa.table(
{
"vector": [[3.1, 4.1], [5.9, 26.5]],
"id": ["foo", "bar"],
},
schema=nullable_schema,
)
# We can add nullable schema if it doesn't actually contain nulls
table.add(data)
expected = data.cast(schema)
assert table.to_arrow() == expected
data = pa.table(
{
"vector": [None],
"id": ["baz"],
},
schema=nullable_schema,
)
# We can't add nullable schema if it contains nulls
with pytest.raises(Exception, match="Vector column vector has NaNs"):
table.add(data)
# But we can make it nullable
table.alter_columns(dict(path="vector", nullable=True))
table.add(data)
expected_schema = pa.schema(
[
pa.field("vector", pa.list_(pa.float32(), 2), nullable=True),
pa.field("id", pa.string(), nullable=False),
]
)
expected = pa.table(
{
"vector": [[3.1, 4.1], [5.9, 26.5], None],
"id": ["foo", "bar", "baz"],
},
schema=expected_schema,
)
assert table.to_arrow() == expected
def test_add_pydantic_model(db):
# https://github.com/lancedb/lancedb/issues/562