diff --git a/python/python/lancedb/embeddings/registry.py b/python/python/lancedb/embeddings/registry.py index 91424253..10978486 100644 --- a/python/python/lancedb/embeddings/registry.py +++ b/python/python/lancedb/embeddings/registry.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright The LanceDB Authors import json -from typing import Dict, Optional +from typing import Dict, Optional, Type from .base import EmbeddingFunction, EmbeddingFunctionConfig @@ -43,7 +43,7 @@ class EmbeddingFunctionRegistry: self._functions = {} self._variables = {} - def register(self, alias: str = None): + def register(self, alias: Optional[str] = None): """ This creates a decorator that can be used to register an EmbeddingFunction. @@ -75,7 +75,7 @@ class EmbeddingFunctionRegistry: """ self._functions = {} - def get(self, name: str): + def get(self, name: str) -> Type[EmbeddingFunction]: """ Fetch an embedding function class by name