mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 15:12:53 +00:00
feat!: set embeddings to Null if embedding function return invalid results (#1674)
This commit is contained in:
@@ -1,15 +1,6 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Union
|
||||
|
||||
@@ -34,7 +25,7 @@ class EmbeddingFunction(BaseModel, ABC):
|
||||
|
||||
__slots__ = ("__weakref__",) # pydantic 1.x compatibility
|
||||
max_retries: int = (
|
||||
7 # Setitng 0 disables retires. Maybe this should not be enabled by default,
|
||||
7 # Setting 0 disables retires. Maybe this should not be enabled by default,
|
||||
)
|
||||
_ndims: int = PrivateAttr()
|
||||
|
||||
@@ -46,22 +37,37 @@ class EmbeddingFunction(BaseModel, ABC):
|
||||
return cls(**kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def compute_query_embeddings(self, *args, **kwargs) -> List[np.array]:
|
||||
def compute_query_embeddings(self, *args, **kwargs) -> list[Union[np.array, None]]:
|
||||
"""
|
||||
Compute the embeddings for a given user query
|
||||
|
||||
Returns
|
||||
-------
|
||||
A list of embeddings for each input. The embedding of each input can be None
|
||||
when the embedding is not valid.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def compute_source_embeddings(self, *args, **kwargs) -> List[np.array]:
|
||||
"""
|
||||
Compute the embeddings for the source column in the database
|
||||
def compute_source_embeddings(self, *args, **kwargs) -> list[Union[np.array, None]]:
|
||||
"""Compute the embeddings for the source column in the database
|
||||
|
||||
Returns
|
||||
-------
|
||||
A list of embeddings for each input. The embedding of each input can be None
|
||||
when the embedding is not valid.
|
||||
"""
|
||||
pass
|
||||
|
||||
def compute_query_embeddings_with_retry(self, *args, **kwargs) -> List[np.array]:
|
||||
"""
|
||||
Compute the embeddings for a given user query with retries
|
||||
def compute_query_embeddings_with_retry(
|
||||
self, *args, **kwargs
|
||||
) -> list[Union[np.array, None]]:
|
||||
"""Compute the embeddings for a given user query with retries
|
||||
|
||||
Returns
|
||||
-------
|
||||
A list of embeddings for each input. The embedding of each input can be None
|
||||
when the embedding is not valid.
|
||||
"""
|
||||
return retry_with_exponential_backoff(
|
||||
self.compute_query_embeddings, max_retries=self.max_retries
|
||||
@@ -70,9 +76,15 @@ class EmbeddingFunction(BaseModel, ABC):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def compute_source_embeddings_with_retry(self, *args, **kwargs) -> List[np.array]:
|
||||
"""
|
||||
Compute the embeddings for the source column in the database with retries
|
||||
def compute_source_embeddings_with_retry(
|
||||
self, *args, **kwargs
|
||||
) -> list[Union[np.array, None]]:
|
||||
"""Compute the embeddings for the source column in the database with retries.
|
||||
|
||||
Returns
|
||||
-------
|
||||
A list of embeddings for each input. The embedding of each input can be None
|
||||
when the embedding is not valid.
|
||||
"""
|
||||
return retry_with_exponential_backoff(
|
||||
self.compute_source_embeddings, max_retries=self.max_retries
|
||||
@@ -144,18 +156,20 @@ class TextEmbeddingFunction(EmbeddingFunction):
|
||||
A callable ABC for embedding functions that take text as input
|
||||
"""
|
||||
|
||||
def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
|
||||
def compute_query_embeddings(
|
||||
self, query: str, *args, **kwargs
|
||||
) -> list[Union[np.array, None]]:
|
||||
return self.compute_source_embeddings(query, *args, **kwargs)
|
||||
|
||||
def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
|
||||
def compute_source_embeddings(
|
||||
self, texts: TEXT, *args, **kwargs
|
||||
) -> list[Union[np.array, None]]:
|
||||
texts = self.sanitize_input(texts)
|
||||
return self.generate_embeddings(texts)
|
||||
|
||||
@abstractmethod
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray], *args, **kwargs
|
||||
) -> List[np.array]:
|
||||
"""
|
||||
Generate the embeddings for the given texts
|
||||
"""
|
||||
) -> list[Union[np.array, None]]:
|
||||
"""Generate the embeddings for the given texts"""
|
||||
pass
|
||||
|
||||
@@ -1,15 +1,6 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
@@ -19,6 +10,7 @@ from .registry import register
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
import ollama
|
||||
|
||||
|
||||
@register("ollama")
|
||||
@@ -39,17 +31,20 @@ class OllamaEmbeddings(TextEmbeddingFunction):
|
||||
def ndims(self):
|
||||
return len(self.generate_embeddings(["foo"])[0])
|
||||
|
||||
def _compute_embedding(self, text):
|
||||
return self._ollama_client.embeddings(
|
||||
model=self.name,
|
||||
prompt=text,
|
||||
options=self.options,
|
||||
keep_alive=self.keep_alive,
|
||||
)["embedding"]
|
||||
def _compute_embedding(self, text) -> Union["np.array", None]:
|
||||
return (
|
||||
self._ollama_client.embeddings(
|
||||
model=self.name,
|
||||
prompt=text,
|
||||
options=self.options,
|
||||
keep_alive=self.keep_alive,
|
||||
)["embedding"]
|
||||
or None
|
||||
)
|
||||
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], "np.ndarray"]
|
||||
) -> List["np.array"]:
|
||||
) -> list[Union["np.array", None]]:
|
||||
"""
|
||||
Get the embeddings for the given texts
|
||||
|
||||
@@ -63,7 +58,7 @@ class OllamaEmbeddings(TextEmbeddingFunction):
|
||||
return embeddings
|
||||
|
||||
@cached_property
|
||||
def _ollama_client(self):
|
||||
def _ollama_client(self) -> "ollama.Client":
|
||||
ollama = attempt_import_or_raise("ollama")
|
||||
# ToDo explore ollama.AsyncClient
|
||||
return ollama.Client(host=self.host, **self.ollama_client_kwargs)
|
||||
|
||||
@@ -1998,22 +1998,26 @@ def _sanitize_vector_column(
|
||||
data, fill_value, on_bad_vectors, vec_arr, vector_column_name
|
||||
)
|
||||
vec_arr = data[vector_column_name].combine_chunks()
|
||||
vec_arr = ensure_fixed_size_list(vec_arr)
|
||||
data = data.set_column(
|
||||
data.column_names.index(vector_column_name), vector_column_name, vec_arr
|
||||
)
|
||||
elif not pa.types.is_fixed_size_list(vec_arr.type):
|
||||
raise TypeError(f"Unsupported vector column type: {vec_arr.type}")
|
||||
|
||||
vec_arr = ensure_fixed_size_list(vec_arr)
|
||||
data = data.set_column(
|
||||
data.column_names.index(vector_column_name), vector_column_name, vec_arr
|
||||
)
|
||||
|
||||
# Use numpy to check for NaNs, because as pyarrow 14.0.2 does not have `is_nan`
|
||||
# kernel over f16 types.
|
||||
values_np = vec_arr.values.to_numpy(zero_copy_only=False)
|
||||
if np.isnan(values_np).any():
|
||||
data = _sanitize_nans(
|
||||
data, fill_value, on_bad_vectors, vec_arr, vector_column_name
|
||||
)
|
||||
|
||||
if pa.types.is_float16(vec_arr.values.type):
|
||||
# Use numpy to check for NaNs, because as pyarrow does not have `is_nan`
|
||||
# kernel over f16 types yet.
|
||||
values_np = vec_arr.values.to_numpy(zero_copy_only=True)
|
||||
if np.isnan(values_np).any():
|
||||
data = _sanitize_nans(
|
||||
data, fill_value, on_bad_vectors, vec_arr, vector_column_name
|
||||
)
|
||||
else:
|
||||
if pc.any(pc.is_null(vec_arr.values, nan_is_null=True)).as_py():
|
||||
data = _sanitize_nans(
|
||||
data, fill_value, on_bad_vectors, vec_arr, vector_column_name
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
@@ -2057,8 +2061,15 @@ def _sanitize_jagged(data, fill_value, on_bad_vectors, vec_arr, vector_column_na
|
||||
return data
|
||||
|
||||
|
||||
def _sanitize_nans(data, fill_value, on_bad_vectors, vec_arr, vector_column_name):
|
||||
def _sanitize_nans(
|
||||
data,
|
||||
fill_value,
|
||||
on_bad_vectors,
|
||||
vec_arr: pa.FixedSizeListArray,
|
||||
vector_column_name: str,
|
||||
):
|
||||
"""Sanitize NaNs in vectors"""
|
||||
assert pa.types.is_fixed_size_list(vec_arr.type)
|
||||
if on_bad_vectors == "error":
|
||||
raise ValueError(
|
||||
f"Vector column {vector_column_name} has NaNs. "
|
||||
@@ -2078,9 +2089,11 @@ def _sanitize_nans(data, fill_value, on_bad_vectors, vec_arr, vector_column_name
|
||||
data.column_names.index(vector_column_name), vector_column_name, vec_arr
|
||||
)
|
||||
elif on_bad_vectors == "drop":
|
||||
is_value_nan = pc.is_nan(vec_arr.values).to_numpy(zero_copy_only=False)
|
||||
is_full = np.any(~is_value_nan.reshape(-1, vec_arr.type.list_size), axis=1)
|
||||
data = data.filter(is_full)
|
||||
# Drop is very slow to be able to filter out NaNs in a fixed size list array
|
||||
np_arr = np.isnan(vec_arr.values.to_numpy(zero_copy_only=False))
|
||||
np_arr = np_arr.reshape(-1, vec_arr.type.list_size)
|
||||
not_nulls = np.any(np_arr, axis=1)
|
||||
data = data.filter(~not_nulls)
|
||||
return data
|
||||
|
||||
|
||||
|
||||
@@ -86,6 +86,47 @@ def test_embedding_function(tmp_path):
|
||||
assert np.allclose(actual, expected)
|
||||
|
||||
|
||||
def test_embedding_with_bad_results(tmp_path):
|
||||
@register("mock-embedding")
|
||||
class MockEmbeddingFunction(TextEmbeddingFunction):
|
||||
def ndims(self):
|
||||
return 128
|
||||
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray]
|
||||
) -> list[Union[np.array, None]]:
|
||||
return [
|
||||
None if i % 2 == 0 else np.random.randn(self.ndims())
|
||||
for i in range(len(texts))
|
||||
]
|
||||
|
||||
db = lancedb.connect(tmp_path)
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
model = registry.get("mock-embedding").create()
|
||||
|
||||
class Schema(LanceModel):
|
||||
text: str = model.SourceField()
|
||||
vector: Vector(model.ndims()) = model.VectorField()
|
||||
|
||||
table = db.create_table("test", schema=Schema, mode="overwrite")
|
||||
table.add(
|
||||
[{"text": "hello world"}, {"text": "bar"}],
|
||||
on_bad_vectors="drop",
|
||||
)
|
||||
|
||||
df = table.to_pandas()
|
||||
assert len(table) == 1
|
||||
assert df.iloc[0]["text"] == "bar"
|
||||
|
||||
# table = db.create_table("test2", schema=Schema, mode="overwrite")
|
||||
# table.add(
|
||||
# [{"text": "hello world"}, {"text": "bar"}],
|
||||
# )
|
||||
# assert len(table) == 2
|
||||
# tbl = table.to_arrow()
|
||||
# assert tbl["vector"].null_count == 1
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_embedding_function_rate_limit(tmp_path):
|
||||
def _get_schema_from_model(model):
|
||||
|
||||
Reference in New Issue
Block a user