diff --git a/docs/src/embeddings/available_embedding_models/text_embedding_functions/voyageai_embedding.md b/docs/src/embeddings/available_embedding_models/text_embedding_functions/voyageai_embedding.md new file mode 100644 index 00000000..41a6be31 --- /dev/null +++ b/docs/src/embeddings/available_embedding_models/text_embedding_functions/voyageai_embedding.md @@ -0,0 +1,51 @@ +# VoyageAI Embeddings + +Voyage AI provides cutting-edge embedding and rerankers. + + +Using voyageai API requires voyageai package, which can be installed using `pip install voyageai`. Voyage AI embeddings are used to generate embeddings for text data. The embeddings can be used for various tasks like semantic search, clustering, and classification. +You also need to set the `VOYAGE_API_KEY` environment variable to use the VoyageAI API. + +Supported models are: + +- voyage-3 +- voyage-3-lite +- voyage-finance-2 +- voyage-multilingual-2 +- voyage-law-2 +- voyage-code-2 + + +Supported parameters (to be passed in `create` method) are: + +| Parameter | Type | Default Value | Description | +|---|---|--------|---------| +| `name` | `str` | `"voyage-3"` | The model ID of the model to use. Supported base models for Text Embeddings: voyage-3, voyage-3-lite, voyage-finance-2, voyage-multilingual-2, voyage-law-2, voyage-code-2 | +| `input_type` | `str` | `None` | Type of the input text. Default to None. Other options: query, document. | +| `truncation` | `bool` | `True` | Whether to truncate the input texts to fit within the context length. | + + +Usage Example: + +```python + import lancedb + from lancedb.pydantic import LanceModel, Vector + from lancedb.embeddings import EmbeddingFunctionRegistry + + voyageai = EmbeddingFunctionRegistry + .get_instance() + .get("voyageai") + .create(name="voyage-3") + + class TextModel(LanceModel): + text: str = voyageai.SourceField() + vector: Vector(voyageai.ndims()) = voyageai.VectorField() + + data = [ { "text": "hello world" }, + { "text": "goodbye world" }] + + db = lancedb.connect("~/.lancedb") + tbl = db.create_table("test", schema=TextModel, mode="overwrite") + + tbl.add(data) +``` \ No newline at end of file diff --git a/docs/src/reranking/voyageai.md b/docs/src/reranking/voyageai.md new file mode 100644 index 00000000..4021729a --- /dev/null +++ b/docs/src/reranking/voyageai.md @@ -0,0 +1,77 @@ +# Voyage AI Reranker + +Voyage AI provides cutting-edge embedding and rerankers. + +This re-ranker uses the [VoyageAI](https://docs.voyageai.com/docs/) API to rerank the search results. You can use this re-ranker by passing `VoyageAIReranker()` to the `rerank()` method. Note that you'll either need to set the `VOYAGE_API_KEY` environment variable or pass the `api_key` argument to use this re-ranker. + + +!!! note + Supported Query Types: Hybrid, Vector, FTS + + +```python +import numpy +import lancedb +from lancedb.embeddings import get_registry +from lancedb.pydantic import LanceModel, Vector +from lancedb.rerankers import VoyageAIReranker + +embedder = get_registry().get("sentence-transformers").create() +db = lancedb.connect("~/.lancedb") + +class Schema(LanceModel): + text: str = embedder.SourceField() + vector: Vector(embedder.ndims()) = embedder.VectorField() + +data = [ + {"text": "hello world"}, + {"text": "goodbye world"} + ] +tbl = db.create_table("test", schema=Schema, mode="overwrite") +tbl.add(data) +reranker = VoyageAIReranker(model_name="rerank-2") + +# Run vector search with a reranker +result = tbl.search("hello").rerank(reranker=reranker).to_list() + +# Run FTS search with a reranker +result = tbl.search("hello", query_type="fts").rerank(reranker=reranker).to_list() + +# Run hybrid search with a reranker +tbl.create_fts_index("text", replace=True) +result = tbl.search("hello", query_type="hybrid").rerank(reranker=reranker).to_list() + +``` + +Accepted Arguments +---------------- +| Argument | Type | Default | Description | +| --- | --- | --- | --- | +| `model_name` | `str` | `None` | The name of the reranker model to use. Available models are: rerank-2, rerank-2-lite | +| `column` | `str` | `"text"` | The name of the column to use as input to the cross encoder model. | +| `top_n` | `str` | `None` | The number of results to return. If None, will return all results. | +| `api_key` | `str` | `None` | The API key for the Voyage AI API. If not provided, the `VOYAGE_API_KEY` environment variable is used. | +| `return_score` | str | `"relevance"` | Options are "relevance" or "all". The type of score to return. If "relevance", will return only the `_relevance_score. If "all" is supported, will return relevance score along with the vector and/or fts scores depending on query type | +| `truncation` | `bool` | `None` | Whether to truncate the input to satisfy the "context length limit" on the query and the documents. | + + +## Supported Scores for each query type +You can specify the type of scores you want the reranker to return. The following are the supported scores for each query type: + +### Hybrid Search +|`return_score`| Status | Description | +| --- | --- | --- | +| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column | +| `all` | ❌ Not Supported | Returns have vector(`_distance`) and FTS(`score`) along with Hybrid Search score(`_relevance_score`) | + +### Vector Search +|`return_score`| Status | Description | +| --- | --- | --- | +| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column | +| `all` | ✅ Supported | Returns have vector(`_distance`) along with Hybrid Search score(`_relevance_score`) | + +### FTS Search +|`return_score`| Status | Description | +| --- | --- | --- | +| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column | +| `all` | ✅ Supported | Returns have FTS(`score`) along with Hybrid Search score(`_relevance_score`) | \ No newline at end of file diff --git a/python/python/lancedb/embeddings/__init__.py b/python/python/lancedb/embeddings/__init__.py index 76da3ab4..afa127d7 100644 --- a/python/python/lancedb/embeddings/__init__.py +++ b/python/python/lancedb/embeddings/__init__.py @@ -27,3 +27,4 @@ from .imagebind import ImageBindEmbeddings from .utils import with_embeddings from .jinaai import JinaEmbeddings from .watsonx import WatsonxEmbeddings +from .voyageai import VoyageAIEmbeddingFunction diff --git a/python/python/lancedb/embeddings/voyageai.py b/python/python/lancedb/embeddings/voyageai.py new file mode 100644 index 00000000..161c5e43 --- /dev/null +++ b/python/python/lancedb/embeddings/voyageai.py @@ -0,0 +1,127 @@ +# Copyright (c) 2023. LanceDB Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import ClassVar, List, Union + +import numpy as np + +from ..util import attempt_import_or_raise +from .base import TextEmbeddingFunction +from .registry import register +from .utils import api_key_not_found_help, TEXT + + +@register("voyageai") +class VoyageAIEmbeddingFunction(TextEmbeddingFunction): + """ + An embedding function that uses the VoyageAI API + + https://docs.voyageai.com/docs/embeddings + + Parameters + ---------- + name: str + The name of the model to use. List of acceptable models: + + * voyage-3 + * voyage-3-lite + * voyage-finance-2 + * voyage-multilingual-2 + * voyage-law-2 + * voyage-code-2 + + + Examples + -------- + import lancedb + from lancedb.pydantic import LanceModel, Vector + from lancedb.embeddings import EmbeddingFunctionRegistry + + voyageai = EmbeddingFunctionRegistry + .get_instance() + .get("voyageai") + .create(name="voyage-3") + + class TextModel(LanceModel): + text: str = voyageai.SourceField() + vector: Vector(voyageai.ndims()) = voyageai.VectorField() + + data = [ { "text": "hello world" }, + { "text": "goodbye world" }] + + db = lancedb.connect("~/.lancedb") + tbl = db.create_table("test", schema=TextModel, mode="overwrite") + + tbl.add(data) + + """ + + name: str + client: ClassVar = None + + def ndims(self): + if self.name == "voyage-3-lite": + return 512 + elif self.name == "voyage-code-2": + return 1536 + elif self.name in [ + "voyage-3", + "voyage-finance-2", + "voyage-multilingual-2", + "voyage-law-2", + ]: + return 1024 + else: + raise ValueError(f"Model {self.name} not supported") + + def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]: + return self.compute_source_embeddings(query, input_type="query") + + def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]: + texts = self.sanitize_input(texts) + input_type = ( + kwargs.get("input_type") or "document" + ) # assume source input type if not passed by `compute_query_embeddings` + return self.generate_embeddings(texts, input_type=input_type) + + def generate_embeddings( + self, texts: Union[List[str], np.ndarray], *args, **kwargs + ) -> List[np.array]: + """ + Get the embeddings for the given texts + + Parameters + ---------- + texts: list[str] or np.ndarray (of str) + The texts to embed + input_type: Optional[str] + + truncation: Optional[bool] + """ + VoyageAIEmbeddingFunction._init_client() + rs = VoyageAIEmbeddingFunction.client.embed( + texts=texts, model=self.name, **kwargs + ) + + return [emb for emb in rs.embeddings] + + @staticmethod + def _init_client(): + if VoyageAIEmbeddingFunction.client is None: + voyageai = attempt_import_or_raise("voyageai") + if os.environ.get("VOYAGE_API_KEY") is None: + api_key_not_found_help("voyageai") + VoyageAIEmbeddingFunction.client = voyageai.Client( + os.environ["VOYAGE_API_KEY"] + ) diff --git a/python/python/lancedb/rerankers/__init__.py b/python/python/lancedb/rerankers/__init__.py index 93903a16..c3e27331 100644 --- a/python/python/lancedb/rerankers/__init__.py +++ b/python/python/lancedb/rerankers/__init__.py @@ -7,6 +7,7 @@ from .openai import OpenaiReranker from .jinaai import JinaReranker from .rrf import RRFReranker from .answerdotai import AnswerdotaiRerankers +from .voyageai import VoyageAIReranker __all__ = [ "Reranker", @@ -18,4 +19,5 @@ __all__ = [ "JinaReranker", "RRFReranker", "AnswerdotaiRerankers", + "VoyageAIReranker", ] diff --git a/python/python/lancedb/rerankers/voyageai.py b/python/python/lancedb/rerankers/voyageai.py new file mode 100644 index 00000000..d04a5ad4 --- /dev/null +++ b/python/python/lancedb/rerankers/voyageai.py @@ -0,0 +1,133 @@ +# Copyright (c) 2023. LanceDB Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from functools import cached_property +from typing import Union, Optional + +import pyarrow as pa + +from ..util import attempt_import_or_raise +from .base import Reranker + + +class VoyageAIReranker(Reranker): + """ + Reranks the results using the VoyageAI Rerank API. + https://docs.voyageai.com/docs/reranker + + Parameters + ---------- + model_name : str, default "rerank-english-v2.0" + The name of the cross encoder model to use. Available voyageai models are: + - rerank-2 + - rerank-2-lite + column : str, default "text" + The name of the column to use as input to the cross encoder model. + top_n : int, default None + The number of results to return. If None, will return all results. + return_score : str, default "relevance" + options are "relevance" or "all". Only "relevance" is supported for now. + api_key : str, default None + The API key to use. If None, will use the OPENAI_API_KEY environment variable. + truncation : Optional[bool], default None + """ + + def __init__( + self, + model_name: str, + column: str = "text", + top_n: Optional[int] = None, + return_score="relevance", + api_key: Optional[str] = None, + truncation: Optional[bool] = True, + ): + super().__init__(return_score) + self.model_name = model_name + self.column = column + self.top_n = top_n + self.api_key = api_key + self.truncation = truncation + + @cached_property + def _client(self): + voyageai = attempt_import_or_raise("voyageai") + if os.environ.get("VOYAGE_API_KEY") is None and self.api_key is None: + raise ValueError( + "VOYAGE_API_KEY not set. Either set it in your environment or \ + pass it as `api_key` argument to the VoyageAIReranker." + ) + return voyageai.Client( + api_key=os.environ.get("VOYAGE_API_KEY") or self.api_key, + ) + + def _rerank(self, result_set: pa.Table, query: str): + docs = result_set[self.column].to_pylist() + response = self._client.rerank( + query=query, + documents=docs, + top_k=self.top_n, + model=self.model_name, + truncation=self.truncation, + ) + results = ( + response.results + ) # returns list (text, idx, relevance) attributes sorted descending by score + indices, scores = list( + zip(*[(result.index, result.relevance_score) for result in results]) + ) # tuples + result_set = result_set.take(list(indices)) + # add the scores + result_set = result_set.append_column( + "_relevance_score", pa.array(scores, type=pa.float32()) + ) + + return result_set + + def rerank_hybrid( + self, + query: str, + vector_results: pa.Table, + fts_results: pa.Table, + ): + combined_results = self.merge_results(vector_results, fts_results) + combined_results = self._rerank(combined_results, query) + if self.score == "relevance": + combined_results = self._keep_relevance_score(combined_results) + elif self.score == "all": + raise NotImplementedError( + "return_score='all' not implemented for voyageai reranker" + ) + return combined_results + + def rerank_vector( + self, + query: str, + vector_results: pa.Table, + ): + result_set = self._rerank(vector_results, query) + if self.score == "relevance": + result_set = result_set.drop_columns(["_distance"]) + + return result_set + + def rerank_fts( + self, + query: str, + fts_results: pa.Table, + ): + result_set = self._rerank(fts_results, query) + if self.score == "relevance": + result_set = result_set.drop_columns(["_score"]) + + return result_set diff --git a/python/python/tests/test_embeddings.py b/python/python/tests/test_embeddings.py index e48fb209..a9f939ee 100644 --- a/python/python/tests/test_embeddings.py +++ b/python/python/tests/test_embeddings.py @@ -196,6 +196,7 @@ def test_add_optional_vector(tmp_path): "ollama", "cohere", "instructor", + "voyageai", ], ) def test_embedding_function_safe_model_dump(embedding_type): diff --git a/python/python/tests/test_embeddings_slow.py b/python/python/tests/test_embeddings_slow.py index 9e17ca66..58f9ff98 100644 --- a/python/python/tests/test_embeddings_slow.py +++ b/python/python/tests/test_embeddings_slow.py @@ -481,3 +481,22 @@ def test_ollama_embedding(tmp_path): json.dumps(dumped_model) except TypeError: pytest.fail("Failed to JSON serialize the dumped model") + + +@pytest.mark.slow +@pytest.mark.skipif( + os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set" +) +def test_voyageai_embedding_function(): + voyageai = get_registry().get("voyageai").create(name="voyage-3", max_retries=0) + + class TextModel(LanceModel): + text: str = voyageai.SourceField() + vector: Vector(voyageai.ndims()) = voyageai.VectorField() + + df = pd.DataFrame({"text": ["hello world", "goodbye world"]}) + db = lancedb.connect("~/lancedb") + tbl = db.create_table("test", schema=TextModel, mode="overwrite") + + tbl.add(df) + assert len(tbl.to_pandas()["vector"][0]) == voyageai.ndims() diff --git a/python/python/tests/test_rerankers.py b/python/python/tests/test_rerankers.py index f2f7c6cc..4e1c6898 100644 --- a/python/python/tests/test_rerankers.py +++ b/python/python/tests/test_rerankers.py @@ -16,6 +16,7 @@ from lancedb.rerankers import ( OpenaiReranker, JinaReranker, AnswerdotaiRerankers, + VoyageAIReranker, ) from lancedb.table import LanceTable @@ -344,3 +345,14 @@ def test_jina_reranker(tmp_path, use_tantivy): table, schema = get_test_table(tmp_path, use_tantivy) reranker = JinaReranker() _run_test_reranker(reranker, table, "single player experience", None, schema) + + +@pytest.mark.skipif( + os.environ.get("VOYAGE_API_KEY") is None, reason="VOYAGE_API_KEY not set" +) +@pytest.mark.parametrize("use_tantivy", [True, False]) +def test_voyageai_reranker(tmp_path, use_tantivy): + pytest.importorskip("voyageai") + reranker = VoyageAIReranker(model_name="rerank-2") + table, schema = get_test_table(tmp_path, use_tantivy) + _run_test_reranker(reranker, table, "single player experience", None, schema)