mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-08 12:52:58 +00:00
[python] Use pydantic for embedding function persistence (#467)
1. Support persistent embedding function so users can just search using query string 2. Add fixed size list conversion for multiple vector columns 3. Add support for empty query (just apply select/where/limit). 4. Refactor and simplify some of the data prep code --------- Co-authored-by: Chang She <chang@lancedb.com> Co-authored-by: Weston Pace <weston.pace@gmail.com>
This commit is contained in:
@@ -13,6 +13,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Literal, Optional, Type, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -54,7 +55,164 @@ class Query(pydantic.BaseModel):
|
||||
refine_factor: Optional[int] = None
|
||||
|
||||
|
||||
class LanceQueryBuilder:
|
||||
class LanceQueryBuilder(ABC):
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
table: "lancedb.table.Table",
|
||||
query: Optional[Union[np.ndarray, str]],
|
||||
query_type: str,
|
||||
vector_column_name: str,
|
||||
) -> LanceQueryBuilder:
|
||||
if query is None:
|
||||
return LanceEmptyQueryBuilder(table)
|
||||
|
||||
query, query_type = cls._resolve_query(
|
||||
table, query, query_type, vector_column_name
|
||||
)
|
||||
|
||||
if isinstance(query, str):
|
||||
# fts
|
||||
return LanceFtsQueryBuilder(table, query)
|
||||
|
||||
if isinstance(query, list):
|
||||
query = np.array(query, dtype=np.float32)
|
||||
elif isinstance(query, np.ndarray):
|
||||
query = query.astype(np.float32)
|
||||
else:
|
||||
raise TypeError(f"Unsupported query type: {type(query)}")
|
||||
|
||||
return LanceVectorQueryBuilder(table, query, vector_column_name)
|
||||
|
||||
@classmethod
|
||||
def _resolve_query(cls, table, query, query_type, vector_column_name):
|
||||
# If query_type is fts, then query must be a string.
|
||||
# otherwise raise TypeError
|
||||
if query_type == "fts":
|
||||
if not isinstance(query, str):
|
||||
raise TypeError(
|
||||
f"Query type is 'fts' but query is not a string: {type(query)}"
|
||||
)
|
||||
return query, query_type
|
||||
elif query_type == "vector":
|
||||
# If query_type is vector, then query must be a list or np.ndarray.
|
||||
# otherwise raise TypeError
|
||||
if not isinstance(query, (list, np.ndarray)):
|
||||
raise TypeError(
|
||||
f"Query type is 'vector' but query is not a list or np.ndarray: {type(query)}"
|
||||
)
|
||||
return query, query_type
|
||||
elif query_type == "auto":
|
||||
if isinstance(query, (list, np.ndarray)):
|
||||
return query, "vector"
|
||||
elif isinstance(query, str):
|
||||
func = table.embedding_functions.get(vector_column_name, None)
|
||||
if func is not None:
|
||||
query = func(query)[0]
|
||||
return query, "vector"
|
||||
else:
|
||||
return query, "fts"
|
||||
else:
|
||||
raise TypeError("Query must be a list, np.ndarray, or str")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid query_type, must be 'vector', 'fts', or 'auto': {query_type}"
|
||||
)
|
||||
|
||||
def __init__(self, table: "lancedb.table.Table"):
|
||||
self._table = table
|
||||
self._limit = 10
|
||||
self._columns = None
|
||||
self._where = None
|
||||
|
||||
def to_df(self) -> "pd.DataFrame":
|
||||
"""
|
||||
Execute the query and return the results as a pandas DataFrame.
|
||||
In addition to the selected columns, LanceDB also returns a vector
|
||||
and also the "_distance" column which is the distance between the query
|
||||
vector and the returned vector.
|
||||
"""
|
||||
return self.to_arrow().to_pandas()
|
||||
|
||||
@abstractmethod
|
||||
def to_arrow(self) -> pa.Table:
|
||||
"""
|
||||
Execute the query and return the results as an
|
||||
[Apache Arrow Table](https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table).
|
||||
|
||||
In addition to the selected columns, LanceDB also returns a vector
|
||||
and also the "_distance" column which is the distance between the query
|
||||
vector and the returned vectors.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def to_pydantic(self, model: Type[LanceModel]) -> List[LanceModel]:
|
||||
"""Return the table as a list of pydantic models.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model: Type[LanceModel]
|
||||
The pydantic model to use.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[LanceModel]
|
||||
"""
|
||||
return [
|
||||
model(**{k: v for k, v in row.items() if k in model.field_names()})
|
||||
for row in self.to_arrow().to_pylist()
|
||||
]
|
||||
|
||||
def limit(self, limit: int) -> LanceVectorQueryBuilder:
|
||||
"""Set the maximum number of results to return.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
limit: int
|
||||
The maximum number of results to return.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceVectorQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._limit = limit
|
||||
return self
|
||||
|
||||
def select(self, columns: list) -> LanceVectorQueryBuilder:
|
||||
"""Set the columns to return.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
columns: list
|
||||
The columns to return.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceVectorQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._columns = columns
|
||||
return self
|
||||
|
||||
def where(self, where: str) -> LanceVectorQueryBuilder:
|
||||
"""Set the where clause.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
where: str
|
||||
The where clause.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceVectorQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._where = where
|
||||
return self
|
||||
|
||||
|
||||
class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
"""
|
||||
A builder for nearest neighbor queries for LanceDB.
|
||||
|
||||
@@ -80,68 +238,17 @@ class LanceQueryBuilder:
|
||||
def __init__(
|
||||
self,
|
||||
table: "lancedb.table.Table",
|
||||
query: Union[np.ndarray, str],
|
||||
query: Union[np.ndarray, list],
|
||||
vector_column: str = VECTOR_COLUMN_NAME,
|
||||
):
|
||||
super().__init__(table)
|
||||
self._query = query
|
||||
self._metric = "L2"
|
||||
self._nprobes = 20
|
||||
self._refine_factor = None
|
||||
self._table = table
|
||||
self._query = query
|
||||
self._limit = 10
|
||||
self._columns = None
|
||||
self._where = None
|
||||
self._vector_column = vector_column
|
||||
|
||||
def limit(self, limit: int) -> LanceQueryBuilder:
|
||||
"""Set the maximum number of results to return.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
limit: int
|
||||
The maximum number of results to return.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._limit = limit
|
||||
return self
|
||||
|
||||
def select(self, columns: list) -> LanceQueryBuilder:
|
||||
"""Set the columns to return.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
columns: list
|
||||
The columns to return.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._columns = columns
|
||||
return self
|
||||
|
||||
def where(self, where: str) -> LanceQueryBuilder:
|
||||
"""Set the where clause.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
where: str
|
||||
The where clause.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._where = where
|
||||
return self
|
||||
|
||||
def metric(self, metric: Literal["L2", "cosine"]) -> LanceQueryBuilder:
|
||||
def metric(self, metric: Literal["L2", "cosine"]) -> LanceVectorQueryBuilder:
|
||||
"""Set the distance metric to use.
|
||||
|
||||
Parameters
|
||||
@@ -151,13 +258,13 @@ class LanceQueryBuilder:
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceQueryBuilder
|
||||
LanceVectorQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._metric = metric
|
||||
return self
|
||||
|
||||
def nprobes(self, nprobes: int) -> LanceQueryBuilder:
|
||||
def nprobes(self, nprobes: int) -> LanceVectorQueryBuilder:
|
||||
"""Set the number of probes to use.
|
||||
|
||||
Higher values will yield better recall (more likely to find vectors if
|
||||
@@ -173,13 +280,13 @@ class LanceQueryBuilder:
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceQueryBuilder
|
||||
LanceVectorQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._nprobes = nprobes
|
||||
return self
|
||||
|
||||
def refine_factor(self, refine_factor: int) -> LanceQueryBuilder:
|
||||
def refine_factor(self, refine_factor: int) -> LanceVectorQueryBuilder:
|
||||
"""Set the refine factor to use, increasing the number of vectors sampled.
|
||||
|
||||
As an example, a refine factor of 2 will sample 2x as many vectors as
|
||||
@@ -195,22 +302,12 @@ class LanceQueryBuilder:
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceQueryBuilder
|
||||
LanceVectorQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._refine_factor = refine_factor
|
||||
return self
|
||||
|
||||
def to_df(self) -> "pd.DataFrame":
|
||||
"""
|
||||
Execute the query and return the results as a pandas DataFrame.
|
||||
In addition to the selected columns, LanceDB also returns a vector
|
||||
and also the "_distance" column which is the distance between the query
|
||||
vector and the returned vector.
|
||||
"""
|
||||
|
||||
return self.to_arrow().to_pandas()
|
||||
|
||||
def to_arrow(self) -> pa.Table:
|
||||
"""
|
||||
Execute the query and return the results as an
|
||||
@@ -233,25 +330,12 @@ class LanceQueryBuilder:
|
||||
)
|
||||
return self._table._execute_query(query)
|
||||
|
||||
def to_pydantic(self, model: Type[LanceModel]) -> List[LanceModel]:
|
||||
"""Return the table as a list of pydantic models.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model: Type[LanceModel]
|
||||
The pydantic model to use.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[LanceModel]
|
||||
"""
|
||||
return [
|
||||
model(**{k: v for k, v in row.items() if k in model.field_names()})
|
||||
for row in self.to_arrow().to_pylist()
|
||||
]
|
||||
|
||||
|
||||
class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
def __init__(self, table: "lancedb.table.Table", query: str):
|
||||
super().__init__(table)
|
||||
self._query = query
|
||||
|
||||
def to_arrow(self) -> pa.Table:
|
||||
try:
|
||||
import tantivy
|
||||
@@ -275,3 +359,13 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
output_tbl = self._table.to_lance().take(row_ids, columns=self._columns)
|
||||
output_tbl = output_tbl.append_column("score", scores)
|
||||
return output_tbl
|
||||
|
||||
|
||||
class LanceEmptyQueryBuilder(LanceQueryBuilder):
|
||||
def to_arrow(self) -> pa.Table:
|
||||
ds = self._table.to_lance()
|
||||
return ds.to_table(
|
||||
columns=self._columns,
|
||||
filter=self._where,
|
||||
limit=self._limit,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user