mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-14 15:52:57 +00:00
feat(python): add support new openai embedding functions (#912)
@PrashantDixit0 --------- Co-authored-by: Chang She <759245+changhiskhan@users.noreply.github.com>
This commit is contained in:
committed by
Weston Pace
parent
84edf56995
commit
0f00cd0097
@@ -12,7 +12,7 @@
|
||||
# limitations under the License.
|
||||
import os
|
||||
from functools import cached_property
|
||||
from typing import List, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -30,10 +30,21 @@ class OpenAIEmbeddings(TextEmbeddingFunction):
|
||||
"""
|
||||
|
||||
name: str = "text-embedding-ada-002"
|
||||
dim: Optional[int] = None
|
||||
|
||||
def ndims(self):
|
||||
# TODO don't hardcode this
|
||||
return 1536
|
||||
return self._ndims
|
||||
|
||||
@cached_property
|
||||
def _ndims(self):
|
||||
if self.name == "text-embedding-ada-002":
|
||||
return 1536
|
||||
elif self.name == "text-embedding-3-large":
|
||||
return self.dim or 3072
|
||||
elif self.name == "text-embedding-3-small":
|
||||
return self.dim or 1536
|
||||
else:
|
||||
raise ValueError(f"Unknown model name {self.name}")
|
||||
|
||||
def generate_embeddings(
|
||||
self, texts: Union[List[str], np.ndarray]
|
||||
@@ -47,7 +58,12 @@ class OpenAIEmbeddings(TextEmbeddingFunction):
|
||||
The texts to embed
|
||||
"""
|
||||
# TODO retry, rate limit, token limit
|
||||
rs = self._openai_client.embeddings.create(input=texts, model=self.name)
|
||||
if self.name == "text-embedding-ada-002":
|
||||
rs = self._openai_client.embeddings.create(input=texts, model=self.name)
|
||||
else:
|
||||
rs = self._openai_client.embeddings.create(
|
||||
input=texts, model=self.name, dimensions=self.ndims()
|
||||
)
|
||||
return [v.embedding for v in rs.data]
|
||||
|
||||
@cached_property
|
||||
|
||||
Reference in New Issue
Block a user