From dbec598610b151e7a32cd9d07bab8b346dfd11ad Mon Sep 17 00:00:00 2001 From: Chang She <759245+changhiskhan@users.noreply.github.com> Date: Wed, 13 Mar 2024 14:35:08 -0700 Subject: [PATCH] feat(python): support optional vector field in pydantic model (#1097) The LanceDB embeddings registry allows users to annotate the pydantic model used as table schema with the desired embedding function, e.g.: ```python class Schema(LanceModel): id: str vector: Vector(openai.ndims()) = openai.VectorField() text: str = openai.SourceField() ``` Tables created like this does not require embeddings to be calculated by the user explicitly, e.g. this works: ```python table.add([{"id": "foo", "text": "rust all the things"}]) ``` However, trying to construct pydantic model instances without vector doesn't because it's a required field. Instead, you need add a default value: ```python class Schema(LanceModel): id: str vector: Vector(openai.ndims()) = openai.VectorField(default=None) text: str = openai.SourceField() ``` then this completes without errors: ```python table.add([Schema(id="foo", text="rust all the things")]) ``` However, all of the vectors are filled with zeros. Instead in add_vector_col we have to add an additional check so that the embedding generation is called. --- python/python/lancedb/table.py | 16 +++++++++--- python/python/tests/test_embeddings.py | 34 ++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index b07cef2f..e050749b 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -118,7 +118,8 @@ def _append_vector_col(data: pa.Table, metadata: dict, schema: Optional[pa.Schem functions = EmbeddingFunctionRegistry.get_instance().parse_functions(metadata) for vector_column, conf in functions.items(): func = conf.function - if vector_column not in data.column_names: + no_vector_column = vector_column not in data.column_names + if no_vector_column or pc.all(pc.is_null(data[vector_column])).as_py(): col_data = func.compute_source_embeddings_with_retry( data[conf.source_column] ) @@ -126,9 +127,16 @@ def _append_vector_col(data: pa.Table, metadata: dict, schema: Optional[pa.Schem dtype = schema.field(vector_column).type else: dtype = pa.list_(pa.float32(), len(col_data[0])) - data = data.append_column( - pa.field(vector_column, type=dtype), pa.array(col_data, type=dtype) - ) + if no_vector_column: + data = data.append_column( + pa.field(vector_column, type=dtype), pa.array(col_data, type=dtype) + ) + else: + data = data.set_column( + data.column_names.index(vector_column), + pa.field(vector_column, type=dtype), + pa.array(col_data, type=dtype), + ) return data diff --git a/python/python/tests/test_embeddings.py b/python/python/tests/test_embeddings.py index af442a16..ed7b105a 100644 --- a/python/python/tests/test_embeddings.py +++ b/python/python/tests/test_embeddings.py @@ -11,6 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import sys +from typing import List, Union import lance import lancedb @@ -23,6 +24,8 @@ from lancedb.embeddings import ( EmbeddingFunctionRegistry, with_embeddings, ) +from lancedb.embeddings.base import TextEmbeddingFunction +from lancedb.embeddings.registry import get_registry, register from lancedb.pydantic import LanceModel, Vector @@ -112,3 +115,34 @@ def test_embedding_function_rate_limit(tmp_path): table.add([{"text": "hello world"}]) table.add([{"text": "hello world"}]) assert len(table) == 2 + + +def test_add_optional_vector(tmp_path): + @register("mock-embedding") + class MockEmbeddingFunction(TextEmbeddingFunction): + def ndims(self): + return 128 + + def generate_embeddings( + self, texts: Union[List[str], np.ndarray] + ) -> List[np.array]: + """ + Generate the embeddings for the given texts + """ + return [np.random.randn(self.ndims()).tolist() for _ in range(len(texts))] + + registry = get_registry() + model = registry.get("mock-embedding").create() + + class LanceSchema(LanceModel): + id: str + vector: Vector(model.ndims()) = model.VectorField(default=None) + text: str = model.SourceField() + + db = lancedb.connect(tmp_path) + tbl = db.create_table("optional_vector", schema=LanceSchema) + + # add works + expected = LanceSchema(id="id", text="text") + tbl.add([expected]) + assert not (np.abs(tbl.to_pandas()["vector"][0]) < 1e-6).all()