feat(python): adding VoyageAI v4 models (#2959)

Adding VoyageAI v4 models
 - with these, i added unit tests
 - added example code (tested!)
This commit is contained in:
fzowl
2026-01-31 00:16:03 +01:00
committed by GitHub
parent 9be28448f5
commit 1ee29675b3
4 changed files with 202 additions and 3 deletions

View File

@@ -21,6 +21,9 @@ if TYPE_CHECKING:
# Token limits for different VoyageAI models
VOYAGE_TOTAL_TOKEN_LIMITS = {
"voyage-4": 320_000,
"voyage-4-lite": 1_000_000,
"voyage-4-large": 120_000,
"voyage-context-3": 32_000,
"voyage-3.5-lite": 1_000_000,
"voyage-3.5": 320_000,
@@ -167,6 +170,9 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
name: str
The name of the model to use. List of acceptable models:
* voyage-4 (1024 dims, general-purpose and multilingual retrieval)
* voyage-4-lite (1024 dims, optimized for latency and cost)
* voyage-4-large (1024 dims, best retrieval quality)
* voyage-context-3
* voyage-3.5
* voyage-3.5-lite
@@ -215,6 +221,9 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
_FLEXIBLE_DIM_MODELS: ClassVar[list] = ["voyage-multimodal-3.5"]
_VALID_DIMENSIONS: ClassVar[list] = [256, 512, 1024, 2048]
text_embedding_models: list = [
"voyage-4",
"voyage-4-lite",
"voyage-4-large",
"voyage-3.5",
"voyage-3.5-lite",
"voyage-3",
@@ -252,6 +261,9 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
elif self.name == "voyage-code-2":
return 1536
elif self.name in [
"voyage-4",
"voyage-4-lite",
"voyage-4-large",
"voyage-context-3",
"voyage-3.5",
"voyage-3.5-lite",

View File

@@ -517,19 +517,36 @@ def test_ollama_embedding(tmp_path):
@pytest.mark.skipif(
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
)
def test_voyageai_embedding_function():
voyageai = get_registry().get("voyageai").create(name="voyage-3", max_retries=0)
@pytest.mark.parametrize(
"model_name,expected_dims",
[
("voyage-3", 1024),
("voyage-4", 1024),
("voyage-4-lite", 1024),
("voyage-4-large", 1024),
],
)
def test_voyageai_embedding_function(model_name, expected_dims, tmp_path):
"""Integration test for VoyageAI text embedding models with real API calls."""
voyageai = get_registry().get("voyageai").create(name=model_name, max_retries=0)
class TextModel(LanceModel):
text: str = voyageai.SourceField()
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
db = lancedb.connect("~/lancedb")
db = lancedb.connect(tmp_path)
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()
assert voyageai.ndims() == expected_dims, (
f"{model_name} should have {expected_dims} dimensions"
)
# Test search functionality
result = tbl.search("hello").limit(1).to_pandas()
assert result["text"][0] == "hello world"
@pytest.mark.slow

View File

@@ -0,0 +1,108 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
"""Unit tests for VoyageAI embedding function.
These tests verify model registration and configuration without requiring API calls.
"""
import pytest
from unittest.mock import MagicMock, patch
from lancedb.embeddings import get_registry
@pytest.fixture(autouse=True)
def reset_voyageai_client():
"""Reset VoyageAI client before and after each test to avoid state pollution."""
from lancedb.embeddings.voyageai import VoyageAIEmbeddingFunction
VoyageAIEmbeddingFunction.client = None
yield
VoyageAIEmbeddingFunction.client = None
class TestVoyageAIModelRegistration:
"""Tests for VoyageAI model registration and configuration."""
@pytest.fixture
def mock_voyageai_client(self):
"""Mock VoyageAI client to avoid API calls."""
with patch.dict("os.environ", {"VOYAGE_API_KEY": "test-key"}):
with patch("lancedb.embeddings.voyageai.attempt_import_or_raise") as mock:
mock_client = MagicMock()
mock_voyageai = MagicMock()
mock_voyageai.Client.return_value = mock_client
mock.return_value = mock_voyageai
yield mock_client
def test_voyageai_registered(self):
"""Test that VoyageAI is registered in the embedding function registry."""
registry = get_registry()
assert registry.get("voyageai") is not None
@pytest.mark.parametrize(
"model_name,expected_dims",
[
# Voyage-4 series (all 1024 dims)
("voyage-4", 1024),
("voyage-4-lite", 1024),
("voyage-4-large", 1024),
# Voyage-3 series
("voyage-3", 1024),
("voyage-3-lite", 512),
# Domain-specific models
("voyage-finance-2", 1024),
("voyage-multilingual-2", 1024),
("voyage-law-2", 1024),
("voyage-code-2", 1536),
# Multimodal
("voyage-multimodal-3", 1024),
],
)
def test_model_dimensions(self, model_name, expected_dims, mock_voyageai_client):
"""Test that each model returns the correct dimensions."""
registry = get_registry()
func = registry.get("voyageai").create(name=model_name)
assert func.ndims() == expected_dims, (
f"Model {model_name} should have {expected_dims} dimensions"
)
def test_unsupported_model_raises_error(self, mock_voyageai_client):
"""Test that unsupported models raise ValueError."""
registry = get_registry()
func = registry.get("voyageai").create(name="unsupported-model")
with pytest.raises(ValueError, match="not supported"):
func.ndims()
@pytest.mark.parametrize(
"model_name",
[
"voyage-4",
"voyage-4-lite",
"voyage-4-large",
],
)
def test_voyage4_models_are_text_models(self, model_name, mock_voyageai_client):
"""Test that voyage-4 models are classified as text models (not multimodal)."""
registry = get_registry()
func = registry.get("voyageai").create(name=model_name)
assert not func._is_multimodal_model(model_name), (
f"{model_name} should be a text model, not multimodal"
)
def test_voyage4_models_in_text_embedding_list(self, mock_voyageai_client):
"""Test that voyage-4 models are in the text_embedding_models list."""
registry = get_registry()
func = registry.get("voyageai").create(name="voyage-4")
assert "voyage-4" in func.text_embedding_models
assert "voyage-4-lite" in func.text_embedding_models
assert "voyage-4-large" in func.text_embedding_models
def test_voyage4_models_not_in_multimodal_list(self, mock_voyageai_client):
"""Test that voyage-4 models are NOT in the multimodal_embedding_models list."""
registry = get_registry()
func = registry.get("voyageai").create(name="voyage-4")
assert "voyage-4" not in func.multimodal_embedding_models
assert "voyage-4-lite" not in func.multimodal_embedding_models
assert "voyage-4-large" not in func.multimodal_embedding_models