feat(python): allow user to override api url (#1054)

This commit is contained in:
Chang She
2024-03-03 18:29:47 -08:00
committed by Weston Pace
parent a7dbe933dc
commit e60fde73ba

View File

@@ -10,16 +10,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from functools import cached_property
from typing import List, Optional, Union
import numpy as np
from typing import TYPE_CHECKING, List, Optional, Union
from ..util import attempt_import_or_raise
from .base import TextEmbeddingFunction
from .registry import register
from .utils import api_key_not_found_help
if TYPE_CHECKING:
import numpy as np
@register("openai")
@@ -28,10 +27,34 @@ class OpenAIEmbeddings(TextEmbeddingFunction):
An embedding function that uses the OpenAI API
https://platform.openai.com/docs/guides/embeddings
This can also be used for open source models that
are compatible with the OpenAI API.
Notes
-----
If you're running an Ollama server locally,
you can just override the `base_url` parameter
and provide the Ollama embedding model you want
to use (https://ollama.com/library):
```python
from lancedb.embeddings import get_registry
openai = get_registry().get("openai")
embedding_function = openai.create(
name="<ollama-embedding-model-name>",
base_url="http://localhost:11434",
)
```
"""
name: str = "text-embedding-ada-002"
dim: Optional[int] = None
base_url: Optional[str] = None
default_headers: Optional[dict] = None
organization: Optional[str] = None
api_key: Optional[str] = None
def ndims(self):
return self._ndims
@@ -56,8 +79,8 @@ class OpenAIEmbeddings(TextEmbeddingFunction):
raise ValueError(f"Unknown model name {self.name}")
def generate_embeddings(
self, texts: Union[List[str], np.ndarray]
) -> List[np.array]:
self, texts: Union[List[str], "np.ndarray"]
) -> List["np.array"]:
"""
Get the embeddings for the given texts
@@ -70,15 +93,25 @@ class OpenAIEmbeddings(TextEmbeddingFunction):
if self.name == "text-embedding-ada-002":
rs = self._openai_client.embeddings.create(input=texts, model=self.name)
else:
rs = self._openai_client.embeddings.create(
input=texts, model=self.name, dimensions=self.ndims()
)
kwargs = {
"input": texts,
"model": self.name,
}
if self.dim:
kwargs["dimensions"] = self.dim
rs = self._openai_client.embeddings.create(**kwargs)
return [v.embedding for v in rs.data]
@cached_property
def _openai_client(self):
openai = attempt_import_or_raise("openai")
if not os.environ.get("OPENAI_API_KEY"):
api_key_not_found_help("openai")
return openai.OpenAI()
kwargs = {}
if self.base_url:
kwargs["base_url"] = self.base_url
if self.default_headers:
kwargs["default_headers"] = self.default_headers
if self.organization:
kwargs["organization"] = self.organization
if self.api_key:
kwargs["api_key"] = self
return openai.OpenAI(**kwargs)