From 915d828cee1c669ddd8c368df293c54a7205498e Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Thu, 19 Sep 2024 23:16:20 -0700 Subject: [PATCH] feat!: set embeddings to Null if embedding function return invalid results (#1674) --- python/python/lancedb/embeddings/base.py | 72 +++++++++++++--------- python/python/lancedb/embeddings/ollama.py | 37 +++++------ python/python/lancedb/table.py | 47 +++++++++----- python/python/tests/test_embeddings.py | 41 ++++++++++++ 4 files changed, 130 insertions(+), 67 deletions(-) diff --git a/python/python/lancedb/embeddings/base.py b/python/python/lancedb/embeddings/base.py index bcd6d2cd..07ef17ae 100644 --- a/python/python/lancedb/embeddings/base.py +++ b/python/python/lancedb/embeddings/base.py @@ -1,15 +1,6 @@ -# Copyright (c) 2023. LanceDB Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The LanceDB Authors + from abc import ABC, abstractmethod from typing import List, Union @@ -34,7 +25,7 @@ class EmbeddingFunction(BaseModel, ABC): __slots__ = ("__weakref__",) # pydantic 1.x compatibility max_retries: int = ( - 7 # Setitng 0 disables retires. Maybe this should not be enabled by default, + 7 # Setting 0 disables retires. Maybe this should not be enabled by default, ) _ndims: int = PrivateAttr() @@ -46,22 +37,37 @@ class EmbeddingFunction(BaseModel, ABC): return cls(**kwargs) @abstractmethod - def compute_query_embeddings(self, *args, **kwargs) -> List[np.array]: + def compute_query_embeddings(self, *args, **kwargs) -> list[Union[np.array, None]]: """ Compute the embeddings for a given user query + + Returns + ------- + A list of embeddings for each input. The embedding of each input can be None + when the embedding is not valid. """ pass @abstractmethod - def compute_source_embeddings(self, *args, **kwargs) -> List[np.array]: - """ - Compute the embeddings for the source column in the database + def compute_source_embeddings(self, *args, **kwargs) -> list[Union[np.array, None]]: + """Compute the embeddings for the source column in the database + + Returns + ------- + A list of embeddings for each input. The embedding of each input can be None + when the embedding is not valid. """ pass - def compute_query_embeddings_with_retry(self, *args, **kwargs) -> List[np.array]: - """ - Compute the embeddings for a given user query with retries + def compute_query_embeddings_with_retry( + self, *args, **kwargs + ) -> list[Union[np.array, None]]: + """Compute the embeddings for a given user query with retries + + Returns + ------- + A list of embeddings for each input. The embedding of each input can be None + when the embedding is not valid. """ return retry_with_exponential_backoff( self.compute_query_embeddings, max_retries=self.max_retries @@ -70,9 +76,15 @@ class EmbeddingFunction(BaseModel, ABC): **kwargs, ) - def compute_source_embeddings_with_retry(self, *args, **kwargs) -> List[np.array]: - """ - Compute the embeddings for the source column in the database with retries + def compute_source_embeddings_with_retry( + self, *args, **kwargs + ) -> list[Union[np.array, None]]: + """Compute the embeddings for the source column in the database with retries. + + Returns + ------- + A list of embeddings for each input. The embedding of each input can be None + when the embedding is not valid. """ return retry_with_exponential_backoff( self.compute_source_embeddings, max_retries=self.max_retries @@ -144,18 +156,20 @@ class TextEmbeddingFunction(EmbeddingFunction): A callable ABC for embedding functions that take text as input """ - def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]: + def compute_query_embeddings( + self, query: str, *args, **kwargs + ) -> list[Union[np.array, None]]: return self.compute_source_embeddings(query, *args, **kwargs) - def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]: + def compute_source_embeddings( + self, texts: TEXT, *args, **kwargs + ) -> list[Union[np.array, None]]: texts = self.sanitize_input(texts) return self.generate_embeddings(texts) @abstractmethod def generate_embeddings( self, texts: Union[List[str], np.ndarray], *args, **kwargs - ) -> List[np.array]: - """ - Generate the embeddings for the given texts - """ + ) -> list[Union[np.array, None]]: + """Generate the embeddings for the given texts""" pass diff --git a/python/python/lancedb/embeddings/ollama.py b/python/python/lancedb/embeddings/ollama.py index 6e1be917..1dbc3305 100644 --- a/python/python/lancedb/embeddings/ollama.py +++ b/python/python/lancedb/embeddings/ollama.py @@ -1,15 +1,6 @@ -# Copyright (c) 2023. LanceDB Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The LanceDB Authors + from functools import cached_property from typing import TYPE_CHECKING, List, Optional, Union @@ -19,6 +10,7 @@ from .registry import register if TYPE_CHECKING: import numpy as np + import ollama @register("ollama") @@ -39,17 +31,20 @@ class OllamaEmbeddings(TextEmbeddingFunction): def ndims(self): return len(self.generate_embeddings(["foo"])[0]) - def _compute_embedding(self, text): - return self._ollama_client.embeddings( - model=self.name, - prompt=text, - options=self.options, - keep_alive=self.keep_alive, - )["embedding"] + def _compute_embedding(self, text) -> Union["np.array", None]: + return ( + self._ollama_client.embeddings( + model=self.name, + prompt=text, + options=self.options, + keep_alive=self.keep_alive, + )["embedding"] + or None + ) def generate_embeddings( self, texts: Union[List[str], "np.ndarray"] - ) -> List["np.array"]: + ) -> list[Union["np.array", None]]: """ Get the embeddings for the given texts @@ -63,7 +58,7 @@ class OllamaEmbeddings(TextEmbeddingFunction): return embeddings @cached_property - def _ollama_client(self): + def _ollama_client(self) -> "ollama.Client": ollama = attempt_import_or_raise("ollama") # ToDo explore ollama.AsyncClient return ollama.Client(host=self.host, **self.ollama_client_kwargs) diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index d63c2084..615dd27e 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -1998,22 +1998,26 @@ def _sanitize_vector_column( data, fill_value, on_bad_vectors, vec_arr, vector_column_name ) vec_arr = data[vector_column_name].combine_chunks() + vec_arr = ensure_fixed_size_list(vec_arr) + data = data.set_column( + data.column_names.index(vector_column_name), vector_column_name, vec_arr + ) elif not pa.types.is_fixed_size_list(vec_arr.type): raise TypeError(f"Unsupported vector column type: {vec_arr.type}") - vec_arr = ensure_fixed_size_list(vec_arr) - data = data.set_column( - data.column_names.index(vector_column_name), vector_column_name, vec_arr - ) - - # Use numpy to check for NaNs, because as pyarrow 14.0.2 does not have `is_nan` - # kernel over f16 types. - values_np = vec_arr.values.to_numpy(zero_copy_only=False) - if np.isnan(values_np).any(): - data = _sanitize_nans( - data, fill_value, on_bad_vectors, vec_arr, vector_column_name - ) - + if pa.types.is_float16(vec_arr.values.type): + # Use numpy to check for NaNs, because as pyarrow does not have `is_nan` + # kernel over f16 types yet. + values_np = vec_arr.values.to_numpy(zero_copy_only=True) + if np.isnan(values_np).any(): + data = _sanitize_nans( + data, fill_value, on_bad_vectors, vec_arr, vector_column_name + ) + else: + if pc.any(pc.is_null(vec_arr.values, nan_is_null=True)).as_py(): + data = _sanitize_nans( + data, fill_value, on_bad_vectors, vec_arr, vector_column_name + ) return data @@ -2057,8 +2061,15 @@ def _sanitize_jagged(data, fill_value, on_bad_vectors, vec_arr, vector_column_na return data -def _sanitize_nans(data, fill_value, on_bad_vectors, vec_arr, vector_column_name): +def _sanitize_nans( + data, + fill_value, + on_bad_vectors, + vec_arr: pa.FixedSizeListArray, + vector_column_name: str, +): """Sanitize NaNs in vectors""" + assert pa.types.is_fixed_size_list(vec_arr.type) if on_bad_vectors == "error": raise ValueError( f"Vector column {vector_column_name} has NaNs. " @@ -2078,9 +2089,11 @@ def _sanitize_nans(data, fill_value, on_bad_vectors, vec_arr, vector_column_name data.column_names.index(vector_column_name), vector_column_name, vec_arr ) elif on_bad_vectors == "drop": - is_value_nan = pc.is_nan(vec_arr.values).to_numpy(zero_copy_only=False) - is_full = np.any(~is_value_nan.reshape(-1, vec_arr.type.list_size), axis=1) - data = data.filter(is_full) + # Drop is very slow to be able to filter out NaNs in a fixed size list array + np_arr = np.isnan(vec_arr.values.to_numpy(zero_copy_only=False)) + np_arr = np_arr.reshape(-1, vec_arr.type.list_size) + not_nulls = np.any(np_arr, axis=1) + data = data.filter(~not_nulls) return data diff --git a/python/python/tests/test_embeddings.py b/python/python/tests/test_embeddings.py index 05886699..f858e8fc 100644 --- a/python/python/tests/test_embeddings.py +++ b/python/python/tests/test_embeddings.py @@ -86,6 +86,47 @@ def test_embedding_function(tmp_path): assert np.allclose(actual, expected) +def test_embedding_with_bad_results(tmp_path): + @register("mock-embedding") + class MockEmbeddingFunction(TextEmbeddingFunction): + def ndims(self): + return 128 + + def generate_embeddings( + self, texts: Union[List[str], np.ndarray] + ) -> list[Union[np.array, None]]: + return [ + None if i % 2 == 0 else np.random.randn(self.ndims()) + for i in range(len(texts)) + ] + + db = lancedb.connect(tmp_path) + registry = EmbeddingFunctionRegistry.get_instance() + model = registry.get("mock-embedding").create() + + class Schema(LanceModel): + text: str = model.SourceField() + vector: Vector(model.ndims()) = model.VectorField() + + table = db.create_table("test", schema=Schema, mode="overwrite") + table.add( + [{"text": "hello world"}, {"text": "bar"}], + on_bad_vectors="drop", + ) + + df = table.to_pandas() + assert len(table) == 1 + assert df.iloc[0]["text"] == "bar" + + # table = db.create_table("test2", schema=Schema, mode="overwrite") + # table.add( + # [{"text": "hello world"}, {"text": "bar"}], + # ) + # assert len(table) == 2 + # tbl = table.to_arrow() + # assert tbl["vector"].null_count == 1 + + @pytest.mark.slow def test_embedding_function_rate_limit(tmp_path): def _get_schema_from_model(model):