diff --git a/docs/src/embeddings/default_embedding_functions.md b/docs/src/embeddings/default_embedding_functions.md index 14910485..00293360 100644 --- a/docs/src/embeddings/default_embedding_functions.md +++ b/docs/src/embeddings/default_embedding_functions.md @@ -17,6 +17,7 @@ Allows you to set parameters when registering a `sentence-transformers` object. | `name` | `str` | `all-MiniLM-L6-v2` | The name of the model | | `device` | `str` | `cpu` | The device to run the model on (can be `cpu` or `gpu`) | | `normalize` | `bool` | `True` | Whether to normalize the input text before feeding it to the model | +| `trust_remote_code` | `bool` | `False` | Whether to trust and execute remote code from the model's Huggingface repository | ??? "Check out available sentence-transformer models here!" diff --git a/python/python/lancedb/embeddings/sentence_transformers.py b/python/python/lancedb/embeddings/sentence_transformers.py index 97fe1318..fe8e997d 100644 --- a/python/python/lancedb/embeddings/sentence_transformers.py +++ b/python/python/lancedb/embeddings/sentence_transformers.py @@ -31,6 +31,7 @@ class SentenceTransformerEmbeddings(TextEmbeddingFunction): name: str = "all-MiniLM-L6-v2" device: str = "cpu" normalize: bool = True + trust_remote_code: bool = False def __init__(self, **kwargs): super().__init__(**kwargs) @@ -40,8 +41,8 @@ class SentenceTransformerEmbeddings(TextEmbeddingFunction): def embedding_model(self): """ Get the sentence-transformers embedding model specified by the - name and device. This is cached so that the model is only loaded - once per process. + name, device, and trust_remote_code. This is cached so that the + model is only loaded once per process. """ return self.get_embedding_model() @@ -71,12 +72,14 @@ class SentenceTransformerEmbeddings(TextEmbeddingFunction): def get_embedding_model(self): """ Get the sentence-transformers embedding model specified by the - name and device. This is cached so that the model is only loaded - once per process. + name, device, and trust_remote_code. This is cached so that the + model is only loaded once per process. TODO: use lru_cache instead with a reasonable/configurable maxsize """ sentence_transformers = attempt_import_or_raise( "sentence_transformers", "sentence-transformers" ) - return sentence_transformers.SentenceTransformer(self.name, device=self.device) + return sentence_transformers.SentenceTransformer( + self.name, device=self.device, trust_remote_code=self.trust_remote_code + )