fix: use pandas with pydantic embedding column (#1818)

* Make Pandas `DataFrame` works with embedding function + Subset of
columns
* Make `lancedb.create_table()` work with embedding function
This commit is contained in:
Lei Xu
2024-11-11 14:48:56 -08:00
committed by GitHub
parent 5117aecc38
commit 4c9bab0d92
5 changed files with 166 additions and 30 deletions

View File

@@ -138,7 +138,7 @@ jobs:
run: rm -rf target/wheels run: rm -rf target/wheels
windows: windows:
name: "Windows: ${{ matrix.config.name }}" name: "Windows: ${{ matrix.config.name }}"
timeout-minutes: 30 timeout-minutes: 60
strategy: strategy:
matrix: matrix:
config: config:

View File

@@ -1,15 +1,6 @@
# Copyright (c) 2023. LanceDB Developers # SPDX-License-Identifier: Apache-2.0
# # SPDX-FileCopyrightText: Copyright The LanceDB Authors
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json import json
from typing import Dict, Optional from typing import Dict, Optional
@@ -170,7 +161,7 @@ def register(name):
return __REGISTRY__.get_instance().register(name) return __REGISTRY__.get_instance().register(name)
def get_registry(): def get_registry() -> EmbeddingFunctionRegistry:
""" """
Utility function to get the global instance of the registry Utility function to get the global instance of the registry

View File

@@ -13,7 +13,7 @@
import os import os
from functools import cached_property from functools import cached_property
from typing import Union, Optional from typing import Optional
import pyarrow as pa import pyarrow as pa

View File

@@ -73,6 +73,21 @@ pl = safe_import_polars()
QueryType = Literal["vector", "fts", "hybrid", "auto"] QueryType = Literal["vector", "fts", "hybrid", "auto"]
def _pd_schema_without_embedding_funcs(
schema: Optional[pa.Schema], columns: List[str]
) -> Optional[pa.Schema]:
"""Return a schema without any embedding function columns"""
if schema is None:
return None
embedding_functions = EmbeddingFunctionRegistry.get_instance().parse_functions(
schema.metadata
)
if not embedding_functions:
return schema
columns = set(columns)
return pa.schema([field for field in schema if field.name in columns])
def _coerce_to_table(data, schema: Optional[pa.Schema] = None) -> pa.Table: def _coerce_to_table(data, schema: Optional[pa.Schema] = None) -> pa.Table:
if _check_for_hugging_face(data): if _check_for_hugging_face(data):
# Huggingface datasets # Huggingface datasets
@@ -103,10 +118,10 @@ def _coerce_to_table(data, schema: Optional[pa.Schema] = None) -> pa.Table:
elif isinstance(data[0], pa.RecordBatch): elif isinstance(data[0], pa.RecordBatch):
return pa.Table.from_batches(data, schema=schema) return pa.Table.from_batches(data, schema=schema)
else: else:
return pa.Table.from_pylist(data) return pa.Table.from_pylist(data, schema=schema)
elif _check_for_pandas(data) and isinstance(data, pd.DataFrame): elif _check_for_pandas(data) and isinstance(data, pd.DataFrame):
# Do not add schema here, since schema may contains the vector column raw_schema = _pd_schema_without_embedding_funcs(schema, data.columns.to_list())
table = pa.Table.from_pandas(data, preserve_index=False) table = pa.Table.from_pandas(data, preserve_index=False, schema=raw_schema)
# Do not serialize Pandas metadata # Do not serialize Pandas metadata
meta = table.schema.metadata if table.schema.metadata is not None else {} meta = table.schema.metadata if table.schema.metadata is not None else {}
meta = {k: v for k, v in meta.items() if k != b"pandas"} meta = {k: v for k, v in meta.items() if k != b"pandas"}
@@ -172,6 +187,8 @@ def sanitize_create_table(
schema = schema.to_arrow_schema() schema = schema.to_arrow_schema()
if data is not None: if data is not None:
if metadata is None and schema is not None:
metadata = schema.metadata
data, schema = _sanitize_data( data, schema = _sanitize_data(
data, data,
schema, schema,

View File

@@ -1,15 +1,6 @@
# Copyright 2023 LanceDB Developers # SPDX-License-Identifier: Apache-2.0
# # SPDX-FileCopyrightText: Copyright The LanceDB Authors
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Union from typing import List, Union
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
@@ -18,6 +9,7 @@ import lancedb
import numpy as np import numpy as np
import pyarrow as pa import pyarrow as pa
import pytest import pytest
import pandas as pd
from lancedb.conftest import MockTextEmbeddingFunction from lancedb.conftest import MockTextEmbeddingFunction
from lancedb.embeddings import ( from lancedb.embeddings import (
EmbeddingFunctionConfig, EmbeddingFunctionConfig,
@@ -129,6 +121,142 @@ def test_embedding_with_bad_results(tmp_path):
# assert tbl["vector"].null_count == 1 # assert tbl["vector"].null_count == 1
def test_with_existing_vectors(tmp_path):
@register("mock-embedding")
class MockEmbeddingFunction(TextEmbeddingFunction):
def ndims(self):
return 128
def generate_embeddings(
self, texts: Union[List[str], np.ndarray]
) -> List[np.array]:
return [np.random.randn(self.ndims()).tolist() for _ in range(len(texts))]
registry = get_registry()
model = registry.get("mock-embedding").create()
class Schema(LanceModel):
text: str = model.SourceField()
vector: Vector(model.ndims()) = model.VectorField()
db = lancedb.connect(tmp_path)
tbl = db.create_table("test", schema=Schema, mode="overwrite")
tbl.add([{"text": "hello world", "vector": np.zeros(128).tolist()}])
embeddings = tbl.to_arrow()["vector"].to_pylist()
assert not np.any(embeddings), "all zeros"
def test_embedding_function_with_pandas(tmp_path):
@register("mock-embedding")
class _MockEmbeddingFunction(TextEmbeddingFunction):
def ndims(self):
return 128
def generate_embeddings(
self, texts: Union[List[str], np.ndarray]
) -> List[np.array]:
return [np.random.randn(self.ndims()).tolist() for _ in range(len(texts))]
registery = get_registry()
func = registery.get("mock-embedding").create()
class TestSchema(LanceModel):
text: str = func.SourceField()
val: int
vector: Vector(func.ndims()) = func.VectorField()
df = pd.DataFrame(
{
"text": ["hello world", "goodbye world"],
"val": [1, 2],
"not-used": ["s1", "s3"],
}
)
db = lancedb.connect(tmp_path)
tbl = db.create_table("test", schema=TestSchema, mode="overwrite", data=df)
schema = tbl.schema
assert schema.field("text").type == pa.string()
assert schema.field("val").type == pa.int64()
assert schema.field("vector").type == pa.list_(pa.float32(), 128)
df = pd.DataFrame(
{
"text": ["extra", "more"],
"val": [4, 5],
"misc-col": ["s1", "s3"],
}
)
tbl.add(df)
assert tbl.count_rows() == 4
embeddings = tbl.to_arrow()["vector"]
assert embeddings.null_count == 0
df = pd.DataFrame(
{
"text": ["with", "embeddings"],
"val": [6, 7],
"vector": [np.zeros(128).tolist(), np.zeros(128).tolist()],
}
)
tbl.add(df)
embeddings = tbl.search().where("val > 5").to_arrow()["vector"].to_pylist()
assert not np.any(embeddings), "all zeros"
def test_multiple_embeddings_for_pandas(tmp_path):
@register("mock-embedding")
class MockFunc1(TextEmbeddingFunction):
def ndims(self):
return 128
def generate_embeddings(
self, texts: Union[List[str], np.ndarray]
) -> List[np.array]:
return [np.random.randn(self.ndims()).tolist() for _ in range(len(texts))]
@register("mock-embedding2")
class MockFunc2(TextEmbeddingFunction):
def ndims(self):
return 512
def generate_embeddings(
self, texts: Union[List[str], np.ndarray]
) -> List[np.array]:
return [np.random.randn(self.ndims()).tolist() for _ in range(len(texts))]
registery = get_registry()
func1 = registery.get("mock-embedding").create()
func2 = registery.get("mock-embedding2").create()
class TestSchema(LanceModel):
text: str = func1.SourceField()
val: int
vec1: Vector(func1.ndims()) = func1.VectorField()
prompt: str = func2.SourceField()
vec2: Vector(func2.ndims()) = func2.VectorField()
df = pd.DataFrame(
{
"text": ["hello world", "goodbye world"],
"val": [1, 2],
"prompt": ["hello", "goodbye"],
}
)
db = lancedb.connect(tmp_path)
tbl = db.create_table("test", schema=TestSchema, mode="overwrite", data=df)
schema = tbl.schema
assert schema.field("text").type == pa.string()
assert schema.field("val").type == pa.int64()
assert schema.field("vec1").type == pa.list_(pa.float32(), 128)
assert schema.field("prompt").type == pa.string()
assert schema.field("vec2").type == pa.list_(pa.float32(), 512)
assert tbl.count_rows() == 2
@pytest.mark.slow @pytest.mark.slow
def test_embedding_function_rate_limit(tmp_path): def test_embedding_function_rate_limit(tmp_path):
def _get_schema_from_model(model): def _get_schema_from_model(model):