feat(python): support bytes in lit() expressions (#3387)

Closes #3261.

## Summary

Adds `bytes` to the accepted types of `lancedb.expr.lit()` so that
binary scalars can be used in filter / projection expressions. The
previous attempt in #3235 had to be reverted because DataFusion's SQL
unparser does not support `Binary` / `LargeBinary` scalars, so any
expression containing such a literal would fail in both `to_sql()` and
`__repr__`.

## How

`expr_to_sql_string` now has two paths:

- **Fast path** (no binary literals): delegate to DataFusion's unparser
unchanged.
- **Slow path**: rewrite each `Binary(Some(bytes))` literal in the tree
to a unique string-literal placeholder, run the unparser, then
substitute `'<placeholder>'` with `X'<HEX>'` in the resulting SQL.
`Binary(None)` / `LargeBinary(None)` are rewritten to
`ScalarValue::Null` so the unparser emits plain `NULL`.

This keeps DataFusion as the single source of truth for operator and
function serialization, so binary literals work in every expression node
type the unparser already supports — including nested cases like
`contains(col("data"), lit(b"\xff"))`, `NOT (col == lit(b"..."))`, and
`col.cast(...) == lit(b"...")`.

## Changes

- `rust/lancedb/src/expr/sql.rs`: placeholder-substitution
implementation.
- `rust/lancedb/src/expr.rs`: 4 new unit tests covering binary literals
in equality, compound predicates, scalar function calls, negation, and
`NULL` binary literals.
- `python/src/expr.rs`: `expr_lit` accepts `PyBytes` and produces
`ScalarValue::Binary`.
- `python/Cargo.toml` + `Cargo.lock`: pull in `datafusion-common` for
`ScalarValue`.
- `python/python/lancedb/expr.py`: extend `ExprLike` and `lit()` type
annotations / docstrings with `bytes`.
- `python/python/lancedb/_lancedb.pyi`: update `expr_lit` stub.
- `python/tests/test_expr.py`: unit tests for `to_sql` / `repr` of
binary literals and an integration test against a real `pa.binary()`
column for equality / inequality / compound filters.

## Example

```python
from lancedb.expr import col, lit, func

# Equality against a binary column
col("payload") == lit(b"\xca\xfe")
# Expr((payload = X'CAFE'))

# Nested inside a function call (previously failed)
func("contains", col("data"), lit(b"\xff"))
# Expr(contains(data, X'FF'))

# repr() no longer crashes
repr(lit(b"\xde\xad\xbe\xef"))
# "Expr(X'DEADBEEF')"
```

## Verification

- [x] `cargo test -p lancedb --lib expr::` — 12/12 pass (was 9; +3 new
tests)
- [x] `cargo check --features remote --tests --examples` — clean
- [x] `cargo clippy --features remote --tests --examples` — no warnings
- [x] `cargo fmt --all -- --check` — clean
- [x] `pytest python/tests/test_expr.py` — 76/76 pass (was 74; +2 new
tests)
- [x] `ruff check python` / `ruff format --check python` — clean

## Follow-ups (not in this PR)

Issue #3261 also raises the possibility of a *truncated* `__repr__` for
very large binary literals. This PR keeps `__repr__` exact (it forwards
to `to_sql()`), since truncating display output would diverge from the
SQL that actually gets executed. A display-only truncation could be
added in a follow-up by giving `__repr__` its own renderer.

Made with [Cursor](https://cursor.com)

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
Shengan Zhang
2026-05-14 15:24:52 -07:00
committed by GitHub
parent 5b45e44ce3
commit 64aeee84a8
8 changed files with 243 additions and 7 deletions

View File

@@ -19,6 +19,7 @@ arrow = { version = "58.0.0", features = ["pyarrow"] }
async-trait = "0.1"
bytes = "1"
lancedb = { path = "../rust/lancedb", default-features = false }
datafusion-common.workspace = true
lance-core.workspace = true
lance-namespace.workspace = true
lance-namespace-impls.workspace = true

View File

@@ -51,7 +51,7 @@ class PyExpr:
def to_sql(self) -> str: ...
def expr_col(name: str) -> PyExpr: ...
def expr_lit(value: Union[bool, int, float, str]) -> PyExpr: ...
def expr_lit(value: Union[bool, int, float, str, bytes]) -> PyExpr: ...
def expr_func(name: str, args: List[PyExpr]) -> PyExpr: ...
class Session:

View File

@@ -63,7 +63,7 @@ def _coerce(value: "ExprLike") -> "Expr":
# Type alias used in annotations.
ExprLike = Union["Expr", bool, int, float, str]
ExprLike = Union["Expr", bool, int, float, str, bytes]
class Expr:
@@ -261,13 +261,13 @@ def col(name: str) -> Expr:
return Expr(expr_col(name))
def lit(value: Union[bool, int, float, str]) -> Expr:
def lit(value: Union[bool, int, float, str, bytes]) -> Expr:
"""Create a literal (constant) value expression.
Parameters
----------
value:
A Python ``bool``, ``int``, ``float``, or ``str``.
A Python ``bool``, ``int``, ``float``, ``str``, or ``bytes``.
Examples
--------

View File

@@ -8,7 +8,9 @@
//! DataFusion [`Expr`] nodes, bypassing SQL string parsing.
use arrow::{datatypes::DataType, pyarrow::PyArrowType};
use datafusion_common::ScalarValue;
use lancedb::expr::{DfExpr, col as ldb_col, contains, expr_cast, lit as df_lit, lower, upper};
use pyo3::types::PyBytes;
use pyo3::{Bound, PyAny, PyResult, exceptions::PyValueError, prelude::*, pyfunction};
/// A type-safe DataFusion expression.
@@ -141,7 +143,7 @@ pub fn expr_col(name: &str) -> PyExpr {
/// Create a literal value expression.
///
/// Supported Python types: `bool`, `int`, `float`, `str`.
/// Supported Python types: `bool`, `int`, `float`, `str`, `bytes`.
#[pyfunction]
pub fn expr_lit(value: Bound<'_, PyAny>) -> PyResult<PyExpr> {
// bool must be checked before int because bool is a subclass of int in Python
@@ -157,8 +159,12 @@ pub fn expr_lit(value: Bound<'_, PyAny>) -> PyResult<PyExpr> {
if let Ok(s) = value.extract::<String>() {
return Ok(PyExpr(df_lit(s)));
}
if value.is_instance_of::<PyBytes>() {
let bytes = value.extract::<Vec<u8>>()?;
return Ok(PyExpr(df_lit(ScalarValue::Binary(Some(bytes)))));
}
Err(PyValueError::new_err(format!(
"unsupported literal type: {}. Supported: bool, int, float, str",
"unsupported literal type: {}. Supported: bool, int, float, str, bytes",
value.get_type().name()?
)))
}

View File

@@ -33,6 +33,14 @@ class TestExprConstruction:
e = lit(True)
assert isinstance(e, Expr)
def test_lit_bytes(self):
e = lit(b"\xde\xad\xbe\xef")
assert isinstance(e, Expr)
def test_lit_bytes_empty(self):
e = lit(b"")
assert isinstance(e, Expr)
def test_lit_unsupported_type_raises(self):
with pytest.raises(Exception):
lit([1, 2, 3])
@@ -135,6 +143,43 @@ class TestExprOperators:
assert e.to_sql() == "(name = 'alice')"
class TestExprBytesLiteral:
def test_bytes_to_sql(self):
e = lit(b"\xde\xad\xbe\xef")
assert e.to_sql() == "X'DEADBEEF'"
def test_empty_bytes_to_sql(self):
e = lit(b"")
assert e.to_sql() == "X''"
def test_bytes_repr(self):
e = lit(b"\x01\x02")
assert repr(e) == "Expr(X'0102')"
def test_bytes_equality_expr_sql(self):
e = col("data") == lit(b"\xca\xfe")
assert e.to_sql() == "(data = X'CAFE')"
def test_bytes_ne_expr_sql(self):
e = col("data") != lit(b"\xff")
assert e.to_sql() == "(data <> X'FF')"
def test_bytes_compound_expr_sql(self):
e = (col("data") == lit(b"\x01")) & (col("id") > lit(5))
assert e.to_sql() == "((data = X'01') AND (id > 5))"
def test_bytes_in_function_call(self):
# Regression test: binary literals inside scalar function calls
# used to fail because DataFusion's unparser does not support Binary
# scalars. Now handled via a placeholder-substitution rewrite.
e = func("contains", col("data"), lit(b"\xff"))
assert e.to_sql() == "contains(data, X'FF')"
def test_bytes_in_not(self):
e = ~(col("data") == lit(b"\xff"))
assert e.to_sql() == "NOT (data = X'FF')"
class TestExprStringMethods:
def test_lower(self):
e = col("name").lower()
@@ -385,3 +430,44 @@ class TestColNamingIntegration:
)
assert "upper_name" in result.schema.names
assert sorted(result["upper_name"].to_pylist()) == ["ALICE", "BOB", "CHARLIE"]
# ── bytes / binary column integration tests ───────────────────────────────────
@pytest.fixture
def binary_table(tmp_path):
db = lancedb.connect(str(tmp_path))
data = pa.table(
{
"id": [1, 2, 3],
"payload": pa.array(
[b"\x01\x02", b"\xca\xfe", b"\xff\x00"],
type=pa.binary(),
),
}
)
return db.create_table("binary_test", data)
class TestExprBytesIntegration:
def test_binary_equality_filter(self, binary_table):
result = (
binary_table.search().where(col("payload") == lit(b"\xca\xfe")).to_arrow()
)
assert result.num_rows == 1
assert result["id"][0].as_py() == 2
def test_binary_ne_filter(self, binary_table):
result = (
binary_table.search().where(col("payload") != lit(b"\x01\x02")).to_arrow()
)
assert result.num_rows == 2
def test_binary_compound_filter(self, binary_table):
result = (
binary_table.search()
.where((col("payload") == lit(b"\x01\x02")) | (col("id") == lit(3)))
.to_arrow()
)
assert result.num_rows == 2