From 64aeee84a863caf55aec5bec0dff69d9bf820a89 Mon Sep 17 00:00:00 2001 From: Shengan Zhang Date: Thu, 14 May 2026 15:24:52 -0700 Subject: [PATCH] feat(python): support `bytes` in `lit()` expressions (#3387) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Closes #3261. ## Summary Adds `bytes` to the accepted types of `lancedb.expr.lit()` so that binary scalars can be used in filter / projection expressions. The previous attempt in #3235 had to be reverted because DataFusion's SQL unparser does not support `Binary` / `LargeBinary` scalars, so any expression containing such a literal would fail in both `to_sql()` and `__repr__`. ## How `expr_to_sql_string` now has two paths: - **Fast path** (no binary literals): delegate to DataFusion's unparser unchanged. - **Slow path**: rewrite each `Binary(Some(bytes))` literal in the tree to a unique string-literal placeholder, run the unparser, then substitute `''` with `X''` in the resulting SQL. `Binary(None)` / `LargeBinary(None)` are rewritten to `ScalarValue::Null` so the unparser emits plain `NULL`. This keeps DataFusion as the single source of truth for operator and function serialization, so binary literals work in every expression node type the unparser already supports — including nested cases like `contains(col("data"), lit(b"\xff"))`, `NOT (col == lit(b"..."))`, and `col.cast(...) == lit(b"...")`. ## Changes - `rust/lancedb/src/expr/sql.rs`: placeholder-substitution implementation. - `rust/lancedb/src/expr.rs`: 4 new unit tests covering binary literals in equality, compound predicates, scalar function calls, negation, and `NULL` binary literals. - `python/src/expr.rs`: `expr_lit` accepts `PyBytes` and produces `ScalarValue::Binary`. - `python/Cargo.toml` + `Cargo.lock`: pull in `datafusion-common` for `ScalarValue`. - `python/python/lancedb/expr.py`: extend `ExprLike` and `lit()` type annotations / docstrings with `bytes`. - `python/python/lancedb/_lancedb.pyi`: update `expr_lit` stub. - `python/tests/test_expr.py`: unit tests for `to_sql` / `repr` of binary literals and an integration test against a real `pa.binary()` column for equality / inequality / compound filters. ## Example ```python from lancedb.expr import col, lit, func # Equality against a binary column col("payload") == lit(b"\xca\xfe") # Expr((payload = X'CAFE')) # Nested inside a function call (previously failed) func("contains", col("data"), lit(b"\xff")) # Expr(contains(data, X'FF')) # repr() no longer crashes repr(lit(b"\xde\xad\xbe\xef")) # "Expr(X'DEADBEEF')" ``` ## Verification - [x] `cargo test -p lancedb --lib expr::` — 12/12 pass (was 9; +3 new tests) - [x] `cargo check --features remote --tests --examples` — clean - [x] `cargo clippy --features remote --tests --examples` — no warnings - [x] `cargo fmt --all -- --check` — clean - [x] `pytest python/tests/test_expr.py` — 76/76 pass (was 74; +2 new tests) - [x] `ruff check python` / `ruff format --check python` — clean ## Follow-ups (not in this PR) Issue #3261 also raises the possibility of a *truncated* `__repr__` for very large binary literals. This PR keeps `__repr__` exact (it forwards to `to_sql()`), since truncating display output would diverge from the SQL that actually gets executed. A display-only truncation could be added in a follow-up by giving `__repr__` its own renderer. Made with [Cursor](https://cursor.com) Co-authored-by: Cursor --- Cargo.lock | 1 + python/Cargo.toml | 1 + python/python/lancedb/_lancedb.pyi | 2 +- python/python/lancedb/expr.py | 6 +-- python/src/expr.rs | 10 +++- python/tests/test_expr.py | 86 ++++++++++++++++++++++++++++++ rust/lancedb/src/expr.rs | 65 ++++++++++++++++++++++ rust/lancedb/src/expr/sql.rs | 79 ++++++++++++++++++++++++++- 8 files changed, 243 insertions(+), 7 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4b9d561bc..ba1014d56 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5115,6 +5115,7 @@ dependencies = [ "arrow", "async-trait", "bytes", + "datafusion-common", "env_logger", "futures", "lance-core", diff --git a/python/Cargo.toml b/python/Cargo.toml index fce27e65a..417ed523d 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -19,6 +19,7 @@ arrow = { version = "58.0.0", features = ["pyarrow"] } async-trait = "0.1" bytes = "1" lancedb = { path = "../rust/lancedb", default-features = false } +datafusion-common.workspace = true lance-core.workspace = true lance-namespace.workspace = true lance-namespace-impls.workspace = true diff --git a/python/python/lancedb/_lancedb.pyi b/python/python/lancedb/_lancedb.pyi index 8811723e2..d6a8d71d6 100644 --- a/python/python/lancedb/_lancedb.pyi +++ b/python/python/lancedb/_lancedb.pyi @@ -51,7 +51,7 @@ class PyExpr: def to_sql(self) -> str: ... def expr_col(name: str) -> PyExpr: ... -def expr_lit(value: Union[bool, int, float, str]) -> PyExpr: ... +def expr_lit(value: Union[bool, int, float, str, bytes]) -> PyExpr: ... def expr_func(name: str, args: List[PyExpr]) -> PyExpr: ... class Session: diff --git a/python/python/lancedb/expr.py b/python/python/lancedb/expr.py index 5a568d66a..08cb91c50 100644 --- a/python/python/lancedb/expr.py +++ b/python/python/lancedb/expr.py @@ -63,7 +63,7 @@ def _coerce(value: "ExprLike") -> "Expr": # Type alias used in annotations. -ExprLike = Union["Expr", bool, int, float, str] +ExprLike = Union["Expr", bool, int, float, str, bytes] class Expr: @@ -261,13 +261,13 @@ def col(name: str) -> Expr: return Expr(expr_col(name)) -def lit(value: Union[bool, int, float, str]) -> Expr: +def lit(value: Union[bool, int, float, str, bytes]) -> Expr: """Create a literal (constant) value expression. Parameters ---------- value: - A Python ``bool``, ``int``, ``float``, or ``str``. + A Python ``bool``, ``int``, ``float``, ``str``, or ``bytes``. Examples -------- diff --git a/python/src/expr.rs b/python/src/expr.rs index 7d29fcd05..b322c5bdf 100644 --- a/python/src/expr.rs +++ b/python/src/expr.rs @@ -8,7 +8,9 @@ //! DataFusion [`Expr`] nodes, bypassing SQL string parsing. use arrow::{datatypes::DataType, pyarrow::PyArrowType}; +use datafusion_common::ScalarValue; use lancedb::expr::{DfExpr, col as ldb_col, contains, expr_cast, lit as df_lit, lower, upper}; +use pyo3::types::PyBytes; use pyo3::{Bound, PyAny, PyResult, exceptions::PyValueError, prelude::*, pyfunction}; /// A type-safe DataFusion expression. @@ -141,7 +143,7 @@ pub fn expr_col(name: &str) -> PyExpr { /// Create a literal value expression. /// -/// Supported Python types: `bool`, `int`, `float`, `str`. +/// Supported Python types: `bool`, `int`, `float`, `str`, `bytes`. #[pyfunction] pub fn expr_lit(value: Bound<'_, PyAny>) -> PyResult { // bool must be checked before int because bool is a subclass of int in Python @@ -157,8 +159,12 @@ pub fn expr_lit(value: Bound<'_, PyAny>) -> PyResult { if let Ok(s) = value.extract::() { return Ok(PyExpr(df_lit(s))); } + if value.is_instance_of::() { + let bytes = value.extract::>()?; + return Ok(PyExpr(df_lit(ScalarValue::Binary(Some(bytes))))); + } Err(PyValueError::new_err(format!( - "unsupported literal type: {}. Supported: bool, int, float, str", + "unsupported literal type: {}. Supported: bool, int, float, str, bytes", value.get_type().name()? ))) } diff --git a/python/tests/test_expr.py b/python/tests/test_expr.py index 339aca323..c4f68b1e2 100644 --- a/python/tests/test_expr.py +++ b/python/tests/test_expr.py @@ -33,6 +33,14 @@ class TestExprConstruction: e = lit(True) assert isinstance(e, Expr) + def test_lit_bytes(self): + e = lit(b"\xde\xad\xbe\xef") + assert isinstance(e, Expr) + + def test_lit_bytes_empty(self): + e = lit(b"") + assert isinstance(e, Expr) + def test_lit_unsupported_type_raises(self): with pytest.raises(Exception): lit([1, 2, 3]) @@ -135,6 +143,43 @@ class TestExprOperators: assert e.to_sql() == "(name = 'alice')" +class TestExprBytesLiteral: + def test_bytes_to_sql(self): + e = lit(b"\xde\xad\xbe\xef") + assert e.to_sql() == "X'DEADBEEF'" + + def test_empty_bytes_to_sql(self): + e = lit(b"") + assert e.to_sql() == "X''" + + def test_bytes_repr(self): + e = lit(b"\x01\x02") + assert repr(e) == "Expr(X'0102')" + + def test_bytes_equality_expr_sql(self): + e = col("data") == lit(b"\xca\xfe") + assert e.to_sql() == "(data = X'CAFE')" + + def test_bytes_ne_expr_sql(self): + e = col("data") != lit(b"\xff") + assert e.to_sql() == "(data <> X'FF')" + + def test_bytes_compound_expr_sql(self): + e = (col("data") == lit(b"\x01")) & (col("id") > lit(5)) + assert e.to_sql() == "((data = X'01') AND (id > 5))" + + def test_bytes_in_function_call(self): + # Regression test: binary literals inside scalar function calls + # used to fail because DataFusion's unparser does not support Binary + # scalars. Now handled via a placeholder-substitution rewrite. + e = func("contains", col("data"), lit(b"\xff")) + assert e.to_sql() == "contains(data, X'FF')" + + def test_bytes_in_not(self): + e = ~(col("data") == lit(b"\xff")) + assert e.to_sql() == "NOT (data = X'FF')" + + class TestExprStringMethods: def test_lower(self): e = col("name").lower() @@ -385,3 +430,44 @@ class TestColNamingIntegration: ) assert "upper_name" in result.schema.names assert sorted(result["upper_name"].to_pylist()) == ["ALICE", "BOB", "CHARLIE"] + + +# ── bytes / binary column integration tests ─────────────────────────────────── + + +@pytest.fixture +def binary_table(tmp_path): + db = lancedb.connect(str(tmp_path)) + data = pa.table( + { + "id": [1, 2, 3], + "payload": pa.array( + [b"\x01\x02", b"\xca\xfe", b"\xff\x00"], + type=pa.binary(), + ), + } + ) + return db.create_table("binary_test", data) + + +class TestExprBytesIntegration: + def test_binary_equality_filter(self, binary_table): + result = ( + binary_table.search().where(col("payload") == lit(b"\xca\xfe")).to_arrow() + ) + assert result.num_rows == 1 + assert result["id"][0].as_py() == 2 + + def test_binary_ne_filter(self, binary_table): + result = ( + binary_table.search().where(col("payload") != lit(b"\x01\x02")).to_arrow() + ) + assert result.num_rows == 2 + + def test_binary_compound_filter(self, binary_table): + result = ( + binary_table.search() + .where((col("payload") == lit(b"\x01\x02")) | (col("id") == lit(3))) + .to_arrow() + ) + assert result.num_rows == 2 diff --git a/rust/lancedb/src/expr.rs b/rust/lancedb/src/expr.rs index 4d700497f..02b6b7d08 100644 --- a/rust/lancedb/src/expr.rs +++ b/rust/lancedb/src/expr.rs @@ -138,4 +138,69 @@ mod tests { let sql = expr_to_sql_string(&expr).unwrap(); assert!(sql.contains("price")); } + + #[test] + fn test_binary_literal() { + use datafusion_common::ScalarValue; + let expr = lit(ScalarValue::Binary(Some(vec![0xde, 0xad, 0xbe, 0xef]))); + let sql = expr_to_sql_string(&expr).unwrap(); + assert_eq!(sql, "X'DEADBEEF'"); + } + + #[test] + fn test_binary_literal_in_filter() { + use datafusion_common::ScalarValue; + let expr = col("data").eq(lit(ScalarValue::Binary(Some(vec![0xca, 0xfe])))); + let sql = expr_to_sql_string(&expr).unwrap(); + assert_eq!(sql, "(data = X'CAFE')"); + } + + #[test] + fn test_binary_literal_compound() { + use datafusion_common::ScalarValue; + let bin_expr = col("data").eq(lit(ScalarValue::Binary(Some(vec![0x01])))); + let int_expr = col("id").gt(lit(5i64)); + let combined = bin_expr.and(int_expr); + let sql = expr_to_sql_string(&combined).unwrap(); + assert_eq!(sql, "((data = X'01') AND (id > 5))"); + } + + #[test] + fn test_null_binary_literal() { + use datafusion_common::ScalarValue; + let expr = lit(ScalarValue::Binary(None)); + let sql = expr_to_sql_string(&expr).unwrap(); + assert_eq!(sql, "NULL"); + } + + #[test] + fn test_binary_literal_in_function_call() { + use datafusion_common::ScalarValue; + // Binary literals inside scalar function arguments must also be + // serialized correctly (regression test for placeholder rewrite path). + let expr = contains(col("data"), lit(ScalarValue::Binary(Some(vec![0xff])))); + let sql = expr_to_sql_string(&expr).unwrap(); + assert_eq!(sql, "contains(data, X'FF')"); + } + + #[test] + fn test_binary_literal_in_negation() { + use datafusion_common::ScalarValue; + use std::ops::Not; + let expr = col("data") + .eq(lit(ScalarValue::Binary(Some(vec![0xab, 0xcd])))) + .not(); + let sql = expr_to_sql_string(&expr).unwrap(); + assert_eq!(sql, "NOT (data = X'ABCD')"); + } + + #[test] + fn test_multiple_binary_literals() { + use datafusion_common::ScalarValue; + let lhs = col("a").eq(lit(ScalarValue::Binary(Some(vec![0x01])))); + let rhs = col("b").eq(lit(ScalarValue::Binary(Some(vec![0x02, 0x03])))); + let expr = lhs.and(rhs); + let sql = expr_to_sql_string(&expr).unwrap(); + assert_eq!(sql, "((a = X'01') AND (b = X'0203'))"); + } } diff --git a/rust/lancedb/src/expr/sql.rs b/rust/lancedb/src/expr/sql.rs index f9cd81b50..23b89821a 100644 --- a/rust/lancedb/src/expr/sql.rs +++ b/rust/lancedb/src/expr/sql.rs @@ -1,6 +1,8 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The LanceDB Authors +use datafusion_common::ScalarValue; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; use datafusion_expr::Expr; use datafusion_sql::unparser::{self, dialect::Dialect}; @@ -28,7 +30,36 @@ impl Dialect for LanceSqlDialect { } } -pub fn expr_to_sql_string(expr: &Expr) -> crate::Result { +/// Prefix for placeholder strings inserted in place of binary literals. Chosen +/// to be extremely unlikely to occur in user data. +const BINARY_PLACEHOLDER_PREFIX: &str = "__lancedb_binary_placeholder_"; + +fn bytes_to_hex_sql(bytes: &[u8]) -> String { + let hex: String = bytes.iter().map(|b| format!("{b:02X}")).collect(); + format!("X'{hex}'") +} + +/// Returns true if *expr* contains a `Binary` or `LargeBinary` scalar literal +/// anywhere in its subtree. DataFusion's SQL unparser cannot serialize those +/// variants, so we route such expressions through a placeholder-substitution +/// path that emits SQL `X'...'` byte-string literals. +fn has_binary_literal(expr: &Expr) -> bool { + let mut found = false; + let _ = expr.apply(&mut |e: &Expr| { + if matches!( + e, + Expr::Literal(ScalarValue::Binary(_) | ScalarValue::LargeBinary(_), _) + ) { + found = true; + Ok(TreeNodeRecursion::Stop) + } else { + Ok(TreeNodeRecursion::Continue) + } + }); + found +} + +fn run_unparser(expr: &Expr) -> crate::Result { let ast = unparser::Unparser::new(&LanceSqlDialect) .expr_to_sql(expr) .map_err(|e| crate::Error::InvalidInput { @@ -36,3 +67,49 @@ pub fn expr_to_sql_string(expr: &Expr) -> crate::Result { })?; Ok(ast.to_string()) } + +pub fn expr_to_sql_string(expr: &Expr) -> crate::Result { + // Fast path: no binary literals — DataFusion's unparser handles everything. + if !has_binary_literal(expr) { + return run_unparser(expr); + } + + // Slow path: DataFusion's unparser cannot serialize `Binary`/`LargeBinary` + // scalars, so we rewrite each one to a unique string-literal placeholder, + // let the unparser do the rest of the work, then substitute the SQL + // `X'...'` byte-string literal back in. This keeps the operator/function + // serialization logic centralized in DataFusion and works for every + // expression node type the unparser supports. + let mut bindings: Vec> = Vec::new(); + let rewritten = expr + .clone() + .transform(|e: Expr| match e { + Expr::Literal(ScalarValue::Binary(Some(bytes)), m) + | Expr::Literal(ScalarValue::LargeBinary(Some(bytes)), m) => { + let placeholder = format!("{}{}__", BINARY_PLACEHOLDER_PREFIX, bindings.len()); + bindings.push(bytes); + Ok(Transformed::yes(Expr::Literal( + ScalarValue::Utf8(Some(placeholder)), + m, + ))) + } + Expr::Literal(ScalarValue::Binary(None), m) + | Expr::Literal(ScalarValue::LargeBinary(None), m) => { + Ok(Transformed::yes(Expr::Literal(ScalarValue::Null, m))) + } + other => Ok(Transformed::no(other)), + }) + .map_err(|e| crate::Error::InvalidInput { + message: format!("failed to rewrite expression: {}", e), + })? + .data; + + let mut sql = run_unparser(&rewritten)?; + for (i, bytes) in bindings.iter().enumerate() { + // The unparser quotes string literals with single quotes, so the + // placeholder appears as `'__lancedb_binary_placeholder___'`. + let quoted = format!("'{}{}__'", BINARY_PLACEHOLDER_PREFIX, i); + sql = sql.replace("ed, &bytes_to_hex_sql(bytes)); + } + Ok(sql) +}