mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-10 13:52:58 +00:00
multi-modal embedding-function (#484)
This commit is contained in:
@@ -28,7 +28,8 @@ from lance.dataset import ReaderLike
|
||||
from lance.vector import vec_to_table
|
||||
|
||||
from .common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||
from .embeddings import EmbeddingFunctionModel, EmbeddingFunctionRegistry
|
||||
from .embeddings import EmbeddingFunctionRegistry
|
||||
from .embeddings.functions import EmbeddingFunctionConfig
|
||||
from .pydantic import LanceModel
|
||||
from .query import LanceQueryBuilder, Query
|
||||
from .util import fs_from_uri, safe_import_pandas
|
||||
@@ -81,15 +82,16 @@ def _append_vector_col(data: pa.Table, metadata: dict, schema: Optional[pa.Schem
|
||||
vector column to the table.
|
||||
"""
|
||||
functions = EmbeddingFunctionRegistry.get_instance().parse_functions(metadata)
|
||||
for vector_col, func in functions.items():
|
||||
if vector_col not in data.column_names:
|
||||
col_data = func(data[func.source_column])
|
||||
for vector_column, conf in functions.items():
|
||||
func = conf.function
|
||||
if vector_column not in data.column_names:
|
||||
col_data = func.compute_source_embeddings(data[conf.source_column])
|
||||
if schema is not None:
|
||||
dtype = schema.field(vector_col).type
|
||||
dtype = schema.field(vector_column).type
|
||||
else:
|
||||
dtype = pa.list_(pa.float32(), len(col_data[0]))
|
||||
data = data.append_column(
|
||||
pa.field(vector_col, type=dtype), pa.array(col_data, type=dtype)
|
||||
pa.field(vector_column, type=dtype), pa.array(col_data, type=dtype)
|
||||
)
|
||||
return data
|
||||
|
||||
@@ -230,7 +232,7 @@ class Table(ABC):
|
||||
@abstractmethod
|
||||
def search(
|
||||
self,
|
||||
query: Optional[Union[VEC, str]] = None,
|
||||
query: Optional[Union[VEC, str, "PIL.Image.Image"]] = None,
|
||||
vector_column_name: str = VECTOR_COLUMN_NAME,
|
||||
query_type: str = "auto",
|
||||
) -> LanceQueryBuilder:
|
||||
@@ -239,7 +241,7 @@ class Table(ABC):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query: str, list, np.ndarray, default None
|
||||
query: str, list, np.ndarray, PIL.Image.Image, default None
|
||||
The query to search for. If None then
|
||||
the select/where/limit clauses are applied to filter
|
||||
the table
|
||||
@@ -249,6 +251,8 @@ class Table(ABC):
|
||||
"vector", "fts", or "auto"
|
||||
If "auto" then the query type is inferred from the query;
|
||||
If `query` is a list/np.ndarray then the query type is "vector";
|
||||
If `query` is a PIL.Image.Image then either do vector search
|
||||
or raise an error if no corresponding embedding function is found.
|
||||
If `query` is a string, then the query type is "vector" if the
|
||||
table has embedding functions else the query type is "fts"
|
||||
|
||||
@@ -524,6 +528,9 @@ class LanceTable(Table):
|
||||
fill_value: float = 0.0,
|
||||
):
|
||||
"""Add data to the table.
|
||||
If vector columns are missing and the table
|
||||
has embedding functions, then the vector columns
|
||||
are automatically computed and added.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -617,12 +624,6 @@ class LanceTable(Table):
|
||||
)
|
||||
self._reset_dataset()
|
||||
|
||||
def _get_embedding_function_for_source_col(self, column_name: str):
|
||||
for k, v in self.embedding_functions.items():
|
||||
if v.source_column == column_name:
|
||||
return v
|
||||
return None
|
||||
|
||||
@cached_property
|
||||
def embedding_functions(self) -> dict:
|
||||
"""
|
||||
@@ -640,7 +641,7 @@ class LanceTable(Table):
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: Optional[Union[VEC, str]] = None,
|
||||
query: Optional[Union[VEC, str, "PIL.Image.Image"]] = None,
|
||||
vector_column_name: str = VECTOR_COLUMN_NAME,
|
||||
query_type: str = "auto",
|
||||
) -> LanceQueryBuilder:
|
||||
@@ -649,7 +650,7 @@ class LanceTable(Table):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query: str, list, np.ndarray, or None
|
||||
query: str, list, np.ndarray, a PIL Image or None
|
||||
The query to search for. If None then
|
||||
the select/where/limit clauses are applied to filter
|
||||
the table
|
||||
@@ -658,9 +659,11 @@ class LanceTable(Table):
|
||||
query_type: str, default "auto"
|
||||
"vector", "fts", or "auto"
|
||||
If "auto" then the query type is inferred from the query;
|
||||
If the query is a list/np.ndarray then the query type is "vector";
|
||||
If `query` is a list/np.ndarray then the query type is "vector";
|
||||
If `query` is a PIL.Image.Image then either do vector search
|
||||
or raise an error if no corresponding embedding function is found.
|
||||
If the query is a string, then the query type is "vector" if the
|
||||
table has embedding functions else the query type is "fts"
|
||||
table has embedding functions, else the query type is "fts"
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -684,7 +687,7 @@ class LanceTable(Table):
|
||||
mode="create",
|
||||
on_bad_vectors: str = "error",
|
||||
fill_value: float = 0.0,
|
||||
embedding_functions: List[EmbeddingFunctionModel] = None,
|
||||
embedding_functions: List[EmbeddingFunctionConfig] = None,
|
||||
):
|
||||
"""
|
||||
Create a new table.
|
||||
@@ -727,10 +730,16 @@ class LanceTable(Table):
|
||||
"""
|
||||
tbl = LanceTable(db, name)
|
||||
if inspect.isclass(schema) and issubclass(schema, LanceModel):
|
||||
# convert LanceModel to pyarrow schema
|
||||
# note that it's possible this contains
|
||||
# embedding function metadata already
|
||||
schema = schema.to_arrow_schema()
|
||||
|
||||
metadata = None
|
||||
if embedding_functions is not None:
|
||||
# If we passed in embedding functions explicitly
|
||||
# then we'll override any schema metadata that
|
||||
# may was implicitly specified by the LanceModel schema
|
||||
registry = EmbeddingFunctionRegistry.get_instance()
|
||||
metadata = registry.get_table_metadata(embedding_functions)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user