mirror of
https://github.com/lancedb/lancedb.git
synced 2026-04-09 17:30:41 +00:00
Compare commits
3 Commits
main
...
xuanwo/per
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7c37ba216a | ||
|
|
768d84845c | ||
|
|
2d380d1669 |
@@ -1,21 +1,98 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
import json
|
||||
import pickle
|
||||
from datetime import timedelta
|
||||
from typing import Any, Callable, Iterator, Literal, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import pyarrow as pa
|
||||
from deprecation import deprecated
|
||||
from lancedb import AsyncConnection, DBConnection
|
||||
import pyarrow as pa
|
||||
import json
|
||||
|
||||
from ._lancedb import async_permutation_builder, PermutationReader
|
||||
from .table import LanceTable
|
||||
from .background_loop import LOOP
|
||||
from .table import LanceTable
|
||||
from .util import batch_to_tensor, batch_to_tensor_rows
|
||||
from typing import Any, Callable, Iterator, Literal, Optional, TYPE_CHECKING, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lancedb.dependencies import pandas as pd, numpy as np, polars as pl
|
||||
|
||||
|
||||
def _builtin_transform(format: str) -> Callable[[pa.RecordBatch], Any]:
|
||||
if format == "python":
|
||||
return Transforms.arrow2python
|
||||
if format == "python_col":
|
||||
return Transforms.arrow2pythoncol
|
||||
if format == "numpy":
|
||||
return Transforms.arrow2numpy
|
||||
if format == "pandas":
|
||||
return Transforms.arrow2pandas
|
||||
if format == "arrow":
|
||||
return Transforms.arrow2arrow
|
||||
if format == "torch":
|
||||
return batch_to_tensor_rows
|
||||
if format == "torch_col":
|
||||
return batch_to_tensor
|
||||
if format == "polars":
|
||||
return Transforms.arrow2polars()
|
||||
raise ValueError(f"Invalid format: {format}")
|
||||
|
||||
|
||||
def _table_to_state(
|
||||
table: Union[LanceTable, dict[str, Any]],
|
||||
) -> dict[str, Any]:
|
||||
if isinstance(table, dict):
|
||||
return table
|
||||
if not isinstance(table, LanceTable):
|
||||
raise pickle.PicklingError(
|
||||
"Permutation pickling only supports LanceTable-backed permutations"
|
||||
)
|
||||
if table._namespace_client is not None:
|
||||
raise pickle.PicklingError(
|
||||
"Permutation pickling does not yet support namespace-backed tables"
|
||||
)
|
||||
if table._conn.uri.startswith("memory://"):
|
||||
raise pickle.PicklingError(
|
||||
"Permutation pickling does not support in-memory databases"
|
||||
)
|
||||
|
||||
try:
|
||||
read_consistency_interval = table._conn.read_consistency_interval
|
||||
except Exception:
|
||||
read_consistency_interval = None
|
||||
return {
|
||||
"uri": table._conn.uri,
|
||||
"name": table.name,
|
||||
"version": table.version,
|
||||
"storage_options": table.initial_storage_options(),
|
||||
"read_consistency_interval_secs": (
|
||||
read_consistency_interval.total_seconds()
|
||||
if read_consistency_interval is not None
|
||||
else None
|
||||
),
|
||||
"namespace_path": list(table.namespace),
|
||||
}
|
||||
|
||||
|
||||
def _table_from_state(state: dict[str, Any]) -> LanceTable:
|
||||
from . import connect
|
||||
|
||||
read_consistency_interval = (
|
||||
timedelta(seconds=state["read_consistency_interval_secs"])
|
||||
if state["read_consistency_interval_secs"] is not None
|
||||
else None
|
||||
)
|
||||
db = connect(
|
||||
state["uri"],
|
||||
read_consistency_interval=read_consistency_interval,
|
||||
storage_options=state["storage_options"],
|
||||
)
|
||||
table = db.open_table(state["name"], namespace_path=state["namespace_path"])
|
||||
table.checkout(state["version"])
|
||||
return table
|
||||
|
||||
|
||||
class PermutationBuilder:
|
||||
"""
|
||||
A utility for creating a "permutation table" which is a table that defines an
|
||||
@@ -284,8 +361,9 @@ class Permutations:
|
||||
self.permutation_table = permutation_table
|
||||
|
||||
if permutation_table.schema.metadata is not None:
|
||||
raw = permutation_table.schema.metadata.get(b"split_names")
|
||||
split_names = raw.decode("utf-8") if raw is not None else None
|
||||
split_names = permutation_table.schema.metadata.get(
|
||||
b"split_names", None
|
||||
).decode("utf-8")
|
||||
if split_names is not None:
|
||||
self.split_names = json.loads(split_names)
|
||||
self.split_dict = {
|
||||
@@ -384,6 +462,13 @@ class Permutation:
|
||||
selection: dict[str, str],
|
||||
batch_size: int,
|
||||
transform_fn: Callable[pa.RecordBatch, Any],
|
||||
*,
|
||||
base_table: Union[LanceTable, dict[str, Any]],
|
||||
permutation_table: Optional[Union[LanceTable, dict[str, Any]]],
|
||||
split: int,
|
||||
offset: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
transform_spec: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Internal constructor. Use [from_tables](#from_tables) instead.
|
||||
@@ -394,6 +479,93 @@ class Permutation:
|
||||
self.selection = selection
|
||||
self.transform_fn = transform_fn
|
||||
self.batch_size = batch_size
|
||||
self._transform_spec = transform_spec
|
||||
# These fields are used to reconstruct the permutation in a new process.
|
||||
self._base_table = base_table
|
||||
self._permutation_table = permutation_table
|
||||
self._split = split
|
||||
self._offset = offset
|
||||
self._limit = limit
|
||||
|
||||
def _reopen_metadata(self) -> dict[str, Any]:
|
||||
return {
|
||||
"base_table": self._base_table,
|
||||
"permutation_table": self._permutation_table,
|
||||
"split": self._split,
|
||||
"offset": self._offset,
|
||||
"limit": self._limit,
|
||||
"transform_spec": self._transform_spec,
|
||||
}
|
||||
|
||||
def __getstate__(self) -> dict[str, Any]:
|
||||
if self._transform_spec is not None:
|
||||
transform_state = {
|
||||
"kind": "builtin",
|
||||
"format": self._transform_spec,
|
||||
}
|
||||
else:
|
||||
transform_state = {
|
||||
"kind": "callable",
|
||||
"transform_fn": self.transform_fn,
|
||||
}
|
||||
|
||||
return {
|
||||
"selection": self.selection,
|
||||
"batch_size": self.batch_size,
|
||||
"transform": transform_state,
|
||||
"reopen": {
|
||||
**self._reopen_metadata(),
|
||||
# Store reopen state instead of live LanceTable handles.
|
||||
"base_table": _table_to_state(self._base_table),
|
||||
"permutation_table": (
|
||||
_table_to_state(self._permutation_table)
|
||||
if self._permutation_table is not None
|
||||
else None
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
def __setstate__(self, state: dict[str, Any]) -> None:
|
||||
reopen = state["reopen"]
|
||||
base_table = _table_from_state(reopen["base_table"])
|
||||
permutation_table_state = reopen["permutation_table"]
|
||||
permutation_table = (
|
||||
_table_from_state(permutation_table_state)
|
||||
if permutation_table_state is not None
|
||||
else None
|
||||
)
|
||||
split = reopen["split"]
|
||||
offset = reopen["offset"]
|
||||
limit = reopen["limit"]
|
||||
|
||||
async def do_reopen():
|
||||
reader = await PermutationReader.from_tables(
|
||||
base_table, permutation_table, split
|
||||
)
|
||||
if offset is not None:
|
||||
reader = await reader.with_offset(offset)
|
||||
if limit is not None:
|
||||
reader = await reader.with_limit(limit)
|
||||
return reader
|
||||
|
||||
transform = state["transform"]
|
||||
if transform["kind"] == "builtin":
|
||||
transform_spec = transform["format"]
|
||||
transform_fn = _builtin_transform(transform_spec)
|
||||
else:
|
||||
transform_spec = None
|
||||
transform_fn = transform["transform_fn"]
|
||||
|
||||
self.reader = LOOP.run(do_reopen())
|
||||
self.selection = state["selection"]
|
||||
self.batch_size = state["batch_size"]
|
||||
self.transform_fn = transform_fn
|
||||
self._transform_spec = transform_spec
|
||||
self._base_table = reopen["base_table"]
|
||||
self._permutation_table = permutation_table_state
|
||||
self._split = split
|
||||
self._offset = offset
|
||||
self._limit = limit
|
||||
|
||||
def _with_selection(self, selection: dict[str, str]) -> "Permutation":
|
||||
"""
|
||||
@@ -402,7 +574,13 @@ class Permutation:
|
||||
Does not validation of the selection and it replaces it entirely. This is not
|
||||
intended for public use.
|
||||
"""
|
||||
return Permutation(self.reader, selection, self.batch_size, self.transform_fn)
|
||||
return Permutation(
|
||||
self.reader,
|
||||
selection,
|
||||
self.batch_size,
|
||||
self.transform_fn,
|
||||
**self._reopen_metadata(),
|
||||
)
|
||||
|
||||
def _with_reader(self, reader: PermutationReader) -> "Permutation":
|
||||
"""
|
||||
@@ -410,13 +588,25 @@ class Permutation:
|
||||
|
||||
This is an internal method and should not be used directly.
|
||||
"""
|
||||
return Permutation(reader, self.selection, self.batch_size, self.transform_fn)
|
||||
return Permutation(
|
||||
reader,
|
||||
self.selection,
|
||||
self.batch_size,
|
||||
self.transform_fn,
|
||||
**self._reopen_metadata(),
|
||||
)
|
||||
|
||||
def with_batch_size(self, batch_size: int) -> "Permutation":
|
||||
"""
|
||||
Creates a new permutation with the given batch size
|
||||
"""
|
||||
return Permutation(self.reader, self.selection, batch_size, self.transform_fn)
|
||||
return Permutation(
|
||||
self.reader,
|
||||
self.selection,
|
||||
batch_size,
|
||||
self.transform_fn,
|
||||
**self._reopen_metadata(),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def identity(cls, table: LanceTable) -> "Permutation":
|
||||
@@ -459,8 +649,9 @@ class Permutation:
|
||||
f"Cannot create a permutation on split `{split}`"
|
||||
" because no split names are defined in the permutation table"
|
||||
)
|
||||
raw = permutation_table.schema.metadata.get(b"split_names")
|
||||
split_names = raw.decode("utf-8") if raw is not None else None
|
||||
split_names = permutation_table.schema.metadata.get(
|
||||
b"split_names", None
|
||||
).decode("utf-8")
|
||||
if split_names is None:
|
||||
raise ValueError(
|
||||
f"Cannot create a permutation on split `{split}`"
|
||||
@@ -489,7 +680,14 @@ class Permutation:
|
||||
schema = await reader.output_schema(None)
|
||||
initial_selection = {name: name for name in schema.names}
|
||||
return cls(
|
||||
reader, initial_selection, DEFAULT_BATCH_SIZE, Transforms.arrow2python
|
||||
reader,
|
||||
initial_selection,
|
||||
DEFAULT_BATCH_SIZE,
|
||||
Transforms.arrow2python,
|
||||
base_table=base_table,
|
||||
permutation_table=permutation_table,
|
||||
split=split,
|
||||
transform_spec="python",
|
||||
)
|
||||
|
||||
return LOOP.run(do_from_tables())
|
||||
@@ -730,24 +928,16 @@ class Permutation:
|
||||
this method.
|
||||
"""
|
||||
assert format is not None, "format is required"
|
||||
if format == "python":
|
||||
return self.with_transform(Transforms.arrow2python)
|
||||
if format == "python_col":
|
||||
return self.with_transform(Transforms.arrow2pythoncol)
|
||||
elif format == "numpy":
|
||||
return self.with_transform(Transforms.arrow2numpy)
|
||||
elif format == "pandas":
|
||||
return self.with_transform(Transforms.arrow2pandas)
|
||||
elif format == "arrow":
|
||||
return self.with_transform(Transforms.arrow2arrow)
|
||||
elif format == "torch":
|
||||
return self.with_transform(batch_to_tensor_rows)
|
||||
elif format == "torch_col":
|
||||
return self.with_transform(batch_to_tensor)
|
||||
elif format == "polars":
|
||||
return self.with_transform(Transforms.arrow2polars())
|
||||
else:
|
||||
raise ValueError(f"Invalid format: {format}")
|
||||
return Permutation(
|
||||
self.reader,
|
||||
self.selection,
|
||||
self.batch_size,
|
||||
_builtin_transform(format),
|
||||
**{
|
||||
**self._reopen_metadata(),
|
||||
"transform_spec": format,
|
||||
},
|
||||
)
|
||||
|
||||
def with_transform(self, transform: Callable[pa.RecordBatch, Any]) -> "Permutation":
|
||||
"""
|
||||
@@ -760,7 +950,16 @@ class Permutation:
|
||||
for expensive operations such as image decoding.
|
||||
"""
|
||||
assert transform is not None, "transform is required"
|
||||
return Permutation(self.reader, self.selection, self.batch_size, transform)
|
||||
return Permutation(
|
||||
self.reader,
|
||||
self.selection,
|
||||
self.batch_size,
|
||||
transform,
|
||||
**{
|
||||
**self._reopen_metadata(),
|
||||
"transform_spec": None,
|
||||
},
|
||||
)
|
||||
|
||||
def __getitem__(self, index: int) -> Any:
|
||||
"""
|
||||
@@ -798,7 +997,16 @@ class Permutation:
|
||||
|
||||
async def do_with_skip():
|
||||
reader = await self.reader.with_offset(skip)
|
||||
return self._with_reader(reader)
|
||||
return Permutation(
|
||||
reader,
|
||||
self.selection,
|
||||
self.batch_size,
|
||||
self.transform_fn,
|
||||
**{
|
||||
**self._reopen_metadata(),
|
||||
"offset": skip,
|
||||
},
|
||||
)
|
||||
|
||||
return LOOP.run(do_with_skip())
|
||||
|
||||
@@ -821,7 +1029,16 @@ class Permutation:
|
||||
|
||||
async def do_with_take():
|
||||
reader = await self.reader.with_limit(limit)
|
||||
return self._with_reader(reader)
|
||||
return Permutation(
|
||||
reader,
|
||||
self.selection,
|
||||
self.batch_size,
|
||||
self.transform_fn,
|
||||
**{
|
||||
**self._reopen_metadata(),
|
||||
"limit": limit,
|
||||
},
|
||||
)
|
||||
|
||||
return LOOP.run(do_with_take())
|
||||
|
||||
|
||||
@@ -270,17 +270,15 @@ def _sanitize_data(
|
||||
reader,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
target_schema=target_schema,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
if target_schema is None:
|
||||
target_schema, reader = _infer_target_schema(reader)
|
||||
|
||||
if metadata:
|
||||
target_schema = target_schema.with_metadata(
|
||||
_merge_metadata(target_schema.metadata, metadata)
|
||||
)
|
||||
new_metadata = target_schema.metadata or {}
|
||||
new_metadata.update(metadata)
|
||||
target_schema = target_schema.with_metadata(new_metadata)
|
||||
|
||||
_validate_schema(target_schema)
|
||||
reader = _cast_to_target_schema(reader, target_schema, allow_subschema)
|
||||
@@ -296,7 +294,7 @@ def _cast_to_target_schema(
|
||||
# pa.Table.cast expects field order not to be changed.
|
||||
# Lance doesn't care about field order, so we don't need to rearrange fields
|
||||
# to match the target schema. We just need to correctly cast the fields.
|
||||
if reader.schema.equals(target_schema, check_metadata=True):
|
||||
if reader.schema == target_schema:
|
||||
# Fast path when the schemas are already the same
|
||||
return reader
|
||||
|
||||
@@ -316,13 +314,7 @@ def _cast_to_target_schema(
|
||||
def gen():
|
||||
for batch in reader:
|
||||
# Table but not RecordBatch has cast.
|
||||
cast_batches = (
|
||||
pa.Table.from_batches([batch]).cast(reordered_schema).to_batches()
|
||||
)
|
||||
if cast_batches:
|
||||
yield pa.RecordBatch.from_arrays(
|
||||
cast_batches[0].columns, schema=reordered_schema
|
||||
)
|
||||
yield pa.Table.from_batches([batch]).cast(reordered_schema).to_batches()[0]
|
||||
|
||||
return pa.RecordBatchReader.from_batches(reordered_schema, gen())
|
||||
|
||||
@@ -340,51 +332,37 @@ def _align_field_types(
|
||||
if target_field is None:
|
||||
raise ValueError(f"Field '{field.name}' not found in target schema")
|
||||
if pa.types.is_struct(target_field.type):
|
||||
if pa.types.is_struct(field.type):
|
||||
new_type = pa.struct(
|
||||
_align_field_types(
|
||||
field.type.fields,
|
||||
target_field.type.fields,
|
||||
)
|
||||
new_type = pa.struct(
|
||||
_align_field_types(
|
||||
field.type.fields,
|
||||
target_field.type.fields,
|
||||
)
|
||||
else:
|
||||
new_type = target_field.type
|
||||
)
|
||||
elif pa.types.is_list(target_field.type):
|
||||
if _is_list_like(field.type):
|
||||
new_type = pa.list_(
|
||||
_align_field_types(
|
||||
[field.type.value_field],
|
||||
[target_field.type.value_field],
|
||||
)[0]
|
||||
)
|
||||
else:
|
||||
new_type = target_field.type
|
||||
new_type = pa.list_(
|
||||
_align_field_types(
|
||||
[field.type.value_field],
|
||||
[target_field.type.value_field],
|
||||
)[0]
|
||||
)
|
||||
elif pa.types.is_large_list(target_field.type):
|
||||
if _is_list_like(field.type):
|
||||
new_type = pa.large_list(
|
||||
_align_field_types(
|
||||
[field.type.value_field],
|
||||
[target_field.type.value_field],
|
||||
)[0]
|
||||
)
|
||||
else:
|
||||
new_type = target_field.type
|
||||
new_type = pa.large_list(
|
||||
_align_field_types(
|
||||
[field.type.value_field],
|
||||
[target_field.type.value_field],
|
||||
)[0]
|
||||
)
|
||||
elif pa.types.is_fixed_size_list(target_field.type):
|
||||
if _is_list_like(field.type):
|
||||
new_type = pa.list_(
|
||||
_align_field_types(
|
||||
[field.type.value_field],
|
||||
[target_field.type.value_field],
|
||||
)[0],
|
||||
target_field.type.list_size,
|
||||
)
|
||||
else:
|
||||
new_type = target_field.type
|
||||
new_type = pa.list_(
|
||||
_align_field_types(
|
||||
[field.type.value_field],
|
||||
[target_field.type.value_field],
|
||||
)[0],
|
||||
target_field.type.list_size,
|
||||
)
|
||||
else:
|
||||
new_type = target_field.type
|
||||
new_fields.append(
|
||||
pa.field(field.name, new_type, field.nullable, target_field.metadata)
|
||||
)
|
||||
new_fields.append(pa.field(field.name, new_type, field.nullable))
|
||||
return new_fields
|
||||
|
||||
|
||||
@@ -462,7 +440,6 @@ def sanitize_create_table(
|
||||
schema = data.schema
|
||||
|
||||
if metadata:
|
||||
metadata = _merge_metadata(schema.metadata, metadata)
|
||||
schema = schema.with_metadata(metadata)
|
||||
# Need to apply metadata to the data as well
|
||||
if isinstance(data, pa.Table):
|
||||
@@ -515,9 +492,9 @@ def _append_vector_columns(
|
||||
vector columns to the table.
|
||||
"""
|
||||
if schema is None:
|
||||
metadata = _merge_metadata(metadata)
|
||||
metadata = metadata or {}
|
||||
else:
|
||||
metadata = _merge_metadata(schema.metadata, metadata)
|
||||
metadata = schema.metadata or metadata or {}
|
||||
functions = EmbeddingFunctionRegistry.get_instance().parse_functions(metadata)
|
||||
|
||||
if not functions:
|
||||
@@ -3234,157 +3211,43 @@ def _handle_bad_vectors(
|
||||
reader: pa.RecordBatchReader,
|
||||
on_bad_vectors: Literal["error", "drop", "fill", "null"] = "error",
|
||||
fill_value: float = 0.0,
|
||||
target_schema: Optional[pa.Schema] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
) -> pa.RecordBatchReader:
|
||||
vector_columns = _find_vector_columns(reader.schema, target_schema, metadata)
|
||||
if not vector_columns:
|
||||
return reader
|
||||
vector_columns = []
|
||||
|
||||
output_schema = _vector_output_schema(reader.schema, vector_columns)
|
||||
for field in reader.schema:
|
||||
# They can provide a 'vector' column that isn't yet a FSL
|
||||
named_vector_col = (
|
||||
(
|
||||
pa.types.is_list(field.type)
|
||||
or pa.types.is_large_list(field.type)
|
||||
or pa.types.is_fixed_size_list(field.type)
|
||||
)
|
||||
and pa.types.is_floating(field.type.value_type)
|
||||
and field.name == VECTOR_COLUMN_NAME
|
||||
)
|
||||
# TODO: we're making an assumption that fixed size list of 10 or more
|
||||
# is a vector column. This is definitely a bit hacky.
|
||||
likely_vector_col = (
|
||||
pa.types.is_fixed_size_list(field.type)
|
||||
and pa.types.is_floating(field.type.value_type)
|
||||
and (field.type.list_size >= 10)
|
||||
)
|
||||
|
||||
if named_vector_col or likely_vector_col:
|
||||
vector_columns.append(field.name)
|
||||
|
||||
def gen():
|
||||
for batch in reader:
|
||||
pending_dims = []
|
||||
for vector_column in vector_columns:
|
||||
dim = vector_column["expected_dim"]
|
||||
if target_schema is not None and dim is None:
|
||||
dim = _infer_vector_dim(batch[vector_column["name"]])
|
||||
pending_dims.append(vector_column)
|
||||
for name in vector_columns:
|
||||
batch = _handle_bad_vector_column(
|
||||
batch,
|
||||
vector_column_name=vector_column["name"],
|
||||
vector_column_name=name,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
expected_dim=dim,
|
||||
expected_value_type=vector_column["expected_value_type"],
|
||||
)
|
||||
for vector_column in pending_dims:
|
||||
if vector_column["expected_dim"] is None:
|
||||
vector_column["expected_dim"] = _infer_vector_dim(
|
||||
batch[vector_column["name"]]
|
||||
)
|
||||
if batch.schema.equals(output_schema, check_metadata=True):
|
||||
yield batch
|
||||
continue
|
||||
yield batch
|
||||
|
||||
cast_batches = (
|
||||
pa.Table.from_batches([batch]).cast(output_schema).to_batches()
|
||||
)
|
||||
if cast_batches:
|
||||
yield pa.RecordBatch.from_arrays(
|
||||
cast_batches[0].columns,
|
||||
schema=output_schema,
|
||||
)
|
||||
|
||||
return pa.RecordBatchReader.from_batches(output_schema, gen())
|
||||
|
||||
|
||||
def _find_vector_columns(
|
||||
reader_schema: pa.Schema,
|
||||
target_schema: Optional[pa.Schema],
|
||||
metadata: Optional[dict],
|
||||
) -> List[dict]:
|
||||
if target_schema is None:
|
||||
vector_columns = []
|
||||
for field in reader_schema:
|
||||
named_vector_col = (
|
||||
_is_list_like(field.type)
|
||||
and pa.types.is_floating(field.type.value_type)
|
||||
and field.name == VECTOR_COLUMN_NAME
|
||||
)
|
||||
likely_vector_col = (
|
||||
pa.types.is_fixed_size_list(field.type)
|
||||
and pa.types.is_floating(field.type.value_type)
|
||||
and (field.type.list_size >= 10)
|
||||
)
|
||||
if named_vector_col or likely_vector_col:
|
||||
vector_columns.append(
|
||||
{
|
||||
"name": field.name,
|
||||
"expected_dim": None,
|
||||
"expected_value_type": None,
|
||||
}
|
||||
)
|
||||
return vector_columns
|
||||
|
||||
reader_column_names = set(reader_schema.names)
|
||||
active_metadata = _merge_metadata(target_schema.metadata, metadata)
|
||||
embedding_function_columns = set(
|
||||
EmbeddingFunctionRegistry.get_instance().parse_functions(active_metadata).keys()
|
||||
)
|
||||
vector_columns = []
|
||||
for field in target_schema:
|
||||
if field.name not in reader_column_names:
|
||||
continue
|
||||
if not _is_list_like(field.type) or not pa.types.is_floating(
|
||||
field.type.value_type
|
||||
):
|
||||
continue
|
||||
|
||||
reader_field = reader_schema.field(field.name)
|
||||
named_vector_col = (
|
||||
field.name in embedding_function_columns
|
||||
or field.name == VECTOR_COLUMN_NAME
|
||||
or (field.name == "embedding" and pa.types.is_fixed_size_list(field.type))
|
||||
)
|
||||
typed_fixed_vector_col = (
|
||||
pa.types.is_fixed_size_list(reader_field.type)
|
||||
and pa.types.is_floating(reader_field.type.value_type)
|
||||
and reader_field.type.list_size >= 10
|
||||
)
|
||||
|
||||
if named_vector_col or typed_fixed_vector_col:
|
||||
vector_columns.append(
|
||||
{
|
||||
"name": field.name,
|
||||
"expected_dim": (
|
||||
field.type.list_size
|
||||
if pa.types.is_fixed_size_list(field.type)
|
||||
else None
|
||||
),
|
||||
"expected_value_type": field.type.value_type,
|
||||
}
|
||||
)
|
||||
|
||||
return vector_columns
|
||||
|
||||
|
||||
def _vector_output_schema(
|
||||
reader_schema: pa.Schema,
|
||||
vector_columns: List[dict],
|
||||
) -> pa.Schema:
|
||||
columns_by_name = {column["name"]: column for column in vector_columns}
|
||||
fields = []
|
||||
for field in reader_schema:
|
||||
column = columns_by_name.get(field.name)
|
||||
if column is None:
|
||||
output_type = field.type
|
||||
else:
|
||||
output_type = _vector_output_type(field, column)
|
||||
fields.append(pa.field(field.name, output_type, field.nullable, field.metadata))
|
||||
return pa.schema(fields, metadata=reader_schema.metadata)
|
||||
|
||||
|
||||
def _vector_output_type(field: pa.Field, vector_column: dict) -> pa.DataType:
|
||||
if not _is_list_like(field.type):
|
||||
return field.type
|
||||
|
||||
if vector_column["expected_value_type"] is not None and (
|
||||
pa.types.is_null(field.type.value_type)
|
||||
or pa.types.is_integer(field.type.value_type)
|
||||
or pa.types.is_unsigned_integer(field.type.value_type)
|
||||
):
|
||||
return pa.list_(vector_column["expected_value_type"])
|
||||
|
||||
if (
|
||||
vector_column["expected_dim"] is not None
|
||||
and pa.types.is_fixed_size_list(field.type)
|
||||
and field.type.list_size != vector_column["expected_dim"]
|
||||
):
|
||||
return pa.list_(field.type.value_type)
|
||||
|
||||
return field.type
|
||||
return pa.RecordBatchReader.from_batches(reader.schema, gen())
|
||||
|
||||
|
||||
def _handle_bad_vector_column(
|
||||
@@ -3392,8 +3255,6 @@ def _handle_bad_vector_column(
|
||||
vector_column_name: str,
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
expected_dim: Optional[int] = None,
|
||||
expected_value_type: Optional[pa.DataType] = None,
|
||||
) -> pa.RecordBatch:
|
||||
"""
|
||||
Ensure that the vector column exists and has type fixed_size_list(float)
|
||||
@@ -3410,39 +3271,14 @@ def _handle_bad_vector_column(
|
||||
fill_value: float, default 0.0
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
"""
|
||||
position = data.column_names.index(vector_column_name)
|
||||
vec_arr = data[vector_column_name]
|
||||
if not _is_list_like(vec_arr.type):
|
||||
return data
|
||||
|
||||
if (
|
||||
expected_dim is not None
|
||||
and pa.types.is_fixed_size_list(vec_arr.type)
|
||||
and vec_arr.type.list_size != expected_dim
|
||||
):
|
||||
vec_arr = pa.array(vec_arr.to_pylist(), type=pa.list_(vec_arr.type.value_type))
|
||||
data = data.set_column(position, vector_column_name, vec_arr)
|
||||
has_nan = has_nan_values(vec_arr)
|
||||
|
||||
if expected_value_type is not None and (
|
||||
pa.types.is_integer(vec_arr.type.value_type)
|
||||
or pa.types.is_unsigned_integer(vec_arr.type.value_type)
|
||||
):
|
||||
vec_arr = pa.array(vec_arr.to_pylist(), type=pa.list_(expected_value_type))
|
||||
data = data.set_column(position, vector_column_name, vec_arr)
|
||||
|
||||
if pa.types.is_floating(vec_arr.type.value_type):
|
||||
has_nan = has_nan_values(vec_arr)
|
||||
else:
|
||||
has_nan = pa.array([False] * len(vec_arr))
|
||||
|
||||
if expected_dim is not None:
|
||||
dim = expected_dim
|
||||
elif pa.types.is_fixed_size_list(vec_arr.type):
|
||||
if pa.types.is_fixed_size_list(vec_arr.type):
|
||||
dim = vec_arr.type.list_size
|
||||
else:
|
||||
dim = _infer_vector_dim(vec_arr)
|
||||
if dim is None:
|
||||
return data
|
||||
dim = _modal_list_size(vec_arr)
|
||||
has_wrong_dim = pc.not_equal(pc.list_value_length(vec_arr), dim)
|
||||
|
||||
has_bad_vectors = pc.any(has_nan).as_py() or pc.any(has_wrong_dim).as_py()
|
||||
@@ -3480,12 +3316,13 @@ def _handle_bad_vector_column(
|
||||
)
|
||||
vec_arr = pc.if_else(
|
||||
is_bad,
|
||||
pa.scalar([fill_value] * dim, type=vec_arr.type),
|
||||
pa.scalar([fill_value] * dim),
|
||||
vec_arr,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid value for on_bad_vectors: {on_bad_vectors}")
|
||||
|
||||
position = data.column_names.index(vector_column_name)
|
||||
return data.set_column(position, vector_column_name, vec_arr)
|
||||
|
||||
|
||||
@@ -3506,28 +3343,6 @@ def has_nan_values(arr: Union[pa.ListArray, pa.ChunkedArray]) -> pa.BooleanArray
|
||||
return pc.is_in(indices, has_nan_indices)
|
||||
|
||||
|
||||
def _is_list_like(data_type: pa.DataType) -> bool:
|
||||
return (
|
||||
pa.types.is_list(data_type)
|
||||
or pa.types.is_large_list(data_type)
|
||||
or pa.types.is_fixed_size_list(data_type)
|
||||
)
|
||||
|
||||
|
||||
def _merge_metadata(*metadata_dicts: Optional[dict]) -> dict:
|
||||
merged = {}
|
||||
for metadata in metadata_dicts:
|
||||
if metadata is None:
|
||||
continue
|
||||
for key, value in metadata.items():
|
||||
if isinstance(key, str):
|
||||
key = key.encode("utf-8")
|
||||
if isinstance(value, str):
|
||||
value = value.encode("utf-8")
|
||||
merged[key] = value
|
||||
return merged
|
||||
|
||||
|
||||
def _name_suggests_vector_column(field_name: str) -> bool:
|
||||
"""Check if a field name indicates a vector column."""
|
||||
name_lower = field_name.lower()
|
||||
@@ -3595,16 +3410,6 @@ def _modal_list_size(arr: Union[pa.ListArray, pa.ChunkedArray]) -> int:
|
||||
return pc.mode(pc.list_value_length(arr))[0].as_py()["mode"]
|
||||
|
||||
|
||||
def _infer_vector_dim(arr: Union[pa.Array, pa.ChunkedArray]) -> Optional[int]:
|
||||
if not _is_list_like(arr.type):
|
||||
return None
|
||||
lengths = pc.list_value_length(arr)
|
||||
lengths = pc.filter(lengths, pc.greater(lengths, 0))
|
||||
if len(lengths) == 0:
|
||||
return None
|
||||
return pc.mode(lengths)[0].as_py()["mode"]
|
||||
|
||||
|
||||
def _validate_schema(schema: pa.Schema):
|
||||
"""
|
||||
Make sure the metadata is valid utf8
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
import pyarrow as pa
|
||||
import math
|
||||
import pickle
|
||||
import pytest
|
||||
|
||||
from lancedb import DBConnection, Table, connect
|
||||
@@ -522,50 +523,6 @@ def test_no_split_names(some_table: Table):
|
||||
assert permutations[1].num_rows == 500
|
||||
|
||||
|
||||
def test_permutations_metadata_without_split_names_key(mem_db: DBConnection):
|
||||
"""Regression: schema metadata present but missing split_names key must not crash.
|
||||
|
||||
Previously, `.get(b"split_names", None).decode()` was called unconditionally,
|
||||
so any permutation table whose metadata dict had other keys but no split_names
|
||||
raised AttributeError: 'NoneType' has no attribute 'decode'.
|
||||
"""
|
||||
base = mem_db.create_table("base_nosplit", pa.table({"x": range(10)}))
|
||||
|
||||
# Build a permutation-like table that carries some metadata but NOT split_names.
|
||||
raw = pa.table(
|
||||
{
|
||||
"row_id": pa.array(range(10), type=pa.uint64()),
|
||||
"split_id": pa.array([0] * 10, type=pa.uint32()),
|
||||
}
|
||||
).replace_schema_metadata({b"other_key": b"other_value"})
|
||||
perm_tbl = mem_db.create_table("perm_nosplit", raw)
|
||||
|
||||
permutations = Permutations(base, perm_tbl)
|
||||
assert permutations.split_names == []
|
||||
assert permutations.split_dict == {}
|
||||
|
||||
|
||||
def test_from_tables_string_split_missing_names_key(mem_db: DBConnection):
|
||||
"""Regression: from_tables() with a string split must raise ValueError, not
|
||||
AttributeError.
|
||||
|
||||
Previously, `.get(b"split_names", None).decode()` crashed with AttributeError
|
||||
when the metadata dict existed but had no split_names key.
|
||||
"""
|
||||
base = mem_db.create_table("base_strsplit", pa.table({"x": range(10)}))
|
||||
|
||||
raw = pa.table(
|
||||
{
|
||||
"row_id": pa.array(range(10), type=pa.uint64()),
|
||||
"split_id": pa.array([0] * 10, type=pa.uint32()),
|
||||
}
|
||||
).replace_schema_metadata({b"other_key": b"other_value"})
|
||||
perm_tbl = mem_db.create_table("perm_strsplit", raw)
|
||||
|
||||
with pytest.raises(ValueError, match="no split names are defined"):
|
||||
Permutation.from_tables(base, perm_tbl, split="train")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def some_perm_table(some_table: Table) -> Table:
|
||||
return (
|
||||
@@ -643,6 +600,87 @@ def test_limit_offset(some_permutation: Permutation):
|
||||
some_permutation.with_skip(500).with_take(500).num_rows
|
||||
|
||||
|
||||
def test_permutation_pickle_rejects_in_memory_tables(mem_db: DBConnection):
|
||||
table = mem_db.create_table("identity_table", pa.table({"id": range(10)}))
|
||||
permutation = Permutation.identity(table)
|
||||
|
||||
with pytest.raises(
|
||||
pickle.PicklingError,
|
||||
match="in-memory databases",
|
||||
):
|
||||
pickle.dumps(permutation)
|
||||
|
||||
|
||||
def test_identity_permutation_pickle_roundtrip_preserves_table_version(tmp_path):
|
||||
db = connect(tmp_path)
|
||||
table = db.create_table(
|
||||
"identity_table",
|
||||
pa.table({"id": range(10), "value": range(10)}),
|
||||
)
|
||||
permutation = (
|
||||
Permutation.identity(table).with_skip(2).with_take(3).with_format("python_col")
|
||||
)
|
||||
|
||||
payload = pickle.dumps(permutation)
|
||||
table.add(pa.table({"id": [10], "value": [10]}))
|
||||
|
||||
restored = pickle.loads(payload)
|
||||
assert restored.num_rows == 3
|
||||
batches = list(restored.iter(10, skip_last_batch=False))
|
||||
assert batches == [{"id": [2, 3, 4], "value": [2, 3, 4]}]
|
||||
|
||||
|
||||
def test_permutation_pickle_roundtrip_with_persisted_permutation_table(tmp_path):
|
||||
db = connect(tmp_path)
|
||||
table = db.create_table(
|
||||
"base_table",
|
||||
pa.table({"id": range(1000), "value": range(1000)}),
|
||||
)
|
||||
permutation_table = (
|
||||
permutation_builder(table)
|
||||
.split_random(ratios=[0.95, 0.05], seed=42, split_names=["train", "test"])
|
||||
.shuffle(seed=42)
|
||||
.persist(db, "persisted_permutation")
|
||||
.execute()
|
||||
)
|
||||
permutation = (
|
||||
Permutation.from_tables(table, permutation_table, "test")
|
||||
.select_columns(["id"])
|
||||
.rename_column("id", "row_id")
|
||||
.with_batch_size(32)
|
||||
.with_skip(5)
|
||||
.with_take(10)
|
||||
.with_format("arrow")
|
||||
)
|
||||
|
||||
restored = pickle.loads(pickle.dumps(permutation))
|
||||
|
||||
assert restored.batch_size == 32
|
||||
assert restored.column_names == ["row_id"]
|
||||
assert restored.num_rows == 10
|
||||
assert (
|
||||
restored.__getitems__([0, 1, 2]).to_pylist()
|
||||
== permutation.__getitems__([0, 1, 2]).to_pylist()
|
||||
)
|
||||
|
||||
|
||||
def test_permutation_pickle_roundtrip_preserves_builtin_polars_format(tmp_path):
|
||||
pl = pytest.importorskip("polars")
|
||||
|
||||
db = connect(tmp_path)
|
||||
table = db.create_table(
|
||||
"polars_table",
|
||||
pa.table({"id": range(5), "value": range(5)}),
|
||||
)
|
||||
permutation = Permutation.identity(table).with_take(2).with_format("polars")
|
||||
|
||||
restored = pickle.loads(pickle.dumps(permutation))
|
||||
batch = restored.__getitems__([0, 1])
|
||||
|
||||
assert isinstance(batch, pl.DataFrame)
|
||||
assert batch.to_dict(as_series=False) == {"id": [0, 1], "value": [0, 1]}
|
||||
|
||||
|
||||
def test_remove_columns(some_permutation: Permutation):
|
||||
assert some_permutation.remove_columns(["value"]).schema == pa.schema(
|
||||
[("id", pa.int64())]
|
||||
|
||||
@@ -1049,231 +1049,6 @@ def test_add_with_nans(mem_db: DBConnection):
|
||||
assert np.allclose(v, np.array([0.0, 0.0]))
|
||||
|
||||
|
||||
def test_add_with_empty_fixed_size_list_drops_bad_rows(mem_db: DBConnection):
|
||||
class Schema(LanceModel):
|
||||
text: str
|
||||
embedding: Vector(16)
|
||||
|
||||
table = mem_db.create_table("test_empty_embeddings", schema=Schema)
|
||||
table.add(
|
||||
[
|
||||
{"text": "hello", "embedding": []},
|
||||
{"text": "bar", "embedding": [0.1] * 16},
|
||||
],
|
||||
on_bad_vectors="drop",
|
||||
)
|
||||
|
||||
data = table.to_arrow()
|
||||
assert data["text"].to_pylist() == ["bar"]
|
||||
assert np.allclose(data["embedding"].to_pylist()[0], np.array([0.1] * 16))
|
||||
|
||||
|
||||
def test_add_with_integer_embeddings_preserves_casting(mem_db: DBConnection):
|
||||
class Schema(LanceModel):
|
||||
text: str
|
||||
embedding: Vector(4)
|
||||
|
||||
table = mem_db.create_table("test_integer_embeddings", schema=Schema)
|
||||
table.add(
|
||||
[{"text": "foo", "embedding": [1, 2, 3, 4]}],
|
||||
on_bad_vectors="drop",
|
||||
)
|
||||
|
||||
assert table.to_arrow()["embedding"].to_pylist() == [[1.0, 2.0, 3.0, 4.0]]
|
||||
|
||||
|
||||
def test_on_bad_vectors_does_not_handle_non_vector_fixed_size_lists(
|
||||
mem_db: DBConnection,
|
||||
):
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("vector", pa.list_(pa.float32(), 4)),
|
||||
pa.field("bbox", pa.list_(pa.float32(), 4)),
|
||||
]
|
||||
)
|
||||
table = mem_db.create_table("test_bbox_schema", schema=schema)
|
||||
|
||||
with pytest.raises(RuntimeError, match="FixedSizeListType"):
|
||||
table.add(
|
||||
[{"vector": [1.0, 2.0, 3.0, 4.0], "bbox": [0.0, 1.0]}],
|
||||
on_bad_vectors="drop",
|
||||
)
|
||||
|
||||
|
||||
def test_on_bad_vectors_does_not_handle_custom_named_fixed_size_lists(
|
||||
mem_db: DBConnection,
|
||||
):
|
||||
schema = pa.schema([pa.field("features", pa.list_(pa.float32(), 16))])
|
||||
table = mem_db.create_table("test_custom_named_fixed_size_vector", schema=schema)
|
||||
|
||||
with pytest.raises(RuntimeError, match="FixedSizeListType"):
|
||||
table.add(
|
||||
[
|
||||
{"features": []},
|
||||
{"features": [0.1] * 16},
|
||||
],
|
||||
on_bad_vectors="drop",
|
||||
)
|
||||
|
||||
|
||||
def test_on_bad_vectors_with_schema_list_vector_still_sanitizes(mem_db: DBConnection):
|
||||
schema = pa.schema([pa.field("vector", pa.list_(pa.float32()))])
|
||||
table = mem_db.create_table("test_schema_list_vector", schema=schema)
|
||||
table.add(
|
||||
[
|
||||
{"vector": [1.0, 2.0]},
|
||||
{"vector": [3.0]},
|
||||
{"vector": [4.0, 5.0]},
|
||||
],
|
||||
on_bad_vectors="drop",
|
||||
)
|
||||
|
||||
assert table.to_arrow()["vector"].to_pylist() == [[1.0, 2.0], [4.0, 5.0]]
|
||||
|
||||
|
||||
def test_on_bad_vectors_handles_typed_custom_fixed_vectors_for_list_schema(
|
||||
mem_db: DBConnection,
|
||||
):
|
||||
schema = pa.schema([pa.field("vec", pa.list_(pa.float32()))])
|
||||
table = mem_db.create_table("test_typed_custom_fixed_vector", schema=schema)
|
||||
data = pa.table(
|
||||
{
|
||||
"vec": pa.array(
|
||||
[[float("nan")] * 16, [1.0] * 16],
|
||||
type=pa.list_(pa.float32(), 16),
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
table.add(data, on_bad_vectors="drop")
|
||||
|
||||
assert table.to_arrow()["vec"].to_pylist() == [[1.0] * 16]
|
||||
|
||||
|
||||
def test_on_bad_vectors_fill_preserves_arrow_nested_vector_type(mem_db: DBConnection):
|
||||
schema = pa.schema([pa.field("vector", pa.list_(pa.float32()))])
|
||||
table = mem_db.create_table("test_fill_arrow_nested_type", schema=schema)
|
||||
data = pa.table(
|
||||
{
|
||||
"vector": pa.array(
|
||||
[[1.0, 2.0], [float("nan"), 3.0]],
|
||||
type=pa.list_(pa.float32(), 2),
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
table.add(
|
||||
data,
|
||||
on_bad_vectors="fill",
|
||||
fill_value=0.0,
|
||||
)
|
||||
|
||||
assert table.to_arrow()["vector"].to_pylist() == [[1.0, 2.0], [0.0, 0.0]]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("table_name", "batch1", "expected"),
|
||||
[
|
||||
(
|
||||
"test_schema_list_vector_empty_prefix",
|
||||
pa.record_batch({"vector": [[], []]}),
|
||||
[[], [], [1.0, 2.0], [3.0, 4.0]],
|
||||
),
|
||||
(
|
||||
"test_schema_list_vector_all_bad_prefix",
|
||||
pa.record_batch({"vector": [[float("nan")] * 3, [float("nan")] * 3]}),
|
||||
[[1.0, 2.0], [3.0, 4.0]],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_on_bad_vectors_with_schema_list_vector_ignores_invalid_prefix_batches(
|
||||
mem_db: DBConnection,
|
||||
table_name: str,
|
||||
batch1: pa.RecordBatch,
|
||||
expected: list,
|
||||
):
|
||||
schema = pa.schema([pa.field("vector", pa.list_(pa.float32()))])
|
||||
table = mem_db.create_table(table_name, schema=schema)
|
||||
batch2 = pa.record_batch({"vector": [[1.0, 2.0], [3.0, 4.0]]})
|
||||
reader = pa.RecordBatchReader.from_batches(batch1.schema, [batch1, batch2])
|
||||
|
||||
table.add(reader, on_bad_vectors="drop")
|
||||
|
||||
assert table.to_arrow()["vector"].to_pylist() == expected
|
||||
|
||||
|
||||
def test_on_bad_vectors_with_multiple_vectors_locks_dim_after_final_drop(
|
||||
mem_db: DBConnection,
|
||||
):
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
func = MockTextEmbeddingFunction.create()
|
||||
metadata = registry.get_table_metadata(
|
||||
[
|
||||
EmbeddingFunctionConfig(
|
||||
source_column="text1", vector_column="vec1", function=func
|
||||
),
|
||||
EmbeddingFunctionConfig(
|
||||
source_column="text2", vector_column="vec2", function=func
|
||||
),
|
||||
]
|
||||
)
|
||||
schema = pa.schema(
|
||||
[
|
||||
pa.field("vec1", pa.list_(pa.float32())),
|
||||
pa.field("vec2", pa.list_(pa.float32())),
|
||||
],
|
||||
metadata=metadata,
|
||||
)
|
||||
table = mem_db.create_table("test_multi_vector_dim_lock", schema=schema)
|
||||
batch1 = pa.record_batch(
|
||||
{
|
||||
"vec1": [[1.0, 2.0, 3.0], [10.0, 11.0]],
|
||||
"vec2": [[float("nan"), 0.0], [5.0, 6.0]],
|
||||
}
|
||||
)
|
||||
batch2 = pa.record_batch(
|
||||
{
|
||||
"vec1": [[20.0, 21.0], [30.0, 31.0]],
|
||||
"vec2": [[7.0, 8.0], [9.0, 10.0]],
|
||||
}
|
||||
)
|
||||
reader = pa.RecordBatchReader.from_batches(batch1.schema, [batch1, batch2])
|
||||
|
||||
table.add(reader, on_bad_vectors="drop")
|
||||
|
||||
data = table.to_arrow()
|
||||
assert data["vec1"].to_pylist() == [[10.0, 11.0], [20.0, 21.0], [30.0, 31.0]]
|
||||
assert data["vec2"].to_pylist() == [[5.0, 6.0], [7.0, 8.0], [9.0, 10.0]]
|
||||
|
||||
|
||||
def test_on_bad_vectors_does_not_handle_non_vector_list_columns(mem_db: DBConnection):
|
||||
schema = pa.schema([pa.field("embedding_history", pa.list_(pa.float32()))])
|
||||
table = mem_db.create_table("test_non_vector_list_schema", schema=schema)
|
||||
table.add(
|
||||
[
|
||||
{"embedding_history": [1.0, 2.0]},
|
||||
{"embedding_history": [3.0]},
|
||||
],
|
||||
on_bad_vectors="drop",
|
||||
)
|
||||
|
||||
assert table.to_arrow()["embedding_history"].to_pylist() == [
|
||||
[1.0, 2.0],
|
||||
[3.0],
|
||||
]
|
||||
|
||||
|
||||
def test_on_bad_vectors_all_null_schema_vector_batches_do_not_crash(
|
||||
mem_db: DBConnection,
|
||||
):
|
||||
schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2), nullable=True)])
|
||||
table = mem_db.create_table("test_all_null_vector_batch", schema=schema)
|
||||
|
||||
table.add([{"vector": None}], on_bad_vectors="drop")
|
||||
|
||||
assert table.to_arrow()["vector"].to_pylist() == [None]
|
||||
|
||||
|
||||
def test_restore(mem_db: DBConnection):
|
||||
table = mem_db.create_table(
|
||||
"my_table",
|
||||
|
||||
@@ -15,10 +15,8 @@ from lancedb.table import (
|
||||
_cast_to_target_schema,
|
||||
_handle_bad_vectors,
|
||||
_into_pyarrow_reader,
|
||||
_infer_target_schema,
|
||||
_merge_metadata,
|
||||
_sanitize_data,
|
||||
sanitize_create_table,
|
||||
_infer_target_schema,
|
||||
)
|
||||
import pyarrow as pa
|
||||
import pandas as pd
|
||||
@@ -306,117 +304,6 @@ def test_handle_bad_vectors_noop():
|
||||
assert output["vector"] == vector
|
||||
|
||||
|
||||
def test_handle_bad_vectors_updates_reader_schema_for_target_schema():
|
||||
data = pa.table({"vector": [[1, 2, 3, 4]]})
|
||||
target_schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 4))])
|
||||
|
||||
output = _handle_bad_vectors(
|
||||
data.to_reader(),
|
||||
on_bad_vectors="drop",
|
||||
target_schema=target_schema,
|
||||
)
|
||||
|
||||
assert output.schema == pa.schema([pa.field("vector", pa.list_(pa.float32()))])
|
||||
assert output.read_all()["vector"].to_pylist() == [[1.0, 2.0, 3.0, 4.0]]
|
||||
|
||||
|
||||
def test_sanitize_data_keeps_target_field_metadata():
|
||||
source_field = pa.field(
|
||||
"vector",
|
||||
pa.list_(pa.float32(), 2),
|
||||
metadata={b"source": b"drop-me"},
|
||||
)
|
||||
target_field = pa.field(
|
||||
"vector",
|
||||
pa.list_(pa.float32(), 2),
|
||||
metadata={b"target": b"keep-me"},
|
||||
)
|
||||
data = pa.table(
|
||||
{"vector": pa.array([[1.0, 2.0]], type=pa.list_(pa.float32(), 2))},
|
||||
schema=pa.schema([source_field]),
|
||||
)
|
||||
|
||||
output = _sanitize_data(
|
||||
data,
|
||||
target_schema=pa.schema([target_field]),
|
||||
on_bad_vectors="drop",
|
||||
).read_all()
|
||||
|
||||
assert output.schema.field("vector").metadata == {b"target": b"keep-me"}
|
||||
|
||||
|
||||
def test_sanitize_data_uses_separate_embedding_metadata_for_bad_vectors():
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
conf = EmbeddingFunctionConfig(
|
||||
source_column="text",
|
||||
vector_column="custom_vector",
|
||||
function=MockTextEmbeddingFunction.create(),
|
||||
)
|
||||
metadata = registry.get_table_metadata([conf])
|
||||
schema = pa.schema(
|
||||
{
|
||||
"text": pa.string(),
|
||||
"custom_vector": pa.list_(pa.float32(), 10),
|
||||
},
|
||||
metadata={b"note": b"keep-me"},
|
||||
)
|
||||
data = pa.table(
|
||||
{
|
||||
"text": ["bad", "good"],
|
||||
"custom_vector": [[1.0] * 9, [2.0] * 10],
|
||||
}
|
||||
)
|
||||
|
||||
output = _sanitize_data(
|
||||
data,
|
||||
target_schema=schema,
|
||||
metadata=metadata,
|
||||
on_bad_vectors="drop",
|
||||
).read_all()
|
||||
|
||||
assert output["text"].to_pylist() == ["good"]
|
||||
assert output.schema.metadata[b"note"] == b"keep-me"
|
||||
assert b"embedding_functions" in output.schema.metadata
|
||||
|
||||
|
||||
def test_sanitize_create_table_merges_and_overrides_embedding_metadata():
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
old_conf = EmbeddingFunctionConfig(
|
||||
source_column="text",
|
||||
vector_column="old_vector",
|
||||
function=MockTextEmbeddingFunction.create(),
|
||||
)
|
||||
new_conf = EmbeddingFunctionConfig(
|
||||
source_column="text",
|
||||
vector_column="custom_vector",
|
||||
function=MockTextEmbeddingFunction.create(),
|
||||
)
|
||||
metadata = registry.get_table_metadata([new_conf])
|
||||
schema = pa.schema(
|
||||
{
|
||||
"text": pa.string(),
|
||||
"custom_vector": pa.list_(pa.float32(), 10),
|
||||
},
|
||||
metadata=_merge_metadata(
|
||||
{b"note": b"keep-me"},
|
||||
registry.get_table_metadata([old_conf]),
|
||||
),
|
||||
)
|
||||
|
||||
data, schema = sanitize_create_table(
|
||||
pa.table({"text": ["good"]}),
|
||||
schema,
|
||||
metadata=metadata,
|
||||
on_bad_vectors="drop",
|
||||
)
|
||||
|
||||
assert schema.metadata[b"note"] == b"keep-me"
|
||||
assert b"embedding_functions" in schema.metadata
|
||||
assert data.schema.metadata[b"note"] == b"keep-me"
|
||||
funcs = EmbeddingFunctionRegistry.get_instance().parse_functions(schema.metadata)
|
||||
assert set(funcs.keys()) == {"custom_vector"}
|
||||
|
||||
|
||||
class TestModel(lancedb.pydantic.LanceModel):
|
||||
a: Optional[int]
|
||||
b: Optional[int]
|
||||
|
||||
Reference in New Issue
Block a user