diff --git a/python/python/lancedb/embeddings/base.py b/python/python/lancedb/embeddings/base.py index 07ef17ae..d7ab93ee 100644 --- a/python/python/lancedb/embeddings/base.py +++ b/python/python/lancedb/embeddings/base.py @@ -106,8 +106,14 @@ class EmbeddingFunction(BaseModel, ABC): from ..pydantic import PYDANTIC_VERSION if PYDANTIC_VERSION.major < 2: - return dict(self) - return self.model_dump() + return {k: v for k, v in self.__dict__.items() if not k.startswith("_")} + return self.model_dump( + exclude={ + field_name + for field_name in self.model_fields + if field_name.startswith("_") + } + ) @abstractmethod def ndims(self): diff --git a/python/python/tests/test_embeddings.py b/python/python/tests/test_embeddings.py index f858e8fc..9611f0ec 100644 --- a/python/python/tests/test_embeddings.py +++ b/python/python/tests/test_embeddings.py @@ -183,3 +183,45 @@ def test_add_optional_vector(tmp_path): expected = LanceSchema(id="id", text="text") tbl.add([expected]) assert not (np.abs(tbl.to_pandas()["vector"][0]) < 1e-6).all() + + +@pytest.mark.parametrize( + "embedding_type", + [ + "openai", + "sentence-transformers", + "huggingface", + "ollama", + "cohere", + "instructor", + ], +) +def test_embedding_function_safe_model_dump(embedding_type): + registry = get_registry() + + # Note: Some embedding types might require specific parameters + try: + model = registry.get(embedding_type).create() + except Exception as e: + pytest.skip(f"Skipping {embedding_type} due to error: {str(e)}") + + dumped_model = model.safe_model_dump() + + assert all( + not k.startswith("_") for k in dumped_model.keys() + ), f"{embedding_type}: Dumped model contains keys starting with underscore" + + assert ( + "max_retries" in dumped_model + ), f"{embedding_type}: Essential field 'max_retries' is missing from dumped model" + + assert isinstance( + dumped_model, dict + ), f"{embedding_type}: Dumped model is not a dictionary" + + for key in model.__dict__: + if key.startswith("_"): + assert key not in dumped_model, ( + f"{embedding_type}: Private attribute '{key}' " + f"is present in dumped model" + ) diff --git a/python/python/tests/test_embeddings_slow.py b/python/python/tests/test_embeddings_slow.py index 6762ae10..87b2e249 100644 --- a/python/python/tests/test_embeddings_slow.py +++ b/python/python/tests/test_embeddings_slow.py @@ -442,3 +442,42 @@ def test_watsonx_embedding(tmp_path): tbl.add(df) assert len(tbl.to_pandas()["vector"][0]) == model.ndims() assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world" + + +@pytest.mark.slow +@pytest.mark.skipif( + importlib.util.find_spec("ollama") is None, reason="Ollama not installed" +) +def test_ollama_embedding(tmp_path): + model = get_registry().get("ollama").create(max_retries=0) + + class TextModel(LanceModel): + text: str = model.SourceField() + vector: Vector(model.ndims()) = model.VectorField() + + df = pd.DataFrame({"text": ["hello world", "goodbye world"]}) + db = lancedb.connect(tmp_path) + tbl = db.create_table("test", schema=TextModel, mode="overwrite") + + tbl.add(df) + + assert len(tbl.to_pandas()["vector"][0]) == model.ndims() + + result = tbl.search("hello").limit(1).to_pandas() + assert result["text"][0] == "hello world" + + # Test safe_model_dump + dumped_model = model.safe_model_dump() + assert isinstance(dumped_model, dict) + assert "name" in dumped_model + assert "max_retries" in dumped_model + assert dumped_model["max_retries"] == 0 + assert all(not k.startswith("_") for k in dumped_model.keys()) + + # Test serialization of the dumped model + import json + + try: + json.dumps(dumped_model) + except TypeError: + pytest.fail("Failed to JSON serialize the dumped model")