diff --git a/.gitignore b/.gitignore index 5d428df1..92421b13 100644 --- a/.gitignore +++ b/.gitignore @@ -31,3 +31,4 @@ node/examples/**/dist ## Rust target +Cargo.lock \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 834faae6..b8f409a0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,4 +8,14 @@ repos: - repo: https://github.com/psf/black rev: 22.12.0 hooks: - - id: black \ No newline at end of file + - id: black +- repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.0.277 + hooks: + - id: ruff +- repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort + name: isort (python) \ No newline at end of file diff --git a/python/lancedb/db.py b/python/lancedb/db.py index a9a43eb8..50b2147d 100644 --- a/python/lancedb/db.py +++ b/python/lancedb/db.py @@ -158,6 +158,8 @@ class LanceDBConnection: data: DATA = None, schema: pa.Schema = None, mode: str = "create", + on_bad_vectors: str = "drop", + fill_value: float = 0.0, ) -> LanceTable: """Create a table in the database. @@ -173,6 +175,11 @@ class LanceDBConnection: The mode to use when creating the table. Can be either "create" or "overwrite". By default, if the table already exists, an exception is raised. If you want to overwrite the table, use mode="overwrite". + on_bad_vectors: str + What to do if any of the vectors are not the same size or contains NaNs. + One of "raise", "drop", "fill". + fill_value: float + The value to use when filling vectors. Only used if on_bad_vectors="fill". Note ---- @@ -253,7 +260,15 @@ class LanceDBConnection: raise ValueError("mode must be either 'create' or 'overwrite'") if data is not None: - tbl = LanceTable.create(self, name, data, schema, mode=mode) + tbl = LanceTable.create( + self, + name, + data, + schema, + mode=mode, + on_bad_vectors=on_bad_vectors, + fill_value=fill_value, + ) else: tbl = LanceTable.open(self, name) return tbl diff --git a/python/lancedb/table.py b/python/lancedb/table.py index 07c0c61f..111f0bae 100644 --- a/python/lancedb/table.py +++ b/python/lancedb/table.py @@ -15,12 +15,13 @@ from __future__ import annotations import os from functools import cached_property -from typing import List, Union +from typing import Any, List, Union import lance import numpy as np import pandas as pd import pyarrow as pa +import pyarrow.compute as pc from lance import LanceDataset from lance.vector import vec_to_table @@ -28,15 +29,19 @@ from .common import DATA, VEC, VECTOR_COLUMN_NAME from .query import LanceFtsQueryBuilder, LanceQueryBuilder -def _sanitize_data(data, schema): +def _sanitize_data(data, schema, on_bad_vectors, fill_value): if isinstance(data, list): data = pa.Table.from_pylist(data) - data = _sanitize_schema(data, schema=schema) + data = _sanitize_schema( + data, schema=schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value + ) if isinstance(data, dict): data = vec_to_table(data) if isinstance(data, pd.DataFrame): data = pa.Table.from_pandas(data) - data = _sanitize_schema(data, schema=schema) + data = _sanitize_schema( + data, schema=schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value + ) if not isinstance(data, pa.Table): raise TypeError(f"Unsupported data type: {type(data)}") return data @@ -249,7 +254,13 @@ class LanceTable: """Return the LanceDataset backing this table.""" return self._dataset - def add(self, data: DATA, mode: str = "append") -> int: + def add( + self, + data: DATA, + mode: str = "append", + on_bad_vectors: str = "drop", + fill_value: float = 0.0, + ) -> int: """Add data to the table. Parameters @@ -259,13 +270,20 @@ class LanceTable: mode: str The mode to use when writing the data. Valid values are "append" and "overwrite". + on_bad_vectors: str + What to do if any of the vectors are not the same size or contains NaNs. + One of "raise", "drop", "fill". + fill_value: float, default 0. + The value to use when filling vectors. Only used if on_bad_vectors="fill". Returns ------- int The number of vectors in the table. """ - data = _sanitize_data(data, self.schema) + data = _sanitize_data( + data, self.schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value + ) lance.write_dataset(data, self._dataset_uri, mode=mode) self._reset_dataset() return len(self) @@ -280,6 +298,8 @@ class LanceTable: ---------- query: list, np.ndarray The query vector. + vector_column_name: str, default "vector" + The name of the vector column to search. Returns ------- @@ -302,9 +322,55 @@ class LanceTable: return LanceQueryBuilder(self, query, vector_column_name) @classmethod - def create(cls, db, name, data, schema=None, mode="create"): + def create( + cls, + db, + name, + data, + schema=None, + mode="create", + on_bad_vectors: str = "drop", + fill_value: float = 0.0, + ): + """ + Create a new table. + + Examples + -------- + >>> import lancedb + >>> import pandas as pd + >>> data = pd.DataFrame({"x": [1, 2, 3], "vector": [[1, 2], [3, 4], [5, 6]]}) + >>> db = lancedb.connect("./.lancedb") + >>> table = db.create_table("my_table", data) + >>> table.to_pandas() + x vector + 0 1 [1.0, 2.0] + 1 2 [3.0, 4.0] + 2 3 [5.0, 6.0] + + Parameters + ---------- + db: LanceDB + The LanceDB instance to create the table in. + name: str + The name of the table to create. + data: list-of-dict, dict, pd.DataFrame + The data to insert into the table. + schema: dict, optional + The schema of the table. If not provided, the schema is inferred from the data. + mode: str, default "create" + The mode to use when writing the data. Valid values are + "create", "overwrite", and "append". + on_bad_vectors: str + What to do if any of the vectors are not the same size or contains NaNs. + One of "raise", "drop", "fill". + fill_value: float, default 0. + The value to use when filling vectors. Only used if on_bad_vectors="fill". + """ tbl = LanceTable(db, name) - data = _sanitize_data(data, schema) + data = _sanitize_data( + data, schema, on_bad_vectors=on_bad_vectors, fill_value=fill_value + ) lance.write_dataset(data, tbl._dataset_uri, mode=mode) return tbl @@ -350,7 +416,12 @@ class LanceTable: self._dataset.delete(where) -def _sanitize_schema(data: pa.Table, schema: pa.Schema = None) -> pa.Table: +def _sanitize_schema( + data: pa.Table, + schema: pa.Schema = None, + on_bad_vectors: str = "drop", + fill_value: float = 0.0, +) -> pa.Table: """Ensure that the table has the expected schema. Parameters @@ -360,21 +431,41 @@ def _sanitize_schema(data: pa.Table, schema: pa.Schema = None) -> pa.Table: schema: pa.Schema; optional The expected schema. If not provided, this just converts the vector column to fixed_size_list(float32) if necessary. + on_bad_vectors: str + What to do if any of the vectors are not the same size or contains NaNs. + One of "raise", "drop", "fill". + fill_value: float + The value to use when filling vectors. Only used if on_bad_vectors="fill". """ if schema is not None: if data.schema == schema: return data # cast the columns to the expected types data = data.combine_chunks() - data = _sanitize_vector_column(data, vector_column_name=VECTOR_COLUMN_NAME) + data = _sanitize_vector_column( + data, + vector_column_name=VECTOR_COLUMN_NAME, + on_bad_vectors=on_bad_vectors, + fill_value=fill_value, + ) return pa.Table.from_arrays( [data[name] for name in schema.names], schema=schema ) # just check the vector column - return _sanitize_vector_column(data, vector_column_name=VECTOR_COLUMN_NAME) + return _sanitize_vector_column( + data, + vector_column_name=VECTOR_COLUMN_NAME, + on_bad_vectors=on_bad_vectors, + fill_value=fill_value, + ) -def _sanitize_vector_column(data: pa.Table, vector_column_name: str) -> pa.Table: +def _sanitize_vector_column( + data: pa.Table, + vector_column_name: str, + on_bad_vectors: str = "drop", + fill_value: float = 0.0, +) -> pa.Table: """ Ensure that the vector column exists and has type fixed_size_list(float32) @@ -384,19 +475,103 @@ def _sanitize_vector_column(data: pa.Table, vector_column_name: str) -> pa.Table The table to sanitize. vector_column_name: str The name of the vector column. + on_bad_vectors: str + What to do if any of the vectors are not the same size or contains NaNs. + One of "raise", "drop", "fill". + fill_value: float + The value to use when filling vectors. Only used if on_bad_vectors="fill". """ if vector_column_name not in data.column_names: raise ValueError(f"Missing vector column: {vector_column_name}") + # ChunkedArray is annoying to work with, so we combine chunks here vec_arr = data[vector_column_name].combine_chunks() - if pa.types.is_fixed_size_list(vec_arr.type): - return data - if not pa.types.is_list(vec_arr.type): + if pa.types.is_list(data[vector_column_name].type): + # if it's a variable size list array we make sure the dimensions are all the same + has_jagged_ndims = len(vec_arr.values) % len(data) != 0 + if has_jagged_ndims: + data = _sanitize_jagged( + data, fill_value, on_bad_vectors, vec_arr, vector_column_name + ) + vec_arr = data[vector_column_name].combine_chunks() + elif not pa.types.is_fixed_size_list(vec_arr.type): raise TypeError(f"Unsupported vector column type: {vec_arr.type}") + + vec_arr = ensure_fixed_size_list_of_f32(vec_arr) + data = data.set_column( + data.column_names.index(vector_column_name), vector_column_name, vec_arr + ) + + has_nans = pc.any(vec_arr.values.is_nan()).as_py() + if has_nans: + data = _sanitize_nans( + data, fill_value, on_bad_vectors, vec_arr, vector_column_name + ) + + return data + + +def ensure_fixed_size_list_of_f32(vec_arr): values = vec_arr.values if not pa.types.is_float32(values.type): values = values.cast(pa.float32()) - list_size = len(values) / len(data) + if pa.types.is_fixed_size_list(vec_arr.type): + list_size = vec_arr.type.list_size + else: + list_size = len(values) / len(vec_arr) vec_arr = pa.FixedSizeListArray.from_arrays(values, list_size) - return data.set_column( - data.column_names.index(vector_column_name), vector_column_name, vec_arr - ) + return vec_arr + + +def _sanitize_jagged(data, fill_value, on_bad_vectors, vec_arr, vector_column_name): + """Sanitize jagged vectors.""" + if on_bad_vectors == "raise": + raise ValueError( + f"Vector column {vector_column_name} has variable length vectors " + "Set on_bad_vectors='drop' to remove them, or " + "set on_bad_vectors='fill' and fill_value= to replace them." + ) + + lst_lengths = pc.list_value_length(vec_arr) + ndims = pc.max(lst_lengths).as_py() + correct_ndims = pc.equal(lst_lengths, ndims) + + if on_bad_vectors == "fill": + if fill_value is None: + raise ValueError( + f"`fill_value` must not be None if `on_bad_vectors` is 'fill'" + ) + fill_arr = pa.scalar([float(fill_value)] * ndims) + vec_arr = pc.if_else(correct_ndims, vec_arr, fill_arr) + data = data.set_column( + data.column_names.index(vector_column_name), vector_column_name, vec_arr + ) + elif on_bad_vectors == "drop": + data = data.filter(correct_ndims) + return data + + +def _sanitize_nans(data, fill_value, on_bad_vectors, vec_arr, vector_column_name): + """Sanitize NaNs in vectors""" + if on_bad_vectors == "raise": + raise ValueError( + f"Vector column {vector_column_name} has NaNs. " + "Set on_bad_vectors='drop' to remove them, or " + "set on_bad_vectors='fill' and fill_value= to replace them." + ) + elif on_bad_vectors == "fill": + if fill_value is None: + raise ValueError( + f"`fill_value` must not be None if `on_bad_vectors` is 'fill'" + ) + fill_value = float(fill_value) + values = pc.if_else(vec_arr.values.is_nan(), fill_value, vec_arr.values) + ndims = len(vec_arr[0]) + vec_arr = pa.FixedSizeListArray.from_arrays(values, ndims) + data = data.set_column( + data.column_names.index(vector_column_name), vector_column_name, vec_arr + ) + elif on_bad_vectors == "drop": + is_value_nan = pc.is_nan(vec_arr.values).to_numpy(zero_copy_only=False) + is_full = np.any(~is_value_nan.reshape(-1, vec_arr.type.list_size), axis=1) + data = data.filter(is_full) + return data diff --git a/python/tests/test_query.py b/python/tests/test_query.py index 1af20f2b..69eda338 100644 --- a/python/tests/test_query.py +++ b/python/tests/test_query.py @@ -15,7 +15,6 @@ import unittest.mock as mock import lance import numpy as np -import pandas as pd import pandas.testing as tm import pyarrow as pa import pytest diff --git a/python/tests/test_remote_client.py b/python/tests/test_remote_client.py index e9fd309a..ee90f28a 100644 --- a/python/tests/test_remote_client.py +++ b/python/tests/test_remote_client.py @@ -30,7 +30,7 @@ class MockLanceDBServer: table_name = request.match_info["table_name"] assert table_name == "test_table" - request_json = await request.json() + await request.json() # TODO: do some matching vecs = pd.Series([np.random.rand(128) for x in range(10)], name="vector") diff --git a/python/tests/test_table.py b/python/tests/test_table.py index e2e9b64b..1dc10d65 100644 --- a/python/tests/test_table.py +++ b/python/tests/test_table.py @@ -15,6 +15,7 @@ import functools from pathlib import Path from unittest.mock import PropertyMock, patch +import numpy as np import pandas as pd import pyarrow as pa import pytest @@ -167,7 +168,8 @@ def test_create_index_method(): replace=True, ) - # Check that the _dataset.create_index method was called with the right parameters + # Check that the _dataset.create_index method was called + # with the right parameters mock_dataset.return_value.create_index.assert_called_once_with( column="vector", index_type="IVF_PQ", @@ -176,3 +178,50 @@ def test_create_index_method(): num_sub_vectors=96, replace=True, ) + + +def test_add_with_nans(db): + # By default we drop bad input vectors + table = LanceTable.create( + db, + "drop_test", + data=[ + {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, + {"vector": [np.nan], "item": "bar", "price": 20.0}, + {"vector": [5], "item": "bar", "price": 20.0}, + {"vector": [np.nan, np.nan], "item": "bar", "price": 20.0}, + ], + ) + assert len(table) == 1 + + # We can fill bad input with some value + table = LanceTable.create( + db, + "fill_test", + data=[ + {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, + {"vector": [np.nan], "item": "bar", "price": 20.0}, + {"vector": [np.nan, np.nan], "item": "bar", "price": 20.0}, + ], + on_bad_vectors="fill", + fill_value=0.0, + ) + assert len(table) == 3 + arrow_tbl = table.to_lance().to_table(filter="item == 'bar'") + v = arrow_tbl["vector"].to_pylist()[0] + assert np.allclose(v, np.array([0.0, 0.0])) + + bad_data = [ + {"vector": [np.nan], "item": "bar", "price": 20.0}, + {"vector": [5], "item": "bar", "price": 20.0}, + {"vector": [np.nan, np.nan], "item": "bar", "price": 20.0}, + {"vector": [np.nan, 5.0], "item": "bar", "price": 20.0}, + ] + for row in bad_data: + with pytest.raises(ValueError): + LanceTable.create( + db, + "raise_test", + data=[{"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, row], + on_bad_vectors="raise", + )