From b66cd943a7c2d8db1d587a52b61b7e722602c0bc Mon Sep 17 00:00:00 2001 From: Prashant Dixit <54981696+PrashantDixit0@users.noreply.github.com> Date: Mon, 13 Jan 2025 22:22:38 +0530 Subject: [PATCH] fix: broken voyageai embedding API (#2013) This PR fixes the broken Embedding API for Voyageai. --- python/python/lancedb/embeddings/voyageai.py | 23 ++++++++++++++------ 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/python/python/lancedb/embeddings/voyageai.py b/python/python/lancedb/embeddings/voyageai.py index ec2c5b6d..417b5fc4 100644 --- a/python/python/lancedb/embeddings/voyageai.py +++ b/python/python/lancedb/embeddings/voyageai.py @@ -59,7 +59,7 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction): .create(name="voyage-3") class TextModel(LanceModel): - data: str = voyageai.SourceField() + text: str = voyageai.SourceField() vector: Vector(voyageai.ndims()) = voyageai.VectorField() data = [ { "text": "hello world" }, @@ -74,6 +74,14 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction): name: str client: ClassVar = None + text_embedding_models: list = [ + "voyage-3", + "voyage-3-lite", + "voyage-finance-2", + "voyage-law-2", + "voyage-code-2", + ] + multimodal_embedding_models: list = ["voyage-multimodal-3"] def ndims(self): if self.name == "voyage-3-lite": @@ -115,13 +123,14 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction): truncation: Optional[bool] """ - if self.name in ["voyage-multimodal-3"]: - rs = VoyageAIEmbeddingFunction._get_client().multimodal_embed( - inputs=[[text]], model=self.name, **kwargs - ) + client = VoyageAIEmbeddingFunction._get_client() + if self.name in self.text_embedding_models: + rs = client.embed(texts=[text], model=self.name, **kwargs) + elif self.name in self.multimodal_embedding_models: + rs = client.multimodal_embed(inputs=[[text]], model=self.name, **kwargs) else: - rs = VoyageAIEmbeddingFunction._get_client().embed( - texts=[text], model=self.name, **kwargs + raise ValueError( + f"Model {self.name} not supported to generate text embeddings" ) return rs.embeddings[0]