fix: use pandas with pydantic embedding column (#1818)

* Make Pandas `DataFrame` works with embedding function + Subset of
columns
* Make `lancedb.create_table()` work with embedding function
This commit is contained in:
Lei Xu
2024-11-11 14:48:56 -08:00
committed by GitHub
parent 5117aecc38
commit 4c9bab0d92
5 changed files with 166 additions and 30 deletions

View File

@@ -1,15 +1,6 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
import json
from typing import Dict, Optional
@@ -170,7 +161,7 @@ def register(name):
return __REGISTRY__.get_instance().register(name)
def get_registry():
def get_registry() -> EmbeddingFunctionRegistry:
"""
Utility function to get the global instance of the registry

View File

@@ -13,7 +13,7 @@
import os
from functools import cached_property
from typing import Union, Optional
from typing import Optional
import pyarrow as pa

View File

@@ -73,6 +73,21 @@ pl = safe_import_polars()
QueryType = Literal["vector", "fts", "hybrid", "auto"]
def _pd_schema_without_embedding_funcs(
schema: Optional[pa.Schema], columns: List[str]
) -> Optional[pa.Schema]:
"""Return a schema without any embedding function columns"""
if schema is None:
return None
embedding_functions = EmbeddingFunctionRegistry.get_instance().parse_functions(
schema.metadata
)
if not embedding_functions:
return schema
columns = set(columns)
return pa.schema([field for field in schema if field.name in columns])
def _coerce_to_table(data, schema: Optional[pa.Schema] = None) -> pa.Table:
if _check_for_hugging_face(data):
# Huggingface datasets
@@ -103,10 +118,10 @@ def _coerce_to_table(data, schema: Optional[pa.Schema] = None) -> pa.Table:
elif isinstance(data[0], pa.RecordBatch):
return pa.Table.from_batches(data, schema=schema)
else:
return pa.Table.from_pylist(data)
return pa.Table.from_pylist(data, schema=schema)
elif _check_for_pandas(data) and isinstance(data, pd.DataFrame):
# Do not add schema here, since schema may contains the vector column
table = pa.Table.from_pandas(data, preserve_index=False)
raw_schema = _pd_schema_without_embedding_funcs(schema, data.columns.to_list())
table = pa.Table.from_pandas(data, preserve_index=False, schema=raw_schema)
# Do not serialize Pandas metadata
meta = table.schema.metadata if table.schema.metadata is not None else {}
meta = {k: v for k, v in meta.items() if k != b"pandas"}
@@ -172,6 +187,8 @@ def sanitize_create_table(
schema = schema.to_arrow_schema()
if data is not None:
if metadata is None and schema is not None:
metadata = schema.metadata
data, schema = _sanitize_data(
data,
schema,

View File

@@ -1,15 +1,6 @@
# Copyright 2023 LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
from typing import List, Union
from unittest.mock import MagicMock, patch
@@ -18,6 +9,7 @@ import lancedb
import numpy as np
import pyarrow as pa
import pytest
import pandas as pd
from lancedb.conftest import MockTextEmbeddingFunction
from lancedb.embeddings import (
EmbeddingFunctionConfig,
@@ -129,6 +121,142 @@ def test_embedding_with_bad_results(tmp_path):
# assert tbl["vector"].null_count == 1
def test_with_existing_vectors(tmp_path):
@register("mock-embedding")
class MockEmbeddingFunction(TextEmbeddingFunction):
def ndims(self):
return 128
def generate_embeddings(
self, texts: Union[List[str], np.ndarray]
) -> List[np.array]:
return [np.random.randn(self.ndims()).tolist() for _ in range(len(texts))]
registry = get_registry()
model = registry.get("mock-embedding").create()
class Schema(LanceModel):
text: str = model.SourceField()
vector: Vector(model.ndims()) = model.VectorField()
db = lancedb.connect(tmp_path)
tbl = db.create_table("test", schema=Schema, mode="overwrite")
tbl.add([{"text": "hello world", "vector": np.zeros(128).tolist()}])
embeddings = tbl.to_arrow()["vector"].to_pylist()
assert not np.any(embeddings), "all zeros"
def test_embedding_function_with_pandas(tmp_path):
@register("mock-embedding")
class _MockEmbeddingFunction(TextEmbeddingFunction):
def ndims(self):
return 128
def generate_embeddings(
self, texts: Union[List[str], np.ndarray]
) -> List[np.array]:
return [np.random.randn(self.ndims()).tolist() for _ in range(len(texts))]
registery = get_registry()
func = registery.get("mock-embedding").create()
class TestSchema(LanceModel):
text: str = func.SourceField()
val: int
vector: Vector(func.ndims()) = func.VectorField()
df = pd.DataFrame(
{
"text": ["hello world", "goodbye world"],
"val": [1, 2],
"not-used": ["s1", "s3"],
}
)
db = lancedb.connect(tmp_path)
tbl = db.create_table("test", schema=TestSchema, mode="overwrite", data=df)
schema = tbl.schema
assert schema.field("text").type == pa.string()
assert schema.field("val").type == pa.int64()
assert schema.field("vector").type == pa.list_(pa.float32(), 128)
df = pd.DataFrame(
{
"text": ["extra", "more"],
"val": [4, 5],
"misc-col": ["s1", "s3"],
}
)
tbl.add(df)
assert tbl.count_rows() == 4
embeddings = tbl.to_arrow()["vector"]
assert embeddings.null_count == 0
df = pd.DataFrame(
{
"text": ["with", "embeddings"],
"val": [6, 7],
"vector": [np.zeros(128).tolist(), np.zeros(128).tolist()],
}
)
tbl.add(df)
embeddings = tbl.search().where("val > 5").to_arrow()["vector"].to_pylist()
assert not np.any(embeddings), "all zeros"
def test_multiple_embeddings_for_pandas(tmp_path):
@register("mock-embedding")
class MockFunc1(TextEmbeddingFunction):
def ndims(self):
return 128
def generate_embeddings(
self, texts: Union[List[str], np.ndarray]
) -> List[np.array]:
return [np.random.randn(self.ndims()).tolist() for _ in range(len(texts))]
@register("mock-embedding2")
class MockFunc2(TextEmbeddingFunction):
def ndims(self):
return 512
def generate_embeddings(
self, texts: Union[List[str], np.ndarray]
) -> List[np.array]:
return [np.random.randn(self.ndims()).tolist() for _ in range(len(texts))]
registery = get_registry()
func1 = registery.get("mock-embedding").create()
func2 = registery.get("mock-embedding2").create()
class TestSchema(LanceModel):
text: str = func1.SourceField()
val: int
vec1: Vector(func1.ndims()) = func1.VectorField()
prompt: str = func2.SourceField()
vec2: Vector(func2.ndims()) = func2.VectorField()
df = pd.DataFrame(
{
"text": ["hello world", "goodbye world"],
"val": [1, 2],
"prompt": ["hello", "goodbye"],
}
)
db = lancedb.connect(tmp_path)
tbl = db.create_table("test", schema=TestSchema, mode="overwrite", data=df)
schema = tbl.schema
assert schema.field("text").type == pa.string()
assert schema.field("val").type == pa.int64()
assert schema.field("vec1").type == pa.list_(pa.float32(), 128)
assert schema.field("prompt").type == pa.string()
assert schema.field("vec2").type == pa.list_(pa.float32(), 512)
assert tbl.count_rows() == 2
@pytest.mark.slow
def test_embedding_function_rate_limit(tmp_path):
def _get_schema_from_model(model):