docs: add jina reranker to index (#1427)

PR to add JinaReranker documentation page to the rerankers index
This commit is contained in:
Joan Fontanals
2024-07-09 11:09:35 +02:00
committed by GitHub
parent b4436e0804
commit cef24801f4
4 changed files with 209 additions and 53 deletions

View File

@@ -102,6 +102,7 @@ nav:
- Linear Combination Reranker: reranking/linear_combination.md
- Cross Encoder Reranker: reranking/cross_encoder.md
- ColBERT Reranker: reranking/colbert.md
- Jina Reranker: reranking/jina.md
- OpenAI Reranker: reranking/openai.md
- Building Custom Rerankers: reranking/custom_reranker.md
- Filtering: sql.md
@@ -184,6 +185,7 @@ nav:
- Linear Combination Reranker: reranking/linear_combination.md
- Cross Encoder Reranker: reranking/cross_encoder.md
- ColBERT Reranker: reranking/colbert.md
- Jina Reranker: reranking/jina.md
- OpenAI Reranker: reranking/openai.md
- Building Custom Rerankers: reranking/custom_reranker.md
- Filtering: sql.md

View File

@@ -427,6 +427,45 @@ Usage Example:
tbl.add(data)
```
### Jina Embeddings
Jina embeddings are used to generate embeddings for text and image data.
You also need to set the `JINA_API_KEY` environment variable to use the Jina API.
You can find a list of supported models under [https://jina.ai/embeddings/](https://jina.ai/embeddings/)
Supported parameters (to be passed in `create` method) are:
| Parameter | Type | Default Value | Description |
|---|---|---|---|
| `name` | `str` | `"jina-clip-v1"` | The model ID of the jina model to use |
Usage Example:
```python
import os
import lancedb
from lancedb.pydantic import LanceModel, Vector
from lancedb.embeddings import EmbeddingFunctionRegistry
os.environ['JINA_API_KEY'] = 'jina_*'
jina_embed = EmbeddingFunctionRegistry.get_instance().get("jina").create(name="jina-embeddings-v2-base-en")
class TextModel(LanceModel):
text: str = jina_embed.SourceField()
vector: Vector(jina_embed.ndims()) = jina_embed.VectorField()
data = [{"text": "hello world"},
{"text": "goodbye world"}]
db = lancedb.connect("~/.lancedb-2")
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(data)
```
### AWS Bedrock Text Embedding Functions
AWS Bedrock supports multiple base models for generating text embeddings. You need to setup the AWS credentials to use this embedding function.
You can do so by using `awscli` and also add your session_token:
@@ -630,3 +669,54 @@ print(actual.text == "bird")
```
If you have any questions about the embeddings API, supported models, or see a relevant model missing, please raise an issue [on GitHub](https://github.com/lancedb/lancedb/issues).
### Jina Embeddings
Jina embeddings can also be used to embed both text and image data, only some of the models support image data and you can check the list
under [https://jina.ai/embeddings/](https://jina.ai/embeddings/)
Supported parameters (to be passed in `create` method) are:
| Parameter | Type | Default Value | Description |
|---|---|---|---|
| `name` | `str` | `"jina-clip-v1"` | The model ID of the jina model to use |
Usage Example:
```python
import os
import requests
import lancedb
from lancedb.pydantic import LanceModel, Vector
from lancedb.embeddings import get_registry
import pandas as pd
os.environ['JINA_API_KEY'] = 'jina_*'
db = lancedb.connect("~/.lancedb")
func = get_registry().get("jina").create()
class Images(LanceModel):
label: str
image_uri: str = func.SourceField() # image uri as the source
image_bytes: bytes = func.SourceField() # image bytes as the source
vector: Vector(func.ndims()) = func.VectorField() # vector column
vec_from_bytes: Vector(func.ndims()) = func.VectorField() # Another vector column
table = db.create_table("images", schema=Images)
labels = ["cat", "cat", "dog", "dog", "horse", "horse"]
uris = [
"http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg",
"http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg",
"http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg",
"http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg",
"http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg",
"http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg",
]
# get each uri as bytes
image_bytes = [requests.get(uri).content for uri in uris]
table.add(
pd.DataFrame({"label": labels, "image_uri": uris, "image_bytes": image_bytes})
)
```

View File

@@ -15,7 +15,6 @@ LanceDB comes with some built-in rerankers. Some of the rerankers that are avail
Using rerankers is optional for vector and FTS. However, for hybrid search, rerankers are required. To use a reranker, you need to create an instance of the reranker and pass it to the `rerank` method of the query builder.
```python
import numpy
import lancedb
from lancedb.embeddings import get_registry
from lancedb.pydantic import LanceModel, Vector
@@ -54,6 +53,7 @@ LanceDB comes with some built-in rerankers. Here are some of the rerankers that
- [ColBERT Reranker](./colbert.md)
- [OpenAI Reranker](./openai.md)
- [Linear Combination Reranker](./linear_combination.md)
- [Jina Reranker](./jina.md)
## Creating Custom Rerankers

View File

@@ -15,8 +15,9 @@ import os
import io
import requests
import base64
import urllib.parse as urlparse
from typing import ClassVar, List, Union, Optional, TYPE_CHECKING
from urllib.parse import urlparse
from pathlib import Path
from typing import TYPE_CHECKING, ClassVar, List, Union, Optional, Any, Dict
import numpy as np
import pyarrow as pa
@@ -32,6 +33,14 @@ if TYPE_CHECKING:
API_URL = "https://api.jina.ai/v1/embeddings"
def is_valid_url(text):
try:
parsed = urlparse(text)
return bool(parsed.scheme) and bool(parsed.netloc)
except Exception:
return False
@register("jina")
class JinaEmbeddings(EmbeddingFunction):
"""
@@ -58,67 +67,35 @@ class JinaEmbeddings(EmbeddingFunction):
# TODO: fix hardcoding
return 768
def sanitize_input(self, inputs: IMAGES) -> Union[List[bytes], np.ndarray]:
def sanitize_input(
self, inputs: Union[TEXT, IMAGES]
) -> Union[List[Any], np.ndarray]:
"""
Sanitize the input to the embedding function.
"""
if isinstance(inputs, (str, bytes)):
if isinstance(inputs, (str, bytes, Path)):
inputs = [inputs]
elif isinstance(inputs, pa.Array):
inputs = inputs.to_pylist()
elif isinstance(inputs, pa.ChunkedArray):
inputs = inputs.combine_chunks().to_pylist()
else:
if isinstance(inputs, list):
inputs = inputs
else:
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(inputs, PIL.Image.Image):
inputs = [inputs]
return inputs
def compute_query_embeddings(
self, query: Union[str, "PIL.Image.Image"], *args, **kwargs
) -> List[np.ndarray]:
"""
Compute the embeddings for a given user query
Parameters
----------
query : Union[str, PIL.Image.Image]
The query to embed. A query can be either text or an image.
"""
if isinstance(query, str):
return self.generate_text_embeddings([query])
else:
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(query, PIL.Image.Image):
return [self.generate_image_embedding(query)]
else:
raise TypeError(
"JinaEmbeddingFunction supports str or PIL Image as query"
)
def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
texts = self.sanitize_input(texts)
return self.generate_text_embeddings(texts)
def generate_image_embedding(
self, image: Union[str, bytes, "PIL.Image.Image"]
) -> np.ndarray:
"""
Generate the embedding for a single image
Parameters
----------
image : Union[str, bytes, PIL.Image.Image]
The image to embed. If the image is a str, it is treated as a uri.
If the image is bytes, it is treated as the raw image bytes.
"""
PIL = attempt_import_or_raise("PIL", "pillow")
@staticmethod
def _generate_image_input_dict(image: Union[str, bytes, "PIL.Image.Image"]) -> Dict:
if isinstance(image, bytes):
image = {"image": base64.b64encode(image).decode("utf-8")}
if isinstance(image, PIL.Image.Image):
buffered = io.BytesIO()
image.save(buffered, format="PNG")
image_bytes = buffered.getvalue()
image = {"image": base64.b64encode(image_bytes).decode("utf-8")}
elif isinstance(image, str):
image_dict = {"image": base64.b64encode(image).decode("utf-8")}
elif isinstance(image, (str, Path)):
parsed = urlparse.urlparse(image)
# TODO handle drive letter on windows.
PIL = attempt_import_or_raise("PIL", "pillow")
if parsed.scheme == "file":
pil_image = PIL.Image.open(parsed.path)
elif parsed.scheme == "":
@@ -130,8 +107,95 @@ class JinaEmbeddings(EmbeddingFunction):
buffered = io.BytesIO()
pil_image.save(buffered, format="PNG")
image_bytes = buffered.getvalue()
image = {"image": base64.b64encode(image_bytes).decode("utf-8")}
return self._generate_embeddings(input=[image])[0]
image_dict = {"image": base64.b64encode(image_bytes).decode("utf-8")}
else:
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(image, PIL.Image.Image):
buffered = io.BytesIO()
image.save(buffered, format="PNG")
image_bytes = buffered.getvalue()
image_dict = {"image": base64.b64encode(image_bytes).decode("utf-8")}
else:
raise TypeError(
f"JinaEmbeddingFunction supports str, Path, bytes or PIL Image"
f" as query, but {type(image)} is given"
)
return image_dict
def compute_query_embeddings(
self, query: Union[str, bytes, "Path", "PIL.Image.Image"], *args, **kwargs
) -> List[np.ndarray]:
"""
Compute the embeddings for a given user query
Parameters
----------
query : Union[str, PIL.Image.Image]
The query to embed. A query can be either text or an image.
"""
if isinstance(query, str):
if not is_valid_url(query):
return self.generate_text_embeddings([query])
else:
return [self.generate_image_embedding(query)]
elif isinstance(query, (Path, bytes)):
return [self.generate_image_embedding(query)]
else:
PIL = attempt_import_or_raise("PIL", "pillow")
if isinstance(query, PIL.Image.Image):
return [self.generate_image_embedding(query)]
else:
raise TypeError(
f"JinaEmbeddingFunction supports str, Path, bytes or PIL Image"
f" as query, but {type(query)} is given"
)
def compute_source_embeddings(
self, inputs: Union[TEXT, IMAGES], *args, **kwargs
) -> List[np.array]:
inputs = self.sanitize_input(inputs)
model_inputs = []
image_inputs = 0
def process_input(input, model_inputs, image_inputs):
if isinstance(input, str):
if not is_valid_url(input):
model_inputs.append({"text": input})
else:
image_inputs += 1
model_inputs.append(self._generate_image_input_dict(input))
elif isinstance(input, list):
for _input in input:
image_inputs = process_input(_input, model_inputs, image_inputs)
else:
image_inputs += 1
model_inputs.append(self._generate_image_input_dict(input))
return image_inputs
for input in inputs:
image_inputs = process_input(input, model_inputs, image_inputs)
if image_inputs > 0:
return self._generate_embeddings(model_inputs)
else:
return self.generate_text_embeddings(inputs)
def generate_image_embedding(
self, image: Union[str, bytes, Path, "PIL.Image.Image"]
) -> np.ndarray:
"""
Generate the embedding for a single image
Parameters
----------
image : Union[str, bytes, PIL.Image.Image]
The image to embed. If the image is a str, it is treated as a uri.
If the image is bytes, it is treated as the raw image bytes.
"""
image_dict = self._generate_image_input_dict(image)
return self._generate_embeddings(input=[image_dict])[0]
def generate_text_embeddings(
self, texts: Union[List[str], np.ndarray], *args, **kwargs