diff --git a/python/python/lancedb/embeddings/voyageai.py b/python/python/lancedb/embeddings/voyageai.py index ec2c5b6d..417b5fc4 100644 --- a/python/python/lancedb/embeddings/voyageai.py +++ b/python/python/lancedb/embeddings/voyageai.py @@ -59,7 +59,7 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction): .create(name="voyage-3") class TextModel(LanceModel): - data: str = voyageai.SourceField() + text: str = voyageai.SourceField() vector: Vector(voyageai.ndims()) = voyageai.VectorField() data = [ { "text": "hello world" }, @@ -74,6 +74,14 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction): name: str client: ClassVar = None + text_embedding_models: list = [ + "voyage-3", + "voyage-3-lite", + "voyage-finance-2", + "voyage-law-2", + "voyage-code-2", + ] + multimodal_embedding_models: list = ["voyage-multimodal-3"] def ndims(self): if self.name == "voyage-3-lite": @@ -115,13 +123,14 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction): truncation: Optional[bool] """ - if self.name in ["voyage-multimodal-3"]: - rs = VoyageAIEmbeddingFunction._get_client().multimodal_embed( - inputs=[[text]], model=self.name, **kwargs - ) + client = VoyageAIEmbeddingFunction._get_client() + if self.name in self.text_embedding_models: + rs = client.embed(texts=[text], model=self.name, **kwargs) + elif self.name in self.multimodal_embedding_models: + rs = client.multimodal_embed(inputs=[[text]], model=self.name, **kwargs) else: - rs = VoyageAIEmbeddingFunction._get_client().embed( - texts=[text], model=self.name, **kwargs + raise ValueError( + f"Model {self.name} not supported to generate text embeddings" ) return rs.embeddings[0]