mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-06 20:02:58 +00:00
feat(python)!: support inserting and upserting subschemas (#1965)
BREAKING CHANGE: For a field "vector", list of integers will now be converted to binary (uint8) vectors instead of f32 vectors. Use float values instead for f32 vectors. * Adds proper support for inserting and upserting subsets of the full schema. I thought I had previously implemented this in #1827, but it turns out I had not tested carefully enough. * Refactors `_santize_data` and other utility functions to be simpler and not require `numpy` or `combine_chunks()`. * Added a new suite of unit tests to validate sanitization utilities. ## Examples ```python import pandas as pd import lancedb db = lancedb.connect("memory://demo") intial_data = pd.DataFrame({ "a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9] }) table = db.create_table("demo", intial_data) # Insert a subschema new_data = pd.DataFrame({"a": [10, 11]}) table.add(new_data) table.to_pandas() ``` ``` a b c 0 1 4.0 7.0 1 2 5.0 8.0 2 3 6.0 9.0 3 10 NaN NaN 4 11 NaN NaN ``` ```python # Upsert a subschema upsert_data = pd.DataFrame({ "a": [3, 10, 15], "b": [6, 7, 8], }) table.merge_insert(on="a").when_matched_update_all().when_not_matched_insert_all().execute(upsert_data) table.to_pandas() ``` ``` a b c 0 1 4.0 7.0 1 2 5.0 8.0 2 3 6.0 9.0 3 10 7.0 NaN 4 11 NaN NaN 5 15 8.0 NaN ```
This commit is contained in:
@@ -784,10 +784,6 @@ class AsyncConnection(object):
|
|||||||
registry = EmbeddingFunctionRegistry.get_instance()
|
registry = EmbeddingFunctionRegistry.get_instance()
|
||||||
metadata = registry.get_table_metadata(embedding_functions)
|
metadata = registry.get_table_metadata(embedding_functions)
|
||||||
|
|
||||||
data, schema = sanitize_create_table(
|
|
||||||
data, schema, metadata, on_bad_vectors, fill_value
|
|
||||||
)
|
|
||||||
|
|
||||||
# Defining defaults here and not in function prototype. In the future
|
# Defining defaults here and not in function prototype. In the future
|
||||||
# these defaults will move into rust so better to keep them as None.
|
# these defaults will move into rust so better to keep them as None.
|
||||||
if on_bad_vectors is None:
|
if on_bad_vectors is None:
|
||||||
|
|||||||
@@ -108,9 +108,14 @@ class EmbeddingFunctionRegistry:
|
|||||||
An empty dict is returned if input is None or does not
|
An empty dict is returned if input is None or does not
|
||||||
contain b"embedding_functions".
|
contain b"embedding_functions".
|
||||||
"""
|
"""
|
||||||
if metadata is None or b"embedding_functions" not in metadata:
|
if metadata is None:
|
||||||
|
return {}
|
||||||
|
# Look at both bytes and string keys, since we might use either
|
||||||
|
serialized = metadata.get(
|
||||||
|
b"embedding_functions", metadata.get("embedding_functions")
|
||||||
|
)
|
||||||
|
if serialized is None:
|
||||||
return {}
|
return {}
|
||||||
serialized = metadata[b"embedding_functions"]
|
|
||||||
raw_list = json.loads(serialized.decode("utf-8"))
|
raw_list = json.loads(serialized.decode("utf-8"))
|
||||||
return {
|
return {
|
||||||
obj["vector_column"]: EmbeddingFunctionConfig(
|
obj["vector_column"]: EmbeddingFunctionConfig(
|
||||||
|
|||||||
@@ -472,7 +472,7 @@ class LanceQueryBuilder(ABC):
|
|||||||
--------
|
--------
|
||||||
>>> import lancedb
|
>>> import lancedb
|
||||||
>>> db = lancedb.connect("./.lancedb")
|
>>> db = lancedb.connect("./.lancedb")
|
||||||
>>> table = db.create_table("my_table", [{"vector": [99, 99]}])
|
>>> table = db.create_table("my_table", [{"vector": [99.0, 99]}])
|
||||||
>>> query = [100, 100]
|
>>> query = [100, 100]
|
||||||
>>> plan = table.search(query).explain_plan(True)
|
>>> plan = table.search(query).explain_plan(True)
|
||||||
>>> print(plan) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
|
>>> print(plan) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
|
||||||
|
|||||||
@@ -25,7 +25,6 @@ from urllib.parse import urlparse
|
|||||||
import lance
|
import lance
|
||||||
from lancedb.background_loop import LOOP
|
from lancedb.background_loop import LOOP
|
||||||
from .dependencies import _check_for_pandas
|
from .dependencies import _check_for_pandas
|
||||||
import numpy as np
|
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
import pyarrow.compute as pc
|
import pyarrow.compute as pc
|
||||||
import pyarrow.fs as pa_fs
|
import pyarrow.fs as pa_fs
|
||||||
@@ -74,34 +73,17 @@ pl = safe_import_polars()
|
|||||||
QueryType = Literal["vector", "fts", "hybrid", "auto"]
|
QueryType = Literal["vector", "fts", "hybrid", "auto"]
|
||||||
|
|
||||||
|
|
||||||
def _pd_schema_without_embedding_funcs(
|
def _into_pyarrow_table(data) -> pa.Table:
|
||||||
schema: Optional[pa.Schema], columns: List[str]
|
|
||||||
) -> Optional[pa.Schema]:
|
|
||||||
"""Return a schema without any embedding function columns"""
|
|
||||||
if schema is None:
|
|
||||||
return None
|
|
||||||
embedding_functions = EmbeddingFunctionRegistry.get_instance().parse_functions(
|
|
||||||
schema.metadata
|
|
||||||
)
|
|
||||||
if not embedding_functions:
|
|
||||||
return schema
|
|
||||||
return pa.schema([field for field in schema if field.name in columns])
|
|
||||||
|
|
||||||
|
|
||||||
def _coerce_to_table(data, schema: Optional[pa.Schema] = None) -> pa.Table:
|
|
||||||
if _check_for_hugging_face(data):
|
if _check_for_hugging_face(data):
|
||||||
# Huggingface datasets
|
# Huggingface datasets
|
||||||
from lance.dependencies import datasets
|
from lance.dependencies import datasets
|
||||||
|
|
||||||
if isinstance(data, datasets.Dataset):
|
if isinstance(data, datasets.Dataset):
|
||||||
if schema is None:
|
schema = data.features.arrow_schema
|
||||||
schema = data.features.arrow_schema
|
|
||||||
return pa.Table.from_batches(data.data.to_batches(), schema=schema)
|
return pa.Table.from_batches(data.data.to_batches(), schema=schema)
|
||||||
elif isinstance(data, datasets.dataset_dict.DatasetDict):
|
elif isinstance(data, datasets.dataset_dict.DatasetDict):
|
||||||
if schema is None:
|
schema = _schema_from_hf(data, schema)
|
||||||
schema = _schema_from_hf(data, schema)
|
|
||||||
return pa.Table.from_batches(_to_batches_with_split(data), schema=schema)
|
return pa.Table.from_batches(_to_batches_with_split(data), schema=schema)
|
||||||
|
|
||||||
if isinstance(data, LanceModel):
|
if isinstance(data, LanceModel):
|
||||||
raise ValueError("Cannot add a single LanceModel to a table. Use a list.")
|
raise ValueError("Cannot add a single LanceModel to a table. Use a list.")
|
||||||
|
|
||||||
@@ -111,17 +93,15 @@ def _coerce_to_table(data, schema: Optional[pa.Schema] = None) -> pa.Table:
|
|||||||
if isinstance(data, list):
|
if isinstance(data, list):
|
||||||
# convert to list of dict if data is a bunch of LanceModels
|
# convert to list of dict if data is a bunch of LanceModels
|
||||||
if isinstance(data[0], LanceModel):
|
if isinstance(data[0], LanceModel):
|
||||||
if schema is None:
|
schema = data[0].__class__.to_arrow_schema()
|
||||||
schema = data[0].__class__.to_arrow_schema()
|
|
||||||
data = [model_to_dict(d) for d in data]
|
data = [model_to_dict(d) for d in data]
|
||||||
return pa.Table.from_pylist(data, schema=schema)
|
return pa.Table.from_pylist(data, schema=schema)
|
||||||
elif isinstance(data[0], pa.RecordBatch):
|
elif isinstance(data[0], pa.RecordBatch):
|
||||||
return pa.Table.from_batches(data, schema=schema)
|
return pa.Table.from_batches(data)
|
||||||
else:
|
else:
|
||||||
return pa.Table.from_pylist(data, schema=schema)
|
return pa.Table.from_pylist(data)
|
||||||
elif _check_for_pandas(data) and isinstance(data, pd.DataFrame): # type: ignore
|
elif _check_for_pandas(data) and isinstance(data, pd.DataFrame):
|
||||||
raw_schema = _pd_schema_without_embedding_funcs(schema, data.columns.to_list())
|
table = pa.Table.from_pandas(data, preserve_index=False)
|
||||||
table = pa.Table.from_pandas(data, preserve_index=False, schema=raw_schema)
|
|
||||||
# Do not serialize Pandas metadata
|
# Do not serialize Pandas metadata
|
||||||
meta = table.schema.metadata if table.schema.metadata is not None else {}
|
meta = table.schema.metadata if table.schema.metadata is not None else {}
|
||||||
meta = {k: v for k, v in meta.items() if k != b"pandas"}
|
meta = {k: v for k, v in meta.items() if k != b"pandas"}
|
||||||
@@ -143,8 +123,13 @@ def _coerce_to_table(data, schema: Optional[pa.Schema] = None) -> pa.Table:
|
|||||||
and data.__class__.__name__ == "DataFrame"
|
and data.__class__.__name__ == "DataFrame"
|
||||||
):
|
):
|
||||||
return data.to_arrow()
|
return data.to_arrow()
|
||||||
|
elif (
|
||||||
|
type(data).__module__.startswith("polars")
|
||||||
|
and data.__class__.__name__ == "LazyFrame"
|
||||||
|
):
|
||||||
|
return data.collect().to_arrow()
|
||||||
elif isinstance(data, Iterable):
|
elif isinstance(data, Iterable):
|
||||||
return _process_iterator(data, schema)
|
return _iterator_to_table(data)
|
||||||
else:
|
else:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"Unknown data type {type(data)}. "
|
f"Unknown data type {type(data)}. "
|
||||||
@@ -154,27 +139,172 @@ def _coerce_to_table(data, schema: Optional[pa.Schema] = None) -> pa.Table:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _iterator_to_table(data: Iterable) -> pa.Table:
|
||||||
|
batches = []
|
||||||
|
schema = None # Will get schema from first batch
|
||||||
|
for batch in data:
|
||||||
|
batch_table = _into_pyarrow_table(batch)
|
||||||
|
if schema is not None:
|
||||||
|
if batch_table.schema != schema:
|
||||||
|
try:
|
||||||
|
batch_table = batch_table.cast(schema)
|
||||||
|
except pa.lib.ArrowInvalid:
|
||||||
|
raise ValueError(
|
||||||
|
f"Input iterator yielded a batch with schema that "
|
||||||
|
f"does not match the schema of other batches.\n"
|
||||||
|
f"Expected:\n{schema}\nGot:\n{batch_table.schema}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Use the first schema for the remainder of the batches
|
||||||
|
schema = batch_table.schema
|
||||||
|
batches.append(batch_table)
|
||||||
|
|
||||||
|
if batches:
|
||||||
|
return pa.concat_tables(batches)
|
||||||
|
else:
|
||||||
|
raise ValueError("Input iterable is empty")
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_data(
|
def _sanitize_data(
|
||||||
data: Any,
|
data: "DATA",
|
||||||
schema: Optional[pa.Schema] = None,
|
target_schema: Optional[pa.Schema] = None,
|
||||||
metadata: Optional[dict] = None, # embedding metadata
|
metadata: Optional[dict] = None, # embedding metadata
|
||||||
on_bad_vectors: str = "error",
|
on_bad_vectors: Literal["error", "drop", "fill", "null"] = "error",
|
||||||
fill_value: float = 0.0,
|
fill_value: float = 0.0,
|
||||||
) -> Tuple[pa.Table, pa.Schema]:
|
*,
|
||||||
data = _coerce_to_table(data, schema)
|
allow_subschema: bool = False,
|
||||||
|
) -> pa.Table:
|
||||||
|
"""
|
||||||
|
Handle input data, applying all standard transformations.
|
||||||
|
|
||||||
|
This includes:
|
||||||
|
|
||||||
|
* Converting the data to a PyArrow Table
|
||||||
|
* Adding vector columns defined in the metadata
|
||||||
|
* Adding embedding metadata into the schema
|
||||||
|
* Casting the table to the target schema
|
||||||
|
* Handling bad vectors
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
target_schema : Optional[pa.Schema], default None
|
||||||
|
The schema to cast the table to. This is typically the schema of the table
|
||||||
|
if it already exists. Otherwise it might be a user-requested schema.
|
||||||
|
allow_subschema : bool, default False
|
||||||
|
If True, the input table is allowed to omit columns from the target schema.
|
||||||
|
The target schema will be filtered to only include columns that are present
|
||||||
|
in the input table before casting.
|
||||||
|
metadata : Optional[dict], default None
|
||||||
|
The embedding metadata to add to the schema.
|
||||||
|
on_bad_vectors : Literal["error", "drop", "fill", "null"], default "error"
|
||||||
|
What to do if any of the vectors are not the same size or contains NaNs.
|
||||||
|
fill_value : float, default 0.0
|
||||||
|
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||||
|
All entries in the vector will be set to this value.
|
||||||
|
"""
|
||||||
|
# At this point, the table might not match the schema we are targeting:
|
||||||
|
# 1. There might be embedding columns missing that will be added
|
||||||
|
# in the add_embeddings step.
|
||||||
|
# 2. If `allow_subschemas` is True, there might be columns missing.
|
||||||
|
table = _into_pyarrow_table(data)
|
||||||
|
|
||||||
|
table = _append_vector_columns(table, target_schema, metadata=metadata)
|
||||||
|
|
||||||
|
# This happens before the cast so we can fix vector columns with
|
||||||
|
# incorrect lengths before they are cast to FSL.
|
||||||
|
table = _handle_bad_vectors(
|
||||||
|
table,
|
||||||
|
on_bad_vectors=on_bad_vectors,
|
||||||
|
fill_value=fill_value,
|
||||||
|
)
|
||||||
|
|
||||||
|
if target_schema is None:
|
||||||
|
target_schema = _infer_target_schema(table)
|
||||||
|
|
||||||
if metadata:
|
if metadata:
|
||||||
data = _append_vector_col(data, metadata, schema)
|
new_metadata = target_schema.metadata or {}
|
||||||
metadata.update(data.schema.metadata or {})
|
new_metadata = new_metadata.update(metadata)
|
||||||
data = data.replace_schema_metadata(metadata)
|
target_schema = target_schema.with_metadata(new_metadata)
|
||||||
|
|
||||||
# TODO improve the logics in _sanitize_schema
|
_validate_schema(target_schema)
|
||||||
data = _sanitize_schema(data, schema, on_bad_vectors, fill_value)
|
|
||||||
if schema is None:
|
|
||||||
schema = data.schema
|
|
||||||
|
|
||||||
_validate_schema(schema)
|
table = _cast_to_target_schema(table, target_schema, allow_subschema)
|
||||||
return data, schema
|
|
||||||
|
return table
|
||||||
|
|
||||||
|
|
||||||
|
def _cast_to_target_schema(
|
||||||
|
table: pa.Table,
|
||||||
|
target_schema: pa.Schema,
|
||||||
|
allow_subschema: bool = False,
|
||||||
|
) -> pa.Table:
|
||||||
|
# pa.Table.cast expects field order not to be changed.
|
||||||
|
# Lance doesn't care about field order, so we don't need to rearrange fields
|
||||||
|
# to match the target schema. We just need to correctly cast the fields.
|
||||||
|
if table.schema == target_schema:
|
||||||
|
# Fast path when the schemas are already the same
|
||||||
|
return table
|
||||||
|
|
||||||
|
fields = []
|
||||||
|
for field in table.schema:
|
||||||
|
target_field = target_schema.field(field.name)
|
||||||
|
if target_field is None:
|
||||||
|
raise ValueError(f"Field {field.name} not found in target schema")
|
||||||
|
fields.append(target_field)
|
||||||
|
reordered_schema = pa.schema(fields, metadata=target_schema.metadata)
|
||||||
|
if not allow_subschema and len(reordered_schema) != len(target_schema):
|
||||||
|
raise ValueError(
|
||||||
|
"Input table has different number of columns than target schema"
|
||||||
|
)
|
||||||
|
|
||||||
|
if allow_subschema and len(reordered_schema) != len(target_schema):
|
||||||
|
fields = _infer_subschema(
|
||||||
|
list(iter(table.schema)), list(iter(reordered_schema))
|
||||||
|
)
|
||||||
|
subschema = pa.schema(fields, metadata=target_schema.metadata)
|
||||||
|
return table.cast(subschema)
|
||||||
|
else:
|
||||||
|
return table.cast(reordered_schema)
|
||||||
|
|
||||||
|
|
||||||
|
def _infer_subschema(
|
||||||
|
schema: List[pa.Field],
|
||||||
|
reference_fields: List[pa.Field],
|
||||||
|
) -> List[pa.Field]:
|
||||||
|
"""
|
||||||
|
Transform the list of fields so the types match the reference_fields.
|
||||||
|
|
||||||
|
The order of the fields is preserved.
|
||||||
|
|
||||||
|
``schema`` may have fewer fields than `reference_fields`, but it may not have
|
||||||
|
more fields.
|
||||||
|
|
||||||
|
"""
|
||||||
|
fields = []
|
||||||
|
lookup = {f.name: f for f in reference_fields}
|
||||||
|
for field in schema:
|
||||||
|
reference = lookup.get(field.name)
|
||||||
|
if reference is None:
|
||||||
|
raise ValueError("Unexpected field in schema: {}".format(field))
|
||||||
|
|
||||||
|
if pa.types.is_struct(reference.type):
|
||||||
|
new_type = pa.struct(
|
||||||
|
_infer_subschema(
|
||||||
|
field.type.fields,
|
||||||
|
reference.type.fields,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
new_field = pa.field(
|
||||||
|
field.name,
|
||||||
|
new_type,
|
||||||
|
reference.nullable,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
new_field = reference
|
||||||
|
|
||||||
|
fields.append(new_field)
|
||||||
|
|
||||||
|
return fields
|
||||||
|
|
||||||
|
|
||||||
def sanitize_create_table(
|
def sanitize_create_table(
|
||||||
@@ -193,13 +323,14 @@ def sanitize_create_table(
|
|||||||
if data is not None:
|
if data is not None:
|
||||||
if metadata is None and schema is not None:
|
if metadata is None and schema is not None:
|
||||||
metadata = schema.metadata
|
metadata = schema.metadata
|
||||||
data, schema = _sanitize_data(
|
data = _sanitize_data(
|
||||||
data,
|
data,
|
||||||
schema,
|
schema,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
on_bad_vectors=on_bad_vectors,
|
on_bad_vectors=on_bad_vectors,
|
||||||
fill_value=fill_value,
|
fill_value=fill_value,
|
||||||
)
|
)
|
||||||
|
schema = data.schema
|
||||||
else:
|
else:
|
||||||
if schema is not None:
|
if schema is not None:
|
||||||
data = pa.Table.from_pylist([], schema)
|
data = pa.Table.from_pylist([], schema)
|
||||||
@@ -211,6 +342,8 @@ def sanitize_create_table(
|
|||||||
|
|
||||||
if metadata:
|
if metadata:
|
||||||
schema = schema.with_metadata(metadata)
|
schema = schema.with_metadata(metadata)
|
||||||
|
# Need to apply metadata to the data as well
|
||||||
|
data = data.replace_schema_metadata(metadata)
|
||||||
|
|
||||||
return data, schema
|
return data, schema
|
||||||
|
|
||||||
@@ -246,12 +379,22 @@ def _to_batches_with_split(data):
|
|||||||
yield b
|
yield b
|
||||||
|
|
||||||
|
|
||||||
def _append_vector_col(data: pa.Table, metadata: dict, schema: Optional[pa.Schema]):
|
def _append_vector_columns(
|
||||||
|
data: pa.Table,
|
||||||
|
schema: Optional[pa.Schema] = None,
|
||||||
|
*,
|
||||||
|
metadata: Optional[dict] = None,
|
||||||
|
) -> pa.Table:
|
||||||
"""
|
"""
|
||||||
Use the embedding function to automatically embed the source column and add the
|
Use the embedding function to automatically embed the source columns and add the
|
||||||
vector column to the table.
|
vector columns to the table.
|
||||||
"""
|
"""
|
||||||
|
if schema is None:
|
||||||
|
metadata = metadata or {}
|
||||||
|
else:
|
||||||
|
metadata = schema.metadata or metadata or {}
|
||||||
functions = EmbeddingFunctionRegistry.get_instance().parse_functions(metadata)
|
functions = EmbeddingFunctionRegistry.get_instance().parse_functions(metadata)
|
||||||
|
|
||||||
for vector_column, conf in functions.items():
|
for vector_column, conf in functions.items():
|
||||||
func = conf.function
|
func = conf.function
|
||||||
no_vector_column = vector_column not in data.column_names
|
no_vector_column = vector_column not in data.column_names
|
||||||
@@ -790,9 +933,9 @@ class Table(ABC):
|
|||||||
--------
|
--------
|
||||||
>>> import lancedb
|
>>> import lancedb
|
||||||
>>> data = [
|
>>> data = [
|
||||||
... {"x": 1, "vector": [1, 2]},
|
... {"x": 1, "vector": [1.0, 2]},
|
||||||
... {"x": 2, "vector": [3, 4]},
|
... {"x": 2, "vector": [3.0, 4]},
|
||||||
... {"x": 3, "vector": [5, 6]}
|
... {"x": 3, "vector": [5.0, 6]}
|
||||||
... ]
|
... ]
|
||||||
>>> db = lancedb.connect("./.lancedb")
|
>>> db = lancedb.connect("./.lancedb")
|
||||||
>>> table = db.create_table("my_table", data)
|
>>> table = db.create_table("my_table", data)
|
||||||
@@ -854,7 +997,7 @@ class Table(ABC):
|
|||||||
--------
|
--------
|
||||||
>>> import lancedb
|
>>> import lancedb
|
||||||
>>> import pandas as pd
|
>>> import pandas as pd
|
||||||
>>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1, 2], [3, 4], [5, 6]]})
|
>>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1.0, 2], [3, 4], [5, 6]]})
|
||||||
>>> db = lancedb.connect("./.lancedb")
|
>>> db = lancedb.connect("./.lancedb")
|
||||||
>>> table = db.create_table("my_table", data)
|
>>> table = db.create_table("my_table", data)
|
||||||
>>> table.to_pandas()
|
>>> table.to_pandas()
|
||||||
@@ -862,7 +1005,7 @@ class Table(ABC):
|
|||||||
0 1 [1.0, 2.0]
|
0 1 [1.0, 2.0]
|
||||||
1 2 [3.0, 4.0]
|
1 2 [3.0, 4.0]
|
||||||
2 3 [5.0, 6.0]
|
2 3 [5.0, 6.0]
|
||||||
>>> table.update(where="x = 2", values={"vector": [10, 10]})
|
>>> table.update(where="x = 2", values={"vector": [10.0, 10]})
|
||||||
>>> table.to_pandas()
|
>>> table.to_pandas()
|
||||||
x vector
|
x vector
|
||||||
0 1 [1.0, 2.0]
|
0 1 [1.0, 2.0]
|
||||||
@@ -1880,9 +2023,9 @@ class LanceTable(Table):
|
|||||||
--------
|
--------
|
||||||
>>> import lancedb
|
>>> import lancedb
|
||||||
>>> data = [
|
>>> data = [
|
||||||
... {"x": 1, "vector": [1, 2]},
|
... {"x": 1, "vector": [1.0, 2]},
|
||||||
... {"x": 2, "vector": [3, 4]},
|
... {"x": 2, "vector": [3.0, 4]},
|
||||||
... {"x": 3, "vector": [5, 6]}
|
... {"x": 3, "vector": [5.0, 6]}
|
||||||
... ]
|
... ]
|
||||||
>>> db = lancedb.connect("./.lancedb")
|
>>> db = lancedb.connect("./.lancedb")
|
||||||
>>> table = db.create_table("my_table", data)
|
>>> table = db.create_table("my_table", data)
|
||||||
@@ -1971,7 +2114,7 @@ class LanceTable(Table):
|
|||||||
--------
|
--------
|
||||||
>>> import lancedb
|
>>> import lancedb
|
||||||
>>> import pandas as pd
|
>>> import pandas as pd
|
||||||
>>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1, 2], [3, 4], [5, 6]]})
|
>>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1.0, 2], [3, 4], [5, 6]]})
|
||||||
>>> db = lancedb.connect("./.lancedb")
|
>>> db = lancedb.connect("./.lancedb")
|
||||||
>>> table = db.create_table("my_table", data)
|
>>> table = db.create_table("my_table", data)
|
||||||
>>> table.to_pandas()
|
>>> table.to_pandas()
|
||||||
@@ -1979,7 +2122,7 @@ class LanceTable(Table):
|
|||||||
0 1 [1.0, 2.0]
|
0 1 [1.0, 2.0]
|
||||||
1 2 [3.0, 4.0]
|
1 2 [3.0, 4.0]
|
||||||
2 3 [5.0, 6.0]
|
2 3 [5.0, 6.0]
|
||||||
>>> table.update(where="x = 2", values={"vector": [10, 10]})
|
>>> table.update(where="x = 2", values={"vector": [10.0, 10]})
|
||||||
>>> table.to_pandas()
|
>>> table.to_pandas()
|
||||||
x vector
|
x vector
|
||||||
0 1 [1.0, 2.0]
|
0 1 [1.0, 2.0]
|
||||||
@@ -2165,74 +2308,49 @@ class LanceTable(Table):
|
|||||||
LOOP.run(self._table.migrate_v2_manifest_paths())
|
LOOP.run(self._table.migrate_v2_manifest_paths())
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_schema(
|
def _handle_bad_vectors(
|
||||||
data: pa.Table,
|
table: pa.Table,
|
||||||
schema: pa.Schema = None,
|
on_bad_vectors: Literal["error", "drop", "fill", "null"] = "error",
|
||||||
on_bad_vectors: str = "error",
|
|
||||||
fill_value: float = 0.0,
|
fill_value: float = 0.0,
|
||||||
) -> pa.Table:
|
) -> pa.Table:
|
||||||
"""Ensure that the table has the expected schema.
|
for field in table.schema:
|
||||||
|
# They can provide a 'vector' column that isn't yet a FSL
|
||||||
Parameters
|
named_vector_col = (
|
||||||
----------
|
(
|
||||||
data: pa.Table
|
pa.types.is_list(field.type)
|
||||||
The table to sanitize.
|
or pa.types.is_large_list(field.type)
|
||||||
schema: pa.Schema; optional
|
or pa.types.is_fixed_size_list(field.type)
|
||||||
The expected schema. If not provided, this just converts the
|
|
||||||
vector column to fixed_size_list(float32) if necessary.
|
|
||||||
on_bad_vectors: str, default "error"
|
|
||||||
What to do if any of the vectors are not the same size or contains NaNs.
|
|
||||||
One of "error", "drop", "fill", "null".
|
|
||||||
fill_value: float, default 0.
|
|
||||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
|
||||||
"""
|
|
||||||
if schema is not None:
|
|
||||||
# cast the columns to the expected types
|
|
||||||
data = data.combine_chunks()
|
|
||||||
for field in schema:
|
|
||||||
# TODO: we're making an assumption that fixed size list of 10 or more
|
|
||||||
# is a vector column. This is definitely a bit hacky.
|
|
||||||
likely_vector_col = (
|
|
||||||
pa.types.is_fixed_size_list(field.type)
|
|
||||||
and pa.types.is_float32(field.type.value_type)
|
|
||||||
and field.type.list_size >= 10
|
|
||||||
)
|
)
|
||||||
is_default_vector_col = field.name == VECTOR_COLUMN_NAME
|
and pa.types.is_floating(field.type.value_type)
|
||||||
if field.name in data.column_names and (
|
and field.name == VECTOR_COLUMN_NAME
|
||||||
likely_vector_col or is_default_vector_col
|
)
|
||||||
):
|
# TODO: we're making an assumption that fixed size list of 10 or more
|
||||||
data = _sanitize_vector_column(
|
# is a vector column. This is definitely a bit hacky.
|
||||||
data,
|
likely_vector_col = (
|
||||||
vector_column_name=field.name,
|
pa.types.is_fixed_size_list(field.type)
|
||||||
on_bad_vectors=on_bad_vectors,
|
and pa.types.is_floating(field.type.value_type)
|
||||||
fill_value=fill_value,
|
and (field.type.list_size >= 10)
|
||||||
table_schema=schema,
|
|
||||||
)
|
|
||||||
return pa.Table.from_arrays(
|
|
||||||
[data[name] for name in schema.names], schema=schema
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# just check the vector column
|
if named_vector_col or likely_vector_col:
|
||||||
if VECTOR_COLUMN_NAME in data.column_names:
|
table = _handle_bad_vector_column(
|
||||||
return _sanitize_vector_column(
|
table,
|
||||||
data,
|
vector_column_name=field.name,
|
||||||
vector_column_name=VECTOR_COLUMN_NAME,
|
on_bad_vectors=on_bad_vectors,
|
||||||
on_bad_vectors=on_bad_vectors,
|
fill_value=fill_value,
|
||||||
fill_value=fill_value,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
return data
|
return table
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_vector_column(
|
def _handle_bad_vector_column(
|
||||||
data: pa.Table,
|
data: pa.Table,
|
||||||
vector_column_name: str,
|
vector_column_name: str,
|
||||||
table_schema: Optional[pa.Schema] = None,
|
|
||||||
on_bad_vectors: str = "error",
|
on_bad_vectors: str = "error",
|
||||||
fill_value: float = 0.0,
|
fill_value: float = 0.0,
|
||||||
) -> pa.Table:
|
) -> pa.Table:
|
||||||
"""
|
"""
|
||||||
Ensure that the vector column exists and has type fixed_size_list(float32)
|
Ensure that the vector column exists and has type fixed_size_list(float)
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
@@ -2246,141 +2364,118 @@ def _sanitize_vector_column(
|
|||||||
fill_value: float, default 0.0
|
fill_value: float, default 0.0
|
||||||
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
The value to use when filling vectors. Only used if on_bad_vectors="fill".
|
||||||
"""
|
"""
|
||||||
# ChunkedArray is annoying to work with, so we combine chunks here
|
vec_arr = data[vector_column_name]
|
||||||
vec_arr = data[vector_column_name].combine_chunks()
|
|
||||||
if table_schema is not None:
|
|
||||||
field = table_schema.field(vector_column_name)
|
|
||||||
else:
|
|
||||||
field = None
|
|
||||||
typ = data[vector_column_name].type
|
|
||||||
if pa.types.is_list(typ) or pa.types.is_large_list(typ):
|
|
||||||
# if it's a variable size list array,
|
|
||||||
# we make sure the dimensions are all the same
|
|
||||||
has_jagged_ndims = len(vec_arr.values) % len(data) != 0
|
|
||||||
if has_jagged_ndims:
|
|
||||||
data = _sanitize_jagged(
|
|
||||||
data, fill_value, on_bad_vectors, vec_arr, vector_column_name
|
|
||||||
)
|
|
||||||
vec_arr = data[vector_column_name].combine_chunks()
|
|
||||||
vec_arr = ensure_fixed_size_list(vec_arr)
|
|
||||||
data = data.set_column(
|
|
||||||
data.column_names.index(vector_column_name), vector_column_name, vec_arr
|
|
||||||
)
|
|
||||||
elif not pa.types.is_fixed_size_list(vec_arr.type):
|
|
||||||
raise TypeError(f"Unsupported vector column type: {vec_arr.type}")
|
|
||||||
|
|
||||||
if pa.types.is_float16(vec_arr.values.type):
|
has_nan = has_nan_values(vec_arr)
|
||||||
# Use numpy to check for NaNs, because as pyarrow does not have `is_nan`
|
|
||||||
# kernel over f16 types yet.
|
|
||||||
values_np = vec_arr.values.to_numpy(zero_copy_only=True)
|
|
||||||
if np.isnan(values_np).any():
|
|
||||||
data = _sanitize_nans(
|
|
||||||
data, fill_value, on_bad_vectors, vec_arr, vector_column_name
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if (
|
|
||||||
field is not None
|
|
||||||
and not field.nullable
|
|
||||||
and pc.any(pc.is_null(vec_arr.values)).as_py()
|
|
||||||
) or (pc.any(pc.is_nan(vec_arr.values)).as_py()):
|
|
||||||
data = _sanitize_nans(
|
|
||||||
data, fill_value, on_bad_vectors, vec_arr, vector_column_name
|
|
||||||
)
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
def ensure_fixed_size_list(vec_arr) -> pa.FixedSizeListArray:
|
|
||||||
values = vec_arr.values
|
|
||||||
if not (pa.types.is_float16(values.type) or pa.types.is_float32(values.type)):
|
|
||||||
values = values.cast(pa.float32())
|
|
||||||
if pa.types.is_fixed_size_list(vec_arr.type):
|
if pa.types.is_fixed_size_list(vec_arr.type):
|
||||||
list_size = vec_arr.type.list_size
|
dim = vec_arr.type.list_size
|
||||||
else:
|
else:
|
||||||
list_size = len(values) / len(vec_arr)
|
dim = _modal_list_size(vec_arr)
|
||||||
vec_arr = pa.FixedSizeListArray.from_arrays(values, list_size)
|
has_wrong_dim = pc.not_equal(pc.list_value_length(vec_arr), dim)
|
||||||
return vec_arr
|
|
||||||
|
|
||||||
|
has_bad_vectors = pc.any(has_nan).as_py() or pc.any(has_wrong_dim).as_py()
|
||||||
|
|
||||||
def _sanitize_jagged(data, fill_value, on_bad_vectors, vec_arr, vector_column_name):
|
if has_bad_vectors:
|
||||||
"""Sanitize jagged vectors."""
|
is_bad = pc.or_(has_nan, has_wrong_dim)
|
||||||
if on_bad_vectors == "error":
|
if on_bad_vectors == "error":
|
||||||
raise ValueError(
|
if pc.any(has_wrong_dim).as_py():
|
||||||
f"Vector column {vector_column_name} has variable length vectors "
|
raise ValueError(
|
||||||
"Set on_bad_vectors='drop' to remove them, or "
|
f"Vector column '{vector_column_name}' has variable length "
|
||||||
"set on_bad_vectors='fill' and fill_value=<value> to replace them."
|
"vectors. Set on_bad_vectors='drop' to remove them, "
|
||||||
)
|
"set on_bad_vectors='fill' and fill_value=<value> to replace them, "
|
||||||
|
"or set on_bad_vectors='null' to replace them with null."
|
||||||
lst_lengths = pc.list_value_length(vec_arr)
|
)
|
||||||
ndims = pc.max(lst_lengths).as_py()
|
else:
|
||||||
correct_ndims = pc.equal(lst_lengths, ndims)
|
raise ValueError(
|
||||||
|
f"Vector column '{vector_column_name}' has NaNs. "
|
||||||
if on_bad_vectors == "fill":
|
"Set on_bad_vectors='drop' to remove them, "
|
||||||
if fill_value is None:
|
"set on_bad_vectors='fill' and fill_value=<value> to replace them, "
|
||||||
raise ValueError(
|
"or set on_bad_vectors='null' to replace them with null."
|
||||||
"`fill_value` must not be None if `on_bad_vectors` is 'fill'"
|
)
|
||||||
|
elif on_bad_vectors == "null":
|
||||||
|
vec_arr = pc.if_else(
|
||||||
|
is_bad,
|
||||||
|
pa.scalar(None),
|
||||||
|
vec_arr,
|
||||||
)
|
)
|
||||||
fill_arr = pa.scalar([float(fill_value)] * ndims)
|
elif on_bad_vectors == "drop":
|
||||||
vec_arr = pc.if_else(correct_ndims, vec_arr, fill_arr)
|
data = data.filter(pc.invert(is_bad))
|
||||||
data = data.set_column(
|
vec_arr = data[vector_column_name]
|
||||||
data.column_names.index(vector_column_name), vector_column_name, vec_arr
|
elif on_bad_vectors == "fill":
|
||||||
)
|
if fill_value is None:
|
||||||
elif on_bad_vectors == "drop":
|
raise ValueError(
|
||||||
data = data.filter(correct_ndims)
|
"`fill_value` must not be None if `on_bad_vectors` is 'fill'"
|
||||||
elif on_bad_vectors == "null":
|
)
|
||||||
data = data.set_column(
|
vec_arr = pc.if_else(
|
||||||
data.column_names.index(vector_column_name),
|
is_bad,
|
||||||
vector_column_name,
|
pa.scalar([fill_value] * dim),
|
||||||
pc.if_else(correct_ndims, vec_arr, pa.scalar(None)),
|
vec_arr,
|
||||||
)
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_nans(
|
|
||||||
data,
|
|
||||||
fill_value,
|
|
||||||
on_bad_vectors,
|
|
||||||
vec_arr: pa.FixedSizeListArray,
|
|
||||||
vector_column_name: str,
|
|
||||||
):
|
|
||||||
"""Sanitize NaNs in vectors"""
|
|
||||||
assert pa.types.is_fixed_size_list(vec_arr.type)
|
|
||||||
if on_bad_vectors == "error":
|
|
||||||
raise ValueError(
|
|
||||||
f"Vector column {vector_column_name} has NaNs. "
|
|
||||||
"Set on_bad_vectors='drop' to remove them, or "
|
|
||||||
"set on_bad_vectors='fill' and fill_value=<value> to replace them. "
|
|
||||||
"Or set on_bad_vectors='null' to replace them with null."
|
|
||||||
)
|
|
||||||
elif on_bad_vectors == "fill":
|
|
||||||
if fill_value is None:
|
|
||||||
raise ValueError(
|
|
||||||
"`fill_value` must not be None if `on_bad_vectors` is 'fill'"
|
|
||||||
)
|
)
|
||||||
fill_value = float(fill_value)
|
else:
|
||||||
values = pc.if_else(pc.is_nan(vec_arr.values), fill_value, vec_arr.values)
|
raise ValueError(f"Invalid value for on_bad_vectors: {on_bad_vectors}")
|
||||||
ndims = len(vec_arr[0])
|
|
||||||
vec_arr = pa.FixedSizeListArray.from_arrays(values, ndims)
|
position = data.column_names.index(vector_column_name)
|
||||||
data = data.set_column(
|
return data.set_column(position, vector_column_name, vec_arr)
|
||||||
data.column_names.index(vector_column_name), vector_column_name, vec_arr
|
|
||||||
)
|
|
||||||
elif on_bad_vectors == "drop":
|
def has_nan_values(arr: Union[pa.ListArray, pa.ChunkedArray]) -> pa.BooleanArray:
|
||||||
# Drop is very slow to be able to filter out NaNs in a fixed size list array
|
if isinstance(arr, pa.ChunkedArray):
|
||||||
np_arr = np.isnan(vec_arr.values.to_numpy(zero_copy_only=False))
|
values = pa.chunked_array([chunk.flatten() for chunk in arr.chunks])
|
||||||
np_arr = np_arr.reshape(-1, vec_arr.type.list_size)
|
else:
|
||||||
not_nulls = np.any(np_arr, axis=1)
|
values = arr.flatten()
|
||||||
data = data.filter(~not_nulls)
|
if pa.types.is_float16(values.type):
|
||||||
elif on_bad_vectors == "null":
|
# is_nan isn't yet implemented for f16, so we cast to f32
|
||||||
# null = pa.nulls(len(vec_arr)).cast(vec_arr.type)
|
# https://github.com/apache/arrow/issues/45083
|
||||||
# values = pc.if_else(pc.is_nan(vec_arr.values), fill_value, vec_arr.values)
|
values_has_nan = pc.is_nan(values.cast(pa.float32()))
|
||||||
np_arr = np.isnan(vec_arr.values.to_numpy(zero_copy_only=False))
|
else:
|
||||||
np_arr = np_arr.reshape(-1, vec_arr.type.list_size)
|
values_has_nan = pc.is_nan(values)
|
||||||
no_nans = np.any(np_arr, axis=1)
|
values_indices = pc.list_parent_indices(arr)
|
||||||
data = data.set_column(
|
has_nan_indices = pc.unique(pc.filter(values_indices, values_has_nan))
|
||||||
data.column_names.index(vector_column_name),
|
indices = pa.array(range(len(arr)), type=pa.uint32())
|
||||||
vector_column_name,
|
return pc.is_in(indices, has_nan_indices)
|
||||||
pc.if_else(no_nans, vec_arr, pa.scalar(None)),
|
|
||||||
)
|
|
||||||
return data
|
def _infer_target_schema(table: pa.Table) -> pa.Schema:
|
||||||
|
schema = table.schema
|
||||||
|
|
||||||
|
for i, field in enumerate(schema):
|
||||||
|
if (
|
||||||
|
field.name == VECTOR_COLUMN_NAME
|
||||||
|
and (pa.types.is_list(field.type) or pa.types.is_large_list(field.type))
|
||||||
|
and pa.types.is_floating(field.type.value_type)
|
||||||
|
):
|
||||||
|
# Use the most common length of the list as the dimensions
|
||||||
|
dim = _modal_list_size(table.column(i))
|
||||||
|
|
||||||
|
new_field = pa.field(
|
||||||
|
VECTOR_COLUMN_NAME,
|
||||||
|
pa.list_(pa.float32(), dim),
|
||||||
|
nullable=field.nullable,
|
||||||
|
)
|
||||||
|
|
||||||
|
schema = schema.set(i, new_field)
|
||||||
|
elif (
|
||||||
|
field.name == VECTOR_COLUMN_NAME
|
||||||
|
and (pa.types.is_list(field.type) or pa.types.is_large_list(field.type))
|
||||||
|
and pa.types.is_integer(field.type.value_type)
|
||||||
|
):
|
||||||
|
# Use the most common length of the list as the dimensions
|
||||||
|
dim = _modal_list_size(table.column(i))
|
||||||
|
new_field = pa.field(
|
||||||
|
VECTOR_COLUMN_NAME,
|
||||||
|
pa.list_(pa.uint8(), dim),
|
||||||
|
nullable=field.nullable,
|
||||||
|
)
|
||||||
|
|
||||||
|
schema = schema.set(i, new_field)
|
||||||
|
|
||||||
|
return schema
|
||||||
|
|
||||||
|
|
||||||
|
def _modal_list_size(arr: Union[pa.ListArray, pa.ChunkedArray]) -> int:
|
||||||
|
# Use the most common length of the list as the dimensions
|
||||||
|
return pc.mode(pc.list_value_length(arr))[0].as_py()["mode"]
|
||||||
|
|
||||||
|
|
||||||
def _validate_schema(schema: pa.Schema):
|
def _validate_schema(schema: pa.Schema):
|
||||||
@@ -2410,28 +2505,6 @@ def _validate_metadata(metadata: dict):
|
|||||||
_validate_metadata(v)
|
_validate_metadata(v)
|
||||||
|
|
||||||
|
|
||||||
def _process_iterator(data: Iterable, schema: Optional[pa.Schema] = None) -> pa.Table:
|
|
||||||
batches = []
|
|
||||||
for batch in data:
|
|
||||||
batch_table = _coerce_to_table(batch, schema)
|
|
||||||
if schema is not None:
|
|
||||||
if batch_table.schema != schema:
|
|
||||||
try:
|
|
||||||
batch_table = batch_table.cast(schema)
|
|
||||||
except pa.lib.ArrowInvalid: # type: ignore
|
|
||||||
raise ValueError(
|
|
||||||
f"Input iterator yielded a batch with schema that "
|
|
||||||
f"does not match the expected schema.\nExpected:\n{schema}\n"
|
|
||||||
f"Got:\n{batch_table.schema}"
|
|
||||||
)
|
|
||||||
batches.append(batch_table)
|
|
||||||
|
|
||||||
if batches:
|
|
||||||
return pa.concat_tables(batches)
|
|
||||||
else:
|
|
||||||
raise ValueError("Input iterable is empty")
|
|
||||||
|
|
||||||
|
|
||||||
class AsyncTable:
|
class AsyncTable:
|
||||||
"""
|
"""
|
||||||
An AsyncTable is a collection of Records in a LanceDB Database.
|
An AsyncTable is a collection of Records in a LanceDB Database.
|
||||||
@@ -2678,16 +2751,17 @@ class AsyncTable:
|
|||||||
on_bad_vectors = "error"
|
on_bad_vectors = "error"
|
||||||
if fill_value is None:
|
if fill_value is None:
|
||||||
fill_value = 0.0
|
fill_value = 0.0
|
||||||
table_and_schema: Tuple[pa.Table, pa.Schema] = _sanitize_data(
|
data = _sanitize_data(
|
||||||
data,
|
data,
|
||||||
schema,
|
schema,
|
||||||
metadata=schema.metadata,
|
metadata=schema.metadata,
|
||||||
on_bad_vectors=on_bad_vectors,
|
on_bad_vectors=on_bad_vectors,
|
||||||
fill_value=fill_value,
|
fill_value=fill_value,
|
||||||
|
allow_subschema=True,
|
||||||
)
|
)
|
||||||
tbl, schema = table_and_schema
|
if isinstance(data, pa.Table):
|
||||||
if isinstance(tbl, pa.Table):
|
data = data.to_reader()
|
||||||
data = pa.RecordBatchReader.from_batches(schema, tbl.to_batches())
|
|
||||||
await self._inner.add(data, mode or "append")
|
await self._inner.add(data, mode or "append")
|
||||||
|
|
||||||
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
|
def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder:
|
||||||
@@ -2822,12 +2896,13 @@ class AsyncTable:
|
|||||||
on_bad_vectors = "error"
|
on_bad_vectors = "error"
|
||||||
if fill_value is None:
|
if fill_value is None:
|
||||||
fill_value = 0.0
|
fill_value = 0.0
|
||||||
data, _ = _sanitize_data(
|
data = _sanitize_data(
|
||||||
new_data,
|
new_data,
|
||||||
schema,
|
schema,
|
||||||
metadata=schema.metadata,
|
metadata=schema.metadata,
|
||||||
on_bad_vectors=on_bad_vectors,
|
on_bad_vectors=on_bad_vectors,
|
||||||
fill_value=fill_value,
|
fill_value=fill_value,
|
||||||
|
allow_subschema=True,
|
||||||
)
|
)
|
||||||
if isinstance(data, pa.Table):
|
if isinstance(data, pa.Table):
|
||||||
data = pa.RecordBatchReader.from_batches(data.schema, data.to_batches())
|
data = pa.RecordBatchReader.from_batches(data.schema, data.to_batches())
|
||||||
@@ -2862,9 +2937,9 @@ class AsyncTable:
|
|||||||
--------
|
--------
|
||||||
>>> import lancedb
|
>>> import lancedb
|
||||||
>>> data = [
|
>>> data = [
|
||||||
... {"x": 1, "vector": [1, 2]},
|
... {"x": 1, "vector": [1.0, 2]},
|
||||||
... {"x": 2, "vector": [3, 4]},
|
... {"x": 2, "vector": [3.0, 4]},
|
||||||
... {"x": 3, "vector": [5, 6]}
|
... {"x": 3, "vector": [5.0, 6]}
|
||||||
... ]
|
... ]
|
||||||
>>> db = lancedb.connect("./.lancedb")
|
>>> db = lancedb.connect("./.lancedb")
|
||||||
>>> table = db.create_table("my_table", data)
|
>>> table = db.create_table("my_table", data)
|
||||||
|
|||||||
@@ -223,9 +223,7 @@ def inf_vector_column_query(schema: pa.Schema) -> str:
|
|||||||
vector_col_count = 0
|
vector_col_count = 0
|
||||||
for field_name in schema.names:
|
for field_name in schema.names:
|
||||||
field = schema.field(field_name)
|
field = schema.field(field_name)
|
||||||
if pa.types.is_fixed_size_list(field.type) and pa.types.is_floating(
|
if pa.types.is_fixed_size_list(field.type):
|
||||||
field.type.value_type
|
|
||||||
):
|
|
||||||
vector_col_count += 1
|
vector_col_count += 1
|
||||||
if vector_col_count > 1:
|
if vector_col_count > 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ def test_binary_vector():
|
|||||||
]
|
]
|
||||||
tbl = db.create_table("my_binary_vectors", data=data)
|
tbl = db.create_table("my_binary_vectors", data=data)
|
||||||
query = np.random.randint(0, 256, size=16)
|
query = np.random.randint(0, 256, size=16)
|
||||||
tbl.search(query).to_arrow()
|
tbl.search(query).metric("hamming").to_arrow()
|
||||||
# --8<-- [end:sync_binary_vector]
|
# --8<-- [end:sync_binary_vector]
|
||||||
db.drop_table("my_binary_vectors")
|
db.drop_table("my_binary_vectors")
|
||||||
|
|
||||||
@@ -39,6 +39,6 @@ async def test_binary_vector_async():
|
|||||||
]
|
]
|
||||||
tbl = await db.create_table("my_binary_vectors", data=data)
|
tbl = await db.create_table("my_binary_vectors", data=data)
|
||||||
query = np.random.randint(0, 256, size=16)
|
query = np.random.randint(0, 256, size=16)
|
||||||
await tbl.query().nearest_to(query).to_arrow()
|
await tbl.query().nearest_to(query).distance_type("hamming").to_arrow()
|
||||||
# --8<-- [end:async_binary_vector]
|
# --8<-- [end:async_binary_vector]
|
||||||
await db.drop_table("my_binary_vectors")
|
await db.drop_table("my_binary_vectors")
|
||||||
|
|||||||
@@ -118,9 +118,9 @@ def test_scalar_index():
|
|||||||
# --8<-- [end:search_with_scalar_index]
|
# --8<-- [end:search_with_scalar_index]
|
||||||
# --8<-- [start:vector_search_with_scalar_index]
|
# --8<-- [start:vector_search_with_scalar_index]
|
||||||
data = [
|
data = [
|
||||||
{"book_id": 1, "vector": [1, 2]},
|
{"book_id": 1, "vector": [1.0, 2]},
|
||||||
{"book_id": 2, "vector": [3, 4]},
|
{"book_id": 2, "vector": [3.0, 4]},
|
||||||
{"book_id": 3, "vector": [5, 6]},
|
{"book_id": 3, "vector": [5.0, 6]},
|
||||||
]
|
]
|
||||||
|
|
||||||
table = db.create_table("book_with_embeddings", data)
|
table = db.create_table("book_with_embeddings", data)
|
||||||
@@ -156,9 +156,9 @@ async def test_scalar_index_async():
|
|||||||
# --8<-- [end:search_with_scalar_index_async]
|
# --8<-- [end:search_with_scalar_index_async]
|
||||||
# --8<-- [start:vector_search_with_scalar_index_async]
|
# --8<-- [start:vector_search_with_scalar_index_async]
|
||||||
data = [
|
data = [
|
||||||
{"book_id": 1, "vector": [1, 2]},
|
{"book_id": 1, "vector": [1.0, 2]},
|
||||||
{"book_id": 2, "vector": [3, 4]},
|
{"book_id": 2, "vector": [3.0, 4]},
|
||||||
{"book_id": 3, "vector": [5, 6]},
|
{"book_id": 3, "vector": [5.0, 6]},
|
||||||
]
|
]
|
||||||
async_tbl = await async_db.create_table("book_with_embeddings_async", data)
|
async_tbl = await async_db.create_table("book_with_embeddings_async", data)
|
||||||
(await async_tbl.query().where("book_id != 3").nearest_to([1, 2]).to_pandas())
|
(await async_tbl.query().where("book_id != 3").nearest_to([1, 2]).to_pandas())
|
||||||
|
|||||||
@@ -198,7 +198,6 @@ def test_embedding_function_with_pandas(tmp_path):
|
|||||||
{
|
{
|
||||||
"text": ["hello world", "goodbye world"],
|
"text": ["hello world", "goodbye world"],
|
||||||
"val": [1, 2],
|
"val": [1, 2],
|
||||||
"not-used": ["s1", "s3"],
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
db = lancedb.connect(tmp_path)
|
db = lancedb.connect(tmp_path)
|
||||||
@@ -212,7 +211,6 @@ def test_embedding_function_with_pandas(tmp_path):
|
|||||||
{
|
{
|
||||||
"text": ["extra", "more"],
|
"text": ["extra", "more"],
|
||||||
"val": [4, 5],
|
"val": [4, 5],
|
||||||
"misc-col": ["s1", "s3"],
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
tbl.add(df)
|
tbl.add(df)
|
||||||
|
|||||||
@@ -242,8 +242,8 @@ def test_add_subschema(mem_db: DBConnection):
|
|||||||
|
|
||||||
data = {"price": 10.0, "item": "foo"}
|
data = {"price": 10.0, "item": "foo"}
|
||||||
table.add([data])
|
table.add([data])
|
||||||
data = {"price": 2.0, "vector": [3.1, 4.1]}
|
data = pd.DataFrame({"price": [2.0], "vector": [[3.1, 4.1]]})
|
||||||
table.add([data])
|
table.add(data)
|
||||||
data = {"price": 3.0, "vector": [5.9, 26.5], "item": "bar"}
|
data = {"price": 3.0, "vector": [5.9, 26.5], "item": "bar"}
|
||||||
table.add([data])
|
table.add([data])
|
||||||
|
|
||||||
@@ -259,7 +259,7 @@ def test_add_subschema(mem_db: DBConnection):
|
|||||||
|
|
||||||
data = {"item": "foo"}
|
data = {"item": "foo"}
|
||||||
# We can't omit a column if it's not nullable
|
# We can't omit a column if it's not nullable
|
||||||
with pytest.raises(RuntimeError, match="Invalid user input"):
|
with pytest.raises(RuntimeError, match="Append with different schema"):
|
||||||
table.add([data])
|
table.add([data])
|
||||||
|
|
||||||
# We can add it if we make the column nullable
|
# We can add it if we make the column nullable
|
||||||
@@ -292,6 +292,7 @@ def test_add_nullability(mem_db: DBConnection):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
table = mem_db.create_table("test", schema=schema)
|
table = mem_db.create_table("test", schema=schema)
|
||||||
|
assert table.schema.field("vector").nullable is False
|
||||||
|
|
||||||
nullable_schema = pa.schema(
|
nullable_schema = pa.schema(
|
||||||
[
|
[
|
||||||
@@ -320,7 +321,10 @@ def test_add_nullability(mem_db: DBConnection):
|
|||||||
schema=nullable_schema,
|
schema=nullable_schema,
|
||||||
)
|
)
|
||||||
# We can't add nullable schema if it contains nulls
|
# We can't add nullable schema if it contains nulls
|
||||||
with pytest.raises(Exception, match="Vector column vector has NaNs"):
|
with pytest.raises(
|
||||||
|
Exception,
|
||||||
|
match="Casting field 'vector' with null values to non-nullable",
|
||||||
|
):
|
||||||
table.add(data)
|
table.add(data)
|
||||||
|
|
||||||
# But we can make it nullable
|
# But we can make it nullable
|
||||||
@@ -776,6 +780,38 @@ def test_merge_insert(mem_db: DBConnection):
|
|||||||
assert table.to_arrow().sort_by("a") == expected
|
assert table.to_arrow().sort_by("a") == expected
|
||||||
|
|
||||||
|
|
||||||
|
# We vary the data format because there are slight differences in how
|
||||||
|
# subschemas are handled in different formats
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"data_format",
|
||||||
|
[
|
||||||
|
lambda table: table,
|
||||||
|
lambda table: table.to_pandas(),
|
||||||
|
lambda table: table.to_pylist(),
|
||||||
|
],
|
||||||
|
ids=["pa.Table", "pd.DataFrame", "rows"],
|
||||||
|
)
|
||||||
|
def test_merge_insert_subschema(mem_db: DBConnection, data_format):
|
||||||
|
initial_data = pa.table(
|
||||||
|
{"id": range(3), "a": [1.0, 2.0, 3.0], "c": ["x", "x", "x"]}
|
||||||
|
)
|
||||||
|
table = mem_db.create_table("my_table", data=initial_data)
|
||||||
|
|
||||||
|
new_data = pa.table({"id": [2, 3], "c": ["y", "y"]})
|
||||||
|
new_data = data_format(new_data)
|
||||||
|
(
|
||||||
|
table.merge_insert(on="id")
|
||||||
|
.when_matched_update_all()
|
||||||
|
.when_not_matched_insert_all()
|
||||||
|
.execute(new_data)
|
||||||
|
)
|
||||||
|
|
||||||
|
expected = pa.table(
|
||||||
|
{"id": [0, 1, 2, 3], "a": [1.0, 2.0, 3.0, None], "c": ["x", "x", "y", "y"]}
|
||||||
|
)
|
||||||
|
assert table.to_arrow().sort_by("id") == expected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_merge_insert_async(mem_db_async: AsyncConnection):
|
async def test_merge_insert_async(mem_db_async: AsyncConnection):
|
||||||
data = pa.table({"a": [1, 2, 3], "b": ["a", "b", "c"]})
|
data = pa.table({"a": [1, 2, 3], "b": ["a", "b", "c"]})
|
||||||
|
|||||||
@@ -13,10 +13,27 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import lance
|
||||||
|
from lancedb.conftest import MockTextEmbeddingFunction
|
||||||
|
from lancedb.embeddings.base import EmbeddingFunctionConfig
|
||||||
|
from lancedb.embeddings.registry import EmbeddingFunctionRegistry
|
||||||
|
from lancedb.table import (
|
||||||
|
_append_vector_columns,
|
||||||
|
_cast_to_target_schema,
|
||||||
|
_handle_bad_vectors,
|
||||||
|
_into_pyarrow_table,
|
||||||
|
_sanitize_data,
|
||||||
|
_infer_target_schema,
|
||||||
|
)
|
||||||
|
import pyarrow as pa
|
||||||
|
import pandas as pd
|
||||||
|
import polars as pl
|
||||||
import pytest
|
import pytest
|
||||||
import lancedb
|
import lancedb
|
||||||
from lancedb.util import get_uri_scheme, join_uri, value_to_sql
|
from lancedb.util import get_uri_scheme, join_uri, value_to_sql
|
||||||
|
from utils import exception_output
|
||||||
|
|
||||||
|
|
||||||
def test_normalize_uri():
|
def test_normalize_uri():
|
||||||
@@ -111,3 +128,460 @@ def test_value_to_sql_string(tmp_path):
|
|||||||
for value in values:
|
for value in values:
|
||||||
table.update(where=f"search = {value_to_sql(value)}", values={"replace": value})
|
table.update(where=f"search = {value_to_sql(value)}", values={"replace": value})
|
||||||
assert table.to_pandas().query("search == @value")["replace"].item() == value
|
assert table.to_pandas().query("search == @value")["replace"].item() == value
|
||||||
|
|
||||||
|
|
||||||
|
def test_append_vector_columns():
|
||||||
|
registry = EmbeddingFunctionRegistry.get_instance()
|
||||||
|
registry.register("test")(MockTextEmbeddingFunction)
|
||||||
|
conf = EmbeddingFunctionConfig(
|
||||||
|
source_column="text",
|
||||||
|
vector_column="vector",
|
||||||
|
function=MockTextEmbeddingFunction(),
|
||||||
|
)
|
||||||
|
metadata = registry.get_table_metadata([conf])
|
||||||
|
|
||||||
|
schema = pa.schema(
|
||||||
|
{
|
||||||
|
"text": pa.string(),
|
||||||
|
"vector": pa.list_(pa.float64(), 10),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
data = pa.table(
|
||||||
|
{
|
||||||
|
"text": ["hello"],
|
||||||
|
"vector": [None], # Replaces null
|
||||||
|
},
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
output = _append_vector_columns(
|
||||||
|
data,
|
||||||
|
schema, # metadata passed separate from schema
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
assert output.schema == schema
|
||||||
|
assert output["vector"].null_count == 0
|
||||||
|
|
||||||
|
# Adds if missing
|
||||||
|
data = pa.table({"text": ["hello"]})
|
||||||
|
output = _append_vector_columns(
|
||||||
|
data,
|
||||||
|
schema.with_metadata(metadata),
|
||||||
|
)
|
||||||
|
assert output.schema == schema
|
||||||
|
assert output["vector"].null_count == 0
|
||||||
|
|
||||||
|
# doesn't embed if already there
|
||||||
|
data = pa.table(
|
||||||
|
{
|
||||||
|
"text": ["hello"],
|
||||||
|
"vector": [[42.0] * 10],
|
||||||
|
},
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
output = _append_vector_columns(
|
||||||
|
data,
|
||||||
|
schema.with_metadata(metadata),
|
||||||
|
)
|
||||||
|
assert output == data # No change
|
||||||
|
|
||||||
|
# No provided schema
|
||||||
|
data = pa.table(
|
||||||
|
{
|
||||||
|
"text": ["hello"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
output = _append_vector_columns(
|
||||||
|
data,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
expected_schema = pa.schema(
|
||||||
|
{
|
||||||
|
"text": pa.string(),
|
||||||
|
"vector": pa.list_(pa.float32(), 10),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert output.schema == expected_schema
|
||||||
|
assert output["vector"].null_count == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("on_bad_vectors", ["error", "drop", "fill", "null"])
|
||||||
|
def test_handle_bad_vectors_jagged(on_bad_vectors):
|
||||||
|
vector = pa.array([[1.0, 2.0], [3.0], [4.0, 5.0]])
|
||||||
|
schema = pa.schema({"vector": pa.list_(pa.float64())})
|
||||||
|
data = pa.table({"vector": vector}, schema=schema)
|
||||||
|
|
||||||
|
if on_bad_vectors == "error":
|
||||||
|
with pytest.raises(ValueError) as e:
|
||||||
|
output = _handle_bad_vectors(
|
||||||
|
data,
|
||||||
|
on_bad_vectors=on_bad_vectors,
|
||||||
|
)
|
||||||
|
output = exception_output(e)
|
||||||
|
assert output == (
|
||||||
|
"ValueError: Vector column 'vector' has variable length vectors. Set "
|
||||||
|
"on_bad_vectors='drop' to remove them, set on_bad_vectors='fill' "
|
||||||
|
"and fill_value=<value> to replace them, or set on_bad_vectors='null' "
|
||||||
|
"to replace them with null."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
output = _handle_bad_vectors(
|
||||||
|
data,
|
||||||
|
on_bad_vectors=on_bad_vectors,
|
||||||
|
fill_value=42.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if on_bad_vectors == "drop":
|
||||||
|
expected = pa.array([[1.0, 2.0], [4.0, 5.0]])
|
||||||
|
elif on_bad_vectors == "fill":
|
||||||
|
expected = pa.array([[1.0, 2.0], [42.0, 42.0], [4.0, 5.0]])
|
||||||
|
elif on_bad_vectors == "null":
|
||||||
|
expected = pa.array([[1.0, 2.0], None, [4.0, 5.0]])
|
||||||
|
|
||||||
|
assert output["vector"].combine_chunks() == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("on_bad_vectors", ["error", "drop", "fill", "null"])
|
||||||
|
def test_handle_bad_vectors_nan(on_bad_vectors):
|
||||||
|
vector = pa.array([[1.0, float("nan")], [3.0, 4.0]])
|
||||||
|
data = pa.table({"vector": vector})
|
||||||
|
|
||||||
|
if on_bad_vectors == "error":
|
||||||
|
with pytest.raises(ValueError) as e:
|
||||||
|
output = _handle_bad_vectors(
|
||||||
|
data,
|
||||||
|
on_bad_vectors=on_bad_vectors,
|
||||||
|
)
|
||||||
|
output = exception_output(e)
|
||||||
|
assert output == (
|
||||||
|
"ValueError: Vector column 'vector' has NaNs. Set "
|
||||||
|
"on_bad_vectors='drop' to remove them, set on_bad_vectors='fill' "
|
||||||
|
"and fill_value=<value> to replace them, or set on_bad_vectors='null' "
|
||||||
|
"to replace them with null."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
output = _handle_bad_vectors(
|
||||||
|
data,
|
||||||
|
on_bad_vectors=on_bad_vectors,
|
||||||
|
fill_value=42.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if on_bad_vectors == "drop":
|
||||||
|
expected = pa.array([[3.0, 4.0]])
|
||||||
|
elif on_bad_vectors == "fill":
|
||||||
|
expected = pa.array([[42.0, 42.0], [3.0, 4.0]])
|
||||||
|
elif on_bad_vectors == "null":
|
||||||
|
expected = pa.array([None, [3.0, 4.0]])
|
||||||
|
|
||||||
|
assert output["vector"].combine_chunks() == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_bad_vectors_noop():
|
||||||
|
# ChunkedArray should be preserved as-is
|
||||||
|
vector = pa.chunked_array(
|
||||||
|
[[[1.0, 2.0], [3.0, 4.0]]], type=pa.list_(pa.float64(), 2)
|
||||||
|
)
|
||||||
|
data = pa.table({"vector": vector})
|
||||||
|
output = _handle_bad_vectors(data)
|
||||||
|
assert output["vector"] == vector
|
||||||
|
|
||||||
|
|
||||||
|
class TestModel(lancedb.pydantic.LanceModel):
|
||||||
|
a: Optional[int]
|
||||||
|
b: Optional[int]
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: huggingface,
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"data",
|
||||||
|
[
|
||||||
|
lambda: [{"a": 1, "b": 2}],
|
||||||
|
lambda: pa.RecordBatch.from_pylist([{"a": 1, "b": 2}]),
|
||||||
|
lambda: pa.table({"a": [1], "b": [2]}),
|
||||||
|
lambda: pa.table({"a": [1], "b": [2]}).to_reader(),
|
||||||
|
lambda: iter(pa.table({"a": [1], "b": [2]}).to_batches()),
|
||||||
|
lambda: (
|
||||||
|
lance.write_dataset(
|
||||||
|
pa.table({"a": [1], "b": [2]}),
|
||||||
|
"memory://test",
|
||||||
|
)
|
||||||
|
),
|
||||||
|
lambda: (
|
||||||
|
lance.write_dataset(
|
||||||
|
pa.table({"a": [1], "b": [2]}),
|
||||||
|
"memory://test",
|
||||||
|
).scanner()
|
||||||
|
),
|
||||||
|
lambda: pd.DataFrame({"a": [1], "b": [2]}),
|
||||||
|
lambda: pl.DataFrame({"a": [1], "b": [2]}),
|
||||||
|
lambda: pl.LazyFrame({"a": [1], "b": [2]}),
|
||||||
|
lambda: [TestModel(a=1, b=2)],
|
||||||
|
],
|
||||||
|
ids=[
|
||||||
|
"rows",
|
||||||
|
"pa.RecordBatch",
|
||||||
|
"pa.Table",
|
||||||
|
"pa.RecordBatchReader",
|
||||||
|
"batch_iter",
|
||||||
|
"lance.LanceDataset",
|
||||||
|
"lance.LanceScanner",
|
||||||
|
"pd.DataFrame",
|
||||||
|
"pl.DataFrame",
|
||||||
|
"pl.LazyFrame",
|
||||||
|
"pydantic",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_into_pyarrow_table(data):
|
||||||
|
expected = pa.table({"a": [1], "b": [2]})
|
||||||
|
output = _into_pyarrow_table(data())
|
||||||
|
assert output == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_infer_target_schema():
|
||||||
|
example = pa.schema(
|
||||||
|
{
|
||||||
|
"vec1": pa.list_(pa.float64(), 2),
|
||||||
|
"vector": pa.list_(pa.float64()),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
data = pa.table(
|
||||||
|
{
|
||||||
|
"vec1": [[0.0] * 2],
|
||||||
|
"vector": [[0.0] * 2],
|
||||||
|
},
|
||||||
|
schema=example,
|
||||||
|
)
|
||||||
|
expected = pa.schema(
|
||||||
|
{
|
||||||
|
"vec1": pa.list_(pa.float64(), 2),
|
||||||
|
"vector": pa.list_(pa.float32(), 2),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
output = _infer_target_schema(data)
|
||||||
|
assert output == expected
|
||||||
|
|
||||||
|
# Handle large list and use modal size
|
||||||
|
# Most vectors are of length 2, so we should infer that as the target dimension
|
||||||
|
example = pa.schema(
|
||||||
|
{
|
||||||
|
"vector": pa.large_list(pa.float64()),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
data = pa.table(
|
||||||
|
{
|
||||||
|
"vector": [[0.0] * 2, [0.0], [0.0] * 2],
|
||||||
|
},
|
||||||
|
schema=example,
|
||||||
|
)
|
||||||
|
expected = pa.schema(
|
||||||
|
{
|
||||||
|
"vector": pa.list_(pa.float32(), 2),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
output = _infer_target_schema(data)
|
||||||
|
assert output == expected
|
||||||
|
|
||||||
|
# ignore if not list
|
||||||
|
example = pa.schema(
|
||||||
|
{
|
||||||
|
"vector": pa.float64(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
data = pa.table(
|
||||||
|
{
|
||||||
|
"vector": [0.0],
|
||||||
|
},
|
||||||
|
schema=example,
|
||||||
|
)
|
||||||
|
expected = example
|
||||||
|
output = _infer_target_schema(data)
|
||||||
|
assert output == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"data",
|
||||||
|
[
|
||||||
|
[{"id": 1, "text": "hello"}],
|
||||||
|
pa.RecordBatch.from_pylist([{"id": 1, "text": "hello"}]),
|
||||||
|
pd.DataFrame({"id": [1], "text": ["hello"]}),
|
||||||
|
pl.DataFrame({"id": [1], "text": ["hello"]}),
|
||||||
|
],
|
||||||
|
ids=["rows", "pa.RecordBatch", "pd.DataFrame", "pl.DataFrame"],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"schema",
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
pa.schema(
|
||||||
|
{
|
||||||
|
"id": pa.int32(),
|
||||||
|
"text": pa.string(),
|
||||||
|
"vector": pa.list_(pa.float32(), 10),
|
||||||
|
}
|
||||||
|
),
|
||||||
|
pa.schema(
|
||||||
|
{
|
||||||
|
"id": pa.int64(),
|
||||||
|
"text": pa.string(),
|
||||||
|
"vector": pa.list_(pa.float32(), 10),
|
||||||
|
"extra": pa.int64(),
|
||||||
|
}
|
||||||
|
),
|
||||||
|
],
|
||||||
|
ids=["infer", "explicit", "subschema"],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("with_embedding", [True, False])
|
||||||
|
def test_sanitize_data(
|
||||||
|
data,
|
||||||
|
schema: Optional[pa.Schema],
|
||||||
|
with_embedding: bool,
|
||||||
|
):
|
||||||
|
if with_embedding:
|
||||||
|
registry = EmbeddingFunctionRegistry.get_instance()
|
||||||
|
registry.register("test")(MockTextEmbeddingFunction)
|
||||||
|
conf = EmbeddingFunctionConfig(
|
||||||
|
source_column="text",
|
||||||
|
vector_column="vector",
|
||||||
|
function=MockTextEmbeddingFunction(),
|
||||||
|
)
|
||||||
|
metadata = registry.get_table_metadata([conf])
|
||||||
|
else:
|
||||||
|
metadata = None
|
||||||
|
|
||||||
|
if schema is not None:
|
||||||
|
to_remove = schema.get_field_index("extra")
|
||||||
|
if to_remove >= 0:
|
||||||
|
expected_schema = schema.remove(to_remove)
|
||||||
|
else:
|
||||||
|
expected_schema = schema
|
||||||
|
else:
|
||||||
|
expected_schema = pa.schema(
|
||||||
|
{
|
||||||
|
"id": pa.int64(),
|
||||||
|
"text": pa.large_utf8()
|
||||||
|
if isinstance(data, pl.DataFrame)
|
||||||
|
else pa.string(),
|
||||||
|
"vector": pa.list_(pa.float32(), 10),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if not with_embedding:
|
||||||
|
to_remove = expected_schema.get_field_index("vector")
|
||||||
|
if to_remove >= 0:
|
||||||
|
expected_schema = expected_schema.remove(to_remove)
|
||||||
|
|
||||||
|
expected = pa.table(
|
||||||
|
{
|
||||||
|
"id": [1],
|
||||||
|
"text": ["hello"],
|
||||||
|
"vector": [[0.0] * 10],
|
||||||
|
},
|
||||||
|
schema=expected_schema,
|
||||||
|
)
|
||||||
|
|
||||||
|
output_data = _sanitize_data(
|
||||||
|
data,
|
||||||
|
target_schema=schema,
|
||||||
|
metadata=metadata,
|
||||||
|
allow_subschema=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert output_data == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_cast_to_target_schema():
|
||||||
|
original_schema = pa.schema(
|
||||||
|
{
|
||||||
|
"id": pa.int32(),
|
||||||
|
"struct": pa.struct(
|
||||||
|
[
|
||||||
|
pa.field("a", pa.int32()),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
"vector": pa.list_(pa.float64()),
|
||||||
|
"vec1": pa.list_(pa.float64(), 2),
|
||||||
|
"vec2": pa.list_(pa.float32(), 2),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
data = pa.table(
|
||||||
|
{
|
||||||
|
"id": [1],
|
||||||
|
"struct": [{"a": 1}],
|
||||||
|
"vector": [[0.0] * 2],
|
||||||
|
"vec1": [[0.0] * 2],
|
||||||
|
"vec2": [[0.0] * 2],
|
||||||
|
},
|
||||||
|
schema=original_schema,
|
||||||
|
)
|
||||||
|
|
||||||
|
target = pa.schema(
|
||||||
|
{
|
||||||
|
"id": pa.int64(),
|
||||||
|
"struct": pa.struct(
|
||||||
|
[
|
||||||
|
pa.field("a", pa.int64()),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
"vector": pa.list_(pa.float32(), 2),
|
||||||
|
"vec1": pa.list_(pa.float32(), 2),
|
||||||
|
"vec2": pa.list_(pa.float32(), 2),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
output = _cast_to_target_schema(data, target)
|
||||||
|
expected = pa.table(
|
||||||
|
{
|
||||||
|
"id": [1],
|
||||||
|
"struct": [{"a": 1}],
|
||||||
|
"vector": [[0.0] * 2],
|
||||||
|
"vec1": [[0.0] * 2],
|
||||||
|
"vec2": [[0.0] * 2],
|
||||||
|
},
|
||||||
|
schema=target,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Data can be a subschema of the target
|
||||||
|
target = pa.schema(
|
||||||
|
{
|
||||||
|
"id": pa.int64(),
|
||||||
|
"struct": pa.struct(
|
||||||
|
[
|
||||||
|
pa.field("a", pa.int64()),
|
||||||
|
# Additional nested field
|
||||||
|
pa.field("b", pa.int64()),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
"vector": pa.list_(pa.float32(), 2),
|
||||||
|
"vec1": pa.list_(pa.float32(), 2),
|
||||||
|
"vec2": pa.list_(pa.float32(), 2),
|
||||||
|
# Additional field
|
||||||
|
"extra": pa.int64(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
_cast_to_target_schema(data, target)
|
||||||
|
output = _cast_to_target_schema(data, target, allow_subschema=True)
|
||||||
|
expected_schema = pa.schema(
|
||||||
|
{
|
||||||
|
"id": pa.int64(),
|
||||||
|
"struct": pa.struct(
|
||||||
|
[
|
||||||
|
pa.field("a", pa.int64()),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
"vector": pa.list_(pa.float32(), 2),
|
||||||
|
"vec1": pa.list_(pa.float32(), 2),
|
||||||
|
"vec2": pa.list_(pa.float32(), 2),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
expected = pa.table(
|
||||||
|
{
|
||||||
|
"id": [1],
|
||||||
|
"struct": [{"a": 1}],
|
||||||
|
"vector": [[0.0] * 2],
|
||||||
|
"vec1": [[0.0] * 2],
|
||||||
|
"vec2": [[0.0] * 2],
|
||||||
|
},
|
||||||
|
schema=expected_schema,
|
||||||
|
)
|
||||||
|
assert output == expected
|
||||||
|
|||||||
Reference in New Issue
Block a user