From 587c0824af821e1cf19fcd79ab67b272cd93f35f Mon Sep 17 00:00:00 2001 From: Will Jones Date: Fri, 15 Nov 2024 11:33:00 -0800 Subject: [PATCH] feat: flexible null handling and insert subschemas in Python (#1827) * Test that we can insert subschemas (omit nullable columns) in Python. * More work is needed to support this in Node. See: https://github.com/lancedb/lancedb/issues/1832 * Test that we can insert data with nullable schema but no nulls in non-nullable schema. * Add `"null"` option for `on_bad_vectors` where we fill with null if the vector is bad. * Make null values not considered bad if the field itself is nullable. --- .github/workflows/nodejs.yml | 1 - Cargo.toml | 14 +-- docs/src/guides/tables.md | 21 +++++ nodejs/__test__/table.test.ts | 75 ++++++++++++++++ python/python/lancedb/table.py | 44 ++++++++-- python/python/tests/test_embeddings.py | 45 +++++++--- python/python/tests/test_table.py | 115 +++++++++++++++++++++++++ 7 files changed, 288 insertions(+), 27 deletions(-) diff --git a/.github/workflows/nodejs.yml b/.github/workflows/nodejs.yml index cbbef821..af2a4619 100644 --- a/.github/workflows/nodejs.yml +++ b/.github/workflows/nodejs.yml @@ -104,7 +104,6 @@ jobs: OPENAI_BASE_URL: http://0.0.0.0:8000 run: | python ci/mock_openai.py & - ss -ltnp | grep :8000 cd nodejs/examples npm test macos: diff --git a/Cargo.toml b/Cargo.toml index 7122cdf2..971c8e46 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,13 +23,13 @@ rust-version = "1.80.0" # TODO: lower this once we upgrade Lance again. [workspace.dependencies] lance = { "version" = "=0.19.2", "features" = [ "dynamodb", -], git = "https://github.com/lancedb/lance.git", tag = "v0.19.2" } -lance-index = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2" } -lance-linalg = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2" } -lance-table = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2" } -lance-testing = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2" } -lance-datafusion = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2" } -lance-encoding = { "version" = "=0.19.2", git = "https://github.com/lancedb/lance.git", tag = "v0.19.2" } +]} +lance-index = "=0.19.2" +lance-linalg = "=0.19.2" +lance-table = "=0.19.2" +lance-testing = "=0.19.2" +lance-datafusion = "=0.19.2" +lance-encoding = "=0.19.2" # Note that this one does not include pyarrow arrow = { version = "52.2", optional = false } arrow-array = "52.2" diff --git a/docs/src/guides/tables.md b/docs/src/guides/tables.md index 97b12b83..a6973512 100644 --- a/docs/src/guides/tables.md +++ b/docs/src/guides/tables.md @@ -790,6 +790,27 @@ Use the `drop_table()` method on the database to remove a table. This permanently removes the table and is not recoverable, unlike deleting rows. If the table does not exist an exception is raised. +## Handling bad vectors + +In LanceDB Python, you can use the `on_bad_vectors` parameter to choose how +invalid vector values are handled. Invalid vectors are vectors that are not valid +because: + +1. They are the wrong dimension +2. They contain NaN values +3. They are null but are on a non-nullable field + +By default, LanceDB will raise an error if it encounters a bad vector. You can +also choose one of the following options: + +* `drop`: Ignore rows with bad vectors +* `fill`: Replace bad values (NaNs) or missing values (too few dimensions) with + the fill value specified in the `fill_value` parameter. An input like + `[1.0, NaN, 3.0]` will be replaced with `[1.0, 0.0, 3.0]` if `fill_value=0.0`. +* `null`: Replace bad vectors with null (only works if the column is nullable). + A bad vector `[1.0, NaN, 3.0]` will be replaced with `null` if the column is + nullable. If the vector column is non-nullable, then bad vectors will cause an + error ## Consistency diff --git a/nodejs/__test__/table.test.ts b/nodejs/__test__/table.test.ts index e9c06a82..39289c0d 100644 --- a/nodejs/__test__/table.test.ts +++ b/nodejs/__test__/table.test.ts @@ -187,6 +187,81 @@ describe.each([arrow13, arrow14, arrow15, arrow16, arrow17])( }, ); + // TODO: https://github.com/lancedb/lancedb/issues/1832 + it.skip("should be able to omit nullable fields", async () => { + const db = await connect(tmpDir.name); + const schema = new arrow.Schema([ + new arrow.Field( + "vector", + new arrow.FixedSizeList( + 2, + new arrow.Field("item", new arrow.Float64()), + ), + true, + ), + new arrow.Field("item", new arrow.Utf8(), true), + new arrow.Field("price", new arrow.Float64(), false), + ]); + const table = await db.createEmptyTable("test", schema); + + const data1 = { item: "foo", price: 10.0 }; + await table.add([data1]); + const data2 = { vector: [3.1, 4.1], price: 2.0 }; + await table.add([data2]); + const data3 = { vector: [5.9, 26.5], item: "bar", price: 3.0 }; + await table.add([data3]); + + let res = await table.query().limit(10).toArray(); + const resVector = res.map((r) => r.get("vector").toArray()); + expect(resVector).toEqual([null, data2.vector, data3.vector]); + const resItem = res.map((r) => r.get("item").toArray()); + expect(resItem).toEqual(["foo", null, "bar"]); + const resPrice = res.map((r) => r.get("price").toArray()); + expect(resPrice).toEqual([10.0, 2.0, 3.0]); + + const data4 = { item: "foo" }; + // We can't omit a column if it's not nullable + await expect(table.add([data4])).rejects.toThrow("Invalid user input"); + + // But we can alter columns to make them nullable + await table.alterColumns([{ path: "price", nullable: true }]); + await table.add([data4]); + + res = (await table.query().limit(10).toArray()).map((r) => r.toJSON()); + expect(res).toEqual([data1, data2, data3, data4]); + }); + + it("should be able to insert nullable data for non-nullable fields", async () => { + const db = await connect(tmpDir.name); + const schema = new arrow.Schema([ + new arrow.Field("x", new arrow.Float64(), false), + new arrow.Field("id", new arrow.Utf8(), false), + ]); + const table = await db.createEmptyTable("test", schema); + + const data1 = { x: 4.1, id: "foo" }; + await table.add([data1]); + const res = (await table.query().toArray())[0]; + expect(res.x).toEqual(data1.x); + expect(res.id).toEqual(data1.id); + + const data2 = { x: null, id: "bar" }; + await expect(table.add([data2])).rejects.toThrow( + "declared as non-nullable but contains null values", + ); + + // But we can alter columns to make them nullable + await table.alterColumns([{ path: "x", nullable: true }]); + await table.add([data2]); + + const res2 = await table.query().toArray(); + expect(res2.length).toBe(2); + expect(res2[0].x).toEqual(data1.x); + expect(res2[0].id).toEqual(data1.id); + expect(res2[1].x).toBeNull(); + expect(res2[1].id).toEqual(data2.id); + }); + it("should return the table as an instance of an arrow table", async () => { const arrowTbl = await table.toArrow(); expect(arrowTbl).toBeInstanceOf(ArrowTable); diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 684998b6..c737cf36 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -1567,7 +1567,7 @@ class LanceTable(Table): "append" and "overwrite". on_bad_vectors: str, default "error" What to do if any of the vectors are not the same size or contains NaNs. - One of "error", "drop", "fill". + One of "error", "drop", "fill", "null". fill_value: float, default 0. The value to use when filling vectors. Only used if on_bad_vectors="fill". @@ -1851,7 +1851,7 @@ class LanceTable(Table): data but will validate against any schema that's specified. on_bad_vectors: str, default "error" What to do if any of the vectors are not the same size or contains NaNs. - One of "error", "drop", "fill". + One of "error", "drop", "fill", "null". fill_value: float, default 0. The value to use when filling vectors. Only used if on_bad_vectors="fill". embedding_functions: list of EmbeddingFunctionModel, default None @@ -2151,13 +2151,11 @@ def _sanitize_schema( vector column to fixed_size_list(float32) if necessary. on_bad_vectors: str, default "error" What to do if any of the vectors are not the same size or contains NaNs. - One of "error", "drop", "fill". + One of "error", "drop", "fill", "null". fill_value: float, default 0. The value to use when filling vectors. Only used if on_bad_vectors="fill". """ if schema is not None: - if data.schema == schema: - return data # cast the columns to the expected types data = data.combine_chunks() for field in schema: @@ -2177,6 +2175,7 @@ def _sanitize_schema( vector_column_name=field.name, on_bad_vectors=on_bad_vectors, fill_value=fill_value, + table_schema=schema, ) return pa.Table.from_arrays( [data[name] for name in schema.names], schema=schema @@ -2197,6 +2196,7 @@ def _sanitize_schema( def _sanitize_vector_column( data: pa.Table, vector_column_name: str, + table_schema: Optional[pa.Schema] = None, on_bad_vectors: str = "error", fill_value: float = 0.0, ) -> pa.Table: @@ -2211,12 +2211,16 @@ def _sanitize_vector_column( The name of the vector column. on_bad_vectors: str, default "error" What to do if any of the vectors are not the same size or contains NaNs. - One of "error", "drop", "fill". + One of "error", "drop", "fill", "null". fill_value: float, default 0.0 The value to use when filling vectors. Only used if on_bad_vectors="fill". """ # ChunkedArray is annoying to work with, so we combine chunks here vec_arr = data[vector_column_name].combine_chunks() + if table_schema is not None: + field = table_schema.field(vector_column_name) + else: + field = None typ = data[vector_column_name].type if pa.types.is_list(typ) or pa.types.is_large_list(typ): # if it's a variable size list array, @@ -2243,7 +2247,11 @@ def _sanitize_vector_column( data, fill_value, on_bad_vectors, vec_arr, vector_column_name ) else: - if pc.any(pc.is_null(vec_arr.values, nan_is_null=True)).as_py(): + if ( + field is not None + and not field.nullable + and pc.any(pc.is_null(vec_arr.values)).as_py() + ) or (pc.any(pc.is_nan(vec_arr.values)).as_py()): data = _sanitize_nans( data, fill_value, on_bad_vectors, vec_arr, vector_column_name ) @@ -2287,6 +2295,12 @@ def _sanitize_jagged(data, fill_value, on_bad_vectors, vec_arr, vector_column_na ) elif on_bad_vectors == "drop": data = data.filter(correct_ndims) + elif on_bad_vectors == "null": + data = data.set_column( + data.column_names.index(vector_column_name), + vector_column_name, + pc.if_else(correct_ndims, vec_arr, pa.scalar(None)), + ) return data @@ -2303,7 +2317,8 @@ def _sanitize_nans( raise ValueError( f"Vector column {vector_column_name} has NaNs. " "Set on_bad_vectors='drop' to remove them, or " - "set on_bad_vectors='fill' and fill_value= to replace them." + "set on_bad_vectors='fill' and fill_value= to replace them. " + "Or set on_bad_vectors='null' to replace them with null." ) elif on_bad_vectors == "fill": if fill_value is None: @@ -2323,6 +2338,17 @@ def _sanitize_nans( np_arr = np_arr.reshape(-1, vec_arr.type.list_size) not_nulls = np.any(np_arr, axis=1) data = data.filter(~not_nulls) + elif on_bad_vectors == "null": + # null = pa.nulls(len(vec_arr)).cast(vec_arr.type) + # values = pc.if_else(pc.is_nan(vec_arr.values), fill_value, vec_arr.values) + np_arr = np.isnan(vec_arr.values.to_numpy(zero_copy_only=False)) + np_arr = np_arr.reshape(-1, vec_arr.type.list_size) + no_nans = np.any(np_arr, axis=1) + data = data.set_column( + data.column_names.index(vector_column_name), + vector_column_name, + pc.if_else(no_nans, vec_arr, pa.scalar(None)), + ) return data @@ -2588,7 +2614,7 @@ class AsyncTable: "append" and "overwrite". on_bad_vectors: str, default "error" What to do if any of the vectors are not the same size or contains NaNs. - One of "error", "drop", "fill". + One of "error", "drop", "fill", "null". fill_value: float, default 0. The value to use when filling vectors. Only used if on_bad_vectors="fill". diff --git a/python/python/tests/test_embeddings.py b/python/python/tests/test_embeddings.py index 59a9ee4b..32394009 100644 --- a/python/python/tests/test_embeddings.py +++ b/python/python/tests/test_embeddings.py @@ -81,14 +81,15 @@ def test_embedding_function(tmp_path): def test_embedding_with_bad_results(tmp_path): - @register("mock-embedding") - class MockEmbeddingFunction(TextEmbeddingFunction): + @register("null-embedding") + class NullEmbeddingFunction(TextEmbeddingFunction): def ndims(self): return 128 def generate_embeddings( self, texts: Union[List[str], np.ndarray] ) -> list[Union[np.array, None]]: + # Return None, which is bad if field is non-nullable return [ None if i % 2 == 0 else np.random.randn(self.ndims()) for i in range(len(texts)) @@ -96,13 +97,17 @@ def test_embedding_with_bad_results(tmp_path): db = lancedb.connect(tmp_path) registry = EmbeddingFunctionRegistry.get_instance() - model = registry.get("mock-embedding").create() + model = registry.get("null-embedding").create() class Schema(LanceModel): text: str = model.SourceField() vector: Vector(model.ndims()) = model.VectorField() table = db.create_table("test", schema=Schema, mode="overwrite") + with pytest.raises(ValueError): + # Default on_bad_vectors is "error" + table.add([{"text": "hello world"}]) + table.add( [{"text": "hello world"}, {"text": "bar"}], on_bad_vectors="drop", @@ -112,13 +117,33 @@ def test_embedding_with_bad_results(tmp_path): assert len(table) == 1 assert df.iloc[0]["text"] == "bar" - # table = db.create_table("test2", schema=Schema, mode="overwrite") - # table.add( - # [{"text": "hello world"}, {"text": "bar"}], - # ) - # assert len(table) == 2 - # tbl = table.to_arrow() - # assert tbl["vector"].null_count == 1 + @register("nan-embedding") + class NanEmbeddingFunction(TextEmbeddingFunction): + def ndims(self): + return 128 + + def generate_embeddings( + self, texts: Union[List[str], np.ndarray] + ) -> list[Union[np.array, None]]: + # Return NaN to produce bad vectors + return [ + [np.NAN] * 128 if i % 2 == 0 else np.random.randn(self.ndims()) + for i in range(len(texts)) + ] + + db = lancedb.connect(tmp_path) + registry = EmbeddingFunctionRegistry.get_instance() + model = registry.get("nan-embedding").create() + + table = db.create_table("test2", schema=Schema, mode="overwrite") + table.alter_columns(dict(path="vector", nullable=True)) + table.add( + [{"text": "hello world"}, {"text": "bar"}], + on_bad_vectors="null", + ) + assert len(table) == 2 + tbl = table.to_arrow() + assert tbl["vector"].null_count == 1 def test_with_existing_vectors(tmp_path): diff --git a/python/python/tests/test_table.py b/python/python/tests/test_table.py index 5f597ef6..2b65a620 100644 --- a/python/python/tests/test_table.py +++ b/python/python/tests/test_table.py @@ -240,6 +240,121 @@ def test_add(db): _add(table, schema) +def test_add_subschema(tmp_path): + db = lancedb.connect(tmp_path) + schema = pa.schema( + [ + pa.field("vector", pa.list_(pa.float32(), 2), nullable=True), + pa.field("item", pa.string(), nullable=True), + pa.field("price", pa.float64(), nullable=False), + ] + ) + table = db.create_table("test", schema=schema) + + data = {"price": 10.0, "item": "foo"} + table.add([data]) + data = {"price": 2.0, "vector": [3.1, 4.1]} + table.add([data]) + data = {"price": 3.0, "vector": [5.9, 26.5], "item": "bar"} + table.add([data]) + + expected = pa.table( + { + "vector": [None, [3.1, 4.1], [5.9, 26.5]], + "item": ["foo", None, "bar"], + "price": [10.0, 2.0, 3.0], + }, + schema=schema, + ) + assert table.to_arrow() == expected + + data = {"item": "foo"} + # We can't omit a column if it's not nullable + with pytest.raises(OSError, match="Invalid user input"): + table.add([data]) + + # We can add it if we make the column nullable + table.alter_columns(dict(path="price", nullable=True)) + table.add([data]) + + expected_schema = pa.schema( + [ + pa.field("vector", pa.list_(pa.float32(), 2), nullable=True), + pa.field("item", pa.string(), nullable=True), + pa.field("price", pa.float64(), nullable=True), + ] + ) + expected = pa.table( + { + "vector": [None, [3.1, 4.1], [5.9, 26.5], None], + "item": ["foo", None, "bar", "foo"], + "price": [10.0, 2.0, 3.0, None], + }, + schema=expected_schema, + ) + assert table.to_arrow() == expected + + +def test_add_nullability(tmp_path): + db = lancedb.connect(tmp_path) + schema = pa.schema( + [ + pa.field("vector", pa.list_(pa.float32(), 2), nullable=False), + pa.field("id", pa.string(), nullable=False), + ] + ) + table = db.create_table("test", schema=schema) + + nullable_schema = pa.schema( + [ + pa.field("vector", pa.list_(pa.float32(), 2), nullable=True), + pa.field("id", pa.string(), nullable=True), + ] + ) + data = pa.table( + { + "vector": [[3.1, 4.1], [5.9, 26.5]], + "id": ["foo", "bar"], + }, + schema=nullable_schema, + ) + # We can add nullable schema if it doesn't actually contain nulls + table.add(data) + + expected = data.cast(schema) + assert table.to_arrow() == expected + + data = pa.table( + { + "vector": [None], + "id": ["baz"], + }, + schema=nullable_schema, + ) + # We can't add nullable schema if it contains nulls + with pytest.raises(Exception, match="Vector column vector has NaNs"): + table.add(data) + + # But we can make it nullable + table.alter_columns(dict(path="vector", nullable=True)) + table.add(data) + + expected_schema = pa.schema( + [ + pa.field("vector", pa.list_(pa.float32(), 2), nullable=True), + pa.field("id", pa.string(), nullable=False), + ] + ) + expected = pa.table( + { + "vector": [[3.1, 4.1], [5.9, 26.5], None], + "id": ["foo", "bar", "baz"], + }, + schema=expected_schema, + ) + assert table.to_arrow() == expected + + def test_add_pydantic_model(db): # https://github.com/lancedb/lancedb/issues/562