feat: add Jina integration in Python for Embedding and Reranker (#1424)

Integration of Jina Embeddings and Rerankers through its API
This commit is contained in:
Joan Fontanals
2024-07-04 22:04:43 +02:00
committed by GitHub
parent a5ff623443
commit 08d25c5a80
6 changed files with 408 additions and 0 deletions

View File

@@ -68,6 +68,39 @@ table.add(
]
)
query = "greetings"
actual = table.search(query).limit(1).to_pydantic(Words)[0]
print(actual.text)
```
### Jina Embeddings
LanceDB registers the JinaAI embeddings function in the registry as `jina`. You can pass any supported model name to the `create`. By default it uses `"jina-clip-v1"`.
`jina-clip-v1` can handle both text and images and other models only support `text`.
You need to pass `JINA_API_KEY` in the environment variable or pass it as `api_key` to `create` method.
```python
import os
import lancedb
from lancedb.pydantic import LanceModel, Vector
from lancedb.embeddings import get_registry
os.environ['JINA_API_KEY'] = "jina_*"
db = lancedb.connect("/tmp/db")
func = get_registry().get("jina").create(name="jina-clip-v1")
class Words(LanceModel):
text: str = func.SourceField()
vector: Vector(func.ndims()) = func.VectorField()
table = db.create_table("words", schema=Words, mode="overwrite")
table.add(
[
{"text": "hello world"},
{"text": "goodbye world"}
]
)
query = "greetings"
actual = table.search(query).limit(1).to_pydantic(Words)[0]
print(actual.text)

View File

@@ -0,0 +1,78 @@
# Jina Reranker
This re-ranker uses the [Jina](https://jina.ai/reranker/) API to rerank the search results. You can use this re-ranker by passing `JinaReranker()` to the `rerank()` method. Note that you'll either need to set the `JINA_API_KEY` environment variable or pass the `api_key` argument to use this re-ranker.
!!! note
Supported Query Types: Hybrid, Vector, FTS
```python
import os
import lancedb
from lancedb.embeddings import get_registry
from lancedb.pydantic import LanceModel, Vector
from lancedb.rerankers import JinaReranker
os.environ['JINA_API_KEY'] = "jina_*"
embedder = get_registry().get("jina").create()
db = lancedb.connect("~/.lancedb")
class Schema(LanceModel):
text: str = embedder.SourceField()
vector: Vector(embedder.ndims()) = embedder.VectorField()
data = [
{"text": "hello world"},
{"text": "goodbye world"}
]
tbl = db.create_table("test", schema=Schema, mode="overwrite")
tbl.add(data)
reranker = JinaReranker(api_key="key")
# Run vector search with a reranker
result = tbl.search("hello").rerank(reranker=reranker).to_list()
# Run FTS search with a reranker
result = tbl.search("hello", query_type="fts").rerank(reranker=reranker).to_list()
# Run hybrid search with a reranker
tbl.create_fts_index("text", replace=True)
result = tbl.search("hello", query_type="hybrid").rerank(reranker=reranker).to_list()
```
Accepted Arguments
----------------
| Argument | Type | Default | Description |
| --- | --- | --- | --- |
| `model_name` | `str` | `"jina-reranker-v2-base-multilingual"` | The name of the reranker model to use. You can find the list of available models in https://jina.ai/reranker/|
| `column` | `str` | `"text"` | The name of the column to use as input to the cross encoder model. |
| `top_n` | `str` | `None` | The number of results to return. If None, will return all results. |
| `api_key` | `str` | `None` | The API key for the Jina API. If not provided, the `JINA_API_KEY` environment variable is used. |
| `return_score` | str | `"relevance"` | Options are "relevance" or "all". The type of score to return. If "relevance", will return only the `_relevance_score. If "all" is supported, will return relevance score along with the vector and/or fts scores depending on query type |
## Supported Scores for each query type
You can specify the type of scores you want the reranker to return. The following are the supported scores for each query type:
### Hybrid Search
|`return_score`| Status | Description |
| --- | --- | --- |
| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column |
| `all` | ❌ Not Supported | Returns have vector(`_distance`) and FTS(`score`) along with Hybrid Search score(`_relevance_score`) |
### Vector Search
|`return_score`| Status | Description |
| --- | --- | --- |
| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column |
| `all` | ✅ Supported | Returns have vector(`_distance`) along with Hybrid Search score(`_relevance_score`) |
### FTS Search
|`return_score`| Status | Description |
| --- | --- | --- |
| `relevance` | ✅ Supported | Returns only have the `_relevance_score` column |
| `all` | ✅ Supported | Returns have FTS(`score`) along with Hybrid Search score(`_relevance_score`) |

View File

@@ -25,3 +25,4 @@ from .gte import GteEmbeddings
from .transformers import TransformersEmbeddingFunction, ColbertEmbeddings
from .imagebind import ImageBindEmbeddings
from .utils import with_embeddings
from .jinaai import JinaEmbeddings

View File

@@ -0,0 +1,172 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import io
import requests
import base64
import urllib.parse as urlparse
from typing import ClassVar, List, Union, Optional, TYPE_CHECKING
import numpy as np
import pyarrow as pa
from ..util import attempt_import_or_raise
from .base import EmbeddingFunction
from .registry import register
from .utils import api_key_not_found_help, TEXT, IMAGES, url_retrieve
if TYPE_CHECKING:
import PIL
API_URL = "https://api.jina.ai/v1/embeddings"
@register("jina")
class JinaEmbeddings(EmbeddingFunction):
"""
An embedding function that uses the Jina API
https://jina.ai/embeddings/
Parameters
----------
name: str, default "jina-clip-v1". Note that some models support both image
and text embeddings and some just text embedding
api_key: str, default None
The api key to access Jina API. If you pass None, you can set JINA_API_KEY
environment variable
"""
name: str = "jina-clip-v1"
api_key: Optional[str] = None
_session: ClassVar = None
def ndims(self):
# TODO: fix hardcoding
return 768
def sanitize_input(self, inputs: IMAGES) -> Union[List[bytes], np.ndarray]:
"""
Sanitize the input to the embedding function.
"""
if isinstance(inputs, (str, bytes)):
inputs = [inputs]
elif isinstance(inputs, pa.Array):
inputs = inputs.to_pylist()
elif isinstance(inputs, pa.ChunkedArray):
inputs = inputs.combine_chunks().to_pylist()
return inputs
def compute_query_embeddings(
self, query: Union[str, "PIL.Image.Image"], *args, **kwargs
) -> List[np.ndarray]:
"""
Compute the embeddings for a given user query
Parameters
----------
query : Union[str, PIL.Image.Image]
The query to embed. A query can be either text or an image.
"""
if isinstance(query, str):
return self.generate_text_embeddings([query])
else:
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(query, PIL.Image.Image):
return [self.generate_image_embedding(query)]
else:
raise TypeError(
"JinaEmbeddingFunction supports str or PIL Image as query"
)
def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
texts = self.sanitize_input(texts)
return self.generate_text_embeddings(texts)
def generate_image_embedding(
self, image: Union[str, bytes, "PIL.Image.Image"]
) -> np.ndarray:
"""
Generate the embedding for a single image
Parameters
----------
image : Union[str, bytes, PIL.Image.Image]
The image to embed. If the image is a str, it is treated as a uri.
If the image is bytes, it is treated as the raw image bytes.
"""
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(image, bytes):
image = {"image": base64.b64encode(image).decode("utf-8")}
if isinstance(image, PIL.Image.Image):
buffered = io.BytesIO()
image.save(buffered, format="PNG")
image_bytes = buffered.getvalue()
image = {"image": base64.b64encode(image_bytes).decode("utf-8")}
elif isinstance(image, str):
parsed = urlparse.urlparse(image)
# TODO handle drive letter on windows.
if parsed.scheme == "file":
pil_image = PIL.Image.open(parsed.path)
elif parsed.scheme == "":
pil_image = PIL.Image.open(image if os.name == "nt" else parsed.path)
elif parsed.scheme.startswith("http"):
pil_image = PIL.Image.open(io.BytesIO(url_retrieve(image)))
else:
raise NotImplementedError("Only local and http(s) urls are supported")
buffered = io.BytesIO()
pil_image.save(buffered, format="PNG")
image_bytes = buffered.getvalue()
image = {"image": base64.b64encode(image_bytes).decode("utf-8")}
return self._generate_embeddings(input=[image])[0]
def generate_text_embeddings(
self, texts: Union[List[str], np.ndarray], *args, **kwargs
) -> List[np.array]:
return self._generate_embeddings(input=texts)
def _generate_embeddings(self, input: List, *args, **kwargs) -> List[np.array]:
"""
Get the embeddings for the given texts
Parameters
----------
texts: list[str] or np.ndarray (of str)
The texts to embed
"""
self._init_client()
resp = JinaEmbeddings._session.post( # type: ignore
API_URL, json={"input": input, "model": self.name}
).json()
if "data" not in resp:
raise RuntimeError(resp["detail"])
embeddings = resp["data"]
# Sort resulting embeddings by index
sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) # type: ignore
return [result["embedding"] for result in sorted_embeddings]
def _init_client(self):
if JinaEmbeddings._session is None:
if self.api_key is None and os.environ.get("JINA_API_KEY") is None:
api_key_not_found_help("jina")
api_key = self.api_key or os.environ.get("JINA_API_KEY")
JinaEmbeddings._session = requests.Session()
JinaEmbeddings._session.headers.update(
{"Authorization": f"Bearer {api_key}", "Accept-Encoding": "identity"}
)

View File

@@ -4,6 +4,7 @@ from .colbert import ColbertReranker
from .cross_encoder import CrossEncoderReranker
from .linear_combination import LinearCombinationReranker
from .openai import OpenaiReranker
from .jinaai import JinaReranker
__all__ = [
"Reranker",
@@ -12,4 +13,5 @@ __all__ = [
"LinearCombinationReranker",
"OpenaiReranker",
"ColbertReranker",
"JinaReranker",
]

View File

@@ -0,0 +1,122 @@
import os
import requests
from functools import cached_property
from typing import Union
import pyarrow as pa
from .base import Reranker
API_URL = "https://api.jina.ai/v1/rerank"
class JinaReranker(Reranker):
"""
Reranks the results using the Jina Rerank API.
https://jina.ai/rerank
Parameters
----------
model_name : str, default "jina-reranker-v2-base-multilingual"
The name of the cross reanker model to use
column : str, default "text"
The name of the column to use as input to the cross encoder model.
top_n : str, default None
The number of results to return. If None, will return all results.
api_key : str, default None
The api key to access Jina API. If you pass None, you can set JINA_API_KEY
environment variable
"""
def __init__(
self,
model_name: str = "jina-reranker-v2-base-multilingual",
column: str = "text",
top_n: Union[int, None] = None,
return_score="relevance",
api_key: Union[str, None] = None,
):
super().__init__(return_score)
self.model_name = model_name
self.column = column
self.top_n = top_n
self.api_key = api_key
@cached_property
def _client(self):
if os.environ.get("JINA_API_KEY") is None and self.api_key is None:
raise ValueError(
"JINA_API_KEY not set. Either set it in your environment or \
pass it as `api_key` argument to the JinaReranker."
)
self.api_key = self.api_key or os.environ.get("JINA_API_KEY")
self._session = requests.Session()
self._session.headers.update(
{"Authorization": f"Bearer {self.api_key}", "Accept-Encoding": "identity"}
)
return self._session
def _rerank(self, result_set: pa.Table, query: str):
docs = result_set[self.column].to_pylist()
response = self._client.post( # type: ignore
API_URL,
json={
"query": query,
"documents": docs,
"model": self.model_name,
"top_n": self.top_n,
},
).json()
if "results" not in response:
raise RuntimeError(response["detail"])
results = response["results"]
indices, scores = list(
zip(*[(result["index"], result["relevance_score"]) for result in results])
) # tuples
result_set = result_set.take(list(indices))
# add the scores
result_set = result_set.append_column(
"_relevance_score", pa.array(scores, type=pa.float32())
)
return result_set
def rerank_hybrid(
self,
query: str,
vector_results: pa.Table,
fts_results: pa.Table,
):
combined_results = self.merge_results(vector_results, fts_results)
combined_results = self._rerank(combined_results, query)
if self.score == "relevance":
combined_results = combined_results.drop_columns(["score", "_distance"])
elif self.score == "all":
raise NotImplementedError(
"return_score='all' not implemented for JinaReranker"
)
return combined_results
def rerank_vector(
self,
query: str,
vector_results: pa.Table,
):
result_set = self._rerank(vector_results, query)
if self.score == "relevance":
result_set = result_set.drop_columns(["_distance"])
return result_set
def rerank_fts(
self,
query: str,
fts_results: pa.Table,
):
result_set = self._rerank(fts_results, query)
if self.score == "relevance":
result_set = result_set.drop_columns(["score"])
return result_set