mirror of
https://github.com/lancedb/lancedb.git
synced 2026-06-07 06:10:38 +00:00
feat(python)!: support inserting and upserting subschemas (#1965)
BREAKING CHANGE: For a field "vector", list of integers will now be converted to binary (uint8) vectors instead of f32 vectors. Use float values instead for f32 vectors. * Adds proper support for inserting and upserting subsets of the full schema. I thought I had previously implemented this in #1827, but it turns out I had not tested carefully enough. * Refactors `_santize_data` and other utility functions to be simpler and not require `numpy` or `combine_chunks()`. * Added a new suite of unit tests to validate sanitization utilities. ## Examples ```python import pandas as pd import lancedb db = lancedb.connect("memory://demo") intial_data = pd.DataFrame({ "a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9] }) table = db.create_table("demo", intial_data) # Insert a subschema new_data = pd.DataFrame({"a": [10, 11]}) table.add(new_data) table.to_pandas() ``` ``` a b c 0 1 4.0 7.0 1 2 5.0 8.0 2 3 6.0 9.0 3 10 NaN NaN 4 11 NaN NaN ``` ```python # Upsert a subschema upsert_data = pd.DataFrame({ "a": [3, 10, 15], "b": [6, 7, 8], }) table.merge_insert(on="a").when_matched_update_all().when_not_matched_insert_all().execute(upsert_data) table.to_pandas() ``` ``` a b c 0 1 4.0 7.0 1 2 5.0 8.0 2 3 6.0 9.0 3 10 7.0 NaN 4 11 NaN NaN 5 15 8.0 NaN ```
This commit is contained in:
@@ -21,7 +21,7 @@ def test_binary_vector():
|
||||
]
|
||||
tbl = db.create_table("my_binary_vectors", data=data)
|
||||
query = np.random.randint(0, 256, size=16)
|
||||
tbl.search(query).to_arrow()
|
||||
tbl.search(query).metric("hamming").to_arrow()
|
||||
# --8<-- [end:sync_binary_vector]
|
||||
db.drop_table("my_binary_vectors")
|
||||
|
||||
@@ -39,6 +39,6 @@ async def test_binary_vector_async():
|
||||
]
|
||||
tbl = await db.create_table("my_binary_vectors", data=data)
|
||||
query = np.random.randint(0, 256, size=16)
|
||||
await tbl.query().nearest_to(query).to_arrow()
|
||||
await tbl.query().nearest_to(query).distance_type("hamming").to_arrow()
|
||||
# --8<-- [end:async_binary_vector]
|
||||
await db.drop_table("my_binary_vectors")
|
||||
|
||||
@@ -118,9 +118,9 @@ def test_scalar_index():
|
||||
# --8<-- [end:search_with_scalar_index]
|
||||
# --8<-- [start:vector_search_with_scalar_index]
|
||||
data = [
|
||||
{"book_id": 1, "vector": [1, 2]},
|
||||
{"book_id": 2, "vector": [3, 4]},
|
||||
{"book_id": 3, "vector": [5, 6]},
|
||||
{"book_id": 1, "vector": [1.0, 2]},
|
||||
{"book_id": 2, "vector": [3.0, 4]},
|
||||
{"book_id": 3, "vector": [5.0, 6]},
|
||||
]
|
||||
|
||||
table = db.create_table("book_with_embeddings", data)
|
||||
@@ -156,9 +156,9 @@ async def test_scalar_index_async():
|
||||
# --8<-- [end:search_with_scalar_index_async]
|
||||
# --8<-- [start:vector_search_with_scalar_index_async]
|
||||
data = [
|
||||
{"book_id": 1, "vector": [1, 2]},
|
||||
{"book_id": 2, "vector": [3, 4]},
|
||||
{"book_id": 3, "vector": [5, 6]},
|
||||
{"book_id": 1, "vector": [1.0, 2]},
|
||||
{"book_id": 2, "vector": [3.0, 4]},
|
||||
{"book_id": 3, "vector": [5.0, 6]},
|
||||
]
|
||||
async_tbl = await async_db.create_table("book_with_embeddings_async", data)
|
||||
(await async_tbl.query().where("book_id != 3").nearest_to([1, 2]).to_pandas())
|
||||
|
||||
@@ -198,7 +198,6 @@ def test_embedding_function_with_pandas(tmp_path):
|
||||
{
|
||||
"text": ["hello world", "goodbye world"],
|
||||
"val": [1, 2],
|
||||
"not-used": ["s1", "s3"],
|
||||
}
|
||||
)
|
||||
db = lancedb.connect(tmp_path)
|
||||
@@ -212,7 +211,6 @@ def test_embedding_function_with_pandas(tmp_path):
|
||||
{
|
||||
"text": ["extra", "more"],
|
||||
"val": [4, 5],
|
||||
"misc-col": ["s1", "s3"],
|
||||
}
|
||||
)
|
||||
tbl.add(df)
|
||||
|
||||
@@ -242,8 +242,8 @@ def test_add_subschema(mem_db: DBConnection):
|
||||
|
||||
data = {"price": 10.0, "item": "foo"}
|
||||
table.add([data])
|
||||
data = {"price": 2.0, "vector": [3.1, 4.1]}
|
||||
table.add([data])
|
||||
data = pd.DataFrame({"price": [2.0], "vector": [[3.1, 4.1]]})
|
||||
table.add(data)
|
||||
data = {"price": 3.0, "vector": [5.9, 26.5], "item": "bar"}
|
||||
table.add([data])
|
||||
|
||||
@@ -259,7 +259,7 @@ def test_add_subschema(mem_db: DBConnection):
|
||||
|
||||
data = {"item": "foo"}
|
||||
# We can't omit a column if it's not nullable
|
||||
with pytest.raises(RuntimeError, match="Invalid user input"):
|
||||
with pytest.raises(RuntimeError, match="Append with different schema"):
|
||||
table.add([data])
|
||||
|
||||
# We can add it if we make the column nullable
|
||||
@@ -292,6 +292,7 @@ def test_add_nullability(mem_db: DBConnection):
|
||||
]
|
||||
)
|
||||
table = mem_db.create_table("test", schema=schema)
|
||||
assert table.schema.field("vector").nullable is False
|
||||
|
||||
nullable_schema = pa.schema(
|
||||
[
|
||||
@@ -320,7 +321,10 @@ def test_add_nullability(mem_db: DBConnection):
|
||||
schema=nullable_schema,
|
||||
)
|
||||
# We can't add nullable schema if it contains nulls
|
||||
with pytest.raises(Exception, match="Vector column vector has NaNs"):
|
||||
with pytest.raises(
|
||||
Exception,
|
||||
match="Casting field 'vector' with null values to non-nullable",
|
||||
):
|
||||
table.add(data)
|
||||
|
||||
# But we can make it nullable
|
||||
@@ -776,6 +780,38 @@ def test_merge_insert(mem_db: DBConnection):
|
||||
assert table.to_arrow().sort_by("a") == expected
|
||||
|
||||
|
||||
# We vary the data format because there are slight differences in how
|
||||
# subschemas are handled in different formats
|
||||
@pytest.mark.parametrize(
|
||||
"data_format",
|
||||
[
|
||||
lambda table: table,
|
||||
lambda table: table.to_pandas(),
|
||||
lambda table: table.to_pylist(),
|
||||
],
|
||||
ids=["pa.Table", "pd.DataFrame", "rows"],
|
||||
)
|
||||
def test_merge_insert_subschema(mem_db: DBConnection, data_format):
|
||||
initial_data = pa.table(
|
||||
{"id": range(3), "a": [1.0, 2.0, 3.0], "c": ["x", "x", "x"]}
|
||||
)
|
||||
table = mem_db.create_table("my_table", data=initial_data)
|
||||
|
||||
new_data = pa.table({"id": [2, 3], "c": ["y", "y"]})
|
||||
new_data = data_format(new_data)
|
||||
(
|
||||
table.merge_insert(on="id")
|
||||
.when_matched_update_all()
|
||||
.when_not_matched_insert_all()
|
||||
.execute(new_data)
|
||||
)
|
||||
|
||||
expected = pa.table(
|
||||
{"id": [0, 1, 2, 3], "a": [1.0, 2.0, 3.0, None], "c": ["x", "x", "y", "y"]}
|
||||
)
|
||||
assert table.to_arrow().sort_by("id") == expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_insert_async(mem_db_async: AsyncConnection):
|
||||
data = pa.table({"a": [1, 2, 3], "b": ["a", "b", "c"]})
|
||||
|
||||
@@ -13,10 +13,27 @@
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
from typing import Optional
|
||||
|
||||
import lance
|
||||
from lancedb.conftest import MockTextEmbeddingFunction
|
||||
from lancedb.embeddings.base import EmbeddingFunctionConfig
|
||||
from lancedb.embeddings.registry import EmbeddingFunctionRegistry
|
||||
from lancedb.table import (
|
||||
_append_vector_columns,
|
||||
_cast_to_target_schema,
|
||||
_handle_bad_vectors,
|
||||
_into_pyarrow_table,
|
||||
_sanitize_data,
|
||||
_infer_target_schema,
|
||||
)
|
||||
import pyarrow as pa
|
||||
import pandas as pd
|
||||
import polars as pl
|
||||
import pytest
|
||||
import lancedb
|
||||
from lancedb.util import get_uri_scheme, join_uri, value_to_sql
|
||||
from utils import exception_output
|
||||
|
||||
|
||||
def test_normalize_uri():
|
||||
@@ -111,3 +128,460 @@ def test_value_to_sql_string(tmp_path):
|
||||
for value in values:
|
||||
table.update(where=f"search = {value_to_sql(value)}", values={"replace": value})
|
||||
assert table.to_pandas().query("search == @value")["replace"].item() == value
|
||||
|
||||
|
||||
def test_append_vector_columns():
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
registry.register("test")(MockTextEmbeddingFunction)
|
||||
conf = EmbeddingFunctionConfig(
|
||||
source_column="text",
|
||||
vector_column="vector",
|
||||
function=MockTextEmbeddingFunction(),
|
||||
)
|
||||
metadata = registry.get_table_metadata([conf])
|
||||
|
||||
schema = pa.schema(
|
||||
{
|
||||
"text": pa.string(),
|
||||
"vector": pa.list_(pa.float64(), 10),
|
||||
}
|
||||
)
|
||||
data = pa.table(
|
||||
{
|
||||
"text": ["hello"],
|
||||
"vector": [None], # Replaces null
|
||||
},
|
||||
schema=schema,
|
||||
)
|
||||
output = _append_vector_columns(
|
||||
data,
|
||||
schema, # metadata passed separate from schema
|
||||
metadata=metadata,
|
||||
)
|
||||
assert output.schema == schema
|
||||
assert output["vector"].null_count == 0
|
||||
|
||||
# Adds if missing
|
||||
data = pa.table({"text": ["hello"]})
|
||||
output = _append_vector_columns(
|
||||
data,
|
||||
schema.with_metadata(metadata),
|
||||
)
|
||||
assert output.schema == schema
|
||||
assert output["vector"].null_count == 0
|
||||
|
||||
# doesn't embed if already there
|
||||
data = pa.table(
|
||||
{
|
||||
"text": ["hello"],
|
||||
"vector": [[42.0] * 10],
|
||||
},
|
||||
schema=schema,
|
||||
)
|
||||
output = _append_vector_columns(
|
||||
data,
|
||||
schema.with_metadata(metadata),
|
||||
)
|
||||
assert output == data # No change
|
||||
|
||||
# No provided schema
|
||||
data = pa.table(
|
||||
{
|
||||
"text": ["hello"],
|
||||
}
|
||||
)
|
||||
output = _append_vector_columns(
|
||||
data,
|
||||
metadata=metadata,
|
||||
)
|
||||
expected_schema = pa.schema(
|
||||
{
|
||||
"text": pa.string(),
|
||||
"vector": pa.list_(pa.float32(), 10),
|
||||
}
|
||||
)
|
||||
assert output.schema == expected_schema
|
||||
assert output["vector"].null_count == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("on_bad_vectors", ["error", "drop", "fill", "null"])
|
||||
def test_handle_bad_vectors_jagged(on_bad_vectors):
|
||||
vector = pa.array([[1.0, 2.0], [3.0], [4.0, 5.0]])
|
||||
schema = pa.schema({"vector": pa.list_(pa.float64())})
|
||||
data = pa.table({"vector": vector}, schema=schema)
|
||||
|
||||
if on_bad_vectors == "error":
|
||||
with pytest.raises(ValueError) as e:
|
||||
output = _handle_bad_vectors(
|
||||
data,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
)
|
||||
output = exception_output(e)
|
||||
assert output == (
|
||||
"ValueError: Vector column 'vector' has variable length vectors. Set "
|
||||
"on_bad_vectors='drop' to remove them, set on_bad_vectors='fill' "
|
||||
"and fill_value=<value> to replace them, or set on_bad_vectors='null' "
|
||||
"to replace them with null."
|
||||
)
|
||||
return
|
||||
else:
|
||||
output = _handle_bad_vectors(
|
||||
data,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=42.0,
|
||||
)
|
||||
|
||||
if on_bad_vectors == "drop":
|
||||
expected = pa.array([[1.0, 2.0], [4.0, 5.0]])
|
||||
elif on_bad_vectors == "fill":
|
||||
expected = pa.array([[1.0, 2.0], [42.0, 42.0], [4.0, 5.0]])
|
||||
elif on_bad_vectors == "null":
|
||||
expected = pa.array([[1.0, 2.0], None, [4.0, 5.0]])
|
||||
|
||||
assert output["vector"].combine_chunks() == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("on_bad_vectors", ["error", "drop", "fill", "null"])
|
||||
def test_handle_bad_vectors_nan(on_bad_vectors):
|
||||
vector = pa.array([[1.0, float("nan")], [3.0, 4.0]])
|
||||
data = pa.table({"vector": vector})
|
||||
|
||||
if on_bad_vectors == "error":
|
||||
with pytest.raises(ValueError) as e:
|
||||
output = _handle_bad_vectors(
|
||||
data,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
)
|
||||
output = exception_output(e)
|
||||
assert output == (
|
||||
"ValueError: Vector column 'vector' has NaNs. Set "
|
||||
"on_bad_vectors='drop' to remove them, set on_bad_vectors='fill' "
|
||||
"and fill_value=<value> to replace them, or set on_bad_vectors='null' "
|
||||
"to replace them with null."
|
||||
)
|
||||
return
|
||||
else:
|
||||
output = _handle_bad_vectors(
|
||||
data,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=42.0,
|
||||
)
|
||||
|
||||
if on_bad_vectors == "drop":
|
||||
expected = pa.array([[3.0, 4.0]])
|
||||
elif on_bad_vectors == "fill":
|
||||
expected = pa.array([[42.0, 42.0], [3.0, 4.0]])
|
||||
elif on_bad_vectors == "null":
|
||||
expected = pa.array([None, [3.0, 4.0]])
|
||||
|
||||
assert output["vector"].combine_chunks() == expected
|
||||
|
||||
|
||||
def test_handle_bad_vectors_noop():
|
||||
# ChunkedArray should be preserved as-is
|
||||
vector = pa.chunked_array(
|
||||
[[[1.0, 2.0], [3.0, 4.0]]], type=pa.list_(pa.float64(), 2)
|
||||
)
|
||||
data = pa.table({"vector": vector})
|
||||
output = _handle_bad_vectors(data)
|
||||
assert output["vector"] == vector
|
||||
|
||||
|
||||
class TestModel(lancedb.pydantic.LanceModel):
|
||||
a: Optional[int]
|
||||
b: Optional[int]
|
||||
|
||||
|
||||
# TODO: huggingface,
|
||||
@pytest.mark.parametrize(
|
||||
"data",
|
||||
[
|
||||
lambda: [{"a": 1, "b": 2}],
|
||||
lambda: pa.RecordBatch.from_pylist([{"a": 1, "b": 2}]),
|
||||
lambda: pa.table({"a": [1], "b": [2]}),
|
||||
lambda: pa.table({"a": [1], "b": [2]}).to_reader(),
|
||||
lambda: iter(pa.table({"a": [1], "b": [2]}).to_batches()),
|
||||
lambda: (
|
||||
lance.write_dataset(
|
||||
pa.table({"a": [1], "b": [2]}),
|
||||
"memory://test",
|
||||
)
|
||||
),
|
||||
lambda: (
|
||||
lance.write_dataset(
|
||||
pa.table({"a": [1], "b": [2]}),
|
||||
"memory://test",
|
||||
).scanner()
|
||||
),
|
||||
lambda: pd.DataFrame({"a": [1], "b": [2]}),
|
||||
lambda: pl.DataFrame({"a": [1], "b": [2]}),
|
||||
lambda: pl.LazyFrame({"a": [1], "b": [2]}),
|
||||
lambda: [TestModel(a=1, b=2)],
|
||||
],
|
||||
ids=[
|
||||
"rows",
|
||||
"pa.RecordBatch",
|
||||
"pa.Table",
|
||||
"pa.RecordBatchReader",
|
||||
"batch_iter",
|
||||
"lance.LanceDataset",
|
||||
"lance.LanceScanner",
|
||||
"pd.DataFrame",
|
||||
"pl.DataFrame",
|
||||
"pl.LazyFrame",
|
||||
"pydantic",
|
||||
],
|
||||
)
|
||||
def test_into_pyarrow_table(data):
|
||||
expected = pa.table({"a": [1], "b": [2]})
|
||||
output = _into_pyarrow_table(data())
|
||||
assert output == expected
|
||||
|
||||
|
||||
def test_infer_target_schema():
|
||||
example = pa.schema(
|
||||
{
|
||||
"vec1": pa.list_(pa.float64(), 2),
|
||||
"vector": pa.list_(pa.float64()),
|
||||
}
|
||||
)
|
||||
data = pa.table(
|
||||
{
|
||||
"vec1": [[0.0] * 2],
|
||||
"vector": [[0.0] * 2],
|
||||
},
|
||||
schema=example,
|
||||
)
|
||||
expected = pa.schema(
|
||||
{
|
||||
"vec1": pa.list_(pa.float64(), 2),
|
||||
"vector": pa.list_(pa.float32(), 2),
|
||||
}
|
||||
)
|
||||
output = _infer_target_schema(data)
|
||||
assert output == expected
|
||||
|
||||
# Handle large list and use modal size
|
||||
# Most vectors are of length 2, so we should infer that as the target dimension
|
||||
example = pa.schema(
|
||||
{
|
||||
"vector": pa.large_list(pa.float64()),
|
||||
}
|
||||
)
|
||||
data = pa.table(
|
||||
{
|
||||
"vector": [[0.0] * 2, [0.0], [0.0] * 2],
|
||||
},
|
||||
schema=example,
|
||||
)
|
||||
expected = pa.schema(
|
||||
{
|
||||
"vector": pa.list_(pa.float32(), 2),
|
||||
}
|
||||
)
|
||||
output = _infer_target_schema(data)
|
||||
assert output == expected
|
||||
|
||||
# ignore if not list
|
||||
example = pa.schema(
|
||||
{
|
||||
"vector": pa.float64(),
|
||||
}
|
||||
)
|
||||
data = pa.table(
|
||||
{
|
||||
"vector": [0.0],
|
||||
},
|
||||
schema=example,
|
||||
)
|
||||
expected = example
|
||||
output = _infer_target_schema(data)
|
||||
assert output == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"data",
|
||||
[
|
||||
[{"id": 1, "text": "hello"}],
|
||||
pa.RecordBatch.from_pylist([{"id": 1, "text": "hello"}]),
|
||||
pd.DataFrame({"id": [1], "text": ["hello"]}),
|
||||
pl.DataFrame({"id": [1], "text": ["hello"]}),
|
||||
],
|
||||
ids=["rows", "pa.RecordBatch", "pd.DataFrame", "pl.DataFrame"],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"schema",
|
||||
[
|
||||
None,
|
||||
pa.schema(
|
||||
{
|
||||
"id": pa.int32(),
|
||||
"text": pa.string(),
|
||||
"vector": pa.list_(pa.float32(), 10),
|
||||
}
|
||||
),
|
||||
pa.schema(
|
||||
{
|
||||
"id": pa.int64(),
|
||||
"text": pa.string(),
|
||||
"vector": pa.list_(pa.float32(), 10),
|
||||
"extra": pa.int64(),
|
||||
}
|
||||
),
|
||||
],
|
||||
ids=["infer", "explicit", "subschema"],
|
||||
)
|
||||
@pytest.mark.parametrize("with_embedding", [True, False])
|
||||
def test_sanitize_data(
|
||||
data,
|
||||
schema: Optional[pa.Schema],
|
||||
with_embedding: bool,
|
||||
):
|
||||
if with_embedding:
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
registry.register("test")(MockTextEmbeddingFunction)
|
||||
conf = EmbeddingFunctionConfig(
|
||||
source_column="text",
|
||||
vector_column="vector",
|
||||
function=MockTextEmbeddingFunction(),
|
||||
)
|
||||
metadata = registry.get_table_metadata([conf])
|
||||
else:
|
||||
metadata = None
|
||||
|
||||
if schema is not None:
|
||||
to_remove = schema.get_field_index("extra")
|
||||
if to_remove >= 0:
|
||||
expected_schema = schema.remove(to_remove)
|
||||
else:
|
||||
expected_schema = schema
|
||||
else:
|
||||
expected_schema = pa.schema(
|
||||
{
|
||||
"id": pa.int64(),
|
||||
"text": pa.large_utf8()
|
||||
if isinstance(data, pl.DataFrame)
|
||||
else pa.string(),
|
||||
"vector": pa.list_(pa.float32(), 10),
|
||||
}
|
||||
)
|
||||
|
||||
if not with_embedding:
|
||||
to_remove = expected_schema.get_field_index("vector")
|
||||
if to_remove >= 0:
|
||||
expected_schema = expected_schema.remove(to_remove)
|
||||
|
||||
expected = pa.table(
|
||||
{
|
||||
"id": [1],
|
||||
"text": ["hello"],
|
||||
"vector": [[0.0] * 10],
|
||||
},
|
||||
schema=expected_schema,
|
||||
)
|
||||
|
||||
output_data = _sanitize_data(
|
||||
data,
|
||||
target_schema=schema,
|
||||
metadata=metadata,
|
||||
allow_subschema=True,
|
||||
)
|
||||
|
||||
assert output_data == expected
|
||||
|
||||
|
||||
def test_cast_to_target_schema():
|
||||
original_schema = pa.schema(
|
||||
{
|
||||
"id": pa.int32(),
|
||||
"struct": pa.struct(
|
||||
[
|
||||
pa.field("a", pa.int32()),
|
||||
]
|
||||
),
|
||||
"vector": pa.list_(pa.float64()),
|
||||
"vec1": pa.list_(pa.float64(), 2),
|
||||
"vec2": pa.list_(pa.float32(), 2),
|
||||
}
|
||||
)
|
||||
data = pa.table(
|
||||
{
|
||||
"id": [1],
|
||||
"struct": [{"a": 1}],
|
||||
"vector": [[0.0] * 2],
|
||||
"vec1": [[0.0] * 2],
|
||||
"vec2": [[0.0] * 2],
|
||||
},
|
||||
schema=original_schema,
|
||||
)
|
||||
|
||||
target = pa.schema(
|
||||
{
|
||||
"id": pa.int64(),
|
||||
"struct": pa.struct(
|
||||
[
|
||||
pa.field("a", pa.int64()),
|
||||
]
|
||||
),
|
||||
"vector": pa.list_(pa.float32(), 2),
|
||||
"vec1": pa.list_(pa.float32(), 2),
|
||||
"vec2": pa.list_(pa.float32(), 2),
|
||||
}
|
||||
)
|
||||
output = _cast_to_target_schema(data, target)
|
||||
expected = pa.table(
|
||||
{
|
||||
"id": [1],
|
||||
"struct": [{"a": 1}],
|
||||
"vector": [[0.0] * 2],
|
||||
"vec1": [[0.0] * 2],
|
||||
"vec2": [[0.0] * 2],
|
||||
},
|
||||
schema=target,
|
||||
)
|
||||
|
||||
# Data can be a subschema of the target
|
||||
target = pa.schema(
|
||||
{
|
||||
"id": pa.int64(),
|
||||
"struct": pa.struct(
|
||||
[
|
||||
pa.field("a", pa.int64()),
|
||||
# Additional nested field
|
||||
pa.field("b", pa.int64()),
|
||||
]
|
||||
),
|
||||
"vector": pa.list_(pa.float32(), 2),
|
||||
"vec1": pa.list_(pa.float32(), 2),
|
||||
"vec2": pa.list_(pa.float32(), 2),
|
||||
# Additional field
|
||||
"extra": pa.int64(),
|
||||
}
|
||||
)
|
||||
with pytest.raises(Exception):
|
||||
_cast_to_target_schema(data, target)
|
||||
output = _cast_to_target_schema(data, target, allow_subschema=True)
|
||||
expected_schema = pa.schema(
|
||||
{
|
||||
"id": pa.int64(),
|
||||
"struct": pa.struct(
|
||||
[
|
||||
pa.field("a", pa.int64()),
|
||||
]
|
||||
),
|
||||
"vector": pa.list_(pa.float32(), 2),
|
||||
"vec1": pa.list_(pa.float32(), 2),
|
||||
"vec2": pa.list_(pa.float32(), 2),
|
||||
}
|
||||
)
|
||||
expected = pa.table(
|
||||
{
|
||||
"id": [1],
|
||||
"struct": [{"a": 1}],
|
||||
"vector": [[0.0] * 2],
|
||||
"vec1": [[0.0] * 2],
|
||||
"vec2": [[0.0] * 2],
|
||||
},
|
||||
schema=expected_schema,
|
||||
)
|
||||
assert output == expected
|
||||
|
||||
Reference in New Issue
Block a user