multi-modal embedding-function (#484)

This commit is contained in:
Chang She
2023-09-16 21:23:51 -04:00
committed by GitHub
parent 9585f550b3
commit 31dad71c94
13 changed files with 645 additions and 143 deletions

View File

@@ -60,13 +60,15 @@ class LanceQueryBuilder(ABC):
def create(
cls,
table: "lancedb.table.Table",
query: Optional[Union[np.ndarray, str]],
query: Optional[Union[np.ndarray, str, "PIL.Image.Image"]],
query_type: str,
vector_column_name: str,
) -> LanceQueryBuilder:
if query is None:
return LanceEmptyQueryBuilder(table)
# convert "auto" query_type to "vector" or "fts"
# and convert the query to vector if needed
query, query_type = cls._resolve_query(
table, query, query_type, vector_column_name
)
@@ -90,30 +92,27 @@ class LanceQueryBuilder(ABC):
# otherwise raise TypeError
if query_type == "fts":
if not isinstance(query, str):
raise TypeError(
f"Query type is 'fts' but query is not a string: {type(query)}"
)
raise TypeError(f"'fts' queries must be a string: {type(query)}")
return query, query_type
elif query_type == "vector":
# If query_type is vector, then query must be a list or np.ndarray.
# otherwise raise TypeError
if not isinstance(query, (list, np.ndarray)):
raise TypeError(
f"Query type is 'vector' but query is not a list or np.ndarray: {type(query)}"
)
conf = table.embedding_functions.get(vector_column_name)
if conf is not None:
query = conf.function.compute_query_embeddings(query)[0]
else:
msg = f"No embedding function for {vector_column_name}"
raise ValueError(msg)
return query, query_type
elif query_type == "auto":
if isinstance(query, (list, np.ndarray)):
return query, "vector"
elif isinstance(query, str):
func = table.embedding_functions.get(vector_column_name, None)
if func is not None:
query = func(query)[0]
else:
conf = table.embedding_functions.get(vector_column_name)
if conf is not None:
query = conf.function.compute_query_embeddings(query)[0]
return query, "vector"
else:
return query, "fts"
else:
raise TypeError("Query must be a list, np.ndarray, or str")
else:
raise ValueError(
f"Invalid query_type, must be 'vector', 'fts', or 'auto': {query_type}"
@@ -238,7 +237,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
def __init__(
self,
table: "lancedb.table.Table",
query: Union[np.ndarray, list],
query: Union[np.ndarray, list, "PIL.Image.Image"],
vector_column: str = VECTOR_COLUMN_NAME,
):
super().__init__(table)