fix(python)!: handle bad openai embeddings gracefully (#1873)

BREAKING-CHANGE: change Pydantic Vector field to be nullable by default.
Closes #1577
This commit is contained in:
Lei Xu
2024-11-23 13:33:52 -08:00
committed by GitHub
parent dfd9d2ac99
commit 2ded17452b
7 changed files with 102 additions and 63 deletions

View File

@@ -83,25 +83,33 @@ class OpenAIEmbeddings(TextEmbeddingFunction):
"""
openai = attempt_import_or_raise("openai")
valid_texts = []
valid_indices = []
for idx, text in enumerate(texts):
if text:
valid_texts.append(text)
valid_indices.append(idx)
# TODO retry, rate limit, token limit
try:
if self.name == "text-embedding-ada-002":
rs = self._openai_client.embeddings.create(input=texts, model=self.name)
else:
kwargs = {
"input": texts,
"model": self.name,
}
if self.dim:
kwargs["dimensions"] = self.dim
rs = self._openai_client.embeddings.create(**kwargs)
kwargs = {
"input": valid_texts,
"model": self.name,
}
if self.name != "text-embedding-ada-002":
kwargs["dimensions"] = self.dim
rs = self._openai_client.embeddings.create(**kwargs)
valid_embeddings = {
idx: v.embedding for v, idx in zip(rs.data, valid_indices)
}
except openai.BadRequestError:
logging.exception("Bad request: %s", texts)
return [None] * len(texts)
except Exception:
logging.exception("OpenAI embeddings error")
raise
return [v.embedding for v in rs.data]
return [valid_embeddings.get(idx, None) for idx in range(len(texts))]
@cached_property
def _openai_client(self):

View File

@@ -1,15 +1,5 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
"""Pydantic (v1 / v2) adapter for LanceDB"""
@@ -30,6 +20,7 @@ from typing import (
Type,
Union,
_GenericAlias,
GenericAlias,
)
import numpy as np
@@ -75,7 +66,7 @@ def vector(dim: int, value_type: pa.DataType = pa.float32()):
def Vector(
dim: int, value_type: pa.DataType = pa.float32()
dim: int, value_type: pa.DataType = pa.float32(), nullable: bool = True
) -> Type[FixedSizeListMixin]:
"""Pydantic Vector Type.
@@ -88,6 +79,8 @@ def Vector(
The dimension of the vector.
value_type : pyarrow.DataType, optional
The value type of the vector, by default pa.float32()
nullable : bool, optional
Whether the vector is nullable, by default it is True.
Examples
--------
@@ -103,7 +96,7 @@ def Vector(
>>> assert schema == pa.schema([
... pa.field("id", pa.int64(), False),
... pa.field("url", pa.utf8(), False),
... pa.field("embeddings", pa.list_(pa.float32(), 768), False)
... pa.field("embeddings", pa.list_(pa.float32(), 768))
... ])
"""
@@ -112,6 +105,10 @@ def Vector(
def __repr__(self):
return f"FixedSizeList(dim={dim})"
@staticmethod
def nullable() -> bool:
return nullable
@staticmethod
def dim() -> int:
return dim
@@ -205,9 +202,7 @@ else:
def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType:
"""Convert a Pydantic FieldInfo to Arrow DataType"""
if isinstance(field.annotation, _GenericAlias) or (
sys.version_info > (3, 9) and isinstance(field.annotation, types.GenericAlias)
):
if isinstance(field.annotation, (_GenericAlias, GenericAlias)):
origin = field.annotation.__origin__
args = field.annotation.__args__
if origin is list:
@@ -235,7 +230,7 @@ def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType:
def is_nullable(field: FieldInfo) -> bool:
"""Check if a Pydantic FieldInfo is nullable."""
if isinstance(field.annotation, _GenericAlias):
if isinstance(field.annotation, (_GenericAlias, GenericAlias)):
origin = field.annotation.__origin__
args = field.annotation.__args__
if origin == Union:
@@ -246,6 +241,10 @@ def is_nullable(field: FieldInfo) -> bool:
for typ in args:
if typ is type(None):
return True
elif inspect.isclass(field.annotation) and issubclass(
field.annotation, FixedSizeListMixin
):
return field.annotation.nullable()
return False

View File

@@ -90,10 +90,13 @@ def test_embedding_with_bad_results(tmp_path):
self, texts: Union[List[str], np.ndarray]
) -> list[Union[np.array, None]]:
# Return None, which is bad if field is non-nullable
return [
None if i % 2 == 0 else np.random.randn(self.ndims())
a = [
np.full(self.ndims(), np.nan)
if i % 2 == 0
else np.random.randn(self.ndims())
for i in range(len(texts))
]
return a
db = lancedb.connect(tmp_path)
registry = EmbeddingFunctionRegistry.get_instance()

View File

@@ -1,15 +1,6 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
import importlib
import io
import os
@@ -17,6 +8,7 @@ import os
import lancedb
import numpy as np
import pandas as pd
import pyarrow as pa
import pytest
from lancedb.embeddings import get_registry
from lancedb.pydantic import LanceModel, Vector
@@ -444,6 +436,30 @@ def test_watsonx_embedding(tmp_path):
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
@pytest.mark.slow
@pytest.mark.skipif(
os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set"
)
def test_openai_with_empty_strs(tmp_path):
model = get_registry().get("openai").create(max_retries=0)
class TextModel(LanceModel):
text: str = model.SourceField()
vector: Vector(model.ndims()) = model.VectorField()
df = pd.DataFrame({"text": ["hello world", ""]})
db = lancedb.connect(tmp_path)
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(df, on_bad_vectors="skip")
tb = tbl.to_arrow()
assert tb.schema.field_by_name("vector").type == pa.list_(
pa.float32(), model.ndims()
)
assert len(tb) == 2
assert tb["vector"].is_null().to_pylist() == [False, True]
@pytest.mark.slow
@pytest.mark.skipif(
importlib.util.find_spec("ollama") is None, reason="Ollama not installed"

View File

@@ -1,16 +1,5 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
import json
import sys
@@ -172,6 +161,26 @@ def test_pydantic_to_arrow_py38():
assert schema == expect_schema
def test_nullable_vector():
class NullableModel(pydantic.BaseModel):
vec: Vector(16, nullable=False)
schema = pydantic_to_schema(NullableModel)
assert schema == pa.schema([pa.field("vec", pa.list_(pa.float32(), 16), False)])
class DefaultModel(pydantic.BaseModel):
vec: Vector(16)
schema = pydantic_to_schema(DefaultModel)
assert schema == pa.schema([pa.field("vec", pa.list_(pa.float32(), 16), True)])
class NotNullableModel(pydantic.BaseModel):
vec: Vector(16)
schema = pydantic_to_schema(NotNullableModel)
assert schema == pa.schema([pa.field("vec", pa.list_(pa.float32(), 16), True)])
def test_fixed_size_list_field():
class TestModel(pydantic.BaseModel):
vec: Vector(16)
@@ -192,7 +201,7 @@ def test_fixed_size_list_field():
schema = pydantic_to_schema(TestModel)
assert schema == pa.schema(
[
pa.field("vec", pa.list_(pa.float32(), 16), False),
pa.field("vec", pa.list_(pa.float32(), 16)),
pa.field("li", pa.list_(pa.int64()), False),
]
)