mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-26 16:30:41 +00:00
Add ollama embeddings function (#1263)
Following the docs [here](https://lancedb.github.io/lancedb/python/python/#lancedb.embeddings.openai.OpenAIEmbeddings) I've been trying to use ollama embedding via the OpenAI API interface, but unfortunately I couldn't get it to work (possibly related to https://github.com/ollama/ollama/issues/2416) Given the popularity of ollama I thought it could be helpful to have a dedicated Ollama Embedding function in lancedb. Very much welcome any thought on this or my code etc. Thanks!
This commit is contained in:
@@ -206,6 +206,44 @@ print(actual.text)
|
||||
```
|
||||
|
||||
|
||||
### Ollama embeddings
|
||||
Generate embeddings via the [ollama](https://github.com/ollama/ollama-python) python library. More details:
|
||||
|
||||
- [Ollama docs on embeddings](https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embeddings)
|
||||
- [Ollama blog on embeddings](https://ollama.com/blog/embedding-models)
|
||||
|
||||
| Parameter | Type | Default Value | Description |
|
||||
|------------------------|----------------------------|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| `name` | `str` | `nomic-embed-text` | The name of the model. |
|
||||
| `host` | `str` | `http://localhost:11434` | The Ollama host to connect to. |
|
||||
| `options` | `ollama.Options` or `dict` | `None` | Additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`. |
|
||||
| `keep_alive` | `float` or `str` | `"5m"` | Controls how long the model will stay loaded into memory following the request. |
|
||||
| `ollama_client_kwargs` | `dict` | `{}` | kwargs that can be past to the `ollama.Client`. |
|
||||
|
||||
```python
|
||||
import lancedb
|
||||
from lancedb.pydantic import LanceModel, Vector
|
||||
from lancedb.embeddings import get_registry
|
||||
|
||||
db = lancedb.connect("/tmp/db")
|
||||
func = get_registry().get("ollama").create(name="nomic-embed-text")
|
||||
|
||||
class Words(LanceModel):
|
||||
text: str = func.SourceField()
|
||||
vector: Vector(func.ndims()) = func.VectorField()
|
||||
|
||||
table = db.create_table("words", schema=Words, mode="overwrite")
|
||||
table.add([
|
||||
{"text": "hello world"},
|
||||
{"text": "goodbye world"}
|
||||
])
|
||||
|
||||
query = "greetings"
|
||||
actual = table.search(query).limit(1).to_pydantic(Words)[0]
|
||||
print(actual.text)
|
||||
```
|
||||
|
||||
|
||||
### OpenAI embeddings
|
||||
LanceDB registers the OpenAI embeddings function in the registry by default, as `openai`. Below are the parameters that you can customize when creating the instances:
|
||||
|
||||
|
||||
@@ -80,6 +80,7 @@ embeddings = [
|
||||
"boto3>=1.28.57",
|
||||
"awscli>=1.29.57",
|
||||
"botocore>=1.31.57",
|
||||
"ollama",
|
||||
]
|
||||
azure = ["adlfs>=2024.2.0"]
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ from .bedrock import BedRockText
|
||||
from .cohere import CohereEmbeddingFunction
|
||||
from .gemini_text import GeminiText
|
||||
from .instructor import InstructorEmbeddingFunction
|
||||
from .ollama import OllamaEmbeddings
|
||||
from .open_clip import OpenClipEmbeddings
|
||||
from .openai import OpenAIEmbeddings
|
||||
from .registry import EmbeddingFunctionRegistry, get_registry
|
||||
|
||||
69
python/python/lancedb/embeddings/ollama.py
Normal file
69
python/python/lancedb/embeddings/ollama.py
Normal file
@@ -0,0 +1,69 @@
|
||||
# Copyright (c) 2023. LanceDB Developers
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import TextEmbeddingFunction
|
||||
from .registry import register
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
|
||||
|
||||
@register("ollama")
|
||||
class OllamaEmbeddings(TextEmbeddingFunction):
|
||||
"""
|
||||
An embedding function that uses Ollama
|
||||
|
||||
https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embeddings
|
||||
https://ollama.com/blog/embedding-models
|
||||
"""
|
||||
|
||||
name: str = "nomic-embed-text"
|
||||
host: str = "http://localhost:11434"
|
||||
options: Optional[dict] = None # type = ollama.Options
|
||||
keep_alive: Optional[Union[float, str]] = None
|
||||
ollama_client_kwargs: Optional[dict] = {}
|
||||
|
||||
def ndims(self):
|
||||
return len(self.generate_embeddings(["foo"])[0])
|
||||
|
||||
def _compute_embedding(self, text):
|
||||
return self._ollama_client.embeddings(
|
||||
model=self.name,
|
||||
prompt=text,
|
||||
options=self.options,
|
||||
keep_alive=self.keep_alive,
|
||||
)["embedding"]
|
||||
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], "np.ndarray"]
|
||||
) -> List["np.array"]:
|
||||
"""
|
||||
Get the embeddings for the given texts
|
||||
|
||||
Parameters
|
||||
----------
|
||||
texts: list[str] or np.ndarray (of str)
|
||||
The texts to embed
|
||||
"""
|
||||
# TODO retry, rate limit, token limit
|
||||
embeddings = [self._compute_embedding(text) for text in texts]
|
||||
return embeddings
|
||||
|
||||
@cached_property
|
||||
def _ollama_client(self):
|
||||
ollama = attempt_import_or_raise("ollama")
|
||||
# ToDo explore ollama.AsyncClient
|
||||
return ollama.Client(host=self.host, **self.ollama_client_kwargs)
|
||||
@@ -45,7 +45,9 @@ except Exception:
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("alias", ["sentence-transformers", "openai", "huggingface"])
|
||||
@pytest.mark.parametrize(
|
||||
"alias", ["sentence-transformers", "openai", "huggingface", "ollama"]
|
||||
)
|
||||
def test_basic_text_embeddings(alias, tmp_path):
|
||||
db = lancedb.connect(tmp_path)
|
||||
registry = get_registry()
|
||||
|
||||
Reference in New Issue
Block a user