mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-06 20:02:58 +00:00
feat(python): allow user to override api url (#1054)
This commit is contained in:
@@ -10,16 +10,15 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import os
|
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import List, Optional, Union
|
from typing import TYPE_CHECKING, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from ..util import attempt_import_or_raise
|
from ..util import attempt_import_or_raise
|
||||||
from .base import TextEmbeddingFunction
|
from .base import TextEmbeddingFunction
|
||||||
from .registry import register
|
from .registry import register
|
||||||
from .utils import api_key_not_found_help
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
@register("openai")
|
@register("openai")
|
||||||
@@ -28,10 +27,34 @@ class OpenAIEmbeddings(TextEmbeddingFunction):
|
|||||||
An embedding function that uses the OpenAI API
|
An embedding function that uses the OpenAI API
|
||||||
|
|
||||||
https://platform.openai.com/docs/guides/embeddings
|
https://platform.openai.com/docs/guides/embeddings
|
||||||
|
|
||||||
|
This can also be used for open source models that
|
||||||
|
are compatible with the OpenAI API.
|
||||||
|
|
||||||
|
Notes
|
||||||
|
-----
|
||||||
|
If you're running an Ollama server locally,
|
||||||
|
you can just override the `base_url` parameter
|
||||||
|
and provide the Ollama embedding model you want
|
||||||
|
to use (https://ollama.com/library):
|
||||||
|
|
||||||
|
```python
|
||||||
|
from lancedb.embeddings import get_registry
|
||||||
|
openai = get_registry().get("openai")
|
||||||
|
embedding_function = openai.create(
|
||||||
|
name="<ollama-embedding-model-name>",
|
||||||
|
base_url="http://localhost:11434",
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name: str = "text-embedding-ada-002"
|
name: str = "text-embedding-ada-002"
|
||||||
dim: Optional[int] = None
|
dim: Optional[int] = None
|
||||||
|
base_url: Optional[str] = None
|
||||||
|
default_headers: Optional[dict] = None
|
||||||
|
organization: Optional[str] = None
|
||||||
|
api_key: Optional[str] = None
|
||||||
|
|
||||||
def ndims(self):
|
def ndims(self):
|
||||||
return self._ndims
|
return self._ndims
|
||||||
@@ -56,8 +79,8 @@ class OpenAIEmbeddings(TextEmbeddingFunction):
|
|||||||
raise ValueError(f"Unknown model name {self.name}")
|
raise ValueError(f"Unknown model name {self.name}")
|
||||||
|
|
||||||
def generate_embeddings(
|
def generate_embeddings(
|
||||||
self, texts: Union[List[str], np.ndarray]
|
self, texts: Union[List[str], "np.ndarray"]
|
||||||
) -> List[np.array]:
|
) -> List["np.array"]:
|
||||||
"""
|
"""
|
||||||
Get the embeddings for the given texts
|
Get the embeddings for the given texts
|
||||||
|
|
||||||
@@ -70,15 +93,25 @@ class OpenAIEmbeddings(TextEmbeddingFunction):
|
|||||||
if self.name == "text-embedding-ada-002":
|
if self.name == "text-embedding-ada-002":
|
||||||
rs = self._openai_client.embeddings.create(input=texts, model=self.name)
|
rs = self._openai_client.embeddings.create(input=texts, model=self.name)
|
||||||
else:
|
else:
|
||||||
rs = self._openai_client.embeddings.create(
|
kwargs = {
|
||||||
input=texts, model=self.name, dimensions=self.ndims()
|
"input": texts,
|
||||||
)
|
"model": self.name,
|
||||||
|
}
|
||||||
|
if self.dim:
|
||||||
|
kwargs["dimensions"] = self.dim
|
||||||
|
rs = self._openai_client.embeddings.create(**kwargs)
|
||||||
return [v.embedding for v in rs.data]
|
return [v.embedding for v in rs.data]
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def _openai_client(self):
|
def _openai_client(self):
|
||||||
openai = attempt_import_or_raise("openai")
|
openai = attempt_import_or_raise("openai")
|
||||||
|
kwargs = {}
|
||||||
if not os.environ.get("OPENAI_API_KEY"):
|
if self.base_url:
|
||||||
api_key_not_found_help("openai")
|
kwargs["base_url"] = self.base_url
|
||||||
return openai.OpenAI()
|
if self.default_headers:
|
||||||
|
kwargs["default_headers"] = self.default_headers
|
||||||
|
if self.organization:
|
||||||
|
kwargs["organization"] = self.organization
|
||||||
|
if self.api_key:
|
||||||
|
kwargs["api_key"] = self
|
||||||
|
return openai.OpenAI(**kwargs)
|
||||||
|
|||||||
Reference in New Issue
Block a user