diff --git a/python/python/lancedb/embeddings/openai.py b/python/python/lancedb/embeddings/openai.py index 4af0efce..b87faa73 100644 --- a/python/python/lancedb/embeddings/openai.py +++ b/python/python/lancedb/embeddings/openai.py @@ -10,16 +10,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os from functools import cached_property -from typing import List, Optional, Union - -import numpy as np +from typing import TYPE_CHECKING, List, Optional, Union from ..util import attempt_import_or_raise from .base import TextEmbeddingFunction from .registry import register -from .utils import api_key_not_found_help + +if TYPE_CHECKING: + import numpy as np @register("openai") @@ -28,10 +27,34 @@ class OpenAIEmbeddings(TextEmbeddingFunction): An embedding function that uses the OpenAI API https://platform.openai.com/docs/guides/embeddings + + This can also be used for open source models that + are compatible with the OpenAI API. + + Notes + ----- + If you're running an Ollama server locally, + you can just override the `base_url` parameter + and provide the Ollama embedding model you want + to use (https://ollama.com/library): + + ```python + from lancedb.embeddings import get_registry + openai = get_registry().get("openai") + embedding_function = openai.create( + name="", + base_url="http://localhost:11434", + ) + ``` + """ name: str = "text-embedding-ada-002" dim: Optional[int] = None + base_url: Optional[str] = None + default_headers: Optional[dict] = None + organization: Optional[str] = None + api_key: Optional[str] = None def ndims(self): return self._ndims @@ -56,8 +79,8 @@ class OpenAIEmbeddings(TextEmbeddingFunction): raise ValueError(f"Unknown model name {self.name}") def generate_embeddings( - self, texts: Union[List[str], np.ndarray] - ) -> List[np.array]: + self, texts: Union[List[str], "np.ndarray"] + ) -> List["np.array"]: """ Get the embeddings for the given texts @@ -70,15 +93,25 @@ class OpenAIEmbeddings(TextEmbeddingFunction): if self.name == "text-embedding-ada-002": rs = self._openai_client.embeddings.create(input=texts, model=self.name) else: - rs = self._openai_client.embeddings.create( - input=texts, model=self.name, dimensions=self.ndims() - ) + kwargs = { + "input": texts, + "model": self.name, + } + if self.dim: + kwargs["dimensions"] = self.dim + rs = self._openai_client.embeddings.create(**kwargs) return [v.embedding for v in rs.data] @cached_property def _openai_client(self): openai = attempt_import_or_raise("openai") - - if not os.environ.get("OPENAI_API_KEY"): - api_key_not_found_help("openai") - return openai.OpenAI() + kwargs = {} + if self.base_url: + kwargs["base_url"] = self.base_url + if self.default_headers: + kwargs["default_headers"] = self.default_headers + if self.organization: + kwargs["organization"] = self.organization + if self.api_key: + kwargs["api_key"] = self + return openai.OpenAI(**kwargs)