diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index ccd40b4a7..45f76bdc9 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -270,15 +270,17 @@ def _sanitize_data( reader, on_bad_vectors=on_bad_vectors, fill_value=fill_value, + target_schema=target_schema, + metadata=metadata, ) if target_schema is None: target_schema, reader = _infer_target_schema(reader) if metadata: - new_metadata = target_schema.metadata or {} - new_metadata.update(metadata) - target_schema = target_schema.with_metadata(new_metadata) + target_schema = target_schema.with_metadata( + _merge_metadata(target_schema.metadata, metadata) + ) _validate_schema(target_schema) reader = _cast_to_target_schema(reader, target_schema, allow_subschema) @@ -294,7 +296,7 @@ def _cast_to_target_schema( # pa.Table.cast expects field order not to be changed. # Lance doesn't care about field order, so we don't need to rearrange fields # to match the target schema. We just need to correctly cast the fields. - if reader.schema == target_schema: + if reader.schema.equals(target_schema, check_metadata=True): # Fast path when the schemas are already the same return reader @@ -314,7 +316,13 @@ def _cast_to_target_schema( def gen(): for batch in reader: # Table but not RecordBatch has cast. - yield pa.Table.from_batches([batch]).cast(reordered_schema).to_batches()[0] + cast_batches = ( + pa.Table.from_batches([batch]).cast(reordered_schema).to_batches() + ) + if cast_batches: + yield pa.RecordBatch.from_arrays( + cast_batches[0].columns, schema=reordered_schema + ) return pa.RecordBatchReader.from_batches(reordered_schema, gen()) @@ -332,37 +340,51 @@ def _align_field_types( if target_field is None: raise ValueError(f"Field '{field.name}' not found in target schema") if pa.types.is_struct(target_field.type): - new_type = pa.struct( - _align_field_types( - field.type.fields, - target_field.type.fields, + if pa.types.is_struct(field.type): + new_type = pa.struct( + _align_field_types( + field.type.fields, + target_field.type.fields, + ) ) - ) + else: + new_type = target_field.type elif pa.types.is_list(target_field.type): - new_type = pa.list_( - _align_field_types( - [field.type.value_field], - [target_field.type.value_field], - )[0] - ) + if _is_list_like(field.type): + new_type = pa.list_( + _align_field_types( + [field.type.value_field], + [target_field.type.value_field], + )[0] + ) + else: + new_type = target_field.type elif pa.types.is_large_list(target_field.type): - new_type = pa.large_list( - _align_field_types( - [field.type.value_field], - [target_field.type.value_field], - )[0] - ) + if _is_list_like(field.type): + new_type = pa.large_list( + _align_field_types( + [field.type.value_field], + [target_field.type.value_field], + )[0] + ) + else: + new_type = target_field.type elif pa.types.is_fixed_size_list(target_field.type): - new_type = pa.list_( - _align_field_types( - [field.type.value_field], - [target_field.type.value_field], - )[0], - target_field.type.list_size, - ) + if _is_list_like(field.type): + new_type = pa.list_( + _align_field_types( + [field.type.value_field], + [target_field.type.value_field], + )[0], + target_field.type.list_size, + ) + else: + new_type = target_field.type else: new_type = target_field.type - new_fields.append(pa.field(field.name, new_type, field.nullable)) + new_fields.append( + pa.field(field.name, new_type, field.nullable, target_field.metadata) + ) return new_fields @@ -440,6 +462,7 @@ def sanitize_create_table( schema = data.schema if metadata: + metadata = _merge_metadata(schema.metadata, metadata) schema = schema.with_metadata(metadata) # Need to apply metadata to the data as well if isinstance(data, pa.Table): @@ -492,9 +515,9 @@ def _append_vector_columns( vector columns to the table. """ if schema is None: - metadata = metadata or {} + metadata = _merge_metadata(metadata) else: - metadata = schema.metadata or metadata or {} + metadata = _merge_metadata(schema.metadata, metadata) functions = EmbeddingFunctionRegistry.get_instance().parse_functions(metadata) if not functions: @@ -3211,43 +3234,157 @@ def _handle_bad_vectors( reader: pa.RecordBatchReader, on_bad_vectors: Literal["error", "drop", "fill", "null"] = "error", fill_value: float = 0.0, + target_schema: Optional[pa.Schema] = None, + metadata: Optional[dict] = None, ) -> pa.RecordBatchReader: - vector_columns = [] + vector_columns = _find_vector_columns(reader.schema, target_schema, metadata) + if not vector_columns: + return reader - for field in reader.schema: - # They can provide a 'vector' column that isn't yet a FSL - named_vector_col = ( - ( - pa.types.is_list(field.type) - or pa.types.is_large_list(field.type) - or pa.types.is_fixed_size_list(field.type) - ) - and pa.types.is_floating(field.type.value_type) - and field.name == VECTOR_COLUMN_NAME - ) - # TODO: we're making an assumption that fixed size list of 10 or more - # is a vector column. This is definitely a bit hacky. - likely_vector_col = ( - pa.types.is_fixed_size_list(field.type) - and pa.types.is_floating(field.type.value_type) - and (field.type.list_size >= 10) - ) - - if named_vector_col or likely_vector_col: - vector_columns.append(field.name) + output_schema = _vector_output_schema(reader.schema, vector_columns) def gen(): for batch in reader: - for name in vector_columns: + pending_dims = [] + for vector_column in vector_columns: + dim = vector_column["expected_dim"] + if target_schema is not None and dim is None: + dim = _infer_vector_dim(batch[vector_column["name"]]) + pending_dims.append(vector_column) batch = _handle_bad_vector_column( batch, - vector_column_name=name, + vector_column_name=vector_column["name"], on_bad_vectors=on_bad_vectors, fill_value=fill_value, + expected_dim=dim, + expected_value_type=vector_column["expected_value_type"], ) - yield batch + for vector_column in pending_dims: + if vector_column["expected_dim"] is None: + vector_column["expected_dim"] = _infer_vector_dim( + batch[vector_column["name"]] + ) + if batch.schema.equals(output_schema, check_metadata=True): + yield batch + continue - return pa.RecordBatchReader.from_batches(reader.schema, gen()) + cast_batches = ( + pa.Table.from_batches([batch]).cast(output_schema).to_batches() + ) + if cast_batches: + yield pa.RecordBatch.from_arrays( + cast_batches[0].columns, + schema=output_schema, + ) + + return pa.RecordBatchReader.from_batches(output_schema, gen()) + + +def _find_vector_columns( + reader_schema: pa.Schema, + target_schema: Optional[pa.Schema], + metadata: Optional[dict], +) -> List[dict]: + if target_schema is None: + vector_columns = [] + for field in reader_schema: + named_vector_col = ( + _is_list_like(field.type) + and pa.types.is_floating(field.type.value_type) + and field.name == VECTOR_COLUMN_NAME + ) + likely_vector_col = ( + pa.types.is_fixed_size_list(field.type) + and pa.types.is_floating(field.type.value_type) + and (field.type.list_size >= 10) + ) + if named_vector_col or likely_vector_col: + vector_columns.append( + { + "name": field.name, + "expected_dim": None, + "expected_value_type": None, + } + ) + return vector_columns + + reader_column_names = set(reader_schema.names) + active_metadata = _merge_metadata(target_schema.metadata, metadata) + embedding_function_columns = set( + EmbeddingFunctionRegistry.get_instance().parse_functions(active_metadata).keys() + ) + vector_columns = [] + for field in target_schema: + if field.name not in reader_column_names: + continue + if not _is_list_like(field.type) or not pa.types.is_floating( + field.type.value_type + ): + continue + + reader_field = reader_schema.field(field.name) + named_vector_col = ( + field.name in embedding_function_columns + or field.name == VECTOR_COLUMN_NAME + or (field.name == "embedding" and pa.types.is_fixed_size_list(field.type)) + ) + typed_fixed_vector_col = ( + pa.types.is_fixed_size_list(reader_field.type) + and pa.types.is_floating(reader_field.type.value_type) + and reader_field.type.list_size >= 10 + ) + + if named_vector_col or typed_fixed_vector_col: + vector_columns.append( + { + "name": field.name, + "expected_dim": ( + field.type.list_size + if pa.types.is_fixed_size_list(field.type) + else None + ), + "expected_value_type": field.type.value_type, + } + ) + + return vector_columns + + +def _vector_output_schema( + reader_schema: pa.Schema, + vector_columns: List[dict], +) -> pa.Schema: + columns_by_name = {column["name"]: column for column in vector_columns} + fields = [] + for field in reader_schema: + column = columns_by_name.get(field.name) + if column is None: + output_type = field.type + else: + output_type = _vector_output_type(field, column) + fields.append(pa.field(field.name, output_type, field.nullable, field.metadata)) + return pa.schema(fields, metadata=reader_schema.metadata) + + +def _vector_output_type(field: pa.Field, vector_column: dict) -> pa.DataType: + if not _is_list_like(field.type): + return field.type + + if vector_column["expected_value_type"] is not None and ( + pa.types.is_null(field.type.value_type) + or pa.types.is_integer(field.type.value_type) + or pa.types.is_unsigned_integer(field.type.value_type) + ): + return pa.list_(vector_column["expected_value_type"]) + + if ( + vector_column["expected_dim"] is not None + and pa.types.is_fixed_size_list(field.type) + and field.type.list_size != vector_column["expected_dim"] + ): + return pa.list_(field.type.value_type) + + return field.type def _handle_bad_vector_column( @@ -3255,6 +3392,8 @@ def _handle_bad_vector_column( vector_column_name: str, on_bad_vectors: str = "error", fill_value: float = 0.0, + expected_dim: Optional[int] = None, + expected_value_type: Optional[pa.DataType] = None, ) -> pa.RecordBatch: """ Ensure that the vector column exists and has type fixed_size_list(float) @@ -3271,14 +3410,39 @@ def _handle_bad_vector_column( fill_value: float, default 0.0 The value to use when filling vectors. Only used if on_bad_vectors="fill". """ + position = data.column_names.index(vector_column_name) vec_arr = data[vector_column_name] + if not _is_list_like(vec_arr.type): + return data - has_nan = has_nan_values(vec_arr) + if ( + expected_dim is not None + and pa.types.is_fixed_size_list(vec_arr.type) + and vec_arr.type.list_size != expected_dim + ): + vec_arr = pa.array(vec_arr.to_pylist(), type=pa.list_(vec_arr.type.value_type)) + data = data.set_column(position, vector_column_name, vec_arr) - if pa.types.is_fixed_size_list(vec_arr.type): + if expected_value_type is not None and ( + pa.types.is_integer(vec_arr.type.value_type) + or pa.types.is_unsigned_integer(vec_arr.type.value_type) + ): + vec_arr = pa.array(vec_arr.to_pylist(), type=pa.list_(expected_value_type)) + data = data.set_column(position, vector_column_name, vec_arr) + + if pa.types.is_floating(vec_arr.type.value_type): + has_nan = has_nan_values(vec_arr) + else: + has_nan = pa.array([False] * len(vec_arr)) + + if expected_dim is not None: + dim = expected_dim + elif pa.types.is_fixed_size_list(vec_arr.type): dim = vec_arr.type.list_size else: - dim = _modal_list_size(vec_arr) + dim = _infer_vector_dim(vec_arr) + if dim is None: + return data has_wrong_dim = pc.not_equal(pc.list_value_length(vec_arr), dim) has_bad_vectors = pc.any(has_nan).as_py() or pc.any(has_wrong_dim).as_py() @@ -3316,13 +3480,12 @@ def _handle_bad_vector_column( ) vec_arr = pc.if_else( is_bad, - pa.scalar([fill_value] * dim), + pa.scalar([fill_value] * dim, type=vec_arr.type), vec_arr, ) else: raise ValueError(f"Invalid value for on_bad_vectors: {on_bad_vectors}") - position = data.column_names.index(vector_column_name) return data.set_column(position, vector_column_name, vec_arr) @@ -3343,6 +3506,28 @@ def has_nan_values(arr: Union[pa.ListArray, pa.ChunkedArray]) -> pa.BooleanArray return pc.is_in(indices, has_nan_indices) +def _is_list_like(data_type: pa.DataType) -> bool: + return ( + pa.types.is_list(data_type) + or pa.types.is_large_list(data_type) + or pa.types.is_fixed_size_list(data_type) + ) + + +def _merge_metadata(*metadata_dicts: Optional[dict]) -> dict: + merged = {} + for metadata in metadata_dicts: + if metadata is None: + continue + for key, value in metadata.items(): + if isinstance(key, str): + key = key.encode("utf-8") + if isinstance(value, str): + value = value.encode("utf-8") + merged[key] = value + return merged + + def _name_suggests_vector_column(field_name: str) -> bool: """Check if a field name indicates a vector column.""" name_lower = field_name.lower() @@ -3410,6 +3595,16 @@ def _modal_list_size(arr: Union[pa.ListArray, pa.ChunkedArray]) -> int: return pc.mode(pc.list_value_length(arr))[0].as_py()["mode"] +def _infer_vector_dim(arr: Union[pa.Array, pa.ChunkedArray]) -> Optional[int]: + if not _is_list_like(arr.type): + return None + lengths = pc.list_value_length(arr) + lengths = pc.filter(lengths, pc.greater(lengths, 0)) + if len(lengths) == 0: + return None + return pc.mode(lengths)[0].as_py()["mode"] + + def _validate_schema(schema: pa.Schema): """ Make sure the metadata is valid utf8 diff --git a/python/python/tests/test_table.py b/python/python/tests/test_table.py index f1e71e4cb..639afe903 100644 --- a/python/python/tests/test_table.py +++ b/python/python/tests/test_table.py @@ -1049,6 +1049,231 @@ def test_add_with_nans(mem_db: DBConnection): assert np.allclose(v, np.array([0.0, 0.0])) +def test_add_with_empty_fixed_size_list_drops_bad_rows(mem_db: DBConnection): + class Schema(LanceModel): + text: str + embedding: Vector(16) + + table = mem_db.create_table("test_empty_embeddings", schema=Schema) + table.add( + [ + {"text": "hello", "embedding": []}, + {"text": "bar", "embedding": [0.1] * 16}, + ], + on_bad_vectors="drop", + ) + + data = table.to_arrow() + assert data["text"].to_pylist() == ["bar"] + assert np.allclose(data["embedding"].to_pylist()[0], np.array([0.1] * 16)) + + +def test_add_with_integer_embeddings_preserves_casting(mem_db: DBConnection): + class Schema(LanceModel): + text: str + embedding: Vector(4) + + table = mem_db.create_table("test_integer_embeddings", schema=Schema) + table.add( + [{"text": "foo", "embedding": [1, 2, 3, 4]}], + on_bad_vectors="drop", + ) + + assert table.to_arrow()["embedding"].to_pylist() == [[1.0, 2.0, 3.0, 4.0]] + + +def test_on_bad_vectors_does_not_handle_non_vector_fixed_size_lists( + mem_db: DBConnection, +): + schema = pa.schema( + [ + pa.field("vector", pa.list_(pa.float32(), 4)), + pa.field("bbox", pa.list_(pa.float32(), 4)), + ] + ) + table = mem_db.create_table("test_bbox_schema", schema=schema) + + with pytest.raises(RuntimeError, match="FixedSizeListType"): + table.add( + [{"vector": [1.0, 2.0, 3.0, 4.0], "bbox": [0.0, 1.0]}], + on_bad_vectors="drop", + ) + + +def test_on_bad_vectors_does_not_handle_custom_named_fixed_size_lists( + mem_db: DBConnection, +): + schema = pa.schema([pa.field("features", pa.list_(pa.float32(), 16))]) + table = mem_db.create_table("test_custom_named_fixed_size_vector", schema=schema) + + with pytest.raises(RuntimeError, match="FixedSizeListType"): + table.add( + [ + {"features": []}, + {"features": [0.1] * 16}, + ], + on_bad_vectors="drop", + ) + + +def test_on_bad_vectors_with_schema_list_vector_still_sanitizes(mem_db: DBConnection): + schema = pa.schema([pa.field("vector", pa.list_(pa.float32()))]) + table = mem_db.create_table("test_schema_list_vector", schema=schema) + table.add( + [ + {"vector": [1.0, 2.0]}, + {"vector": [3.0]}, + {"vector": [4.0, 5.0]}, + ], + on_bad_vectors="drop", + ) + + assert table.to_arrow()["vector"].to_pylist() == [[1.0, 2.0], [4.0, 5.0]] + + +def test_on_bad_vectors_handles_typed_custom_fixed_vectors_for_list_schema( + mem_db: DBConnection, +): + schema = pa.schema([pa.field("vec", pa.list_(pa.float32()))]) + table = mem_db.create_table("test_typed_custom_fixed_vector", schema=schema) + data = pa.table( + { + "vec": pa.array( + [[float("nan")] * 16, [1.0] * 16], + type=pa.list_(pa.float32(), 16), + ) + } + ) + + table.add(data, on_bad_vectors="drop") + + assert table.to_arrow()["vec"].to_pylist() == [[1.0] * 16] + + +def test_on_bad_vectors_fill_preserves_arrow_nested_vector_type(mem_db: DBConnection): + schema = pa.schema([pa.field("vector", pa.list_(pa.float32()))]) + table = mem_db.create_table("test_fill_arrow_nested_type", schema=schema) + data = pa.table( + { + "vector": pa.array( + [[1.0, 2.0], [float("nan"), 3.0]], + type=pa.list_(pa.float32(), 2), + ) + } + ) + + table.add( + data, + on_bad_vectors="fill", + fill_value=0.0, + ) + + assert table.to_arrow()["vector"].to_pylist() == [[1.0, 2.0], [0.0, 0.0]] + + +@pytest.mark.parametrize( + ("table_name", "batch1", "expected"), + [ + ( + "test_schema_list_vector_empty_prefix", + pa.record_batch({"vector": [[], []]}), + [[], [], [1.0, 2.0], [3.0, 4.0]], + ), + ( + "test_schema_list_vector_all_bad_prefix", + pa.record_batch({"vector": [[float("nan")] * 3, [float("nan")] * 3]}), + [[1.0, 2.0], [3.0, 4.0]], + ), + ], +) +def test_on_bad_vectors_with_schema_list_vector_ignores_invalid_prefix_batches( + mem_db: DBConnection, + table_name: str, + batch1: pa.RecordBatch, + expected: list, +): + schema = pa.schema([pa.field("vector", pa.list_(pa.float32()))]) + table = mem_db.create_table(table_name, schema=schema) + batch2 = pa.record_batch({"vector": [[1.0, 2.0], [3.0, 4.0]]}) + reader = pa.RecordBatchReader.from_batches(batch1.schema, [batch1, batch2]) + + table.add(reader, on_bad_vectors="drop") + + assert table.to_arrow()["vector"].to_pylist() == expected + + +def test_on_bad_vectors_with_multiple_vectors_locks_dim_after_final_drop( + mem_db: DBConnection, +): + registry = EmbeddingFunctionRegistry.get_instance() + func = MockTextEmbeddingFunction.create() + metadata = registry.get_table_metadata( + [ + EmbeddingFunctionConfig( + source_column="text1", vector_column="vec1", function=func + ), + EmbeddingFunctionConfig( + source_column="text2", vector_column="vec2", function=func + ), + ] + ) + schema = pa.schema( + [ + pa.field("vec1", pa.list_(pa.float32())), + pa.field("vec2", pa.list_(pa.float32())), + ], + metadata=metadata, + ) + table = mem_db.create_table("test_multi_vector_dim_lock", schema=schema) + batch1 = pa.record_batch( + { + "vec1": [[1.0, 2.0, 3.0], [10.0, 11.0]], + "vec2": [[float("nan"), 0.0], [5.0, 6.0]], + } + ) + batch2 = pa.record_batch( + { + "vec1": [[20.0, 21.0], [30.0, 31.0]], + "vec2": [[7.0, 8.0], [9.0, 10.0]], + } + ) + reader = pa.RecordBatchReader.from_batches(batch1.schema, [batch1, batch2]) + + table.add(reader, on_bad_vectors="drop") + + data = table.to_arrow() + assert data["vec1"].to_pylist() == [[10.0, 11.0], [20.0, 21.0], [30.0, 31.0]] + assert data["vec2"].to_pylist() == [[5.0, 6.0], [7.0, 8.0], [9.0, 10.0]] + + +def test_on_bad_vectors_does_not_handle_non_vector_list_columns(mem_db: DBConnection): + schema = pa.schema([pa.field("embedding_history", pa.list_(pa.float32()))]) + table = mem_db.create_table("test_non_vector_list_schema", schema=schema) + table.add( + [ + {"embedding_history": [1.0, 2.0]}, + {"embedding_history": [3.0]}, + ], + on_bad_vectors="drop", + ) + + assert table.to_arrow()["embedding_history"].to_pylist() == [ + [1.0, 2.0], + [3.0], + ] + + +def test_on_bad_vectors_all_null_schema_vector_batches_do_not_crash( + mem_db: DBConnection, +): + schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2), nullable=True)]) + table = mem_db.create_table("test_all_null_vector_batch", schema=schema) + + table.add([{"vector": None}], on_bad_vectors="drop") + + assert table.to_arrow()["vector"].to_pylist() == [None] + + def test_restore(mem_db: DBConnection): table = mem_db.create_table( "my_table", diff --git a/python/python/tests/test_util.py b/python/python/tests/test_util.py index 74296a221..b5ab159b7 100644 --- a/python/python/tests/test_util.py +++ b/python/python/tests/test_util.py @@ -15,8 +15,10 @@ from lancedb.table import ( _cast_to_target_schema, _handle_bad_vectors, _into_pyarrow_reader, - _sanitize_data, _infer_target_schema, + _merge_metadata, + _sanitize_data, + sanitize_create_table, ) import pyarrow as pa import pandas as pd @@ -304,6 +306,117 @@ def test_handle_bad_vectors_noop(): assert output["vector"] == vector +def test_handle_bad_vectors_updates_reader_schema_for_target_schema(): + data = pa.table({"vector": [[1, 2, 3, 4]]}) + target_schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 4))]) + + output = _handle_bad_vectors( + data.to_reader(), + on_bad_vectors="drop", + target_schema=target_schema, + ) + + assert output.schema == pa.schema([pa.field("vector", pa.list_(pa.float32()))]) + assert output.read_all()["vector"].to_pylist() == [[1.0, 2.0, 3.0, 4.0]] + + +def test_sanitize_data_keeps_target_field_metadata(): + source_field = pa.field( + "vector", + pa.list_(pa.float32(), 2), + metadata={b"source": b"drop-me"}, + ) + target_field = pa.field( + "vector", + pa.list_(pa.float32(), 2), + metadata={b"target": b"keep-me"}, + ) + data = pa.table( + {"vector": pa.array([[1.0, 2.0]], type=pa.list_(pa.float32(), 2))}, + schema=pa.schema([source_field]), + ) + + output = _sanitize_data( + data, + target_schema=pa.schema([target_field]), + on_bad_vectors="drop", + ).read_all() + + assert output.schema.field("vector").metadata == {b"target": b"keep-me"} + + +def test_sanitize_data_uses_separate_embedding_metadata_for_bad_vectors(): + registry = EmbeddingFunctionRegistry.get_instance() + conf = EmbeddingFunctionConfig( + source_column="text", + vector_column="custom_vector", + function=MockTextEmbeddingFunction.create(), + ) + metadata = registry.get_table_metadata([conf]) + schema = pa.schema( + { + "text": pa.string(), + "custom_vector": pa.list_(pa.float32(), 10), + }, + metadata={b"note": b"keep-me"}, + ) + data = pa.table( + { + "text": ["bad", "good"], + "custom_vector": [[1.0] * 9, [2.0] * 10], + } + ) + + output = _sanitize_data( + data, + target_schema=schema, + metadata=metadata, + on_bad_vectors="drop", + ).read_all() + + assert output["text"].to_pylist() == ["good"] + assert output.schema.metadata[b"note"] == b"keep-me" + assert b"embedding_functions" in output.schema.metadata + + +def test_sanitize_create_table_merges_and_overrides_embedding_metadata(): + registry = EmbeddingFunctionRegistry.get_instance() + old_conf = EmbeddingFunctionConfig( + source_column="text", + vector_column="old_vector", + function=MockTextEmbeddingFunction.create(), + ) + new_conf = EmbeddingFunctionConfig( + source_column="text", + vector_column="custom_vector", + function=MockTextEmbeddingFunction.create(), + ) + metadata = registry.get_table_metadata([new_conf]) + schema = pa.schema( + { + "text": pa.string(), + "custom_vector": pa.list_(pa.float32(), 10), + }, + metadata=_merge_metadata( + {b"note": b"keep-me"}, + registry.get_table_metadata([old_conf]), + ), + ) + + data, schema = sanitize_create_table( + pa.table({"text": ["good"]}), + schema, + metadata=metadata, + on_bad_vectors="drop", + ) + + assert schema.metadata[b"note"] == b"keep-me" + assert b"embedding_functions" in schema.metadata + assert data.schema.metadata[b"note"] == b"keep-me" + funcs = EmbeddingFunctionRegistry.get_instance().parse_functions(schema.metadata) + assert set(funcs.keys()) == {"custom_vector"} + + class TestModel(lancedb.pydantic.LanceModel): a: Optional[int] b: Optional[int]