fix(python)!: handle bad openai embeddings gracefully (#1873)

BREAKING-CHANGE: change Pydantic Vector field to be nullable by default.
Closes #1577
This commit is contained in:
Lei Xu
2024-11-23 13:33:52 -08:00
committed by GitHub
parent dfd9d2ac99
commit 2ded17452b
7 changed files with 102 additions and 63 deletions

View File

@@ -1,15 +1,6 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
import importlib
import io
import os
@@ -17,6 +8,7 @@ import os
import lancedb
import numpy as np
import pandas as pd
import pyarrow as pa
import pytest
from lancedb.embeddings import get_registry
from lancedb.pydantic import LanceModel, Vector
@@ -444,6 +436,30 @@ def test_watsonx_embedding(tmp_path):
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
@pytest.mark.slow
@pytest.mark.skipif(
os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set"
)
def test_openai_with_empty_strs(tmp_path):
model = get_registry().get("openai").create(max_retries=0)
class TextModel(LanceModel):
text: str = model.SourceField()
vector: Vector(model.ndims()) = model.VectorField()
df = pd.DataFrame({"text": ["hello world", ""]})
db = lancedb.connect(tmp_path)
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(df, on_bad_vectors="skip")
tb = tbl.to_arrow()
assert tb.schema.field_by_name("vector").type == pa.list_(
pa.float32(), model.ndims()
)
assert len(tb) == 2
assert tb["vector"].is_null().to_pylist() == [False, True]
@pytest.mark.slow
@pytest.mark.skipif(
importlib.util.find_spec("ollama") is None, reason="Ollama not installed"