From 3850d5fb3554d04f9045de0148a766380e8c1866 Mon Sep 17 00:00:00 2001 From: asmith26 Date: Mon, 13 May 2024 08:39:19 +0100 Subject: [PATCH] Add ollama embeddings function (#1263) Following the docs [here](https://lancedb.github.io/lancedb/python/python/#lancedb.embeddings.openai.OpenAIEmbeddings) I've been trying to use ollama embedding via the OpenAI API interface, but unfortunately I couldn't get it to work (possibly related to https://github.com/ollama/ollama/issues/2416) Given the popularity of ollama I thought it could be helpful to have a dedicated Ollama Embedding function in lancedb. Very much welcome any thought on this or my code etc. Thanks! --- .../embeddings/default_embedding_functions.md | 38 ++++++++++ python/pyproject.toml | 1 + python/python/lancedb/embeddings/__init__.py | 1 + python/python/lancedb/embeddings/ollama.py | 69 +++++++++++++++++++ python/python/tests/test_embeddings_slow.py | 4 +- 5 files changed, 112 insertions(+), 1 deletion(-) create mode 100644 python/python/lancedb/embeddings/ollama.py diff --git a/docs/src/embeddings/default_embedding_functions.md b/docs/src/embeddings/default_embedding_functions.md index be05d9fa..06b343b5 100644 --- a/docs/src/embeddings/default_embedding_functions.md +++ b/docs/src/embeddings/default_embedding_functions.md @@ -206,6 +206,44 @@ print(actual.text) ``` +### Ollama embeddings +Generate embeddings via the [ollama](https://github.com/ollama/ollama-python) python library. More details: + +- [Ollama docs on embeddings](https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embeddings) +- [Ollama blog on embeddings](https://ollama.com/blog/embedding-models) + +| Parameter | Type | Default Value | Description | +|------------------------|----------------------------|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------| +| `name` | `str` | `nomic-embed-text` | The name of the model. | +| `host` | `str` | `http://localhost:11434` | The Ollama host to connect to. | +| `options` | `ollama.Options` or `dict` | `None` | Additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`. | +| `keep_alive` | `float` or `str` | `"5m"` | Controls how long the model will stay loaded into memory following the request. | +| `ollama_client_kwargs` | `dict` | `{}` | kwargs that can be past to the `ollama.Client`. | + +```python +import lancedb +from lancedb.pydantic import LanceModel, Vector +from lancedb.embeddings import get_registry + +db = lancedb.connect("/tmp/db") +func = get_registry().get("ollama").create(name="nomic-embed-text") + +class Words(LanceModel): + text: str = func.SourceField() + vector: Vector(func.ndims()) = func.VectorField() + +table = db.create_table("words", schema=Words, mode="overwrite") +table.add([ + {"text": "hello world"}, + {"text": "goodbye world"} +]) + +query = "greetings" +actual = table.search(query).limit(1).to_pydantic(Words)[0] +print(actual.text) +``` + + ### OpenAI embeddings LanceDB registers the OpenAI embeddings function in the registry by default, as `openai`. Below are the parameters that you can customize when creating the instances: diff --git a/python/pyproject.toml b/python/pyproject.toml index 285d09dd..6ad974c2 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -80,6 +80,7 @@ embeddings = [ "boto3>=1.28.57", "awscli>=1.29.57", "botocore>=1.31.57", + "ollama", ] azure = ["adlfs>=2024.2.0"] diff --git a/python/python/lancedb/embeddings/__init__.py b/python/python/lancedb/embeddings/__init__.py index 1b88dd24..9567d071 100644 --- a/python/python/lancedb/embeddings/__init__.py +++ b/python/python/lancedb/embeddings/__init__.py @@ -16,6 +16,7 @@ from .bedrock import BedRockText from .cohere import CohereEmbeddingFunction from .gemini_text import GeminiText from .instructor import InstructorEmbeddingFunction +from .ollama import OllamaEmbeddings from .open_clip import OpenClipEmbeddings from .openai import OpenAIEmbeddings from .registry import EmbeddingFunctionRegistry, get_registry diff --git a/python/python/lancedb/embeddings/ollama.py b/python/python/lancedb/embeddings/ollama.py new file mode 100644 index 00000000..6e1be917 --- /dev/null +++ b/python/python/lancedb/embeddings/ollama.py @@ -0,0 +1,69 @@ +# Copyright (c) 2023. LanceDB Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import cached_property +from typing import TYPE_CHECKING, List, Optional, Union + +from ..util import attempt_import_or_raise +from .base import TextEmbeddingFunction +from .registry import register + +if TYPE_CHECKING: + import numpy as np + + +@register("ollama") +class OllamaEmbeddings(TextEmbeddingFunction): + """ + An embedding function that uses Ollama + + https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embeddings + https://ollama.com/blog/embedding-models + """ + + name: str = "nomic-embed-text" + host: str = "http://localhost:11434" + options: Optional[dict] = None # type = ollama.Options + keep_alive: Optional[Union[float, str]] = None + ollama_client_kwargs: Optional[dict] = {} + + def ndims(self): + return len(self.generate_embeddings(["foo"])[0]) + + def _compute_embedding(self, text): + return self._ollama_client.embeddings( + model=self.name, + prompt=text, + options=self.options, + keep_alive=self.keep_alive, + )["embedding"] + + def generate_embeddings( + self, texts: Union[List[str], "np.ndarray"] + ) -> List["np.array"]: + """ + Get the embeddings for the given texts + + Parameters + ---------- + texts: list[str] or np.ndarray (of str) + The texts to embed + """ + # TODO retry, rate limit, token limit + embeddings = [self._compute_embedding(text) for text in texts] + return embeddings + + @cached_property + def _ollama_client(self): + ollama = attempt_import_or_raise("ollama") + # ToDo explore ollama.AsyncClient + return ollama.Client(host=self.host, **self.ollama_client_kwargs) diff --git a/python/python/tests/test_embeddings_slow.py b/python/python/tests/test_embeddings_slow.py index 14139307..5be77ed7 100644 --- a/python/python/tests/test_embeddings_slow.py +++ b/python/python/tests/test_embeddings_slow.py @@ -45,7 +45,9 @@ except Exception: @pytest.mark.slow -@pytest.mark.parametrize("alias", ["sentence-transformers", "openai", "huggingface"]) +@pytest.mark.parametrize( + "alias", ["sentence-transformers", "openai", "huggingface", "ollama"] +) def test_basic_text_embeddings(alias, tmp_path): db = lancedb.connect(tmp_path) registry = get_registry()