mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 15:12:53 +00:00
feat: voyageai support (#1799)
Adding VoyageAI embedding and rerank support
This commit is contained in:
@@ -27,3 +27,4 @@ from .imagebind import ImageBindEmbeddings
|
||||
from .utils import with_embeddings
|
||||
from .jinaai import JinaEmbeddings
|
||||
from .watsonx import WatsonxEmbeddings
|
||||
from .voyageai import VoyageAIEmbeddingFunction
|
||||
|
||||
127
python/python/lancedb/embeddings/voyageai.py
Normal file
127
python/python/lancedb/embeddings/voyageai.py
Normal file
@@ -0,0 +1,127 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import ClassVar, List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import TextEmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import api_key_not_found_help, TEXT
|
||||
|
||||
|
||||
@register("voyageai")
|
||||
class VoyageAIEmbeddingFunction(TextEmbeddingFunction):
|
||||
"""
|
||||
An embedding function that uses the VoyageAI API
|
||||
|
||||
https://docs.voyageai.com/docs/embeddings
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str
|
||||
The name of the model to use. List of acceptable models:
|
||||
|
||||
* voyage-3
|
||||
* voyage-3-lite
|
||||
* voyage-finance-2
|
||||
* voyage-multilingual-2
|
||||
* voyage-law-2
|
||||
* voyage-code-2
|
||||
|
||||
|
||||
Examples
|
||||
--------
|
||||
import lancedb
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.embeddings import EmbeddingFunctionRegistry
|
||||
|
||||
voyageai = EmbeddingFunctionRegistry
|
||||
.get_instance()
|
||||
.get("voyageai")
|
||||
.create(name="voyage-3")
|
||||
|
||||
class TextModel(LanceModel):
|
||||
text: str = voyageai.SourceField()
|
||||
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
|
||||
|
||||
data = [ { "text": "hello world" },
|
||||
{ "text": "goodbye world" }]
|
||||
|
||||
db = lancedb.connect("~/.lancedb")
|
||||
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
||||
|
||||
tbl.add(data)
|
||||
|
||||
"""
|
||||
|
||||
name: str
|
||||
client: ClassVar = None
|
||||
|
||||
def ndims(self):
|
||||
if self.name == "voyage-3-lite":
|
||||
return 512
|
||||
elif self.name == "voyage-code-2":
|
||||
return 1536
|
||||
elif self.name in [
|
||||
"voyage-3",
|
||||
"voyage-finance-2",
|
||||
"voyage-multilingual-2",
|
||||
"voyage-law-2",
|
||||
]:
|
||||
return 1024
|
||||
else:
|
||||
raise ValueError(f"Model {self.name} not supported")
|
||||
|
||||
def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
|
||||
return self.compute_source_embeddings(query, input_type="query")
|
||||
|
||||
def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
|
||||
texts = self.sanitize_input(texts)
|
||||
input_type = (
|
||||
kwargs.get("input_type") or "document"
|
||||
) # assume source input type if not passed by `compute_query_embeddings`
|
||||
return self.generate_embeddings(texts, input_type=input_type)
|
||||
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray], *args, **kwargs
|
||||
) -> List[np.array]:
|
||||
"""
|
||||
Get the embeddings for the given texts
|
||||
|
||||
Parameters
|
||||
----------
|
||||
texts: list[str] or np.ndarray (of str)
|
||||
The texts to embed
|
||||
input_type: Optional[str]
|
||||
|
||||
truncation: Optional[bool]
|
||||
"""
|
||||
VoyageAIEmbeddingFunction._init_client()
|
||||
rs = VoyageAIEmbeddingFunction.client.embed(
|
||||
texts=texts, model=self.name, **kwargs
|
||||
)
|
||||
|
||||
return [emb for emb in rs.embeddings]
|
||||
|
||||
@staticmethod
|
||||
def _init_client():
|
||||
if VoyageAIEmbeddingFunction.client is None:
|
||||
voyageai = attempt_import_or_raise("voyageai")
|
||||
if os.environ.get("VOYAGE_API_KEY") is None:
|
||||
api_key_not_found_help("voyageai")
|
||||
VoyageAIEmbeddingFunction.client = voyageai.Client(
|
||||
os.environ["VOYAGE_API_KEY"]
|
||||
)
|
||||
@@ -7,6 +7,7 @@ from .openai import OpenaiReranker
|
||||
from .jinaai import JinaReranker
|
||||
from .rrf import RRFReranker
|
||||
from .answerdotai import AnswerdotaiRerankers
|
||||
from .voyageai import VoyageAIReranker
|
||||
|
||||
__all__ = [
|
||||
"Reranker",
|
||||
@@ -18,4 +19,5 @@ __all__ = [
|
||||
"JinaReranker",
|
||||
"RRFReranker",
|
||||
"AnswerdotaiRerankers",
|
||||
"VoyageAIReranker",
|
||||
]
|
||||
|
||||
133
python/python/lancedb/rerankers/voyageai.py
Normal file
133
python/python/lancedb/rerankers/voyageai.py
Normal file
@@ -0,0 +1,133 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from functools import cached_property
|
||||
from typing import Union, Optional
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import Reranker
|
||||
|
||||
|
||||
class VoyageAIReranker(Reranker):
|
||||
"""
|
||||
Reranks the results using the VoyageAI Rerank API.
|
||||
https://docs.voyageai.com/docs/reranker
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_name : str, default "rerank-english-v2.0"
|
||||
The name of the cross encoder model to use. Available voyageai models are:
|
||||
- rerank-2
|
||||
- rerank-2-lite
|
||||
column : str, default "text"
|
||||
The name of the column to use as input to the cross encoder model.
|
||||
top_n : int, default None
|
||||
The number of results to return. If None, will return all results.
|
||||
return_score : str, default "relevance"
|
||||
options are "relevance" or "all". Only "relevance" is supported for now.
|
||||
api_key : str, default None
|
||||
The API key to use. If None, will use the OPENAI_API_KEY environment variable.
|
||||
truncation : Optional[bool], default None
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
column: str = "text",
|
||||
top_n: Optional[int] = None,
|
||||
return_score="relevance",
|
||||
api_key: Optional[str] = None,
|
||||
truncation: Optional[bool] = True,
|
||||
):
|
||||
super().__init__(return_score)
|
||||
self.model_name = model_name
|
||||
self.column = column
|
||||
self.top_n = top_n
|
||||
self.api_key = api_key
|
||||
self.truncation = truncation
|
||||
|
||||
@cached_property
|
||||
def _client(self):
|
||||
voyageai = attempt_import_or_raise("voyageai")
|
||||
if os.environ.get("VOYAGE_API_KEY") is None and self.api_key is None:
|
||||
raise ValueError(
|
||||
"VOYAGE_API_KEY not set. Either set it in your environment or \
|
||||
pass it as `api_key` argument to the VoyageAIReranker."
|
||||
)
|
||||
return voyageai.Client(
|
||||
api_key=os.environ.get("VOYAGE_API_KEY") or self.api_key,
|
||||
)
|
||||
|
||||
def _rerank(self, result_set: pa.Table, query: str):
|
||||
docs = result_set[self.column].to_pylist()
|
||||
response = self._client.rerank(
|
||||
query=query,
|
||||
documents=docs,
|
||||
top_k=self.top_n,
|
||||
model=self.model_name,
|
||||
truncation=self.truncation,
|
||||
)
|
||||
results = (
|
||||
response.results
|
||||
) # returns list (text, idx, relevance) attributes sorted descending by score
|
||||
indices, scores = list(
|
||||
zip(*[(result.index, result.relevance_score) for result in results])
|
||||
) # tuples
|
||||
result_set = result_set.take(list(indices))
|
||||
# add the scores
|
||||
result_set = result_set.append_column(
|
||||
"_relevance_score", pa.array(scores, type=pa.float32())
|
||||
)
|
||||
|
||||
return result_set
|
||||
|
||||
def rerank_hybrid(
|
||||
self,
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
combined_results = self.merge_results(vector_results, fts_results)
|
||||
combined_results = self._rerank(combined_results, query)
|
||||
if self.score == "relevance":
|
||||
combined_results = self._keep_relevance_score(combined_results)
|
||||
elif self.score == "all":
|
||||
raise NotImplementedError(
|
||||
"return_score='all' not implemented for voyageai reranker"
|
||||
)
|
||||
return combined_results
|
||||
|
||||
def rerank_vector(
|
||||
self,
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
):
|
||||
result_set = self._rerank(vector_results, query)
|
||||
if self.score == "relevance":
|
||||
result_set = result_set.drop_columns(["_distance"])
|
||||
|
||||
return result_set
|
||||
|
||||
def rerank_fts(
|
||||
self,
|
||||
query: str,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
result_set = self._rerank(fts_results, query)
|
||||
if self.score == "relevance":
|
||||
result_set = result_set.drop_columns(["_score"])
|
||||
|
||||
return result_set
|
||||
@@ -196,6 +196,7 @@ def test_add_optional_vector(tmp_path):
|
||||
"ollama",
|
||||
"cohere",
|
||||
"instructor",
|
||||
"voyageai",
|
||||
],
|
||||
)
|
||||
def test_embedding_function_safe_model_dump(embedding_type):
|
||||
|
||||
@@ -481,3 +481,22 @@ def test_ollama_embedding(tmp_path):
|
||||
json.dumps(dumped_model)
|
||||
except TypeError:
|
||||
pytest.fail("Failed to JSON serialize the dumped model")
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
|
||||
)
|
||||
def test_voyageai_embedding_function():
|
||||
voyageai = get_registry().get("voyageai").create(name="voyage-3", max_retries=0)
|
||||
|
||||
class TextModel(LanceModel):
|
||||
text: str = voyageai.SourceField()
|
||||
vector: Vector(voyageai.ndims()) = voyageai.VectorField()
|
||||
|
||||
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
|
||||
db = lancedb.connect("~/lancedb")
|
||||
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
||||
|
||||
tbl.add(df)
|
||||
assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims()
|
||||
|
||||
@@ -16,6 +16,7 @@ from lancedb.rerankers import (
|
||||
OpenaiReranker,
|
||||
JinaReranker,
|
||||
AnswerdotaiRerankers,
|
||||
VoyageAIReranker,
|
||||
)
|
||||
from lancedb.table import LanceTable
|
||||
|
||||
@@ -344,3 +345,14 @@ def test_jina_reranker(tmp_path, use_tantivy):
|
||||
table, schema = get_test_table(tmp_path, use_tantivy)
|
||||
reranker = JinaReranker()
|
||||
_run_test_reranker(reranker, table, "single player experience", None, schema)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set"
|
||||
)
|
||||
@pytest.mark.parametrize("use_tantivy", [True, False])
|
||||
def test_voyageai_reranker(tmp_path, use_tantivy):
|
||||
pytest.importorskip("voyageai")
|
||||
reranker = VoyageAIReranker(model_name="rerank-2")
|
||||
table, schema = get_test_table(tmp_path, use_tantivy)
|
||||
_run_test_reranker(reranker, table, "single player experience", None, schema)
|
||||
|
||||
Reference in New Issue
Block a user