[python] Use pydantic for embedding function persistence (#467)

1. Support persistent embedding function so users can just search using
query string
2. Add fixed size list conversion for multiple vector columns
3. Add support for empty query (just apply select/where/limit).
4. Refactor and simplify some of the data prep code

---------

Co-authored-by: Chang She <chang@lancedb.com>
Co-authored-by: Weston Pace <weston.pace@gmail.com>
This commit is contained in:
Chang She
2023-09-05 21:30:45 -07:00
committed by GitHub
parent 52fa7f5577
commit 9a9a73a65d
13 changed files with 815 additions and 192 deletions

View File

@@ -13,6 +13,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import List, Literal, Optional, Type, Union
import numpy as np
@@ -54,7 +55,164 @@ class Query(pydantic.BaseModel):
refine_factor: Optional[int] = None
class LanceQueryBuilder:
class LanceQueryBuilder(ABC):
@classmethod
def create(
cls,
table: "lancedb.table.Table",
query: Optional[Union[np.ndarray, str]],
query_type: str,
vector_column_name: str,
) -> LanceQueryBuilder:
if query is None:
return LanceEmptyQueryBuilder(table)
query, query_type = cls._resolve_query(
table, query, query_type, vector_column_name
)
if isinstance(query, str):
# fts
return LanceFtsQueryBuilder(table, query)
if isinstance(query, list):
query = np.array(query, dtype=np.float32)
elif isinstance(query, np.ndarray):
query = query.astype(np.float32)
else:
raise TypeError(f"Unsupported query type: {type(query)}")
return LanceVectorQueryBuilder(table, query, vector_column_name)
@classmethod
def _resolve_query(cls, table, query, query_type, vector_column_name):
# If query_type is fts, then query must be a string.
# otherwise raise TypeError
if query_type == "fts":
if not isinstance(query, str):
raise TypeError(
f"Query type is 'fts' but query is not a string: {type(query)}"
)
return query, query_type
elif query_type == "vector":
# If query_type is vector, then query must be a list or np.ndarray.
# otherwise raise TypeError
if not isinstance(query, (list, np.ndarray)):
raise TypeError(
f"Query type is 'vector' but query is not a list or np.ndarray: {type(query)}"
)
return query, query_type
elif query_type == "auto":
if isinstance(query, (list, np.ndarray)):
return query, "vector"
elif isinstance(query, str):
func = table.embedding_functions.get(vector_column_name, None)
if func is not None:
query = func(query)[0]
return query, "vector"
else:
return query, "fts"
else:
raise TypeError("Query must be a list, np.ndarray, or str")
else:
raise ValueError(
f"Invalid query_type, must be 'vector', 'fts', or 'auto': {query_type}"
)
def __init__(self, table: "lancedb.table.Table"):
self._table = table
self._limit = 10
self._columns = None
self._where = None
def to_df(self) -> "pd.DataFrame":
"""
Execute the query and return the results as a pandas DataFrame.
In addition to the selected columns, LanceDB also returns a vector
and also the "_distance" column which is the distance between the query
vector and the returned vector.
"""
return self.to_arrow().to_pandas()
@abstractmethod
def to_arrow(self) -> pa.Table:
"""
Execute the query and return the results as an
[Apache Arrow Table](https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table).
In addition to the selected columns, LanceDB also returns a vector
and also the "_distance" column which is the distance between the query
vector and the returned vectors.
"""
raise NotImplementedError
def to_pydantic(self, model: Type[LanceModel]) -> List[LanceModel]:
"""Return the table as a list of pydantic models.
Parameters
----------
model: Type[LanceModel]
The pydantic model to use.
Returns
-------
List[LanceModel]
"""
return [
model(**{k: v for k, v in row.items() if k in model.field_names()})
for row in self.to_arrow().to_pylist()
]
def limit(self, limit: int) -> LanceVectorQueryBuilder:
"""Set the maximum number of results to return.
Parameters
----------
limit: int
The maximum number of results to return.
Returns
-------
LanceVectorQueryBuilder
The LanceQueryBuilder object.
"""
self._limit = limit
return self
def select(self, columns: list) -> LanceVectorQueryBuilder:
"""Set the columns to return.
Parameters
----------
columns: list
The columns to return.
Returns
-------
LanceVectorQueryBuilder
The LanceQueryBuilder object.
"""
self._columns = columns
return self
def where(self, where: str) -> LanceVectorQueryBuilder:
"""Set the where clause.
Parameters
----------
where: str
The where clause.
Returns
-------
LanceVectorQueryBuilder
The LanceQueryBuilder object.
"""
self._where = where
return self
class LanceVectorQueryBuilder(LanceQueryBuilder):
"""
A builder for nearest neighbor queries for LanceDB.
@@ -80,68 +238,17 @@ class LanceQueryBuilder:
def __init__(
self,
table: "lancedb.table.Table",
query: Union[np.ndarray, str],
query: Union[np.ndarray, list],
vector_column: str = VECTOR_COLUMN_NAME,
):
super().__init__(table)
self._query = query
self._metric = "L2"
self._nprobes = 20
self._refine_factor = None
self._table = table
self._query = query
self._limit = 10
self._columns = None
self._where = None
self._vector_column = vector_column
def limit(self, limit: int) -> LanceQueryBuilder:
"""Set the maximum number of results to return.
Parameters
----------
limit: int
The maximum number of results to return.
Returns
-------
LanceQueryBuilder
The LanceQueryBuilder object.
"""
self._limit = limit
return self
def select(self, columns: list) -> LanceQueryBuilder:
"""Set the columns to return.
Parameters
----------
columns: list
The columns to return.
Returns
-------
LanceQueryBuilder
The LanceQueryBuilder object.
"""
self._columns = columns
return self
def where(self, where: str) -> LanceQueryBuilder:
"""Set the where clause.
Parameters
----------
where: str
The where clause.
Returns
-------
LanceQueryBuilder
The LanceQueryBuilder object.
"""
self._where = where
return self
def metric(self, metric: Literal["L2", "cosine"]) -> LanceQueryBuilder:
def metric(self, metric: Literal["L2", "cosine"]) -> LanceVectorQueryBuilder:
"""Set the distance metric to use.
Parameters
@@ -151,13 +258,13 @@ class LanceQueryBuilder:
Returns
-------
LanceQueryBuilder
LanceVectorQueryBuilder
The LanceQueryBuilder object.
"""
self._metric = metric
return self
def nprobes(self, nprobes: int) -> LanceQueryBuilder:
def nprobes(self, nprobes: int) -> LanceVectorQueryBuilder:
"""Set the number of probes to use.
Higher values will yield better recall (more likely to find vectors if
@@ -173,13 +280,13 @@ class LanceQueryBuilder:
Returns
-------
LanceQueryBuilder
LanceVectorQueryBuilder
The LanceQueryBuilder object.
"""
self._nprobes = nprobes
return self
def refine_factor(self, refine_factor: int) -> LanceQueryBuilder:
def refine_factor(self, refine_factor: int) -> LanceVectorQueryBuilder:
"""Set the refine factor to use, increasing the number of vectors sampled.
As an example, a refine factor of 2 will sample 2x as many vectors as
@@ -195,22 +302,12 @@ class LanceQueryBuilder:
Returns
-------
LanceQueryBuilder
LanceVectorQueryBuilder
The LanceQueryBuilder object.
"""
self._refine_factor = refine_factor
return self
def to_df(self) -> "pd.DataFrame":
"""
Execute the query and return the results as a pandas DataFrame.
In addition to the selected columns, LanceDB also returns a vector
and also the "_distance" column which is the distance between the query
vector and the returned vector.
"""
return self.to_arrow().to_pandas()
def to_arrow(self) -> pa.Table:
"""
Execute the query and return the results as an
@@ -233,25 +330,12 @@ class LanceQueryBuilder:
)
return self._table._execute_query(query)
def to_pydantic(self, model: Type[LanceModel]) -> List[LanceModel]:
"""Return the table as a list of pydantic models.
Parameters
----------
model: Type[LanceModel]
The pydantic model to use.
Returns
-------
List[LanceModel]
"""
return [
model(**{k: v for k, v in row.items() if k in model.field_names()})
for row in self.to_arrow().to_pylist()
]
class LanceFtsQueryBuilder(LanceQueryBuilder):
def __init__(self, table: "lancedb.table.Table", query: str):
super().__init__(table)
self._query = query
def to_arrow(self) -> pa.Table:
try:
import tantivy
@@ -275,3 +359,13 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
output_tbl = self._table.to_lance().take(row_ids, columns=self._columns)
output_tbl = output_tbl.append_column("score", scores)
return output_tbl
class LanceEmptyQueryBuilder(LanceQueryBuilder):
def to_arrow(self) -> pa.Table:
ds = self._table.to_lance()
return ds.to_table(
columns=self._columns,
filter=self._where,
limit=self._limit,
)