From 4c9bab0d92fe12d5d126391bb5b8ac1bb4f8a45e Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Mon, 11 Nov 2024 14:48:56 -0800 Subject: [PATCH] fix: use pandas with pydantic embedding column (#1818) * Make Pandas `DataFrame` works with embedding function + Subset of columns * Make `lancedb.create_table()` work with embedding function --- .github/workflows/python.yml | 2 +- python/python/lancedb/embeddings/registry.py | 17 +-- python/python/lancedb/rerankers/voyageai.py | 2 +- python/python/lancedb/table.py | 23 ++- python/python/tests/test_embeddings.py | 152 +++++++++++++++++-- 5 files changed, 166 insertions(+), 30 deletions(-) diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 6f450192..c0edc098 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -138,7 +138,7 @@ jobs: run: rm -rf target/wheels windows: name: "Windows: ${{ matrix.config.name }}" - timeout-minutes: 30 + timeout-minutes: 60 strategy: matrix: config: diff --git a/python/python/lancedb/embeddings/registry.py b/python/python/lancedb/embeddings/registry.py index d5ab1f35..f442b713 100644 --- a/python/python/lancedb/embeddings/registry.py +++ b/python/python/lancedb/embeddings/registry.py @@ -1,15 +1,6 @@ -# Copyright (c) 2023. LanceDB Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The LanceDB Authors + import json from typing import Dict, Optional @@ -170,7 +161,7 @@ def register(name): return __REGISTRY__.get_instance().register(name) -def get_registry(): +def get_registry() -> EmbeddingFunctionRegistry: """ Utility function to get the global instance of the registry diff --git a/python/python/lancedb/rerankers/voyageai.py b/python/python/lancedb/rerankers/voyageai.py index d04a5ad4..e47f190e 100644 --- a/python/python/lancedb/rerankers/voyageai.py +++ b/python/python/lancedb/rerankers/voyageai.py @@ -13,7 +13,7 @@ import os from functools import cached_property -from typing import Union, Optional +from typing import Optional import pyarrow as pa diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 6403c88f..684998b6 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -73,6 +73,21 @@ pl = safe_import_polars() QueryType = Literal["vector", "fts", "hybrid", "auto"] +def _pd_schema_without_embedding_funcs( + schema: Optional[pa.Schema], columns: List[str] +) -> Optional[pa.Schema]: + """Return a schema without any embedding function columns""" + if schema is None: + return None + embedding_functions = EmbeddingFunctionRegistry.get_instance().parse_functions( + schema.metadata + ) + if not embedding_functions: + return schema + columns = set(columns) + return pa.schema([field for field in schema if field.name in columns]) + + def _coerce_to_table(data, schema: Optional[pa.Schema] = None) -> pa.Table: if _check_for_hugging_face(data): # Huggingface datasets @@ -103,10 +118,10 @@ def _coerce_to_table(data, schema: Optional[pa.Schema] = None) -> pa.Table: elif isinstance(data[0], pa.RecordBatch): return pa.Table.from_batches(data, schema=schema) else: - return pa.Table.from_pylist(data) + return pa.Table.from_pylist(data, schema=schema) elif _check_for_pandas(data) and isinstance(data, pd.DataFrame): - # Do not add schema here, since schema may contains the vector column - table = pa.Table.from_pandas(data, preserve_index=False) + raw_schema = _pd_schema_without_embedding_funcs(schema, data.columns.to_list()) + table = pa.Table.from_pandas(data, preserve_index=False, schema=raw_schema) # Do not serialize Pandas metadata meta = table.schema.metadata if table.schema.metadata is not None else {} meta = {k: v for k, v in meta.items() if k != b"pandas"} @@ -172,6 +187,8 @@ def sanitize_create_table( schema = schema.to_arrow_schema() if data is not None: + if metadata is None and schema is not None: + metadata = schema.metadata data, schema = _sanitize_data( data, schema, diff --git a/python/python/tests/test_embeddings.py b/python/python/tests/test_embeddings.py index a9f939ee..59a9ee4b 100644 --- a/python/python/tests/test_embeddings.py +++ b/python/python/tests/test_embeddings.py @@ -1,15 +1,6 @@ -# Copyright 2023 LanceDB Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The LanceDB Authors + from typing import List, Union from unittest.mock import MagicMock, patch @@ -18,6 +9,7 @@ import lancedb import numpy as np import pyarrow as pa import pytest +import pandas as pd from lancedb.conftest import MockTextEmbeddingFunction from lancedb.embeddings import ( EmbeddingFunctionConfig, @@ -129,6 +121,142 @@ def test_embedding_with_bad_results(tmp_path): # assert tbl["vector"].null_count == 1 +def test_with_existing_vectors(tmp_path): + @register("mock-embedding") + class MockEmbeddingFunction(TextEmbeddingFunction): + def ndims(self): + return 128 + + def generate_embeddings( + self, texts: Union[List[str], np.ndarray] + ) -> List[np.array]: + return [np.random.randn(self.ndims()).tolist() for _ in range(len(texts))] + + registry = get_registry() + model = registry.get("mock-embedding").create() + + class Schema(LanceModel): + text: str = model.SourceField() + vector: Vector(model.ndims()) = model.VectorField() + + db = lancedb.connect(tmp_path) + tbl = db.create_table("test", schema=Schema, mode="overwrite") + tbl.add([{"text": "hello world", "vector": np.zeros(128).tolist()}]) + + embeddings = tbl.to_arrow()["vector"].to_pylist() + assert not np.any(embeddings), "all zeros" + + +def test_embedding_function_with_pandas(tmp_path): + @register("mock-embedding") + class _MockEmbeddingFunction(TextEmbeddingFunction): + def ndims(self): + return 128 + + def generate_embeddings( + self, texts: Union[List[str], np.ndarray] + ) -> List[np.array]: + return [np.random.randn(self.ndims()).tolist() for _ in range(len(texts))] + + registery = get_registry() + func = registery.get("mock-embedding").create() + + class TestSchema(LanceModel): + text: str = func.SourceField() + val: int + vector: Vector(func.ndims()) = func.VectorField() + + df = pd.DataFrame( + { + "text": ["hello world", "goodbye world"], + "val": [1, 2], + "not-used": ["s1", "s3"], + } + ) + db = lancedb.connect(tmp_path) + tbl = db.create_table("test", schema=TestSchema, mode="overwrite", data=df) + schema = tbl.schema + assert schema.field("text").type == pa.string() + assert schema.field("val").type == pa.int64() + assert schema.field("vector").type == pa.list_(pa.float32(), 128) + + df = pd.DataFrame( + { + "text": ["extra", "more"], + "val": [4, 5], + "misc-col": ["s1", "s3"], + } + ) + tbl.add(df) + + assert tbl.count_rows() == 4 + embeddings = tbl.to_arrow()["vector"] + assert embeddings.null_count == 0 + + df = pd.DataFrame( + { + "text": ["with", "embeddings"], + "val": [6, 7], + "vector": [np.zeros(128).tolist(), np.zeros(128).tolist()], + } + ) + tbl.add(df) + + embeddings = tbl.search().where("val > 5").to_arrow()["vector"].to_pylist() + assert not np.any(embeddings), "all zeros" + + +def test_multiple_embeddings_for_pandas(tmp_path): + @register("mock-embedding") + class MockFunc1(TextEmbeddingFunction): + def ndims(self): + return 128 + + def generate_embeddings( + self, texts: Union[List[str], np.ndarray] + ) -> List[np.array]: + return [np.random.randn(self.ndims()).tolist() for _ in range(len(texts))] + + @register("mock-embedding2") + class MockFunc2(TextEmbeddingFunction): + def ndims(self): + return 512 + + def generate_embeddings( + self, texts: Union[List[str], np.ndarray] + ) -> List[np.array]: + return [np.random.randn(self.ndims()).tolist() for _ in range(len(texts))] + + registery = get_registry() + func1 = registery.get("mock-embedding").create() + func2 = registery.get("mock-embedding2").create() + + class TestSchema(LanceModel): + text: str = func1.SourceField() + val: int + vec1: Vector(func1.ndims()) = func1.VectorField() + prompt: str = func2.SourceField() + vec2: Vector(func2.ndims()) = func2.VectorField() + + df = pd.DataFrame( + { + "text": ["hello world", "goodbye world"], + "val": [1, 2], + "prompt": ["hello", "goodbye"], + } + ) + db = lancedb.connect(tmp_path) + tbl = db.create_table("test", schema=TestSchema, mode="overwrite", data=df) + + schema = tbl.schema + assert schema.field("text").type == pa.string() + assert schema.field("val").type == pa.int64() + assert schema.field("vec1").type == pa.list_(pa.float32(), 128) + assert schema.field("prompt").type == pa.string() + assert schema.field("vec2").type == pa.list_(pa.float32(), 512) + assert tbl.count_rows() == 2 + + @pytest.mark.slow def test_embedding_function_rate_limit(tmp_path): def _get_schema_from_model(model):