mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-27 23:12:58 +00:00
feat(python): allow user to override api url (#1054)
This commit is contained in:
@@ -10,16 +10,15 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from functools import cached_property
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import TextEmbeddingFunction
|
||||
from .registry import register
|
||||
from .utils import api_key_not_found_help
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
|
||||
|
||||
@register("openai")
|
||||
@@ -28,10 +27,34 @@ class OpenAIEmbeddings(TextEmbeddingFunction):
|
||||
An embedding function that uses the OpenAI API
|
||||
|
||||
https://platform.openai.com/docs/guides/embeddings
|
||||
|
||||
This can also be used for open source models that
|
||||
are compatible with the OpenAI API.
|
||||
|
||||
Notes
|
||||
-----
|
||||
If you're running an Ollama server locally,
|
||||
you can just override the `base_url` parameter
|
||||
and provide the Ollama embedding model you want
|
||||
to use (https://ollama.com/library):
|
||||
|
||||
```python
|
||||
from lancedb.embeddings import get_registry
|
||||
openai = get_registry().get("openai")
|
||||
embedding_function = openai.create(
|
||||
name="<ollama-embedding-model-name>",
|
||||
base_url="http://localhost:11434",
|
||||
)
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
name: str = "text-embedding-ada-002"
|
||||
dim: Optional[int] = None
|
||||
base_url: Optional[str] = None
|
||||
default_headers: Optional[dict] = None
|
||||
organization: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
|
||||
def ndims(self):
|
||||
return self._ndims
|
||||
@@ -56,8 +79,8 @@ class OpenAIEmbeddings(TextEmbeddingFunction):
|
||||
raise ValueError(f"Unknown model name {self.name}")
|
||||
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray]
|
||||
) -> List[np.array]:
|
||||
self, texts: Union[List[str], "np.ndarray"]
|
||||
) -> List["np.array"]:
|
||||
"""
|
||||
Get the embeddings for the given texts
|
||||
|
||||
@@ -70,15 +93,25 @@ class OpenAIEmbeddings(TextEmbeddingFunction):
|
||||
if self.name == "text-embedding-ada-002":
|
||||
rs = self._openai_client.embeddings.create(input=texts, model=self.name)
|
||||
else:
|
||||
rs = self._openai_client.embeddings.create(
|
||||
input=texts, model=self.name, dimensions=self.ndims()
|
||||
)
|
||||
kwargs = {
|
||||
"input": texts,
|
||||
"model": self.name,
|
||||
}
|
||||
if self.dim:
|
||||
kwargs["dimensions"] = self.dim
|
||||
rs = self._openai_client.embeddings.create(**kwargs)
|
||||
return [v.embedding for v in rs.data]
|
||||
|
||||
@cached_property
|
||||
def _openai_client(self):
|
||||
openai = attempt_import_or_raise("openai")
|
||||
|
||||
if not os.environ.get("OPENAI_API_KEY"):
|
||||
api_key_not_found_help("openai")
|
||||
return openai.OpenAI()
|
||||
kwargs = {}
|
||||
if self.base_url:
|
||||
kwargs["base_url"] = self.base_url
|
||||
if self.default_headers:
|
||||
kwargs["default_headers"] = self.default_headers
|
||||
if self.organization:
|
||||
kwargs["organization"] = self.organization
|
||||
if self.api_key:
|
||||
kwargs["api_key"] = self
|
||||
return openai.OpenAI(**kwargs)
|
||||
|
||||
Reference in New Issue
Block a user