mirror of
https://github.com/lancedb/lancedb.git
synced 2026-03-19 15:10:40 +00:00
feat(python): adding VoyageAI v4 models (#2959)
Adding VoyageAI v4 models - with these, i added unit tests - added example code (tested!)
This commit is contained in:
@@ -21,6 +21,9 @@ if TYPE_CHECKING:
|
||||
|
||||
# Token limits for different VoyageAI models
|
||||
VOYAGE_TOTAL_TOKEN_LIMITS = {
|
||||
"voyage-4": 320_000,
|
||||
"voyage-4-lite": 1_000_000,
|
||||
"voyage-4-large": 120_000,
|
||||
"voyage-context-3": 32_000,
|
||||
"voyage-3.5-lite": 1_000_000,
|
||||
"voyage-3.5": 320_000,
|
||||
@@ -167,6 +170,9 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
||||
name: str
|
||||
The name of the model to use. List of acceptable models:
|
||||
|
||||
* voyage-4 (1024 dims, general-purpose and multilingual retrieval)
|
||||
* voyage-4-lite (1024 dims, optimized for latency and cost)
|
||||
* voyage-4-large (1024 dims, best retrieval quality)
|
||||
* voyage-context-3
|
||||
* voyage-3.5
|
||||
* voyage-3.5-lite
|
||||
@@ -215,6 +221,9 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
||||
_FLEXIBLE_DIM_MODELS: ClassVar[list] = ["voyage-multimodal-3.5"]
|
||||
_VALID_DIMENSIONS: ClassVar[list] = [256, 512, 1024, 2048]
|
||||
text_embedding_models: list = [
|
||||
"voyage-4",
|
||||
"voyage-4-lite",
|
||||
"voyage-4-large",
|
||||
"voyage-3.5",
|
||||
"voyage-3.5-lite",
|
||||
"voyage-3",
|
||||
@@ -252,6 +261,9 @@ class VoyageAIEmbeddingFunction(EmbeddingFunction):
|
||||
elif self.name == "voyage-code-2":
|
||||
return 1536
|
||||
elif self.name in [
|
||||
"voyage-4",
|
||||
"voyage-4-lite",
|
||||
"voyage-4-large",
|
||||
"voyage-context-3",
|
||||
"voyage-3.5",
|
||||
"voyage-3.5-lite",
|
||||
|
||||
@@ -517,19 +517,36 @@ def test_ollama_embedding(tmp_path):
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
|
||||
)
|
||||
def test_voyageai_embedding_function():
|
||||
voyageai = get_registry().get("voyageai").create(name="voyage-3", max_retries=0)
|
||||
@pytest.mark.parametrize(
|
||||
"model_name,expected_dims",
|
||||
[
|
||||
("voyage-3", 1024),
|
||||
("voyage-4", 1024),
|
||||
("voyage-4-lite", 1024),
|
||||
("voyage-4-large", 1024),
|
||||
],
|
||||
)
|
||||
def test_voyageai_embedding_function(model_name, expected_dims, tmp_path):
|
||||
"""Integration test for VoyageAI text embedding models with real API calls."""
|
||||
voyageai = get_registry().get("voyageai").create(name=model_name, max_retries=0)
|
||||
|
||||
class TextModel(LanceModel):
|
||||
text: str = voyageai.SourceField()
|
||||
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
|
||||
|
||||
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
||||
db = lancedb.connect("~/lancedb")
|
||||
db = lancedb.connect(tmp_path)
|
||||
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
||||
|
||||
tbl.add(df)
|
||||
assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()
|
||||
assert voyageai.ndims() == expected_dims, (
|
||||
f"{model_name} should have {expected_dims} dimensions"
|
||||
)
|
||||
|
||||
# Test search functionality
|
||||
result = tbl.search("hello").limit(1).to_pandas()
|
||||
assert result["text"][0] == "hello world"
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
|
||||
108
python/python/tests/test_voyageai_embeddings.py
Normal file
108
python/python/tests/test_voyageai_embeddings.py
Normal file
@@ -0,0 +1,108 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
"""Unit tests for VoyageAI embedding function.
|
||||
|
||||
These tests verify model registration and configuration without requiring API calls.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from lancedb.embeddings import get_registry
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_voyageai_client():
|
||||
"""Reset VoyageAI client before and after each test to avoid state pollution."""
|
||||
from lancedb.embeddings.voyageai import VoyageAIEmbeddingFunction
|
||||
|
||||
VoyageAIEmbeddingFunction.client = None
|
||||
yield
|
||||
VoyageAIEmbeddingFunction.client = None
|
||||
|
||||
|
||||
class TestVoyageAIModelRegistration:
|
||||
"""Tests for VoyageAI model registration and configuration."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_voyageai_client(self):
|
||||
"""Mock VoyageAI client to avoid API calls."""
|
||||
with patch.dict("os.environ", {"VOYAGE_API_KEY": "test-key"}):
|
||||
with patch("lancedb.embeddings.voyageai.attempt_import_or_raise") as mock:
|
||||
mock_client = MagicMock()
|
||||
mock_voyageai = MagicMock()
|
||||
mock_voyageai.Client.return_value = mock_client
|
||||
mock.return_value = mock_voyageai
|
||||
yield mock_client
|
||||
|
||||
def test_voyageai_registered(self):
|
||||
"""Test that VoyageAI is registered in the embedding function registry."""
|
||||
registry = get_registry()
|
||||
assert registry.get("voyageai") is not None
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name,expected_dims",
|
||||
[
|
||||
# Voyage-4 series (all 1024 dims)
|
||||
("voyage-4", 1024),
|
||||
("voyage-4-lite", 1024),
|
||||
("voyage-4-large", 1024),
|
||||
# Voyage-3 series
|
||||
("voyage-3", 1024),
|
||||
("voyage-3-lite", 512),
|
||||
# Domain-specific models
|
||||
("voyage-finance-2", 1024),
|
||||
("voyage-multilingual-2", 1024),
|
||||
("voyage-law-2", 1024),
|
||||
("voyage-code-2", 1536),
|
||||
# Multimodal
|
||||
("voyage-multimodal-3", 1024),
|
||||
],
|
||||
)
|
||||
def test_model_dimensions(self, model_name, expected_dims, mock_voyageai_client):
|
||||
"""Test that each model returns the correct dimensions."""
|
||||
registry = get_registry()
|
||||
func = registry.get("voyageai").create(name=model_name)
|
||||
assert func.ndims() == expected_dims, (
|
||||
f"Model {model_name} should have {expected_dims} dimensions"
|
||||
)
|
||||
|
||||
def test_unsupported_model_raises_error(self, mock_voyageai_client):
|
||||
"""Test that unsupported models raise ValueError."""
|
||||
registry = get_registry()
|
||||
func = registry.get("voyageai").create(name="unsupported-model")
|
||||
with pytest.raises(ValueError, match="not supported"):
|
||||
func.ndims()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[
|
||||
"voyage-4",
|
||||
"voyage-4-lite",
|
||||
"voyage-4-large",
|
||||
],
|
||||
)
|
||||
def test_voyage4_models_are_text_models(self, model_name, mock_voyageai_client):
|
||||
"""Test that voyage-4 models are classified as text models (not multimodal)."""
|
||||
registry = get_registry()
|
||||
func = registry.get("voyageai").create(name=model_name)
|
||||
assert not func._is_multimodal_model(model_name), (
|
||||
f"{model_name} should be a text model, not multimodal"
|
||||
)
|
||||
|
||||
def test_voyage4_models_in_text_embedding_list(self, mock_voyageai_client):
|
||||
"""Test that voyage-4 models are in the text_embedding_models list."""
|
||||
registry = get_registry()
|
||||
func = registry.get("voyageai").create(name="voyage-4")
|
||||
assert "voyage-4" in func.text_embedding_models
|
||||
assert "voyage-4-lite" in func.text_embedding_models
|
||||
assert "voyage-4-large" in func.text_embedding_models
|
||||
|
||||
def test_voyage4_models_not_in_multimodal_list(self, mock_voyageai_client):
|
||||
"""Test that voyage-4 models are NOT in the multimodal_embedding_models list."""
|
||||
registry = get_registry()
|
||||
func = registry.get("voyageai").create(name="voyage-4")
|
||||
assert "voyage-4" not in func.multimodal_embedding_models
|
||||
assert "voyage-4-lite" not in func.multimodal_embedding_models
|
||||
assert "voyage-4-large" not in func.multimodal_embedding_models
|
||||
Reference in New Issue
Block a user