mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-08 12:52:58 +00:00
Add ollama embeddings function (#1263)
Following the docs [here](https://lancedb.github.io/lancedb/python/python/#lancedb.embeddings.openai.OpenAIEmbeddings) I've been trying to use ollama embedding via the OpenAI API interface, but unfortunately I couldn't get it to work (possibly related to https://github.com/ollama/ollama/issues/2416) Given the popularity of ollama I thought it could be helpful to have a dedicated Ollama Embedding function in lancedb. Very much welcome any thought on this or my code etc. Thanks!
This commit is contained in:
@@ -80,6 +80,7 @@ embeddings = [
|
||||
"boto3>=1.28.57",
|
||||
"awscli>=1.29.57",
|
||||
"botocore>=1.31.57",
|
||||
"ollama",
|
||||
]
|
||||
azure = ["adlfs>=2024.2.0"]
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ from .bedrock import BedRockText
|
||||
from .cohere import CohereEmbeddingFunction
|
||||
from .gemini_text import GeminiText
|
||||
from .instructor import InstructorEmbeddingFunction
|
||||
from .ollama import OllamaEmbeddings
|
||||
from .open_clip import OpenClipEmbeddings
|
||||
from .openai import OpenAIEmbeddings
|
||||
from .registry import EmbeddingFunctionRegistry, get_registry
|
||||
|
||||
69
python/python/lancedb/embeddings/ollama.py
Normal file
69
python/python/lancedb/embeddings/ollama.py
Normal file
@@ -0,0 +1,69 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import TextEmbeddingFunction
|
||||
from .registry import register
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
|
||||
|
||||
@register("ollama")
|
||||
class OllamaEmbeddings(TextEmbeddingFunction):
|
||||
"""
|
||||
An embedding function that uses Ollama
|
||||
|
||||
https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embeddings
|
||||
https://ollama.com/blog/embedding-models
|
||||
"""
|
||||
|
||||
name: str = "nomic-embed-text"
|
||||
host: str = "http://localhost:11434"
|
||||
options: Optional[dict] = None # type = ollama.Options
|
||||
keep_alive: Optional[Union[float, str]] = None
|
||||
ollama_client_kwargs: Optional[dict] = {}
|
||||
|
||||
def ndims(self):
|
||||
return len(self.generate_embeddings(["foo"])[0])
|
||||
|
||||
def _compute_embedding(self, text):
|
||||
return self._ollama_client.embeddings(
|
||||
model=self.name,
|
||||
prompt=text,
|
||||
options=self.options,
|
||||
keep_alive=self.keep_alive,
|
||||
)["embedding"]
|
||||
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], "np.ndarray"]
|
||||
) -> List["np.array"]:
|
||||
"""
|
||||
Get the embeddings for the given texts
|
||||
|
||||
Parameters
|
||||
----------
|
||||
texts: list[str] or np.ndarray (of str)
|
||||
The texts to embed
|
||||
"""
|
||||
# TODO retry, rate limit, token limit
|
||||
embeddings = [self._compute_embedding(text) for text in texts]
|
||||
return embeddings
|
||||
|
||||
@cached_property
|
||||
def _ollama_client(self):
|
||||
ollama = attempt_import_or_raise("ollama")
|
||||
# ToDo explore ollama.AsyncClient
|
||||
return ollama.Client(host=self.host, **self.ollama_client_kwargs)
|
||||
@@ -45,7 +45,9 @@ except Exception:
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("alias", ["sentence-transformers", "openai", "huggingface"])
|
||||
@pytest.mark.parametrize(
|
||||
"alias", ["sentence-transformers", "openai", "huggingface", "ollama"]
|
||||
)
|
||||
def test_basic_text_embeddings(alias, tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
registry = get_registry()
|
||||
|
||||
Reference in New Issue
Block a user