fix: added support for trust_remote_code (#1454)

Closes #1285 

Added trust_remote_code to the SentenceTransformerEmbeddings class.
Defaults to `False`
This commit is contained in:
Magnus
2024-07-18 16:07:52 +02:00
committed by GitHub
parent d564f6eacb
commit dc609a337d
2 changed files with 9 additions and 5 deletions

View File

@@ -31,6 +31,7 @@ class SentenceTransformerEmbeddings(TextEmbeddingFunction):
name: str = "all-MiniLM-L6-v2"
device: str = "cpu"
normalize: bool = True
trust_remote_code: bool = False
def __init__(self, **kwargs):
super().__init__(**kwargs)
@@ -40,8 +41,8 @@ class SentenceTransformerEmbeddings(TextEmbeddingFunction):
def embedding_model(self):
"""
Get the sentence-transformers embedding model specified by the
name and device. This is cached so that the model is only loaded
once per process.
name, device, and trust_remote_code. This is cached so that the
model is only loaded once per process.
"""
return self.get_embedding_model()
@@ -71,12 +72,14 @@ class SentenceTransformerEmbeddings(TextEmbeddingFunction):
def get_embedding_model(self):
"""
Get the sentence-transformers embedding model specified by the
name and device. This is cached so that the model is only loaded
once per process.
name, device, and trust_remote_code. This is cached so that the
model is only loaded once per process.
TODO: use lru_cache instead with a reasonable/configurable maxsize
"""
sentence_transformers = attempt_import_or_raise(
"sentence_transformers", "sentence-transformers"
)
return sentence_transformers.SentenceTransformer(self.name, device=self.device)
return sentence_transformers.SentenceTransformer(
self.name, device=self.device, trust_remote_code=self.trust_remote_code
)