diff --git a/python/python/lancedb/embeddings/gemini_text.py b/python/python/lancedb/embeddings/gemini_text.py index 9756115c..7de412d1 100644 --- a/python/python/lancedb/embeddings/gemini_text.py +++ b/python/python/lancedb/embeddings/gemini_text.py @@ -4,7 +4,7 @@ import os from functools import cached_property -from typing import List, Union +from typing import List, Optional, Union import numpy as np @@ -46,10 +46,11 @@ class GeminiText(TextEmbeddingFunction): Parameters ---------- - name: str, default "models/embedding-001" + name: str, default "models/text-embedding-004" The name of the model to use. See the Gemini documentation for a list of available models. - + dims: int, optional + The dimension of the embedding, otherwise it will be inferred. query_task_type: str, default "retrieval_query" Sets the task type for the queries. source_task_type: str, default "retrieval_document" @@ -77,9 +78,10 @@ class GeminiText(TextEmbeddingFunction): """ - name: str = "models/embedding-001" + name: str = "models/text-embedding-004" query_task_type: str = "retrieval_query" source_task_type: str = "retrieval_document" + dims: Optional[int] = None if PYDANTIC_VERSION.major < 2: # Pydantic 1.x compat @@ -89,9 +91,18 @@ class GeminiText(TextEmbeddingFunction): model_config = dict() model_config["ignored_types"] = (cached_property,) - def ndims(self): - # TODO: fix hardcoding - return 768 + @cached_property + def _model(self): + return self.client.get_model(self.name) + + def ndims(self) -> int: + if self.dims: + return self.dims + if hasattr(self._model, "output_dimensionality"): + return self._model.output_dimensionality + # Fallback for older versions of the library + # or models that don't have the attribute + return len(self.generate_embeddings(["lancedb"])[0]) def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]: return self.compute_source_embeddings(query, task_type=self.query_task_type) @@ -119,6 +130,8 @@ class GeminiText(TextEmbeddingFunction): ): # Provide a title to use existing API design title = "Embedding of a document" kwargs["title"] = title + if self.dims: + kwargs["output_dimensionality"] = self.dims return [ self.client.embed_content(model=self.name, content=text, **kwargs)[ @@ -131,6 +144,8 @@ class GeminiText(TextEmbeddingFunction): def client(self): genai = attempt_import_or_raise("google.generativeai", "google.generativeai") - if not os.environ.get("GOOGLE_API_KEY"): + api_key = os.environ.get("GOOGLE_API_KEY") + if not api_key: api_key_not_found_help("google") + genai.configure(api_key=api_key) return genai diff --git a/python/python/tests/test_embeddings_slow.py b/python/python/tests/test_embeddings_slow.py index b461c003..234abf5b 100644 --- a/python/python/tests/test_embeddings_slow.py +++ b/python/python/tests/test_embeddings_slow.py @@ -308,7 +308,7 @@ def test_instructor_embedding(tmp_path): os.environ.get("GOOGLE_API_KEY") is None, reason="GOOGLE_API_KEY not set" ) def test_gemini_embedding(tmp_path): - model = get_registry().get("gemini-text").create(max_retries=0) + model = get_registry().get("gemini-text").create(max_retries=0, dims=512) class TextModel(LanceModel): text: str = model.SourceField() @@ -319,7 +319,7 @@ def test_gemini_embedding(tmp_path): tbl = db.create_table("test", schema=TextModel, mode="overwrite") tbl.add(df) - assert len(tbl.to_pandas()["vector"][0]) == model.ndims() + assert len(tbl.to_pandas()["vector"][0]) == model.ndims() == 512 assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"