Files
lancedb/python/python/tests/test_util.py
Will Jones c557e77f09 feat(python)!: support inserting and upserting subschemas (#1965)
BREAKING CHANGE: For a field "vector", list of integers will now be
converted to binary (uint8) vectors instead of f32 vectors. Use float
values instead for f32 vectors.

* Adds proper support for inserting and upserting subsets of the full
schema. I thought I had previously implemented this in #1827, but it
turns out I had not tested carefully enough.
* Refactors `_santize_data` and other utility functions to be simpler
and not require `numpy` or `combine_chunks()`.
* Added a new suite of unit tests to validate sanitization utilities.

## Examples

```python
import pandas as pd
import lancedb

db = lancedb.connect("memory://demo")
intial_data = pd.DataFrame({
    "a": [1, 2, 3],
    "b": [4, 5, 6],
    "c": [7, 8, 9]
})
table = db.create_table("demo", intial_data)

# Insert a subschema
new_data = pd.DataFrame({"a": [10, 11]})
table.add(new_data)
table.to_pandas()
```
```
    a    b    c
0   1  4.0  7.0
1   2  5.0  8.0
2   3  6.0  9.0
3  10  NaN  NaN
4  11  NaN  NaN
```


```python
# Upsert a subschema
upsert_data = pd.DataFrame({
    "a": [3, 10, 15],
    "b": [6, 7, 8],
})
table.merge_insert(on="a").when_matched_update_all().when_not_matched_insert_all().execute(upsert_data)
table.to_pandas()
```
```
    a    b    c
0   1  4.0  7.0
1   2  5.0  8.0
2   3  6.0  9.0
3  10  7.0  NaN
4  11  NaN  NaN
5  15  8.0  NaN
```
2025-01-08 10:11:10 -08:00

588 lines
17 KiB
Python

# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pathlib
from typing import Optional
import lance
from lancedb.conftest import MockTextEmbeddingFunction
from lancedb.embeddings.base import EmbeddingFunctionConfig
from lancedb.embeddings.registry import EmbeddingFunctionRegistry
from lancedb.table import (
_append_vector_columns,
_cast_to_target_schema,
_handle_bad_vectors,
_into_pyarrow_table,
_sanitize_data,
_infer_target_schema,
)
import pyarrow as pa
import pandas as pd
import polars as pl
import pytest
import lancedb
from lancedb.util import get_uri_scheme, join_uri, value_to_sql
from utils import exception_output
def test_normalize_uri():
uris = [
"relative/path",
"/absolute/path",
"file:///absolute/path",
"s3://bucket/path",
"gs://bucket/path",
"c:\\windows\\path",
]
schemes = ["file", "file", "file", "s3", "gs", "file"]
for uri, expected_scheme in zip(uris, schemes):
parsed_scheme = get_uri_scheme(uri)
assert parsed_scheme == expected_scheme
def test_join_uri_remote():
schemes = ["s3", "az", "gs"]
for scheme in schemes:
expected = f"{scheme}://bucket/path/to/table.lance"
base_uri = f"{scheme}://bucket/path/to/"
parts = ["table.lance"]
assert join_uri(base_uri, *parts) == expected
base_uri = f"{scheme}://bucket"
parts = ["path", "to", "table.lance"]
assert join_uri(base_uri, *parts) == expected
# skip this test if on windows
@pytest.mark.skipif(os.name == "nt", reason="Windows paths are not POSIX")
def test_join_uri_posix():
for base in [
# relative path
"relative/path",
"relative/path/",
# an absolute path
"/absolute/path",
"/absolute/path/",
# a file URI
"file:///absolute/path",
"file:///absolute/path/",
]:
joined = join_uri(base, "table.lance")
assert joined == str(pathlib.Path(base) / "table.lance")
joined = join_uri(pathlib.Path(base), "table.lance")
assert joined == pathlib.Path(base) / "table.lance"
# skip this test if not on windows
@pytest.mark.skipif(os.name != "nt", reason="Windows paths are not POSIX")
def test_local_join_uri_windows():
# https://learn.microsoft.com/en-us/dotnet/standard/io/file-path-formats
for base in [
# windows relative path
"relative\\path",
"relative\\path\\",
# windows absolute path from current drive
"c:\\absolute\\path",
# relative path from root of current drive
"\\relative\\path",
]:
joined = join_uri(base, "table.lance")
assert joined == str(pathlib.Path(base) / "table.lance")
joined = join_uri(pathlib.Path(base), "table.lance")
assert joined == pathlib.Path(base) / "table.lance"
def test_value_to_sql_string(tmp_path):
# Make sure we can convert Python string literals to SQL strings, even if
# they contain characters meaningful in SQL, such as ' and \.
values = ["anthony's", 'a "test" string', "anthony's \"favorite color\" wasn't red"]
expected_values = [
"'anthony''s'",
"'a \"test\" string'",
"'anthony''s \"favorite color\" wasn''t red'",
]
for value, expected in zip(values, expected_values):
assert value_to_sql(value) == expected
# Also test we can roundtrip those strings through update.
# This validates the query parser understands the strings we
# are creating.
db = lancedb.connect(tmp_path)
table = db.create_table(
"test",
[{"search": value, "replace": "something"} for value in values],
)
for value in values:
table.update(where=f"search = {value_to_sql(value)}", values={"replace": value})
assert table.to_pandas().query("search == @value")["replace"].item() == value
def test_append_vector_columns():
registry = EmbeddingFunctionRegistry.get_instance()
registry.register("test")(MockTextEmbeddingFunction)
conf = EmbeddingFunctionConfig(
source_column="text",
vector_column="vector",
function=MockTextEmbeddingFunction(),
)
metadata = registry.get_table_metadata([conf])
schema = pa.schema(
{
"text": pa.string(),
"vector": pa.list_(pa.float64(), 10),
}
)
data = pa.table(
{
"text": ["hello"],
"vector": [None], # Replaces null
},
schema=schema,
)
output = _append_vector_columns(
data,
schema, # metadata passed separate from schema
metadata=metadata,
)
assert output.schema == schema
assert output["vector"].null_count == 0
# Adds if missing
data = pa.table({"text": ["hello"]})
output = _append_vector_columns(
data,
schema.with_metadata(metadata),
)
assert output.schema == schema
assert output["vector"].null_count == 0
# doesn't embed if already there
data = pa.table(
{
"text": ["hello"],
"vector": [[42.0] * 10],
},
schema=schema,
)
output = _append_vector_columns(
data,
schema.with_metadata(metadata),
)
assert output == data # No change
# No provided schema
data = pa.table(
{
"text": ["hello"],
}
)
output = _append_vector_columns(
data,
metadata=metadata,
)
expected_schema = pa.schema(
{
"text": pa.string(),
"vector": pa.list_(pa.float32(), 10),
}
)
assert output.schema == expected_schema
assert output["vector"].null_count == 0
@pytest.mark.parametrize("on_bad_vectors", ["error", "drop", "fill", "null"])
def test_handle_bad_vectors_jagged(on_bad_vectors):
vector = pa.array([[1.0, 2.0], [3.0], [4.0, 5.0]])
schema = pa.schema({"vector": pa.list_(pa.float64())})
data = pa.table({"vector": vector}, schema=schema)
if on_bad_vectors == "error":
with pytest.raises(ValueError) as e:
output = _handle_bad_vectors(
data,
on_bad_vectors=on_bad_vectors,
)
output = exception_output(e)
assert output == (
"ValueError: Vector column 'vector' has variable length vectors. Set "
"on_bad_vectors='drop' to remove them, set on_bad_vectors='fill' "
"and fill_value=<value> to replace them, or set on_bad_vectors='null' "
"to replace them with null."
)
return
else:
output = _handle_bad_vectors(
data,
on_bad_vectors=on_bad_vectors,
fill_value=42.0,
)
if on_bad_vectors == "drop":
expected = pa.array([[1.0, 2.0], [4.0, 5.0]])
elif on_bad_vectors == "fill":
expected = pa.array([[1.0, 2.0], [42.0, 42.0], [4.0, 5.0]])
elif on_bad_vectors == "null":
expected = pa.array([[1.0, 2.0], None, [4.0, 5.0]])
assert output["vector"].combine_chunks() == expected
@pytest.mark.parametrize("on_bad_vectors", ["error", "drop", "fill", "null"])
def test_handle_bad_vectors_nan(on_bad_vectors):
vector = pa.array([[1.0, float("nan")], [3.0, 4.0]])
data = pa.table({"vector": vector})
if on_bad_vectors == "error":
with pytest.raises(ValueError) as e:
output = _handle_bad_vectors(
data,
on_bad_vectors=on_bad_vectors,
)
output = exception_output(e)
assert output == (
"ValueError: Vector column 'vector' has NaNs. Set "
"on_bad_vectors='drop' to remove them, set on_bad_vectors='fill' "
"and fill_value=<value> to replace them, or set on_bad_vectors='null' "
"to replace them with null."
)
return
else:
output = _handle_bad_vectors(
data,
on_bad_vectors=on_bad_vectors,
fill_value=42.0,
)
if on_bad_vectors == "drop":
expected = pa.array([[3.0, 4.0]])
elif on_bad_vectors == "fill":
expected = pa.array([[42.0, 42.0], [3.0, 4.0]])
elif on_bad_vectors == "null":
expected = pa.array([None, [3.0, 4.0]])
assert output["vector"].combine_chunks() == expected
def test_handle_bad_vectors_noop():
# ChunkedArray should be preserved as-is
vector = pa.chunked_array(
[[[1.0, 2.0], [3.0, 4.0]]], type=pa.list_(pa.float64(), 2)
)
data = pa.table({"vector": vector})
output = _handle_bad_vectors(data)
assert output["vector"] == vector
class TestModel(lancedb.pydantic.LanceModel):
a: Optional[int]
b: Optional[int]
# TODO: huggingface,
@pytest.mark.parametrize(
"data",
[
lambda: [{"a": 1, "b": 2}],
lambda: pa.RecordBatch.from_pylist([{"a": 1, "b": 2}]),
lambda: pa.table({"a": [1], "b": [2]}),
lambda: pa.table({"a": [1], "b": [2]}).to_reader(),
lambda: iter(pa.table({"a": [1], "b": [2]}).to_batches()),
lambda: (
lance.write_dataset(
pa.table({"a": [1], "b": [2]}),
"memory://test",
)
),
lambda: (
lance.write_dataset(
pa.table({"a": [1], "b": [2]}),
"memory://test",
).scanner()
),
lambda: pd.DataFrame({"a": [1], "b": [2]}),
lambda: pl.DataFrame({"a": [1], "b": [2]}),
lambda: pl.LazyFrame({"a": [1], "b": [2]}),
lambda: [TestModel(a=1, b=2)],
],
ids=[
"rows",
"pa.RecordBatch",
"pa.Table",
"pa.RecordBatchReader",
"batch_iter",
"lance.LanceDataset",
"lance.LanceScanner",
"pd.DataFrame",
"pl.DataFrame",
"pl.LazyFrame",
"pydantic",
],
)
def test_into_pyarrow_table(data):
expected = pa.table({"a": [1], "b": [2]})
output = _into_pyarrow_table(data())
assert output == expected
def test_infer_target_schema():
example = pa.schema(
{
"vec1": pa.list_(pa.float64(), 2),
"vector": pa.list_(pa.float64()),
}
)
data = pa.table(
{
"vec1": [[0.0] * 2],
"vector": [[0.0] * 2],
},
schema=example,
)
expected = pa.schema(
{
"vec1": pa.list_(pa.float64(), 2),
"vector": pa.list_(pa.float32(), 2),
}
)
output = _infer_target_schema(data)
assert output == expected
# Handle large list and use modal size
# Most vectors are of length 2, so we should infer that as the target dimension
example = pa.schema(
{
"vector": pa.large_list(pa.float64()),
}
)
data = pa.table(
{
"vector": [[0.0] * 2, [0.0], [0.0] * 2],
},
schema=example,
)
expected = pa.schema(
{
"vector": pa.list_(pa.float32(), 2),
}
)
output = _infer_target_schema(data)
assert output == expected
# ignore if not list
example = pa.schema(
{
"vector": pa.float64(),
}
)
data = pa.table(
{
"vector": [0.0],
},
schema=example,
)
expected = example
output = _infer_target_schema(data)
assert output == expected
@pytest.mark.parametrize(
"data",
[
[{"id": 1, "text": "hello"}],
pa.RecordBatch.from_pylist([{"id": 1, "text": "hello"}]),
pd.DataFrame({"id": [1], "text": ["hello"]}),
pl.DataFrame({"id": [1], "text": ["hello"]}),
],
ids=["rows", "pa.RecordBatch", "pd.DataFrame", "pl.DataFrame"],
)
@pytest.mark.parametrize(
"schema",
[
None,
pa.schema(
{
"id": pa.int32(),
"text": pa.string(),
"vector": pa.list_(pa.float32(), 10),
}
),
pa.schema(
{
"id": pa.int64(),
"text": pa.string(),
"vector": pa.list_(pa.float32(), 10),
"extra": pa.int64(),
}
),
],
ids=["infer", "explicit", "subschema"],
)
@pytest.mark.parametrize("with_embedding", [True, False])
def test_sanitize_data(
data,
schema: Optional[pa.Schema],
with_embedding: bool,
):
if with_embedding:
registry = EmbeddingFunctionRegistry.get_instance()
registry.register("test")(MockTextEmbeddingFunction)
conf = EmbeddingFunctionConfig(
source_column="text",
vector_column="vector",
function=MockTextEmbeddingFunction(),
)
metadata = registry.get_table_metadata([conf])
else:
metadata = None
if schema is not None:
to_remove = schema.get_field_index("extra")
if to_remove >= 0:
expected_schema = schema.remove(to_remove)
else:
expected_schema = schema
else:
expected_schema = pa.schema(
{
"id": pa.int64(),
"text": pa.large_utf8()
if isinstance(data, pl.DataFrame)
else pa.string(),
"vector": pa.list_(pa.float32(), 10),
}
)
if not with_embedding:
to_remove = expected_schema.get_field_index("vector")
if to_remove >= 0:
expected_schema = expected_schema.remove(to_remove)
expected = pa.table(
{
"id": [1],
"text": ["hello"],
"vector": [[0.0] * 10],
},
schema=expected_schema,
)
output_data = _sanitize_data(
data,
target_schema=schema,
metadata=metadata,
allow_subschema=True,
)
assert output_data == expected
def test_cast_to_target_schema():
original_schema = pa.schema(
{
"id": pa.int32(),
"struct": pa.struct(
[
pa.field("a", pa.int32()),
]
),
"vector": pa.list_(pa.float64()),
"vec1": pa.list_(pa.float64(), 2),
"vec2": pa.list_(pa.float32(), 2),
}
)
data = pa.table(
{
"id": [1],
"struct": [{"a": 1}],
"vector": [[0.0] * 2],
"vec1": [[0.0] * 2],
"vec2": [[0.0] * 2],
},
schema=original_schema,
)
target = pa.schema(
{
"id": pa.int64(),
"struct": pa.struct(
[
pa.field("a", pa.int64()),
]
),
"vector": pa.list_(pa.float32(), 2),
"vec1": pa.list_(pa.float32(), 2),
"vec2": pa.list_(pa.float32(), 2),
}
)
output = _cast_to_target_schema(data, target)
expected = pa.table(
{
"id": [1],
"struct": [{"a": 1}],
"vector": [[0.0] * 2],
"vec1": [[0.0] * 2],
"vec2": [[0.0] * 2],
},
schema=target,
)
# Data can be a subschema of the target
target = pa.schema(
{
"id": pa.int64(),
"struct": pa.struct(
[
pa.field("a", pa.int64()),
# Additional nested field
pa.field("b", pa.int64()),
]
),
"vector": pa.list_(pa.float32(), 2),
"vec1": pa.list_(pa.float32(), 2),
"vec2": pa.list_(pa.float32(), 2),
# Additional field
"extra": pa.int64(),
}
)
with pytest.raises(Exception):
_cast_to_target_schema(data, target)
output = _cast_to_target_schema(data, target, allow_subschema=True)
expected_schema = pa.schema(
{
"id": pa.int64(),
"struct": pa.struct(
[
pa.field("a", pa.int64()),
]
),
"vector": pa.list_(pa.float32(), 2),
"vec1": pa.list_(pa.float32(), 2),
"vec2": pa.list_(pa.float32(), 2),
}
)
expected = pa.table(
{
"id": [1],
"struct": [{"a": 1}],
"vector": [[0.0] * 2],
"vec1": [[0.0] * 2],
"vec2": [[0.0] * 2],
},
schema=expected_schema,
)
assert output == expected