mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-16 19:40:40 +00:00
feat(python): add type-safe expression builder API (#3150)
Introduces col(), lit(), func(), and Expr class as alternatives to raw SQL strings in .where() and .select(). Expressions are backed by DataFusion's Expr AST and serialized to SQL for remote table compat. Resolves: - https://github.com/lancedb/lancedb/issues/3044 (python api's) - https://github.com/lancedb/lancedb/issues/3043 (support for filter) - https://github.com/lancedb/lancedb/issues/3045 (support for projection) --------- Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2
python/.gitignore
vendored
2
python/.gitignore
vendored
@@ -1,3 +1,5 @@
|
||||
# Test data created by some example tests
|
||||
data/
|
||||
_lancedb.pyd
|
||||
# macOS debug symbols bundle generated during build
|
||||
*.dSYM/
|
||||
|
||||
@@ -18,6 +18,7 @@ from .db import AsyncConnection, DBConnection, LanceDBConnection
|
||||
from .io import StorageOptionsProvider
|
||||
from .remote import ClientConfig
|
||||
from .remote.db import RemoteDBConnection
|
||||
from .expr import Expr, col, lit, func
|
||||
from .schema import vector
|
||||
from .table import AsyncTable, Table
|
||||
from ._lancedb import Session
|
||||
@@ -271,6 +272,10 @@ __all__ = [
|
||||
"AsyncConnection",
|
||||
"AsyncLanceNamespaceDBConnection",
|
||||
"AsyncTable",
|
||||
"col",
|
||||
"Expr",
|
||||
"func",
|
||||
"lit",
|
||||
"URI",
|
||||
"sanitize_uri",
|
||||
"vector",
|
||||
|
||||
@@ -27,6 +27,32 @@ from .remote import ClientConfig
|
||||
IvfHnswPq: type[HnswPq] = HnswPq
|
||||
IvfHnswSq: type[HnswSq] = HnswSq
|
||||
|
||||
class PyExpr:
|
||||
"""A type-safe DataFusion expression node (Rust-side handle)."""
|
||||
|
||||
def eq(self, other: "PyExpr") -> "PyExpr": ...
|
||||
def ne(self, other: "PyExpr") -> "PyExpr": ...
|
||||
def lt(self, other: "PyExpr") -> "PyExpr": ...
|
||||
def lte(self, other: "PyExpr") -> "PyExpr": ...
|
||||
def gt(self, other: "PyExpr") -> "PyExpr": ...
|
||||
def gte(self, other: "PyExpr") -> "PyExpr": ...
|
||||
def and_(self, other: "PyExpr") -> "PyExpr": ...
|
||||
def or_(self, other: "PyExpr") -> "PyExpr": ...
|
||||
def not_(self) -> "PyExpr": ...
|
||||
def add(self, other: "PyExpr") -> "PyExpr": ...
|
||||
def sub(self, other: "PyExpr") -> "PyExpr": ...
|
||||
def mul(self, other: "PyExpr") -> "PyExpr": ...
|
||||
def div(self, other: "PyExpr") -> "PyExpr": ...
|
||||
def lower(self) -> "PyExpr": ...
|
||||
def upper(self) -> "PyExpr": ...
|
||||
def contains(self, substr: "PyExpr") -> "PyExpr": ...
|
||||
def cast(self, data_type: pa.DataType) -> "PyExpr": ...
|
||||
def to_sql(self) -> str: ...
|
||||
|
||||
def expr_col(name: str) -> PyExpr: ...
|
||||
def expr_lit(value: Union[bool, int, float, str]) -> PyExpr: ...
|
||||
def expr_func(name: str, args: List[PyExpr]) -> PyExpr: ...
|
||||
|
||||
class Session:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -225,7 +251,9 @@ class RecordBatchStream:
|
||||
|
||||
class Query:
|
||||
def where(self, filter: str): ...
|
||||
def select(self, columns: Tuple[str, str]): ...
|
||||
def where_expr(self, expr: PyExpr): ...
|
||||
def select(self, columns: List[Tuple[str, str]]): ...
|
||||
def select_expr(self, columns: List[Tuple[str, PyExpr]]): ...
|
||||
def select_columns(self, columns: List[str]): ...
|
||||
def limit(self, limit: int): ...
|
||||
def offset(self, offset: int): ...
|
||||
@@ -251,7 +279,9 @@ class TakeQuery:
|
||||
|
||||
class FTSQuery:
|
||||
def where(self, filter: str): ...
|
||||
def select(self, columns: List[str]): ...
|
||||
def where_expr(self, expr: PyExpr): ...
|
||||
def select(self, columns: List[Tuple[str, str]]): ...
|
||||
def select_expr(self, columns: List[Tuple[str, PyExpr]]): ...
|
||||
def limit(self, limit: int): ...
|
||||
def offset(self, offset: int): ...
|
||||
def fast_search(self): ...
|
||||
@@ -270,7 +300,9 @@ class VectorQuery:
|
||||
async def output_schema(self) -> pa.Schema: ...
|
||||
async def execute(self) -> RecordBatchStream: ...
|
||||
def where(self, filter: str): ...
|
||||
def select(self, columns: List[str]): ...
|
||||
def where_expr(self, expr: PyExpr): ...
|
||||
def select(self, columns: List[Tuple[str, str]]): ...
|
||||
def select_expr(self, columns: List[Tuple[str, PyExpr]]): ...
|
||||
def select_with_projection(self, columns: Tuple[str, str]): ...
|
||||
def limit(self, limit: int): ...
|
||||
def offset(self, offset: int): ...
|
||||
@@ -287,7 +319,9 @@ class VectorQuery:
|
||||
|
||||
class HybridQuery:
|
||||
def where(self, filter: str): ...
|
||||
def select(self, columns: List[str]): ...
|
||||
def where_expr(self, expr: PyExpr): ...
|
||||
def select(self, columns: List[Tuple[str, str]]): ...
|
||||
def select_expr(self, columns: List[Tuple[str, PyExpr]]): ...
|
||||
def limit(self, limit: int): ...
|
||||
def offset(self, offset: int): ...
|
||||
def fast_search(self): ...
|
||||
|
||||
298
python/python/lancedb/expr.py
Normal file
298
python/python/lancedb/expr.py
Normal file
@@ -0,0 +1,298 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
"""Type-safe expression builder for filters and projections.
|
||||
|
||||
Instead of writing raw SQL strings you can build expressions with Python
|
||||
operators::
|
||||
|
||||
from lancedb.expr import col, lit
|
||||
|
||||
# filter: age > 18 AND status = 'active'
|
||||
filt = (col("age") > lit(18)) & (col("status") == lit("active"))
|
||||
|
||||
# projection: compute a derived column
|
||||
proj = {"score": col("raw_score") * lit(1.5)}
|
||||
|
||||
table.search().where(filt).select(proj).to_list()
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Union
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
from lancedb._lancedb import PyExpr, expr_col, expr_lit, expr_func
|
||||
|
||||
__all__ = ["Expr", "col", "lit", "func"]
|
||||
|
||||
_STR_TO_PA_TYPE: dict = {
|
||||
"bool": pa.bool_(),
|
||||
"boolean": pa.bool_(),
|
||||
"int8": pa.int8(),
|
||||
"int16": pa.int16(),
|
||||
"int32": pa.int32(),
|
||||
"int64": pa.int64(),
|
||||
"uint8": pa.uint8(),
|
||||
"uint16": pa.uint16(),
|
||||
"uint32": pa.uint32(),
|
||||
"uint64": pa.uint64(),
|
||||
"float16": pa.float16(),
|
||||
"float32": pa.float32(),
|
||||
"float": pa.float32(),
|
||||
"float64": pa.float64(),
|
||||
"double": pa.float64(),
|
||||
"string": pa.string(),
|
||||
"utf8": pa.string(),
|
||||
"str": pa.string(),
|
||||
"large_string": pa.large_utf8(),
|
||||
"large_utf8": pa.large_utf8(),
|
||||
"date32": pa.date32(),
|
||||
"date": pa.date32(),
|
||||
"date64": pa.date64(),
|
||||
}
|
||||
|
||||
|
||||
def _coerce(value: "ExprLike") -> "Expr":
|
||||
"""Return *value* as an :class:`Expr`, wrapping plain Python values via
|
||||
:func:`lit` if needed."""
|
||||
if isinstance(value, Expr):
|
||||
return value
|
||||
return lit(value)
|
||||
|
||||
|
||||
# Type alias used in annotations.
|
||||
ExprLike = Union["Expr", bool, int, float, str]
|
||||
|
||||
|
||||
class Expr:
|
||||
"""A type-safe expression node.
|
||||
|
||||
Construct instances with :func:`col` and :func:`lit`, then combine them
|
||||
using Python operators or the named methods below.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from lancedb.expr import col, lit
|
||||
>>> filt = (col("age") > lit(18)) & (col("name").lower() == lit("alice"))
|
||||
>>> proj = {"double": col("x") * lit(2)}
|
||||
"""
|
||||
|
||||
# Make Expr unhashable so that == returns an Expr rather than being used
|
||||
# for dict keys / set membership.
|
||||
__hash__ = None # type: ignore[assignment]
|
||||
|
||||
def __init__(self, inner: PyExpr) -> None:
|
||||
self._inner = inner
|
||||
|
||||
# ── comparisons ──────────────────────────────────────────────────────────
|
||||
|
||||
def __eq__(self, other: ExprLike) -> "Expr": # type: ignore[override]
|
||||
"""Equal to (``col("x") == 1``)."""
|
||||
return Expr(self._inner.eq(_coerce(other)._inner))
|
||||
|
||||
def __ne__(self, other: ExprLike) -> "Expr": # type: ignore[override]
|
||||
"""Not equal to (``col("x") != 1``)."""
|
||||
return Expr(self._inner.ne(_coerce(other)._inner))
|
||||
|
||||
def __lt__(self, other: ExprLike) -> "Expr":
|
||||
"""Less than (``col("x") < 1``)."""
|
||||
return Expr(self._inner.lt(_coerce(other)._inner))
|
||||
|
||||
def __le__(self, other: ExprLike) -> "Expr":
|
||||
"""Less than or equal to (``col("x") <= 1``)."""
|
||||
return Expr(self._inner.lte(_coerce(other)._inner))
|
||||
|
||||
def __gt__(self, other: ExprLike) -> "Expr":
|
||||
"""Greater than (``col("x") > 1``)."""
|
||||
return Expr(self._inner.gt(_coerce(other)._inner))
|
||||
|
||||
def __ge__(self, other: ExprLike) -> "Expr":
|
||||
"""Greater than or equal to (``col("x") >= 1``)."""
|
||||
return Expr(self._inner.gte(_coerce(other)._inner))
|
||||
|
||||
# ── logical ──────────────────────────────────────────────────────────────
|
||||
|
||||
def __and__(self, other: "Expr") -> "Expr":
|
||||
"""Logical AND (``expr_a & expr_b``)."""
|
||||
return Expr(self._inner.and_(_coerce(other)._inner))
|
||||
|
||||
def __or__(self, other: "Expr") -> "Expr":
|
||||
"""Logical OR (``expr_a | expr_b``)."""
|
||||
return Expr(self._inner.or_(_coerce(other)._inner))
|
||||
|
||||
def __invert__(self) -> "Expr":
|
||||
"""Logical NOT (``~expr``)."""
|
||||
return Expr(self._inner.not_())
|
||||
|
||||
# ── arithmetic ───────────────────────────────────────────────────────────
|
||||
|
||||
def __add__(self, other: ExprLike) -> "Expr":
|
||||
"""Add (``col("x") + 1``)."""
|
||||
return Expr(self._inner.add(_coerce(other)._inner))
|
||||
|
||||
def __radd__(self, other: ExprLike) -> "Expr":
|
||||
"""Right-hand add (``1 + col("x")``)."""
|
||||
return Expr(_coerce(other)._inner.add(self._inner))
|
||||
|
||||
def __sub__(self, other: ExprLike) -> "Expr":
|
||||
"""Subtract (``col("x") - 1``)."""
|
||||
return Expr(self._inner.sub(_coerce(other)._inner))
|
||||
|
||||
def __rsub__(self, other: ExprLike) -> "Expr":
|
||||
"""Right-hand subtract (``1 - col("x")``)."""
|
||||
return Expr(_coerce(other)._inner.sub(self._inner))
|
||||
|
||||
def __mul__(self, other: ExprLike) -> "Expr":
|
||||
"""Multiply (``col("x") * 2``)."""
|
||||
return Expr(self._inner.mul(_coerce(other)._inner))
|
||||
|
||||
def __rmul__(self, other: ExprLike) -> "Expr":
|
||||
"""Right-hand multiply (``2 * col("x")``)."""
|
||||
return Expr(_coerce(other)._inner.mul(self._inner))
|
||||
|
||||
def __truediv__(self, other: ExprLike) -> "Expr":
|
||||
"""Divide (``col("x") / 2``)."""
|
||||
return Expr(self._inner.div(_coerce(other)._inner))
|
||||
|
||||
def __rtruediv__(self, other: ExprLike) -> "Expr":
|
||||
"""Right-hand divide (``1 / col("x")``)."""
|
||||
return Expr(_coerce(other)._inner.div(self._inner))
|
||||
|
||||
# ── string methods ───────────────────────────────────────────────────────
|
||||
|
||||
def lower(self) -> "Expr":
|
||||
"""Convert string column values to lowercase."""
|
||||
return Expr(self._inner.lower())
|
||||
|
||||
def upper(self) -> "Expr":
|
||||
"""Convert string column values to uppercase."""
|
||||
return Expr(self._inner.upper())
|
||||
|
||||
def contains(self, substr: "ExprLike") -> "Expr":
|
||||
"""Return True where the string contains *substr*."""
|
||||
return Expr(self._inner.contains(_coerce(substr)._inner))
|
||||
|
||||
# ── type cast ────────────────────────────────────────────────────────────
|
||||
|
||||
def cast(self, data_type: Union[str, "pa.DataType"]) -> "Expr":
|
||||
"""Cast values to *data_type*.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data_type:
|
||||
A PyArrow ``DataType`` (e.g. ``pa.int32()``) or one of the type
|
||||
name strings: ``"bool"``, ``"int8"``, ``"int16"``, ``"int32"``,
|
||||
``"int64"``, ``"uint8"``–``"uint64"``, ``"float32"``,
|
||||
``"float64"``, ``"string"``, ``"date32"``, ``"date64"``.
|
||||
"""
|
||||
if isinstance(data_type, str):
|
||||
try:
|
||||
data_type = _STR_TO_PA_TYPE[data_type]
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
f"unsupported data type: '{data_type}'. Supported: "
|
||||
f"{', '.join(_STR_TO_PA_TYPE)}"
|
||||
)
|
||||
return Expr(self._inner.cast(data_type))
|
||||
|
||||
# ── named comparison helpers (alternative to operators) ──────────────────
|
||||
|
||||
def eq(self, other: ExprLike) -> "Expr":
|
||||
"""Equal to."""
|
||||
return self.__eq__(other)
|
||||
|
||||
def ne(self, other: ExprLike) -> "Expr":
|
||||
"""Not equal to."""
|
||||
return self.__ne__(other)
|
||||
|
||||
def lt(self, other: ExprLike) -> "Expr":
|
||||
"""Less than."""
|
||||
return self.__lt__(other)
|
||||
|
||||
def lte(self, other: ExprLike) -> "Expr":
|
||||
"""Less than or equal to."""
|
||||
return self.__le__(other)
|
||||
|
||||
def gt(self, other: ExprLike) -> "Expr":
|
||||
"""Greater than."""
|
||||
return self.__gt__(other)
|
||||
|
||||
def gte(self, other: ExprLike) -> "Expr":
|
||||
"""Greater than or equal to."""
|
||||
return self.__ge__(other)
|
||||
|
||||
def and_(self, other: "Expr") -> "Expr":
|
||||
"""Logical AND."""
|
||||
return self.__and__(other)
|
||||
|
||||
def or_(self, other: "Expr") -> "Expr":
|
||||
"""Logical OR."""
|
||||
return self.__or__(other)
|
||||
|
||||
# ── utilities ────────────────────────────────────────────────────────────
|
||||
|
||||
def to_sql(self) -> str:
|
||||
"""Render the expression as a SQL string (useful for debugging)."""
|
||||
return self._inner.to_sql()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Expr({self._inner.to_sql()})"
|
||||
|
||||
|
||||
# ── free functions ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def col(name: str) -> Expr:
|
||||
"""Reference a table column by name.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name:
|
||||
The column name.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from lancedb.expr import col, lit
|
||||
>>> col("age") > lit(18)
|
||||
Expr((age > 18))
|
||||
"""
|
||||
return Expr(expr_col(name))
|
||||
|
||||
|
||||
def lit(value: Union[bool, int, float, str]) -> Expr:
|
||||
"""Create a literal (constant) value expression.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
value:
|
||||
A Python ``bool``, ``int``, ``float``, or ``str``.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from lancedb.expr import col, lit
|
||||
>>> col("price") * lit(1.1)
|
||||
Expr((price * 1.1))
|
||||
"""
|
||||
return Expr(expr_lit(value))
|
||||
|
||||
|
||||
def func(name: str, *args: ExprLike) -> Expr:
|
||||
"""Call an arbitrary SQL function by name.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name:
|
||||
The SQL function name (e.g. ``"lower"``, ``"upper"``).
|
||||
*args:
|
||||
The function arguments as :class:`Expr` or plain Python literals.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from lancedb.expr import col, func
|
||||
>>> func("lower", col("name"))
|
||||
Expr(lower(name))
|
||||
"""
|
||||
inner_args = [_coerce(a)._inner for a in args]
|
||||
return Expr(expr_func(name, inner_args))
|
||||
@@ -38,6 +38,7 @@ from .rerankers.base import Reranker
|
||||
from .rerankers.rrf import RRFReranker
|
||||
from .rerankers.util import check_reranker_result
|
||||
from .util import flatten_columns
|
||||
from .expr import Expr
|
||||
from lancedb._lancedb import fts_query_to_json
|
||||
from typing_extensions import Annotated
|
||||
|
||||
@@ -449,8 +450,8 @@ class Query(pydantic.BaseModel):
|
||||
ensure_vector_query,
|
||||
] = None
|
||||
|
||||
# sql filter to refine the query with
|
||||
filter: Optional[str] = None
|
||||
# sql filter or type-safe Expr to refine the query with
|
||||
filter: Optional[Union[str, Expr]] = None
|
||||
|
||||
# if True then apply the filter after vector search
|
||||
postfilter: Optional[bool] = None
|
||||
@@ -464,8 +465,8 @@ class Query(pydantic.BaseModel):
|
||||
# distance type to use for vector search
|
||||
distance_type: Optional[str] = None
|
||||
|
||||
# which columns to return in the results
|
||||
columns: Optional[Union[List[str], Dict[str, str]]] = None
|
||||
# which columns to return in the results (dict values may be str or Expr)
|
||||
columns: Optional[Union[List[str], Dict[str, Union[str, Expr]]]] = None
|
||||
|
||||
# minimum number of IVF partitions to search
|
||||
#
|
||||
@@ -856,14 +857,15 @@ class LanceQueryBuilder(ABC):
|
||||
self._offset = offset
|
||||
return self
|
||||
|
||||
def select(self, columns: Union[list[str], dict[str, str]]) -> Self:
|
||||
def select(self, columns: Union[list[str], dict[str, Union[str, Expr]]]) -> Self:
|
||||
"""Set the columns to return.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
columns: list of str, or dict of str to str default None
|
||||
columns: list of str, or dict of str to str or Expr
|
||||
List of column names to be fetched.
|
||||
Or a dictionary of column names to SQL expressions.
|
||||
Or a dictionary of column names to SQL expressions or
|
||||
:class:`~lancedb.expr.Expr` objects.
|
||||
All columns are fetched if None or unspecified.
|
||||
|
||||
Returns
|
||||
@@ -877,15 +879,15 @@ class LanceQueryBuilder(ABC):
|
||||
raise ValueError("columns must be a list or a dictionary")
|
||||
return self
|
||||
|
||||
def where(self, where: str, prefilter: bool = True) -> Self:
|
||||
def where(self, where: Union[str, Expr], prefilter: bool = True) -> Self:
|
||||
"""Set the where clause.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
where: str
|
||||
The where clause which is a valid SQL where clause. See
|
||||
`Lance filter pushdown <https://lance.org/guide/read_and_write#filter-push-down>`_
|
||||
for valid SQL expressions.
|
||||
where: str or :class:`~lancedb.expr.Expr`
|
||||
The filter condition. Can be a SQL string or a type-safe
|
||||
:class:`~lancedb.expr.Expr` built with :func:`~lancedb.expr.col`
|
||||
and :func:`~lancedb.expr.lit`.
|
||||
prefilter: bool, default True
|
||||
If True, apply the filter before vector search, otherwise the
|
||||
filter is applied on the result of vector search.
|
||||
@@ -1355,15 +1357,17 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
|
||||
return result_set
|
||||
|
||||
def where(self, where: str, prefilter: bool = None) -> LanceVectorQueryBuilder:
|
||||
def where(
|
||||
self, where: Union[str, Expr], prefilter: bool = None
|
||||
) -> LanceVectorQueryBuilder:
|
||||
"""Set the where clause.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
where: str
|
||||
The where clause which is a valid SQL where clause. See
|
||||
`Lance filter pushdown <https://lance.org/guide/read_and_write#filter-push-down>`_
|
||||
for valid SQL expressions.
|
||||
where: str or :class:`~lancedb.expr.Expr`
|
||||
The filter condition. Can be a SQL string or a type-safe
|
||||
:class:`~lancedb.expr.Expr` built with :func:`~lancedb.expr.col`
|
||||
and :func:`~lancedb.expr.lit`.
|
||||
prefilter: bool, default True
|
||||
If True, apply the filter before vector search, otherwise the
|
||||
filter is applied on the result of vector search.
|
||||
@@ -2286,10 +2290,20 @@ class AsyncQueryBase(object):
|
||||
"""
|
||||
if isinstance(columns, list) and all(isinstance(c, str) for c in columns):
|
||||
self._inner.select_columns(columns)
|
||||
elif isinstance(columns, dict) and all(
|
||||
isinstance(k, str) and isinstance(v, str) for k, v in columns.items()
|
||||
):
|
||||
self._inner.select(list(columns.items()))
|
||||
elif isinstance(columns, dict) and all(isinstance(k, str) for k in columns):
|
||||
if any(isinstance(v, Expr) for v in columns.values()):
|
||||
# At least one value is an Expr — use the type-safe path.
|
||||
from .expr import _coerce
|
||||
|
||||
pairs = [(k, _coerce(v)._inner) for k, v in columns.items()]
|
||||
self._inner.select_expr(pairs)
|
||||
elif all(isinstance(v, str) for v in columns.values()):
|
||||
self._inner.select(list(columns.items()))
|
||||
else:
|
||||
raise TypeError(
|
||||
"dict values must be str or Expr, got "
|
||||
+ str({k: type(v) for k, v in columns.items()})
|
||||
)
|
||||
else:
|
||||
raise TypeError("columns must be a list of column names or a dict")
|
||||
return self
|
||||
@@ -2529,11 +2543,13 @@ class AsyncStandardQuery(AsyncQueryBase):
|
||||
"""
|
||||
super().__init__(inner)
|
||||
|
||||
def where(self, predicate: str) -> Self:
|
||||
def where(self, predicate: Union[str, Expr]) -> Self:
|
||||
"""
|
||||
Only return rows matching the given predicate
|
||||
|
||||
The predicate should be supplied as an SQL query string.
|
||||
The predicate can be a SQL string or a type-safe
|
||||
:class:`~lancedb.expr.Expr` built with :func:`~lancedb.expr.col`
|
||||
and :func:`~lancedb.expr.lit`.
|
||||
|
||||
Examples
|
||||
--------
|
||||
@@ -2545,7 +2561,10 @@ class AsyncStandardQuery(AsyncQueryBase):
|
||||
Filtering performance can often be improved by creating a scalar index
|
||||
on the filter column(s).
|
||||
"""
|
||||
self._inner.where(predicate)
|
||||
if isinstance(predicate, Expr):
|
||||
self._inner.where_expr(predicate._inner)
|
||||
else:
|
||||
self._inner.where(predicate)
|
||||
return self
|
||||
|
||||
def limit(self, limit: int) -> Self:
|
||||
|
||||
@@ -4211,7 +4211,7 @@ class AsyncTable:
|
||||
async_query = async_query.offset(query.offset)
|
||||
if query.columns:
|
||||
async_query = async_query.select(query.columns)
|
||||
if query.filter:
|
||||
if query.filter is not None:
|
||||
async_query = async_query.where(query.filter)
|
||||
if query.fast_search:
|
||||
async_query = async_query.fast_search()
|
||||
|
||||
175
python/src/expr.rs
Normal file
175
python/src/expr.rs
Normal file
@@ -0,0 +1,175 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
//! PyO3 bindings for the LanceDB expression builder API.
|
||||
//!
|
||||
//! This module exposes [`PyExpr`] and helper free functions so Python can
|
||||
//! build type-safe filter / projection expressions that map directly to
|
||||
//! DataFusion [`Expr`] nodes, bypassing SQL string parsing.
|
||||
|
||||
use arrow::{datatypes::DataType, pyarrow::PyArrowType};
|
||||
use lancedb::expr::{DfExpr, col as ldb_col, contains, expr_cast, lit as df_lit, lower, upper};
|
||||
use pyo3::{Bound, PyAny, PyResult, exceptions::PyValueError, prelude::*, pyfunction};
|
||||
|
||||
/// A type-safe DataFusion expression.
|
||||
///
|
||||
/// Instances are constructed via the free functions [`expr_col`] and
|
||||
/// [`expr_lit`] and combined with the methods on this struct. On the Python
|
||||
/// side a thin wrapper class (`lancedb.expr.Expr`) delegates to these methods
|
||||
/// and adds Python operator overloads.
|
||||
#[pyclass(name = "PyExpr")]
|
||||
#[derive(Clone)]
|
||||
pub struct PyExpr(pub DfExpr);
|
||||
|
||||
#[pymethods]
|
||||
impl PyExpr {
|
||||
// ── comparisons ──────────────────────────────────────────────────────────
|
||||
|
||||
fn eq(&self, other: &Self) -> Self {
|
||||
Self(self.0.clone().eq(other.0.clone()))
|
||||
}
|
||||
|
||||
fn ne(&self, other: &Self) -> Self {
|
||||
Self(self.0.clone().not_eq(other.0.clone()))
|
||||
}
|
||||
|
||||
fn lt(&self, other: &Self) -> Self {
|
||||
Self(self.0.clone().lt(other.0.clone()))
|
||||
}
|
||||
|
||||
fn lte(&self, other: &Self) -> Self {
|
||||
Self(self.0.clone().lt_eq(other.0.clone()))
|
||||
}
|
||||
|
||||
fn gt(&self, other: &Self) -> Self {
|
||||
Self(self.0.clone().gt(other.0.clone()))
|
||||
}
|
||||
|
||||
fn gte(&self, other: &Self) -> Self {
|
||||
Self(self.0.clone().gt_eq(other.0.clone()))
|
||||
}
|
||||
|
||||
// ── logical ──────────────────────────────────────────────────────────────
|
||||
|
||||
fn and_(&self, other: &Self) -> Self {
|
||||
Self(self.0.clone().and(other.0.clone()))
|
||||
}
|
||||
|
||||
fn or_(&self, other: &Self) -> Self {
|
||||
Self(self.0.clone().or(other.0.clone()))
|
||||
}
|
||||
|
||||
fn not_(&self) -> Self {
|
||||
use std::ops::Not;
|
||||
Self(self.0.clone().not())
|
||||
}
|
||||
|
||||
// ── arithmetic ───────────────────────────────────────────────────────────
|
||||
|
||||
fn add(&self, other: &Self) -> Self {
|
||||
use std::ops::Add;
|
||||
Self(self.0.clone().add(other.0.clone()))
|
||||
}
|
||||
|
||||
fn sub(&self, other: &Self) -> Self {
|
||||
use std::ops::Sub;
|
||||
Self(self.0.clone().sub(other.0.clone()))
|
||||
}
|
||||
|
||||
fn mul(&self, other: &Self) -> Self {
|
||||
use std::ops::Mul;
|
||||
Self(self.0.clone().mul(other.0.clone()))
|
||||
}
|
||||
|
||||
fn div(&self, other: &Self) -> Self {
|
||||
use std::ops::Div;
|
||||
Self(self.0.clone().div(other.0.clone()))
|
||||
}
|
||||
|
||||
// ── string functions ─────────────────────────────────────────────────────
|
||||
|
||||
/// Convert string column to lowercase.
|
||||
fn lower(&self) -> Self {
|
||||
Self(lower(self.0.clone()))
|
||||
}
|
||||
|
||||
/// Convert string column to uppercase.
|
||||
fn upper(&self) -> Self {
|
||||
Self(upper(self.0.clone()))
|
||||
}
|
||||
|
||||
/// Test whether the string contains `substr`.
|
||||
fn contains(&self, substr: &Self) -> Self {
|
||||
Self(contains(self.0.clone(), substr.0.clone()))
|
||||
}
|
||||
|
||||
// ── type cast ────────────────────────────────────────────────────────────
|
||||
|
||||
/// Cast the expression to `data_type`.
|
||||
///
|
||||
/// `data_type` must be a PyArrow `DataType` (e.g. `pa.int32()`).
|
||||
/// On the Python side, `lancedb.expr.Expr.cast` also accepts type name
|
||||
/// strings via `pa.lib.ensure_type` before forwarding here.
|
||||
fn cast(&self, data_type: PyArrowType<DataType>) -> Self {
|
||||
Self(expr_cast(self.0.clone(), data_type.0))
|
||||
}
|
||||
|
||||
// ── utilities ────────────────────────────────────────────────────────────
|
||||
|
||||
/// Render the expression as a SQL string (useful for debugging).
|
||||
fn to_sql(&self) -> PyResult<String> {
|
||||
lancedb::expr::expr_to_sql_string(&self.0).map_err(|e| PyValueError::new_err(e.to_string()))
|
||||
}
|
||||
|
||||
fn __repr__(&self) -> PyResult<String> {
|
||||
let sql =
|
||||
lancedb::expr::expr_to_sql_string(&self.0).unwrap_or_else(|_| "<expr>".to_string());
|
||||
Ok(format!("PyExpr({})", sql))
|
||||
}
|
||||
}
|
||||
|
||||
// ── free functions ────────────────────────────────────────────────────────────
|
||||
|
||||
/// Create a column reference expression.
|
||||
///
|
||||
/// The column name is preserved exactly as given (case-sensitive), so
|
||||
/// `col("firstName")` correctly references a field named `firstName`.
|
||||
#[pyfunction]
|
||||
pub fn expr_col(name: &str) -> PyExpr {
|
||||
PyExpr(ldb_col(name))
|
||||
}
|
||||
|
||||
/// Create a literal value expression.
|
||||
///
|
||||
/// Supported Python types: `bool`, `int`, `float`, `str`.
|
||||
#[pyfunction]
|
||||
pub fn expr_lit(value: Bound<'_, PyAny>) -> PyResult<PyExpr> {
|
||||
// bool must be checked before int because bool is a subclass of int in Python
|
||||
if let Ok(b) = value.extract::<bool>() {
|
||||
return Ok(PyExpr(df_lit(b)));
|
||||
}
|
||||
if let Ok(i) = value.extract::<i64>() {
|
||||
return Ok(PyExpr(df_lit(i)));
|
||||
}
|
||||
if let Ok(f) = value.extract::<f64>() {
|
||||
return Ok(PyExpr(df_lit(f)));
|
||||
}
|
||||
if let Ok(s) = value.extract::<String>() {
|
||||
return Ok(PyExpr(df_lit(s)));
|
||||
}
|
||||
Err(PyValueError::new_err(format!(
|
||||
"unsupported literal type: {}. Supported: bool, int, float, str",
|
||||
value.get_type().name()?
|
||||
)))
|
||||
}
|
||||
|
||||
/// Call an arbitrary registered SQL function by name.
|
||||
///
|
||||
/// See `lancedb::expr::func` for the list of supported function names.
|
||||
#[pyfunction]
|
||||
pub fn expr_func(name: &str, args: Vec<PyExpr>) -> PyResult<PyExpr> {
|
||||
let df_args: Vec<DfExpr> = args.into_iter().map(|e| e.0).collect();
|
||||
lancedb::expr::func(name, df_args)
|
||||
.map(PyExpr)
|
||||
.map_err(|e| PyValueError::new_err(e.to_string()))
|
||||
}
|
||||
@@ -4,6 +4,7 @@
|
||||
use arrow::RecordBatchStream;
|
||||
use connection::{Connection, connect};
|
||||
use env_logger::Env;
|
||||
use expr::{PyExpr, expr_col, expr_func, expr_lit};
|
||||
use index::IndexConfig;
|
||||
use permutation::{PyAsyncPermutationBuilder, PyPermutationReader};
|
||||
use pyo3::{
|
||||
@@ -21,6 +22,7 @@ use table::{
|
||||
pub mod arrow;
|
||||
pub mod connection;
|
||||
pub mod error;
|
||||
pub mod expr;
|
||||
pub mod header;
|
||||
pub mod index;
|
||||
pub mod namespace;
|
||||
@@ -55,10 +57,14 @@ pub fn _lancedb(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<UpdateResult>()?;
|
||||
m.add_class::<PyAsyncPermutationBuilder>()?;
|
||||
m.add_class::<PyPermutationReader>()?;
|
||||
m.add_class::<PyExpr>()?;
|
||||
m.add_function(wrap_pyfunction!(connect, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(permutation::async_permutation_builder, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(util::validate_table_name, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(query::fts_query_to_json, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(expr_col, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(expr_lit, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(expr_func, m)?)?;
|
||||
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -35,12 +35,10 @@ use pyo3::types::PyList;
|
||||
use pyo3::types::{PyDict, PyString};
|
||||
use pyo3::{FromPyObject, exceptions::PyRuntimeError};
|
||||
use pyo3::{PyErr, pyclass};
|
||||
use pyo3::{
|
||||
exceptions::{PyNotImplementedError, PyValueError},
|
||||
intern,
|
||||
};
|
||||
use pyo3::{exceptions::PyValueError, intern};
|
||||
use pyo3_async_runtimes::tokio::future_into_py;
|
||||
|
||||
use crate::expr::PyExpr;
|
||||
use crate::util::parse_distance_type;
|
||||
use crate::{arrow::RecordBatchStream, util::PyLanceDB};
|
||||
use crate::{error::PythonErrorExt, index::class_name};
|
||||
@@ -344,9 +342,13 @@ impl<'py> IntoPyObject<'py> for PyQueryFilter {
|
||||
|
||||
fn into_pyobject(self, py: pyo3::Python<'py>) -> PyResult<Self::Output> {
|
||||
match self.0 {
|
||||
QueryFilter::Datafusion(_) => Err(PyNotImplementedError::new_err(
|
||||
"Datafusion filter has no conversion to Python",
|
||||
)),
|
||||
QueryFilter::Datafusion(expr) => {
|
||||
// Serialize the DataFusion expression to a SQL string so that
|
||||
// callers (e.g. remote tables) see the same format as Sql.
|
||||
let sql = lancedb::expr::expr_to_sql_string(&expr)
|
||||
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
|
||||
Ok(sql.into_pyobject(py)?.into_any())
|
||||
}
|
||||
QueryFilter::Sql(sql) => Ok(sql.into_pyobject(py)?.into_any()),
|
||||
QueryFilter::Substrait(substrait) => Ok(substrait.into_pyobject(py)?.into_any()),
|
||||
}
|
||||
@@ -370,10 +372,20 @@ impl Query {
|
||||
self.inner = self.inner.clone().only_if(predicate);
|
||||
}
|
||||
|
||||
pub fn where_expr(&mut self, expr: PyExpr) {
|
||||
self.inner = self.inner.clone().only_if_expr(expr.0);
|
||||
}
|
||||
|
||||
pub fn select(&mut self, columns: Vec<(String, String)>) {
|
||||
self.inner = self.inner.clone().select(Select::dynamic(&columns));
|
||||
}
|
||||
|
||||
pub fn select_expr(&mut self, columns: Vec<(String, PyExpr)>) {
|
||||
let pairs: Vec<(String, lancedb::expr::DfExpr)> =
|
||||
columns.into_iter().map(|(name, e)| (name, e.0)).collect();
|
||||
self.inner = self.inner.clone().select(Select::Expr(pairs));
|
||||
}
|
||||
|
||||
pub fn select_columns(&mut self, columns: Vec<String>) {
|
||||
self.inner = self.inner.clone().select(Select::columns(&columns));
|
||||
}
|
||||
@@ -607,10 +619,20 @@ impl FTSQuery {
|
||||
self.inner = self.inner.clone().only_if(predicate);
|
||||
}
|
||||
|
||||
pub fn where_expr(&mut self, expr: PyExpr) {
|
||||
self.inner = self.inner.clone().only_if_expr(expr.0);
|
||||
}
|
||||
|
||||
pub fn select(&mut self, columns: Vec<(String, String)>) {
|
||||
self.inner = self.inner.clone().select(Select::dynamic(&columns));
|
||||
}
|
||||
|
||||
pub fn select_expr(&mut self, columns: Vec<(String, PyExpr)>) {
|
||||
let pairs: Vec<(String, lancedb::expr::DfExpr)> =
|
||||
columns.into_iter().map(|(name, e)| (name, e.0)).collect();
|
||||
self.inner = self.inner.clone().select(Select::Expr(pairs));
|
||||
}
|
||||
|
||||
pub fn select_columns(&mut self, columns: Vec<String>) {
|
||||
self.inner = self.inner.clone().select(Select::columns(&columns));
|
||||
}
|
||||
@@ -725,6 +747,10 @@ impl VectorQuery {
|
||||
self.inner = self.inner.clone().only_if(predicate);
|
||||
}
|
||||
|
||||
pub fn where_expr(&mut self, expr: PyExpr) {
|
||||
self.inner = self.inner.clone().only_if_expr(expr.0);
|
||||
}
|
||||
|
||||
pub fn add_query_vector(&mut self, vector: Bound<'_, PyAny>) -> PyResult<()> {
|
||||
let data: ArrayData = ArrayData::from_pyarrow_bound(&vector)?;
|
||||
let array = make_array(data);
|
||||
@@ -736,6 +762,12 @@ impl VectorQuery {
|
||||
self.inner = self.inner.clone().select(Select::dynamic(&columns));
|
||||
}
|
||||
|
||||
pub fn select_expr(&mut self, columns: Vec<(String, PyExpr)>) {
|
||||
let pairs: Vec<(String, lancedb::expr::DfExpr)> =
|
||||
columns.into_iter().map(|(name, e)| (name, e.0)).collect();
|
||||
self.inner = self.inner.clone().select(Select::Expr(pairs));
|
||||
}
|
||||
|
||||
pub fn select_columns(&mut self, columns: Vec<String>) {
|
||||
self.inner = self.inner.clone().select(Select::columns(&columns));
|
||||
}
|
||||
@@ -890,11 +922,21 @@ impl HybridQuery {
|
||||
self.inner_fts.r#where(predicate);
|
||||
}
|
||||
|
||||
pub fn where_expr(&mut self, expr: PyExpr) {
|
||||
self.inner_vec.where_expr(expr.clone());
|
||||
self.inner_fts.where_expr(expr);
|
||||
}
|
||||
|
||||
pub fn select(&mut self, columns: Vec<(String, String)>) {
|
||||
self.inner_vec.select(columns.clone());
|
||||
self.inner_fts.select(columns);
|
||||
}
|
||||
|
||||
pub fn select_expr(&mut self, columns: Vec<(String, PyExpr)>) {
|
||||
self.inner_vec.select_expr(columns.clone());
|
||||
self.inner_fts.select_expr(columns);
|
||||
}
|
||||
|
||||
pub fn select_columns(&mut self, columns: Vec<String>) {
|
||||
self.inner_vec.select_columns(columns.clone());
|
||||
self.inner_fts.select_columns(columns);
|
||||
|
||||
387
python/tests/test_expr.py
Normal file
387
python/tests/test_expr.py
Normal file
@@ -0,0 +1,387 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
"""Tests for the type-safe expression builder API."""
|
||||
|
||||
import pytest
|
||||
import pyarrow as pa
|
||||
import lancedb
|
||||
from lancedb.expr import Expr, col, lit, func
|
||||
|
||||
|
||||
# ── unit tests for Expr construction ─────────────────────────────────────────
|
||||
|
||||
|
||||
class TestExprConstruction:
|
||||
def test_col_returns_expr(self):
|
||||
e = col("age")
|
||||
assert isinstance(e, Expr)
|
||||
|
||||
def test_lit_int(self):
|
||||
e = lit(42)
|
||||
assert isinstance(e, Expr)
|
||||
|
||||
def test_lit_float(self):
|
||||
e = lit(3.14)
|
||||
assert isinstance(e, Expr)
|
||||
|
||||
def test_lit_str(self):
|
||||
e = lit("hello")
|
||||
assert isinstance(e, Expr)
|
||||
|
||||
def test_lit_bool(self):
|
||||
e = lit(True)
|
||||
assert isinstance(e, Expr)
|
||||
|
||||
def test_lit_unsupported_type_raises(self):
|
||||
with pytest.raises(Exception):
|
||||
lit([1, 2, 3])
|
||||
|
||||
def test_func(self):
|
||||
e = func("lower", col("name"))
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "lower(name)"
|
||||
|
||||
def test_func_unknown_raises(self):
|
||||
with pytest.raises(Exception):
|
||||
func("not_a_real_function", col("x"))
|
||||
|
||||
|
||||
class TestExprOperators:
|
||||
def test_eq_operator(self):
|
||||
e = col("x") == lit(1)
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "(x = 1)"
|
||||
|
||||
def test_ne_operator(self):
|
||||
e = col("x") != lit(1)
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "(x <> 1)"
|
||||
|
||||
def test_lt_operator(self):
|
||||
e = col("age") < lit(18)
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "(age < 18)"
|
||||
|
||||
def test_le_operator(self):
|
||||
e = col("age") <= lit(18)
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "(age <= 18)"
|
||||
|
||||
def test_gt_operator(self):
|
||||
e = col("age") > lit(18)
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "(age > 18)"
|
||||
|
||||
def test_ge_operator(self):
|
||||
e = col("age") >= lit(18)
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "(age >= 18)"
|
||||
|
||||
def test_and_operator(self):
|
||||
e = (col("age") > lit(18)) & (col("status") == lit("active"))
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "((age > 18) AND (status = 'active'))"
|
||||
|
||||
def test_or_operator(self):
|
||||
e = (col("a") == lit(1)) | (col("b") == lit(2))
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "((a = 1) OR (b = 2))"
|
||||
|
||||
def test_invert_operator(self):
|
||||
e = ~(col("active") == lit(True))
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "NOT (active = true)"
|
||||
|
||||
def test_add_operator(self):
|
||||
e = col("x") + lit(1)
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "(x + 1)"
|
||||
|
||||
def test_sub_operator(self):
|
||||
e = col("x") - lit(1)
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "(x - 1)"
|
||||
|
||||
def test_mul_operator(self):
|
||||
e = col("price") * lit(1.1)
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "(price * 1.1)"
|
||||
|
||||
def test_div_operator(self):
|
||||
e = col("total") / lit(2)
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "(total / 2)"
|
||||
|
||||
def test_radd(self):
|
||||
e = lit(1) + col("x")
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "(1 + x)"
|
||||
|
||||
def test_rmul(self):
|
||||
e = lit(2) * col("x")
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "(2 * x)"
|
||||
|
||||
def test_coerce_plain_int(self):
|
||||
# Operators should auto-wrap plain Python values via lit()
|
||||
e = col("age") > 18
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "(age > 18)"
|
||||
|
||||
def test_coerce_plain_str(self):
|
||||
e = col("name") == "alice"
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "(name = 'alice')"
|
||||
|
||||
|
||||
class TestExprStringMethods:
|
||||
def test_lower(self):
|
||||
e = col("name").lower()
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "lower(name)"
|
||||
|
||||
def test_upper(self):
|
||||
e = col("name").upper()
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "upper(name)"
|
||||
|
||||
def test_contains(self):
|
||||
e = col("text").contains(lit("hello"))
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "contains(text, 'hello')"
|
||||
|
||||
def test_contains_with_str_coerce(self):
|
||||
e = col("text").contains("hello")
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "contains(text, 'hello')"
|
||||
|
||||
def test_chained_lower_eq(self):
|
||||
e = col("name").lower() == lit("alice")
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "(lower(name) = 'alice')"
|
||||
|
||||
|
||||
class TestExprCast:
|
||||
def test_cast_string(self):
|
||||
e = col("id").cast("string")
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "CAST(id AS VARCHAR)"
|
||||
|
||||
def test_cast_int32(self):
|
||||
e = col("score").cast("int32")
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "CAST(score AS INTEGER)"
|
||||
|
||||
def test_cast_float64(self):
|
||||
e = col("val").cast("float64")
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "CAST(val AS DOUBLE)"
|
||||
|
||||
def test_cast_pyarrow_type(self):
|
||||
e = col("score").cast(pa.int32())
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "CAST(score AS INTEGER)"
|
||||
|
||||
def test_cast_pyarrow_float64(self):
|
||||
e = col("val").cast(pa.float64())
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "CAST(val AS DOUBLE)"
|
||||
|
||||
def test_cast_pyarrow_string(self):
|
||||
e = col("id").cast(pa.string())
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "CAST(id AS VARCHAR)"
|
||||
|
||||
def test_cast_pyarrow_and_string_equivalent(self):
|
||||
# pa.int32() and "int32" should produce equivalent SQL
|
||||
sql_str = col("x").cast("int32").to_sql()
|
||||
sql_pa = col("x").cast(pa.int32()).to_sql()
|
||||
assert sql_str == sql_pa
|
||||
|
||||
|
||||
class TestExprNamedMethods:
|
||||
def test_eq_method(self):
|
||||
e = col("x").eq(lit(1))
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "(x = 1)"
|
||||
|
||||
def test_gt_method(self):
|
||||
e = col("x").gt(lit(0))
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "(x > 0)"
|
||||
|
||||
def test_and_method(self):
|
||||
e = col("x").gt(lit(0)).and_(col("y").lt(lit(10)))
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "((x > 0) AND (y < 10))"
|
||||
|
||||
def test_or_method(self):
|
||||
e = col("x").eq(lit(1)).or_(col("x").eq(lit(2)))
|
||||
assert isinstance(e, Expr)
|
||||
assert e.to_sql() == "((x = 1) OR (x = 2))"
|
||||
|
||||
|
||||
class TestExprRepr:
|
||||
def test_repr(self):
|
||||
e = col("age") > lit(18)
|
||||
assert repr(e) == "Expr((age > 18))"
|
||||
|
||||
def test_to_sql(self):
|
||||
e = col("age") > 18
|
||||
assert e.to_sql() == "(age > 18)"
|
||||
|
||||
def test_unhashable(self):
|
||||
e = col("x")
|
||||
with pytest.raises(TypeError):
|
||||
{e: 1}
|
||||
|
||||
|
||||
# ── integration tests: end-to-end query against a real table ─────────────────
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_table(tmp_path):
|
||||
db = lancedb.connect(str(tmp_path))
|
||||
data = pa.table(
|
||||
{
|
||||
"id": [1, 2, 3, 4, 5],
|
||||
"name": ["Alice", "Bob", "Charlie", "alice", "BOB"],
|
||||
"age": [25, 17, 30, 22, 15],
|
||||
"score": [1.5, 2.0, 3.5, 4.0, 0.5],
|
||||
}
|
||||
)
|
||||
return db.create_table("test", data)
|
||||
|
||||
|
||||
class TestExprFilter:
|
||||
def test_simple_gt_filter(self, simple_table):
|
||||
result = simple_table.search().where(col("age") > lit(20)).to_arrow()
|
||||
assert result.num_rows == 3 # ages 25, 30, 22
|
||||
|
||||
def test_compound_and_filter(self, simple_table):
|
||||
result = (
|
||||
simple_table.search()
|
||||
.where((col("age") > lit(18)) & (col("score") > lit(2.0)))
|
||||
.to_arrow()
|
||||
)
|
||||
assert result.num_rows == 2 # (30, 3.5) and (22, 4.0)
|
||||
|
||||
def test_string_equality_filter(self, simple_table):
|
||||
result = simple_table.search().where(col("name") == lit("Bob")).to_arrow()
|
||||
assert result.num_rows == 1
|
||||
|
||||
def test_or_filter(self, simple_table):
|
||||
result = (
|
||||
simple_table.search()
|
||||
.where((col("age") < lit(18)) | (col("age") > lit(28)))
|
||||
.to_arrow()
|
||||
)
|
||||
assert result.num_rows == 3 # ages 17, 30, 15
|
||||
|
||||
def test_coercion_no_lit(self, simple_table):
|
||||
# Python values should be auto-coerced
|
||||
result = simple_table.search().where(col("age") > 20).to_arrow()
|
||||
assert result.num_rows == 3
|
||||
|
||||
def test_string_sql_still_works(self, simple_table):
|
||||
# Backwards compatibility: plain strings still accepted
|
||||
result = simple_table.search().where("age > 20").to_arrow()
|
||||
assert result.num_rows == 3
|
||||
|
||||
|
||||
class TestExprProjection:
|
||||
def test_select_with_expr(self, simple_table):
|
||||
result = (
|
||||
simple_table.search()
|
||||
.select({"double_score": col("score") * lit(2)})
|
||||
.to_arrow()
|
||||
)
|
||||
assert "double_score" in result.schema.names
|
||||
|
||||
def test_select_mixed_str_and_expr(self, simple_table):
|
||||
result = (
|
||||
simple_table.search()
|
||||
.select({"id": "id", "double_score": col("score") * lit(2)})
|
||||
.to_arrow()
|
||||
)
|
||||
assert "id" in result.schema.names
|
||||
assert "double_score" in result.schema.names
|
||||
|
||||
def test_select_list_of_columns(self, simple_table):
|
||||
# Plain list of str still works
|
||||
result = simple_table.search().select(["id", "name"]).to_arrow()
|
||||
assert result.schema.names == ["id", "name"]
|
||||
|
||||
|
||||
# ── column name edge cases ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestColNaming:
|
||||
"""Unit tests verifying that col() preserves identifiers exactly.
|
||||
|
||||
Identifiers that need quoting (camelCase, spaces, leading digits, unicode)
|
||||
are wrapped in backticks to match the lance SQL parser's dialect.
|
||||
"""
|
||||
|
||||
def test_camel_case_preserved_in_sql(self):
|
||||
# camelCase is quoted with backticks so the case round-trips correctly.
|
||||
assert col("firstName").to_sql() == "`firstName`"
|
||||
|
||||
def test_camel_case_in_expression(self):
|
||||
assert (col("firstName") > lit(18)).to_sql() == "(`firstName` > 18)"
|
||||
|
||||
def test_space_in_name_quoted(self):
|
||||
assert col("first name").to_sql() == "`first name`"
|
||||
|
||||
def test_space_in_expression(self):
|
||||
assert (col("first name") == lit("A")).to_sql() == "(`first name` = 'A')"
|
||||
|
||||
def test_leading_digit_quoted(self):
|
||||
assert col("2fast").to_sql() == "`2fast`"
|
||||
|
||||
def test_unicode_quoted(self):
|
||||
assert col("名前").to_sql() == "`名前`"
|
||||
|
||||
def test_snake_case_unquoted(self):
|
||||
# Plain snake_case needs no quoting.
|
||||
assert col("first_name").to_sql() == "first_name"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def special_col_table(tmp_path):
|
||||
db = lancedb.connect(str(tmp_path))
|
||||
data = pa.table(
|
||||
{
|
||||
"firstName": ["Alice", "Bob", "Charlie"],
|
||||
"first name": ["A", "B", "C"],
|
||||
"score": [10, 20, 30],
|
||||
}
|
||||
)
|
||||
return db.create_table("special", data)
|
||||
|
||||
|
||||
class TestColNamingIntegration:
|
||||
def test_camel_case_filter(self, special_col_table):
|
||||
result = (
|
||||
special_col_table.search()
|
||||
.where(col("firstName") == lit("Alice"))
|
||||
.to_arrow()
|
||||
)
|
||||
assert result.num_rows == 1
|
||||
assert result["firstName"][0].as_py() == "Alice"
|
||||
|
||||
def test_space_in_col_filter(self, special_col_table):
|
||||
result = (
|
||||
special_col_table.search().where(col("first name") == lit("B")).to_arrow()
|
||||
)
|
||||
assert result.num_rows == 1
|
||||
|
||||
def test_camel_case_projection(self, special_col_table):
|
||||
result = (
|
||||
special_col_table.search()
|
||||
.select({"upper_name": col("firstName").upper()})
|
||||
.to_arrow()
|
||||
)
|
||||
assert "upper_name" in result.schema.names
|
||||
assert sorted(result["upper_name"].to_pylist()) == ["ALICE", "BOB", "CHARLIE"]
|
||||
Reference in New Issue
Block a user