docs: add jina reranker to index (#1427)

PR to add JinaReranker documentation page to the rerankers index
This commit is contained in:
Joan Fontanals
2024-07-09 11:09:35 +02:00
committed by GitHub
parent b4436e0804
commit cef24801f4
4 changed files with 209 additions and 53 deletions

View File

@@ -15,8 +15,9 @@ import os
import io
import requests
import base64
import urllib.parse as urlparse
from typing import ClassVar, List, Union, Optional, TYPE_CHECKING
from urllib.parse import urlparse
from pathlib import Path
from typing import TYPE_CHECKING, ClassVar, List, Union, Optional, Any, Dict
import numpy as np
import pyarrow as pa
@@ -32,6 +33,14 @@ if TYPE_CHECKING:
API_URL = "https://api.jina.ai/v1/embeddings"
def is_valid_url(text):
try:
parsed = urlparse(text)
return bool(parsed.scheme) and bool(parsed.netloc)
except Exception:
return False
@register("jina")
class JinaEmbeddings(EmbeddingFunction):
"""
@@ -58,67 +67,35 @@ class JinaEmbeddings(EmbeddingFunction):
# TODO: fix hardcoding
return 768
def sanitize_input(self, inputs: IMAGES) -> Union[List[bytes], np.ndarray]:
def sanitize_input(
self, inputs: Union[TEXT, IMAGES]
) -> Union[List[Any], np.ndarray]:
"""
Sanitize the input to the embedding function.
"""
if isinstance(inputs, (str, bytes)):
if isinstance(inputs, (str, bytes, Path)):
inputs = [inputs]
elif isinstance(inputs, pa.Array):
inputs = inputs.to_pylist()
elif isinstance(inputs, pa.ChunkedArray):
inputs = inputs.combine_chunks().to_pylist()
else:
if isinstance(inputs, list):
inputs = inputs
else:
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(inputs, PIL.Image.Image):
inputs = [inputs]
return inputs
def compute_query_embeddings(
self, query: Union[str, "PIL.Image.Image"], *args, **kwargs
) -> List[np.ndarray]:
"""
Compute the embeddings for a given user query
Parameters
----------
query : Union[str, PIL.Image.Image]
The query to embed. A query can be either text or an image.
"""
if isinstance(query, str):
return self.generate_text_embeddings([query])
else:
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(query, PIL.Image.Image):
return [self.generate_image_embedding(query)]
else:
raise TypeError(
"JinaEmbeddingFunction supports str or PIL Image as query"
)
def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
texts = self.sanitize_input(texts)
return self.generate_text_embeddings(texts)
def generate_image_embedding(
self, image: Union[str, bytes, "PIL.Image.Image"]
) -> np.ndarray:
"""
Generate the embedding for a single image
Parameters
----------
image : Union[str, bytes, PIL.Image.Image]
The image to embed. If the image is a str, it is treated as a uri.
If the image is bytes, it is treated as the raw image bytes.
"""
PIL = attempt_import_or_raise("PIL", "pillow")
@staticmethod
def _generate_image_input_dict(image: Union[str, bytes, "PIL.Image.Image"]) -> Dict:
if isinstance(image, bytes):
image = {"image": base64.b64encode(image).decode("utf-8")}
if isinstance(image, PIL.Image.Image):
buffered = io.BytesIO()
image.save(buffered, format="PNG")
image_bytes = buffered.getvalue()
image = {"image": base64.b64encode(image_bytes).decode("utf-8")}
elif isinstance(image, str):
image_dict = {"image": base64.b64encode(image).decode("utf-8")}
elif isinstance(image, (str, Path)):
parsed = urlparse.urlparse(image)
# TODO handle drive letter on windows.
PIL = attempt_import_or_raise("PIL", "pillow")
if parsed.scheme == "file":
pil_image = PIL.Image.open(parsed.path)
elif parsed.scheme == "":
@@ -130,8 +107,95 @@ class JinaEmbeddings(EmbeddingFunction):
buffered = io.BytesIO()
pil_image.save(buffered, format="PNG")
image_bytes = buffered.getvalue()
image = {"image": base64.b64encode(image_bytes).decode("utf-8")}
return self._generate_embeddings(input=[image])[0]
image_dict = {"image": base64.b64encode(image_bytes).decode("utf-8")}
else:
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(image, PIL.Image.Image):
buffered = io.BytesIO()
image.save(buffered, format="PNG")
image_bytes = buffered.getvalue()
image_dict = {"image": base64.b64encode(image_bytes).decode("utf-8")}
else:
raise TypeError(
f"JinaEmbeddingFunction supports str, Path, bytes or PIL Image"
f" as query, but {type(image)} is given"
)
return image_dict
def compute_query_embeddings(
self, query: Union[str, bytes, "Path", "PIL.Image.Image"], *args, **kwargs
) -> List[np.ndarray]:
"""
Compute the embeddings for a given user query
Parameters
----------
query : Union[str, PIL.Image.Image]
The query to embed. A query can be either text or an image.
"""
if isinstance(query, str):
if not is_valid_url(query):
return self.generate_text_embeddings([query])
else:
return [self.generate_image_embedding(query)]
elif isinstance(query, (Path, bytes)):
return [self.generate_image_embedding(query)]
else:
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(query, PIL.Image.Image):
return [self.generate_image_embedding(query)]
else:
raise TypeError(
f"JinaEmbeddingFunction supports str, Path, bytes or PIL Image"
f" as query, but {type(query)} is given"
)
def compute_source_embeddings(
self, inputs: Union[TEXT, IMAGES], *args, **kwargs
) -> List[np.array]:
inputs = self.sanitize_input(inputs)
model_inputs = []
image_inputs = 0
def process_input(input, model_inputs, image_inputs):
if isinstance(input, str):
if not is_valid_url(input):
model_inputs.append({"text": input})
else:
image_inputs += 1
model_inputs.append(self._generate_image_input_dict(input))
elif isinstance(input, list):
for _input in input:
image_inputs = process_input(_input, model_inputs, image_inputs)
else:
image_inputs += 1
model_inputs.append(self._generate_image_input_dict(input))
return image_inputs
for input in inputs:
image_inputs = process_input(input, model_inputs, image_inputs)
if image_inputs > 0:
return self._generate_embeddings(model_inputs)
else:
return self.generate_text_embeddings(inputs)
def generate_image_embedding(
self, image: Union[str, bytes, Path, "PIL.Image.Image"]
) -> np.ndarray:
"""
Generate the embedding for a single image
Parameters
----------
image : Union[str, bytes, PIL.Image.Image]
The image to embed. If the image is a str, it is treated as a uri.
If the image is bytes, it is treated as the raw image bytes.
"""
image_dict = self._generate_image_input_dict(image)
return self._generate_embeddings(input=[image_dict])[0]
def generate_text_embeddings(
self, texts: Union[List[str], np.ndarray], *args, **kwargs