multi-modal embedding-function (#484)

This commit is contained in:
Chang She
2023-09-16 21:23:51 -04:00
committed by GitHub
parent 9585f550b3
commit 31dad71c94
13 changed files with 645 additions and 143 deletions

View File

@@ -28,7 +28,8 @@ from lance.dataset import ReaderLike
from lance.vector import vec_to_table
from .common import DATA, VEC, VECTOR_COLUMN_NAME
from .embeddings import EmbeddingFunctionModel, EmbeddingFunctionRegistry
from .embeddings import EmbeddingFunctionRegistry
from .embeddings.functions import EmbeddingFunctionConfig
from .pydantic import LanceModel
from .query import LanceQueryBuilder, Query
from .util import fs_from_uri, safe_import_pandas
@@ -81,15 +82,16 @@ def _append_vector_col(data: pa.Table, metadata: dict, schema: Optional[pa.Schem
vector column to the table.
"""
functions = EmbeddingFunctionRegistry.get_instance().parse_functions(metadata)
for vector_col, func in functions.items():
if vector_col not in data.column_names:
col_data = func(data[func.source_column])
for vector_column, conf in functions.items():
func = conf.function
if vector_column not in data.column_names:
col_data = func.compute_source_embeddings(data[conf.source_column])
if schema is not None:
dtype = schema.field(vector_col).type
dtype = schema.field(vector_column).type
else:
dtype = pa.list_(pa.float32(), len(col_data[0]))
data = data.append_column(
pa.field(vector_col, type=dtype), pa.array(col_data, type=dtype)
pa.field(vector_column, type=dtype), pa.array(col_data, type=dtype)
)
return data
@@ -230,7 +232,7 @@ class Table(ABC):
@abstractmethod
def search(
self,
query: Optional[Union[VEC, str]] = None,
query: Optional[Union[VEC, str, "PIL.Image.Image"]] = None,
vector_column_name: str = VECTOR_COLUMN_NAME,
query_type: str = "auto",
) -> LanceQueryBuilder:
@@ -239,7 +241,7 @@ class Table(ABC):
Parameters
----------
query: str, list, np.ndarray, default None
query: str, list, np.ndarray, PIL.Image.Image, default None
The query to search for. If None then
the select/where/limit clauses are applied to filter
the table
@@ -249,6 +251,8 @@ class Table(ABC):
"vector", "fts", or "auto"
If "auto" then the query type is inferred from the query;
If `query` is a list/np.ndarray then the query type is "vector";
If `query` is a PIL.Image.Image then either do vector search
or raise an error if no corresponding embedding function is found.
If `query` is a string, then the query type is "vector" if the
table has embedding functions else the query type is "fts"
@@ -524,6 +528,9 @@ class LanceTable(Table):
fill_value: float = 0.0,
):
"""Add data to the table.
If vector columns are missing and the table
has embedding functions, then the vector columns
are automatically computed and added.
Parameters
----------
@@ -617,12 +624,6 @@ class LanceTable(Table):
)
self._reset_dataset()
def _get_embedding_function_for_source_col(self, column_name: str):
for k, v in self.embedding_functions.items():
if v.source_column == column_name:
return v
return None
@cached_property
def embedding_functions(self) -> dict:
"""
@@ -640,7 +641,7 @@ class LanceTable(Table):
def search(
self,
query: Optional[Union[VEC, str]] = None,
query: Optional[Union[VEC, str, "PIL.Image.Image"]] = None,
vector_column_name: str = VECTOR_COLUMN_NAME,
query_type: str = "auto",
) -> LanceQueryBuilder:
@@ -649,7 +650,7 @@ class LanceTable(Table):
Parameters
----------
query: str, list, np.ndarray, or None
query: str, list, np.ndarray, a PIL Image or None
The query to search for. If None then
the select/where/limit clauses are applied to filter
the table
@@ -658,9 +659,11 @@ class LanceTable(Table):
query_type: str, default "auto"
"vector", "fts", or "auto"
If "auto" then the query type is inferred from the query;
If the query is a list/np.ndarray then the query type is "vector";
If `query` is a list/np.ndarray then the query type is "vector";
If `query` is a PIL.Image.Image then either do vector search
or raise an error if no corresponding embedding function is found.
If the query is a string, then the query type is "vector" if the
table has embedding functions else the query type is "fts"
table has embedding functions, else the query type is "fts"
Returns
-------
@@ -684,7 +687,7 @@ class LanceTable(Table):
mode="create",
on_bad_vectors: str = "error",
fill_value: float = 0.0,
embedding_functions: List[EmbeddingFunctionModel] = None,
embedding_functions: List[EmbeddingFunctionConfig] = None,
):
"""
Create a new table.
@@ -727,10 +730,16 @@ class LanceTable(Table):
"""
tbl = LanceTable(db, name)
if inspect.isclass(schema) and issubclass(schema, LanceModel):
# convert LanceModel to pyarrow schema
# note that it's possible this contains
# embedding function metadata already
schema = schema.to_arrow_schema()
metadata = None
if embedding_functions is not None:
# If we passed in embedding functions explicitly
# then we'll override any schema metadata that
# may was implicitly specified by the LanceModel schema
registry = EmbeddingFunctionRegistry.get_instance()
metadata = registry.get_table_metadata(embedding_functions)