mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 23:12:58 +00:00
docs: add jina reranker to index (#1427)
PR to add JinaReranker documentation page to the rerankers index
This commit is contained in:
@@ -15,8 +15,9 @@ import os
|
||||
import io
|
||||
import requests
|
||||
import base64
|
||||
import urllib.parse as urlparse
|
||||
from typing import ClassVar, List, Union, Optional, TYPE_CHECKING
|
||||
from urllib.parse import urlparse
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, ClassVar, List, Union, Optional, Any, Dict
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
@@ -32,6 +33,14 @@ if TYPE_CHECKING:
|
||||
API_URL = "https://api.jina.ai/v1/embeddings"
|
||||
|
||||
|
||||
def is_valid_url(text):
|
||||
try:
|
||||
parsed = urlparse(text)
|
||||
return bool(parsed.scheme) and bool(parsed.netloc)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
@register("jina")
|
||||
class JinaEmbeddings(EmbeddingFunction):
|
||||
"""
|
||||
@@ -58,67 +67,35 @@ class JinaEmbeddings(EmbeddingFunction):
|
||||
# TODO: fix hardcoding
|
||||
return 768
|
||||
|
||||
def sanitize_input(self, inputs: IMAGES) -> Union[List[bytes], np.ndarray]:
|
||||
def sanitize_input(
|
||||
self, inputs: Union[TEXT, IMAGES]
|
||||
) -> Union[List[Any], np.ndarray]:
|
||||
"""
|
||||
Sanitize the input to the embedding function.
|
||||
"""
|
||||
if isinstance(inputs, (str, bytes)):
|
||||
if isinstance(inputs, (str, bytes, Path)):
|
||||
inputs = [inputs]
|
||||
elif isinstance(inputs, pa.Array):
|
||||
inputs = inputs.to_pylist()
|
||||
elif isinstance(inputs, pa.ChunkedArray):
|
||||
inputs = inputs.combine_chunks().to_pylist()
|
||||
else:
|
||||
if isinstance(inputs, list):
|
||||
inputs = inputs
|
||||
else:
|
||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
||||
if isinstance(inputs, PIL.Image.Image):
|
||||
inputs = [inputs]
|
||||
return inputs
|
||||
|
||||
def compute_query_embeddings(
|
||||
self, query: Union[str, "PIL.Image.Image"], *args, **kwargs
|
||||
) -> List[np.ndarray]:
|
||||
"""
|
||||
Compute the embeddings for a given user query
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query : Union[str, PIL.Image.Image]
|
||||
The query to embed. A query can be either text or an image.
|
||||
"""
|
||||
if isinstance(query, str):
|
||||
return self.generate_text_embeddings([query])
|
||||
else:
|
||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
||||
if isinstance(query, PIL.Image.Image):
|
||||
return [self.generate_image_embedding(query)]
|
||||
else:
|
||||
raise TypeError(
|
||||
"JinaEmbeddingFunction supports str or PIL Image as query"
|
||||
)
|
||||
|
||||
def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
|
||||
texts = self.sanitize_input(texts)
|
||||
return self.generate_text_embeddings(texts)
|
||||
|
||||
def generate_image_embedding(
|
||||
self, image: Union[str, bytes, "PIL.Image.Image"]
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Generate the embedding for a single image
|
||||
|
||||
Parameters
|
||||
----------
|
||||
image : Union[str, bytes, PIL.Image.Image]
|
||||
The image to embed. If the image is a str, it is treated as a uri.
|
||||
If the image is bytes, it is treated as the raw image bytes.
|
||||
"""
|
||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
||||
@staticmethod
|
||||
def _generate_image_input_dict(image: Union[str, bytes, "PIL.Image.Image"]) -> Dict:
|
||||
if isinstance(image, bytes):
|
||||
image = {"image": base64.b64encode(image).decode("utf-8")}
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
buffered = io.BytesIO()
|
||||
image.save(buffered, format="PNG")
|
||||
image_bytes = buffered.getvalue()
|
||||
image = {"image": base64.b64encode(image_bytes).decode("utf-8")}
|
||||
elif isinstance(image, str):
|
||||
image_dict = {"image": base64.b64encode(image).decode("utf-8")}
|
||||
elif isinstance(image, (str, Path)):
|
||||
parsed = urlparse.urlparse(image)
|
||||
# TODO handle drive letter on windows.
|
||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
||||
if parsed.scheme == "file":
|
||||
pil_image = PIL.Image.open(parsed.path)
|
||||
elif parsed.scheme == "":
|
||||
@@ -130,8 +107,95 @@ class JinaEmbeddings(EmbeddingFunction):
|
||||
buffered = io.BytesIO()
|
||||
pil_image.save(buffered, format="PNG")
|
||||
image_bytes = buffered.getvalue()
|
||||
image = {"image": base64.b64encode(image_bytes).decode("utf-8")}
|
||||
return self._generate_embeddings(input=[image])[0]
|
||||
image_dict = {"image": base64.b64encode(image_bytes).decode("utf-8")}
|
||||
else:
|
||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
||||
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
buffered = io.BytesIO()
|
||||
image.save(buffered, format="PNG")
|
||||
image_bytes = buffered.getvalue()
|
||||
image_dict = {"image": base64.b64encode(image_bytes).decode("utf-8")}
|
||||
else:
|
||||
raise TypeError(
|
||||
f"JinaEmbeddingFunction supports str, Path, bytes or PIL Image"
|
||||
f" as query, but {type(image)} is given"
|
||||
)
|
||||
return image_dict
|
||||
|
||||
def compute_query_embeddings(
|
||||
self, query: Union[str, bytes, "Path", "PIL.Image.Image"], *args, **kwargs
|
||||
) -> List[np.ndarray]:
|
||||
"""
|
||||
Compute the embeddings for a given user query
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query : Union[str, PIL.Image.Image]
|
||||
The query to embed. A query can be either text or an image.
|
||||
"""
|
||||
if isinstance(query, str):
|
||||
if not is_valid_url(query):
|
||||
return self.generate_text_embeddings([query])
|
||||
else:
|
||||
return [self.generate_image_embedding(query)]
|
||||
elif isinstance(query, (Path, bytes)):
|
||||
return [self.generate_image_embedding(query)]
|
||||
else:
|
||||
PIL = attempt_import_or_raise("PIL", "pillow")
|
||||
|
||||
if isinstance(query, PIL.Image.Image):
|
||||
return [self.generate_image_embedding(query)]
|
||||
else:
|
||||
raise TypeError(
|
||||
f"JinaEmbeddingFunction supports str, Path, bytes or PIL Image"
|
||||
f" as query, but {type(query)} is given"
|
||||
)
|
||||
|
||||
def compute_source_embeddings(
|
||||
self, inputs: Union[TEXT, IMAGES], *args, **kwargs
|
||||
) -> List[np.array]:
|
||||
inputs = self.sanitize_input(inputs)
|
||||
model_inputs = []
|
||||
image_inputs = 0
|
||||
|
||||
def process_input(input, model_inputs, image_inputs):
|
||||
if isinstance(input, str):
|
||||
if not is_valid_url(input):
|
||||
model_inputs.append({"text": input})
|
||||
else:
|
||||
image_inputs += 1
|
||||
model_inputs.append(self._generate_image_input_dict(input))
|
||||
elif isinstance(input, list):
|
||||
for _input in input:
|
||||
image_inputs = process_input(_input, model_inputs, image_inputs)
|
||||
else:
|
||||
image_inputs += 1
|
||||
model_inputs.append(self._generate_image_input_dict(input))
|
||||
return image_inputs
|
||||
|
||||
for input in inputs:
|
||||
image_inputs = process_input(input, model_inputs, image_inputs)
|
||||
|
||||
if image_inputs > 0:
|
||||
return self._generate_embeddings(model_inputs)
|
||||
else:
|
||||
return self.generate_text_embeddings(inputs)
|
||||
|
||||
def generate_image_embedding(
|
||||
self, image: Union[str, bytes, Path, "PIL.Image.Image"]
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Generate the embedding for a single image
|
||||
|
||||
Parameters
|
||||
----------
|
||||
image : Union[str, bytes, PIL.Image.Image]
|
||||
The image to embed. If the image is a str, it is treated as a uri.
|
||||
If the image is bytes, it is treated as the raw image bytes.
|
||||
"""
|
||||
image_dict = self._generate_image_input_dict(image)
|
||||
return self._generate_embeddings(input=[image_dict])[0]
|
||||
|
||||
def generate_text_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray], *args, **kwargs
|
||||
|
||||
Reference in New Issue
Block a user