mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-02 18:02:58 +00:00
feat(python)!: support inserting and upserting subschemas (#1965)
BREAKING CHANGE: For a field "vector", list of integers will now be converted to binary (uint8) vectors instead of f32 vectors. Use float values instead for f32 vectors. * Adds proper support for inserting and upserting subsets of the full schema. I thought I had previously implemented this in #1827, but it turns out I had not tested carefully enough. * Refactors `_santize_data` and other utility functions to be simpler and not require `numpy` or `combine_chunks()`. * Added a new suite of unit tests to validate sanitization utilities. ## Examples ```python import pandas as pd import lancedb db = lancedb.connect("memory://demo") intial_data = pd.DataFrame({ "a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9] }) table = db.create_table("demo", intial_data) # Insert a subschema new_data = pd.DataFrame({"a": [10, 11]}) table.add(new_data) table.to_pandas() ``` ``` a b c 0 1 4.0 7.0 1 2 5.0 8.0 2 3 6.0 9.0 3 10 NaN NaN 4 11 NaN NaN ``` ```python # Upsert a subschema upsert_data = pd.DataFrame({ "a": [3, 10, 15], "b": [6, 7, 8], }) table.merge_insert(on="a").when_matched_update_all().when_not_matched_insert_all().execute(upsert_data) table.to_pandas() ``` ``` a b c 0 1 4.0 7.0 1 2 5.0 8.0 2 3 6.0 9.0 3 10 7.0 NaN 4 11 NaN NaN 5 15 8.0 NaN ```
This commit is contained in:
@@ -784,10 +784,6 @@ class AsyncConnection(object):
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
metadata = registry.get_table_metadata(embedding_functions)
|
||||
|
||||
data, schema = sanitize_create_table(
|
||||
data, schema, metadata, on_bad_vectors, fill_value
|
||||
)
|
||||
|
||||
# Defining defaults here and not in function prototype. In the future
|
||||
# these defaults will move into rust so better to keep them as None.
|
||||
if on_bad_vectors is None:
|
||||
|
||||
@@ -108,9 +108,14 @@ class EmbeddingFunctionRegistry:
|
||||
An empty dict is returned if input is None or does not
|
||||
contain b"embedding_functions".
|
||||
"""
|
||||
if metadata is None or b"embedding_functions" not in metadata:
|
||||
if metadata is None:
|
||||
return {}
|
||||
# Look at both bytes and string keys, since we might use either
|
||||
serialized = metadata.get(
|
||||
b"embedding_functions", metadata.get("embedding_functions")
|
||||
)
|
||||
if serialized is None:
|
||||
return {}
|
||||
serialized = metadata[b"embedding_functions"]
|
||||
raw_list = json.loads(serialized.decode("utf-8"))
|
||||
return {
|
||||
obj["vector_column"]: EmbeddingFunctionConfig(
|
||||
|
||||
@@ -472,7 +472,7 @@ class LanceQueryBuilder(ABC):
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> table = db.create_table("my_table", [{"vector": [99, 99]}])
|
||||
>>> table = db.create_table("my_table", [{"vector": [99.0, 99]}])
|
||||
>>> query = [100, 100]
|
||||
>>> plan = table.search(query).explain_plan(True)
|
||||
>>> print(plan) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
|
||||
|
||||
@@ -25,7 +25,6 @@ from urllib.parse import urlparse
|
||||
import lance
|
||||
from lancedb.background_loop import LOOP
|
||||
from .dependencies import _check_for_pandas
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pyarrow.compute as pc
|
||||
import pyarrow.fs as pa_fs
|
||||
@@ -74,34 +73,17 @@ pl = safe_import_polars()
|
||||
QueryType = Literal["vector", "fts", "hybrid", "auto"]
|
||||
|
||||
|
||||
def _pd_schema_without_embedding_funcs(
|
||||
schema: Optional[pa.Schema], columns: List[str]
|
||||
) -> Optional[pa.Schema]:
|
||||
"""Return a schema without any embedding function columns"""
|
||||
if schema is None:
|
||||
return None
|
||||
embedding_functions = EmbeddingFunctionRegistry.get_instance().parse_functions(
|
||||
schema.metadata
|
||||
)
|
||||
if not embedding_functions:
|
||||
return schema
|
||||
return pa.schema([field for field in schema if field.name in columns])
|
||||
|
||||
|
||||
def _coerce_to_table(data, schema: Optional[pa.Schema] = None) -> pa.Table:
|
||||
def _into_pyarrow_table(data) -> pa.Table:
|
||||
if _check_for_hugging_face(data):
|
||||
# Huggingface datasets
|
||||
from lance.dependencies import datasets
|
||||
|
||||
if isinstance(data, datasets.Dataset):
|
||||
if schema is None:
|
||||
schema = data.features.arrow_schema
|
||||
schema = data.features.arrow_schema
|
||||
return pa.Table.from_batches(data.data.to_batches(), schema=schema)
|
||||
elif isinstance(data, datasets.dataset_dict.DatasetDict):
|
||||
if schema is None:
|
||||
schema = _schema_from_hf(data, schema)
|
||||
schema = _schema_from_hf(data, schema)
|
||||
return pa.Table.from_batches(_to_batches_with_split(data), schema=schema)
|
||||
|
||||
if isinstance(data, LanceModel):
|
||||
raise ValueError("Cannot add a single LanceModel to a table. Use a list.")
|
||||
|
||||
@@ -111,17 +93,15 @@ def _coerce_to_table(data, schema: Optional[pa.Schema] = None) -> pa.Table:
|
||||
if isinstance(data, list):
|
||||
# convert to list of dict if data is a bunch of LanceModels
|
||||
if isinstance(data[0], LanceModel):
|
||||
if schema is None:
|
||||
schema = data[0].__class__.to_arrow_schema()
|
||||
schema = data[0].__class__.to_arrow_schema()
|
||||
data = [model_to_dict(d) for d in data]
|
||||
return pa.Table.from_pylist(data, schema=schema)
|
||||
elif isinstance(data[0], pa.RecordBatch):
|
||||
return pa.Table.from_batches(data, schema=schema)
|
||||
return pa.Table.from_batches(data)
|
||||
else:
|
||||
return pa.Table.from_pylist(data, schema=schema)
|
||||
elif _check_for_pandas(data) and isinstance(data, pd.DataFrame): # type: ignore
|
||||
raw_schema = _pd_schema_without_embedding_funcs(schema, data.columns.to_list())
|
||||
table = pa.Table.from_pandas(data, preserve_index=False, schema=raw_schema)
|
||||
return pa.Table.from_pylist(data)
|
||||
elif _check_for_pandas(data) and isinstance(data, pd.DataFrame):
|
||||
table = pa.Table.from_pandas(data, preserve_index=False)
|
||||
# Do not serialize Pandas metadata
|
||||
meta = table.schema.metadata if table.schema.metadata is not None else {}
|
||||
meta = {k: v for k, v in meta.items() if k != b"pandas"}
|
||||
@@ -143,8 +123,13 @@ def _coerce_to_table(data, schema: Optional[pa.Schema] = None) -> pa.Table:
|
||||
and data.__class__.__name__ == "DataFrame"
|
||||
):
|
||||
return data.to_arrow()
|
||||
elif (
|
||||
type(data).__module__.startswith("polars")
|
||||
and data.__class__.__name__ == "LazyFrame"
|
||||
):
|
||||
return data.collect().to_arrow()
|
||||
elif isinstance(data, Iterable):
|
||||
return _process_iterator(data, schema)
|
||||
return _iterator_to_table(data)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Unknown data type {type(data)}. "
|
||||
@@ -154,27 +139,172 @@ def _coerce_to_table(data, schema: Optional[pa.Schema] = None) -> pa.Table:
|
||||
)
|
||||
|
||||
|
||||
def _iterator_to_table(data: Iterable) -> pa.Table:
|
||||
batches = []
|
||||
schema = None # Will get schema from first batch
|
||||
for batch in data:
|
||||
batch_table = _into_pyarrow_table(batch)
|
||||
if schema is not None:
|
||||
if batch_table.schema != schema:
|
||||
try:
|
||||
batch_table = batch_table.cast(schema)
|
||||
except pa.lib.ArrowInvalid:
|
||||
raise ValueError(
|
||||
f"Input iterator yielded a batch with schema that "
|
||||
f"does not match the schema of other batches.\n"
|
||||
f"Expected:\n{schema}\nGot:\n{batch_table.schema}"
|
||||
)
|
||||
else:
|
||||
# Use the first schema for the remainder of the batches
|
||||
schema = batch_table.schema
|
||||
batches.append(batch_table)
|
||||
|
||||
if batches:
|
||||
return pa.concat_tables(batches)
|
||||
else:
|
||||
raise ValueError("Input iterable is empty")
|
||||
|
||||
|
||||
def _sanitize_data(
|
||||
data: Any,
|
||||
schema: Optional[pa.Schema] = None,
|
||||
data: "DATA",
|
||||
target_schema: Optional[pa.Schema] = None,
|
||||
metadata: Optional[dict] = None, # embedding metadata
|
||||
on_bad_vectors: str = "error",
|
||||
on_bad_vectors: Literal["error", "drop", "fill", "null"] = "error",
|
||||
fill_value: float = 0.0,
|
||||
) -> Tuple[pa.Table, pa.Schema]:
|
||||
data = _coerce_to_table(data, schema)
|
||||
*,
|
||||
allow_subschema: bool = False,
|
||||
) -> pa.Table:
|
||||
"""
|
||||
Handle input data, applying all standard transformations.
|
||||
|
||||
This includes:
|
||||
|
||||
* Converting the data to a PyArrow Table
|
||||
* Adding vector columns defined in the metadata
|
||||
* Adding embedding metadata into the schema
|
||||
* Casting the table to the target schema
|
||||
* Handling bad vectors
|
||||
|
||||
Parameters
|
||||
----------
|
||||
target_schema : Optional[pa.Schema], default None
|
||||
The schema to cast the table to. This is typically the schema of the table
|
||||
if it already exists. Otherwise it might be a user-requested schema.
|
||||
allow_subschema : bool, default False
|
||||
If True, the input table is allowed to omit columns from the target schema.
|
||||
The target schema will be filtered to only include columns that are present
|
||||
in the input table before casting.
|
||||
metadata : Optional[dict], default None
|
||||
The embedding metadata to add to the schema.
|
||||
on_bad_vectors : Literal["error", "drop", "fill", "null"], default "error"
|
||||
What to do if any of the vectors are not the same size or contains NaNs.
|
||||
fill_value : float, default 0.0
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
All entries in the vector will be set to this value.
|
||||
"""
|
||||
# At this point, the table might not match the schema we are targeting:
|
||||
# 1. There might be embedding columns missing that will be added
|
||||
# in the add_embeddings step.
|
||||
# 2. If `allow_subschemas` is True, there might be columns missing.
|
||||
table = _into_pyarrow_table(data)
|
||||
|
||||
table = _append_vector_columns(table, target_schema, metadata=metadata)
|
||||
|
||||
# This happens before the cast so we can fix vector columns with
|
||||
# incorrect lengths before they are cast to FSL.
|
||||
table = _handle_bad_vectors(
|
||||
table,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
|
||||
if target_schema is None:
|
||||
target_schema = _infer_target_schema(table)
|
||||
|
||||
if metadata:
|
||||
data = _append_vector_col(data, metadata, schema)
|
||||
metadata.update(data.schema.metadata or {})
|
||||
data = data.replace_schema_metadata(metadata)
|
||||
new_metadata = target_schema.metadata or {}
|
||||
new_metadata = new_metadata.update(metadata)
|
||||
target_schema = target_schema.with_metadata(new_metadata)
|
||||
|
||||
# TODO improve the logics in _sanitize_schema
|
||||
data = _sanitize_schema(data, schema, on_bad_vectors, fill_value)
|
||||
if schema is None:
|
||||
schema = data.schema
|
||||
_validate_schema(target_schema)
|
||||
|
||||
_validate_schema(schema)
|
||||
return data, schema
|
||||
table = _cast_to_target_schema(table, target_schema, allow_subschema)
|
||||
|
||||
return table
|
||||
|
||||
|
||||
def _cast_to_target_schema(
|
||||
table: pa.Table,
|
||||
target_schema: pa.Schema,
|
||||
allow_subschema: bool = False,
|
||||
) -> pa.Table:
|
||||
# pa.Table.cast expects field order not to be changed.
|
||||
# Lance doesn't care about field order, so we don't need to rearrange fields
|
||||
# to match the target schema. We just need to correctly cast the fields.
|
||||
if table.schema == target_schema:
|
||||
# Fast path when the schemas are already the same
|
||||
return table
|
||||
|
||||
fields = []
|
||||
for field in table.schema:
|
||||
target_field = target_schema.field(field.name)
|
||||
if target_field is None:
|
||||
raise ValueError(f"Field {field.name} not found in target schema")
|
||||
fields.append(target_field)
|
||||
reordered_schema = pa.schema(fields, metadata=target_schema.metadata)
|
||||
if not allow_subschema and len(reordered_schema) != len(target_schema):
|
||||
raise ValueError(
|
||||
"Input table has different number of columns than target schema"
|
||||
)
|
||||
|
||||
if allow_subschema and len(reordered_schema) != len(target_schema):
|
||||
fields = _infer_subschema(
|
||||
list(iter(table.schema)), list(iter(reordered_schema))
|
||||
)
|
||||
subschema = pa.schema(fields, metadata=target_schema.metadata)
|
||||
return table.cast(subschema)
|
||||
else:
|
||||
return table.cast(reordered_schema)
|
||||
|
||||
|
||||
def _infer_subschema(
|
||||
schema: List[pa.Field],
|
||||
reference_fields: List[pa.Field],
|
||||
) -> List[pa.Field]:
|
||||
"""
|
||||
Transform the list of fields so the types match the reference_fields.
|
||||
|
||||
The order of the fields is preserved.
|
||||
|
||||
``schema`` may have fewer fields than `reference_fields`, but it may not have
|
||||
more fields.
|
||||
|
||||
"""
|
||||
fields = []
|
||||
lookup = {f.name: f for f in reference_fields}
|
||||
for field in schema:
|
||||
reference = lookup.get(field.name)
|
||||
if reference is None:
|
||||
raise ValueError("Unexpected field in schema: {}".format(field))
|
||||
|
||||
if pa.types.is_struct(reference.type):
|
||||
new_type = pa.struct(
|
||||
_infer_subschema(
|
||||
field.type.fields,
|
||||
reference.type.fields,
|
||||
)
|
||||
)
|
||||
new_field = pa.field(
|
||||
field.name,
|
||||
new_type,
|
||||
reference.nullable,
|
||||
)
|
||||
else:
|
||||
new_field = reference
|
||||
|
||||
fields.append(new_field)
|
||||
|
||||
return fields
|
||||
|
||||
|
||||
def sanitize_create_table(
|
||||
@@ -193,13 +323,14 @@ def sanitize_create_table(
|
||||
if data is not None:
|
||||
if metadata is None and schema is not None:
|
||||
metadata = schema.metadata
|
||||
data, schema = _sanitize_data(
|
||||
data = _sanitize_data(
|
||||
data,
|
||||
schema,
|
||||
metadata=metadata,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
schema = data.schema
|
||||
else:
|
||||
if schema is not None:
|
||||
data = pa.Table.from_pylist([], schema)
|
||||
@@ -211,6 +342,8 @@ def sanitize_create_table(
|
||||
|
||||
if metadata:
|
||||
schema = schema.with_metadata(metadata)
|
||||
# Need to apply metadata to the data as well
|
||||
data = data.replace_schema_metadata(metadata)
|
||||
|
||||
return data, schema
|
||||
|
||||
@@ -246,12 +379,22 @@ def _to_batches_with_split(data):
|
||||
yield b
|
||||
|
||||
|
||||
def _append_vector_col(data: pa.Table, metadata: dict, schema: Optional[pa.Schema]):
|
||||
def _append_vector_columns(
|
||||
data: pa.Table,
|
||||
schema: Optional[pa.Schema] = None,
|
||||
*,
|
||||
metadata: Optional[dict] = None,
|
||||
) -> pa.Table:
|
||||
"""
|
||||
Use the embedding function to automatically embed the source column and add the
|
||||
vector column to the table.
|
||||
Use the embedding function to automatically embed the source columns and add the
|
||||
vector columns to the table.
|
||||
"""
|
||||
if schema is None:
|
||||
metadata = metadata or {}
|
||||
else:
|
||||
metadata = schema.metadata or metadata or {}
|
||||
functions = EmbeddingFunctionRegistry.get_instance().parse_functions(metadata)
|
||||
|
||||
for vector_column, conf in functions.items():
|
||||
func = conf.function
|
||||
no_vector_column = vector_column not in data.column_names
|
||||
@@ -790,9 +933,9 @@ class Table(ABC):
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> data = [
|
||||
... {"x": 1, "vector": [1, 2]},
|
||||
... {"x": 2, "vector": [3, 4]},
|
||||
... {"x": 3, "vector": [5, 6]}
|
||||
... {"x": 1, "vector": [1.0, 2]},
|
||||
... {"x": 2, "vector": [3.0, 4]},
|
||||
... {"x": 3, "vector": [5.0, 6]}
|
||||
... ]
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> table = db.create_table("my_table", data)
|
||||
@@ -854,7 +997,7 @@ class Table(ABC):
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> import pandas as pd
|
||||
>>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1, 2], [3, 4], [5, 6]]})
|
||||
>>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1.0, 2], [3, 4], [5, 6]]})
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> table = db.create_table("my_table", data)
|
||||
>>> table.to_pandas()
|
||||
@@ -862,7 +1005,7 @@ class Table(ABC):
|
||||
0 1 [1.0, 2.0]
|
||||
1 2 [3.0, 4.0]
|
||||
2 3 [5.0, 6.0]
|
||||
>>> table.update(where="x = 2", values={"vector": [10, 10]})
|
||||
>>> table.update(where="x = 2", values={"vector": [10.0, 10]})
|
||||
>>> table.to_pandas()
|
||||
x vector
|
||||
0 1 [1.0, 2.0]
|
||||
@@ -1880,9 +2023,9 @@ class LanceTable(Table):
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> data = [
|
||||
... {"x": 1, "vector": [1, 2]},
|
||||
... {"x": 2, "vector": [3, 4]},
|
||||
... {"x": 3, "vector": [5, 6]}
|
||||
... {"x": 1, "vector": [1.0, 2]},
|
||||
... {"x": 2, "vector": [3.0, 4]},
|
||||
... {"x": 3, "vector": [5.0, 6]}
|
||||
... ]
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> table = db.create_table("my_table", data)
|
||||
@@ -1971,7 +2114,7 @@ class LanceTable(Table):
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> import pandas as pd
|
||||
>>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1, 2], [3, 4], [5, 6]]})
|
||||
>>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1.0, 2], [3, 4], [5, 6]]})
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> table = db.create_table("my_table", data)
|
||||
>>> table.to_pandas()
|
||||
@@ -1979,7 +2122,7 @@ class LanceTable(Table):
|
||||
0 1 [1.0, 2.0]
|
||||
1 2 [3.0, 4.0]
|
||||
2 3 [5.0, 6.0]
|
||||
>>> table.update(where="x = 2", values={"vector": [10, 10]})
|
||||
>>> table.update(where="x = 2", values={"vector": [10.0, 10]})
|
||||
>>> table.to_pandas()
|
||||
x vector
|
||||
0 1 [1.0, 2.0]
|
||||
@@ -2165,74 +2308,49 @@ class LanceTable(Table):
|
||||
LOOP.run(self._table.migrate_v2_manifest_paths())
|
||||
|
||||
|
||||
def _sanitize_schema(
|
||||
data: pa.Table,
|
||||
schema: pa.Schema = None,
|
||||
on_bad_vectors: str = "error",
|
||||
def _handle_bad_vectors(
|
||||
table: pa.Table,
|
||||
on_bad_vectors: Literal["error", "drop", "fill", "null"] = "error",
|
||||
fill_value: float = 0.0,
|
||||
) -> pa.Table:
|
||||
"""Ensure that the table has the expected schema.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data: pa.Table
|
||||
The table to sanitize.
|
||||
schema: pa.Schema; optional
|
||||
The expected schema. If not provided, this just converts the
|
||||
vector column to fixed_size_list(float32) if necessary.
|
||||
on_bad_vectors: str, default "error"
|
||||
What to do if any of the vectors are not the same size or contains NaNs.
|
||||
One of "error", "drop", "fill", "null".
|
||||
fill_value: float, default 0.
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
"""
|
||||
if schema is not None:
|
||||
# cast the columns to the expected types
|
||||
data = data.combine_chunks()
|
||||
for field in schema:
|
||||
# TODO: we're making an assumption that fixed size list of 10 or more
|
||||
# is a vector column. This is definitely a bit hacky.
|
||||
likely_vector_col = (
|
||||
pa.types.is_fixed_size_list(field.type)
|
||||
and pa.types.is_float32(field.type.value_type)
|
||||
and field.type.list_size >= 10
|
||||
for field in table.schema:
|
||||
# They can provide a 'vector' column that isn't yet a FSL
|
||||
named_vector_col = (
|
||||
(
|
||||
pa.types.is_list(field.type)
|
||||
or pa.types.is_large_list(field.type)
|
||||
or pa.types.is_fixed_size_list(field.type)
|
||||
)
|
||||
is_default_vector_col = field.name == VECTOR_COLUMN_NAME
|
||||
if field.name in data.column_names and (
|
||||
likely_vector_col or is_default_vector_col
|
||||
):
|
||||
data = _sanitize_vector_column(
|
||||
data,
|
||||
vector_column_name=field.name,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
table_schema=schema,
|
||||
)
|
||||
return pa.Table.from_arrays(
|
||||
[data[name] for name in schema.names], schema=schema
|
||||
and pa.types.is_floating(field.type.value_type)
|
||||
and field.name == VECTOR_COLUMN_NAME
|
||||
)
|
||||
# TODO: we're making an assumption that fixed size list of 10 or more
|
||||
# is a vector column. This is definitely a bit hacky.
|
||||
likely_vector_col = (
|
||||
pa.types.is_fixed_size_list(field.type)
|
||||
and pa.types.is_floating(field.type.value_type)
|
||||
and (field.type.list_size >= 10)
|
||||
)
|
||||
|
||||
# just check the vector column
|
||||
if VECTOR_COLUMN_NAME in data.column_names:
|
||||
return _sanitize_vector_column(
|
||||
data,
|
||||
vector_column_name=VECTOR_COLUMN_NAME,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
if named_vector_col or likely_vector_col:
|
||||
table = _handle_bad_vector_column(
|
||||
table,
|
||||
vector_column_name=field.name,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
)
|
||||
|
||||
return data
|
||||
return table
|
||||
|
||||
|
||||
def _sanitize_vector_column(
|
||||
def _handle_bad_vector_column(
|
||||
data: pa.Table,
|
||||
vector_column_name: str,
|
||||
table_schema: Optional[pa.Schema] = None,
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
) -> pa.Table:
|
||||
"""
|
||||
Ensure that the vector column exists and has type fixed_size_list(float32)
|
||||
Ensure that the vector column exists and has type fixed_size_list(float)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -2246,141 +2364,118 @@ def _sanitize_vector_column(
|
||||
fill_value: float, default 0.0
|
||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||
"""
|
||||
# ChunkedArray is annoying to work with, so we combine chunks here
|
||||
vec_arr = data[vector_column_name].combine_chunks()
|
||||
if table_schema is not None:
|
||||
field = table_schema.field(vector_column_name)
|
||||
else:
|
||||
field = None
|
||||
typ = data[vector_column_name].type
|
||||
if pa.types.is_list(typ) or pa.types.is_large_list(typ):
|
||||
# if it's a variable size list array,
|
||||
# we make sure the dimensions are all the same
|
||||
has_jagged_ndims = len(vec_arr.values) % len(data) != 0
|
||||
if has_jagged_ndims:
|
||||
data = _sanitize_jagged(
|
||||
data, fill_value, on_bad_vectors, vec_arr, vector_column_name
|
||||
)
|
||||
vec_arr = data[vector_column_name].combine_chunks()
|
||||
vec_arr = ensure_fixed_size_list(vec_arr)
|
||||
data = data.set_column(
|
||||
data.column_names.index(vector_column_name), vector_column_name, vec_arr
|
||||
)
|
||||
elif not pa.types.is_fixed_size_list(vec_arr.type):
|
||||
raise TypeError(f"Unsupported vector column type: {vec_arr.type}")
|
||||
vec_arr = data[vector_column_name]
|
||||
|
||||
if pa.types.is_float16(vec_arr.values.type):
|
||||
# Use numpy to check for NaNs, because as pyarrow does not have `is_nan`
|
||||
# kernel over f16 types yet.
|
||||
values_np = vec_arr.values.to_numpy(zero_copy_only=True)
|
||||
if np.isnan(values_np).any():
|
||||
data = _sanitize_nans(
|
||||
data, fill_value, on_bad_vectors, vec_arr, vector_column_name
|
||||
)
|
||||
else:
|
||||
if (
|
||||
field is not None
|
||||
and not field.nullable
|
||||
and pc.any(pc.is_null(vec_arr.values)).as_py()
|
||||
) or (pc.any(pc.is_nan(vec_arr.values)).as_py()):
|
||||
data = _sanitize_nans(
|
||||
data, fill_value, on_bad_vectors, vec_arr, vector_column_name
|
||||
)
|
||||
return data
|
||||
has_nan = has_nan_values(vec_arr)
|
||||
|
||||
|
||||
def ensure_fixed_size_list(vec_arr) -> pa.FixedSizeListArray:
|
||||
values = vec_arr.values
|
||||
if not (pa.types.is_float16(values.type) or pa.types.is_float32(values.type)):
|
||||
values = values.cast(pa.float32())
|
||||
if pa.types.is_fixed_size_list(vec_arr.type):
|
||||
list_size = vec_arr.type.list_size
|
||||
dim = vec_arr.type.list_size
|
||||
else:
|
||||
list_size = len(values) / len(vec_arr)
|
||||
vec_arr = pa.FixedSizeListArray.from_arrays(values, list_size)
|
||||
return vec_arr
|
||||
dim = _modal_list_size(vec_arr)
|
||||
has_wrong_dim = pc.not_equal(pc.list_value_length(vec_arr), dim)
|
||||
|
||||
has_bad_vectors = pc.any(has_nan).as_py() or pc.any(has_wrong_dim).as_py()
|
||||
|
||||
def _sanitize_jagged(data, fill_value, on_bad_vectors, vec_arr, vector_column_name):
|
||||
"""Sanitize jagged vectors."""
|
||||
if on_bad_vectors == "error":
|
||||
raise ValueError(
|
||||
f"Vector column {vector_column_name} has variable length vectors "
|
||||
"Set on_bad_vectors='drop' to remove them, or "
|
||||
"set on_bad_vectors='fill' and fill_value=<value> to replace them."
|
||||
)
|
||||
|
||||
lst_lengths = pc.list_value_length(vec_arr)
|
||||
ndims = pc.max(lst_lengths).as_py()
|
||||
correct_ndims = pc.equal(lst_lengths, ndims)
|
||||
|
||||
if on_bad_vectors == "fill":
|
||||
if fill_value is None:
|
||||
raise ValueError(
|
||||
"`fill_value` must not be None if `on_bad_vectors` is 'fill'"
|
||||
if has_bad_vectors:
|
||||
is_bad = pc.or_(has_nan, has_wrong_dim)
|
||||
if on_bad_vectors == "error":
|
||||
if pc.any(has_wrong_dim).as_py():
|
||||
raise ValueError(
|
||||
f"Vector column '{vector_column_name}' has variable length "
|
||||
"vectors. Set on_bad_vectors='drop' to remove them, "
|
||||
"set on_bad_vectors='fill' and fill_value=<value> to replace them, "
|
||||
"or set on_bad_vectors='null' to replace them with null."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Vector column '{vector_column_name}' has NaNs. "
|
||||
"Set on_bad_vectors='drop' to remove them, "
|
||||
"set on_bad_vectors='fill' and fill_value=<value> to replace them, "
|
||||
"or set on_bad_vectors='null' to replace them with null."
|
||||
)
|
||||
elif on_bad_vectors == "null":
|
||||
vec_arr = pc.if_else(
|
||||
is_bad,
|
||||
pa.scalar(None),
|
||||
vec_arr,
|
||||
)
|
||||
fill_arr = pa.scalar([float(fill_value)] * ndims)
|
||||
vec_arr = pc.if_else(correct_ndims, vec_arr, fill_arr)
|
||||
data = data.set_column(
|
||||
data.column_names.index(vector_column_name), vector_column_name, vec_arr
|
||||
)
|
||||
elif on_bad_vectors == "drop":
|
||||
data = data.filter(correct_ndims)
|
||||
elif on_bad_vectors == "null":
|
||||
data = data.set_column(
|
||||
data.column_names.index(vector_column_name),
|
||||
vector_column_name,
|
||||
pc.if_else(correct_ndims, vec_arr, pa.scalar(None)),
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
def _sanitize_nans(
|
||||
data,
|
||||
fill_value,
|
||||
on_bad_vectors,
|
||||
vec_arr: pa.FixedSizeListArray,
|
||||
vector_column_name: str,
|
||||
):
|
||||
"""Sanitize NaNs in vectors"""
|
||||
assert pa.types.is_fixed_size_list(vec_arr.type)
|
||||
if on_bad_vectors == "error":
|
||||
raise ValueError(
|
||||
f"Vector column {vector_column_name} has NaNs. "
|
||||
"Set on_bad_vectors='drop' to remove them, or "
|
||||
"set on_bad_vectors='fill' and fill_value=<value> to replace them. "
|
||||
"Or set on_bad_vectors='null' to replace them with null."
|
||||
)
|
||||
elif on_bad_vectors == "fill":
|
||||
if fill_value is None:
|
||||
raise ValueError(
|
||||
"`fill_value` must not be None if `on_bad_vectors` is 'fill'"
|
||||
elif on_bad_vectors == "drop":
|
||||
data = data.filter(pc.invert(is_bad))
|
||||
vec_arr = data[vector_column_name]
|
||||
elif on_bad_vectors == "fill":
|
||||
if fill_value is None:
|
||||
raise ValueError(
|
||||
"`fill_value` must not be None if `on_bad_vectors` is 'fill'"
|
||||
)
|
||||
vec_arr = pc.if_else(
|
||||
is_bad,
|
||||
pa.scalar([fill_value] * dim),
|
||||
vec_arr,
|
||||
)
|
||||
fill_value = float(fill_value)
|
||||
values = pc.if_else(pc.is_nan(vec_arr.values), fill_value, vec_arr.values)
|
||||
ndims = len(vec_arr[0])
|
||||
vec_arr = pa.FixedSizeListArray.from_arrays(values, ndims)
|
||||
data = data.set_column(
|
||||
data.column_names.index(vector_column_name), vector_column_name, vec_arr
|
||||
)
|
||||
elif on_bad_vectors == "drop":
|
||||
# Drop is very slow to be able to filter out NaNs in a fixed size list array
|
||||
np_arr = np.isnan(vec_arr.values.to_numpy(zero_copy_only=False))
|
||||
np_arr = np_arr.reshape(-1, vec_arr.type.list_size)
|
||||
not_nulls = np.any(np_arr, axis=1)
|
||||
data = data.filter(~not_nulls)
|
||||
elif on_bad_vectors == "null":
|
||||
# null = pa.nulls(len(vec_arr)).cast(vec_arr.type)
|
||||
# values = pc.if_else(pc.is_nan(vec_arr.values), fill_value, vec_arr.values)
|
||||
np_arr = np.isnan(vec_arr.values.to_numpy(zero_copy_only=False))
|
||||
np_arr = np_arr.reshape(-1, vec_arr.type.list_size)
|
||||
no_nans = np.any(np_arr, axis=1)
|
||||
data = data.set_column(
|
||||
data.column_names.index(vector_column_name),
|
||||
vector_column_name,
|
||||
pc.if_else(no_nans, vec_arr, pa.scalar(None)),
|
||||
)
|
||||
return data
|
||||
else:
|
||||
raise ValueError(f"Invalid value for on_bad_vectors: {on_bad_vectors}")
|
||||
|
||||
position = data.column_names.index(vector_column_name)
|
||||
return data.set_column(position, vector_column_name, vec_arr)
|
||||
|
||||
|
||||
def has_nan_values(arr: Union[pa.ListArray, pa.ChunkedArray]) -> pa.BooleanArray:
|
||||
if isinstance(arr, pa.ChunkedArray):
|
||||
values = pa.chunked_array([chunk.flatten() for chunk in arr.chunks])
|
||||
else:
|
||||
values = arr.flatten()
|
||||
if pa.types.is_float16(values.type):
|
||||
# is_nan isn't yet implemented for f16, so we cast to f32
|
||||
# https://github.com/apache/arrow/issues/45083
|
||||
values_has_nan = pc.is_nan(values.cast(pa.float32()))
|
||||
else:
|
||||
values_has_nan = pc.is_nan(values)
|
||||
values_indices = pc.list_parent_indices(arr)
|
||||
has_nan_indices = pc.unique(pc.filter(values_indices, values_has_nan))
|
||||
indices = pa.array(range(len(arr)), type=pa.uint32())
|
||||
return pc.is_in(indices, has_nan_indices)
|
||||
|
||||
|
||||
def _infer_target_schema(table: pa.Table) -> pa.Schema:
|
||||
schema = table.schema
|
||||
|
||||
for i, field in enumerate(schema):
|
||||
if (
|
||||
field.name == VECTOR_COLUMN_NAME
|
||||
and (pa.types.is_list(field.type) or pa.types.is_large_list(field.type))
|
||||
and pa.types.is_floating(field.type.value_type)
|
||||
):
|
||||
# Use the most common length of the list as the dimensions
|
||||
dim = _modal_list_size(table.column(i))
|
||||
|
||||
new_field = pa.field(
|
||||
VECTOR_COLUMN_NAME,
|
||||
pa.list_(pa.float32(), dim),
|
||||
nullable=field.nullable,
|
||||
)
|
||||
|
||||
schema = schema.set(i, new_field)
|
||||
elif (
|
||||
field.name == VECTOR_COLUMN_NAME
|
||||
and (pa.types.is_list(field.type) or pa.types.is_large_list(field.type))
|
||||
and pa.types.is_integer(field.type.value_type)
|
||||
):
|
||||
# Use the most common length of the list as the dimensions
|
||||
dim = _modal_list_size(table.column(i))
|
||||
new_field = pa.field(
|
||||
VECTOR_COLUMN_NAME,
|
||||
pa.list_(pa.uint8(), dim),
|
||||
nullable=field.nullable,
|
||||
)
|
||||
|
||||
schema = schema.set(i, new_field)
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def _modal_list_size(arr: Union[pa.ListArray, pa.ChunkedArray]) -> int:
|
||||
# Use the most common length of the list as the dimensions
|
||||
return pc.mode(pc.list_value_length(arr))[0].as_py()["mode"]
|
||||
|
||||
|
||||
def _validate_schema(schema: pa.Schema):
|
||||
@@ -2410,28 +2505,6 @@ def _validate_metadata(metadata: dict):
|
||||
_validate_metadata(v)
|
||||
|
||||
|
||||
def _process_iterator(data: Iterable, schema: Optional[pa.Schema] = None) -> pa.Table:
|
||||
batches = []
|
||||
for batch in data:
|
||||
batch_table = _coerce_to_table(batch, schema)
|
||||
if schema is not None:
|
||||
if batch_table.schema != schema:
|
||||
try:
|
||||
batch_table = batch_table.cast(schema)
|
||||
except pa.lib.ArrowInvalid: # type: ignore
|
||||
raise ValueError(
|
||||
f"Input iterator yielded a batch with schema that "
|
||||
f"does not match the expected schema.\nExpected:\n{schema}\n"
|
||||
f"Got:\n{batch_table.schema}"
|
||||
)
|
||||
batches.append(batch_table)
|
||||
|
||||
if batches:
|
||||
return pa.concat_tables(batches)
|
||||
else:
|
||||
raise ValueError("Input iterable is empty")
|
||||
|
||||
|
||||
class AsyncTable:
|
||||
"""
|
||||
An AsyncTable is a collection of Records in a LanceDB Database.
|
||||
@@ -2678,16 +2751,17 @@ class AsyncTable:
|
||||
on_bad_vectors = "error"
|
||||
if fill_value is None:
|
||||
fill_value = 0.0
|
||||
table_and_schema: Tuple[pa.Table, pa.Schema] = _sanitize_data(
|
||||
data = _sanitize_data(
|
||||
data,
|
||||
schema,
|
||||
metadata=schema.metadata,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
allow_subschema=True,
|
||||
)
|
||||
tbl, schema = table_and_schema
|
||||
if isinstance(tbl, pa.Table):
|
||||
data = pa.RecordBatchReader.from_batches(schema, tbl.to_batches())
|
||||
if isinstance(data, pa.Table):
|
||||
data = data.to_reader()
|
||||
|
||||
await self._inner.add(data, mode or "append")
|
||||
|
||||
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
|
||||
@@ -2822,12 +2896,13 @@ class AsyncTable:
|
||||
on_bad_vectors = "error"
|
||||
if fill_value is None:
|
||||
fill_value = 0.0
|
||||
data, _ = _sanitize_data(
|
||||
data = _sanitize_data(
|
||||
new_data,
|
||||
schema,
|
||||
metadata=schema.metadata,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=fill_value,
|
||||
allow_subschema=True,
|
||||
)
|
||||
if isinstance(data, pa.Table):
|
||||
data = pa.RecordBatchReader.from_batches(data.schema, data.to_batches())
|
||||
@@ -2862,9 +2937,9 @@ class AsyncTable:
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> data = [
|
||||
... {"x": 1, "vector": [1, 2]},
|
||||
... {"x": 2, "vector": [3, 4]},
|
||||
... {"x": 3, "vector": [5, 6]}
|
||||
... {"x": 1, "vector": [1.0, 2]},
|
||||
... {"x": 2, "vector": [3.0, 4]},
|
||||
... {"x": 3, "vector": [5.0, 6]}
|
||||
... ]
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> table = db.create_table("my_table", data)
|
||||
|
||||
@@ -223,9 +223,7 @@ def inf_vector_column_query(schema: pa.Schema) -> str:
|
||||
vector_col_count = 0
|
||||
for field_name in schema.names:
|
||||
field = schema.field(field_name)
|
||||
if pa.types.is_fixed_size_list(field.type) and pa.types.is_floating(
|
||||
field.type.value_type
|
||||
):
|
||||
if pa.types.is_fixed_size_list(field.type):
|
||||
vector_col_count += 1
|
||||
if vector_col_count > 1:
|
||||
raise ValueError(
|
||||
|
||||
@@ -21,7 +21,7 @@ def test_binary_vector():
|
||||
]
|
||||
tbl = db.create_table("my_binary_vectors", data=data)
|
||||
query = np.random.randint(0, 256, size=16)
|
||||
tbl.search(query).to_arrow()
|
||||
tbl.search(query).metric("hamming").to_arrow()
|
||||
# --8<-- [end:sync_binary_vector]
|
||||
db.drop_table("my_binary_vectors")
|
||||
|
||||
@@ -39,6 +39,6 @@ async def test_binary_vector_async():
|
||||
]
|
||||
tbl = await db.create_table("my_binary_vectors", data=data)
|
||||
query = np.random.randint(0, 256, size=16)
|
||||
await tbl.query().nearest_to(query).to_arrow()
|
||||
await tbl.query().nearest_to(query).distance_type("hamming").to_arrow()
|
||||
# --8<-- [end:async_binary_vector]
|
||||
await db.drop_table("my_binary_vectors")
|
||||
|
||||
@@ -118,9 +118,9 @@ def test_scalar_index():
|
||||
# --8<-- [end:search_with_scalar_index]
|
||||
# --8<-- [start:vector_search_with_scalar_index]
|
||||
data = [
|
||||
{"book_id": 1, "vector": [1, 2]},
|
||||
{"book_id": 2, "vector": [3, 4]},
|
||||
{"book_id": 3, "vector": [5, 6]},
|
||||
{"book_id": 1, "vector": [1.0, 2]},
|
||||
{"book_id": 2, "vector": [3.0, 4]},
|
||||
{"book_id": 3, "vector": [5.0, 6]},
|
||||
]
|
||||
|
||||
table = db.create_table("book_with_embeddings", data)
|
||||
@@ -156,9 +156,9 @@ async def test_scalar_index_async():
|
||||
# --8<-- [end:search_with_scalar_index_async]
|
||||
# --8<-- [start:vector_search_with_scalar_index_async]
|
||||
data = [
|
||||
{"book_id": 1, "vector": [1, 2]},
|
||||
{"book_id": 2, "vector": [3, 4]},
|
||||
{"book_id": 3, "vector": [5, 6]},
|
||||
{"book_id": 1, "vector": [1.0, 2]},
|
||||
{"book_id": 2, "vector": [3.0, 4]},
|
||||
{"book_id": 3, "vector": [5.0, 6]},
|
||||
]
|
||||
async_tbl = await async_db.create_table("book_with_embeddings_async", data)
|
||||
(await async_tbl.query().where("book_id != 3").nearest_to([1, 2]).to_pandas())
|
||||
|
||||
@@ -198,7 +198,6 @@ def test_embedding_function_with_pandas(tmp_path):
|
||||
{
|
||||
"text": ["hello world", "goodbye world"],
|
||||
"val": [1, 2],
|
||||
"not-used": ["s1", "s3"],
|
||||
}
|
||||
)
|
||||
db = lancedb.connect(tmp_path)
|
||||
@@ -212,7 +211,6 @@ def test_embedding_function_with_pandas(tmp_path):
|
||||
{
|
||||
"text": ["extra", "more"],
|
||||
"val": [4, 5],
|
||||
"misc-col": ["s1", "s3"],
|
||||
}
|
||||
)
|
||||
tbl.add(df)
|
||||
|
||||
@@ -242,8 +242,8 @@ def test_add_subschema(mem_db: DBConnection):
|
||||
|
||||
data = {"price": 10.0, "item": "foo"}
|
||||
table.add([data])
|
||||
data = {"price": 2.0, "vector": [3.1, 4.1]}
|
||||
table.add([data])
|
||||
data = pd.DataFrame({"price": [2.0], "vector": [[3.1, 4.1]]})
|
||||
table.add(data)
|
||||
data = {"price": 3.0, "vector": [5.9, 26.5], "item": "bar"}
|
||||
table.add([data])
|
||||
|
||||
@@ -259,7 +259,7 @@ def test_add_subschema(mem_db: DBConnection):
|
||||
|
||||
data = {"item": "foo"}
|
||||
# We can't omit a column if it's not nullable
|
||||
with pytest.raises(RuntimeError, match="Invalid user input"):
|
||||
with pytest.raises(RuntimeError, match="Append with different schema"):
|
||||
table.add([data])
|
||||
|
||||
# We can add it if we make the column nullable
|
||||
@@ -292,6 +292,7 @@ def test_add_nullability(mem_db: DBConnection):
|
||||
]
|
||||
)
|
||||
table = mem_db.create_table("test", schema=schema)
|
||||
assert table.schema.field("vector").nullable is False
|
||||
|
||||
nullable_schema = pa.schema(
|
||||
[
|
||||
@@ -320,7 +321,10 @@ def test_add_nullability(mem_db: DBConnection):
|
||||
schema=nullable_schema,
|
||||
)
|
||||
# We can't add nullable schema if it contains nulls
|
||||
with pytest.raises(Exception, match="Vector column vector has NaNs"):
|
||||
with pytest.raises(
|
||||
Exception,
|
||||
match="Casting field 'vector' with null values to non-nullable",
|
||||
):
|
||||
table.add(data)
|
||||
|
||||
# But we can make it nullable
|
||||
@@ -776,6 +780,38 @@ def test_merge_insert(mem_db: DBConnection):
|
||||
assert table.to_arrow().sort_by("a") == expected
|
||||
|
||||
|
||||
# We vary the data format because there are slight differences in how
|
||||
# subschemas are handled in different formats
|
||||
@pytest.mark.parametrize(
|
||||
"data_format",
|
||||
[
|
||||
lambda table: table,
|
||||
lambda table: table.to_pandas(),
|
||||
lambda table: table.to_pylist(),
|
||||
],
|
||||
ids=["pa.Table", "pd.DataFrame", "rows"],
|
||||
)
|
||||
def test_merge_insert_subschema(mem_db: DBConnection, data_format):
|
||||
initial_data = pa.table(
|
||||
{"id": range(3), "a": [1.0, 2.0, 3.0], "c": ["x", "x", "x"]}
|
||||
)
|
||||
table = mem_db.create_table("my_table", data=initial_data)
|
||||
|
||||
new_data = pa.table({"id": [2, 3], "c": ["y", "y"]})
|
||||
new_data = data_format(new_data)
|
||||
(
|
||||
table.merge_insert(on="id")
|
||||
.when_matched_update_all()
|
||||
.when_not_matched_insert_all()
|
||||
.execute(new_data)
|
||||
)
|
||||
|
||||
expected = pa.table(
|
||||
{"id": [0, 1, 2, 3], "a": [1.0, 2.0, 3.0, None], "c": ["x", "x", "y", "y"]}
|
||||
)
|
||||
assert table.to_arrow().sort_by("id") == expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merge_insert_async(mem_db_async: AsyncConnection):
|
||||
data = pa.table({"a": [1, 2, 3], "b": ["a", "b", "c"]})
|
||||
|
||||
@@ -13,10 +13,27 @@
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
from typing import Optional
|
||||
|
||||
import lance
|
||||
from lancedb.conftest import MockTextEmbeddingFunction
|
||||
from lancedb.embeddings.base import EmbeddingFunctionConfig
|
||||
from lancedb.embeddings.registry import EmbeddingFunctionRegistry
|
||||
from lancedb.table import (
|
||||
_append_vector_columns,
|
||||
_cast_to_target_schema,
|
||||
_handle_bad_vectors,
|
||||
_into_pyarrow_table,
|
||||
_sanitize_data,
|
||||
_infer_target_schema,
|
||||
)
|
||||
import pyarrow as pa
|
||||
import pandas as pd
|
||||
import polars as pl
|
||||
import pytest
|
||||
import lancedb
|
||||
from lancedb.util import get_uri_scheme, join_uri, value_to_sql
|
||||
from utils import exception_output
|
||||
|
||||
|
||||
def test_normalize_uri():
|
||||
@@ -111,3 +128,460 @@ def test_value_to_sql_string(tmp_path):
|
||||
for value in values:
|
||||
table.update(where=f"search = {value_to_sql(value)}", values={"replace": value})
|
||||
assert table.to_pandas().query("search == @value")["replace"].item() == value
|
||||
|
||||
|
||||
def test_append_vector_columns():
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
registry.register("test")(MockTextEmbeddingFunction)
|
||||
conf = EmbeddingFunctionConfig(
|
||||
source_column="text",
|
||||
vector_column="vector",
|
||||
function=MockTextEmbeddingFunction(),
|
||||
)
|
||||
metadata = registry.get_table_metadata([conf])
|
||||
|
||||
schema = pa.schema(
|
||||
{
|
||||
"text": pa.string(),
|
||||
"vector": pa.list_(pa.float64(), 10),
|
||||
}
|
||||
)
|
||||
data = pa.table(
|
||||
{
|
||||
"text": ["hello"],
|
||||
"vector": [None], # Replaces null
|
||||
},
|
||||
schema=schema,
|
||||
)
|
||||
output = _append_vector_columns(
|
||||
data,
|
||||
schema, # metadata passed separate from schema
|
||||
metadata=metadata,
|
||||
)
|
||||
assert output.schema == schema
|
||||
assert output["vector"].null_count == 0
|
||||
|
||||
# Adds if missing
|
||||
data = pa.table({"text": ["hello"]})
|
||||
output = _append_vector_columns(
|
||||
data,
|
||||
schema.with_metadata(metadata),
|
||||
)
|
||||
assert output.schema == schema
|
||||
assert output["vector"].null_count == 0
|
||||
|
||||
# doesn't embed if already there
|
||||
data = pa.table(
|
||||
{
|
||||
"text": ["hello"],
|
||||
"vector": [[42.0] * 10],
|
||||
},
|
||||
schema=schema,
|
||||
)
|
||||
output = _append_vector_columns(
|
||||
data,
|
||||
schema.with_metadata(metadata),
|
||||
)
|
||||
assert output == data # No change
|
||||
|
||||
# No provided schema
|
||||
data = pa.table(
|
||||
{
|
||||
"text": ["hello"],
|
||||
}
|
||||
)
|
||||
output = _append_vector_columns(
|
||||
data,
|
||||
metadata=metadata,
|
||||
)
|
||||
expected_schema = pa.schema(
|
||||
{
|
||||
"text": pa.string(),
|
||||
"vector": pa.list_(pa.float32(), 10),
|
||||
}
|
||||
)
|
||||
assert output.schema == expected_schema
|
||||
assert output["vector"].null_count == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("on_bad_vectors", ["error", "drop", "fill", "null"])
|
||||
def test_handle_bad_vectors_jagged(on_bad_vectors):
|
||||
vector = pa.array([[1.0, 2.0], [3.0], [4.0, 5.0]])
|
||||
schema = pa.schema({"vector": pa.list_(pa.float64())})
|
||||
data = pa.table({"vector": vector}, schema=schema)
|
||||
|
||||
if on_bad_vectors == "error":
|
||||
with pytest.raises(ValueError) as e:
|
||||
output = _handle_bad_vectors(
|
||||
data,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
)
|
||||
output = exception_output(e)
|
||||
assert output == (
|
||||
"ValueError: Vector column 'vector' has variable length vectors. Set "
|
||||
"on_bad_vectors='drop' to remove them, set on_bad_vectors='fill' "
|
||||
"and fill_value=<value> to replace them, or set on_bad_vectors='null' "
|
||||
"to replace them with null."
|
||||
)
|
||||
return
|
||||
else:
|
||||
output = _handle_bad_vectors(
|
||||
data,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=42.0,
|
||||
)
|
||||
|
||||
if on_bad_vectors == "drop":
|
||||
expected = pa.array([[1.0, 2.0], [4.0, 5.0]])
|
||||
elif on_bad_vectors == "fill":
|
||||
expected = pa.array([[1.0, 2.0], [42.0, 42.0], [4.0, 5.0]])
|
||||
elif on_bad_vectors == "null":
|
||||
expected = pa.array([[1.0, 2.0], None, [4.0, 5.0]])
|
||||
|
||||
assert output["vector"].combine_chunks() == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("on_bad_vectors", ["error", "drop", "fill", "null"])
|
||||
def test_handle_bad_vectors_nan(on_bad_vectors):
|
||||
vector = pa.array([[1.0, float("nan")], [3.0, 4.0]])
|
||||
data = pa.table({"vector": vector})
|
||||
|
||||
if on_bad_vectors == "error":
|
||||
with pytest.raises(ValueError) as e:
|
||||
output = _handle_bad_vectors(
|
||||
data,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
)
|
||||
output = exception_output(e)
|
||||
assert output == (
|
||||
"ValueError: Vector column 'vector' has NaNs. Set "
|
||||
"on_bad_vectors='drop' to remove them, set on_bad_vectors='fill' "
|
||||
"and fill_value=<value> to replace them, or set on_bad_vectors='null' "
|
||||
"to replace them with null."
|
||||
)
|
||||
return
|
||||
else:
|
||||
output = _handle_bad_vectors(
|
||||
data,
|
||||
on_bad_vectors=on_bad_vectors,
|
||||
fill_value=42.0,
|
||||
)
|
||||
|
||||
if on_bad_vectors == "drop":
|
||||
expected = pa.array([[3.0, 4.0]])
|
||||
elif on_bad_vectors == "fill":
|
||||
expected = pa.array([[42.0, 42.0], [3.0, 4.0]])
|
||||
elif on_bad_vectors == "null":
|
||||
expected = pa.array([None, [3.0, 4.0]])
|
||||
|
||||
assert output["vector"].combine_chunks() == expected
|
||||
|
||||
|
||||
def test_handle_bad_vectors_noop():
|
||||
# ChunkedArray should be preserved as-is
|
||||
vector = pa.chunked_array(
|
||||
[[[1.0, 2.0], [3.0, 4.0]]], type=pa.list_(pa.float64(), 2)
|
||||
)
|
||||
data = pa.table({"vector": vector})
|
||||
output = _handle_bad_vectors(data)
|
||||
assert output["vector"] == vector
|
||||
|
||||
|
||||
class TestModel(lancedb.pydantic.LanceModel):
|
||||
a: Optional[int]
|
||||
b: Optional[int]
|
||||
|
||||
|
||||
# TODO: huggingface,
|
||||
@pytest.mark.parametrize(
|
||||
"data",
|
||||
[
|
||||
lambda: [{"a": 1, "b": 2}],
|
||||
lambda: pa.RecordBatch.from_pylist([{"a": 1, "b": 2}]),
|
||||
lambda: pa.table({"a": [1], "b": [2]}),
|
||||
lambda: pa.table({"a": [1], "b": [2]}).to_reader(),
|
||||
lambda: iter(pa.table({"a": [1], "b": [2]}).to_batches()),
|
||||
lambda: (
|
||||
lance.write_dataset(
|
||||
pa.table({"a": [1], "b": [2]}),
|
||||
"memory://test",
|
||||
)
|
||||
),
|
||||
lambda: (
|
||||
lance.write_dataset(
|
||||
pa.table({"a": [1], "b": [2]}),
|
||||
"memory://test",
|
||||
).scanner()
|
||||
),
|
||||
lambda: pd.DataFrame({"a": [1], "b": [2]}),
|
||||
lambda: pl.DataFrame({"a": [1], "b": [2]}),
|
||||
lambda: pl.LazyFrame({"a": [1], "b": [2]}),
|
||||
lambda: [TestModel(a=1, b=2)],
|
||||
],
|
||||
ids=[
|
||||
"rows",
|
||||
"pa.RecordBatch",
|
||||
"pa.Table",
|
||||
"pa.RecordBatchReader",
|
||||
"batch_iter",
|
||||
"lance.LanceDataset",
|
||||
"lance.LanceScanner",
|
||||
"pd.DataFrame",
|
||||
"pl.DataFrame",
|
||||
"pl.LazyFrame",
|
||||
"pydantic",
|
||||
],
|
||||
)
|
||||
def test_into_pyarrow_table(data):
|
||||
expected = pa.table({"a": [1], "b": [2]})
|
||||
output = _into_pyarrow_table(data())
|
||||
assert output == expected
|
||||
|
||||
|
||||
def test_infer_target_schema():
|
||||
example = pa.schema(
|
||||
{
|
||||
"vec1": pa.list_(pa.float64(), 2),
|
||||
"vector": pa.list_(pa.float64()),
|
||||
}
|
||||
)
|
||||
data = pa.table(
|
||||
{
|
||||
"vec1": [[0.0] * 2],
|
||||
"vector": [[0.0] * 2],
|
||||
},
|
||||
schema=example,
|
||||
)
|
||||
expected = pa.schema(
|
||||
{
|
||||
"vec1": pa.list_(pa.float64(), 2),
|
||||
"vector": pa.list_(pa.float32(), 2),
|
||||
}
|
||||
)
|
||||
output = _infer_target_schema(data)
|
||||
assert output == expected
|
||||
|
||||
# Handle large list and use modal size
|
||||
# Most vectors are of length 2, so we should infer that as the target dimension
|
||||
example = pa.schema(
|
||||
{
|
||||
"vector": pa.large_list(pa.float64()),
|
||||
}
|
||||
)
|
||||
data = pa.table(
|
||||
{
|
||||
"vector": [[0.0] * 2, [0.0], [0.0] * 2],
|
||||
},
|
||||
schema=example,
|
||||
)
|
||||
expected = pa.schema(
|
||||
{
|
||||
"vector": pa.list_(pa.float32(), 2),
|
||||
}
|
||||
)
|
||||
output = _infer_target_schema(data)
|
||||
assert output == expected
|
||||
|
||||
# ignore if not list
|
||||
example = pa.schema(
|
||||
{
|
||||
"vector": pa.float64(),
|
||||
}
|
||||
)
|
||||
data = pa.table(
|
||||
{
|
||||
"vector": [0.0],
|
||||
},
|
||||
schema=example,
|
||||
)
|
||||
expected = example
|
||||
output = _infer_target_schema(data)
|
||||
assert output == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"data",
|
||||
[
|
||||
[{"id": 1, "text": "hello"}],
|
||||
pa.RecordBatch.from_pylist([{"id": 1, "text": "hello"}]),
|
||||
pd.DataFrame({"id": [1], "text": ["hello"]}),
|
||||
pl.DataFrame({"id": [1], "text": ["hello"]}),
|
||||
],
|
||||
ids=["rows", "pa.RecordBatch", "pd.DataFrame", "pl.DataFrame"],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"schema",
|
||||
[
|
||||
None,
|
||||
pa.schema(
|
||||
{
|
||||
"id": pa.int32(),
|
||||
"text": pa.string(),
|
||||
"vector": pa.list_(pa.float32(), 10),
|
||||
}
|
||||
),
|
||||
pa.schema(
|
||||
{
|
||||
"id": pa.int64(),
|
||||
"text": pa.string(),
|
||||
"vector": pa.list_(pa.float32(), 10),
|
||||
"extra": pa.int64(),
|
||||
}
|
||||
),
|
||||
],
|
||||
ids=["infer", "explicit", "subschema"],
|
||||
)
|
||||
@pytest.mark.parametrize("with_embedding", [True, False])
|
||||
def test_sanitize_data(
|
||||
data,
|
||||
schema: Optional[pa.Schema],
|
||||
with_embedding: bool,
|
||||
):
|
||||
if with_embedding:
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
registry.register("test")(MockTextEmbeddingFunction)
|
||||
conf = EmbeddingFunctionConfig(
|
||||
source_column="text",
|
||||
vector_column="vector",
|
||||
function=MockTextEmbeddingFunction(),
|
||||
)
|
||||
metadata = registry.get_table_metadata([conf])
|
||||
else:
|
||||
metadata = None
|
||||
|
||||
if schema is not None:
|
||||
to_remove = schema.get_field_index("extra")
|
||||
if to_remove >= 0:
|
||||
expected_schema = schema.remove(to_remove)
|
||||
else:
|
||||
expected_schema = schema
|
||||
else:
|
||||
expected_schema = pa.schema(
|
||||
{
|
||||
"id": pa.int64(),
|
||||
"text": pa.large_utf8()
|
||||
if isinstance(data, pl.DataFrame)
|
||||
else pa.string(),
|
||||
"vector": pa.list_(pa.float32(), 10),
|
||||
}
|
||||
)
|
||||
|
||||
if not with_embedding:
|
||||
to_remove = expected_schema.get_field_index("vector")
|
||||
if to_remove >= 0:
|
||||
expected_schema = expected_schema.remove(to_remove)
|
||||
|
||||
expected = pa.table(
|
||||
{
|
||||
"id": [1],
|
||||
"text": ["hello"],
|
||||
"vector": [[0.0] * 10],
|
||||
},
|
||||
schema=expected_schema,
|
||||
)
|
||||
|
||||
output_data = _sanitize_data(
|
||||
data,
|
||||
target_schema=schema,
|
||||
metadata=metadata,
|
||||
allow_subschema=True,
|
||||
)
|
||||
|
||||
assert output_data == expected
|
||||
|
||||
|
||||
def test_cast_to_target_schema():
|
||||
original_schema = pa.schema(
|
||||
{
|
||||
"id": pa.int32(),
|
||||
"struct": pa.struct(
|
||||
[
|
||||
pa.field("a", pa.int32()),
|
||||
]
|
||||
),
|
||||
"vector": pa.list_(pa.float64()),
|
||||
"vec1": pa.list_(pa.float64(), 2),
|
||||
"vec2": pa.list_(pa.float32(), 2),
|
||||
}
|
||||
)
|
||||
data = pa.table(
|
||||
{
|
||||
"id": [1],
|
||||
"struct": [{"a": 1}],
|
||||
"vector": [[0.0] * 2],
|
||||
"vec1": [[0.0] * 2],
|
||||
"vec2": [[0.0] * 2],
|
||||
},
|
||||
schema=original_schema,
|
||||
)
|
||||
|
||||
target = pa.schema(
|
||||
{
|
||||
"id": pa.int64(),
|
||||
"struct": pa.struct(
|
||||
[
|
||||
pa.field("a", pa.int64()),
|
||||
]
|
||||
),
|
||||
"vector": pa.list_(pa.float32(), 2),
|
||||
"vec1": pa.list_(pa.float32(), 2),
|
||||
"vec2": pa.list_(pa.float32(), 2),
|
||||
}
|
||||
)
|
||||
output = _cast_to_target_schema(data, target)
|
||||
expected = pa.table(
|
||||
{
|
||||
"id": [1],
|
||||
"struct": [{"a": 1}],
|
||||
"vector": [[0.0] * 2],
|
||||
"vec1": [[0.0] * 2],
|
||||
"vec2": [[0.0] * 2],
|
||||
},
|
||||
schema=target,
|
||||
)
|
||||
|
||||
# Data can be a subschema of the target
|
||||
target = pa.schema(
|
||||
{
|
||||
"id": pa.int64(),
|
||||
"struct": pa.struct(
|
||||
[
|
||||
pa.field("a", pa.int64()),
|
||||
# Additional nested field
|
||||
pa.field("b", pa.int64()),
|
||||
]
|
||||
),
|
||||
"vector": pa.list_(pa.float32(), 2),
|
||||
"vec1": pa.list_(pa.float32(), 2),
|
||||
"vec2": pa.list_(pa.float32(), 2),
|
||||
# Additional field
|
||||
"extra": pa.int64(),
|
||||
}
|
||||
)
|
||||
with pytest.raises(Exception):
|
||||
_cast_to_target_schema(data, target)
|
||||
output = _cast_to_target_schema(data, target, allow_subschema=True)
|
||||
expected_schema = pa.schema(
|
||||
{
|
||||
"id": pa.int64(),
|
||||
"struct": pa.struct(
|
||||
[
|
||||
pa.field("a", pa.int64()),
|
||||
]
|
||||
),
|
||||
"vector": pa.list_(pa.float32(), 2),
|
||||
"vec1": pa.list_(pa.float32(), 2),
|
||||
"vec2": pa.list_(pa.float32(), 2),
|
||||
}
|
||||
)
|
||||
expected = pa.table(
|
||||
{
|
||||
"id": [1],
|
||||
"struct": [{"a": 1}],
|
||||
"vector": [[0.0] * 2],
|
||||
"vec1": [[0.0] * 2],
|
||||
"vec2": [[0.0] * 2],
|
||||
},
|
||||
schema=expected_schema,
|
||||
)
|
||||
assert output == expected
|
||||
|
||||
Reference in New Issue
Block a user