Compare commits

..

3 Commits

Author SHA1 Message Date
Xuanwo
7c37ba216a style(python): format permutation pickle tests 2026-04-09 14:47:06 +08:00
Xuanwo
768d84845c feat(python): support pickling permutations 2026-04-09 00:32:48 +08:00
Xuanwo
2d380d1669 Track permutation reopen metadata 2026-04-08 17:34:05 +08:00
5 changed files with 396 additions and 674 deletions

View File

@@ -1,21 +1,98 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
import json
import pickle
from datetime import timedelta
from typing import Any, Callable, Iterator, Literal, Optional, TYPE_CHECKING, Union
import pyarrow as pa
from deprecation import deprecated
from lancedb import AsyncConnection, DBConnection
import pyarrow as pa
import json
from ._lancedb import async_permutation_builder, PermutationReader
from .table import LanceTable
from .background_loop import LOOP
from .table import LanceTable
from .util import batch_to_tensor, batch_to_tensor_rows
from typing import Any, Callable, Iterator, Literal, Optional, TYPE_CHECKING, Union
if TYPE_CHECKING:
from lancedb.dependencies import pandas as pd, numpy as np, polars as pl
def _builtin_transform(format: str) -> Callable[[pa.RecordBatch], Any]:
if format == "python":
return Transforms.arrow2python
if format == "python_col":
return Transforms.arrow2pythoncol
if format == "numpy":
return Transforms.arrow2numpy
if format == "pandas":
return Transforms.arrow2pandas
if format == "arrow":
return Transforms.arrow2arrow
if format == "torch":
return batch_to_tensor_rows
if format == "torch_col":
return batch_to_tensor
if format == "polars":
return Transforms.arrow2polars()
raise ValueError(f"Invalid format: {format}")
def _table_to_state(
table: Union[LanceTable, dict[str, Any]],
) -> dict[str, Any]:
if isinstance(table, dict):
return table
if not isinstance(table, LanceTable):
raise pickle.PicklingError(
"Permutation pickling only supports LanceTable-backed permutations"
)
if table._namespace_client is not None:
raise pickle.PicklingError(
"Permutation pickling does not yet support namespace-backed tables"
)
if table._conn.uri.startswith("memory://"):
raise pickle.PicklingError(
"Permutation pickling does not support in-memory databases"
)
try:
read_consistency_interval = table._conn.read_consistency_interval
except Exception:
read_consistency_interval = None
return {
"uri": table._conn.uri,
"name": table.name,
"version": table.version,
"storage_options": table.initial_storage_options(),
"read_consistency_interval_secs": (
read_consistency_interval.total_seconds()
if read_consistency_interval is not None
else None
),
"namespace_path": list(table.namespace),
}
def _table_from_state(state: dict[str, Any]) -> LanceTable:
from . import connect
read_consistency_interval = (
timedelta(seconds=state["read_consistency_interval_secs"])
if state["read_consistency_interval_secs"] is not None
else None
)
db = connect(
state["uri"],
read_consistency_interval=read_consistency_interval,
storage_options=state["storage_options"],
)
table = db.open_table(state["name"], namespace_path=state["namespace_path"])
table.checkout(state["version"])
return table
class PermutationBuilder:
"""
A utility for creating a "permutation table" which is a table that defines an
@@ -284,8 +361,9 @@ class Permutations:
self.permutation_table = permutation_table
if permutation_table.schema.metadata is not None:
raw = permutation_table.schema.metadata.get(b"split_names")
split_names = raw.decode("utf-8") if raw is not None else None
split_names = permutation_table.schema.metadata.get(
b"split_names", None
).decode("utf-8")
if split_names is not None:
self.split_names = json.loads(split_names)
self.split_dict = {
@@ -384,6 +462,13 @@ class Permutation:
selection: dict[str, str],
batch_size: int,
transform_fn: Callable[pa.RecordBatch, Any],
*,
base_table: Union[LanceTable, dict[str, Any]],
permutation_table: Optional[Union[LanceTable, dict[str, Any]]],
split: int,
offset: Optional[int] = None,
limit: Optional[int] = None,
transform_spec: Optional[str] = None,
):
"""
Internal constructor. Use [from_tables](#from_tables) instead.
@@ -394,6 +479,93 @@ class Permutation:
self.selection = selection
self.transform_fn = transform_fn
self.batch_size = batch_size
self._transform_spec = transform_spec
# These fields are used to reconstruct the permutation in a new process.
self._base_table = base_table
self._permutation_table = permutation_table
self._split = split
self._offset = offset
self._limit = limit
def _reopen_metadata(self) -> dict[str, Any]:
return {
"base_table": self._base_table,
"permutation_table": self._permutation_table,
"split": self._split,
"offset": self._offset,
"limit": self._limit,
"transform_spec": self._transform_spec,
}
def __getstate__(self) -> dict[str, Any]:
if self._transform_spec is not None:
transform_state = {
"kind": "builtin",
"format": self._transform_spec,
}
else:
transform_state = {
"kind": "callable",
"transform_fn": self.transform_fn,
}
return {
"selection": self.selection,
"batch_size": self.batch_size,
"transform": transform_state,
"reopen": {
**self._reopen_metadata(),
# Store reopen state instead of live LanceTable handles.
"base_table": _table_to_state(self._base_table),
"permutation_table": (
_table_to_state(self._permutation_table)
if self._permutation_table is not None
else None
),
},
}
def __setstate__(self, state: dict[str, Any]) -> None:
reopen = state["reopen"]
base_table = _table_from_state(reopen["base_table"])
permutation_table_state = reopen["permutation_table"]
permutation_table = (
_table_from_state(permutation_table_state)
if permutation_table_state is not None
else None
)
split = reopen["split"]
offset = reopen["offset"]
limit = reopen["limit"]
async def do_reopen():
reader = await PermutationReader.from_tables(
base_table, permutation_table, split
)
if offset is not None:
reader = await reader.with_offset(offset)
if limit is not None:
reader = await reader.with_limit(limit)
return reader
transform = state["transform"]
if transform["kind"] == "builtin":
transform_spec = transform["format"]
transform_fn = _builtin_transform(transform_spec)
else:
transform_spec = None
transform_fn = transform["transform_fn"]
self.reader = LOOP.run(do_reopen())
self.selection = state["selection"]
self.batch_size = state["batch_size"]
self.transform_fn = transform_fn
self._transform_spec = transform_spec
self._base_table = reopen["base_table"]
self._permutation_table = permutation_table_state
self._split = split
self._offset = offset
self._limit = limit
def _with_selection(self, selection: dict[str, str]) -> "Permutation":
"""
@@ -402,7 +574,13 @@ class Permutation:
Does not validation of the selection and it replaces it entirely. This is not
intended for public use.
"""
return Permutation(self.reader, selection, self.batch_size, self.transform_fn)
return Permutation(
self.reader,
selection,
self.batch_size,
self.transform_fn,
**self._reopen_metadata(),
)
def _with_reader(self, reader: PermutationReader) -> "Permutation":
"""
@@ -410,13 +588,25 @@ class Permutation:
This is an internal method and should not be used directly.
"""
return Permutation(reader, self.selection, self.batch_size, self.transform_fn)
return Permutation(
reader,
self.selection,
self.batch_size,
self.transform_fn,
**self._reopen_metadata(),
)
def with_batch_size(self, batch_size: int) -> "Permutation":
"""
Creates a new permutation with the given batch size
"""
return Permutation(self.reader, self.selection, batch_size, self.transform_fn)
return Permutation(
self.reader,
self.selection,
batch_size,
self.transform_fn,
**self._reopen_metadata(),
)
@classmethod
def identity(cls, table: LanceTable) -> "Permutation":
@@ -459,8 +649,9 @@ class Permutation:
f"Cannot create a permutation on split `{split}`"
" because no split names are defined in the permutation table"
)
raw = permutation_table.schema.metadata.get(b"split_names")
split_names = raw.decode("utf-8") if raw is not None else None
split_names = permutation_table.schema.metadata.get(
b"split_names", None
).decode("utf-8")
if split_names is None:
raise ValueError(
f"Cannot create a permutation on split `{split}`"
@@ -489,7 +680,14 @@ class Permutation:
schema = await reader.output_schema(None)
initial_selection = {name: name for name in schema.names}
return cls(
reader, initial_selection, DEFAULT_BATCH_SIZE, Transforms.arrow2python
reader,
initial_selection,
DEFAULT_BATCH_SIZE,
Transforms.arrow2python,
base_table=base_table,
permutation_table=permutation_table,
split=split,
transform_spec="python",
)
return LOOP.run(do_from_tables())
@@ -730,24 +928,16 @@ class Permutation:
this method.
"""
assert format is not None, "format is required"
if format == "python":
return self.with_transform(Transforms.arrow2python)
if format == "python_col":
return self.with_transform(Transforms.arrow2pythoncol)
elif format == "numpy":
return self.with_transform(Transforms.arrow2numpy)
elif format == "pandas":
return self.with_transform(Transforms.arrow2pandas)
elif format == "arrow":
return self.with_transform(Transforms.arrow2arrow)
elif format == "torch":
return self.with_transform(batch_to_tensor_rows)
elif format == "torch_col":
return self.with_transform(batch_to_tensor)
elif format == "polars":
return self.with_transform(Transforms.arrow2polars())
else:
raise ValueError(f"Invalid format: {format}")
return Permutation(
self.reader,
self.selection,
self.batch_size,
_builtin_transform(format),
**{
**self._reopen_metadata(),
"transform_spec": format,
},
)
def with_transform(self, transform: Callable[pa.RecordBatch, Any]) -> "Permutation":
"""
@@ -760,7 +950,16 @@ class Permutation:
for expensive operations such as image decoding.
"""
assert transform is not None, "transform is required"
return Permutation(self.reader, self.selection, self.batch_size, transform)
return Permutation(
self.reader,
self.selection,
self.batch_size,
transform,
**{
**self._reopen_metadata(),
"transform_spec": None,
},
)
def __getitem__(self, index: int) -> Any:
"""
@@ -798,7 +997,16 @@ class Permutation:
async def do_with_skip():
reader = await self.reader.with_offset(skip)
return self._with_reader(reader)
return Permutation(
reader,
self.selection,
self.batch_size,
self.transform_fn,
**{
**self._reopen_metadata(),
"offset": skip,
},
)
return LOOP.run(do_with_skip())
@@ -821,7 +1029,16 @@ class Permutation:
async def do_with_take():
reader = await self.reader.with_limit(limit)
return self._with_reader(reader)
return Permutation(
reader,
self.selection,
self.batch_size,
self.transform_fn,
**{
**self._reopen_metadata(),
"limit": limit,
},
)
return LOOP.run(do_with_take())

View File

@@ -270,17 +270,15 @@ def _sanitize_data(
reader,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
target_schema=target_schema,
metadata=metadata,
)
if target_schema is None:
target_schema, reader = _infer_target_schema(reader)
if metadata:
target_schema = target_schema.with_metadata(
_merge_metadata(target_schema.metadata, metadata)
)
new_metadata = target_schema.metadata or {}
new_metadata.update(metadata)
target_schema = target_schema.with_metadata(new_metadata)
_validate_schema(target_schema)
reader = _cast_to_target_schema(reader, target_schema, allow_subschema)
@@ -296,7 +294,7 @@ def _cast_to_target_schema(
# pa.Table.cast expects field order not to be changed.
# Lance doesn't care about field order, so we don't need to rearrange fields
# to match the target schema. We just need to correctly cast the fields.
if reader.schema.equals(target_schema, check_metadata=True):
if reader.schema == target_schema:
# Fast path when the schemas are already the same
return reader
@@ -316,13 +314,7 @@ def _cast_to_target_schema(
def gen():
for batch in reader:
# Table but not RecordBatch has cast.
cast_batches = (
pa.Table.from_batches([batch]).cast(reordered_schema).to_batches()
)
if cast_batches:
yield pa.RecordBatch.from_arrays(
cast_batches[0].columns, schema=reordered_schema
)
yield pa.Table.from_batches([batch]).cast(reordered_schema).to_batches()[0]
return pa.RecordBatchReader.from_batches(reordered_schema, gen())
@@ -340,51 +332,37 @@ def _align_field_types(
if target_field is None:
raise ValueError(f"Field '{field.name}' not found in target schema")
if pa.types.is_struct(target_field.type):
if pa.types.is_struct(field.type):
new_type = pa.struct(
_align_field_types(
field.type.fields,
target_field.type.fields,
)
new_type = pa.struct(
_align_field_types(
field.type.fields,
target_field.type.fields,
)
else:
new_type = target_field.type
)
elif pa.types.is_list(target_field.type):
if _is_list_like(field.type):
new_type = pa.list_(
_align_field_types(
[field.type.value_field],
[target_field.type.value_field],
)[0]
)
else:
new_type = target_field.type
new_type = pa.list_(
_align_field_types(
[field.type.value_field],
[target_field.type.value_field],
)[0]
)
elif pa.types.is_large_list(target_field.type):
if _is_list_like(field.type):
new_type = pa.large_list(
_align_field_types(
[field.type.value_field],
[target_field.type.value_field],
)[0]
)
else:
new_type = target_field.type
new_type = pa.large_list(
_align_field_types(
[field.type.value_field],
[target_field.type.value_field],
)[0]
)
elif pa.types.is_fixed_size_list(target_field.type):
if _is_list_like(field.type):
new_type = pa.list_(
_align_field_types(
[field.type.value_field],
[target_field.type.value_field],
)[0],
target_field.type.list_size,
)
else:
new_type = target_field.type
new_type = pa.list_(
_align_field_types(
[field.type.value_field],
[target_field.type.value_field],
)[0],
target_field.type.list_size,
)
else:
new_type = target_field.type
new_fields.append(
pa.field(field.name, new_type, field.nullable, target_field.metadata)
)
new_fields.append(pa.field(field.name, new_type, field.nullable))
return new_fields
@@ -462,7 +440,6 @@ def sanitize_create_table(
schema = data.schema
if metadata:
metadata = _merge_metadata(schema.metadata, metadata)
schema = schema.with_metadata(metadata)
# Need to apply metadata to the data as well
if isinstance(data, pa.Table):
@@ -515,9 +492,9 @@ def _append_vector_columns(
vector columns to the table.
"""
if schema is None:
metadata = _merge_metadata(metadata)
metadata = metadata or {}
else:
metadata = _merge_metadata(schema.metadata, metadata)
metadata = schema.metadata or metadata or {}
functions = EmbeddingFunctionRegistry.get_instance().parse_functions(metadata)
if not functions:
@@ -3234,157 +3211,43 @@ def _handle_bad_vectors(
reader: pa.RecordBatchReader,
on_bad_vectors: Literal["error", "drop", "fill", "null"] = "error",
fill_value: float = 0.0,
target_schema: Optional[pa.Schema] = None,
metadata: Optional[dict] = None,
) -> pa.RecordBatchReader:
vector_columns = _find_vector_columns(reader.schema, target_schema, metadata)
if not vector_columns:
return reader
vector_columns = []
output_schema = _vector_output_schema(reader.schema, vector_columns)
for field in reader.schema:
# They can provide a 'vector' column that isn't yet a FSL
named_vector_col = (
(
pa.types.is_list(field.type)
or pa.types.is_large_list(field.type)
or pa.types.is_fixed_size_list(field.type)
)
and pa.types.is_floating(field.type.value_type)
and field.name == VECTOR_COLUMN_NAME
)
# TODO: we're making an assumption that fixed size list of 10 or more
# is a vector column. This is definitely a bit hacky.
likely_vector_col = (
pa.types.is_fixed_size_list(field.type)
and pa.types.is_floating(field.type.value_type)
and (field.type.list_size >= 10)
)
if named_vector_col or likely_vector_col:
vector_columns.append(field.name)
def gen():
for batch in reader:
pending_dims = []
for vector_column in vector_columns:
dim = vector_column["expected_dim"]
if target_schema is not None and dim is None:
dim = _infer_vector_dim(batch[vector_column["name"]])
pending_dims.append(vector_column)
for name in vector_columns:
batch = _handle_bad_vector_column(
batch,
vector_column_name=vector_column["name"],
vector_column_name=name,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
expected_dim=dim,
expected_value_type=vector_column["expected_value_type"],
)
for vector_column in pending_dims:
if vector_column["expected_dim"] is None:
vector_column["expected_dim"] = _infer_vector_dim(
batch[vector_column["name"]]
)
if batch.schema.equals(output_schema, check_metadata=True):
yield batch
continue
yield batch
cast_batches = (
pa.Table.from_batches([batch]).cast(output_schema).to_batches()
)
if cast_batches:
yield pa.RecordBatch.from_arrays(
cast_batches[0].columns,
schema=output_schema,
)
return pa.RecordBatchReader.from_batches(output_schema, gen())
def _find_vector_columns(
reader_schema: pa.Schema,
target_schema: Optional[pa.Schema],
metadata: Optional[dict],
) -> List[dict]:
if target_schema is None:
vector_columns = []
for field in reader_schema:
named_vector_col = (
_is_list_like(field.type)
and pa.types.is_floating(field.type.value_type)
and field.name == VECTOR_COLUMN_NAME
)
likely_vector_col = (
pa.types.is_fixed_size_list(field.type)
and pa.types.is_floating(field.type.value_type)
and (field.type.list_size >= 10)
)
if named_vector_col or likely_vector_col:
vector_columns.append(
{
"name": field.name,
"expected_dim": None,
"expected_value_type": None,
}
)
return vector_columns
reader_column_names = set(reader_schema.names)
active_metadata = _merge_metadata(target_schema.metadata, metadata)
embedding_function_columns = set(
EmbeddingFunctionRegistry.get_instance().parse_functions(active_metadata).keys()
)
vector_columns = []
for field in target_schema:
if field.name not in reader_column_names:
continue
if not _is_list_like(field.type) or not pa.types.is_floating(
field.type.value_type
):
continue
reader_field = reader_schema.field(field.name)
named_vector_col = (
field.name in embedding_function_columns
or field.name == VECTOR_COLUMN_NAME
or (field.name == "embedding" and pa.types.is_fixed_size_list(field.type))
)
typed_fixed_vector_col = (
pa.types.is_fixed_size_list(reader_field.type)
and pa.types.is_floating(reader_field.type.value_type)
and reader_field.type.list_size >= 10
)
if named_vector_col or typed_fixed_vector_col:
vector_columns.append(
{
"name": field.name,
"expected_dim": (
field.type.list_size
if pa.types.is_fixed_size_list(field.type)
else None
),
"expected_value_type": field.type.value_type,
}
)
return vector_columns
def _vector_output_schema(
reader_schema: pa.Schema,
vector_columns: List[dict],
) -> pa.Schema:
columns_by_name = {column["name"]: column for column in vector_columns}
fields = []
for field in reader_schema:
column = columns_by_name.get(field.name)
if column is None:
output_type = field.type
else:
output_type = _vector_output_type(field, column)
fields.append(pa.field(field.name, output_type, field.nullable, field.metadata))
return pa.schema(fields, metadata=reader_schema.metadata)
def _vector_output_type(field: pa.Field, vector_column: dict) -> pa.DataType:
if not _is_list_like(field.type):
return field.type
if vector_column["expected_value_type"] is not None and (
pa.types.is_null(field.type.value_type)
or pa.types.is_integer(field.type.value_type)
or pa.types.is_unsigned_integer(field.type.value_type)
):
return pa.list_(vector_column["expected_value_type"])
if (
vector_column["expected_dim"] is not None
and pa.types.is_fixed_size_list(field.type)
and field.type.list_size != vector_column["expected_dim"]
):
return pa.list_(field.type.value_type)
return field.type
return pa.RecordBatchReader.from_batches(reader.schema, gen())
def _handle_bad_vector_column(
@@ -3392,8 +3255,6 @@ def _handle_bad_vector_column(
vector_column_name: str,
on_bad_vectors: str = "error",
fill_value: float = 0.0,
expected_dim: Optional[int] = None,
expected_value_type: Optional[pa.DataType] = None,
) -> pa.RecordBatch:
"""
Ensure that the vector column exists and has type fixed_size_list(float)
@@ -3410,39 +3271,14 @@ def _handle_bad_vector_column(
fill_value: float, default 0.0
The value to use when filling vectors. Only used if on_bad_vectors="fill".
"""
position = data.column_names.index(vector_column_name)
vec_arr = data[vector_column_name]
if not _is_list_like(vec_arr.type):
return data
if (
expected_dim is not None
and pa.types.is_fixed_size_list(vec_arr.type)
and vec_arr.type.list_size != expected_dim
):
vec_arr = pa.array(vec_arr.to_pylist(), type=pa.list_(vec_arr.type.value_type))
data = data.set_column(position, vector_column_name, vec_arr)
has_nan = has_nan_values(vec_arr)
if expected_value_type is not None and (
pa.types.is_integer(vec_arr.type.value_type)
or pa.types.is_unsigned_integer(vec_arr.type.value_type)
):
vec_arr = pa.array(vec_arr.to_pylist(), type=pa.list_(expected_value_type))
data = data.set_column(position, vector_column_name, vec_arr)
if pa.types.is_floating(vec_arr.type.value_type):
has_nan = has_nan_values(vec_arr)
else:
has_nan = pa.array([False] * len(vec_arr))
if expected_dim is not None:
dim = expected_dim
elif pa.types.is_fixed_size_list(vec_arr.type):
if pa.types.is_fixed_size_list(vec_arr.type):
dim = vec_arr.type.list_size
else:
dim = _infer_vector_dim(vec_arr)
if dim is None:
return data
dim = _modal_list_size(vec_arr)
has_wrong_dim = pc.not_equal(pc.list_value_length(vec_arr), dim)
has_bad_vectors = pc.any(has_nan).as_py() or pc.any(has_wrong_dim).as_py()
@@ -3480,12 +3316,13 @@ def _handle_bad_vector_column(
)
vec_arr = pc.if_else(
is_bad,
pa.scalar([fill_value] * dim, type=vec_arr.type),
pa.scalar([fill_value] * dim),
vec_arr,
)
else:
raise ValueError(f"Invalid value for on_bad_vectors: {on_bad_vectors}")
position = data.column_names.index(vector_column_name)
return data.set_column(position, vector_column_name, vec_arr)
@@ -3506,28 +3343,6 @@ def has_nan_values(arr: Union[pa.ListArray, pa.ChunkedArray]) -> pa.BooleanArray
return pc.is_in(indices, has_nan_indices)
def _is_list_like(data_type: pa.DataType) -> bool:
return (
pa.types.is_list(data_type)
or pa.types.is_large_list(data_type)
or pa.types.is_fixed_size_list(data_type)
)
def _merge_metadata(*metadata_dicts: Optional[dict]) -> dict:
merged = {}
for metadata in metadata_dicts:
if metadata is None:
continue
for key, value in metadata.items():
if isinstance(key, str):
key = key.encode("utf-8")
if isinstance(value, str):
value = value.encode("utf-8")
merged[key] = value
return merged
def _name_suggests_vector_column(field_name: str) -> bool:
"""Check if a field name indicates a vector column."""
name_lower = field_name.lower()
@@ -3595,16 +3410,6 @@ def _modal_list_size(arr: Union[pa.ListArray, pa.ChunkedArray]) -> int:
return pc.mode(pc.list_value_length(arr))[0].as_py()["mode"]
def _infer_vector_dim(arr: Union[pa.Array, pa.ChunkedArray]) -> Optional[int]:
if not _is_list_like(arr.type):
return None
lengths = pc.list_value_length(arr)
lengths = pc.filter(lengths, pc.greater(lengths, 0))
if len(lengths) == 0:
return None
return pc.mode(lengths)[0].as_py()["mode"]
def _validate_schema(schema: pa.Schema):
"""
Make sure the metadata is valid utf8

View File

@@ -3,6 +3,7 @@
import pyarrow as pa
import math
import pickle
import pytest
from lancedb import DBConnection, Table, connect
@@ -522,50 +523,6 @@ def test_no_split_names(some_table: Table):
assert permutations[1].num_rows == 500
def test_permutations_metadata_without_split_names_key(mem_db: DBConnection):
"""Regression: schema metadata present but missing split_names key must not crash.
Previously, `.get(b"split_names", None).decode()` was called unconditionally,
so any permutation table whose metadata dict had other keys but no split_names
raised AttributeError: 'NoneType' has no attribute 'decode'.
"""
base = mem_db.create_table("base_nosplit", pa.table({"x": range(10)}))
# Build a permutation-like table that carries some metadata but NOT split_names.
raw = pa.table(
{
"row_id": pa.array(range(10), type=pa.uint64()),
"split_id": pa.array([0] * 10, type=pa.uint32()),
}
).replace_schema_metadata({b"other_key": b"other_value"})
perm_tbl = mem_db.create_table("perm_nosplit", raw)
permutations = Permutations(base, perm_tbl)
assert permutations.split_names == []
assert permutations.split_dict == {}
def test_from_tables_string_split_missing_names_key(mem_db: DBConnection):
"""Regression: from_tables() with a string split must raise ValueError, not
AttributeError.
Previously, `.get(b"split_names", None).decode()` crashed with AttributeError
when the metadata dict existed but had no split_names key.
"""
base = mem_db.create_table("base_strsplit", pa.table({"x": range(10)}))
raw = pa.table(
{
"row_id": pa.array(range(10), type=pa.uint64()),
"split_id": pa.array([0] * 10, type=pa.uint32()),
}
).replace_schema_metadata({b"other_key": b"other_value"})
perm_tbl = mem_db.create_table("perm_strsplit", raw)
with pytest.raises(ValueError, match="no split names are defined"):
Permutation.from_tables(base, perm_tbl, split="train")
@pytest.fixture
def some_perm_table(some_table: Table) -> Table:
return (
@@ -643,6 +600,87 @@ def test_limit_offset(some_permutation: Permutation):
some_permutation.with_skip(500).with_take(500).num_rows
def test_permutation_pickle_rejects_in_memory_tables(mem_db: DBConnection):
table = mem_db.create_table("identity_table", pa.table({"id": range(10)}))
permutation = Permutation.identity(table)
with pytest.raises(
pickle.PicklingError,
match="in-memory databases",
):
pickle.dumps(permutation)
def test_identity_permutation_pickle_roundtrip_preserves_table_version(tmp_path):
db = connect(tmp_path)
table = db.create_table(
"identity_table",
pa.table({"id": range(10), "value": range(10)}),
)
permutation = (
Permutation.identity(table).with_skip(2).with_take(3).with_format("python_col")
)
payload = pickle.dumps(permutation)
table.add(pa.table({"id": [10], "value": [10]}))
restored = pickle.loads(payload)
assert restored.num_rows == 3
batches = list(restored.iter(10, skip_last_batch=False))
assert batches == [{"id": [2, 3, 4], "value": [2, 3, 4]}]
def test_permutation_pickle_roundtrip_with_persisted_permutation_table(tmp_path):
db = connect(tmp_path)
table = db.create_table(
"base_table",
pa.table({"id": range(1000), "value": range(1000)}),
)
permutation_table = (
permutation_builder(table)
.split_random(ratios=[0.95, 0.05], seed=42, split_names=["train", "test"])
.shuffle(seed=42)
.persist(db, "persisted_permutation")
.execute()
)
permutation = (
Permutation.from_tables(table, permutation_table, "test")
.select_columns(["id"])
.rename_column("id", "row_id")
.with_batch_size(32)
.with_skip(5)
.with_take(10)
.with_format("arrow")
)
restored = pickle.loads(pickle.dumps(permutation))
assert restored.batch_size == 32
assert restored.column_names == ["row_id"]
assert restored.num_rows == 10
assert (
restored.__getitems__([0, 1, 2]).to_pylist()
== permutation.__getitems__([0, 1, 2]).to_pylist()
)
def test_permutation_pickle_roundtrip_preserves_builtin_polars_format(tmp_path):
pl = pytest.importorskip("polars")
db = connect(tmp_path)
table = db.create_table(
"polars_table",
pa.table({"id": range(5), "value": range(5)}),
)
permutation = Permutation.identity(table).with_take(2).with_format("polars")
restored = pickle.loads(pickle.dumps(permutation))
batch = restored.__getitems__([0, 1])
assert isinstance(batch, pl.DataFrame)
assert batch.to_dict(as_series=False) == {"id": [0, 1], "value": [0, 1]}
def test_remove_columns(some_permutation: Permutation):
assert some_permutation.remove_columns(["value"]).schema == pa.schema(
[("id", pa.int64())]

View File

@@ -1049,231 +1049,6 @@ def test_add_with_nans(mem_db: DBConnection):
assert np.allclose(v, np.array([0.0, 0.0]))
def test_add_with_empty_fixed_size_list_drops_bad_rows(mem_db: DBConnection):
class Schema(LanceModel):
text: str
embedding: Vector(16)
table = mem_db.create_table("test_empty_embeddings", schema=Schema)
table.add(
[
{"text": "hello", "embedding": []},
{"text": "bar", "embedding": [0.1] * 16},
],
on_bad_vectors="drop",
)
data = table.to_arrow()
assert data["text"].to_pylist() == ["bar"]
assert np.allclose(data["embedding"].to_pylist()[0], np.array([0.1] * 16))
def test_add_with_integer_embeddings_preserves_casting(mem_db: DBConnection):
class Schema(LanceModel):
text: str
embedding: Vector(4)
table = mem_db.create_table("test_integer_embeddings", schema=Schema)
table.add(
[{"text": "foo", "embedding": [1, 2, 3, 4]}],
on_bad_vectors="drop",
)
assert table.to_arrow()["embedding"].to_pylist() == [[1.0, 2.0, 3.0, 4.0]]
def test_on_bad_vectors_does_not_handle_non_vector_fixed_size_lists(
mem_db: DBConnection,
):
schema = pa.schema(
[
pa.field("vector", pa.list_(pa.float32(), 4)),
pa.field("bbox", pa.list_(pa.float32(), 4)),
]
)
table = mem_db.create_table("test_bbox_schema", schema=schema)
with pytest.raises(RuntimeError, match="FixedSizeListType"):
table.add(
[{"vector": [1.0, 2.0, 3.0, 4.0], "bbox": [0.0, 1.0]}],
on_bad_vectors="drop",
)
def test_on_bad_vectors_does_not_handle_custom_named_fixed_size_lists(
mem_db: DBConnection,
):
schema = pa.schema([pa.field("features", pa.list_(pa.float32(), 16))])
table = mem_db.create_table("test_custom_named_fixed_size_vector", schema=schema)
with pytest.raises(RuntimeError, match="FixedSizeListType"):
table.add(
[
{"features": []},
{"features": [0.1] * 16},
],
on_bad_vectors="drop",
)
def test_on_bad_vectors_with_schema_list_vector_still_sanitizes(mem_db: DBConnection):
schema = pa.schema([pa.field("vector", pa.list_(pa.float32()))])
table = mem_db.create_table("test_schema_list_vector", schema=schema)
table.add(
[
{"vector": [1.0, 2.0]},
{"vector": [3.0]},
{"vector": [4.0, 5.0]},
],
on_bad_vectors="drop",
)
assert table.to_arrow()["vector"].to_pylist() == [[1.0, 2.0], [4.0, 5.0]]
def test_on_bad_vectors_handles_typed_custom_fixed_vectors_for_list_schema(
mem_db: DBConnection,
):
schema = pa.schema([pa.field("vec", pa.list_(pa.float32()))])
table = mem_db.create_table("test_typed_custom_fixed_vector", schema=schema)
data = pa.table(
{
"vec": pa.array(
[[float("nan")] * 16, [1.0] * 16],
type=pa.list_(pa.float32(), 16),
)
}
)
table.add(data, on_bad_vectors="drop")
assert table.to_arrow()["vec"].to_pylist() == [[1.0] * 16]
def test_on_bad_vectors_fill_preserves_arrow_nested_vector_type(mem_db: DBConnection):
schema = pa.schema([pa.field("vector", pa.list_(pa.float32()))])
table = mem_db.create_table("test_fill_arrow_nested_type", schema=schema)
data = pa.table(
{
"vector": pa.array(
[[1.0, 2.0], [float("nan"), 3.0]],
type=pa.list_(pa.float32(), 2),
)
}
)
table.add(
data,
on_bad_vectors="fill",
fill_value=0.0,
)
assert table.to_arrow()["vector"].to_pylist() == [[1.0, 2.0], [0.0, 0.0]]
@pytest.mark.parametrize(
("table_name", "batch1", "expected"),
[
(
"test_schema_list_vector_empty_prefix",
pa.record_batch({"vector": [[], []]}),
[[], [], [1.0, 2.0], [3.0, 4.0]],
),
(
"test_schema_list_vector_all_bad_prefix",
pa.record_batch({"vector": [[float("nan")] * 3, [float("nan")] * 3]}),
[[1.0, 2.0], [3.0, 4.0]],
),
],
)
def test_on_bad_vectors_with_schema_list_vector_ignores_invalid_prefix_batches(
mem_db: DBConnection,
table_name: str,
batch1: pa.RecordBatch,
expected: list,
):
schema = pa.schema([pa.field("vector", pa.list_(pa.float32()))])
table = mem_db.create_table(table_name, schema=schema)
batch2 = pa.record_batch({"vector": [[1.0, 2.0], [3.0, 4.0]]})
reader = pa.RecordBatchReader.from_batches(batch1.schema, [batch1, batch2])
table.add(reader, on_bad_vectors="drop")
assert table.to_arrow()["vector"].to_pylist() == expected
def test_on_bad_vectors_with_multiple_vectors_locks_dim_after_final_drop(
mem_db: DBConnection,
):
registry = EmbeddingFunctionRegistry.get_instance()
func = MockTextEmbeddingFunction.create()
metadata = registry.get_table_metadata(
[
EmbeddingFunctionConfig(
source_column="text1", vector_column="vec1", function=func
),
EmbeddingFunctionConfig(
source_column="text2", vector_column="vec2", function=func
),
]
)
schema = pa.schema(
[
pa.field("vec1", pa.list_(pa.float32())),
pa.field("vec2", pa.list_(pa.float32())),
],
metadata=metadata,
)
table = mem_db.create_table("test_multi_vector_dim_lock", schema=schema)
batch1 = pa.record_batch(
{
"vec1": [[1.0, 2.0, 3.0], [10.0, 11.0]],
"vec2": [[float("nan"), 0.0], [5.0, 6.0]],
}
)
batch2 = pa.record_batch(
{
"vec1": [[20.0, 21.0], [30.0, 31.0]],
"vec2": [[7.0, 8.0], [9.0, 10.0]],
}
)
reader = pa.RecordBatchReader.from_batches(batch1.schema, [batch1, batch2])
table.add(reader, on_bad_vectors="drop")
data = table.to_arrow()
assert data["vec1"].to_pylist() == [[10.0, 11.0], [20.0, 21.0], [30.0, 31.0]]
assert data["vec2"].to_pylist() == [[5.0, 6.0], [7.0, 8.0], [9.0, 10.0]]
def test_on_bad_vectors_does_not_handle_non_vector_list_columns(mem_db: DBConnection):
schema = pa.schema([pa.field("embedding_history", pa.list_(pa.float32()))])
table = mem_db.create_table("test_non_vector_list_schema", schema=schema)
table.add(
[
{"embedding_history": [1.0, 2.0]},
{"embedding_history": [3.0]},
],
on_bad_vectors="drop",
)
assert table.to_arrow()["embedding_history"].to_pylist() == [
[1.0, 2.0],
[3.0],
]
def test_on_bad_vectors_all_null_schema_vector_batches_do_not_crash(
mem_db: DBConnection,
):
schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2), nullable=True)])
table = mem_db.create_table("test_all_null_vector_batch", schema=schema)
table.add([{"vector": None}], on_bad_vectors="drop")
assert table.to_arrow()["vector"].to_pylist() == [None]
def test_restore(mem_db: DBConnection):
table = mem_db.create_table(
"my_table",

View File

@@ -15,10 +15,8 @@ from lancedb.table import (
_cast_to_target_schema,
_handle_bad_vectors,
_into_pyarrow_reader,
_infer_target_schema,
_merge_metadata,
_sanitize_data,
sanitize_create_table,
_infer_target_schema,
)
import pyarrow as pa
import pandas as pd
@@ -306,117 +304,6 @@ def test_handle_bad_vectors_noop():
assert output["vector"] == vector
def test_handle_bad_vectors_updates_reader_schema_for_target_schema():
data = pa.table({"vector": [[1, 2, 3, 4]]})
target_schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 4))])
output = _handle_bad_vectors(
data.to_reader(),
on_bad_vectors="drop",
target_schema=target_schema,
)
assert output.schema == pa.schema([pa.field("vector", pa.list_(pa.float32()))])
assert output.read_all()["vector"].to_pylist() == [[1.0, 2.0, 3.0, 4.0]]
def test_sanitize_data_keeps_target_field_metadata():
source_field = pa.field(
"vector",
pa.list_(pa.float32(), 2),
metadata={b"source": b"drop-me"},
)
target_field = pa.field(
"vector",
pa.list_(pa.float32(), 2),
metadata={b"target": b"keep-me"},
)
data = pa.table(
{"vector": pa.array([[1.0, 2.0]], type=pa.list_(pa.float32(), 2))},
schema=pa.schema([source_field]),
)
output = _sanitize_data(
data,
target_schema=pa.schema([target_field]),
on_bad_vectors="drop",
).read_all()
assert output.schema.field("vector").metadata == {b"target": b"keep-me"}
def test_sanitize_data_uses_separate_embedding_metadata_for_bad_vectors():
registry = EmbeddingFunctionRegistry.get_instance()
conf = EmbeddingFunctionConfig(
source_column="text",
vector_column="custom_vector",
function=MockTextEmbeddingFunction.create(),
)
metadata = registry.get_table_metadata([conf])
schema = pa.schema(
{
"text": pa.string(),
"custom_vector": pa.list_(pa.float32(), 10),
},
metadata={b"note": b"keep-me"},
)
data = pa.table(
{
"text": ["bad", "good"],
"custom_vector": [[1.0] * 9, [2.0] * 10],
}
)
output = _sanitize_data(
data,
target_schema=schema,
metadata=metadata,
on_bad_vectors="drop",
).read_all()
assert output["text"].to_pylist() == ["good"]
assert output.schema.metadata[b"note"] == b"keep-me"
assert b"embedding_functions" in output.schema.metadata
def test_sanitize_create_table_merges_and_overrides_embedding_metadata():
registry = EmbeddingFunctionRegistry.get_instance()
old_conf = EmbeddingFunctionConfig(
source_column="text",
vector_column="old_vector",
function=MockTextEmbeddingFunction.create(),
)
new_conf = EmbeddingFunctionConfig(
source_column="text",
vector_column="custom_vector",
function=MockTextEmbeddingFunction.create(),
)
metadata = registry.get_table_metadata([new_conf])
schema = pa.schema(
{
"text": pa.string(),
"custom_vector": pa.list_(pa.float32(), 10),
},
metadata=_merge_metadata(
{b"note": b"keep-me"},
registry.get_table_metadata([old_conf]),
),
)
data, schema = sanitize_create_table(
pa.table({"text": ["good"]}),
schema,
metadata=metadata,
on_bad_vectors="drop",
)
assert schema.metadata[b"note"] == b"keep-me"
assert b"embedding_functions" in schema.metadata
assert data.schema.metadata[b"note"] == b"keep-me"
funcs = EmbeddingFunctionRegistry.get_instance().parse_functions(schema.metadata)
assert set(funcs.keys()) == {"custom_vector"}
class TestModel(lancedb.pydantic.LanceModel):
a: Optional[int]
b: Optional[int]