feat(python)!: support inserting and upserting subschemas (#1965)

BREAKING CHANGE: For a field "vector", list of integers will now be
converted to binary (uint8) vectors instead of f32 vectors. Use float
values instead for f32 vectors.

* Adds proper support for inserting and upserting subsets of the full
schema. I thought I had previously implemented this in #1827, but it
turns out I had not tested carefully enough.
* Refactors `_santize_data` and other utility functions to be simpler
and not require `numpy` or `combine_chunks()`.
* Added a new suite of unit tests to validate sanitization utilities.

## Examples

```python
import pandas as pd
import lancedb

db = lancedb.connect("memory://demo")
intial_data = pd.DataFrame({
    "a": [1, 2, 3],
    "b": [4, 5, 6],
    "c": [7, 8, 9]
})
table = db.create_table("demo", intial_data)

# Insert a subschema
new_data = pd.DataFrame({"a": [10, 11]})
table.add(new_data)
table.to_pandas()
```
```
    a    b    c
0   1  4.0  7.0
1   2  5.0  8.0
2   3  6.0  9.0
3  10  NaN  NaN
4  11  NaN  NaN
```


```python
# Upsert a subschema
upsert_data = pd.DataFrame({
    "a": [3, 10, 15],
    "b": [6, 7, 8],
})
table.merge_insert(on="a").when_matched_update_all().when_not_matched_insert_all().execute(upsert_data)
table.to_pandas()
```
```
    a    b    c
0   1  4.0  7.0
1   2  5.0  8.0
2   3  6.0  9.0
3  10  7.0  NaN
4  11  NaN  NaN
5  15  8.0  NaN
```
This commit is contained in:
Will Jones
2025-01-08 10:11:10 -08:00
committed by GitHub
parent 3c0a64be8f
commit c557e77f09
10 changed files with 874 additions and 292 deletions

View File

@@ -784,10 +784,6 @@ class AsyncConnection(object):
registry = EmbeddingFunctionRegistry.get_instance()
metadata = registry.get_table_metadata(embedding_functions)
data, schema = sanitize_create_table(
data, schema, metadata, on_bad_vectors, fill_value
)
# Defining defaults here and not in function prototype. In the future
# these defaults will move into rust so better to keep them as None.
if on_bad_vectors is None:

View File

@@ -108,9 +108,14 @@ class EmbeddingFunctionRegistry:
An empty dict is returned if input is None or does not
contain b"embedding_functions".
"""
if metadata is None or b"embedding_functions" not in metadata:
if metadata is None:
return {}
# Look at both bytes and string keys, since we might use either
serialized = metadata.get(
b"embedding_functions", metadata.get("embedding_functions")
)
if serialized is None:
return {}
serialized = metadata[b"embedding_functions"]
raw_list = json.loads(serialized.decode("utf-8"))
return {
obj["vector_column"]: EmbeddingFunctionConfig(

View File

@@ -472,7 +472,7 @@ class LanceQueryBuilder(ABC):
--------
>>> import lancedb
>>> db = lancedb.connect("./.lancedb")
>>> table = db.create_table("my_table", [{"vector": [99, 99]}])
>>> table = db.create_table("my_table", [{"vector": [99.0, 99]}])
>>> query = [100, 100]
>>> plan = table.search(query).explain_plan(True)
>>> print(plan) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE

View File

@@ -25,7 +25,6 @@ from urllib.parse import urlparse
import lance
from lancedb.background_loop import LOOP
from .dependencies import _check_for_pandas
import numpy as np
import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.fs as pa_fs
@@ -74,34 +73,17 @@ pl = safe_import_polars()
QueryType = Literal["vector", "fts", "hybrid", "auto"]
def _pd_schema_without_embedding_funcs(
schema: Optional[pa.Schema], columns: List[str]
) -> Optional[pa.Schema]:
"""Return a schema without any embedding function columns"""
if schema is None:
return None
embedding_functions = EmbeddingFunctionRegistry.get_instance().parse_functions(
schema.metadata
)
if not embedding_functions:
return schema
return pa.schema([field for field in schema if field.name in columns])
def _coerce_to_table(data, schema: Optional[pa.Schema] = None) -> pa.Table:
def _into_pyarrow_table(data) -> pa.Table:
if _check_for_hugging_face(data):
# Huggingface datasets
from lance.dependencies import datasets
if isinstance(data, datasets.Dataset):
if schema is None:
schema = data.features.arrow_schema
schema = data.features.arrow_schema
return pa.Table.from_batches(data.data.to_batches(), schema=schema)
elif isinstance(data, datasets.dataset_dict.DatasetDict):
if schema is None:
schema = _schema_from_hf(data, schema)
schema = _schema_from_hf(data, schema)
return pa.Table.from_batches(_to_batches_with_split(data), schema=schema)
if isinstance(data, LanceModel):
raise ValueError("Cannot add a single LanceModel to a table. Use a list.")
@@ -111,17 +93,15 @@ def _coerce_to_table(data, schema: Optional[pa.Schema] = None) -> pa.Table:
if isinstance(data, list):
# convert to list of dict if data is a bunch of LanceModels
if isinstance(data[0], LanceModel):
if schema is None:
schema = data[0].__class__.to_arrow_schema()
schema = data[0].__class__.to_arrow_schema()
data = [model_to_dict(d) for d in data]
return pa.Table.from_pylist(data, schema=schema)
elif isinstance(data[0], pa.RecordBatch):
return pa.Table.from_batches(data, schema=schema)
return pa.Table.from_batches(data)
else:
return pa.Table.from_pylist(data, schema=schema)
elif _check_for_pandas(data) and isinstance(data, pd.DataFrame): # type: ignore
raw_schema = _pd_schema_without_embedding_funcs(schema, data.columns.to_list())
table = pa.Table.from_pandas(data, preserve_index=False, schema=raw_schema)
return pa.Table.from_pylist(data)
elif _check_for_pandas(data) and isinstance(data, pd.DataFrame):
table = pa.Table.from_pandas(data, preserve_index=False)
# Do not serialize Pandas metadata
meta = table.schema.metadata if table.schema.metadata is not None else {}
meta = {k: v for k, v in meta.items() if k != b"pandas"}
@@ -143,8 +123,13 @@ def _coerce_to_table(data, schema: Optional[pa.Schema] = None) -> pa.Table:
and data.__class__.__name__ == "DataFrame"
):
return data.to_arrow()
elif (
type(data).__module__.startswith("polars")
and data.__class__.__name__ == "LazyFrame"
):
return data.collect().to_arrow()
elif isinstance(data, Iterable):
return _process_iterator(data, schema)
return _iterator_to_table(data)
else:
raise TypeError(
f"Unknown data type {type(data)}. "
@@ -154,27 +139,172 @@ def _coerce_to_table(data, schema: Optional[pa.Schema] = None) -> pa.Table:
)
def _iterator_to_table(data: Iterable) -> pa.Table:
batches = []
schema = None # Will get schema from first batch
for batch in data:
batch_table = _into_pyarrow_table(batch)
if schema is not None:
if batch_table.schema != schema:
try:
batch_table = batch_table.cast(schema)
except pa.lib.ArrowInvalid:
raise ValueError(
f"Input iterator yielded a batch with schema that "
f"does not match the schema of other batches.\n"
f"Expected:\n{schema}\nGot:\n{batch_table.schema}"
)
else:
# Use the first schema for the remainder of the batches
schema = batch_table.schema
batches.append(batch_table)
if batches:
return pa.concat_tables(batches)
else:
raise ValueError("Input iterable is empty")
def _sanitize_data(
data: Any,
schema: Optional[pa.Schema] = None,
data: "DATA",
target_schema: Optional[pa.Schema] = None,
metadata: Optional[dict] = None, # embedding metadata
on_bad_vectors: str = "error",
on_bad_vectors: Literal["error", "drop", "fill", "null"] = "error",
fill_value: float = 0.0,
) -> Tuple[pa.Table, pa.Schema]:
data = _coerce_to_table(data, schema)
*,
allow_subschema: bool = False,
) -> pa.Table:
"""
Handle input data, applying all standard transformations.
This includes:
* Converting the data to a PyArrow Table
* Adding vector columns defined in the metadata
* Adding embedding metadata into the schema
* Casting the table to the target schema
* Handling bad vectors
Parameters
----------
target_schema : Optional[pa.Schema], default None
The schema to cast the table to. This is typically the schema of the table
if it already exists. Otherwise it might be a user-requested schema.
allow_subschema : bool, default False
If True, the input table is allowed to omit columns from the target schema.
The target schema will be filtered to only include columns that are present
in the input table before casting.
metadata : Optional[dict], default None
The embedding metadata to add to the schema.
on_bad_vectors : Literal["error", "drop", "fill", "null"], default "error"
What to do if any of the vectors are not the same size or contains NaNs.
fill_value : float, default 0.0
The value to use when filling vectors. Only used if on_bad_vectors="fill".
All entries in the vector will be set to this value.
"""
# At this point, the table might not match the schema we are targeting:
# 1. There might be embedding columns missing that will be added
# in the add_embeddings step.
# 2. If `allow_subschemas` is True, there might be columns missing.
table = _into_pyarrow_table(data)
table = _append_vector_columns(table, target_schema, metadata=metadata)
# This happens before the cast so we can fix vector columns with
# incorrect lengths before they are cast to FSL.
table = _handle_bad_vectors(
table,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
if target_schema is None:
target_schema = _infer_target_schema(table)
if metadata:
data = _append_vector_col(data, metadata, schema)
metadata.update(data.schema.metadata or {})
data = data.replace_schema_metadata(metadata)
new_metadata = target_schema.metadata or {}
new_metadata = new_metadata.update(metadata)
target_schema = target_schema.with_metadata(new_metadata)
# TODO improve the logics in _sanitize_schema
data = _sanitize_schema(data, schema, on_bad_vectors, fill_value)
if schema is None:
schema = data.schema
_validate_schema(target_schema)
_validate_schema(schema)
return data, schema
table = _cast_to_target_schema(table, target_schema, allow_subschema)
return table
def _cast_to_target_schema(
table: pa.Table,
target_schema: pa.Schema,
allow_subschema: bool = False,
) -> pa.Table:
# pa.Table.cast expects field order not to be changed.
# Lance doesn't care about field order, so we don't need to rearrange fields
# to match the target schema. We just need to correctly cast the fields.
if table.schema == target_schema:
# Fast path when the schemas are already the same
return table
fields = []
for field in table.schema:
target_field = target_schema.field(field.name)
if target_field is None:
raise ValueError(f"Field {field.name} not found in target schema")
fields.append(target_field)
reordered_schema = pa.schema(fields, metadata=target_schema.metadata)
if not allow_subschema and len(reordered_schema) != len(target_schema):
raise ValueError(
"Input table has different number of columns than target schema"
)
if allow_subschema and len(reordered_schema) != len(target_schema):
fields = _infer_subschema(
list(iter(table.schema)), list(iter(reordered_schema))
)
subschema = pa.schema(fields, metadata=target_schema.metadata)
return table.cast(subschema)
else:
return table.cast(reordered_schema)
def _infer_subschema(
schema: List[pa.Field],
reference_fields: List[pa.Field],
) -> List[pa.Field]:
"""
Transform the list of fields so the types match the reference_fields.
The order of the fields is preserved.
``schema`` may have fewer fields than `reference_fields`, but it may not have
more fields.
"""
fields = []
lookup = {f.name: f for f in reference_fields}
for field in schema:
reference = lookup.get(field.name)
if reference is None:
raise ValueError("Unexpected field in schema: {}".format(field))
if pa.types.is_struct(reference.type):
new_type = pa.struct(
_infer_subschema(
field.type.fields,
reference.type.fields,
)
)
new_field = pa.field(
field.name,
new_type,
reference.nullable,
)
else:
new_field = reference
fields.append(new_field)
return fields
def sanitize_create_table(
@@ -193,13 +323,14 @@ def sanitize_create_table(
if data is not None:
if metadata is None and schema is not None:
metadata = schema.metadata
data, schema = _sanitize_data(
data = _sanitize_data(
data,
schema,
metadata=metadata,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
schema = data.schema
else:
if schema is not None:
data = pa.Table.from_pylist([], schema)
@@ -211,6 +342,8 @@ def sanitize_create_table(
if metadata:
schema = schema.with_metadata(metadata)
# Need to apply metadata to the data as well
data = data.replace_schema_metadata(metadata)
return data, schema
@@ -246,12 +379,22 @@ def _to_batches_with_split(data):
yield b
def _append_vector_col(data: pa.Table, metadata: dict, schema: Optional[pa.Schema]):
def _append_vector_columns(
data: pa.Table,
schema: Optional[pa.Schema] = None,
*,
metadata: Optional[dict] = None,
) -> pa.Table:
"""
Use the embedding function to automatically embed the source column and add the
vector column to the table.
Use the embedding function to automatically embed the source columns and add the
vector columns to the table.
"""
if schema is None:
metadata = metadata or {}
else:
metadata = schema.metadata or metadata or {}
functions = EmbeddingFunctionRegistry.get_instance().parse_functions(metadata)
for vector_column, conf in functions.items():
func = conf.function
no_vector_column = vector_column not in data.column_names
@@ -790,9 +933,9 @@ class Table(ABC):
--------
>>> import lancedb
>>> data = [
... {"x": 1, "vector": [1, 2]},
... {"x": 2, "vector": [3, 4]},
... {"x": 3, "vector": [5, 6]}
... {"x": 1, "vector": [1.0, 2]},
... {"x": 2, "vector": [3.0, 4]},
... {"x": 3, "vector": [5.0, 6]}
... ]
>>> db = lancedb.connect("./.lancedb")
>>> table = db.create_table("my_table", data)
@@ -854,7 +997,7 @@ class Table(ABC):
--------
>>> import lancedb
>>> import pandas as pd
>>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1, 2], [3, 4], [5, 6]]})
>>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1.0, 2], [3, 4], [5, 6]]})
>>> db = lancedb.connect("./.lancedb")
>>> table = db.create_table("my_table", data)
>>> table.to_pandas()
@@ -862,7 +1005,7 @@ class Table(ABC):
0 1 [1.0, 2.0]
1 2 [3.0, 4.0]
2 3 [5.0, 6.0]
>>> table.update(where="x = 2", values={"vector": [10, 10]})
>>> table.update(where="x = 2", values={"vector": [10.0, 10]})
>>> table.to_pandas()
x vector
0 1 [1.0, 2.0]
@@ -1880,9 +2023,9 @@ class LanceTable(Table):
--------
>>> import lancedb
>>> data = [
... {"x": 1, "vector": [1, 2]},
... {"x": 2, "vector": [3, 4]},
... {"x": 3, "vector": [5, 6]}
... {"x": 1, "vector": [1.0, 2]},
... {"x": 2, "vector": [3.0, 4]},
... {"x": 3, "vector": [5.0, 6]}
... ]
>>> db = lancedb.connect("./.lancedb")
>>> table = db.create_table("my_table", data)
@@ -1971,7 +2114,7 @@ class LanceTable(Table):
--------
>>> import lancedb
>>> import pandas as pd
>>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1, 2], [3, 4], [5, 6]]})
>>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1.0, 2], [3, 4], [5, 6]]})
>>> db = lancedb.connect("./.lancedb")
>>> table = db.create_table("my_table", data)
>>> table.to_pandas()
@@ -1979,7 +2122,7 @@ class LanceTable(Table):
0 1 [1.0, 2.0]
1 2 [3.0, 4.0]
2 3 [5.0, 6.0]
>>> table.update(where="x = 2", values={"vector": [10, 10]})
>>> table.update(where="x = 2", values={"vector": [10.0, 10]})
>>> table.to_pandas()
x vector
0 1 [1.0, 2.0]
@@ -2165,74 +2308,49 @@ class LanceTable(Table):
LOOP.run(self._table.migrate_v2_manifest_paths())
def _sanitize_schema(
data: pa.Table,
schema: pa.Schema = None,
on_bad_vectors: str = "error",
def _handle_bad_vectors(
table: pa.Table,
on_bad_vectors: Literal["error", "drop", "fill", "null"] = "error",
fill_value: float = 0.0,
) -> pa.Table:
"""Ensure that the table has the expected schema.
Parameters
----------
data: pa.Table
The table to sanitize.
schema: pa.Schema; optional
The expected schema. If not provided, this just converts the
vector column to fixed_size_list(float32) if necessary.
on_bad_vectors: str, default "error"
What to do if any of the vectors are not the same size or contains NaNs.
One of "error", "drop", "fill", "null".
fill_value: float, default 0.
The value to use when filling vectors. Only used if on_bad_vectors="fill".
"""
if schema is not None:
# cast the columns to the expected types
data = data.combine_chunks()
for field in schema:
# TODO: we're making an assumption that fixed size list of 10 or more
# is a vector column. This is definitely a bit hacky.
likely_vector_col = (
pa.types.is_fixed_size_list(field.type)
and pa.types.is_float32(field.type.value_type)
and field.type.list_size >= 10
for field in table.schema:
# They can provide a 'vector' column that isn't yet a FSL
named_vector_col = (
(
pa.types.is_list(field.type)
or pa.types.is_large_list(field.type)
or pa.types.is_fixed_size_list(field.type)
)
is_default_vector_col = field.name == VECTOR_COLUMN_NAME
if field.name in data.column_names and (
likely_vector_col or is_default_vector_col
):
data = _sanitize_vector_column(
data,
vector_column_name=field.name,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
table_schema=schema,
)
return pa.Table.from_arrays(
[data[name] for name in schema.names], schema=schema
and pa.types.is_floating(field.type.value_type)
and field.name == VECTOR_COLUMN_NAME
)
# TODO: we're making an assumption that fixed size list of 10 or more
# is a vector column. This is definitely a bit hacky.
likely_vector_col = (
pa.types.is_fixed_size_list(field.type)
and pa.types.is_floating(field.type.value_type)
and (field.type.list_size >= 10)
)
# just check the vector column
if VECTOR_COLUMN_NAME in data.column_names:
return _sanitize_vector_column(
data,
vector_column_name=VECTOR_COLUMN_NAME,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
if named_vector_col or likely_vector_col:
table = _handle_bad_vector_column(
table,
vector_column_name=field.name,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
)
return data
return table
def _sanitize_vector_column(
def _handle_bad_vector_column(
data: pa.Table,
vector_column_name: str,
table_schema: Optional[pa.Schema] = None,
on_bad_vectors: str = "error",
fill_value: float = 0.0,
) -> pa.Table:
"""
Ensure that the vector column exists and has type fixed_size_list(float32)
Ensure that the vector column exists and has type fixed_size_list(float)
Parameters
----------
@@ -2246,141 +2364,118 @@ def _sanitize_vector_column(
fill_value: float, default 0.0
The value to use when filling vectors. Only used if on_bad_vectors="fill".
"""
# ChunkedArray is annoying to work with, so we combine chunks here
vec_arr = data[vector_column_name].combine_chunks()
if table_schema is not None:
field = table_schema.field(vector_column_name)
else:
field = None
typ = data[vector_column_name].type
if pa.types.is_list(typ) or pa.types.is_large_list(typ):
# if it's a variable size list array,
# we make sure the dimensions are all the same
has_jagged_ndims = len(vec_arr.values) % len(data) != 0
if has_jagged_ndims:
data = _sanitize_jagged(
data, fill_value, on_bad_vectors, vec_arr, vector_column_name
)
vec_arr = data[vector_column_name].combine_chunks()
vec_arr = ensure_fixed_size_list(vec_arr)
data = data.set_column(
data.column_names.index(vector_column_name), vector_column_name, vec_arr
)
elif not pa.types.is_fixed_size_list(vec_arr.type):
raise TypeError(f"Unsupported vector column type: {vec_arr.type}")
vec_arr = data[vector_column_name]
if pa.types.is_float16(vec_arr.values.type):
# Use numpy to check for NaNs, because as pyarrow does not have `is_nan`
# kernel over f16 types yet.
values_np = vec_arr.values.to_numpy(zero_copy_only=True)
if np.isnan(values_np).any():
data = _sanitize_nans(
data, fill_value, on_bad_vectors, vec_arr, vector_column_name
)
else:
if (
field is not None
and not field.nullable
and pc.any(pc.is_null(vec_arr.values)).as_py()
) or (pc.any(pc.is_nan(vec_arr.values)).as_py()):
data = _sanitize_nans(
data, fill_value, on_bad_vectors, vec_arr, vector_column_name
)
return data
has_nan = has_nan_values(vec_arr)
def ensure_fixed_size_list(vec_arr) -> pa.FixedSizeListArray:
values = vec_arr.values
if not (pa.types.is_float16(values.type) or pa.types.is_float32(values.type)):
values = values.cast(pa.float32())
if pa.types.is_fixed_size_list(vec_arr.type):
list_size = vec_arr.type.list_size
dim = vec_arr.type.list_size
else:
list_size = len(values) / len(vec_arr)
vec_arr = pa.FixedSizeListArray.from_arrays(values, list_size)
return vec_arr
dim = _modal_list_size(vec_arr)
has_wrong_dim = pc.not_equal(pc.list_value_length(vec_arr), dim)
has_bad_vectors = pc.any(has_nan).as_py() or pc.any(has_wrong_dim).as_py()
def _sanitize_jagged(data, fill_value, on_bad_vectors, vec_arr, vector_column_name):
"""Sanitize jagged vectors."""
if on_bad_vectors == "error":
raise ValueError(
f"Vector column {vector_column_name} has variable length vectors "
"Set on_bad_vectors='drop' to remove them, or "
"set on_bad_vectors='fill' and fill_value=<value> to replace them."
)
lst_lengths = pc.list_value_length(vec_arr)
ndims = pc.max(lst_lengths).as_py()
correct_ndims = pc.equal(lst_lengths, ndims)
if on_bad_vectors == "fill":
if fill_value is None:
raise ValueError(
"`fill_value` must not be None if `on_bad_vectors` is 'fill'"
if has_bad_vectors:
is_bad = pc.or_(has_nan, has_wrong_dim)
if on_bad_vectors == "error":
if pc.any(has_wrong_dim).as_py():
raise ValueError(
f"Vector column '{vector_column_name}' has variable length "
"vectors. Set on_bad_vectors='drop' to remove them, "
"set on_bad_vectors='fill' and fill_value=<value> to replace them, "
"or set on_bad_vectors='null' to replace them with null."
)
else:
raise ValueError(
f"Vector column '{vector_column_name}' has NaNs. "
"Set on_bad_vectors='drop' to remove them, "
"set on_bad_vectors='fill' and fill_value=<value> to replace them, "
"or set on_bad_vectors='null' to replace them with null."
)
elif on_bad_vectors == "null":
vec_arr = pc.if_else(
is_bad,
pa.scalar(None),
vec_arr,
)
fill_arr = pa.scalar([float(fill_value)] * ndims)
vec_arr = pc.if_else(correct_ndims, vec_arr, fill_arr)
data = data.set_column(
data.column_names.index(vector_column_name), vector_column_name, vec_arr
)
elif on_bad_vectors == "drop":
data = data.filter(correct_ndims)
elif on_bad_vectors == "null":
data = data.set_column(
data.column_names.index(vector_column_name),
vector_column_name,
pc.if_else(correct_ndims, vec_arr, pa.scalar(None)),
)
return data
def _sanitize_nans(
data,
fill_value,
on_bad_vectors,
vec_arr: pa.FixedSizeListArray,
vector_column_name: str,
):
"""Sanitize NaNs in vectors"""
assert pa.types.is_fixed_size_list(vec_arr.type)
if on_bad_vectors == "error":
raise ValueError(
f"Vector column {vector_column_name} has NaNs. "
"Set on_bad_vectors='drop' to remove them, or "
"set on_bad_vectors='fill' and fill_value=<value> to replace them. "
"Or set on_bad_vectors='null' to replace them with null."
)
elif on_bad_vectors == "fill":
if fill_value is None:
raise ValueError(
"`fill_value` must not be None if `on_bad_vectors` is 'fill'"
elif on_bad_vectors == "drop":
data = data.filter(pc.invert(is_bad))
vec_arr = data[vector_column_name]
elif on_bad_vectors == "fill":
if fill_value is None:
raise ValueError(
"`fill_value` must not be None if `on_bad_vectors` is 'fill'"
)
vec_arr = pc.if_else(
is_bad,
pa.scalar([fill_value] * dim),
vec_arr,
)
fill_value = float(fill_value)
values = pc.if_else(pc.is_nan(vec_arr.values), fill_value, vec_arr.values)
ndims = len(vec_arr[0])
vec_arr = pa.FixedSizeListArray.from_arrays(values, ndims)
data = data.set_column(
data.column_names.index(vector_column_name), vector_column_name, vec_arr
)
elif on_bad_vectors == "drop":
# Drop is very slow to be able to filter out NaNs in a fixed size list array
np_arr = np.isnan(vec_arr.values.to_numpy(zero_copy_only=False))
np_arr = np_arr.reshape(-1, vec_arr.type.list_size)
not_nulls = np.any(np_arr, axis=1)
data = data.filter(~not_nulls)
elif on_bad_vectors == "null":
# null = pa.nulls(len(vec_arr)).cast(vec_arr.type)
# values = pc.if_else(pc.is_nan(vec_arr.values), fill_value, vec_arr.values)
np_arr = np.isnan(vec_arr.values.to_numpy(zero_copy_only=False))
np_arr = np_arr.reshape(-1, vec_arr.type.list_size)
no_nans = np.any(np_arr, axis=1)
data = data.set_column(
data.column_names.index(vector_column_name),
vector_column_name,
pc.if_else(no_nans, vec_arr, pa.scalar(None)),
)
return data
else:
raise ValueError(f"Invalid value for on_bad_vectors: {on_bad_vectors}")
position = data.column_names.index(vector_column_name)
return data.set_column(position, vector_column_name, vec_arr)
def has_nan_values(arr: Union[pa.ListArray, pa.ChunkedArray]) -> pa.BooleanArray:
if isinstance(arr, pa.ChunkedArray):
values = pa.chunked_array([chunk.flatten() for chunk in arr.chunks])
else:
values = arr.flatten()
if pa.types.is_float16(values.type):
# is_nan isn't yet implemented for f16, so we cast to f32
# https://github.com/apache/arrow/issues/45083
values_has_nan = pc.is_nan(values.cast(pa.float32()))
else:
values_has_nan = pc.is_nan(values)
values_indices = pc.list_parent_indices(arr)
has_nan_indices = pc.unique(pc.filter(values_indices, values_has_nan))
indices = pa.array(range(len(arr)), type=pa.uint32())
return pc.is_in(indices, has_nan_indices)
def _infer_target_schema(table: pa.Table) -> pa.Schema:
schema = table.schema
for i, field in enumerate(schema):
if (
field.name == VECTOR_COLUMN_NAME
and (pa.types.is_list(field.type) or pa.types.is_large_list(field.type))
and pa.types.is_floating(field.type.value_type)
):
# Use the most common length of the list as the dimensions
dim = _modal_list_size(table.column(i))
new_field = pa.field(
VECTOR_COLUMN_NAME,
pa.list_(pa.float32(), dim),
nullable=field.nullable,
)
schema = schema.set(i, new_field)
elif (
field.name == VECTOR_COLUMN_NAME
and (pa.types.is_list(field.type) or pa.types.is_large_list(field.type))
and pa.types.is_integer(field.type.value_type)
):
# Use the most common length of the list as the dimensions
dim = _modal_list_size(table.column(i))
new_field = pa.field(
VECTOR_COLUMN_NAME,
pa.list_(pa.uint8(), dim),
nullable=field.nullable,
)
schema = schema.set(i, new_field)
return schema
def _modal_list_size(arr: Union[pa.ListArray, pa.ChunkedArray]) -> int:
# Use the most common length of the list as the dimensions
return pc.mode(pc.list_value_length(arr))[0].as_py()["mode"]
def _validate_schema(schema: pa.Schema):
@@ -2410,28 +2505,6 @@ def _validate_metadata(metadata: dict):
_validate_metadata(v)
def _process_iterator(data: Iterable, schema: Optional[pa.Schema] = None) -> pa.Table:
batches = []
for batch in data:
batch_table = _coerce_to_table(batch, schema)
if schema is not None:
if batch_table.schema != schema:
try:
batch_table = batch_table.cast(schema)
except pa.lib.ArrowInvalid: # type: ignore
raise ValueError(
f"Input iterator yielded a batch with schema that "
f"does not match the expected schema.\nExpected:\n{schema}\n"
f"Got:\n{batch_table.schema}"
)
batches.append(batch_table)
if batches:
return pa.concat_tables(batches)
else:
raise ValueError("Input iterable is empty")
class AsyncTable:
"""
An AsyncTable is a collection of Records in a LanceDB Database.
@@ -2678,16 +2751,17 @@ class AsyncTable:
on_bad_vectors = "error"
if fill_value is None:
fill_value = 0.0
table_and_schema: Tuple[pa.Table, pa.Schema] = _sanitize_data(
data = _sanitize_data(
data,
schema,
metadata=schema.metadata,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
allow_subschema=True,
)
tbl, schema = table_and_schema
if isinstance(tbl, pa.Table):
data = pa.RecordBatchReader.from_batches(schema, tbl.to_batches())
if isinstance(data, pa.Table):
data = data.to_reader()
await self._inner.add(data, mode or "append")
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
@@ -2822,12 +2896,13 @@ class AsyncTable:
on_bad_vectors = "error"
if fill_value is None:
fill_value = 0.0
data, _ = _sanitize_data(
data = _sanitize_data(
new_data,
schema,
metadata=schema.metadata,
on_bad_vectors=on_bad_vectors,
fill_value=fill_value,
allow_subschema=True,
)
if isinstance(data, pa.Table):
data = pa.RecordBatchReader.from_batches(data.schema, data.to_batches())
@@ -2862,9 +2937,9 @@ class AsyncTable:
--------
>>> import lancedb
>>> data = [
... {"x": 1, "vector": [1, 2]},
... {"x": 2, "vector": [3, 4]},
... {"x": 3, "vector": [5, 6]}
... {"x": 1, "vector": [1.0, 2]},
... {"x": 2, "vector": [3.0, 4]},
... {"x": 3, "vector": [5.0, 6]}
... ]
>>> db = lancedb.connect("./.lancedb")
>>> table = db.create_table("my_table", data)

View File

@@ -223,9 +223,7 @@ def inf_vector_column_query(schema: pa.Schema) -> str:
vector_col_count = 0
for field_name in schema.names:
field = schema.field(field_name)
if pa.types.is_fixed_size_list(field.type) and pa.types.is_floating(
field.type.value_type
):
if pa.types.is_fixed_size_list(field.type):
vector_col_count += 1
if vector_col_count > 1:
raise ValueError(

View File

@@ -21,7 +21,7 @@ def test_binary_vector():
]
tbl = db.create_table("my_binary_vectors", data=data)
query = np.random.randint(0, 256, size=16)
tbl.search(query).to_arrow()
tbl.search(query).metric("hamming").to_arrow()
# --8<-- [end:sync_binary_vector]
db.drop_table("my_binary_vectors")
@@ -39,6 +39,6 @@ async def test_binary_vector_async():
]
tbl = await db.create_table("my_binary_vectors", data=data)
query = np.random.randint(0, 256, size=16)
await tbl.query().nearest_to(query).to_arrow()
await tbl.query().nearest_to(query).distance_type("hamming").to_arrow()
# --8<-- [end:async_binary_vector]
await db.drop_table("my_binary_vectors")

View File

@@ -118,9 +118,9 @@ def test_scalar_index():
# --8<-- [end:search_with_scalar_index]
# --8<-- [start:vector_search_with_scalar_index]
data = [
{"book_id": 1, "vector": [1, 2]},
{"book_id": 2, "vector": [3, 4]},
{"book_id": 3, "vector": [5, 6]},
{"book_id": 1, "vector": [1.0, 2]},
{"book_id": 2, "vector": [3.0, 4]},
{"book_id": 3, "vector": [5.0, 6]},
]
table = db.create_table("book_with_embeddings", data)
@@ -156,9 +156,9 @@ async def test_scalar_index_async():
# --8<-- [end:search_with_scalar_index_async]
# --8<-- [start:vector_search_with_scalar_index_async]
data = [
{"book_id": 1, "vector": [1, 2]},
{"book_id": 2, "vector": [3, 4]},
{"book_id": 3, "vector": [5, 6]},
{"book_id": 1, "vector": [1.0, 2]},
{"book_id": 2, "vector": [3.0, 4]},
{"book_id": 3, "vector": [5.0, 6]},
]
async_tbl = await async_db.create_table("book_with_embeddings_async", data)
(await async_tbl.query().where("book_id != 3").nearest_to([1, 2]).to_pandas())

View File

@@ -198,7 +198,6 @@ def test_embedding_function_with_pandas(tmp_path):
{
"text": ["hello world", "goodbye world"],
"val": [1, 2],
"not-used": ["s1", "s3"],
}
)
db = lancedb.connect(tmp_path)
@@ -212,7 +211,6 @@ def test_embedding_function_with_pandas(tmp_path):
{
"text": ["extra", "more"],
"val": [4, 5],
"misc-col": ["s1", "s3"],
}
)
tbl.add(df)

View File

@@ -242,8 +242,8 @@ def test_add_subschema(mem_db: DBConnection):
data = {"price": 10.0, "item": "foo"}
table.add([data])
data = {"price": 2.0, "vector": [3.1, 4.1]}
table.add([data])
data = pd.DataFrame({"price": [2.0], "vector": [[3.1, 4.1]]})
table.add(data)
data = {"price": 3.0, "vector": [5.9, 26.5], "item": "bar"}
table.add([data])
@@ -259,7 +259,7 @@ def test_add_subschema(mem_db: DBConnection):
data = {"item": "foo"}
# We can't omit a column if it's not nullable
with pytest.raises(RuntimeError, match="Invalid user input"):
with pytest.raises(RuntimeError, match="Append with different schema"):
table.add([data])
# We can add it if we make the column nullable
@@ -292,6 +292,7 @@ def test_add_nullability(mem_db: DBConnection):
]
)
table = mem_db.create_table("test", schema=schema)
assert table.schema.field("vector").nullable is False
nullable_schema = pa.schema(
[
@@ -320,7 +321,10 @@ def test_add_nullability(mem_db: DBConnection):
schema=nullable_schema,
)
# We can't add nullable schema if it contains nulls
with pytest.raises(Exception, match="Vector column vector has NaNs"):
with pytest.raises(
Exception,
match="Casting field 'vector' with null values to non-nullable",
):
table.add(data)
# But we can make it nullable
@@ -776,6 +780,38 @@ def test_merge_insert(mem_db: DBConnection):
assert table.to_arrow().sort_by("a") == expected
# We vary the data format because there are slight differences in how
# subschemas are handled in different formats
@pytest.mark.parametrize(
"data_format",
[
lambda table: table,
lambda table: table.to_pandas(),
lambda table: table.to_pylist(),
],
ids=["pa.Table", "pd.DataFrame", "rows"],
)
def test_merge_insert_subschema(mem_db: DBConnection, data_format):
initial_data = pa.table(
{"id": range(3), "a": [1.0, 2.0, 3.0], "c": ["x", "x", "x"]}
)
table = mem_db.create_table("my_table", data=initial_data)
new_data = pa.table({"id": [2, 3], "c": ["y", "y"]})
new_data = data_format(new_data)
(
table.merge_insert(on="id")
.when_matched_update_all()
.when_not_matched_insert_all()
.execute(new_data)
)
expected = pa.table(
{"id": [0, 1, 2, 3], "a": [1.0, 2.0, 3.0, None], "c": ["x", "x", "y", "y"]}
)
assert table.to_arrow().sort_by("id") == expected
@pytest.mark.asyncio
async def test_merge_insert_async(mem_db_async: AsyncConnection):
data = pa.table({"a": [1, 2, 3], "b": ["a", "b", "c"]})

View File

@@ -13,10 +13,27 @@
import os
import pathlib
from typing import Optional
import lance
from lancedb.conftest import MockTextEmbeddingFunction
from lancedb.embeddings.base import EmbeddingFunctionConfig
from lancedb.embeddings.registry import EmbeddingFunctionRegistry
from lancedb.table import (
_append_vector_columns,
_cast_to_target_schema,
_handle_bad_vectors,
_into_pyarrow_table,
_sanitize_data,
_infer_target_schema,
)
import pyarrow as pa
import pandas as pd
import polars as pl
import pytest
import lancedb
from lancedb.util import get_uri_scheme, join_uri, value_to_sql
from utils import exception_output
def test_normalize_uri():
@@ -111,3 +128,460 @@ def test_value_to_sql_string(tmp_path):
for value in values:
table.update(where=f"search = {value_to_sql(value)}", values={"replace": value})
assert table.to_pandas().query("search == @value")["replace"].item() == value
def test_append_vector_columns():
registry = EmbeddingFunctionRegistry.get_instance()
registry.register("test")(MockTextEmbeddingFunction)
conf = EmbeddingFunctionConfig(
source_column="text",
vector_column="vector",
function=MockTextEmbeddingFunction(),
)
metadata = registry.get_table_metadata([conf])
schema = pa.schema(
{
"text": pa.string(),
"vector": pa.list_(pa.float64(), 10),
}
)
data = pa.table(
{
"text": ["hello"],
"vector": [None], # Replaces null
},
schema=schema,
)
output = _append_vector_columns(
data,
schema, # metadata passed separate from schema
metadata=metadata,
)
assert output.schema == schema
assert output["vector"].null_count == 0
# Adds if missing
data = pa.table({"text": ["hello"]})
output = _append_vector_columns(
data,
schema.with_metadata(metadata),
)
assert output.schema == schema
assert output["vector"].null_count == 0
# doesn't embed if already there
data = pa.table(
{
"text": ["hello"],
"vector": [[42.0] * 10],
},
schema=schema,
)
output = _append_vector_columns(
data,
schema.with_metadata(metadata),
)
assert output == data # No change
# No provided schema
data = pa.table(
{
"text": ["hello"],
}
)
output = _append_vector_columns(
data,
metadata=metadata,
)
expected_schema = pa.schema(
{
"text": pa.string(),
"vector": pa.list_(pa.float32(), 10),
}
)
assert output.schema == expected_schema
assert output["vector"].null_count == 0
@pytest.mark.parametrize("on_bad_vectors", ["error", "drop", "fill", "null"])
def test_handle_bad_vectors_jagged(on_bad_vectors):
vector = pa.array([[1.0, 2.0], [3.0], [4.0, 5.0]])
schema = pa.schema({"vector": pa.list_(pa.float64())})
data = pa.table({"vector": vector}, schema=schema)
if on_bad_vectors == "error":
with pytest.raises(ValueError) as e:
output = _handle_bad_vectors(
data,
on_bad_vectors=on_bad_vectors,
)
output = exception_output(e)
assert output == (
"ValueError: Vector column 'vector' has variable length vectors. Set "
"on_bad_vectors='drop' to remove them, set on_bad_vectors='fill' "
"and fill_value=<value> to replace them, or set on_bad_vectors='null' "
"to replace them with null."
)
return
else:
output = _handle_bad_vectors(
data,
on_bad_vectors=on_bad_vectors,
fill_value=42.0,
)
if on_bad_vectors == "drop":
expected = pa.array([[1.0, 2.0], [4.0, 5.0]])
elif on_bad_vectors == "fill":
expected = pa.array([[1.0, 2.0], [42.0, 42.0], [4.0, 5.0]])
elif on_bad_vectors == "null":
expected = pa.array([[1.0, 2.0], None, [4.0, 5.0]])
assert output["vector"].combine_chunks() == expected
@pytest.mark.parametrize("on_bad_vectors", ["error", "drop", "fill", "null"])
def test_handle_bad_vectors_nan(on_bad_vectors):
vector = pa.array([[1.0, float("nan")], [3.0, 4.0]])
data = pa.table({"vector": vector})
if on_bad_vectors == "error":
with pytest.raises(ValueError) as e:
output = _handle_bad_vectors(
data,
on_bad_vectors=on_bad_vectors,
)
output = exception_output(e)
assert output == (
"ValueError: Vector column 'vector' has NaNs. Set "
"on_bad_vectors='drop' to remove them, set on_bad_vectors='fill' "
"and fill_value=<value> to replace them, or set on_bad_vectors='null' "
"to replace them with null."
)
return
else:
output = _handle_bad_vectors(
data,
on_bad_vectors=on_bad_vectors,
fill_value=42.0,
)
if on_bad_vectors == "drop":
expected = pa.array([[3.0, 4.0]])
elif on_bad_vectors == "fill":
expected = pa.array([[42.0, 42.0], [3.0, 4.0]])
elif on_bad_vectors == "null":
expected = pa.array([None, [3.0, 4.0]])
assert output["vector"].combine_chunks() == expected
def test_handle_bad_vectors_noop():
# ChunkedArray should be preserved as-is
vector = pa.chunked_array(
[[[1.0, 2.0], [3.0, 4.0]]], type=pa.list_(pa.float64(), 2)
)
data = pa.table({"vector": vector})
output = _handle_bad_vectors(data)
assert output["vector"] == vector
class TestModel(lancedb.pydantic.LanceModel):
a: Optional[int]
b: Optional[int]
# TODO: huggingface,
@pytest.mark.parametrize(
"data",
[
lambda: [{"a": 1, "b": 2}],
lambda: pa.RecordBatch.from_pylist([{"a": 1, "b": 2}]),
lambda: pa.table({"a": [1], "b": [2]}),
lambda: pa.table({"a": [1], "b": [2]}).to_reader(),
lambda: iter(pa.table({"a": [1], "b": [2]}).to_batches()),
lambda: (
lance.write_dataset(
pa.table({"a": [1], "b": [2]}),
"memory://test",
)
),
lambda: (
lance.write_dataset(
pa.table({"a": [1], "b": [2]}),
"memory://test",
).scanner()
),
lambda: pd.DataFrame({"a": [1], "b": [2]}),
lambda: pl.DataFrame({"a": [1], "b": [2]}),
lambda: pl.LazyFrame({"a": [1], "b": [2]}),
lambda: [TestModel(a=1, b=2)],
],
ids=[
"rows",
"pa.RecordBatch",
"pa.Table",
"pa.RecordBatchReader",
"batch_iter",
"lance.LanceDataset",
"lance.LanceScanner",
"pd.DataFrame",
"pl.DataFrame",
"pl.LazyFrame",
"pydantic",
],
)
def test_into_pyarrow_table(data):
expected = pa.table({"a": [1], "b": [2]})
output = _into_pyarrow_table(data())
assert output == expected
def test_infer_target_schema():
example = pa.schema(
{
"vec1": pa.list_(pa.float64(), 2),
"vector": pa.list_(pa.float64()),
}
)
data = pa.table(
{
"vec1": [[0.0] * 2],
"vector": [[0.0] * 2],
},
schema=example,
)
expected = pa.schema(
{
"vec1": pa.list_(pa.float64(), 2),
"vector": pa.list_(pa.float32(), 2),
}
)
output = _infer_target_schema(data)
assert output == expected
# Handle large list and use modal size
# Most vectors are of length 2, so we should infer that as the target dimension
example = pa.schema(
{
"vector": pa.large_list(pa.float64()),
}
)
data = pa.table(
{
"vector": [[0.0] * 2, [0.0], [0.0] * 2],
},
schema=example,
)
expected = pa.schema(
{
"vector": pa.list_(pa.float32(), 2),
}
)
output = _infer_target_schema(data)
assert output == expected
# ignore if not list
example = pa.schema(
{
"vector": pa.float64(),
}
)
data = pa.table(
{
"vector": [0.0],
},
schema=example,
)
expected = example
output = _infer_target_schema(data)
assert output == expected
@pytest.mark.parametrize(
"data",
[
[{"id": 1, "text": "hello"}],
pa.RecordBatch.from_pylist([{"id": 1, "text": "hello"}]),
pd.DataFrame({"id": [1], "text": ["hello"]}),
pl.DataFrame({"id": [1], "text": ["hello"]}),
],
ids=["rows", "pa.RecordBatch", "pd.DataFrame", "pl.DataFrame"],
)
@pytest.mark.parametrize(
"schema",
[
None,
pa.schema(
{
"id": pa.int32(),
"text": pa.string(),
"vector": pa.list_(pa.float32(), 10),
}
),
pa.schema(
{
"id": pa.int64(),
"text": pa.string(),
"vector": pa.list_(pa.float32(), 10),
"extra": pa.int64(),
}
),
],
ids=["infer", "explicit", "subschema"],
)
@pytest.mark.parametrize("with_embedding", [True, False])
def test_sanitize_data(
data,
schema: Optional[pa.Schema],
with_embedding: bool,
):
if with_embedding:
registry = EmbeddingFunctionRegistry.get_instance()
registry.register("test")(MockTextEmbeddingFunction)
conf = EmbeddingFunctionConfig(
source_column="text",
vector_column="vector",
function=MockTextEmbeddingFunction(),
)
metadata = registry.get_table_metadata([conf])
else:
metadata = None
if schema is not None:
to_remove = schema.get_field_index("extra")
if to_remove >= 0:
expected_schema = schema.remove(to_remove)
else:
expected_schema = schema
else:
expected_schema = pa.schema(
{
"id": pa.int64(),
"text": pa.large_utf8()
if isinstance(data, pl.DataFrame)
else pa.string(),
"vector": pa.list_(pa.float32(), 10),
}
)
if not with_embedding:
to_remove = expected_schema.get_field_index("vector")
if to_remove >= 0:
expected_schema = expected_schema.remove(to_remove)
expected = pa.table(
{
"id": [1],
"text": ["hello"],
"vector": [[0.0] * 10],
},
schema=expected_schema,
)
output_data = _sanitize_data(
data,
target_schema=schema,
metadata=metadata,
allow_subschema=True,
)
assert output_data == expected
def test_cast_to_target_schema():
original_schema = pa.schema(
{
"id": pa.int32(),
"struct": pa.struct(
[
pa.field("a", pa.int32()),
]
),
"vector": pa.list_(pa.float64()),
"vec1": pa.list_(pa.float64(), 2),
"vec2": pa.list_(pa.float32(), 2),
}
)
data = pa.table(
{
"id": [1],
"struct": [{"a": 1}],
"vector": [[0.0] * 2],
"vec1": [[0.0] * 2],
"vec2": [[0.0] * 2],
},
schema=original_schema,
)
target = pa.schema(
{
"id": pa.int64(),
"struct": pa.struct(
[
pa.field("a", pa.int64()),
]
),
"vector": pa.list_(pa.float32(), 2),
"vec1": pa.list_(pa.float32(), 2),
"vec2": pa.list_(pa.float32(), 2),
}
)
output = _cast_to_target_schema(data, target)
expected = pa.table(
{
"id": [1],
"struct": [{"a": 1}],
"vector": [[0.0] * 2],
"vec1": [[0.0] * 2],
"vec2": [[0.0] * 2],
},
schema=target,
)
# Data can be a subschema of the target
target = pa.schema(
{
"id": pa.int64(),
"struct": pa.struct(
[
pa.field("a", pa.int64()),
# Additional nested field
pa.field("b", pa.int64()),
]
),
"vector": pa.list_(pa.float32(), 2),
"vec1": pa.list_(pa.float32(), 2),
"vec2": pa.list_(pa.float32(), 2),
# Additional field
"extra": pa.int64(),
}
)
with pytest.raises(Exception):
_cast_to_target_schema(data, target)
output = _cast_to_target_schema(data, target, allow_subschema=True)
expected_schema = pa.schema(
{
"id": pa.int64(),
"struct": pa.struct(
[
pa.field("a", pa.int64()),
]
),
"vector": pa.list_(pa.float32(), 2),
"vec1": pa.list_(pa.float32(), 2),
"vec2": pa.list_(pa.float32(), 2),
}
)
expected = pa.table(
{
"id": [1],
"struct": [{"a": 1}],
"vector": [[0.0] * 2],
"vec1": [[0.0] * 2],
"vec2": [[0.0] * 2],
},
schema=expected_schema,
)
assert output == expected