# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright The LanceDB Authors import os import pathlib from typing import Optional import lance from lancedb.conftest import MockTextEmbeddingFunction from lancedb.embeddings.base import EmbeddingFunctionConfig from lancedb.embeddings.registry import EmbeddingFunctionRegistry from lancedb.table import ( _append_vector_columns, _cast_to_target_schema, _handle_bad_vectors, _into_pyarrow_reader, _infer_target_schema, _merge_metadata, _sanitize_data, sanitize_create_table, ) import pyarrow as pa import pandas as pd import polars as pl import pytest import lancedb from lancedb.util import get_uri_scheme, join_uri, value_to_sql from utils import exception_output def test_normalize_uri(): uris = [ "relative/path", "/absolute/path", "file:///absolute/path", "s3://bucket/path", "gs://bucket/path", "c:\\windows\\path", ] schemes = ["file", "file", "file", "s3", "gs", "file"] for uri, expected_scheme in zip(uris, schemes): parsed_scheme = get_uri_scheme(uri) assert parsed_scheme == expected_scheme def test_join_uri_remote(): schemes = ["s3", "az", "gs"] for scheme in schemes: expected = f"{scheme}://bucket/path/to/table.lance" base_uri = f"{scheme}://bucket/path/to/" parts = ["table.lance"] assert join_uri(base_uri, *parts) == expected base_uri = f"{scheme}://bucket" parts = ["path", "to", "table.lance"] assert join_uri(base_uri, *parts) == expected # skip this test if on windows @pytest.mark.skipif(os.name == "nt", reason="Windows paths are not POSIX") def test_join_uri_posix(): for base in [ # relative path "relative/path", "relative/path/", # an absolute path "/absolute/path", "/absolute/path/", # a file URI "file:///absolute/path", "file:///absolute/path/", ]: joined = join_uri(base, "table.lance") assert joined == str(pathlib.Path(base) / "table.lance") joined = join_uri(pathlib.Path(base), "table.lance") assert joined == pathlib.Path(base) / "table.lance" # skip this test if not on windows @pytest.mark.skipif(os.name != "nt", reason="Windows paths are not POSIX") def test_local_join_uri_windows(): # https://learn.microsoft.com/en-us/dotnet/standard/io/file-path-formats for base in [ # windows relative path "relative\\path", "relative\\path\\", # windows absolute path from current drive "c:\\absolute\\path", # relative path from root of current drive "\\relative\\path", ]: joined = join_uri(base, "table.lance") assert joined == str(pathlib.Path(base) / "table.lance") joined = join_uri(pathlib.Path(base), "table.lance") assert joined == pathlib.Path(base) / "table.lance" def test_value_to_sql_string(tmp_path): # Make sure we can convert Python string literals to SQL strings, even if # they contain characters meaningful in SQL, such as ' and \. values = ["anthony's", 'a "test" string', "anthony's \"favorite color\" wasn't red"] expected_values = [ "'anthony''s'", "'a \"test\" string'", "'anthony''s \"favorite color\" wasn''t red'", ] for value, expected in zip(values, expected_values): assert value_to_sql(value) == expected # Also test we can roundtrip those strings through update. # This validates the query parser understands the strings we # are creating. db = lancedb.connect(tmp_path) table = db.create_table( "test", [{"search": value, "replace": "something"} for value in values], ) for value in values: table.update(where=f"search = {value_to_sql(value)}", values={"replace": value}) assert table.to_pandas().query("search == @value")["replace"].item() == value def test_value_to_sql_dict(): # Simple flat struct assert value_to_sql({"a": 1, "b": "hello"}) == "named_struct('a', 1, 'b', 'hello')" # Nested struct assert ( value_to_sql({"outer": {"inner": 1}}) == "named_struct('outer', named_struct('inner', 1))" ) # List inside struct assert value_to_sql({"a": [1, 2]}) == "named_struct('a', [1, 2])" # Mixed types assert ( value_to_sql({"name": "test", "count": 42, "rate": 3.14, "active": True}) == "named_struct('name', 'test', 'count', 42, 'rate', 3.14, 'active', TRUE)" ) # Null value inside struct assert value_to_sql({"a": None}) == "named_struct('a', NULL)" # Empty dict assert value_to_sql({}) == "named_struct()" def test_append_vector_columns(): registry = EmbeddingFunctionRegistry.get_instance() registry.register("test")(MockTextEmbeddingFunction) conf = EmbeddingFunctionConfig( source_column="text", vector_column="vector", function=MockTextEmbeddingFunction.create(), ) metadata = registry.get_table_metadata([conf]) schema = pa.schema( { "text": pa.string(), "vector": pa.list_(pa.float64(), 10), } ) data = pa.table( { "text": ["hello"], "vector": [None], # Replaces null }, schema=schema, ) output = _append_vector_columns( data.to_reader(), schema, # metadata passed separate from schema metadata=metadata, ).read_all() assert output.schema == schema assert output["vector"].null_count == 0 # Adds if missing data = pa.table({"text": ["hello"]}) output = _append_vector_columns( data.to_reader(), schema.with_metadata(metadata), ).read_all() assert output.schema == schema assert output["vector"].null_count == 0 # doesn't embed if already there data = pa.table( { "text": ["hello"], "vector": [[42.0] * 10], }, schema=schema, ) output = _append_vector_columns( data.to_reader(), schema.with_metadata(metadata), ).read_all() assert output == data # No change # No provided schema data = pa.table( { "text": ["hello"], } ) output = _append_vector_columns( data.to_reader(), metadata=metadata, ).read_all() expected_schema = pa.schema( { "text": pa.string(), "vector": pa.list_(pa.float32(), 10), } ) assert output.schema == expected_schema assert output["vector"].null_count == 0 @pytest.mark.parametrize("on_bad_vectors", ["error", "drop", "fill", "null"]) def test_handle_bad_vectors_jagged(on_bad_vectors): vector = pa.array([[1.0, 2.0], [3.0], [4.0, 5.0]]) schema = pa.schema({"vector": pa.list_(pa.float64())}) data = pa.table({"vector": vector}, schema=schema) if on_bad_vectors == "error": with pytest.raises(ValueError) as e: output = _handle_bad_vectors( data.to_reader(), on_bad_vectors=on_bad_vectors, ).read_all() output = exception_output(e) assert output == ( "ValueError: Vector column 'vector' has variable length vectors. Set " "on_bad_vectors='drop' to remove them, set on_bad_vectors='fill' " "and fill_value= to replace them, or set on_bad_vectors='null' " "to replace them with null." ) return else: output = _handle_bad_vectors( data.to_reader(), on_bad_vectors=on_bad_vectors, fill_value=42.0, ).read_all() if on_bad_vectors == "drop": expected = pa.array([[1.0, 2.0], [4.0, 5.0]]) elif on_bad_vectors == "fill": expected = pa.array([[1.0, 2.0], [42.0, 42.0], [4.0, 5.0]]) elif on_bad_vectors == "null": expected = pa.array([[1.0, 2.0], None, [4.0, 5.0]]) assert output["vector"].combine_chunks() == expected @pytest.mark.parametrize("on_bad_vectors", ["error", "drop", "fill", "null"]) def test_handle_bad_vectors_nan(on_bad_vectors): vector = pa.array([[1.0, float("nan")], [3.0, 4.0]]) data = pa.table({"vector": vector}) if on_bad_vectors == "error": with pytest.raises(ValueError) as e: output = _handle_bad_vectors( data.to_reader(), on_bad_vectors=on_bad_vectors, ).read_all() output = exception_output(e) assert output == ( "ValueError: Vector column 'vector' has NaNs. Set " "on_bad_vectors='drop' to remove them, set on_bad_vectors='fill' " "and fill_value= to replace them, or set on_bad_vectors='null' " "to replace them with null." ) return else: output = _handle_bad_vectors( data.to_reader(), on_bad_vectors=on_bad_vectors, fill_value=42.0, ).read_all() if on_bad_vectors == "drop": expected = pa.array([[3.0, 4.0]]) elif on_bad_vectors == "fill": expected = pa.array([[42.0, 42.0], [3.0, 4.0]]) elif on_bad_vectors == "null": expected = pa.array([None, [3.0, 4.0]]) assert output["vector"].combine_chunks() == expected def test_handle_bad_vectors_noop(): # ChunkedArray should be preserved as-is vector = pa.chunked_array( [[[1.0, 2.0], [3.0, 4.0]]], type=pa.list_(pa.float64(), 2) ) data = pa.table({"vector": vector}) output = _handle_bad_vectors(data.to_reader()).read_all() assert output["vector"] == vector def test_handle_bad_vectors_updates_reader_schema_for_target_schema(): data = pa.table({"vector": [[1, 2, 3, 4]]}) target_schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 4))]) output = _handle_bad_vectors( data.to_reader(), on_bad_vectors="drop", target_schema=target_schema, ) assert output.schema == pa.schema([pa.field("vector", pa.list_(pa.float32()))]) assert output.read_all()["vector"].to_pylist() == [[1.0, 2.0, 3.0, 4.0]] def test_sanitize_data_keeps_target_field_metadata(): source_field = pa.field( "vector", pa.list_(pa.float32(), 2), metadata={b"source": b"drop-me"}, ) target_field = pa.field( "vector", pa.list_(pa.float32(), 2), metadata={b"target": b"keep-me"}, ) data = pa.table( {"vector": pa.array([[1.0, 2.0]], type=pa.list_(pa.float32(), 2))}, schema=pa.schema([source_field]), ) output = _sanitize_data( data, target_schema=pa.schema([target_field]), on_bad_vectors="drop", ).read_all() assert output.schema.field("vector").metadata == {b"target": b"keep-me"} def test_sanitize_data_uses_separate_embedding_metadata_for_bad_vectors(): registry = EmbeddingFunctionRegistry.get_instance() conf = EmbeddingFunctionConfig( source_column="text", vector_column="custom_vector", function=MockTextEmbeddingFunction.create(), ) metadata = registry.get_table_metadata([conf]) schema = pa.schema( { "text": pa.string(), "custom_vector": pa.list_(pa.float32(), 10), }, metadata={b"note": b"keep-me"}, ) data = pa.table( { "text": ["bad", "good"], "custom_vector": [[1.0] * 9, [2.0] * 10], } ) output = _sanitize_data( data, target_schema=schema, metadata=metadata, on_bad_vectors="drop", ).read_all() assert output["text"].to_pylist() == ["good"] assert output.schema.metadata[b"note"] == b"keep-me" assert b"embedding_functions" in output.schema.metadata def test_sanitize_create_table_merges_and_overrides_embedding_metadata(): registry = EmbeddingFunctionRegistry.get_instance() old_conf = EmbeddingFunctionConfig( source_column="text", vector_column="old_vector", function=MockTextEmbeddingFunction.create(), ) new_conf = EmbeddingFunctionConfig( source_column="text", vector_column="custom_vector", function=MockTextEmbeddingFunction.create(), ) metadata = registry.get_table_metadata([new_conf]) schema = pa.schema( { "text": pa.string(), "custom_vector": pa.list_(pa.float32(), 10), }, metadata=_merge_metadata( {b"note": b"keep-me"}, registry.get_table_metadata([old_conf]), ), ) data, schema = sanitize_create_table( pa.table({"text": ["good"]}), schema, metadata=metadata, on_bad_vectors="drop", ) assert schema.metadata[b"note"] == b"keep-me" assert b"embedding_functions" in schema.metadata assert data.schema.metadata[b"note"] == b"keep-me" funcs = EmbeddingFunctionRegistry.get_instance().parse_functions(schema.metadata) assert set(funcs.keys()) == {"custom_vector"} class TestModel(lancedb.pydantic.LanceModel): a: Optional[int] b: Optional[int] # TODO: huggingface, @pytest.mark.parametrize( "data", [ lambda: [{"a": 1, "b": 2}], lambda: pa.RecordBatch.from_pylist([{"a": 1, "b": 2}]), lambda: pa.table({"a": [1], "b": [2]}), lambda: pa.table({"a": [1], "b": [2]}).to_reader(), lambda: iter(pa.table({"a": [1], "b": [2]}).to_batches()), lambda: lance.write_dataset( pa.table({"a": [1], "b": [2]}), "memory://test", ), lambda: lance.write_dataset( pa.table({"a": [1], "b": [2]}), "memory://test", ).scanner(), lambda: pd.DataFrame({"a": [1], "b": [2]}), lambda: pl.DataFrame({"a": [1], "b": [2]}), lambda: pl.LazyFrame({"a": [1], "b": [2]}), lambda: [TestModel(a=1, b=2)], ], ids=[ "rows", "pa.RecordBatch", "pa.Table", "pa.RecordBatchReader", "batch_iter", "lance.LanceDataset", "lance.LanceScanner", "pd.DataFrame", "pl.DataFrame", "pl.LazyFrame", "pydantic", ], ) def test_into_pyarrow_table(data): expected = pa.table({"a": [1], "b": [2]}) output = _into_pyarrow_reader(data()).read_all() assert output == expected def test_infer_target_schema(): example = pa.schema( { "vec1": pa.list_(pa.float64(), 2), "vector": pa.list_(pa.float64()), } ) data = pa.table( { "vec1": [[0.0] * 2], "vector": [[0.0] * 2], }, schema=example, ) expected = pa.schema( { "vec1": pa.list_(pa.float64(), 2), "vector": pa.list_(pa.float32(), 2), } ) output, _ = _infer_target_schema(data.to_reader()) assert output == expected # Handle large list and use modal size # Most vectors are of length 2, so we should infer that as the target dimension example = pa.schema( { "vector": pa.large_list(pa.float64()), } ) data = pa.table( { "vector": [[0.0] * 2, [0.0], [0.0] * 2], }, schema=example, ) expected = pa.schema( { "vector": pa.list_(pa.float32(), 2), } ) output, _ = _infer_target_schema(data.to_reader()) assert output == expected # ignore if not list example = pa.schema( { "vector": pa.float64(), } ) data = pa.table( { "vector": [0.0], }, schema=example, ) expected = example output, _ = _infer_target_schema(data.to_reader()) assert output == expected def test_infer_target_schema_with_vector_embedding_names(): """Test that _infer_target_schema detects vector columns with 'vector'/'embedding'. This tests the enhanced column name detection for vector inference. """ # Test float vectors with various naming patterns example = pa.schema( { "user_vector": pa.list_(pa.float64()), "text_embedding": pa.list_(pa.float64()), "doc_embeddings": pa.list_(pa.float64()), "my_vector_field": pa.list_(pa.float64()), "embedding_model": pa.list_(pa.float64()), "VECTOR_COL": pa.list_(pa.float64()), # uppercase "Vector_Mixed": pa.list_(pa.float64()), # mixed case "normal_list": pa.list_(pa.float64()), # should not be converted } ) data = pa.table( { "user_vector": [[1.0, 2.0]], "text_embedding": [[3.0, 4.0]], "doc_embeddings": [[5.0, 6.0]], "my_vector_field": [[7.0, 8.0]], "embedding_model": [[9.0, 10.0]], "VECTOR_COL": [[11.0, 12.0]], "Vector_Mixed": [[13.0, 14.0]], "normal_list": [[15.0, 16.0]], }, schema=example, ) expected = pa.schema( { "user_vector": pa.list_(pa.float32(), 2), # converted "text_embedding": pa.list_(pa.float32(), 2), # converted "doc_embeddings": pa.list_(pa.float32(), 2), # converted "my_vector_field": pa.list_(pa.float32(), 2), # converted "embedding_model": pa.list_(pa.float32(), 2), # converted "VECTOR_COL": pa.list_(pa.float32(), 2), # converted "Vector_Mixed": pa.list_(pa.float32(), 2), # converted "normal_list": pa.list_(pa.float64()), # not converted } ) output, _ = _infer_target_schema(data.to_reader()) assert output == expected # Test integer vectors with various naming patterns example_int = pa.schema( { "user_vector": pa.list_(pa.int32()), "text_embedding": pa.list_(pa.int64()), "doc_embeddings": pa.list_(pa.int16()), "normal_list": pa.list_(pa.int32()), # should not be converted } ) data_int = pa.table( { "user_vector": [[1, 2]], "text_embedding": [[3, 4]], "doc_embeddings": [[5, 6]], "normal_list": [[7, 8]], }, schema=example_int, ) expected_int = pa.schema( { "user_vector": pa.list_(pa.uint8(), 2), # converted "text_embedding": pa.list_(pa.uint8(), 2), # converted "doc_embeddings": pa.list_(pa.uint8(), 2), # converted "normal_list": pa.list_(pa.int32()), # not converted } ) output_int, _ = _infer_target_schema(data_int.to_reader()) assert output_int == expected_int @pytest.mark.parametrize( "data", [ [{"id": 1, "text": "hello"}], pa.RecordBatch.from_pylist([{"id": 1, "text": "hello"}]), pd.DataFrame({"id": [1], "text": ["hello"]}), pl.DataFrame({"id": [1], "text": ["hello"]}), ], ids=["rows", "pa.RecordBatch", "pd.DataFrame", "pl.DataFrame"], ) @pytest.mark.parametrize( "schema", [ None, pa.schema( { "id": pa.int32(), "text": pa.string(), "vector": pa.list_(pa.float32(), 10), } ), pa.schema( { "id": pa.int64(), "text": pa.string(), "vector": pa.list_(pa.float32(), 10), "extra": pa.int64(), } ), ], ids=["infer", "explicit", "subschema"], ) @pytest.mark.parametrize("with_embedding", [True, False]) def test_sanitize_data( data, schema: Optional[pa.Schema], with_embedding: bool, ): if with_embedding: registry = EmbeddingFunctionRegistry.get_instance() registry.register("test")(MockTextEmbeddingFunction) conf = EmbeddingFunctionConfig( source_column="text", vector_column="vector", function=MockTextEmbeddingFunction.create(), ) metadata = registry.get_table_metadata([conf]) else: metadata = None if schema is not None: to_remove = schema.get_field_index("extra") if to_remove >= 0: expected_schema = schema.remove(to_remove) else: expected_schema = schema else: from conftest import pandas_string_type # polars uses large_string, pandas 3.0+ uses large_string, others use string if isinstance(data, pl.DataFrame): text_type = pa.large_utf8() elif isinstance(data, pd.DataFrame): text_type = pandas_string_type() else: text_type = pa.string() expected_schema = pa.schema( { "id": pa.int64(), "text": text_type, "vector": pa.list_(pa.float32(), 10), } ) if not with_embedding: to_remove = expected_schema.get_field_index("vector") if to_remove >= 0: expected_schema = expected_schema.remove(to_remove) expected = pa.table( { "id": [1], "text": ["hello"], "vector": [[0.0] * 10], }, schema=expected_schema, ) output_data = _sanitize_data( data, target_schema=schema, metadata=metadata, allow_subschema=True, ).read_all() assert output_data == expected def test_cast_to_target_schema(): original_schema = pa.schema( { "id": pa.int32(), "struct": pa.struct( [ pa.field("a", pa.int32()), ] ), "vector": pa.list_(pa.float64()), "vec1": pa.list_(pa.float64(), 2), "vec2": pa.list_(pa.float32(), 2), } ) data = pa.table( { "id": [1], "struct": [{"a": 1}], "vector": [[0.0] * 2], "vec1": [[0.0] * 2], "vec2": [[0.0] * 2], }, schema=original_schema, ) target = pa.schema( { "id": pa.int64(), "struct": pa.struct( [ pa.field("a", pa.int64()), ] ), "vector": pa.list_(pa.float32(), 2), "vec1": pa.list_(pa.float32(), 2), "vec2": pa.list_(pa.float32(), 2), } ) output = _cast_to_target_schema(data.to_reader(), target) expected = pa.table( { "id": [1], "struct": [{"a": 1}], "vector": [[0.0] * 2], "vec1": [[0.0] * 2], "vec2": [[0.0] * 2], }, schema=target, ) # Data can be a subschema of the target target = pa.schema( { "id": pa.int64(), "struct": pa.struct( [ pa.field("a", pa.int64()), # Additional nested field pa.field("b", pa.int64()), ] ), "vector": pa.list_(pa.float32(), 2), "vec1": pa.list_(pa.float32(), 2), "vec2": pa.list_(pa.float32(), 2), # Additional field "extra": pa.int64(), } ) with pytest.raises(Exception): _cast_to_target_schema(data.to_reader(), target) output = _cast_to_target_schema( data.to_reader(), target, allow_subschema=True ).read_all() expected_schema = pa.schema( { "id": pa.int64(), "struct": pa.struct( [ pa.field("a", pa.int64()), ] ), "vector": pa.list_(pa.float32(), 2), "vec1": pa.list_(pa.float32(), 2), "vec2": pa.list_(pa.float32(), 2), } ) expected = pa.table( { "id": [1], "struct": [{"a": 1}], "vector": [[0.0] * 2], "vec1": [[0.0] * 2], "vec2": [[0.0] * 2], }, schema=expected_schema, ) assert output == expected def test_sanitize_data_stream(): # Make sure we don't collect the whole stream when running sanitize_data schema = pa.schema({"a": pa.int32()}) def stream(): yield pa.record_batch([pa.array([1, 2, 3])], schema=schema) raise ValueError("error") reader = pa.RecordBatchReader.from_batches(schema, stream()) output = _sanitize_data(reader) first = next(output) assert first == pa.record_batch([pa.array([1, 2, 3])], schema=schema) with pytest.raises(ValueError): next(output)