From a6aa67baed47b497ef1988556a68a67a8dcc76c2 Mon Sep 17 00:00:00 2001 From: Raghav Dixit <34462078+raghavdixit99@users.noreply.github.com> Date: Wed, 10 Apr 2024 13:17:14 -0400 Subject: [PATCH] python: Bug fixes / tests (#1210) closes #1194 #1172 #1124 #1208 @wjones127 : `if query_type != "fts":` is needed because both fts and vector search create `LanceQueryBuilder` which has `vector_column_name` as a required attribute. --- python/pyproject.toml | 1 + python/python/lancedb/embeddings/bedrock.py | 3 +++ python/python/lancedb/embeddings/gemini_text.py | 3 +++ python/python/lancedb/embeddings/imagebind.py | 11 +++++++++-- .../python/lancedb/embeddings/transformers.py | 10 ++++++++-- python/python/lancedb/table.py | 12 +++++++++++- python/python/tests/test_db.py | 17 +++++++++++++++++ 7 files changed, 52 insertions(+), 5 deletions(-) diff --git a/python/pyproject.toml b/python/pyproject.toml index f457f262..f598f08b 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -57,6 +57,7 @@ tests = [ "duckdb", "pytz", "polars>=0.19", + "tantivy" ] dev = ["ruff", "pre-commit"] docs = [ diff --git a/python/python/lancedb/embeddings/bedrock.py b/python/python/lancedb/embeddings/bedrock.py index 8b0ccbc2..dab926a9 100644 --- a/python/python/lancedb/embeddings/bedrock.py +++ b/python/python/lancedb/embeddings/bedrock.py @@ -78,6 +78,9 @@ class BedRockText(TextEmbeddingFunction): class Config: keep_untouched = (cached_property,) + else: + model_config = dict() + model_config["ignored_types"] = (cached_property,) def ndims(self): # return len(self._generate_embedding("test")) diff --git a/python/python/lancedb/embeddings/gemini_text.py b/python/python/lancedb/embeddings/gemini_text.py index bdbd304c..e3a9b96d 100644 --- a/python/python/lancedb/embeddings/gemini_text.py +++ b/python/python/lancedb/embeddings/gemini_text.py @@ -94,6 +94,9 @@ class GeminiText(TextEmbeddingFunction): class Config: keep_untouched = (cached_property,) + else: + model_config = dict() + model_config["ignored_types"] = (cached_property,) def ndims(self): # TODO: fix hardcoding diff --git a/python/python/lancedb/embeddings/imagebind.py b/python/python/lancedb/embeddings/imagebind.py index 209a134b..634b1487 100644 --- a/python/python/lancedb/embeddings/imagebind.py +++ b/python/python/lancedb/embeddings/imagebind.py @@ -22,6 +22,8 @@ from .base import EmbeddingFunction from .registry import register from .utils import AUDIO, IMAGES, TEXT +from lancedb.pydantic import PYDANTIC_VERSION + @register("imagebind") class ImageBindEmbeddings(EmbeddingFunction): @@ -38,8 +40,13 @@ class ImageBindEmbeddings(EmbeddingFunction): device: str = "cpu" normalize: bool = False - class Config: - keep_untouched = (cached_property,) + if PYDANTIC_VERSION < (2, 0): # Pydantic 1.x compat + + class Config: + keep_untouched = (cached_property,) + else: + model_config = dict() + model_config["ignored_types"] = (cached_property,) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/python/python/lancedb/embeddings/transformers.py b/python/python/lancedb/embeddings/transformers.py index f796bd2e..02696c4f 100644 --- a/python/python/lancedb/embeddings/transformers.py +++ b/python/python/lancedb/embeddings/transformers.py @@ -17,6 +17,7 @@ from typing import List, Any import numpy as np from pydantic import PrivateAttr +from lancedb.pydantic import PYDANTIC_VERSION from ..util import attempt_import_or_raise from .base import EmbeddingFunction @@ -53,8 +54,13 @@ class TransformersEmbeddingFunction(EmbeddingFunction): self._tokenizer = transformers.AutoTokenizer.from_pretrained(self.name) self._model = transformers.AutoModel.from_pretrained(self.name) - class Config: - keep_untouched = (cached_property,) + if PYDANTIC_VERSION < (2, 0): # Pydantic 1.x compat + + class Config: + keep_untouched = (cached_property,) + else: + model_config = dict() + model_config["ignored_types"] = (cached_property,) def ndims(self): self._ndims = self._model.config.hidden_size diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index 1928a3a9..58d0d0bf 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -95,6 +95,9 @@ def _sanitize_data( data.data.to_batches(), schema, metadata, on_bad_vectors, fill_value ) + if isinstance(data, LanceModel): + raise ValueError("Cannot add a single LanceModel to a table. Use a list.") + if isinstance(data, list): # convert to list of dict if data is a bunch of LanceModels if isinstance(data[0], LanceModel): @@ -1403,7 +1406,14 @@ class LanceTable(Table): vector and the returned vector. """ if vector_column_name is None and query is not None: - vector_column_name = inf_vector_column_query(self.schema) + try: + vector_column_name = inf_vector_column_query(self.schema) + except Exception as e: + if query_type == "fts": + vector_column_name = "" + else: + raise e + return LanceQueryBuilder.create( self, query, diff --git a/python/python/tests/test_db.py b/python/python/tests/test_db.py index fc4420ba..82b90c0a 100644 --- a/python/python/tests/test_db.py +++ b/python/python/tests/test_db.py @@ -28,13 +28,25 @@ def test_basic(tmp_path): assert db.uri == str(tmp_path) assert db.table_names() == [] + class SimpleModel(LanceModel): + item: str + price: float + vector: Vector(2) + table = db.create_table( "test", data=[ {"vector": [3.1, 4.1], "item": "foo", "price": 10.0}, {"vector": [5.9, 26.5], "item": "bar", "price": 20.0}, ], + schema=SimpleModel, ) + + with pytest.raises( + ValueError, match="Cannot add a single LanceModel to a table. Use a list." + ): + table.add(SimpleModel(item="baz", price=30.0, vector=[1.0, 2.0])) + rs = table.search([100, 100]).limit(1).to_pandas() assert len(rs) == 1 assert rs["item"].iloc[0] == "bar" @@ -43,6 +55,11 @@ def test_basic(tmp_path): assert len(rs) == 1 assert rs["item"].iloc[0] == "foo" + table.create_fts_index(["item"]) + rs = table.search("bar", query_type="fts").to_pandas() + assert len(rs) == 1 + assert rs["item"].iloc[0] == "bar" + assert db.table_names() == ["test"] assert "test" in db assert len(db) == 1