diff --git a/python/python/lancedb/db.py b/python/python/lancedb/db.py index 80ea876e..51c1c610 100644 --- a/python/python/lancedb/db.py +++ b/python/python/lancedb/db.py @@ -784,10 +784,6 @@ class AsyncConnection(object): registry = EmbeddingFunctionRegistry.get_instance() metadata = registry.get_table_metadata(embedding_functions) - data, schema = sanitize_create_table( - data, schema, metadata, on_bad_vectors, fill_value - ) - # Defining defaults here and not in function prototype. In the future # these defaults will move into rust so better to keep them as None. if on_bad_vectors is None: diff --git a/python/python/lancedb/embeddings/registry.py b/python/python/lancedb/embeddings/registry.py index f442b713..78353e47 100644 --- a/python/python/lancedb/embeddings/registry.py +++ b/python/python/lancedb/embeddings/registry.py @@ -108,9 +108,14 @@ class EmbeddingFunctionRegistry: An empty dict is returned if input is None or does not contain b"embedding_functions". """ - if metadata is None or b"embedding_functions" not in metadata: + if metadata is None: + return {} + # Look at both bytes and string keys, since we might use either + serialized = metadata.get( + b"embedding_functions", metadata.get("embedding_functions") + ) + if serialized is None: return {} - serialized = metadata[b"embedding_functions"] raw_list = json.loads(serialized.decode("utf-8")) return { obj["vector_column"]: EmbeddingFunctionConfig( diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index 0409e658..81363a1b 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -472,7 +472,7 @@ class LanceQueryBuilder(ABC): -------- >>> import lancedb >>> db = lancedb.connect("./.lancedb") - >>> table = db.create_table("my_table", [{"vector": [99, 99]}]) + >>> table = db.create_table("my_table", [{"vector": [99.0, 99]}]) >>> query = [100, 100] >>> plan = table.search(query).explain_plan(True) >>> print(plan) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index e15280b0..264200d6 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -25,7 +25,6 @@ from urllib.parse import urlparse import lance from lancedb.background_loop import LOOP from .dependencies import _check_for_pandas -import numpy as np import pyarrow as pa import pyarrow.compute as pc import pyarrow.fs as pa_fs @@ -74,34 +73,17 @@ pl = safe_import_polars() QueryType = Literal["vector", "fts", "hybrid", "auto"] -def _pd_schema_without_embedding_funcs( - schema: Optional[pa.Schema], columns: List[str] -) -> Optional[pa.Schema]: - """Return a schema without any embedding function columns""" - if schema is None: - return None - embedding_functions = EmbeddingFunctionRegistry.get_instance().parse_functions( - schema.metadata - ) - if not embedding_functions: - return schema - return pa.schema([field for field in schema if field.name in columns]) - - -def _coerce_to_table(data, schema: Optional[pa.Schema] = None) -> pa.Table: +def _into_pyarrow_table(data) -> pa.Table: if _check_for_hugging_face(data): # Huggingface datasets from lance.dependencies import datasets if isinstance(data, datasets.Dataset): - if schema is None: - schema = data.features.arrow_schema + schema = data.features.arrow_schema return pa.Table.from_batches(data.data.to_batches(), schema=schema) elif isinstance(data, datasets.dataset_dict.DatasetDict): - if schema is None: - schema = _schema_from_hf(data, schema) + schema = _schema_from_hf(data, schema) return pa.Table.from_batches(_to_batches_with_split(data), schema=schema) - if isinstance(data, LanceModel): raise ValueError("Cannot add a single LanceModel to a table. Use a list.") @@ -111,17 +93,15 @@ def _coerce_to_table(data, schema: Optional[pa.Schema] = None) -> pa.Table: if isinstance(data, list): # convert to list of dict if data is a bunch of LanceModels if isinstance(data[0], LanceModel): - if schema is None: - schema = data[0].__class__.to_arrow_schema() + schema = data[0].__class__.to_arrow_schema() data = [model_to_dict(d) for d in data] return pa.Table.from_pylist(data, schema=schema) elif isinstance(data[0], pa.RecordBatch): - return pa.Table.from_batches(data, schema=schema) + return pa.Table.from_batches(data) else: - return pa.Table.from_pylist(data, schema=schema) - elif _check_for_pandas(data) and isinstance(data, pd.DataFrame): # type: ignore - raw_schema = _pd_schema_without_embedding_funcs(schema, data.columns.to_list()) - table = pa.Table.from_pandas(data, preserve_index=False, schema=raw_schema) + return pa.Table.from_pylist(data) + elif _check_for_pandas(data) and isinstance(data, pd.DataFrame): + table = pa.Table.from_pandas(data, preserve_index=False) # Do not serialize Pandas metadata meta = table.schema.metadata if table.schema.metadata is not None else {} meta = {k: v for k, v in meta.items() if k != b"pandas"} @@ -143,8 +123,13 @@ def _coerce_to_table(data, schema: Optional[pa.Schema] = None) -> pa.Table: and data.__class__.__name__ == "DataFrame" ): return data.to_arrow() + elif ( + type(data).__module__.startswith("polars") + and data.__class__.__name__ == "LazyFrame" + ): + return data.collect().to_arrow() elif isinstance(data, Iterable): - return _process_iterator(data, schema) + return _iterator_to_table(data) else: raise TypeError( f"Unknown data type {type(data)}. " @@ -154,27 +139,172 @@ def _coerce_to_table(data, schema: Optional[pa.Schema] = None) -> pa.Table: ) +def _iterator_to_table(data: Iterable) -> pa.Table: + batches = [] + schema = None # Will get schema from first batch + for batch in data: + batch_table = _into_pyarrow_table(batch) + if schema is not None: + if batch_table.schema != schema: + try: + batch_table = batch_table.cast(schema) + except pa.lib.ArrowInvalid: + raise ValueError( + f"Input iterator yielded a batch with schema that " + f"does not match the schema of other batches.\n" + f"Expected:\n{schema}\nGot:\n{batch_table.schema}" + ) + else: + # Use the first schema for the remainder of the batches + schema = batch_table.schema + batches.append(batch_table) + + if batches: + return pa.concat_tables(batches) + else: + raise ValueError("Input iterable is empty") + + def _sanitize_data( - data: Any, - schema: Optional[pa.Schema] = None, + data: "DATA", + target_schema: Optional[pa.Schema] = None, metadata: Optional[dict] = None, # embedding metadata - on_bad_vectors: str = "error", + on_bad_vectors: Literal["error", "drop", "fill", "null"] = "error", fill_value: float = 0.0, -) -> Tuple[pa.Table, pa.Schema]: - data = _coerce_to_table(data, schema) + *, + allow_subschema: bool = False, +) -> pa.Table: + """ + Handle input data, applying all standard transformations. + + This includes: + + * Converting the data to a PyArrow Table + * Adding vector columns defined in the metadata + * Adding embedding metadata into the schema + * Casting the table to the target schema + * Handling bad vectors + + Parameters + ---------- + target_schema : Optional[pa.Schema], default None + The schema to cast the table to. This is typically the schema of the table + if it already exists. Otherwise it might be a user-requested schema. + allow_subschema : bool, default False + If True, the input table is allowed to omit columns from the target schema. + The target schema will be filtered to only include columns that are present + in the input table before casting. + metadata : Optional[dict], default None + The embedding metadata to add to the schema. + on_bad_vectors : Literal["error", "drop", "fill", "null"], default "error" + What to do if any of the vectors are not the same size or contains NaNs. + fill_value : float, default 0.0 + The value to use when filling vectors. Only used if on_bad_vectors="fill". + All entries in the vector will be set to this value. + """ + # At this point, the table might not match the schema we are targeting: + # 1. There might be embedding columns missing that will be added + # in the add_embeddings step. + # 2. If `allow_subschemas` is True, there might be columns missing. + table = _into_pyarrow_table(data) + + table = _append_vector_columns(table, target_schema, metadata=metadata) + + # This happens before the cast so we can fix vector columns with + # incorrect lengths before they are cast to FSL. + table = _handle_bad_vectors( + table, + on_bad_vectors=on_bad_vectors, + fill_value=fill_value, + ) + + if target_schema is None: + target_schema = _infer_target_schema(table) if metadata: - data = _append_vector_col(data, metadata, schema) - metadata.update(data.schema.metadata or {}) - data = data.replace_schema_metadata(metadata) + new_metadata = target_schema.metadata or {} + new_metadata = new_metadata.update(metadata) + target_schema = target_schema.with_metadata(new_metadata) - # TODO improve the logics in _sanitize_schema - data = _sanitize_schema(data, schema, on_bad_vectors, fill_value) - if schema is None: - schema = data.schema + _validate_schema(target_schema) - _validate_schema(schema) - return data, schema + table = _cast_to_target_schema(table, target_schema, allow_subschema) + + return table + + +def _cast_to_target_schema( + table: pa.Table, + target_schema: pa.Schema, + allow_subschema: bool = False, +) -> pa.Table: + # pa.Table.cast expects field order not to be changed. + # Lance doesn't care about field order, so we don't need to rearrange fields + # to match the target schema. We just need to correctly cast the fields. + if table.schema == target_schema: + # Fast path when the schemas are already the same + return table + + fields = [] + for field in table.schema: + target_field = target_schema.field(field.name) + if target_field is None: + raise ValueError(f"Field {field.name} not found in target schema") + fields.append(target_field) + reordered_schema = pa.schema(fields, metadata=target_schema.metadata) + if not allow_subschema and len(reordered_schema) != len(target_schema): + raise ValueError( + "Input table has different number of columns than target schema" + ) + + if allow_subschema and len(reordered_schema) != len(target_schema): + fields = _infer_subschema( + list(iter(table.schema)), list(iter(reordered_schema)) + ) + subschema = pa.schema(fields, metadata=target_schema.metadata) + return table.cast(subschema) + else: + return table.cast(reordered_schema) + + +def _infer_subschema( + schema: List[pa.Field], + reference_fields: List[pa.Field], +) -> List[pa.Field]: + """ + Transform the list of fields so the types match the reference_fields. + + The order of the fields is preserved. + + ``schema`` may have fewer fields than `reference_fields`, but it may not have + more fields. + + """ + fields = [] + lookup = {f.name: f for f in reference_fields} + for field in schema: + reference = lookup.get(field.name) + if reference is None: + raise ValueError("Unexpected field in schema: {}".format(field)) + + if pa.types.is_struct(reference.type): + new_type = pa.struct( + _infer_subschema( + field.type.fields, + reference.type.fields, + ) + ) + new_field = pa.field( + field.name, + new_type, + reference.nullable, + ) + else: + new_field = reference + + fields.append(new_field) + + return fields def sanitize_create_table( @@ -193,13 +323,14 @@ def sanitize_create_table( if data is not None: if metadata is None and schema is not None: metadata = schema.metadata - data, schema = _sanitize_data( + data = _sanitize_data( data, schema, metadata=metadata, on_bad_vectors=on_bad_vectors, fill_value=fill_value, ) + schema = data.schema else: if schema is not None: data = pa.Table.from_pylist([], schema) @@ -211,6 +342,8 @@ def sanitize_create_table( if metadata: schema = schema.with_metadata(metadata) + # Need to apply metadata to the data as well + data = data.replace_schema_metadata(metadata) return data, schema @@ -246,12 +379,22 @@ def _to_batches_with_split(data): yield b -def _append_vector_col(data: pa.Table, metadata: dict, schema: Optional[pa.Schema]): +def _append_vector_columns( + data: pa.Table, + schema: Optional[pa.Schema] = None, + *, + metadata: Optional[dict] = None, +) -> pa.Table: """ - Use the embedding function to automatically embed the source column and add the - vector column to the table. + Use the embedding function to automatically embed the source columns and add the + vector columns to the table. """ + if schema is None: + metadata = metadata or {} + else: + metadata = schema.metadata or metadata or {} functions = EmbeddingFunctionRegistry.get_instance().parse_functions(metadata) + for vector_column, conf in functions.items(): func = conf.function no_vector_column = vector_column not in data.column_names @@ -790,9 +933,9 @@ class Table(ABC): -------- >>> import lancedb >>> data = [ - ... {"x": 1, "vector": [1, 2]}, - ... {"x": 2, "vector": [3, 4]}, - ... {"x": 3, "vector": [5, 6]} + ... {"x": 1, "vector": [1.0, 2]}, + ... {"x": 2, "vector": [3.0, 4]}, + ... {"x": 3, "vector": [5.0, 6]} ... ] >>> db = lancedb.connect("./.lancedb") >>> table = db.create_table("my_table", data) @@ -854,7 +997,7 @@ class Table(ABC): -------- >>> import lancedb >>> import pandas as pd - >>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1, 2], [3, 4], [5, 6]]}) + >>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1.0, 2], [3, 4], [5, 6]]}) >>> db = lancedb.connect("./.lancedb") >>> table = db.create_table("my_table", data) >>> table.to_pandas() @@ -862,7 +1005,7 @@ class Table(ABC): 0 1 [1.0, 2.0] 1 2 [3.0, 4.0] 2 3 [5.0, 6.0] - >>> table.update(where="x = 2", values={"vector": [10, 10]}) + >>> table.update(where="x = 2", values={"vector": [10.0, 10]}) >>> table.to_pandas() x vector 0 1 [1.0, 2.0] @@ -1880,9 +2023,9 @@ class LanceTable(Table): -------- >>> import lancedb >>> data = [ - ... {"x": 1, "vector": [1, 2]}, - ... {"x": 2, "vector": [3, 4]}, - ... {"x": 3, "vector": [5, 6]} + ... {"x": 1, "vector": [1.0, 2]}, + ... {"x": 2, "vector": [3.0, 4]}, + ... {"x": 3, "vector": [5.0, 6]} ... ] >>> db = lancedb.connect("./.lancedb") >>> table = db.create_table("my_table", data) @@ -1971,7 +2114,7 @@ class LanceTable(Table): -------- >>> import lancedb >>> import pandas as pd - >>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1, 2], [3, 4], [5, 6]]}) + >>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1.0, 2], [3, 4], [5, 6]]}) >>> db = lancedb.connect("./.lancedb") >>> table = db.create_table("my_table", data) >>> table.to_pandas() @@ -1979,7 +2122,7 @@ class LanceTable(Table): 0 1 [1.0, 2.0] 1 2 [3.0, 4.0] 2 3 [5.0, 6.0] - >>> table.update(where="x = 2", values={"vector": [10, 10]}) + >>> table.update(where="x = 2", values={"vector": [10.0, 10]}) >>> table.to_pandas() x vector 0 1 [1.0, 2.0] @@ -2165,74 +2308,49 @@ class LanceTable(Table): LOOP.run(self._table.migrate_v2_manifest_paths()) -def _sanitize_schema( - data: pa.Table, - schema: pa.Schema = None, - on_bad_vectors: str = "error", +def _handle_bad_vectors( + table: pa.Table, + on_bad_vectors: Literal["error", "drop", "fill", "null"] = "error", fill_value: float = 0.0, ) -> pa.Table: - """Ensure that the table has the expected schema. - - Parameters - ---------- - data: pa.Table - The table to sanitize. - schema: pa.Schema; optional - The expected schema. If not provided, this just converts the - vector column to fixed_size_list(float32) if necessary. - on_bad_vectors: str, default "error" - What to do if any of the vectors are not the same size or contains NaNs. - One of "error", "drop", "fill", "null". - fill_value: float, default 0. - The value to use when filling vectors. Only used if on_bad_vectors="fill". - """ - if schema is not None: - # cast the columns to the expected types - data = data.combine_chunks() - for field in schema: - # TODO: we're making an assumption that fixed size list of 10 or more - # is a vector column. This is definitely a bit hacky. - likely_vector_col = ( - pa.types.is_fixed_size_list(field.type) - and pa.types.is_float32(field.type.value_type) - and field.type.list_size >= 10 + for field in table.schema: + # They can provide a 'vector' column that isn't yet a FSL + named_vector_col = ( + ( + pa.types.is_list(field.type) + or pa.types.is_large_list(field.type) + or pa.types.is_fixed_size_list(field.type) ) - is_default_vector_col = field.name == VECTOR_COLUMN_NAME - if field.name in data.column_names and ( - likely_vector_col or is_default_vector_col - ): - data = _sanitize_vector_column( - data, - vector_column_name=field.name, - on_bad_vectors=on_bad_vectors, - fill_value=fill_value, - table_schema=schema, - ) - return pa.Table.from_arrays( - [data[name] for name in schema.names], schema=schema + and pa.types.is_floating(field.type.value_type) + and field.name == VECTOR_COLUMN_NAME + ) + # TODO: we're making an assumption that fixed size list of 10 or more + # is a vector column. This is definitely a bit hacky. + likely_vector_col = ( + pa.types.is_fixed_size_list(field.type) + and pa.types.is_floating(field.type.value_type) + and (field.type.list_size >= 10) ) - # just check the vector column - if VECTOR_COLUMN_NAME in data.column_names: - return _sanitize_vector_column( - data, - vector_column_name=VECTOR_COLUMN_NAME, - on_bad_vectors=on_bad_vectors, - fill_value=fill_value, - ) + if named_vector_col or likely_vector_col: + table = _handle_bad_vector_column( + table, + vector_column_name=field.name, + on_bad_vectors=on_bad_vectors, + fill_value=fill_value, + ) - return data + return table -def _sanitize_vector_column( +def _handle_bad_vector_column( data: pa.Table, vector_column_name: str, - table_schema: Optional[pa.Schema] = None, on_bad_vectors: str = "error", fill_value: float = 0.0, ) -> pa.Table: """ - Ensure that the vector column exists and has type fixed_size_list(float32) + Ensure that the vector column exists and has type fixed_size_list(float) Parameters ---------- @@ -2246,141 +2364,118 @@ def _sanitize_vector_column( fill_value: float, default 0.0 The value to use when filling vectors. Only used if on_bad_vectors="fill". """ - # ChunkedArray is annoying to work with, so we combine chunks here - vec_arr = data[vector_column_name].combine_chunks() - if table_schema is not None: - field = table_schema.field(vector_column_name) - else: - field = None - typ = data[vector_column_name].type - if pa.types.is_list(typ) or pa.types.is_large_list(typ): - # if it's a variable size list array, - # we make sure the dimensions are all the same - has_jagged_ndims = len(vec_arr.values) % len(data) != 0 - if has_jagged_ndims: - data = _sanitize_jagged( - data, fill_value, on_bad_vectors, vec_arr, vector_column_name - ) - vec_arr = data[vector_column_name].combine_chunks() - vec_arr = ensure_fixed_size_list(vec_arr) - data = data.set_column( - data.column_names.index(vector_column_name), vector_column_name, vec_arr - ) - elif not pa.types.is_fixed_size_list(vec_arr.type): - raise TypeError(f"Unsupported vector column type: {vec_arr.type}") + vec_arr = data[vector_column_name] - if pa.types.is_float16(vec_arr.values.type): - # Use numpy to check for NaNs, because as pyarrow does not have `is_nan` - # kernel over f16 types yet. - values_np = vec_arr.values.to_numpy(zero_copy_only=True) - if np.isnan(values_np).any(): - data = _sanitize_nans( - data, fill_value, on_bad_vectors, vec_arr, vector_column_name - ) - else: - if ( - field is not None - and not field.nullable - and pc.any(pc.is_null(vec_arr.values)).as_py() - ) or (pc.any(pc.is_nan(vec_arr.values)).as_py()): - data = _sanitize_nans( - data, fill_value, on_bad_vectors, vec_arr, vector_column_name - ) - return data + has_nan = has_nan_values(vec_arr) - -def ensure_fixed_size_list(vec_arr) -> pa.FixedSizeListArray: - values = vec_arr.values - if not (pa.types.is_float16(values.type) or pa.types.is_float32(values.type)): - values = values.cast(pa.float32()) if pa.types.is_fixed_size_list(vec_arr.type): - list_size = vec_arr.type.list_size + dim = vec_arr.type.list_size else: - list_size = len(values) / len(vec_arr) - vec_arr = pa.FixedSizeListArray.from_arrays(values, list_size) - return vec_arr + dim = _modal_list_size(vec_arr) + has_wrong_dim = pc.not_equal(pc.list_value_length(vec_arr), dim) + has_bad_vectors = pc.any(has_nan).as_py() or pc.any(has_wrong_dim).as_py() -def _sanitize_jagged(data, fill_value, on_bad_vectors, vec_arr, vector_column_name): - """Sanitize jagged vectors.""" - if on_bad_vectors == "error": - raise ValueError( - f"Vector column {vector_column_name} has variable length vectors " - "Set on_bad_vectors='drop' to remove them, or " - "set on_bad_vectors='fill' and fill_value= to replace them." - ) - - lst_lengths = pc.list_value_length(vec_arr) - ndims = pc.max(lst_lengths).as_py() - correct_ndims = pc.equal(lst_lengths, ndims) - - if on_bad_vectors == "fill": - if fill_value is None: - raise ValueError( - "`fill_value` must not be None if `on_bad_vectors` is 'fill'" + if has_bad_vectors: + is_bad = pc.or_(has_nan, has_wrong_dim) + if on_bad_vectors == "error": + if pc.any(has_wrong_dim).as_py(): + raise ValueError( + f"Vector column '{vector_column_name}' has variable length " + "vectors. Set on_bad_vectors='drop' to remove them, " + "set on_bad_vectors='fill' and fill_value= to replace them, " + "or set on_bad_vectors='null' to replace them with null." + ) + else: + raise ValueError( + f"Vector column '{vector_column_name}' has NaNs. " + "Set on_bad_vectors='drop' to remove them, " + "set on_bad_vectors='fill' and fill_value= to replace them, " + "or set on_bad_vectors='null' to replace them with null." + ) + elif on_bad_vectors == "null": + vec_arr = pc.if_else( + is_bad, + pa.scalar(None), + vec_arr, ) - fill_arr = pa.scalar([float(fill_value)] * ndims) - vec_arr = pc.if_else(correct_ndims, vec_arr, fill_arr) - data = data.set_column( - data.column_names.index(vector_column_name), vector_column_name, vec_arr - ) - elif on_bad_vectors == "drop": - data = data.filter(correct_ndims) - elif on_bad_vectors == "null": - data = data.set_column( - data.column_names.index(vector_column_name), - vector_column_name, - pc.if_else(correct_ndims, vec_arr, pa.scalar(None)), - ) - return data - - -def _sanitize_nans( - data, - fill_value, - on_bad_vectors, - vec_arr: pa.FixedSizeListArray, - vector_column_name: str, -): - """Sanitize NaNs in vectors""" - assert pa.types.is_fixed_size_list(vec_arr.type) - if on_bad_vectors == "error": - raise ValueError( - f"Vector column {vector_column_name} has NaNs. " - "Set on_bad_vectors='drop' to remove them, or " - "set on_bad_vectors='fill' and fill_value= to replace them. " - "Or set on_bad_vectors='null' to replace them with null." - ) - elif on_bad_vectors == "fill": - if fill_value is None: - raise ValueError( - "`fill_value` must not be None if `on_bad_vectors` is 'fill'" + elif on_bad_vectors == "drop": + data = data.filter(pc.invert(is_bad)) + vec_arr = data[vector_column_name] + elif on_bad_vectors == "fill": + if fill_value is None: + raise ValueError( + "`fill_value` must not be None if `on_bad_vectors` is 'fill'" + ) + vec_arr = pc.if_else( + is_bad, + pa.scalar([fill_value] * dim), + vec_arr, ) - fill_value = float(fill_value) - values = pc.if_else(pc.is_nan(vec_arr.values), fill_value, vec_arr.values) - ndims = len(vec_arr[0]) - vec_arr = pa.FixedSizeListArray.from_arrays(values, ndims) - data = data.set_column( - data.column_names.index(vector_column_name), vector_column_name, vec_arr - ) - elif on_bad_vectors == "drop": - # Drop is very slow to be able to filter out NaNs in a fixed size list array - np_arr = np.isnan(vec_arr.values.to_numpy(zero_copy_only=False)) - np_arr = np_arr.reshape(-1, vec_arr.type.list_size) - not_nulls = np.any(np_arr, axis=1) - data = data.filter(~not_nulls) - elif on_bad_vectors == "null": - # null = pa.nulls(len(vec_arr)).cast(vec_arr.type) - # values = pc.if_else(pc.is_nan(vec_arr.values), fill_value, vec_arr.values) - np_arr = np.isnan(vec_arr.values.to_numpy(zero_copy_only=False)) - np_arr = np_arr.reshape(-1, vec_arr.type.list_size) - no_nans = np.any(np_arr, axis=1) - data = data.set_column( - data.column_names.index(vector_column_name), - vector_column_name, - pc.if_else(no_nans, vec_arr, pa.scalar(None)), - ) - return data + else: + raise ValueError(f"Invalid value for on_bad_vectors: {on_bad_vectors}") + + position = data.column_names.index(vector_column_name) + return data.set_column(position, vector_column_name, vec_arr) + + +def has_nan_values(arr: Union[pa.ListArray, pa.ChunkedArray]) -> pa.BooleanArray: + if isinstance(arr, pa.ChunkedArray): + values = pa.chunked_array([chunk.flatten() for chunk in arr.chunks]) + else: + values = arr.flatten() + if pa.types.is_float16(values.type): + # is_nan isn't yet implemented for f16, so we cast to f32 + # https://github.com/apache/arrow/issues/45083 + values_has_nan = pc.is_nan(values.cast(pa.float32())) + else: + values_has_nan = pc.is_nan(values) + values_indices = pc.list_parent_indices(arr) + has_nan_indices = pc.unique(pc.filter(values_indices, values_has_nan)) + indices = pa.array(range(len(arr)), type=pa.uint32()) + return pc.is_in(indices, has_nan_indices) + + +def _infer_target_schema(table: pa.Table) -> pa.Schema: + schema = table.schema + + for i, field in enumerate(schema): + if ( + field.name == VECTOR_COLUMN_NAME + and (pa.types.is_list(field.type) or pa.types.is_large_list(field.type)) + and pa.types.is_floating(field.type.value_type) + ): + # Use the most common length of the list as the dimensions + dim = _modal_list_size(table.column(i)) + + new_field = pa.field( + VECTOR_COLUMN_NAME, + pa.list_(pa.float32(), dim), + nullable=field.nullable, + ) + + schema = schema.set(i, new_field) + elif ( + field.name == VECTOR_COLUMN_NAME + and (pa.types.is_list(field.type) or pa.types.is_large_list(field.type)) + and pa.types.is_integer(field.type.value_type) + ): + # Use the most common length of the list as the dimensions + dim = _modal_list_size(table.column(i)) + new_field = pa.field( + VECTOR_COLUMN_NAME, + pa.list_(pa.uint8(), dim), + nullable=field.nullable, + ) + + schema = schema.set(i, new_field) + + return schema + + +def _modal_list_size(arr: Union[pa.ListArray, pa.ChunkedArray]) -> int: + # Use the most common length of the list as the dimensions + return pc.mode(pc.list_value_length(arr))[0].as_py()["mode"] def _validate_schema(schema: pa.Schema): @@ -2410,28 +2505,6 @@ def _validate_metadata(metadata: dict): _validate_metadata(v) -def _process_iterator(data: Iterable, schema: Optional[pa.Schema] = None) -> pa.Table: - batches = [] - for batch in data: - batch_table = _coerce_to_table(batch, schema) - if schema is not None: - if batch_table.schema != schema: - try: - batch_table = batch_table.cast(schema) - except pa.lib.ArrowInvalid: # type: ignore - raise ValueError( - f"Input iterator yielded a batch with schema that " - f"does not match the expected schema.\nExpected:\n{schema}\n" - f"Got:\n{batch_table.schema}" - ) - batches.append(batch_table) - - if batches: - return pa.concat_tables(batches) - else: - raise ValueError("Input iterable is empty") - - class AsyncTable: """ An AsyncTable is a collection of Records in a LanceDB Database. @@ -2678,16 +2751,17 @@ class AsyncTable: on_bad_vectors = "error" if fill_value is None: fill_value = 0.0 - table_and_schema: Tuple[pa.Table, pa.Schema] = _sanitize_data( + data = _sanitize_data( data, schema, metadata=schema.metadata, on_bad_vectors=on_bad_vectors, fill_value=fill_value, + allow_subschema=True, ) - tbl, schema = table_and_schema - if isinstance(tbl, pa.Table): - data = pa.RecordBatchReader.from_batches(schema, tbl.to_batches()) + if isinstance(data, pa.Table): + data = data.to_reader() + await self._inner.add(data, mode or "append") def merge_insert(self, on: Union[str, Iterable[str]]) -> LanceMergeInsertBuilder: @@ -2822,12 +2896,13 @@ class AsyncTable: on_bad_vectors = "error" if fill_value is None: fill_value = 0.0 - data, _ = _sanitize_data( + data = _sanitize_data( new_data, schema, metadata=schema.metadata, on_bad_vectors=on_bad_vectors, fill_value=fill_value, + allow_subschema=True, ) if isinstance(data, pa.Table): data = pa.RecordBatchReader.from_batches(data.schema, data.to_batches()) @@ -2862,9 +2937,9 @@ class AsyncTable: -------- >>> import lancedb >>> data = [ - ... {"x": 1, "vector": [1, 2]}, - ... {"x": 2, "vector": [3, 4]}, - ... {"x": 3, "vector": [5, 6]} + ... {"x": 1, "vector": [1.0, 2]}, + ... {"x": 2, "vector": [3.0, 4]}, + ... {"x": 3, "vector": [5.0, 6]} ... ] >>> db = lancedb.connect("./.lancedb") >>> table = db.create_table("my_table", data) diff --git a/python/python/lancedb/util.py b/python/python/lancedb/util.py index dd248de0..96337dd8 100644 --- a/python/python/lancedb/util.py +++ b/python/python/lancedb/util.py @@ -223,9 +223,7 @@ def inf_vector_column_query(schema: pa.Schema) -> str: vector_col_count = 0 for field_name in schema.names: field = schema.field(field_name) - if pa.types.is_fixed_size_list(field.type) and pa.types.is_floating( - field.type.value_type - ): + if pa.types.is_fixed_size_list(field.type): vector_col_count += 1 if vector_col_count > 1: raise ValueError( diff --git a/python/python/tests/docs/test_binary_vector.py b/python/python/tests/docs/test_binary_vector.py index 69b466e8..0bc8030d 100644 --- a/python/python/tests/docs/test_binary_vector.py +++ b/python/python/tests/docs/test_binary_vector.py @@ -21,7 +21,7 @@ def test_binary_vector(): ] tbl = db.create_table("my_binary_vectors", data=data) query = np.random.randint(0, 256, size=16) - tbl.search(query).to_arrow() + tbl.search(query).metric("hamming").to_arrow() # --8<-- [end:sync_binary_vector] db.drop_table("my_binary_vectors") @@ -39,6 +39,6 @@ async def test_binary_vector_async(): ] tbl = await db.create_table("my_binary_vectors", data=data) query = np.random.randint(0, 256, size=16) - await tbl.query().nearest_to(query).to_arrow() + await tbl.query().nearest_to(query).distance_type("hamming").to_arrow() # --8<-- [end:async_binary_vector] await db.drop_table("my_binary_vectors") diff --git a/python/python/tests/docs/test_guide_index.py b/python/python/tests/docs/test_guide_index.py index 5e99edaf..482dd794 100644 --- a/python/python/tests/docs/test_guide_index.py +++ b/python/python/tests/docs/test_guide_index.py @@ -118,9 +118,9 @@ def test_scalar_index(): # --8<-- [end:search_with_scalar_index] # --8<-- [start:vector_search_with_scalar_index] data = [ - {"book_id": 1, "vector": [1, 2]}, - {"book_id": 2, "vector": [3, 4]}, - {"book_id": 3, "vector": [5, 6]}, + {"book_id": 1, "vector": [1.0, 2]}, + {"book_id": 2, "vector": [3.0, 4]}, + {"book_id": 3, "vector": [5.0, 6]}, ] table = db.create_table("book_with_embeddings", data) @@ -156,9 +156,9 @@ async def test_scalar_index_async(): # --8<-- [end:search_with_scalar_index_async] # --8<-- [start:vector_search_with_scalar_index_async] data = [ - {"book_id": 1, "vector": [1, 2]}, - {"book_id": 2, "vector": [3, 4]}, - {"book_id": 3, "vector": [5, 6]}, + {"book_id": 1, "vector": [1.0, 2]}, + {"book_id": 2, "vector": [3.0, 4]}, + {"book_id": 3, "vector": [5.0, 6]}, ] async_tbl = await async_db.create_table("book_with_embeddings_async", data) (await async_tbl.query().where("book_id != 3").nearest_to([1, 2]).to_pandas()) diff --git a/python/python/tests/test_embeddings.py b/python/python/tests/test_embeddings.py index 4e5fac2c..8778d069 100644 --- a/python/python/tests/test_embeddings.py +++ b/python/python/tests/test_embeddings.py @@ -198,7 +198,6 @@ def test_embedding_function_with_pandas(tmp_path): { "text": ["hello world", "goodbye world"], "val": [1, 2], - "not-used": ["s1", "s3"], } ) db = lancedb.connect(tmp_path) @@ -212,7 +211,6 @@ def test_embedding_function_with_pandas(tmp_path): { "text": ["extra", "more"], "val": [4, 5], - "misc-col": ["s1", "s3"], } ) tbl.add(df) diff --git a/python/python/tests/test_table.py b/python/python/tests/test_table.py index 1809a5c5..2fdf73da 100644 --- a/python/python/tests/test_table.py +++ b/python/python/tests/test_table.py @@ -242,8 +242,8 @@ def test_add_subschema(mem_db: DBConnection): data = {"price": 10.0, "item": "foo"} table.add([data]) - data = {"price": 2.0, "vector": [3.1, 4.1]} - table.add([data]) + data = pd.DataFrame({"price": [2.0], "vector": [[3.1, 4.1]]}) + table.add(data) data = {"price": 3.0, "vector": [5.9, 26.5], "item": "bar"} table.add([data]) @@ -259,7 +259,7 @@ def test_add_subschema(mem_db: DBConnection): data = {"item": "foo"} # We can't omit a column if it's not nullable - with pytest.raises(RuntimeError, match="Invalid user input"): + with pytest.raises(RuntimeError, match="Append with different schema"): table.add([data]) # We can add it if we make the column nullable @@ -292,6 +292,7 @@ def test_add_nullability(mem_db: DBConnection): ] ) table = mem_db.create_table("test", schema=schema) + assert table.schema.field("vector").nullable is False nullable_schema = pa.schema( [ @@ -320,7 +321,10 @@ def test_add_nullability(mem_db: DBConnection): schema=nullable_schema, ) # We can't add nullable schema if it contains nulls - with pytest.raises(Exception, match="Vector column vector has NaNs"): + with pytest.raises( + Exception, + match="Casting field 'vector' with null values to non-nullable", + ): table.add(data) # But we can make it nullable @@ -776,6 +780,38 @@ def test_merge_insert(mem_db: DBConnection): assert table.to_arrow().sort_by("a") == expected +# We vary the data format because there are slight differences in how +# subschemas are handled in different formats +@pytest.mark.parametrize( + "data_format", + [ + lambda table: table, + lambda table: table.to_pandas(), + lambda table: table.to_pylist(), + ], + ids=["pa.Table", "pd.DataFrame", "rows"], +) +def test_merge_insert_subschema(mem_db: DBConnection, data_format): + initial_data = pa.table( + {"id": range(3), "a": [1.0, 2.0, 3.0], "c": ["x", "x", "x"]} + ) + table = mem_db.create_table("my_table", data=initial_data) + + new_data = pa.table({"id": [2, 3], "c": ["y", "y"]}) + new_data = data_format(new_data) + ( + table.merge_insert(on="id") + .when_matched_update_all() + .when_not_matched_insert_all() + .execute(new_data) + ) + + expected = pa.table( + {"id": [0, 1, 2, 3], "a": [1.0, 2.0, 3.0, None], "c": ["x", "x", "y", "y"]} + ) + assert table.to_arrow().sort_by("id") == expected + + @pytest.mark.asyncio async def test_merge_insert_async(mem_db_async: AsyncConnection): data = pa.table({"a": [1, 2, 3], "b": ["a", "b", "c"]}) diff --git a/python/python/tests/test_util.py b/python/python/tests/test_util.py index 2681505f..596b4649 100644 --- a/python/python/tests/test_util.py +++ b/python/python/tests/test_util.py @@ -13,10 +13,27 @@ import os import pathlib +from typing import Optional +import lance +from lancedb.conftest import MockTextEmbeddingFunction +from lancedb.embeddings.base import EmbeddingFunctionConfig +from lancedb.embeddings.registry import EmbeddingFunctionRegistry +from lancedb.table import ( + _append_vector_columns, + _cast_to_target_schema, + _handle_bad_vectors, + _into_pyarrow_table, + _sanitize_data, + _infer_target_schema, +) +import pyarrow as pa +import pandas as pd +import polars as pl import pytest import lancedb from lancedb.util import get_uri_scheme, join_uri, value_to_sql +from utils import exception_output def test_normalize_uri(): @@ -111,3 +128,460 @@ def test_value_to_sql_string(tmp_path): for value in values: table.update(where=f"search = {value_to_sql(value)}", values={"replace": value}) assert table.to_pandas().query("search == @value")["replace"].item() == value + + +def test_append_vector_columns(): + registry = EmbeddingFunctionRegistry.get_instance() + registry.register("test")(MockTextEmbeddingFunction) + conf = EmbeddingFunctionConfig( + source_column="text", + vector_column="vector", + function=MockTextEmbeddingFunction(), + ) + metadata = registry.get_table_metadata([conf]) + + schema = pa.schema( + { + "text": pa.string(), + "vector": pa.list_(pa.float64(), 10), + } + ) + data = pa.table( + { + "text": ["hello"], + "vector": [None], # Replaces null + }, + schema=schema, + ) + output = _append_vector_columns( + data, + schema, # metadata passed separate from schema + metadata=metadata, + ) + assert output.schema == schema + assert output["vector"].null_count == 0 + + # Adds if missing + data = pa.table({"text": ["hello"]}) + output = _append_vector_columns( + data, + schema.with_metadata(metadata), + ) + assert output.schema == schema + assert output["vector"].null_count == 0 + + # doesn't embed if already there + data = pa.table( + { + "text": ["hello"], + "vector": [[42.0] * 10], + }, + schema=schema, + ) + output = _append_vector_columns( + data, + schema.with_metadata(metadata), + ) + assert output == data # No change + + # No provided schema + data = pa.table( + { + "text": ["hello"], + } + ) + output = _append_vector_columns( + data, + metadata=metadata, + ) + expected_schema = pa.schema( + { + "text": pa.string(), + "vector": pa.list_(pa.float32(), 10), + } + ) + assert output.schema == expected_schema + assert output["vector"].null_count == 0 + + +@pytest.mark.parametrize("on_bad_vectors", ["error", "drop", "fill", "null"]) +def test_handle_bad_vectors_jagged(on_bad_vectors): + vector = pa.array([[1.0, 2.0], [3.0], [4.0, 5.0]]) + schema = pa.schema({"vector": pa.list_(pa.float64())}) + data = pa.table({"vector": vector}, schema=schema) + + if on_bad_vectors == "error": + with pytest.raises(ValueError) as e: + output = _handle_bad_vectors( + data, + on_bad_vectors=on_bad_vectors, + ) + output = exception_output(e) + assert output == ( + "ValueError: Vector column 'vector' has variable length vectors. Set " + "on_bad_vectors='drop' to remove them, set on_bad_vectors='fill' " + "and fill_value= to replace them, or set on_bad_vectors='null' " + "to replace them with null." + ) + return + else: + output = _handle_bad_vectors( + data, + on_bad_vectors=on_bad_vectors, + fill_value=42.0, + ) + + if on_bad_vectors == "drop": + expected = pa.array([[1.0, 2.0], [4.0, 5.0]]) + elif on_bad_vectors == "fill": + expected = pa.array([[1.0, 2.0], [42.0, 42.0], [4.0, 5.0]]) + elif on_bad_vectors == "null": + expected = pa.array([[1.0, 2.0], None, [4.0, 5.0]]) + + assert output["vector"].combine_chunks() == expected + + +@pytest.mark.parametrize("on_bad_vectors", ["error", "drop", "fill", "null"]) +def test_handle_bad_vectors_nan(on_bad_vectors): + vector = pa.array([[1.0, float("nan")], [3.0, 4.0]]) + data = pa.table({"vector": vector}) + + if on_bad_vectors == "error": + with pytest.raises(ValueError) as e: + output = _handle_bad_vectors( + data, + on_bad_vectors=on_bad_vectors, + ) + output = exception_output(e) + assert output == ( + "ValueError: Vector column 'vector' has NaNs. Set " + "on_bad_vectors='drop' to remove them, set on_bad_vectors='fill' " + "and fill_value= to replace them, or set on_bad_vectors='null' " + "to replace them with null." + ) + return + else: + output = _handle_bad_vectors( + data, + on_bad_vectors=on_bad_vectors, + fill_value=42.0, + ) + + if on_bad_vectors == "drop": + expected = pa.array([[3.0, 4.0]]) + elif on_bad_vectors == "fill": + expected = pa.array([[42.0, 42.0], [3.0, 4.0]]) + elif on_bad_vectors == "null": + expected = pa.array([None, [3.0, 4.0]]) + + assert output["vector"].combine_chunks() == expected + + +def test_handle_bad_vectors_noop(): + # ChunkedArray should be preserved as-is + vector = pa.chunked_array( + [[[1.0, 2.0], [3.0, 4.0]]], type=pa.list_(pa.float64(), 2) + ) + data = pa.table({"vector": vector}) + output = _handle_bad_vectors(data) + assert output["vector"] == vector + + +class TestModel(lancedb.pydantic.LanceModel): + a: Optional[int] + b: Optional[int] + + +# TODO: huggingface, +@pytest.mark.parametrize( + "data", + [ + lambda: [{"a": 1, "b": 2}], + lambda: pa.RecordBatch.from_pylist([{"a": 1, "b": 2}]), + lambda: pa.table({"a": [1], "b": [2]}), + lambda: pa.table({"a": [1], "b": [2]}).to_reader(), + lambda: iter(pa.table({"a": [1], "b": [2]}).to_batches()), + lambda: ( + lance.write_dataset( + pa.table({"a": [1], "b": [2]}), + "memory://test", + ) + ), + lambda: ( + lance.write_dataset( + pa.table({"a": [1], "b": [2]}), + "memory://test", + ).scanner() + ), + lambda: pd.DataFrame({"a": [1], "b": [2]}), + lambda: pl.DataFrame({"a": [1], "b": [2]}), + lambda: pl.LazyFrame({"a": [1], "b": [2]}), + lambda: [TestModel(a=1, b=2)], + ], + ids=[ + "rows", + "pa.RecordBatch", + "pa.Table", + "pa.RecordBatchReader", + "batch_iter", + "lance.LanceDataset", + "lance.LanceScanner", + "pd.DataFrame", + "pl.DataFrame", + "pl.LazyFrame", + "pydantic", + ], +) +def test_into_pyarrow_table(data): + expected = pa.table({"a": [1], "b": [2]}) + output = _into_pyarrow_table(data()) + assert output == expected + + +def test_infer_target_schema(): + example = pa.schema( + { + "vec1": pa.list_(pa.float64(), 2), + "vector": pa.list_(pa.float64()), + } + ) + data = pa.table( + { + "vec1": [[0.0] * 2], + "vector": [[0.0] * 2], + }, + schema=example, + ) + expected = pa.schema( + { + "vec1": pa.list_(pa.float64(), 2), + "vector": pa.list_(pa.float32(), 2), + } + ) + output = _infer_target_schema(data) + assert output == expected + + # Handle large list and use modal size + # Most vectors are of length 2, so we should infer that as the target dimension + example = pa.schema( + { + "vector": pa.large_list(pa.float64()), + } + ) + data = pa.table( + { + "vector": [[0.0] * 2, [0.0], [0.0] * 2], + }, + schema=example, + ) + expected = pa.schema( + { + "vector": pa.list_(pa.float32(), 2), + } + ) + output = _infer_target_schema(data) + assert output == expected + + # ignore if not list + example = pa.schema( + { + "vector": pa.float64(), + } + ) + data = pa.table( + { + "vector": [0.0], + }, + schema=example, + ) + expected = example + output = _infer_target_schema(data) + assert output == expected + + +@pytest.mark.parametrize( + "data", + [ + [{"id": 1, "text": "hello"}], + pa.RecordBatch.from_pylist([{"id": 1, "text": "hello"}]), + pd.DataFrame({"id": [1], "text": ["hello"]}), + pl.DataFrame({"id": [1], "text": ["hello"]}), + ], + ids=["rows", "pa.RecordBatch", "pd.DataFrame", "pl.DataFrame"], +) +@pytest.mark.parametrize( + "schema", + [ + None, + pa.schema( + { + "id": pa.int32(), + "text": pa.string(), + "vector": pa.list_(pa.float32(), 10), + } + ), + pa.schema( + { + "id": pa.int64(), + "text": pa.string(), + "vector": pa.list_(pa.float32(), 10), + "extra": pa.int64(), + } + ), + ], + ids=["infer", "explicit", "subschema"], +) +@pytest.mark.parametrize("with_embedding", [True, False]) +def test_sanitize_data( + data, + schema: Optional[pa.Schema], + with_embedding: bool, +): + if with_embedding: + registry = EmbeddingFunctionRegistry.get_instance() + registry.register("test")(MockTextEmbeddingFunction) + conf = EmbeddingFunctionConfig( + source_column="text", + vector_column="vector", + function=MockTextEmbeddingFunction(), + ) + metadata = registry.get_table_metadata([conf]) + else: + metadata = None + + if schema is not None: + to_remove = schema.get_field_index("extra") + if to_remove >= 0: + expected_schema = schema.remove(to_remove) + else: + expected_schema = schema + else: + expected_schema = pa.schema( + { + "id": pa.int64(), + "text": pa.large_utf8() + if isinstance(data, pl.DataFrame) + else pa.string(), + "vector": pa.list_(pa.float32(), 10), + } + ) + + if not with_embedding: + to_remove = expected_schema.get_field_index("vector") + if to_remove >= 0: + expected_schema = expected_schema.remove(to_remove) + + expected = pa.table( + { + "id": [1], + "text": ["hello"], + "vector": [[0.0] * 10], + }, + schema=expected_schema, + ) + + output_data = _sanitize_data( + data, + target_schema=schema, + metadata=metadata, + allow_subschema=True, + ) + + assert output_data == expected + + +def test_cast_to_target_schema(): + original_schema = pa.schema( + { + "id": pa.int32(), + "struct": pa.struct( + [ + pa.field("a", pa.int32()), + ] + ), + "vector": pa.list_(pa.float64()), + "vec1": pa.list_(pa.float64(), 2), + "vec2": pa.list_(pa.float32(), 2), + } + ) + data = pa.table( + { + "id": [1], + "struct": [{"a": 1}], + "vector": [[0.0] * 2], + "vec1": [[0.0] * 2], + "vec2": [[0.0] * 2], + }, + schema=original_schema, + ) + + target = pa.schema( + { + "id": pa.int64(), + "struct": pa.struct( + [ + pa.field("a", pa.int64()), + ] + ), + "vector": pa.list_(pa.float32(), 2), + "vec1": pa.list_(pa.float32(), 2), + "vec2": pa.list_(pa.float32(), 2), + } + ) + output = _cast_to_target_schema(data, target) + expected = pa.table( + { + "id": [1], + "struct": [{"a": 1}], + "vector": [[0.0] * 2], + "vec1": [[0.0] * 2], + "vec2": [[0.0] * 2], + }, + schema=target, + ) + + # Data can be a subschema of the target + target = pa.schema( + { + "id": pa.int64(), + "struct": pa.struct( + [ + pa.field("a", pa.int64()), + # Additional nested field + pa.field("b", pa.int64()), + ] + ), + "vector": pa.list_(pa.float32(), 2), + "vec1": pa.list_(pa.float32(), 2), + "vec2": pa.list_(pa.float32(), 2), + # Additional field + "extra": pa.int64(), + } + ) + with pytest.raises(Exception): + _cast_to_target_schema(data, target) + output = _cast_to_target_schema(data, target, allow_subschema=True) + expected_schema = pa.schema( + { + "id": pa.int64(), + "struct": pa.struct( + [ + pa.field("a", pa.int64()), + ] + ), + "vector": pa.list_(pa.float32(), 2), + "vec1": pa.list_(pa.float32(), 2), + "vec2": pa.list_(pa.float32(), 2), + } + ) + expected = pa.table( + { + "id": [1], + "struct": [{"a": 1}], + "vector": [[0.0] * 2], + "vec1": [[0.0] * 2], + "vec2": [[0.0] * 2], + }, + schema=expected_schema, + ) + assert output == expected