diff --git a/python/Cargo.toml b/python/Cargo.toml index 876385a4..b3d71aef 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -17,11 +17,17 @@ crate-type = ["cdylib"] arrow = { version = "52.1", features = ["pyarrow"] } lancedb = { path = "../rust/lancedb", default-features = false } env_logger.workspace = true -pyo3 = { version = "0.21", features = ["extension-module", "abi3-py38", "gil-refs"] } +pyo3 = { version = "0.21", features = [ + "extension-module", + "abi3-py39", + "gil-refs" +] } # Using this fork for now: https://github.com/awestlake87/pyo3-asyncio/issues/119 # pyo3-asyncio = { version = "0.20", features = ["attributes", "tokio-runtime"] } -pyo3-asyncio-0-21 = { version = "0.21.0", features = ["attributes", "tokio-runtime"] } - +pyo3-asyncio-0-21 = { version = "0.21.0", features = [ + "attributes", + "tokio-runtime" +] } pin-project = "1.1.5" futures.workspace = true tokio = { version = "1.36.0", features = ["sync"] } @@ -29,14 +35,13 @@ tokio = { version = "1.36.0", features = ["sync"] } [build-dependencies] pyo3-build-config = { version = "0.20.3", features = [ "extension-module", - "abi3-py38", + "abi3-py39", ] } [features] default = ["default-tls", "remote"] fp16kernels = ["lancedb/fp16kernels"] remote = ["lancedb/remote"] - # TLS default-tls = ["lancedb/default-tls"] native-tls = ["lancedb/native-tls"] diff --git a/python/pyproject.toml b/python/pyproject.toml index 4ab18808..a60f5baa 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -31,7 +31,6 @@ classifiers = [ "Programming Language :: Python", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", diff --git a/python/python/lancedb/embeddings/openai.py b/python/python/lancedb/embeddings/openai.py index 2fca549d..d126e2a2 100644 --- a/python/python/lancedb/embeddings/openai.py +++ b/python/python/lancedb/embeddings/openai.py @@ -83,25 +83,33 @@ class OpenAIEmbeddings(TextEmbeddingFunction): """ openai = attempt_import_or_raise("openai") + valid_texts = [] + valid_indices = [] + for idx, text in enumerate(texts): + if text: + valid_texts.append(text) + valid_indices.append(idx) + # TODO retry, rate limit, token limit try: - if self.name == "text-embedding-ada-002": - rs = self._openai_client.embeddings.create(input=texts, model=self.name) - else: - kwargs = { - "input": texts, - "model": self.name, - } - if self.dim: - kwargs["dimensions"] = self.dim - rs = self._openai_client.embeddings.create(**kwargs) + kwargs = { + "input": valid_texts, + "model": self.name, + } + if self.name != "text-embedding-ada-002": + kwargs["dimensions"] = self.dim + + rs = self._openai_client.embeddings.create(**kwargs) + valid_embeddings = { + idx: v.embedding for v, idx in zip(rs.data, valid_indices) + } except openai.BadRequestError: logging.exception("Bad request: %s", texts) return [None] * len(texts) except Exception: logging.exception("OpenAI embeddings error") raise - return [v.embedding for v in rs.data] + return [valid_embeddings.get(idx, None) for idx in range(len(texts))] @cached_property def _openai_client(self): diff --git a/python/python/lancedb/pydantic.py b/python/python/lancedb/pydantic.py index 2d72acad..6838ccf3 100644 --- a/python/python/lancedb/pydantic.py +++ b/python/python/lancedb/pydantic.py @@ -1,15 +1,5 @@ -# Copyright 2023 LanceDB Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The LanceDB Authors """Pydantic (v1 / v2) adapter for LanceDB""" @@ -30,6 +20,7 @@ from typing import ( Type, Union, _GenericAlias, + GenericAlias, ) import numpy as np @@ -75,7 +66,7 @@ def vector(dim: int, value_type: pa.DataType = pa.float32()): def Vector( - dim: int, value_type: pa.DataType = pa.float32() + dim: int, value_type: pa.DataType = pa.float32(), nullable: bool = True ) -> Type[FixedSizeListMixin]: """Pydantic Vector Type. @@ -88,6 +79,8 @@ def Vector( The dimension of the vector. value_type : pyarrow.DataType, optional The value type of the vector, by default pa.float32() + nullable : bool, optional + Whether the vector is nullable, by default it is True. Examples -------- @@ -103,7 +96,7 @@ def Vector( >>> assert schema == pa.schema([ ... pa.field("id", pa.int64(), False), ... pa.field("url", pa.utf8(), False), - ... pa.field("embeddings", pa.list_(pa.float32(), 768), False) + ... pa.field("embeddings", pa.list_(pa.float32(), 768)) ... ]) """ @@ -112,6 +105,10 @@ def Vector( def __repr__(self): return f"FixedSizeList(dim={dim})" + @staticmethod + def nullable() -> bool: + return nullable + @staticmethod def dim() -> int: return dim @@ -205,9 +202,7 @@ else: def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType: """Convert a Pydantic FieldInfo to Arrow DataType""" - if isinstance(field.annotation, _GenericAlias) or ( - sys.version_info > (3, 9) and isinstance(field.annotation, types.GenericAlias) - ): + if isinstance(field.annotation, (_GenericAlias, GenericAlias)): origin = field.annotation.__origin__ args = field.annotation.__args__ if origin is list: @@ -235,7 +230,7 @@ def _pydantic_to_arrow_type(field: FieldInfo) -> pa.DataType: def is_nullable(field: FieldInfo) -> bool: """Check if a Pydantic FieldInfo is nullable.""" - if isinstance(field.annotation, _GenericAlias): + if isinstance(field.annotation, (_GenericAlias, GenericAlias)): origin = field.annotation.__origin__ args = field.annotation.__args__ if origin == Union: @@ -246,6 +241,10 @@ def is_nullable(field: FieldInfo) -> bool: for typ in args: if typ is type(None): return True + elif inspect.isclass(field.annotation) and issubclass( + field.annotation, FixedSizeListMixin + ): + return field.annotation.nullable() return False diff --git a/python/python/tests/test_embeddings.py b/python/python/tests/test_embeddings.py index 32394009..4e5fac2c 100644 --- a/python/python/tests/test_embeddings.py +++ b/python/python/tests/test_embeddings.py @@ -90,10 +90,13 @@ def test_embedding_with_bad_results(tmp_path): self, texts: Union[List[str], np.ndarray] ) -> list[Union[np.array, None]]: # Return None, which is bad if field is non-nullable - return [ - None if i % 2 == 0 else np.random.randn(self.ndims()) + a = [ + np.full(self.ndims(), np.nan) + if i % 2 == 0 + else np.random.randn(self.ndims()) for i in range(len(texts)) ] + return a db = lancedb.connect(tmp_path) registry = EmbeddingFunctionRegistry.get_instance() diff --git a/python/python/tests/test_embeddings_slow.py b/python/python/tests/test_embeddings_slow.py index 58f9ff98..ba9f8fea 100644 --- a/python/python/tests/test_embeddings_slow.py +++ b/python/python/tests/test_embeddings_slow.py @@ -1,15 +1,6 @@ -# Copyright (c) 2023. LanceDB Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The LanceDB Authors + import importlib import io import os @@ -17,6 +8,7 @@ import os import lancedb import numpy as np import pandas as pd +import pyarrow as pa import pytest from lancedb.embeddings import get_registry from lancedb.pydantic import LanceModel, Vector @@ -444,6 +436,30 @@ def test_watsonx_embedding(tmp_path): assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world" +@pytest.mark.slow +@pytest.mark.skipif( + os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set" +) +def test_openai_with_empty_strs(tmp_path): + model = get_registry().get("openai").create(max_retries=0) + + class TextModel(LanceModel): + text: str = model.SourceField() + vector: Vector(model.ndims()) = model.VectorField() + + df = pd.DataFrame({"text": ["hello world", ""]}) + db = lancedb.connect(tmp_path) + tbl = db.create_table("test", schema=TextModel, mode="overwrite") + + tbl.add(df, on_bad_vectors="skip") + tb = tbl.to_arrow() + assert tb.schema.field_by_name("vector").type == pa.list_( + pa.float32(), model.ndims() + ) + assert len(tb) == 2 + assert tb["vector"].is_null().to_pylist() == [False, True] + + @pytest.mark.slow @pytest.mark.skipif( importlib.util.find_spec("ollama") is None, reason="Ollama not installed" diff --git a/python/python/tests/test_pydantic.py b/python/python/tests/test_pydantic.py index 5b401334..0e76c3ad 100644 --- a/python/python/tests/test_pydantic.py +++ b/python/python/tests/test_pydantic.py @@ -1,16 +1,5 @@ -# Copyright 2023 LanceDB Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The LanceDB Authors import json import sys @@ -172,6 +161,26 @@ def test_pydantic_to_arrow_py38(): assert schema == expect_schema +def test_nullable_vector(): + class NullableModel(pydantic.BaseModel): + vec: Vector(16, nullable=False) + + schema = pydantic_to_schema(NullableModel) + assert schema == pa.schema([pa.field("vec", pa.list_(pa.float32(), 16), False)]) + + class DefaultModel(pydantic.BaseModel): + vec: Vector(16) + + schema = pydantic_to_schema(DefaultModel) + assert schema == pa.schema([pa.field("vec", pa.list_(pa.float32(), 16), True)]) + + class NotNullableModel(pydantic.BaseModel): + vec: Vector(16) + + schema = pydantic_to_schema(NotNullableModel) + assert schema == pa.schema([pa.field("vec", pa.list_(pa.float32(), 16), True)]) + + def test_fixed_size_list_field(): class TestModel(pydantic.BaseModel): vec: Vector(16) @@ -192,7 +201,7 @@ def test_fixed_size_list_field(): schema = pydantic_to_schema(TestModel) assert schema == pa.schema( [ - pa.field("vec", pa.list_(pa.float32(), 16), False), + pa.field("vec", pa.list_(pa.float32(), 16)), pa.field("li", pa.list_(pa.int64()), False), ] )