Add ollama embeddings function (#1263)

Following the docs
[here](https://lancedb.github.io/lancedb/python/python/#lancedb.embeddings.openai.OpenAIEmbeddings)
I've been trying to use ollama embedding via the OpenAI API interface,
but unfortunately I couldn't get it to work (possibly related to
https://github.com/ollama/ollama/issues/2416)

Given the popularity of ollama I thought it could be helpful to have a
dedicated Ollama Embedding function in lancedb.

Very much welcome any thought on this or my code etc. Thanks!
This commit is contained in:
asmith26
2024-05-13 08:39:19 +01:00
committed by GitHub
parent b37c58342e
commit 3850d5fb35
5 changed files with 112 additions and 1 deletions

View File

@@ -80,6 +80,7 @@ embeddings = [
"boto3>=1.28.57",
"awscli>=1.29.57",
"botocore>=1.31.57",
"ollama",
]
azure = ["adlfs>=2024.2.0"]

View File

@@ -16,6 +16,7 @@ from .bedrock import BedRockText
from .cohere import CohereEmbeddingFunction
from .gemini_text import GeminiText
from .instructor import InstructorEmbeddingFunction
from .ollama import OllamaEmbeddings
from .open_clip import OpenClipEmbeddings
from .openai import OpenAIEmbeddings
from .registry import EmbeddingFunctionRegistry, get_registry

View File

@@ -0,0 +1,69 @@
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import cached_property
from typing import TYPE_CHECKING, List, Optional, Union
from ..util import attempt_import_or_raise
from .base import TextEmbeddingFunction
from .registry import register
if TYPE_CHECKING:
import numpy as np
@register("ollama")
class OllamaEmbeddings(TextEmbeddingFunction):
"""
An embedding function that uses Ollama
https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embeddings
https://ollama.com/blog/embedding-models
"""
name: str = "nomic-embed-text"
host: str = "http://localhost:11434"
options: Optional[dict] = None # type = ollama.Options
keep_alive: Optional[Union[float, str]] = None
ollama_client_kwargs: Optional[dict] = {}
def ndims(self):
return len(self.generate_embeddings(["foo"])[0])
def _compute_embedding(self, text):
return self._ollama_client.embeddings(
model=self.name,
prompt=text,
options=self.options,
keep_alive=self.keep_alive,
)["embedding"]
def generate_embeddings(
self, texts: Union[List[str], "np.ndarray"]
) -> List["np.array"]:
"""
Get the embeddings for the given texts
Parameters
----------
texts: list[str] or np.ndarray (of str)
The texts to embed
"""
# TODO retry, rate limit, token limit
embeddings = [self._compute_embedding(text) for text in texts]
return embeddings
@cached_property
def _ollama_client(self):
ollama = attempt_import_or_raise("ollama")
# ToDo explore ollama.AsyncClient
return ollama.Client(host=self.host, **self.ollama_client_kwargs)

View File

@@ -45,7 +45,9 @@ except Exception:
@pytest.mark.slow
@pytest.mark.parametrize("alias", ["sentence-transformers", "openai", "huggingface"])
@pytest.mark.parametrize(
"alias", ["sentence-transformers", "openai", "huggingface", "ollama"]
)
def test_basic_text_embeddings(alias, tmp_path):
db = lancedb.connect(tmp_path)
registry = get_registry()