mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-10 05:42:58 +00:00
feat: flexible null handling and insert subschemas in Python (#1827)
* Test that we can insert subschemas (omit nullable columns) in Python. * More work is needed to support this in Node. See: https://github.com/lancedb/lancedb/issues/1832 * Test that we can insert data with nullable schema but no nulls in non-nullable schema. * Add `"null"` option for `on_bad_vectors` where we fill with null if the vector is bad. * Make null values not considered bad if the field itself is nullable.
This commit is contained in:
1
.github/workflows/nodejs.yml
vendored
1
.github/workflows/nodejs.yml
vendored
@@ -104,7 +104,6 @@ jobs:
|
||||
OPENAI_BASE_URL: http://0.0.0.0:8000
|
||||
run: |
|
||||
python ci/mock_openai.py &
|
||||
ss -ltnp | grep :8000
|
||||
cd nodejs/examples
|
||||
npm test
|
||||
macos:
|
||||
|
||||
14
Cargo.toml
14
Cargo.toml
@@ -23,13 +23,13 @@ rust-version = "1.80.0" # TODO: lower this once we upgrade Lance again.
|
||||
[workspace.dependencies]
|
||||
lance = { "version" = "=0.19.2", "features" = [
|
||||
"dynamodb",
|
||||
], git = "https://github.com/lancedb/lance.git", tag = "v0.19.2" }
|
||||
lance-index = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2" }
|
||||
lance-linalg = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2" }
|
||||
lance-table = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2" }
|
||||
lance-testing = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2" }
|
||||
lance-datafusion = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2" }
|
||||
lance-encoding = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2" }
|
||||
]}
|
||||
lance-index = "=0.19.2"
|
||||
lance-linalg = "=0.19.2"
|
||||
lance-table = "=0.19.2"
|
||||
lance-testing = "=0.19.2"
|
||||
lance-datafusion = "=0.19.2"
|
||||
lance-encoding = "=0.19.2"
|
||||
# Note that this one does not include pyarrow
|
||||
arrow = { version = "52.2", optional = false }
|
||||
arrow-array = "52.2"
|
||||
|
||||
@@ -790,6 +790,27 @@ Use the `drop_table()` method on the database to remove a table.
|
||||
This permanently removes the table and is not recoverable, unlike deleting rows.
|
||||
If the table does not exist an exception is raised.
|
||||
|
||||
## Handling bad vectors
|
||||
|
||||
In LanceDB Python, you can use the `on_bad_vectors` parameter to choose how
|
||||
invalid vector values are handled. Invalid vectors are vectors that are not valid
|
||||
because:
|
||||
|
||||
1. They are the wrong dimension
|
||||
2. They contain NaN values
|
||||
3. They are null but are on a non-nullable field
|
||||
|
||||
By default, LanceDB will raise an error if it encounters a bad vector. You can
|
||||
also choose one of the following options:
|
||||
|
||||
* `drop`: Ignore rows with bad vectors
|
||||
* `fill`: Replace bad values (NaNs) or missing values (too few dimensions) with
|
||||
the fill value specified in the `fill_value` parameter. An input like
|
||||
`[1.0, NaN, 3.0]` will be replaced with `[1.0, 0.0, 3.0]` if `fill_value=0.0`.
|
||||
* `null`: Replace bad vectors with null (only works if the column is nullable).
|
||||
A bad vector `[1.0, NaN, 3.0]` will be replaced with `null` if the column is
|
||||
nullable. If the vector column is non-nullable, then bad vectors will cause an
|
||||
error
|
||||
|
||||
## Consistency
|
||||
|
||||
|
||||
@@ -187,6 +187,81 @@ describe.each([arrow13, arrow14, arrow15, arrow16, arrow17])(
|
||||
},
|
||||
);
|
||||
|
||||
// TODO: https://github.com/lancedb/lancedb/issues/1832
|
||||
it.skip("should be able to omit nullable fields", async () => {
|
||||
const db = await connect(tmpDir.name);
|
||||
const schema = new arrow.Schema([
|
||||
new arrow.Field(
|
||||
"vector",
|
||||
new arrow.FixedSizeList(
|
||||
2,
|
||||
new arrow.Field("item", new arrow.Float64()),
|
||||
),
|
||||
true,
|
||||
),
|
||||
new arrow.Field("item", new arrow.Utf8(), true),
|
||||
new arrow.Field("price", new arrow.Float64(), false),
|
||||
]);
|
||||
const table = await db.createEmptyTable("test", schema);
|
||||
|
||||
const data1 = { item: "foo", price: 10.0 };
|
||||
await table.add([data1]);
|
||||
const data2 = { vector: [3.1, 4.1], price: 2.0 };
|
||||
await table.add([data2]);
|
||||
const data3 = { vector: [5.9, 26.5], item: "bar", price: 3.0 };
|
||||
await table.add([data3]);
|
||||
|
||||
let res = await table.query().limit(10).toArray();
|
||||
const resVector = res.map((r) => r.get("vector").toArray());
|
||||
expect(resVector).toEqual([null, data2.vector, data3.vector]);
|
||||
const resItem = res.map((r) => r.get("item").toArray());
|
||||
expect(resItem).toEqual(["foo", null, "bar"]);
|
||||
const resPrice = res.map((r) => r.get("price").toArray());
|
||||
expect(resPrice).toEqual([10.0, 2.0, 3.0]);
|
||||
|
||||
const data4 = { item: "foo" };
|
||||
// We can't omit a column if it's not nullable
|
||||
await expect(table.add([data4])).rejects.toThrow("Invalid user input");
|
||||
|
||||
// But we can alter columns to make them nullable
|
||||
await table.alterColumns([{ path: "price", nullable: true }]);
|
||||
await table.add([data4]);
|
||||
|
||||
res = (await table.query().limit(10).toArray()).map((r) => r.toJSON());
|
||||
expect(res).toEqual([data1, data2, data3, data4]);
|
||||
});
|
||||
|
||||
it("should be able to insert nullable data for non-nullable fields", async () => {
|
||||
const db = await connect(tmpDir.name);
|
||||
const schema = new arrow.Schema([
|
||||
new arrow.Field("x", new arrow.Float64(), false),
|
||||
new arrow.Field("id", new arrow.Utf8(), false),
|
||||
]);
|
||||
const table = await db.createEmptyTable("test", schema);
|
||||
|
||||
const data1 = { x: 4.1, id: "foo" };
|
||||
await table.add([data1]);
|
||||
const res = (await table.query().toArray())[0];
|
||||
expect(res.x).toEqual(data1.x);
|
||||
expect(res.id).toEqual(data1.id);
|
||||
|
||||
const data2 = { x: null, id: "bar" };
|
||||
await expect(table.add([data2])).rejects.toThrow(
|
||||
"declared as non-nullable but contains null values",
|
||||
);
|
||||
|
||||
// But we can alter columns to make them nullable
|
||||
await table.alterColumns([{ path: "x", nullable: true }]);
|
||||
await table.add([data2]);
|
||||
|
||||
const res2 = await table.query().toArray();
|
||||
expect(res2.length).toBe(2);
|
||||
expect(res2[0].x).toEqual(data1.x);
|
||||
expect(res2[0].id).toEqual(data1.id);
|
||||
expect(res2[1].x).toBeNull();
|
||||
expect(res2[1].id).toEqual(data2.id);
|
||||
});
|
||||
|
||||
it("should return the table as an instance of an arrow table", async () => {
|
||||
const arrowTbl = await table.toArrow();
|
||||
expect(arrowTbl).toBeInstanceOf(ArrowTable);
|
||||
|
||||
@@ -1567,7 +1567,7 @@ class LanceTable(Table):
|
||||
"append" and "overwrite".
|
||||
on_bad_vectors: str, default "error"
|
||||
What to do if any of the vectors are not the same size or contains NaNs.
|
||||
One of "error", "drop", "fill".
|
||||
One of "error", "drop", "fill", "null".
|
||||
fill_value: float, default 0.
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
|
||||
@@ -1851,7 +1851,7 @@ class LanceTable(Table):
|
||||
data but will validate against any schema that's specified.
|
||||
on_bad_vectors: str, default "error"
|
||||
What to do if any of the vectors are not the same size or contains NaNs.
|
||||
One of "error", "drop", "fill".
|
||||
One of "error", "drop", "fill", "null".
|
||||
fill_value: float, default 0.
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
embedding_functions: list of EmbeddingFunctionModel, default None
|
||||
@@ -2151,13 +2151,11 @@ def _sanitize_schema(
|
||||
vector column to fixed_size_list(float32) if necessary.
|
||||
on_bad_vectors: str, default "error"
|
||||
What to do if any of the vectors are not the same size or contains NaNs.
|
||||
One of "error", "drop", "fill".
|
||||
One of "error", "drop", "fill", "null".
|
||||
fill_value: float, default 0.
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
"""
|
||||
if schema is not None:
|
||||
if data.schema == schema:
|
||||
return data
|
||||
# cast the columns to the expected types
|
||||
data = data.combine_chunks()
|
||||
for field in schema:
|
||||
@@ -2177,6 +2175,7 @@ def _sanitize_schema(
|
||||
vector_column_name=field.name,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
table_schema=schema,
|
||||
)
|
||||
return pa.Table.from_arrays(
|
||||
[data[name] for name in schema.names], schema=schema
|
||||
@@ -2197,6 +2196,7 @@ def _sanitize_schema(
|
||||
def _sanitize_vector_column(
|
||||
data: pa.Table,
|
||||
vector_column_name: str,
|
||||
table_schema: Optional[pa.Schema] = None,
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
) -> pa.Table:
|
||||
@@ -2211,12 +2211,16 @@ def _sanitize_vector_column(
|
||||
The name of the vector column.
|
||||
on_bad_vectors: str, default "error"
|
||||
What to do if any of the vectors are not the same size or contains NaNs.
|
||||
One of "error", "drop", "fill".
|
||||
One of "error", "drop", "fill", "null".
|
||||
fill_value: float, default 0.0
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
"""
|
||||
# ChunkedArray is annoying to work with, so we combine chunks here
|
||||
vec_arr = data[vector_column_name].combine_chunks()
|
||||
if table_schema is not None:
|
||||
field = table_schema.field(vector_column_name)
|
||||
else:
|
||||
field = None
|
||||
typ = data[vector_column_name].type
|
||||
if pa.types.is_list(typ) or pa.types.is_large_list(typ):
|
||||
# if it's a variable size list array,
|
||||
@@ -2243,7 +2247,11 @@ def _sanitize_vector_column(
|
||||
data, fill_value, on_bad_vectors, vec_arr, vector_column_name
|
||||
)
|
||||
else:
|
||||
if pc.any(pc.is_null(vec_arr.values, nan_is_null=True)).as_py():
|
||||
if (
|
||||
field is not None
|
||||
and not field.nullable
|
||||
and pc.any(pc.is_null(vec_arr.values)).as_py()
|
||||
) or (pc.any(pc.is_nan(vec_arr.values)).as_py()):
|
||||
data = _sanitize_nans(
|
||||
data, fill_value, on_bad_vectors, vec_arr, vector_column_name
|
||||
)
|
||||
@@ -2287,6 +2295,12 @@ def _sanitize_jagged(data, fill_value, on_bad_vectors, vec_arr, vector_column_na
|
||||
)
|
||||
elif on_bad_vectors == "drop":
|
||||
data = data.filter(correct_ndims)
|
||||
elif on_bad_vectors == "null":
|
||||
data = data.set_column(
|
||||
data.column_names.index(vector_column_name),
|
||||
vector_column_name,
|
||||
pc.if_else(correct_ndims, vec_arr, pa.scalar(None)),
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
@@ -2303,7 +2317,8 @@ def _sanitize_nans(
|
||||
raise ValueError(
|
||||
f"Vector column {vector_column_name} has NaNs. "
|
||||
"Set on_bad_vectors='drop' to remove them, or "
|
||||
"set on_bad_vectors='fill' and fill_value=<value> to replace them."
|
||||
"set on_bad_vectors='fill' and fill_value=<value> to replace them. "
|
||||
"Or set on_bad_vectors='null' to replace them with null."
|
||||
)
|
||||
elif on_bad_vectors == "fill":
|
||||
if fill_value is None:
|
||||
@@ -2323,6 +2338,17 @@ def _sanitize_nans(
|
||||
np_arr = np_arr.reshape(-1, vec_arr.type.list_size)
|
||||
not_nulls = np.any(np_arr, axis=1)
|
||||
data = data.filter(~not_nulls)
|
||||
elif on_bad_vectors == "null":
|
||||
# null = pa.nulls(len(vec_arr)).cast(vec_arr.type)
|
||||
# values = pc.if_else(pc.is_nan(vec_arr.values), fill_value, vec_arr.values)
|
||||
np_arr = np.isnan(vec_arr.values.to_numpy(zero_copy_only=False))
|
||||
np_arr = np_arr.reshape(-1, vec_arr.type.list_size)
|
||||
no_nans = np.any(np_arr, axis=1)
|
||||
data = data.set_column(
|
||||
data.column_names.index(vector_column_name),
|
||||
vector_column_name,
|
||||
pc.if_else(no_nans, vec_arr, pa.scalar(None)),
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
@@ -2588,7 +2614,7 @@ class AsyncTable:
|
||||
"append" and "overwrite".
|
||||
on_bad_vectors: str, default "error"
|
||||
What to do if any of the vectors are not the same size or contains NaNs.
|
||||
One of "error", "drop", "fill".
|
||||
One of "error", "drop", "fill", "null".
|
||||
fill_value: float, default 0.
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
|
||||
|
||||
@@ -81,14 +81,15 @@ def test_embedding_function(tmp_path):
|
||||
|
||||
|
||||
def test_embedding_with_bad_results(tmp_path):
|
||||
@register("mock-embedding")
|
||||
class MockEmbeddingFunction(TextEmbeddingFunction):
|
||||
@register("null-embedding")
|
||||
class NullEmbeddingFunction(TextEmbeddingFunction):
|
||||
def ndims(self):
|
||||
return 128
|
||||
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray]
|
||||
) -> list[Union[np.array, None]]:
|
||||
# Return None, which is bad if field is non-nullable
|
||||
return [
|
||||
None if i % 2 == 0 else np.random.randn(self.ndims())
|
||||
for i in range(len(texts))
|
||||
@@ -96,13 +97,17 @@ def test_embedding_with_bad_results(tmp_path):
|
||||
|
||||
db = lancedb.connect(tmp_path)
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
model = registry.get("mock-embedding").create()
|
||||
model = registry.get("null-embedding").create()
|
||||
|
||||
class Schema(LanceModel):
|
||||
text: str = model.SourceField()
|
||||
vector: Vector(model.ndims()) = model.VectorField()
|
||||
|
||||
table = db.create_table("test", schema=Schema, mode="overwrite")
|
||||
with pytest.raises(ValueError):
|
||||
# Default on_bad_vectors is "error"
|
||||
table.add([{"text": "hello world"}])
|
||||
|
||||
table.add(
|
||||
[{"text": "hello world"}, {"text": "bar"}],
|
||||
on_bad_vectors="drop",
|
||||
@@ -112,13 +117,33 @@ def test_embedding_with_bad_results(tmp_path):
|
||||
assert len(table) == 1
|
||||
assert df.iloc[0]["text"] == "bar"
|
||||
|
||||
# table = db.create_table("test2", schema=Schema, mode="overwrite")
|
||||
# table.add(
|
||||
# [{"text": "hello world"}, {"text": "bar"}],
|
||||
# )
|
||||
# assert len(table) == 2
|
||||
# tbl = table.to_arrow()
|
||||
# assert tbl["vector"].null_count == 1
|
||||
@register("nan-embedding")
|
||||
class NanEmbeddingFunction(TextEmbeddingFunction):
|
||||
def ndims(self):
|
||||
return 128
|
||||
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray]
|
||||
) -> list[Union[np.array, None]]:
|
||||
# Return NaN to produce bad vectors
|
||||
return [
|
||||
[np.NAN] * 128 if i % 2 == 0 else np.random.randn(self.ndims())
|
||||
for i in range(len(texts))
|
||||
]
|
||||
|
||||
db = lancedb.connect(tmp_path)
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
model = registry.get("nan-embedding").create()
|
||||
|
||||
table = db.create_table("test2", schema=Schema, mode="overwrite")
|
||||
table.alter_columns(dict(path="vector", nullable=True))
|
||||
table.add(
|
||||
[{"text": "hello world"}, {"text": "bar"}],
|
||||
on_bad_vectors="null",
|
||||
)
|
||||
assert len(table) == 2
|
||||
tbl = table.to_arrow()
|
||||
assert tbl["vector"].null_count == 1
|
||||
|
||||
|
||||
def test_with_existing_vectors(tmp_path):
|
||||
|
||||
@@ -240,6 +240,121 @@ def test_add(db):
|
||||
_add(table, schema)
|
||||
|
||||
|
||||
def test_add_subschema(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("vector", pa.list_(pa.float32(), 2), nullable=True),
|
||||
pa.field("item", pa.string(), nullable=True),
|
||||
pa.field("price", pa.float64(), nullable=False),
|
||||
]
|
||||
)
|
||||
table = db.create_table("test", schema=schema)
|
||||
|
||||
data = {"price": 10.0, "item": "foo"}
|
||||
table.add([data])
|
||||
data = {"price": 2.0, "vector": [3.1, 4.1]}
|
||||
table.add([data])
|
||||
data = {"price": 3.0, "vector": [5.9, 26.5], "item": "bar"}
|
||||
table.add([data])
|
||||
|
||||
expected = pa.table(
|
||||
{
|
||||
"vector": [None, [3.1, 4.1], [5.9, 26.5]],
|
||||
"item": ["foo", None, "bar"],
|
||||
"price": [10.0, 2.0, 3.0],
|
||||
},
|
||||
schema=schema,
|
||||
)
|
||||
assert table.to_arrow() == expected
|
||||
|
||||
data = {"item": "foo"}
|
||||
# We can't omit a column if it's not nullable
|
||||
with pytest.raises(OSError, match="Invalid user input"):
|
||||
table.add([data])
|
||||
|
||||
# We can add it if we make the column nullable
|
||||
table.alter_columns(dict(path="price", nullable=True))
|
||||
table.add([data])
|
||||
|
||||
expected_schema = pa.schema(
|
||||
[
|
||||
pa.field("vector", pa.list_(pa.float32(), 2), nullable=True),
|
||||
pa.field("item", pa.string(), nullable=True),
|
||||
pa.field("price", pa.float64(), nullable=True),
|
||||
]
|
||||
)
|
||||
expected = pa.table(
|
||||
{
|
||||
"vector": [None, [3.1, 4.1], [5.9, 26.5], None],
|
||||
"item": ["foo", None, "bar", "foo"],
|
||||
"price": [10.0, 2.0, 3.0, None],
|
||||
},
|
||||
schema=expected_schema,
|
||||
)
|
||||
assert table.to_arrow() == expected
|
||||
|
||||
|
||||
def test_add_nullability(tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("vector", pa.list_(pa.float32(), 2), nullable=False),
|
||||
pa.field("id", pa.string(), nullable=False),
|
||||
]
|
||||
)
|
||||
table = db.create_table("test", schema=schema)
|
||||
|
||||
nullable_schema = pa.schema(
|
||||
[
|
||||
pa.field("vector", pa.list_(pa.float32(), 2), nullable=True),
|
||||
pa.field("id", pa.string(), nullable=True),
|
||||
]
|
||||
)
|
||||
data = pa.table(
|
||||
{
|
||||
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||
"id": ["foo", "bar"],
|
||||
},
|
||||
schema=nullable_schema,
|
||||
)
|
||||
# We can add nullable schema if it doesn't actually contain nulls
|
||||
table.add(data)
|
||||
|
||||
expected = data.cast(schema)
|
||||
assert table.to_arrow() == expected
|
||||
|
||||
data = pa.table(
|
||||
{
|
||||
"vector": [None],
|
||||
"id": ["baz"],
|
||||
},
|
||||
schema=nullable_schema,
|
||||
)
|
||||
# We can't add nullable schema if it contains nulls
|
||||
with pytest.raises(Exception, match="Vector column vector has NaNs"):
|
||||
table.add(data)
|
||||
|
||||
# But we can make it nullable
|
||||
table.alter_columns(dict(path="vector", nullable=True))
|
||||
table.add(data)
|
||||
|
||||
expected_schema = pa.schema(
|
||||
[
|
||||
pa.field("vector", pa.list_(pa.float32(), 2), nullable=True),
|
||||
pa.field("id", pa.string(), nullable=False),
|
||||
]
|
||||
)
|
||||
expected = pa.table(
|
||||
{
|
||||
"vector": [[3.1, 4.1], [5.9, 26.5], None],
|
||||
"id": ["foo", "bar", "baz"],
|
||||
},
|
||||
schema=expected_schema,
|
||||
)
|
||||
assert table.to_arrow() == expected
|
||||
|
||||
|
||||
def test_add_pydantic_model(db):
|
||||
# https://github.com/lancedb/lancedb/issues/562
|
||||
|
||||
|
||||
Reference in New Issue
Block a user