From eb9784d7f2f65c41b86f628deeb611194762ac22 Mon Sep 17 00:00:00 2001 From: Haoyu Weng Date: Mon, 30 Jun 2025 11:28:14 -0400 Subject: [PATCH] feat(python): batch Ollama embed calls (#2453) Other embedding integrations such as Cohere and OpenAI already send requests in batches. We should do that for Ollama too to improve throughput. The Ollama [`.embed` API](https://github.com/ollama/ollama-python/blob/63ca74762284100b2f0ad207bc00fa3d32720fbd/ollama/_client.py#L359-L378) was added in version 0.3.0 (almost a year ago) so I updated the version requirement in pyproject. ## Summary by CodeRabbit - **Bug Fixes** - Improved compatibility with newer versions of the "ollama" package by requiring version 0.3.0 or higher. - Enhanced embedding generation to process batches of texts more efficiently and reliably. - **Refactor** - Improved type consistency and clarity for embedding-related methods. --- python/pyproject.toml | 2 +- python/python/lancedb/embeddings/ollama.py | 31 +++++++++++----------- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/python/pyproject.toml b/python/pyproject.toml index ccc9115b..60b76e4d 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -85,7 +85,7 @@ embeddings = [ "boto3>=1.28.57", "awscli>=1.29.57", "botocore>=1.31.57", - "ollama", + "ollama>=0.3.0", "ibm-watsonx-ai>=1.1.2", ] azure = ["adlfs>=2024.2.0"] diff --git a/python/python/lancedb/embeddings/ollama.py b/python/python/lancedb/embeddings/ollama.py index 1dbc3305..0f4a9cdc 100644 --- a/python/python/lancedb/embeddings/ollama.py +++ b/python/python/lancedb/embeddings/ollama.py @@ -2,14 +2,15 @@ # SPDX-FileCopyrightText: Copyright The LanceDB Authors from functools import cached_property -from typing import TYPE_CHECKING, List, Optional, Union +from typing import TYPE_CHECKING, List, Optional, Sequence, Union + +import numpy as np from ..util import attempt_import_or_raise from .base import TextEmbeddingFunction from .registry import register if TYPE_CHECKING: - import numpy as np import ollama @@ -28,23 +29,21 @@ class OllamaEmbeddings(TextEmbeddingFunction): keep_alive: Optional[Union[float, str]] = None ollama_client_kwargs: Optional[dict] = {} - def ndims(self): + def ndims(self) -> int: return len(self.generate_embeddings(["foo"])[0]) - def _compute_embedding(self, text) -> Union["np.array", None]: - return ( - self._ollama_client.embeddings( - model=self.name, - prompt=text, - options=self.options, - keep_alive=self.keep_alive, - )["embedding"] - or None + def _compute_embedding(self, text: Sequence[str]) -> Sequence[Sequence[float]]: + response = self._ollama_client.embed( + model=self.name, + input=text, + options=self.options, + keep_alive=self.keep_alive, ) + return response.embeddings def generate_embeddings( - self, texts: Union[List[str], "np.ndarray"] - ) -> list[Union["np.array", None]]: + self, texts: Union[List[str], np.ndarray] + ) -> list[Union[np.array, None]]: """ Get the embeddings for the given texts @@ -54,8 +53,8 @@ class OllamaEmbeddings(TextEmbeddingFunction): The texts to embed """ # TODO retry, rate limit, token limit - embeddings = [self._compute_embedding(text) for text in texts] - return embeddings + embeddings = self._compute_embedding(texts) + return list(embeddings) @cached_property def _ollama_client(self) -> "ollama.Client":