From cef24801f44c211c9ad87b8052ca1bc9b0d3586d Mon Sep 17 00:00:00 2001 From: Joan Fontanals Date: Tue, 9 Jul 2024 11:09:35 +0200 Subject: [PATCH] docs: add jina reranker to index (#1427) PR to add JinaReranker documentation page to the rerankers index --- docs/mkdocs.yml | 2 + .../embeddings/default_embedding_functions.md | 90 ++++++++++ docs/src/reranking/index.md | 2 +- python/python/lancedb/embeddings/jinaai.py | 168 ++++++++++++------ 4 files changed, 209 insertions(+), 53 deletions(-) diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 54f1ac40..0413704b 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -102,6 +102,7 @@ nav: - Linear Combination Reranker: reranking/linear_combination.md - Cross Encoder Reranker: reranking/cross_encoder.md - ColBERT Reranker: reranking/colbert.md + - Jina Reranker: reranking/jina.md - OpenAI Reranker: reranking/openai.md - Building Custom Rerankers: reranking/custom_reranker.md - Filtering: sql.md @@ -184,6 +185,7 @@ nav: - Linear Combination Reranker: reranking/linear_combination.md - Cross Encoder Reranker: reranking/cross_encoder.md - ColBERT Reranker: reranking/colbert.md + - Jina Reranker: reranking/jina.md - OpenAI Reranker: reranking/openai.md - Building Custom Rerankers: reranking/custom_reranker.md - Filtering: sql.md diff --git a/docs/src/embeddings/default_embedding_functions.md b/docs/src/embeddings/default_embedding_functions.md index e6a09c1f..ae026acf 100644 --- a/docs/src/embeddings/default_embedding_functions.md +++ b/docs/src/embeddings/default_embedding_functions.md @@ -427,6 +427,45 @@ Usage Example: tbl.add(data) ``` +### Jina Embeddings +Jina embeddings are used to generate embeddings for text and image data. +You also need to set the `JINA_API_KEY` environment variable to use the Jina API. + +You can find a list of supported models under [https://jina.ai/embeddings/](https://jina.ai/embeddings/) + +Supported parameters (to be passed in `create` method) are: + +| Parameter | Type | Default Value | Description | +|---|---|---|---| +| `name` | `str` | `"jina-clip-v1"` | The model ID of the jina model to use | + +Usage Example: + +```python + import os + import lancedb + from lancedb.pydantic import LanceModel, Vector + from lancedb.embeddings import EmbeddingFunctionRegistry + + os.environ['JINA_API_KEY'] = 'jina_*' + + jina_embed = EmbeddingFunctionRegistry.get_instance().get("jina").create(name="jina-embeddings-v2-base-en") + + + class TextModel(LanceModel): + text: str = jina_embed.SourceField() + vector: Vector(jina_embed.ndims()) = jina_embed.VectorField() + + + data = [{"text": "hello world"}, + {"text": "goodbye world"}] + + db = lancedb.connect("~/.lancedb-2") + tbl = db.create_table("test", schema=TextModel, mode="overwrite") + + tbl.add(data) +``` + ### AWS Bedrock Text Embedding Functions AWS Bedrock supports multiple base models for generating text embeddings. You need to setup the AWS credentials to use this embedding function. You can do so by using `awscli` and also add your session_token: @@ -630,3 +669,54 @@ print(actual.text == "bird") ``` If you have any questions about the embeddings API, supported models, or see a relevant model missing, please raise an issue [on GitHub](https://github.com/lancedb/lancedb/issues). + +### Jina Embeddings +Jina embeddings can also be used to embed both text and image data, only some of the models support image data and you can check the list +under [https://jina.ai/embeddings/](https://jina.ai/embeddings/) + +Supported parameters (to be passed in `create` method) are: + +| Parameter | Type | Default Value | Description | +|---|---|---|---| +| `name` | `str` | `"jina-clip-v1"` | The model ID of the jina model to use | + +Usage Example: + +```python + import os + import requests + import lancedb + from lancedb.pydantic import LanceModel, Vector + from lancedb.embeddings import get_registry + import pandas as pd + + os.environ['JINA_API_KEY'] = 'jina_*' + + db = lancedb.connect("~/.lancedb") + func = get_registry().get("jina").create() + + + class Images(LanceModel): + label: str + image_uri: str = func.SourceField() # image uri as the source + image_bytes: bytes = func.SourceField() # image bytes as the source + vector: Vector(func.ndims()) = func.VectorField() # vector column + vec_from_bytes: Vector(func.ndims()) = func.VectorField() # Another vector column + + + table = db.create_table("images", schema=Images) + labels = ["cat", "cat", "dog", "dog", "horse", "horse"] + uris = [ + "http://farm1.staticflickr.com/53/167798175_7c7845bbbd_z.jpg", + "http://farm1.staticflickr.com/134/332220238_da527d8140_z.jpg", + "http://farm9.staticflickr.com/8387/8602747737_2e5c2a45d4_z.jpg", + "http://farm5.staticflickr.com/4092/5017326486_1f46057f5f_z.jpg", + "http://farm9.staticflickr.com/8216/8434969557_d37882c42d_z.jpg", + "http://farm6.staticflickr.com/5142/5835678453_4f3a4edb45_z.jpg", + ] + # get each uri as bytes + image_bytes = [requests.get(uri).content for uri in uris] + table.add( + pd.DataFrame({"label": labels, "image_uri": uris, "image_bytes": image_bytes}) + ) +``` \ No newline at end of file diff --git a/docs/src/reranking/index.md b/docs/src/reranking/index.md index 20199524..d2a15d6b 100644 --- a/docs/src/reranking/index.md +++ b/docs/src/reranking/index.md @@ -15,7 +15,6 @@ LanceDB comes with some built-in rerankers. Some of the rerankers that are avail Using rerankers is optional for vector and FTS. However, for hybrid search, rerankers are required. To use a reranker, you need to create an instance of the reranker and pass it to the `rerank` method of the query builder. ```python -import numpy import lancedb from lancedb.embeddings import get_registry from lancedb.pydantic import LanceModel, Vector @@ -54,6 +53,7 @@ LanceDB comes with some built-in rerankers. Here are some of the rerankers that - [ColBERT Reranker](./colbert.md) - [OpenAI Reranker](./openai.md) - [Linear Combination Reranker](./linear_combination.md) +- [Jina Reranker](./jina.md) ## Creating Custom Rerankers diff --git a/python/python/lancedb/embeddings/jinaai.py b/python/python/lancedb/embeddings/jinaai.py index 8f6d2369..6619627d 100644 --- a/python/python/lancedb/embeddings/jinaai.py +++ b/python/python/lancedb/embeddings/jinaai.py @@ -15,8 +15,9 @@ import os import io import requests import base64 -import urllib.parse as urlparse -from typing import ClassVar, List, Union, Optional, TYPE_CHECKING +from urllib.parse import urlparse +from pathlib import Path +from typing import TYPE_CHECKING, ClassVar, List, Union, Optional, Any, Dict import numpy as np import pyarrow as pa @@ -32,6 +33,14 @@ if TYPE_CHECKING: API_URL = "https://api.jina.ai/v1/embeddings" +def is_valid_url(text): + try: + parsed = urlparse(text) + return bool(parsed.scheme) and bool(parsed.netloc) + except Exception: + return False + + @register("jina") class JinaEmbeddings(EmbeddingFunction): """ @@ -58,67 +67,35 @@ class JinaEmbeddings(EmbeddingFunction): # TODO: fix hardcoding return 768 - def sanitize_input(self, inputs: IMAGES) -> Union[List[bytes], np.ndarray]: + def sanitize_input( + self, inputs: Union[TEXT, IMAGES] + ) -> Union[List[Any], np.ndarray]: """ Sanitize the input to the embedding function. """ - if isinstance(inputs, (str, bytes)): + if isinstance(inputs, (str, bytes, Path)): inputs = [inputs] elif isinstance(inputs, pa.Array): inputs = inputs.to_pylist() elif isinstance(inputs, pa.ChunkedArray): inputs = inputs.combine_chunks().to_pylist() + else: + if isinstance(inputs, list): + inputs = inputs + else: + PIL = attempt_import_or_raise("PIL", "pillow") + if isinstance(inputs, PIL.Image.Image): + inputs = [inputs] return inputs - def compute_query_embeddings( - self, query: Union[str, "PIL.Image.Image"], *args, **kwargs - ) -> List[np.ndarray]: - """ - Compute the embeddings for a given user query - - Parameters - ---------- - query : Union[str, PIL.Image.Image] - The query to embed. A query can be either text or an image. - """ - if isinstance(query, str): - return self.generate_text_embeddings([query]) - else: - PIL = attempt_import_or_raise("PIL", "pillow") - if isinstance(query, PIL.Image.Image): - return [self.generate_image_embedding(query)] - else: - raise TypeError( - "JinaEmbeddingFunction supports str or PIL Image as query" - ) - - def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]: - texts = self.sanitize_input(texts) - return self.generate_text_embeddings(texts) - - def generate_image_embedding( - self, image: Union[str, bytes, "PIL.Image.Image"] - ) -> np.ndarray: - """ - Generate the embedding for a single image - - Parameters - ---------- - image : Union[str, bytes, PIL.Image.Image] - The image to embed. If the image is a str, it is treated as a uri. - If the image is bytes, it is treated as the raw image bytes. - """ - PIL = attempt_import_or_raise("PIL", "pillow") + @staticmethod + def _generate_image_input_dict(image: Union[str, bytes, "PIL.Image.Image"]) -> Dict: if isinstance(image, bytes): - image = {"image": base64.b64encode(image).decode("utf-8")} - if isinstance(image, PIL.Image.Image): - buffered = io.BytesIO() - image.save(buffered, format="PNG") - image_bytes = buffered.getvalue() - image = {"image": base64.b64encode(image_bytes).decode("utf-8")} - elif isinstance(image, str): + image_dict = {"image": base64.b64encode(image).decode("utf-8")} + elif isinstance(image, (str, Path)): parsed = urlparse.urlparse(image) # TODO handle drive letter on windows. + PIL = attempt_import_or_raise("PIL", "pillow") if parsed.scheme == "file": pil_image = PIL.Image.open(parsed.path) elif parsed.scheme == "": @@ -130,8 +107,95 @@ class JinaEmbeddings(EmbeddingFunction): buffered = io.BytesIO() pil_image.save(buffered, format="PNG") image_bytes = buffered.getvalue() - image = {"image": base64.b64encode(image_bytes).decode("utf-8")} - return self._generate_embeddings(input=[image])[0] + image_dict = {"image": base64.b64encode(image_bytes).decode("utf-8")} + else: + PIL = attempt_import_or_raise("PIL", "pillow") + + if isinstance(image, PIL.Image.Image): + buffered = io.BytesIO() + image.save(buffered, format="PNG") + image_bytes = buffered.getvalue() + image_dict = {"image": base64.b64encode(image_bytes).decode("utf-8")} + else: + raise TypeError( + f"JinaEmbeddingFunction supports str, Path, bytes or PIL Image" + f" as query, but {type(image)} is given" + ) + return image_dict + + def compute_query_embeddings( + self, query: Union[str, bytes, "Path", "PIL.Image.Image"], *args, **kwargs + ) -> List[np.ndarray]: + """ + Compute the embeddings for a given user query + + Parameters + ---------- + query : Union[str, PIL.Image.Image] + The query to embed. A query can be either text or an image. + """ + if isinstance(query, str): + if not is_valid_url(query): + return self.generate_text_embeddings([query]) + else: + return [self.generate_image_embedding(query)] + elif isinstance(query, (Path, bytes)): + return [self.generate_image_embedding(query)] + else: + PIL = attempt_import_or_raise("PIL", "pillow") + + if isinstance(query, PIL.Image.Image): + return [self.generate_image_embedding(query)] + else: + raise TypeError( + f"JinaEmbeddingFunction supports str, Path, bytes or PIL Image" + f" as query, but {type(query)} is given" + ) + + def compute_source_embeddings( + self, inputs: Union[TEXT, IMAGES], *args, **kwargs + ) -> List[np.array]: + inputs = self.sanitize_input(inputs) + model_inputs = [] + image_inputs = 0 + + def process_input(input, model_inputs, image_inputs): + if isinstance(input, str): + if not is_valid_url(input): + model_inputs.append({"text": input}) + else: + image_inputs += 1 + model_inputs.append(self._generate_image_input_dict(input)) + elif isinstance(input, list): + for _input in input: + image_inputs = process_input(_input, model_inputs, image_inputs) + else: + image_inputs += 1 + model_inputs.append(self._generate_image_input_dict(input)) + return image_inputs + + for input in inputs: + image_inputs = process_input(input, model_inputs, image_inputs) + + if image_inputs > 0: + return self._generate_embeddings(model_inputs) + else: + return self.generate_text_embeddings(inputs) + + def generate_image_embedding( + self, image: Union[str, bytes, Path, "PIL.Image.Image"] + ) -> np.ndarray: + """ + Generate the embedding for a single image + + Parameters + ---------- + image : Union[str, bytes, PIL.Image.Image] + The image to embed. If the image is a str, it is treated as a uri. + If the image is bytes, it is treated as the raw image bytes. + """ + image_dict = self._generate_image_input_dict(image) + return self._generate_embeddings(input=[image_dict])[0] def generate_text_embeddings( self, texts: Union[List[str], np.ndarray], *args, **kwargs