feat(python)!: support inserting and upserting subschemas (#1965)

BREAKING CHANGE: For a field "vector", list of integers will now be
converted to binary (uint8) vectors instead of f32 vectors. Use float
values instead for f32 vectors.

* Adds proper support for inserting and upserting subsets of the full
schema. I thought I had previously implemented this in #1827, but it
turns out I had not tested carefully enough.
* Refactors `_santize_data` and other utility functions to be simpler
and not require `numpy` or `combine_chunks()`.
* Added a new suite of unit tests to validate sanitization utilities.

## Examples

```python
import pandas as pd
import lancedb

db = lancedb.connect("memory://demo")
intial_data = pd.DataFrame({
    "a": [1, 2, 3],
    "b": [4, 5, 6],
    "c": [7, 8, 9]
})
table = db.create_table("demo", intial_data)

# Insert a subschema
new_data = pd.DataFrame({"a": [10, 11]})
table.add(new_data)
table.to_pandas()
```
```
    a    b    c
0   1  4.0  7.0
1   2  5.0  8.0
2   3  6.0  9.0
3  10  NaN  NaN
4  11  NaN  NaN
```


```python
# Upsert a subschema
upsert_data = pd.DataFrame({
    "a": [3, 10, 15],
    "b": [6, 7, 8],
})
table.merge_insert(on="a").when_matched_update_all().when_not_matched_insert_all().execute(upsert_data)
table.to_pandas()
```
```
    a    b    c
0   1  4.0  7.0
1   2  5.0  8.0
2   3  6.0  9.0
3  10  7.0  NaN
4  11  NaN  NaN
5  15  8.0  NaN
```
This commit is contained in:
Will Jones
2025-01-08 10:11:10 -08:00
committed by GitHub
parent 3c0a64be8f
commit c557e77f09
10 changed files with 874 additions and 292 deletions

View File

@@ -784,10 +784,6 @@ class AsyncConnection(object):
registry = EmbeddingFunctionRegistry.get_instance() registry = EmbeddingFunctionRegistry.get_instance()
metadata = registry.get_table_metadata(embedding_functions) metadata = registry.get_table_metadata(embedding_functions)
data, schema = sanitize_create_table(
data, schema, metadata, on_bad_vectors, fill_value
)
# Defining defaults here and not in function prototype. In the future # Defining defaults here and not in function prototype. In the future
# these defaults will move into rust so better to keep them as None. # these defaults will move into rust so better to keep them as None.
if on_bad_vectors is None: if on_bad_vectors is None:

View File

@@ -108,9 +108,14 @@ class EmbeddingFunctionRegistry:
An empty dict is returned if input is None or does not An empty dict is returned if input is None or does not
contain b"embedding_functions". contain b"embedding_functions".
""" """
if metadata is None or b"embedding_functions" not in metadata: if metadata is None:
return {}
# Look at both bytes and string keys, since we might use either
serialized = metadata.get(
b"embedding_functions", metadata.get("embedding_functions")
)
if serialized is None:
return {} return {}
serialized = metadata[b"embedding_functions"]
raw_list = json.loads(serialized.decode("utf-8")) raw_list = json.loads(serialized.decode("utf-8"))
return { return {
obj["vector_column"]: EmbeddingFunctionConfig( obj["vector_column"]: EmbeddingFunctionConfig(

View File

@@ -472,7 +472,7 @@ class LanceQueryBuilder(ABC):
-------- --------
>>> import lancedb >>> import lancedb
>>> db = lancedb.connect("./.lancedb") >>> db = lancedb.connect("./.lancedb")
>>> table = db.create_table("my_table", [{"vector": [99, 99]}]) >>> table = db.create_table("my_table", [{"vector": [99.0, 99]}])
>>> query = [100, 100] >>> query = [100, 100]
>>> plan = table.search(query).explain_plan(True) >>> plan = table.search(query).explain_plan(True)
>>> print(plan) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE >>> print(plan) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE

View File

@@ -25,7 +25,6 @@ from urllib.parse import urlparse
import lance import lance
from lancedb.background_loop import LOOP from lancedb.background_loop import LOOP
from .dependencies import _check_for_pandas from .dependencies import _check_for_pandas
import numpy as np
import pyarrow as pa import pyarrow as pa
import pyarrow.compute as pc import pyarrow.compute as pc
import pyarrow.fs as pa_fs import pyarrow.fs as pa_fs
@@ -74,34 +73,17 @@ pl = safe_import_polars()
QueryType = Literal["vector", "fts", "hybrid", "auto"] QueryType = Literal["vector", "fts", "hybrid", "auto"]
def _pd_schema_without_embedding_funcs( def _into_pyarrow_table(data) -> pa.Table:
schema: Optional[pa.Schema], columns: List[str]
) -> Optional[pa.Schema]:
"""Return a schema without any embedding function columns"""
if schema is None:
return None
embedding_functions = EmbeddingFunctionRegistry.get_instance().parse_functions(
schema.metadata
)
if not embedding_functions:
return schema
return pa.schema([field for field in schema if field.name in columns])
def _coerce_to_table(data, schema: Optional[pa.Schema] = None) -> pa.Table:
if _check_for_hugging_face(data): if _check_for_hugging_face(data):
# Huggingface datasets # Huggingface datasets
from lance.dependencies import datasets from lance.dependencies import datasets
if isinstance(data, datasets.Dataset): if isinstance(data, datasets.Dataset):
if schema is None: schema = data.features.arrow_schema
schema = data.features.arrow_schema
return pa.Table.from_batches(data.data.to_batches(), schema=schema) return pa.Table.from_batches(data.data.to_batches(), schema=schema)
elif isinstance(data, datasets.dataset_dict.DatasetDict): elif isinstance(data, datasets.dataset_dict.DatasetDict):
if schema is None: schema = _schema_from_hf(data, schema)
schema = _schema_from_hf(data, schema)
return pa.Table.from_batches(_to_batches_with_split(data), schema=schema) return pa.Table.from_batches(_to_batches_with_split(data), schema=schema)
if isinstance(data, LanceModel): if isinstance(data, LanceModel):
raise ValueError("Cannot add a single LanceModel to a table. Use a list.") raise ValueError("Cannot add a single LanceModel to a table. Use a list.")
@@ -111,17 +93,15 @@ def _coerce_to_table(data, schema: Optional[pa.Schema] = None) -> pa.Table:
if isinstance(data, list): if isinstance(data, list):
# convert to list of dict if data is a bunch of LanceModels # convert to list of dict if data is a bunch of LanceModels
if isinstance(data[0], LanceModel): if isinstance(data[0], LanceModel):
if schema is None: schema = data[0].__class__.to_arrow_schema()
schema = data[0].__class__.to_arrow_schema()
data = [model_to_dict(d) for d in data] data = [model_to_dict(d) for d in data]
return pa.Table.from_pylist(data, schema=schema) return pa.Table.from_pylist(data, schema=schema)
elif isinstance(data[0], pa.RecordBatch): elif isinstance(data[0], pa.RecordBatch):
return pa.Table.from_batches(data, schema=schema) return pa.Table.from_batches(data)
else: else:
return pa.Table.from_pylist(data, schema=schema) return pa.Table.from_pylist(data)
elif _check_for_pandas(data) and isinstance(data, pd.DataFrame): # type: ignore elif _check_for_pandas(data) and isinstance(data, pd.DataFrame):
raw_schema = _pd_schema_without_embedding_funcs(schema, data.columns.to_list()) table = pa.Table.from_pandas(data, preserve_index=False)
table = pa.Table.from_pandas(data, preserve_index=False, schema=raw_schema)
# Do not serialize Pandas metadata # Do not serialize Pandas metadata
meta = table.schema.metadata if table.schema.metadata is not None else {} meta = table.schema.metadata if table.schema.metadata is not None else {}
meta = {k: v for k, v in meta.items() if k != b"pandas"} meta = {k: v for k, v in meta.items() if k != b"pandas"}
@@ -143,8 +123,13 @@ def _coerce_to_table(data, schema: Optional[pa.Schema] = None) -> pa.Table:
and data.__class__.__name__ == "DataFrame" and data.__class__.__name__ == "DataFrame"
): ):
return data.to_arrow() return data.to_arrow()
elif (
type(data).__module__.startswith("polars")
and data.__class__.__name__ == "LazyFrame"
):
return data.collect().to_arrow()
elif isinstance(data, Iterable): elif isinstance(data, Iterable):
return _process_iterator(data, schema) return _iterator_to_table(data)
else: else:
raise TypeError( raise TypeError(
f"Unknown data type {type(data)}. " f"Unknown data type {type(data)}. "
@@ -154,27 +139,172 @@ def _coerce_to_table(data, schema: Optional[pa.Schema] = None) -> pa.Table:
) )
def _iterator_to_table(data: Iterable) -> pa.Table:
batches = []
schema = None # Will get schema from first batch
for batch in data:
batch_table = _into_pyarrow_table(batch)
if schema is not None:
if batch_table.schema != schema:
try:
batch_table = batch_table.cast(schema)
except pa.lib.ArrowInvalid:
raise ValueError(
f"Input iterator yielded a batch with schema that "
f"does not match the schema of other batches.\n"
f"Expected:\n{schema}\nGot:\n{batch_table.schema}"
)
else:
# Use the first schema for the remainder of the batches
schema = batch_table.schema
batches.append(batch_table)
if batches:
return pa.concat_tables(batches)
else:
raise ValueError("Input iterable is empty")
def _sanitize_data( def _sanitize_data(
data: Any, data: "DATA",
schema: Optional[pa.Schema] = None, target_schema: Optional[pa.Schema] = None,
metadata: Optional[dict] = None, # embedding metadata metadata: Optional[dict] = None, # embedding metadata
on_bad_vectors: str = "error", on_bad_vectors: Literal["error", "drop", "fill", "null"] = "error",
fill_value: float = 0.0, fill_value: float = 0.0,
) -> Tuple[pa.Table, pa.Schema]: *,
data = _coerce_to_table(data, schema) allow_subschema: bool = False,
) -> pa.Table:
"""
Handle input data, applying all standard transformations.
This includes:
* Converting the data to a PyArrow Table
* Adding vector columns defined in the metadata
* Adding embedding metadata into the schema
* Casting the table to the target schema
* Handling bad vectors
Parameters
----------
target_schema : Optional[pa.Schema], default None
The schema to cast the table to. This is typically the schema of the table
if it already exists. Otherwise it might be a user-requested schema.
allow_subschema : bool, default False
If True, the input table is allowed to omit columns from the target schema.
The target schema will be filtered to only include columns that are present
in the input table before casting.
metadata : Optional[dict], default None
The embedding metadata to add to the schema.
on_bad_vectors : Literal["error", "drop", "fill", "null"], default "error"
What to do if any of the vectors are not the same size or contains NaNs.
fill_value : float, default 0.0
The value to use when filling vectors. Only used if on_bad_vectors="fill".
All entries in the vector will be set to this value.
"""
# At this point, the table might not match the schema we are targeting:
# 1. There might be embedding columns missing that will be added
# in the add_embeddings step.
# 2. If `allow_subschemas` is True, there might be columns missing.
table = _into_pyarrow_table(data)
table = _append_vector_columns(table, target_schema, metadata=metadata)
# This happens before the cast so we can fix vector columns with
# incorrect lengths before they are cast to FSL.
table = _handle_bad_vectors(
table,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
if target_schema is None:
target_schema = _infer_target_schema(table)
if metadata: if metadata:
data = _append_vector_col(data, metadata, schema) new_metadata = target_schema.metadata or {}
metadata.update(data.schema.metadata or {}) new_metadata = new_metadata.update(metadata)
data = data.replace_schema_metadata(metadata) target_schema = target_schema.with_metadata(new_metadata)
# TODO improve the logics in _sanitize_schema _validate_schema(target_schema)
data = _sanitize_schema(data, schema, on_bad_vectors, fill_value)
if schema is None:
schema = data.schema
_validate_schema(schema) table = _cast_to_target_schema(table, target_schema, allow_subschema)
return data, schema
return table
def _cast_to_target_schema(
table: pa.Table,
target_schema: pa.Schema,
allow_subschema: bool = False,
) -> pa.Table:
# pa.Table.cast expects field order not to be changed.
# Lance doesn't care about field order, so we don't need to rearrange fields
# to match the target schema. We just need to correctly cast the fields.
if table.schema == target_schema:
# Fast path when the schemas are already the same
return table
fields = []
for field in table.schema:
target_field = target_schema.field(field.name)
if target_field is None:
raise ValueError(f"Field {field.name} not found in target schema")
fields.append(target_field)
reordered_schema = pa.schema(fields, metadata=target_schema.metadata)
if not allow_subschema and len(reordered_schema) != len(target_schema):
raise ValueError(
"Input table has different number of columns than target schema"
)
if allow_subschema and len(reordered_schema) != len(target_schema):
fields = _infer_subschema(
list(iter(table.schema)), list(iter(reordered_schema))
)
subschema = pa.schema(fields, metadata=target_schema.metadata)
return table.cast(subschema)
else:
return table.cast(reordered_schema)
def _infer_subschema(
schema: List[pa.Field],
reference_fields: List[pa.Field],
) -> List[pa.Field]:
"""
Transform the list of fields so the types match the reference_fields.
The order of the fields is preserved.
``schema`` may have fewer fields than `reference_fields`, but it may not have
more fields.
"""
fields = []
lookup = {f.name: f for f in reference_fields}
for field in schema:
reference = lookup.get(field.name)
if reference is None:
raise ValueError("Unexpected field in schema: {}".format(field))
if pa.types.is_struct(reference.type):
new_type = pa.struct(
_infer_subschema(
field.type.fields,
reference.type.fields,
)
)
new_field = pa.field(
field.name,
new_type,
reference.nullable,
)
else:
new_field = reference
fields.append(new_field)
return fields
def sanitize_create_table( def sanitize_create_table(
@@ -193,13 +323,14 @@ def sanitize_create_table(
if data is not None: if data is not None:
if metadata is None and schema is not None: if metadata is None and schema is not None:
metadata = schema.metadata metadata = schema.metadata
data, schema = _sanitize_data( data = _sanitize_data(
data, data,
schema, schema,
metadata=metadata, metadata=metadata,
on_bad_vectors=on_bad_vectors, on_bad_vectors=on_bad_vectors,
fill_value=fill_value, fill_value=fill_value,
) )
schema = data.schema
else: else:
if schema is not None: if schema is not None:
data = pa.Table.from_pylist([], schema) data = pa.Table.from_pylist([], schema)
@@ -211,6 +342,8 @@ def sanitize_create_table(
if metadata: if metadata:
schema = schema.with_metadata(metadata) schema = schema.with_metadata(metadata)
# Need to apply metadata to the data as well
data = data.replace_schema_metadata(metadata)
return data, schema return data, schema
@@ -246,12 +379,22 @@ def _to_batches_with_split(data):
yield b yield b
def _append_vector_col(data: pa.Table, metadata: dict, schema: Optional[pa.Schema]): def _append_vector_columns(
data: pa.Table,
schema: Optional[pa.Schema] = None,
*,
metadata: Optional[dict] = None,
) -> pa.Table:
""" """
Use the embedding function to automatically embed the source column and add the Use the embedding function to automatically embed the source columns and add the
vector column to the table. vector columns to the table.
""" """
if schema is None:
metadata = metadata or {}
else:
metadata = schema.metadata or metadata or {}
functions = EmbeddingFunctionRegistry.get_instance().parse_functions(metadata) functions = EmbeddingFunctionRegistry.get_instance().parse_functions(metadata)
for vector_column, conf in functions.items(): for vector_column, conf in functions.items():
func = conf.function func = conf.function
no_vector_column = vector_column not in data.column_names no_vector_column = vector_column not in data.column_names
@@ -790,9 +933,9 @@ class Table(ABC):
-------- --------
>>> import lancedb >>> import lancedb
>>> data = [ >>> data = [
... {"x": 1, "vector": [1, 2]}, ... {"x": 1, "vector": [1.0, 2]},
... {"x": 2, "vector": [3, 4]}, ... {"x": 2, "vector": [3.0, 4]},
... {"x": 3, "vector": [5, 6]} ... {"x": 3, "vector": [5.0, 6]}
... ] ... ]
>>> db = lancedb.connect("./.lancedb") >>> db = lancedb.connect("./.lancedb")
>>> table = db.create_table("my_table", data) >>> table = db.create_table("my_table", data)
@@ -854,7 +997,7 @@ class Table(ABC):
-------- --------
>>> import lancedb >>> import lancedb
>>> import pandas as pd >>> import pandas as pd
>>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1, 2], [3, 4], [5, 6]]}) >>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1.0, 2], [3, 4], [5, 6]]})
>>> db = lancedb.connect("./.lancedb") >>> db = lancedb.connect("./.lancedb")
>>> table = db.create_table("my_table", data) >>> table = db.create_table("my_table", data)
>>> table.to_pandas() >>> table.to_pandas()
@@ -862,7 +1005,7 @@ class Table(ABC):
0 1 [1.0, 2.0] 0 1 [1.0, 2.0]
1 2 [3.0, 4.0] 1 2 [3.0, 4.0]
2 3 [5.0, 6.0] 2 3 [5.0, 6.0]
>>> table.update(where="x = 2", values={"vector": [10, 10]}) >>> table.update(where="x = 2", values={"vector": [10.0, 10]})
>>> table.to_pandas() >>> table.to_pandas()
x vector x vector
0 1 [1.0, 2.0] 0 1 [1.0, 2.0]
@@ -1880,9 +2023,9 @@ class LanceTable(Table):
-------- --------
>>> import lancedb >>> import lancedb
>>> data = [ >>> data = [
... {"x": 1, "vector": [1, 2]}, ... {"x": 1, "vector": [1.0, 2]},
... {"x": 2, "vector": [3, 4]}, ... {"x": 2, "vector": [3.0, 4]},
... {"x": 3, "vector": [5, 6]} ... {"x": 3, "vector": [5.0, 6]}
... ] ... ]
>>> db = lancedb.connect("./.lancedb") >>> db = lancedb.connect("./.lancedb")
>>> table = db.create_table("my_table", data) >>> table = db.create_table("my_table", data)
@@ -1971,7 +2114,7 @@ class LanceTable(Table):
-------- --------
>>> import lancedb >>> import lancedb
>>> import pandas as pd >>> import pandas as pd
>>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1, 2], [3, 4], [5, 6]]}) >>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1.0, 2], [3, 4], [5, 6]]})
>>> db = lancedb.connect("./.lancedb") >>> db = lancedb.connect("./.lancedb")
>>> table = db.create_table("my_table", data) >>> table = db.create_table("my_table", data)
>>> table.to_pandas() >>> table.to_pandas()
@@ -1979,7 +2122,7 @@ class LanceTable(Table):
0 1 [1.0, 2.0] 0 1 [1.0, 2.0]
1 2 [3.0, 4.0] 1 2 [3.0, 4.0]
2 3 [5.0, 6.0] 2 3 [5.0, 6.0]
>>> table.update(where="x = 2", values={"vector": [10, 10]}) >>> table.update(where="x = 2", values={"vector": [10.0, 10]})
>>> table.to_pandas() >>> table.to_pandas()
x vector x vector
0 1 [1.0, 2.0] 0 1 [1.0, 2.0]
@@ -2165,74 +2308,49 @@ class LanceTable(Table):
LOOP.run(self._table.migrate_v2_manifest_paths()) LOOP.run(self._table.migrate_v2_manifest_paths())
def _sanitize_schema( def _handle_bad_vectors(
data: pa.Table, table: pa.Table,
schema: pa.Schema = None, on_bad_vectors: Literal["error", "drop", "fill", "null"] = "error",
on_bad_vectors: str = "error",
fill_value: float = 0.0, fill_value: float = 0.0,
) -> pa.Table: ) -> pa.Table:
"""Ensure that the table has the expected schema. for field in table.schema:
# They can provide a 'vector' column that isn't yet a FSL
Parameters named_vector_col = (
---------- (
data: pa.Table pa.types.is_list(field.type)
The table to sanitize. or pa.types.is_large_list(field.type)
schema: pa.Schema; optional or pa.types.is_fixed_size_list(field.type)
The expected schema. If not provided, this just converts the
vector column to fixed_size_list(float32) if necessary.
on_bad_vectors: str, default "error"
What to do if any of the vectors are not the same size or contains NaNs.
One of "error", "drop", "fill", "null".
fill_value: float, default 0.
The value to use when filling vectors. Only used if on_bad_vectors="fill".
"""
if schema is not None:
# cast the columns to the expected types
data = data.combine_chunks()
for field in schema:
# TODO: we're making an assumption that fixed size list of 10 or more
# is a vector column. This is definitely a bit hacky.
likely_vector_col = (
pa.types.is_fixed_size_list(field.type)
and pa.types.is_float32(field.type.value_type)
and field.type.list_size >= 10
) )
is_default_vector_col = field.name == VECTOR_COLUMN_NAME and pa.types.is_floating(field.type.value_type)
if field.name in data.column_names and ( and field.name == VECTOR_COLUMN_NAME
likely_vector_col or is_default_vector_col )
): # TODO: we're making an assumption that fixed size list of 10 or more
data = _sanitize_vector_column( # is a vector column. This is definitely a bit hacky.
data, likely_vector_col = (
vector_column_name=field.name, pa.types.is_fixed_size_list(field.type)
on_bad_vectors=on_bad_vectors, and pa.types.is_floating(field.type.value_type)
fill_value=fill_value, and (field.type.list_size >= 10)
table_schema=schema,
)
return pa.Table.from_arrays(
[data[name] for name in schema.names], schema=schema
) )
# just check the vector column if named_vector_col or likely_vector_col:
if VECTOR_COLUMN_NAME in data.column_names: table = _handle_bad_vector_column(
return _sanitize_vector_column( table,
data, vector_column_name=field.name,
vector_column_name=VECTOR_COLUMN_NAME, on_bad_vectors=on_bad_vectors,
on_bad_vectors=on_bad_vectors, fill_value=fill_value,
fill_value=fill_value, )
)
return data return table
def _sanitize_vector_column( def _handle_bad_vector_column(
data: pa.Table, data: pa.Table,
vector_column_name: str, vector_column_name: str,
table_schema: Optional[pa.Schema] = None,
on_bad_vectors: str = "error", on_bad_vectors: str = "error",
fill_value: float = 0.0, fill_value: float = 0.0,
) -> pa.Table: ) -> pa.Table:
""" """
Ensure that the vector column exists and has type fixed_size_list(float32) Ensure that the vector column exists and has type fixed_size_list(float)
Parameters Parameters
---------- ----------
@@ -2246,141 +2364,118 @@ def _sanitize_vector_column(
fill_value: float, default 0.0 fill_value: float, default 0.0
The value to use when filling vectors. Only used if on_bad_vectors="fill". The value to use when filling vectors. Only used if on_bad_vectors="fill".
""" """
# ChunkedArray is annoying to work with, so we combine chunks here vec_arr = data[vector_column_name]
vec_arr = data[vector_column_name].combine_chunks()
if table_schema is not None:
field = table_schema.field(vector_column_name)
else:
field = None
typ = data[vector_column_name].type
if pa.types.is_list(typ) or pa.types.is_large_list(typ):
# if it's a variable size list array,
# we make sure the dimensions are all the same
has_jagged_ndims = len(vec_arr.values) % len(data) != 0
if has_jagged_ndims:
data = _sanitize_jagged(
data, fill_value, on_bad_vectors, vec_arr, vector_column_name
)
vec_arr = data[vector_column_name].combine_chunks()
vec_arr = ensure_fixed_size_list(vec_arr)
data = data.set_column(
data.column_names.index(vector_column_name), vector_column_name, vec_arr
)
elif not pa.types.is_fixed_size_list(vec_arr.type):
raise TypeError(f"Unsupported vector column type: {vec_arr.type}")
if pa.types.is_float16(vec_arr.values.type): has_nan = has_nan_values(vec_arr)
# Use numpy to check for NaNs, because as pyarrow does not have `is_nan`
# kernel over f16 types yet.
values_np = vec_arr.values.to_numpy(zero_copy_only=True)
if np.isnan(values_np).any():
data = _sanitize_nans(
data, fill_value, on_bad_vectors, vec_arr, vector_column_name
)
else:
if (
field is not None
and not field.nullable
and pc.any(pc.is_null(vec_arr.values)).as_py()
) or (pc.any(pc.is_nan(vec_arr.values)).as_py()):
data = _sanitize_nans(
data, fill_value, on_bad_vectors, vec_arr, vector_column_name
)
return data
def ensure_fixed_size_list(vec_arr) -> pa.FixedSizeListArray:
values = vec_arr.values
if not (pa.types.is_float16(values.type) or pa.types.is_float32(values.type)):
values = values.cast(pa.float32())
if pa.types.is_fixed_size_list(vec_arr.type): if pa.types.is_fixed_size_list(vec_arr.type):
list_size = vec_arr.type.list_size dim = vec_arr.type.list_size
else: else:
list_size = len(values) / len(vec_arr) dim = _modal_list_size(vec_arr)
vec_arr = pa.FixedSizeListArray.from_arrays(values, list_size) has_wrong_dim = pc.not_equal(pc.list_value_length(vec_arr), dim)
return vec_arr
has_bad_vectors = pc.any(has_nan).as_py() or pc.any(has_wrong_dim).as_py()
def _sanitize_jagged(data, fill_value, on_bad_vectors, vec_arr, vector_column_name): if has_bad_vectors:
"""Sanitize jagged vectors.""" is_bad = pc.or_(has_nan, has_wrong_dim)
if on_bad_vectors == "error": if on_bad_vectors == "error":
raise ValueError( if pc.any(has_wrong_dim).as_py():
f"Vector column {vector_column_name} has variable length vectors " raise ValueError(
"Set on_bad_vectors='drop' to remove them, or " f"Vector column '{vector_column_name}' has variable length "
"set on_bad_vectors='fill' and fill_value=<value> to replace them." "vectors. Set on_bad_vectors='drop' to remove them, "
) "set on_bad_vectors='fill' and fill_value=<value> to replace them, "
"or set on_bad_vectors='null' to replace them with null."
lst_lengths = pc.list_value_length(vec_arr) )
ndims = pc.max(lst_lengths).as_py() else:
correct_ndims = pc.equal(lst_lengths, ndims) raise ValueError(
f"Vector column '{vector_column_name}' has NaNs. "
if on_bad_vectors == "fill": "Set on_bad_vectors='drop' to remove them, "
if fill_value is None: "set on_bad_vectors='fill' and fill_value=<value> to replace them, "
raise ValueError( "or set on_bad_vectors='null' to replace them with null."
"`fill_value` must not be None if `on_bad_vectors` is 'fill'" )
elif on_bad_vectors == "null":
vec_arr = pc.if_else(
is_bad,
pa.scalar(None),
vec_arr,
) )
fill_arr = pa.scalar([float(fill_value)] * ndims) elif on_bad_vectors == "drop":
vec_arr = pc.if_else(correct_ndims, vec_arr, fill_arr) data = data.filter(pc.invert(is_bad))
data = data.set_column( vec_arr = data[vector_column_name]
data.column_names.index(vector_column_name), vector_column_name, vec_arr elif on_bad_vectors == "fill":
) if fill_value is None:
elif on_bad_vectors == "drop": raise ValueError(
data = data.filter(correct_ndims) "`fill_value` must not be None if `on_bad_vectors` is 'fill'"
elif on_bad_vectors == "null": )
data = data.set_column( vec_arr = pc.if_else(
data.column_names.index(vector_column_name), is_bad,
vector_column_name, pa.scalar([fill_value] * dim),
pc.if_else(correct_ndims, vec_arr, pa.scalar(None)), vec_arr,
)
return data
def _sanitize_nans(
data,
fill_value,
on_bad_vectors,
vec_arr: pa.FixedSizeListArray,
vector_column_name: str,
):
"""Sanitize NaNs in vectors"""
assert pa.types.is_fixed_size_list(vec_arr.type)
if on_bad_vectors == "error":
raise ValueError(
f"Vector column {vector_column_name} has NaNs. "
"Set on_bad_vectors='drop' to remove them, or "
"set on_bad_vectors='fill' and fill_value=<value> to replace them. "
"Or set on_bad_vectors='null' to replace them with null."
)
elif on_bad_vectors == "fill":
if fill_value is None:
raise ValueError(
"`fill_value` must not be None if `on_bad_vectors` is 'fill'"
) )
fill_value = float(fill_value) else:
values = pc.if_else(pc.is_nan(vec_arr.values), fill_value, vec_arr.values) raise ValueError(f"Invalid value for on_bad_vectors: {on_bad_vectors}")
ndims = len(vec_arr[0])
vec_arr = pa.FixedSizeListArray.from_arrays(values, ndims) position = data.column_names.index(vector_column_name)
data = data.set_column( return data.set_column(position, vector_column_name, vec_arr)
data.column_names.index(vector_column_name), vector_column_name, vec_arr
)
elif on_bad_vectors == "drop": def has_nan_values(arr: Union[pa.ListArray, pa.ChunkedArray]) -> pa.BooleanArray:
# Drop is very slow to be able to filter out NaNs in a fixed size list array if isinstance(arr, pa.ChunkedArray):
np_arr = np.isnan(vec_arr.values.to_numpy(zero_copy_only=False)) values = pa.chunked_array([chunk.flatten() for chunk in arr.chunks])
np_arr = np_arr.reshape(-1, vec_arr.type.list_size) else:
not_nulls = np.any(np_arr, axis=1) values = arr.flatten()
data = data.filter(~not_nulls) if pa.types.is_float16(values.type):
elif on_bad_vectors == "null": # is_nan isn't yet implemented for f16, so we cast to f32
# null = pa.nulls(len(vec_arr)).cast(vec_arr.type) # https://github.com/apache/arrow/issues/45083
# values = pc.if_else(pc.is_nan(vec_arr.values), fill_value, vec_arr.values) values_has_nan = pc.is_nan(values.cast(pa.float32()))
np_arr = np.isnan(vec_arr.values.to_numpy(zero_copy_only=False)) else:
np_arr = np_arr.reshape(-1, vec_arr.type.list_size) values_has_nan = pc.is_nan(values)
no_nans = np.any(np_arr, axis=1) values_indices = pc.list_parent_indices(arr)
data = data.set_column( has_nan_indices = pc.unique(pc.filter(values_indices, values_has_nan))
data.column_names.index(vector_column_name), indices = pa.array(range(len(arr)), type=pa.uint32())
vector_column_name, return pc.is_in(indices, has_nan_indices)
pc.if_else(no_nans, vec_arr, pa.scalar(None)),
)
return data def _infer_target_schema(table: pa.Table) -> pa.Schema:
schema = table.schema
for i, field in enumerate(schema):
if (
field.name == VECTOR_COLUMN_NAME
and (pa.types.is_list(field.type) or pa.types.is_large_list(field.type))
and pa.types.is_floating(field.type.value_type)
):
# Use the most common length of the list as the dimensions
dim = _modal_list_size(table.column(i))
new_field = pa.field(
VECTOR_COLUMN_NAME,
pa.list_(pa.float32(), dim),
nullable=field.nullable,
)
schema = schema.set(i, new_field)
elif (
field.name == VECTOR_COLUMN_NAME
and (pa.types.is_list(field.type) or pa.types.is_large_list(field.type))
and pa.types.is_integer(field.type.value_type)
):
# Use the most common length of the list as the dimensions
dim = _modal_list_size(table.column(i))
new_field = pa.field(
VECTOR_COLUMN_NAME,
pa.list_(pa.uint8(), dim),
nullable=field.nullable,
)
schema = schema.set(i, new_field)
return schema
def _modal_list_size(arr: Union[pa.ListArray, pa.ChunkedArray]) -> int:
# Use the most common length of the list as the dimensions
return pc.mode(pc.list_value_length(arr))[0].as_py()["mode"]
def _validate_schema(schema: pa.Schema): def _validate_schema(schema: pa.Schema):
@@ -2410,28 +2505,6 @@ def _validate_metadata(metadata: dict):
_validate_metadata(v) _validate_metadata(v)
def _process_iterator(data: Iterable, schema: Optional[pa.Schema] = None) -> pa.Table:
batches = []
for batch in data:
batch_table = _coerce_to_table(batch, schema)
if schema is not None:
if batch_table.schema != schema:
try:
batch_table = batch_table.cast(schema)
except pa.lib.ArrowInvalid: # type: ignore
raise ValueError(
f"Input iterator yielded a batch with schema that "
f"does not match the expected schema.\nExpected:\n{schema}\n"
f"Got:\n{batch_table.schema}"
)
batches.append(batch_table)
if batches:
return pa.concat_tables(batches)
else:
raise ValueError("Input iterable is empty")
class AsyncTable: class AsyncTable:
""" """
An AsyncTable is a collection of Records in a LanceDB Database. An AsyncTable is a collection of Records in a LanceDB Database.
@@ -2678,16 +2751,17 @@ class AsyncTable:
on_bad_vectors = "error" on_bad_vectors = "error"
if fill_value is None: if fill_value is None:
fill_value = 0.0 fill_value = 0.0
table_and_schema: Tuple[pa.Table, pa.Schema] = _sanitize_data( data = _sanitize_data(
data, data,
schema, schema,
metadata=schema.metadata, metadata=schema.metadata,
on_bad_vectors=on_bad_vectors, on_bad_vectors=on_bad_vectors,
fill_value=fill_value, fill_value=fill_value,
allow_subschema=True,
) )
tbl, schema = table_and_schema if isinstance(data, pa.Table):
if isinstance(tbl, pa.Table): data = data.to_reader()
data = pa.RecordBatchReader.from_batches(schema, tbl.to_batches())
await self._inner.add(data, mode or "append") await self._inner.add(data, mode or "append")
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder: def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
@@ -2822,12 +2896,13 @@ class AsyncTable:
on_bad_vectors = "error" on_bad_vectors = "error"
if fill_value is None: if fill_value is None:
fill_value = 0.0 fill_value = 0.0
data, _ = _sanitize_data( data = _sanitize_data(
new_data, new_data,
schema, schema,
metadata=schema.metadata, metadata=schema.metadata,
on_bad_vectors=on_bad_vectors, on_bad_vectors=on_bad_vectors,
fill_value=fill_value, fill_value=fill_value,
allow_subschema=True,
) )
if isinstance(data, pa.Table): if isinstance(data, pa.Table):
data = pa.RecordBatchReader.from_batches(data.schema, data.to_batches()) data = pa.RecordBatchReader.from_batches(data.schema, data.to_batches())
@@ -2862,9 +2937,9 @@ class AsyncTable:
-------- --------
>>> import lancedb >>> import lancedb
>>> data = [ >>> data = [
... {"x": 1, "vector": [1, 2]}, ... {"x": 1, "vector": [1.0, 2]},
... {"x": 2, "vector": [3, 4]}, ... {"x": 2, "vector": [3.0, 4]},
... {"x": 3, "vector": [5, 6]} ... {"x": 3, "vector": [5.0, 6]}
... ] ... ]
>>> db = lancedb.connect("./.lancedb") >>> db = lancedb.connect("./.lancedb")
>>> table = db.create_table("my_table", data) >>> table = db.create_table("my_table", data)

View File

@@ -223,9 +223,7 @@ def inf_vector_column_query(schema: pa.Schema) -> str:
vector_col_count = 0 vector_col_count = 0
for field_name in schema.names: for field_name in schema.names:
field = schema.field(field_name) field = schema.field(field_name)
if pa.types.is_fixed_size_list(field.type) and pa.types.is_floating( if pa.types.is_fixed_size_list(field.type):
field.type.value_type
):
vector_col_count += 1 vector_col_count += 1
if vector_col_count > 1: if vector_col_count > 1:
raise ValueError( raise ValueError(

View File

@@ -21,7 +21,7 @@ def test_binary_vector():
] ]
tbl = db.create_table("my_binary_vectors", data=data) tbl = db.create_table("my_binary_vectors", data=data)
query = np.random.randint(0, 256, size=16) query = np.random.randint(0, 256, size=16)
tbl.search(query).to_arrow() tbl.search(query).metric("hamming").to_arrow()
# --8<-- [end:sync_binary_vector] # --8<-- [end:sync_binary_vector]
db.drop_table("my_binary_vectors") db.drop_table("my_binary_vectors")
@@ -39,6 +39,6 @@ async def test_binary_vector_async():
] ]
tbl = await db.create_table("my_binary_vectors", data=data) tbl = await db.create_table("my_binary_vectors", data=data)
query = np.random.randint(0, 256, size=16) query = np.random.randint(0, 256, size=16)
await tbl.query().nearest_to(query).to_arrow() await tbl.query().nearest_to(query).distance_type("hamming").to_arrow()
# --8<-- [end:async_binary_vector] # --8<-- [end:async_binary_vector]
await db.drop_table("my_binary_vectors") await db.drop_table("my_binary_vectors")

View File

@@ -118,9 +118,9 @@ def test_scalar_index():
# --8<-- [end:search_with_scalar_index] # --8<-- [end:search_with_scalar_index]
# --8<-- [start:vector_search_with_scalar_index] # --8<-- [start:vector_search_with_scalar_index]
data = [ data = [
{"book_id": 1, "vector": [1, 2]}, {"book_id": 1, "vector": [1.0, 2]},
{"book_id": 2, "vector": [3, 4]}, {"book_id": 2, "vector": [3.0, 4]},
{"book_id": 3, "vector": [5, 6]}, {"book_id": 3, "vector": [5.0, 6]},
] ]
table = db.create_table("book_with_embeddings", data) table = db.create_table("book_with_embeddings", data)
@@ -156,9 +156,9 @@ async def test_scalar_index_async():
# --8<-- [end:search_with_scalar_index_async] # --8<-- [end:search_with_scalar_index_async]
# --8<-- [start:vector_search_with_scalar_index_async] # --8<-- [start:vector_search_with_scalar_index_async]
data = [ data = [
{"book_id": 1, "vector": [1, 2]}, {"book_id": 1, "vector": [1.0, 2]},
{"book_id": 2, "vector": [3, 4]}, {"book_id": 2, "vector": [3.0, 4]},
{"book_id": 3, "vector": [5, 6]}, {"book_id": 3, "vector": [5.0, 6]},
] ]
async_tbl = await async_db.create_table("book_with_embeddings_async", data) async_tbl = await async_db.create_table("book_with_embeddings_async", data)
(await async_tbl.query().where("book_id != 3").nearest_to([1, 2]).to_pandas()) (await async_tbl.query().where("book_id != 3").nearest_to([1, 2]).to_pandas())

View File

@@ -198,7 +198,6 @@ def test_embedding_function_with_pandas(tmp_path):
{ {
"text": ["hello world", "goodbye world"], "text": ["hello world", "goodbye world"],
"val": [1, 2], "val": [1, 2],
"not-used": ["s1", "s3"],
} }
) )
db = lancedb.connect(tmp_path) db = lancedb.connect(tmp_path)
@@ -212,7 +211,6 @@ def test_embedding_function_with_pandas(tmp_path):
{ {
"text": ["extra", "more"], "text": ["extra", "more"],
"val": [4, 5], "val": [4, 5],
"misc-col": ["s1", "s3"],
} }
) )
tbl.add(df) tbl.add(df)

View File

@@ -242,8 +242,8 @@ def test_add_subschema(mem_db: DBConnection):
data = {"price": 10.0, "item": "foo"} data = {"price": 10.0, "item": "foo"}
table.add([data]) table.add([data])
data = {"price": 2.0, "vector": [3.1, 4.1]} data = pd.DataFrame({"price": [2.0], "vector": [[3.1, 4.1]]})
table.add([data]) table.add(data)
data = {"price": 3.0, "vector": [5.9, 26.5], "item": "bar"} data = {"price": 3.0, "vector": [5.9, 26.5], "item": "bar"}
table.add([data]) table.add([data])
@@ -259,7 +259,7 @@ def test_add_subschema(mem_db: DBConnection):
data = {"item": "foo"} data = {"item": "foo"}
# We can't omit a column if it's not nullable # We can't omit a column if it's not nullable
with pytest.raises(RuntimeError, match="Invalid user input"): with pytest.raises(RuntimeError, match="Append with different schema"):
table.add([data]) table.add([data])
# We can add it if we make the column nullable # We can add it if we make the column nullable
@@ -292,6 +292,7 @@ def test_add_nullability(mem_db: DBConnection):
] ]
) )
table = mem_db.create_table("test", schema=schema) table = mem_db.create_table("test", schema=schema)
assert table.schema.field("vector").nullable is False
nullable_schema = pa.schema( nullable_schema = pa.schema(
[ [
@@ -320,7 +321,10 @@ def test_add_nullability(mem_db: DBConnection):
schema=nullable_schema, schema=nullable_schema,
) )
# We can't add nullable schema if it contains nulls # We can't add nullable schema if it contains nulls
with pytest.raises(Exception, match="Vector column vector has NaNs"): with pytest.raises(
Exception,
match="Casting field 'vector' with null values to non-nullable",
):
table.add(data) table.add(data)
# But we can make it nullable # But we can make it nullable
@@ -776,6 +780,38 @@ def test_merge_insert(mem_db: DBConnection):
assert table.to_arrow().sort_by("a") == expected assert table.to_arrow().sort_by("a") == expected
# We vary the data format because there are slight differences in how
# subschemas are handled in different formats
@pytest.mark.parametrize(
"data_format",
[
lambda table: table,
lambda table: table.to_pandas(),
lambda table: table.to_pylist(),
],
ids=["pa.Table", "pd.DataFrame", "rows"],
)
def test_merge_insert_subschema(mem_db: DBConnection, data_format):
initial_data = pa.table(
{"id": range(3), "a": [1.0, 2.0, 3.0], "c": ["x", "x", "x"]}
)
table = mem_db.create_table("my_table", data=initial_data)
new_data = pa.table({"id": [2, 3], "c": ["y", "y"]})
new_data = data_format(new_data)
(
table.merge_insert(on="id")
.when_matched_update_all()
.when_not_matched_insert_all()
.execute(new_data)
)
expected = pa.table(
{"id": [0, 1, 2, 3], "a": [1.0, 2.0, 3.0, None], "c": ["x", "x", "y", "y"]}
)
assert table.to_arrow().sort_by("id") == expected
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_merge_insert_async(mem_db_async: AsyncConnection): async def test_merge_insert_async(mem_db_async: AsyncConnection):
data = pa.table({"a": [1, 2, 3], "b": ["a", "b", "c"]}) data = pa.table({"a": [1, 2, 3], "b": ["a", "b", "c"]})

View File

@@ -13,10 +13,27 @@
import os import os
import pathlib import pathlib
from typing import Optional
import lance
from lancedb.conftest import MockTextEmbeddingFunction
from lancedb.embeddings.base import EmbeddingFunctionConfig
from lancedb.embeddings.registry import EmbeddingFunctionRegistry
from lancedb.table import (
_append_vector_columns,
_cast_to_target_schema,
_handle_bad_vectors,
_into_pyarrow_table,
_sanitize_data,
_infer_target_schema,
)
import pyarrow as pa
import pandas as pd
import polars as pl
import pytest import pytest
import lancedb import lancedb
from lancedb.util import get_uri_scheme, join_uri, value_to_sql from lancedb.util import get_uri_scheme, join_uri, value_to_sql
from utils import exception_output
def test_normalize_uri(): def test_normalize_uri():
@@ -111,3 +128,460 @@ def test_value_to_sql_string(tmp_path):
for value in values: for value in values:
table.update(where=f"search = {value_to_sql(value)}", values={"replace": value}) table.update(where=f"search = {value_to_sql(value)}", values={"replace": value})
assert table.to_pandas().query("search == @value")["replace"].item() == value assert table.to_pandas().query("search == @value")["replace"].item() == value
def test_append_vector_columns():
registry = EmbeddingFunctionRegistry.get_instance()
registry.register("test")(MockTextEmbeddingFunction)
conf = EmbeddingFunctionConfig(
source_column="text",
vector_column="vector",
function=MockTextEmbeddingFunction(),
)
metadata = registry.get_table_metadata([conf])
schema = pa.schema(
{
"text": pa.string(),
"vector": pa.list_(pa.float64(), 10),
}
)
data = pa.table(
{
"text": ["hello"],
"vector": [None], # Replaces null
},
schema=schema,
)
output = _append_vector_columns(
data,
schema, # metadata passed separate from schema
metadata=metadata,
)
assert output.schema == schema
assert output["vector"].null_count == 0
# Adds if missing
data = pa.table({"text": ["hello"]})
output = _append_vector_columns(
data,
schema.with_metadata(metadata),
)
assert output.schema == schema
assert output["vector"].null_count == 0
# doesn't embed if already there
data = pa.table(
{
"text": ["hello"],
"vector": [[42.0] * 10],
},
schema=schema,
)
output = _append_vector_columns(
data,
schema.with_metadata(metadata),
)
assert output == data # No change
# No provided schema
data = pa.table(
{
"text": ["hello"],
}
)
output = _append_vector_columns(
data,
metadata=metadata,
)
expected_schema = pa.schema(
{
"text": pa.string(),
"vector": pa.list_(pa.float32(), 10),
}
)
assert output.schema == expected_schema
assert output["vector"].null_count == 0
@pytest.mark.parametrize("on_bad_vectors", ["error", "drop", "fill", "null"])
def test_handle_bad_vectors_jagged(on_bad_vectors):
vector = pa.array([[1.0, 2.0], [3.0], [4.0, 5.0]])
schema = pa.schema({"vector": pa.list_(pa.float64())})
data = pa.table({"vector": vector}, schema=schema)
if on_bad_vectors == "error":
with pytest.raises(ValueError) as e:
output = _handle_bad_vectors(
data,
on_bad_vectors=on_bad_vectors,
)
output = exception_output(e)
assert output == (
"ValueError: Vector column 'vector' has variable length vectors. Set "
"on_bad_vectors='drop' to remove them, set on_bad_vectors='fill' "
"and fill_value=<value> to replace them, or set on_bad_vectors='null' "
"to replace them with null."
)
return
else:
output = _handle_bad_vectors(
data,
on_bad_vectors=on_bad_vectors,
fill_value=42.0,
)
if on_bad_vectors == "drop":
expected = pa.array([[1.0, 2.0], [4.0, 5.0]])
elif on_bad_vectors == "fill":
expected = pa.array([[1.0, 2.0], [42.0, 42.0], [4.0, 5.0]])
elif on_bad_vectors == "null":
expected = pa.array([[1.0, 2.0], None, [4.0, 5.0]])
assert output["vector"].combine_chunks() == expected
@pytest.mark.parametrize("on_bad_vectors", ["error", "drop", "fill", "null"])
def test_handle_bad_vectors_nan(on_bad_vectors):
vector = pa.array([[1.0, float("nan")], [3.0, 4.0]])
data = pa.table({"vector": vector})
if on_bad_vectors == "error":
with pytest.raises(ValueError) as e:
output = _handle_bad_vectors(
data,
on_bad_vectors=on_bad_vectors,
)
output = exception_output(e)
assert output == (
"ValueError: Vector column 'vector' has NaNs. Set "
"on_bad_vectors='drop' to remove them, set on_bad_vectors='fill' "
"and fill_value=<value> to replace them, or set on_bad_vectors='null' "
"to replace them with null."
)
return
else:
output = _handle_bad_vectors(
data,
on_bad_vectors=on_bad_vectors,
fill_value=42.0,
)
if on_bad_vectors == "drop":
expected = pa.array([[3.0, 4.0]])
elif on_bad_vectors == "fill":
expected = pa.array([[42.0, 42.0], [3.0, 4.0]])
elif on_bad_vectors == "null":
expected = pa.array([None, [3.0, 4.0]])
assert output["vector"].combine_chunks() == expected
def test_handle_bad_vectors_noop():
# ChunkedArray should be preserved as-is
vector = pa.chunked_array(
[[[1.0, 2.0], [3.0, 4.0]]], type=pa.list_(pa.float64(), 2)
)
data = pa.table({"vector": vector})
output = _handle_bad_vectors(data)
assert output["vector"] == vector
class TestModel(lancedb.pydantic.LanceModel):
a: Optional[int]
b: Optional[int]
# TODO: huggingface,
@pytest.mark.parametrize(
"data",
[
lambda: [{"a": 1, "b": 2}],
lambda: pa.RecordBatch.from_pylist([{"a": 1, "b": 2}]),
lambda: pa.table({"a": [1], "b": [2]}),
lambda: pa.table({"a": [1], "b": [2]}).to_reader(),
lambda: iter(pa.table({"a": [1], "b": [2]}).to_batches()),
lambda: (
lance.write_dataset(
pa.table({"a": [1], "b": [2]}),
"memory://test",
)
),
lambda: (
lance.write_dataset(
pa.table({"a": [1], "b": [2]}),
"memory://test",
).scanner()
),
lambda: pd.DataFrame({"a": [1], "b": [2]}),
lambda: pl.DataFrame({"a": [1], "b": [2]}),
lambda: pl.LazyFrame({"a": [1], "b": [2]}),
lambda: [TestModel(a=1, b=2)],
],
ids=[
"rows",
"pa.RecordBatch",
"pa.Table",
"pa.RecordBatchReader",
"batch_iter",
"lance.LanceDataset",
"lance.LanceScanner",
"pd.DataFrame",
"pl.DataFrame",
"pl.LazyFrame",
"pydantic",
],
)
def test_into_pyarrow_table(data):
expected = pa.table({"a": [1], "b": [2]})
output = _into_pyarrow_table(data())
assert output == expected
def test_infer_target_schema():
example = pa.schema(
{
"vec1": pa.list_(pa.float64(), 2),
"vector": pa.list_(pa.float64()),
}
)
data = pa.table(
{
"vec1": [[0.0] * 2],
"vector": [[0.0] * 2],
},
schema=example,
)
expected = pa.schema(
{
"vec1": pa.list_(pa.float64(), 2),
"vector": pa.list_(pa.float32(), 2),
}
)
output = _infer_target_schema(data)
assert output == expected
# Handle large list and use modal size
# Most vectors are of length 2, so we should infer that as the target dimension
example = pa.schema(
{
"vector": pa.large_list(pa.float64()),
}
)
data = pa.table(
{
"vector": [[0.0] * 2, [0.0], [0.0] * 2],
},
schema=example,
)
expected = pa.schema(
{
"vector": pa.list_(pa.float32(), 2),
}
)
output = _infer_target_schema(data)
assert output == expected
# ignore if not list
example = pa.schema(
{
"vector": pa.float64(),
}
)
data = pa.table(
{
"vector": [0.0],
},
schema=example,
)
expected = example
output = _infer_target_schema(data)
assert output == expected
@pytest.mark.parametrize(
"data",
[
[{"id": 1, "text": "hello"}],
pa.RecordBatch.from_pylist([{"id": 1, "text": "hello"}]),
pd.DataFrame({"id": [1], "text": ["hello"]}),
pl.DataFrame({"id": [1], "text": ["hello"]}),
],
ids=["rows", "pa.RecordBatch", "pd.DataFrame", "pl.DataFrame"],
)
@pytest.mark.parametrize(
"schema",
[
None,
pa.schema(
{
"id": pa.int32(),
"text": pa.string(),
"vector": pa.list_(pa.float32(), 10),
}
),
pa.schema(
{
"id": pa.int64(),
"text": pa.string(),
"vector": pa.list_(pa.float32(), 10),
"extra": pa.int64(),
}
),
],
ids=["infer", "explicit", "subschema"],
)
@pytest.mark.parametrize("with_embedding", [True, False])
def test_sanitize_data(
data,
schema: Optional[pa.Schema],
with_embedding: bool,
):
if with_embedding:
registry = EmbeddingFunctionRegistry.get_instance()
registry.register("test")(MockTextEmbeddingFunction)
conf = EmbeddingFunctionConfig(
source_column="text",
vector_column="vector",
function=MockTextEmbeddingFunction(),
)
metadata = registry.get_table_metadata([conf])
else:
metadata = None
if schema is not None:
to_remove = schema.get_field_index("extra")
if to_remove >= 0:
expected_schema = schema.remove(to_remove)
else:
expected_schema = schema
else:
expected_schema = pa.schema(
{
"id": pa.int64(),
"text": pa.large_utf8()
if isinstance(data, pl.DataFrame)
else pa.string(),
"vector": pa.list_(pa.float32(), 10),
}
)
if not with_embedding:
to_remove = expected_schema.get_field_index("vector")
if to_remove >= 0:
expected_schema = expected_schema.remove(to_remove)
expected = pa.table(
{
"id": [1],
"text": ["hello"],
"vector": [[0.0] * 10],
},
schema=expected_schema,
)
output_data = _sanitize_data(
data,
target_schema=schema,
metadata=metadata,
allow_subschema=True,
)
assert output_data == expected
def test_cast_to_target_schema():
original_schema = pa.schema(
{
"id": pa.int32(),
"struct": pa.struct(
[
pa.field("a", pa.int32()),
]
),
"vector": pa.list_(pa.float64()),
"vec1": pa.list_(pa.float64(), 2),
"vec2": pa.list_(pa.float32(), 2),
}
)
data = pa.table(
{
"id": [1],
"struct": [{"a": 1}],
"vector": [[0.0] * 2],
"vec1": [[0.0] * 2],
"vec2": [[0.0] * 2],
},
schema=original_schema,
)
target = pa.schema(
{
"id": pa.int64(),
"struct": pa.struct(
[
pa.field("a", pa.int64()),
]
),
"vector": pa.list_(pa.float32(), 2),
"vec1": pa.list_(pa.float32(), 2),
"vec2": pa.list_(pa.float32(), 2),
}
)
output = _cast_to_target_schema(data, target)
expected = pa.table(
{
"id": [1],
"struct": [{"a": 1}],
"vector": [[0.0] * 2],
"vec1": [[0.0] * 2],
"vec2": [[0.0] * 2],
},
schema=target,
)
# Data can be a subschema of the target
target = pa.schema(
{
"id": pa.int64(),
"struct": pa.struct(
[
pa.field("a", pa.int64()),
# Additional nested field
pa.field("b", pa.int64()),
]
),
"vector": pa.list_(pa.float32(), 2),
"vec1": pa.list_(pa.float32(), 2),
"vec2": pa.list_(pa.float32(), 2),
# Additional field
"extra": pa.int64(),
}
)
with pytest.raises(Exception):
_cast_to_target_schema(data, target)
output = _cast_to_target_schema(data, target, allow_subschema=True)
expected_schema = pa.schema(
{
"id": pa.int64(),
"struct": pa.struct(
[
pa.field("a", pa.int64()),
]
),
"vector": pa.list_(pa.float32(), 2),
"vec1": pa.list_(pa.float32(), 2),
"vec2": pa.list_(pa.float32(), 2),
}
)
expected = pa.table(
{
"id": [1],
"struct": [{"a": 1}],
"vector": [[0.0] * 2],
"vec1": [[0.0] * 2],
"vec2": [[0.0] * 2],
},
schema=expected_schema,
)
assert output == expected