mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 07:09:57 +00:00
feat(python): batch Ollama embed calls (#2453)
Other embedding integrations such as Cohere and OpenAI already send
requests in batches. We should do that for Ollama too to improve
throughput.
The Ollama [`.embed`
API](63ca747622/ollama/_client.py (L359-L378))
was added in version 0.3.0 (almost a year ago) so I updated the version
requirement in pyproject.
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit
- **Bug Fixes**
- Improved compatibility with newer versions of the "ollama" package by
requiring version 0.3.0 or higher.
- Enhanced embedding generation to process batches of texts more
efficiently and reliably.
- **Refactor**
- Improved type consistency and clarity for embedding-related methods.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
@@ -2,14 +2,15 @@
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, List, Optional, Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import TextEmbeddingFunction
|
||||
from .registry import register
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
import ollama
|
||||
|
||||
|
||||
@@ -28,23 +29,21 @@ class OllamaEmbeddings(TextEmbeddingFunction):
|
||||
keep_alive: Optional[Union[float, str]] = None
|
||||
ollama_client_kwargs: Optional[dict] = {}
|
||||
|
||||
def ndims(self):
|
||||
def ndims(self) -> int:
|
||||
return len(self.generate_embeddings(["foo"])[0])
|
||||
|
||||
def _compute_embedding(self, text) -> Union["np.array", None]:
|
||||
return (
|
||||
self._ollama_client.embeddings(
|
||||
model=self.name,
|
||||
prompt=text,
|
||||
options=self.options,
|
||||
keep_alive=self.keep_alive,
|
||||
)["embedding"]
|
||||
or None
|
||||
def _compute_embedding(self, text: Sequence[str]) -> Sequence[Sequence[float]]:
|
||||
response = self._ollama_client.embed(
|
||||
model=self.name,
|
||||
input=text,
|
||||
options=self.options,
|
||||
keep_alive=self.keep_alive,
|
||||
)
|
||||
return response.embeddings
|
||||
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], "np.ndarray"]
|
||||
) -> list[Union["np.array", None]]:
|
||||
self, texts: Union[List[str], np.ndarray]
|
||||
) -> list[Union[np.array, None]]:
|
||||
"""
|
||||
Get the embeddings for the given texts
|
||||
|
||||
@@ -54,8 +53,8 @@ class OllamaEmbeddings(TextEmbeddingFunction):
|
||||
The texts to embed
|
||||
"""
|
||||
# TODO retry, rate limit, token limit
|
||||
embeddings = [self._compute_embedding(text) for text in texts]
|
||||
return embeddings
|
||||
embeddings = self._compute_embedding(texts)
|
||||
return list(embeddings)
|
||||
|
||||
@cached_property
|
||||
def _ollama_client(self) -> "ollama.Client":
|
||||
|
||||
Reference in New Issue
Block a user