mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-06 20:02:58 +00:00
fix: added support for trust_remote_code (#1454)
Closes #1285 Added trust_remote_code to the SentenceTransformerEmbeddings class. Defaults to `False`
This commit is contained in:
@@ -17,6 +17,7 @@ Allows you to set parameters when registering a `sentence-transformers` object.
|
|||||||
| `name` | `str` | `all-MiniLM-L6-v2` | The name of the model |
|
| `name` | `str` | `all-MiniLM-L6-v2` | The name of the model |
|
||||||
| `device` | `str` | `cpu` | The device to run the model on (can be `cpu` or `gpu`) |
|
| `device` | `str` | `cpu` | The device to run the model on (can be `cpu` or `gpu`) |
|
||||||
| `normalize` | `bool` | `True` | Whether to normalize the input text before feeding it to the model |
|
| `normalize` | `bool` | `True` | Whether to normalize the input text before feeding it to the model |
|
||||||
|
| `trust_remote_code` | `bool` | `False` | Whether to trust and execute remote code from the model's Huggingface repository |
|
||||||
|
|
||||||
|
|
||||||
??? "Check out available sentence-transformer models here!"
|
??? "Check out available sentence-transformer models here!"
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ class SentenceTransformerEmbeddings(TextEmbeddingFunction):
|
|||||||
name: str = "all-MiniLM-L6-v2"
|
name: str = "all-MiniLM-L6-v2"
|
||||||
device: str = "cpu"
|
device: str = "cpu"
|
||||||
normalize: bool = True
|
normalize: bool = True
|
||||||
|
trust_remote_code: bool = False
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
@@ -40,8 +41,8 @@ class SentenceTransformerEmbeddings(TextEmbeddingFunction):
|
|||||||
def embedding_model(self):
|
def embedding_model(self):
|
||||||
"""
|
"""
|
||||||
Get the sentence-transformers embedding model specified by the
|
Get the sentence-transformers embedding model specified by the
|
||||||
name and device. This is cached so that the model is only loaded
|
name, device, and trust_remote_code. This is cached so that the
|
||||||
once per process.
|
model is only loaded once per process.
|
||||||
"""
|
"""
|
||||||
return self.get_embedding_model()
|
return self.get_embedding_model()
|
||||||
|
|
||||||
@@ -71,12 +72,14 @@ class SentenceTransformerEmbeddings(TextEmbeddingFunction):
|
|||||||
def get_embedding_model(self):
|
def get_embedding_model(self):
|
||||||
"""
|
"""
|
||||||
Get the sentence-transformers embedding model specified by the
|
Get the sentence-transformers embedding model specified by the
|
||||||
name and device. This is cached so that the model is only loaded
|
name, device, and trust_remote_code. This is cached so that the
|
||||||
once per process.
|
model is only loaded once per process.
|
||||||
|
|
||||||
TODO: use lru_cache instead with a reasonable/configurable maxsize
|
TODO: use lru_cache instead with a reasonable/configurable maxsize
|
||||||
"""
|
"""
|
||||||
sentence_transformers = attempt_import_or_raise(
|
sentence_transformers = attempt_import_or_raise(
|
||||||
"sentence_transformers", "sentence-transformers"
|
"sentence_transformers", "sentence-transformers"
|
||||||
)
|
)
|
||||||
return sentence_transformers.SentenceTransformer(self.name, device=self.device)
|
return sentence_transformers.SentenceTransformer(
|
||||||
|
self.name, device=self.device, trust_remote_code=self.trust_remote_code
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user