mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-24 05:49:57 +00:00
Compare commits
1 Commits
python-v0.
...
ayush/gemi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5904aec34b |
@@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import List, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -46,10 +46,11 @@ class GeminiText(TextEmbeddingFunction):
|
|||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
name: str, default "models/embedding-001"
|
name: str, default "models/text-embedding-004"
|
||||||
The name of the model to use. See the Gemini documentation for a list of
|
The name of the model to use. See the Gemini documentation for a list of
|
||||||
available models.
|
available models.
|
||||||
|
dims: int, optional
|
||||||
|
The dimension of the embedding, otherwise it will be inferred.
|
||||||
query_task_type: str, default "retrieval_query"
|
query_task_type: str, default "retrieval_query"
|
||||||
Sets the task type for the queries.
|
Sets the task type for the queries.
|
||||||
source_task_type: str, default "retrieval_document"
|
source_task_type: str, default "retrieval_document"
|
||||||
@@ -77,9 +78,10 @@ class GeminiText(TextEmbeddingFunction):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name: str = "models/embedding-001"
|
name: str = "models/text-embedding-004"
|
||||||
query_task_type: str = "retrieval_query"
|
query_task_type: str = "retrieval_query"
|
||||||
source_task_type: str = "retrieval_document"
|
source_task_type: str = "retrieval_document"
|
||||||
|
dims: Optional[int] = None
|
||||||
|
|
||||||
if PYDANTIC_VERSION.major < 2: # Pydantic 1.x compat
|
if PYDANTIC_VERSION.major < 2: # Pydantic 1.x compat
|
||||||
|
|
||||||
@@ -89,9 +91,18 @@ class GeminiText(TextEmbeddingFunction):
|
|||||||
model_config = dict()
|
model_config = dict()
|
||||||
model_config["ignored_types"] = (cached_property,)
|
model_config["ignored_types"] = (cached_property,)
|
||||||
|
|
||||||
def ndims(self):
|
@cached_property
|
||||||
# TODO: fix hardcoding
|
def _model(self):
|
||||||
return 768
|
return self.client.get_model(self.name)
|
||||||
|
|
||||||
|
def ndims(self) -> int:
|
||||||
|
if self.dims:
|
||||||
|
return self.dims
|
||||||
|
if hasattr(self._model, "output_dimensionality"):
|
||||||
|
return self._model.output_dimensionality
|
||||||
|
# Fallback for older versions of the library
|
||||||
|
# or models that don't have the attribute
|
||||||
|
return len(self.generate_embeddings(["lancedb"])[0])
|
||||||
|
|
||||||
def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
|
def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
|
||||||
return self.compute_source_embeddings(query, task_type=self.query_task_type)
|
return self.compute_source_embeddings(query, task_type=self.query_task_type)
|
||||||
@@ -119,6 +130,8 @@ class GeminiText(TextEmbeddingFunction):
|
|||||||
): # Provide a title to use existing API design
|
): # Provide a title to use existing API design
|
||||||
title = "Embedding of a document"
|
title = "Embedding of a document"
|
||||||
kwargs["title"] = title
|
kwargs["title"] = title
|
||||||
|
if self.dims:
|
||||||
|
kwargs["output_dimensionality"] = self.dims
|
||||||
|
|
||||||
return [
|
return [
|
||||||
self.client.embed_content(model=self.name, content=text, **kwargs)[
|
self.client.embed_content(model=self.name, content=text, **kwargs)[
|
||||||
@@ -131,6 +144,8 @@ class GeminiText(TextEmbeddingFunction):
|
|||||||
def client(self):
|
def client(self):
|
||||||
genai = attempt_import_or_raise("google.generativeai", "google.generativeai")
|
genai = attempt_import_or_raise("google.generativeai", "google.generativeai")
|
||||||
|
|
||||||
if not os.environ.get("GOOGLE_API_KEY"):
|
api_key = os.environ.get("GOOGLE_API_KEY")
|
||||||
|
if not api_key:
|
||||||
api_key_not_found_help("google")
|
api_key_not_found_help("google")
|
||||||
|
genai.configure(api_key=api_key)
|
||||||
return genai
|
return genai
|
||||||
|
|||||||
@@ -308,7 +308,7 @@ def test_instructor_embedding(tmp_path):
|
|||||||
os.environ.get("GOOGLE_API_KEY") is None, reason="GOOGLE_API_KEY not set"
|
os.environ.get("GOOGLE_API_KEY") is None, reason="GOOGLE_API_KEY not set"
|
||||||
)
|
)
|
||||||
def test_gemini_embedding(tmp_path):
|
def test_gemini_embedding(tmp_path):
|
||||||
model = get_registry().get("gemini-text").create(max_retries=0)
|
model = get_registry().get("gemini-text").create(max_retries=0, dims=512)
|
||||||
|
|
||||||
class TextModel(LanceModel):
|
class TextModel(LanceModel):
|
||||||
text: str = model.SourceField()
|
text: str = model.SourceField()
|
||||||
@@ -319,7 +319,7 @@ def test_gemini_embedding(tmp_path):
|
|||||||
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
||||||
|
|
||||||
tbl.add(df)
|
tbl.add(df)
|
||||||
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
|
assert len(tbl.to_pandas()["vector"][0]) == model.ndims() == 512
|
||||||
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
|
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user