diff --git a/Cargo.lock b/Cargo.lock index 4b9d561bc..ba1014d56 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5115,6 +5115,7 @@ dependencies = [ "arrow", "async-trait", "bytes", + "datafusion-common", "env_logger", "futures", "lance-core", diff --git a/python/Cargo.toml b/python/Cargo.toml index fce27e65a..417ed523d 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -19,6 +19,7 @@ arrow = { version = "58.0.0", features = ["pyarrow"] } async-trait = "0.1" bytes = "1" lancedb = { path = "../rust/lancedb", default-features = false } +datafusion-common.workspace = true lance-core.workspace = true lance-namespace.workspace = true lance-namespace-impls.workspace = true diff --git a/python/python/lancedb/_lancedb.pyi b/python/python/lancedb/_lancedb.pyi index 8811723e2..d6a8d71d6 100644 --- a/python/python/lancedb/_lancedb.pyi +++ b/python/python/lancedb/_lancedb.pyi @@ -51,7 +51,7 @@ class PyExpr: def to_sql(self) -> str: ... def expr_col(name: str) -> PyExpr: ... -def expr_lit(value: Union[bool, int, float, str]) -> PyExpr: ... +def expr_lit(value: Union[bool, int, float, str, bytes]) -> PyExpr: ... def expr_func(name: str, args: List[PyExpr]) -> PyExpr: ... class Session: diff --git a/python/python/lancedb/expr.py b/python/python/lancedb/expr.py index 5a568d66a..08cb91c50 100644 --- a/python/python/lancedb/expr.py +++ b/python/python/lancedb/expr.py @@ -63,7 +63,7 @@ def _coerce(value: "ExprLike") -> "Expr": # Type alias used in annotations. -ExprLike = Union["Expr", bool, int, float, str] +ExprLike = Union["Expr", bool, int, float, str, bytes] class Expr: @@ -261,13 +261,13 @@ def col(name: str) -> Expr: return Expr(expr_col(name)) -def lit(value: Union[bool, int, float, str]) -> Expr: +def lit(value: Union[bool, int, float, str, bytes]) -> Expr: """Create a literal (constant) value expression. Parameters ---------- value: - A Python ``bool``, ``int``, ``float``, or ``str``. + A Python ``bool``, ``int``, ``float``, ``str``, or ``bytes``. Examples -------- diff --git a/python/src/expr.rs b/python/src/expr.rs index 7d29fcd05..b322c5bdf 100644 --- a/python/src/expr.rs +++ b/python/src/expr.rs @@ -8,7 +8,9 @@ //! DataFusion [`Expr`] nodes, bypassing SQL string parsing. use arrow::{datatypes::DataType, pyarrow::PyArrowType}; +use datafusion_common::ScalarValue; use lancedb::expr::{DfExpr, col as ldb_col, contains, expr_cast, lit as df_lit, lower, upper}; +use pyo3::types::PyBytes; use pyo3::{Bound, PyAny, PyResult, exceptions::PyValueError, prelude::*, pyfunction}; /// A type-safe DataFusion expression. @@ -141,7 +143,7 @@ pub fn expr_col(name: &str) -> PyExpr { /// Create a literal value expression. /// -/// Supported Python types: `bool`, `int`, `float`, `str`. +/// Supported Python types: `bool`, `int`, `float`, `str`, `bytes`. #[pyfunction] pub fn expr_lit(value: Bound<'_, PyAny>) -> PyResult { // bool must be checked before int because bool is a subclass of int in Python @@ -157,8 +159,12 @@ pub fn expr_lit(value: Bound<'_, PyAny>) -> PyResult { if let Ok(s) = value.extract::() { return Ok(PyExpr(df_lit(s))); } + if value.is_instance_of::() { + let bytes = value.extract::>()?; + return Ok(PyExpr(df_lit(ScalarValue::Binary(Some(bytes))))); + } Err(PyValueError::new_err(format!( - "unsupported literal type: {}. Supported: bool, int, float, str", + "unsupported literal type: {}. Supported: bool, int, float, str, bytes", value.get_type().name()? ))) } diff --git a/python/tests/test_expr.py b/python/tests/test_expr.py index 339aca323..c4f68b1e2 100644 --- a/python/tests/test_expr.py +++ b/python/tests/test_expr.py @@ -33,6 +33,14 @@ class TestExprConstruction: e = lit(True) assert isinstance(e, Expr) + def test_lit_bytes(self): + e = lit(b"\xde\xad\xbe\xef") + assert isinstance(e, Expr) + + def test_lit_bytes_empty(self): + e = lit(b"") + assert isinstance(e, Expr) + def test_lit_unsupported_type_raises(self): with pytest.raises(Exception): lit([1, 2, 3]) @@ -135,6 +143,43 @@ class TestExprOperators: assert e.to_sql() == "(name = 'alice')" +class TestExprBytesLiteral: + def test_bytes_to_sql(self): + e = lit(b"\xde\xad\xbe\xef") + assert e.to_sql() == "X'DEADBEEF'" + + def test_empty_bytes_to_sql(self): + e = lit(b"") + assert e.to_sql() == "X''" + + def test_bytes_repr(self): + e = lit(b"\x01\x02") + assert repr(e) == "Expr(X'0102')" + + def test_bytes_equality_expr_sql(self): + e = col("data") == lit(b"\xca\xfe") + assert e.to_sql() == "(data = X'CAFE')" + + def test_bytes_ne_expr_sql(self): + e = col("data") != lit(b"\xff") + assert e.to_sql() == "(data <> X'FF')" + + def test_bytes_compound_expr_sql(self): + e = (col("data") == lit(b"\x01")) & (col("id") > lit(5)) + assert e.to_sql() == "((data = X'01') AND (id > 5))" + + def test_bytes_in_function_call(self): + # Regression test: binary literals inside scalar function calls + # used to fail because DataFusion's unparser does not support Binary + # scalars. Now handled via a placeholder-substitution rewrite. + e = func("contains", col("data"), lit(b"\xff")) + assert e.to_sql() == "contains(data, X'FF')" + + def test_bytes_in_not(self): + e = ~(col("data") == lit(b"\xff")) + assert e.to_sql() == "NOT (data = X'FF')" + + class TestExprStringMethods: def test_lower(self): e = col("name").lower() @@ -385,3 +430,44 @@ class TestColNamingIntegration: ) assert "upper_name" in result.schema.names assert sorted(result["upper_name"].to_pylist()) == ["ALICE", "BOB", "CHARLIE"] + + +# ── bytes / binary column integration tests ─────────────────────────────────── + + +@pytest.fixture +def binary_table(tmp_path): + db = lancedb.connect(str(tmp_path)) + data = pa.table( + { + "id": [1, 2, 3], + "payload": pa.array( + [b"\x01\x02", b"\xca\xfe", b"\xff\x00"], + type=pa.binary(), + ), + } + ) + return db.create_table("binary_test", data) + + +class TestExprBytesIntegration: + def test_binary_equality_filter(self, binary_table): + result = ( + binary_table.search().where(col("payload") == lit(b"\xca\xfe")).to_arrow() + ) + assert result.num_rows == 1 + assert result["id"][0].as_py() == 2 + + def test_binary_ne_filter(self, binary_table): + result = ( + binary_table.search().where(col("payload") != lit(b"\x01\x02")).to_arrow() + ) + assert result.num_rows == 2 + + def test_binary_compound_filter(self, binary_table): + result = ( + binary_table.search() + .where((col("payload") == lit(b"\x01\x02")) | (col("id") == lit(3))) + .to_arrow() + ) + assert result.num_rows == 2 diff --git a/rust/lancedb/src/expr.rs b/rust/lancedb/src/expr.rs index 4d700497f..02b6b7d08 100644 --- a/rust/lancedb/src/expr.rs +++ b/rust/lancedb/src/expr.rs @@ -138,4 +138,69 @@ mod tests { let sql = expr_to_sql_string(&expr).unwrap(); assert!(sql.contains("price")); } + + #[test] + fn test_binary_literal() { + use datafusion_common::ScalarValue; + let expr = lit(ScalarValue::Binary(Some(vec![0xde, 0xad, 0xbe, 0xef]))); + let sql = expr_to_sql_string(&expr).unwrap(); + assert_eq!(sql, "X'DEADBEEF'"); + } + + #[test] + fn test_binary_literal_in_filter() { + use datafusion_common::ScalarValue; + let expr = col("data").eq(lit(ScalarValue::Binary(Some(vec![0xca, 0xfe])))); + let sql = expr_to_sql_string(&expr).unwrap(); + assert_eq!(sql, "(data = X'CAFE')"); + } + + #[test] + fn test_binary_literal_compound() { + use datafusion_common::ScalarValue; + let bin_expr = col("data").eq(lit(ScalarValue::Binary(Some(vec![0x01])))); + let int_expr = col("id").gt(lit(5i64)); + let combined = bin_expr.and(int_expr); + let sql = expr_to_sql_string(&combined).unwrap(); + assert_eq!(sql, "((data = X'01') AND (id > 5))"); + } + + #[test] + fn test_null_binary_literal() { + use datafusion_common::ScalarValue; + let expr = lit(ScalarValue::Binary(None)); + let sql = expr_to_sql_string(&expr).unwrap(); + assert_eq!(sql, "NULL"); + } + + #[test] + fn test_binary_literal_in_function_call() { + use datafusion_common::ScalarValue; + // Binary literals inside scalar function arguments must also be + // serialized correctly (regression test for placeholder rewrite path). + let expr = contains(col("data"), lit(ScalarValue::Binary(Some(vec![0xff])))); + let sql = expr_to_sql_string(&expr).unwrap(); + assert_eq!(sql, "contains(data, X'FF')"); + } + + #[test] + fn test_binary_literal_in_negation() { + use datafusion_common::ScalarValue; + use std::ops::Not; + let expr = col("data") + .eq(lit(ScalarValue::Binary(Some(vec![0xab, 0xcd])))) + .not(); + let sql = expr_to_sql_string(&expr).unwrap(); + assert_eq!(sql, "NOT (data = X'ABCD')"); + } + + #[test] + fn test_multiple_binary_literals() { + use datafusion_common::ScalarValue; + let lhs = col("a").eq(lit(ScalarValue::Binary(Some(vec![0x01])))); + let rhs = col("b").eq(lit(ScalarValue::Binary(Some(vec![0x02, 0x03])))); + let expr = lhs.and(rhs); + let sql = expr_to_sql_string(&expr).unwrap(); + assert_eq!(sql, "((a = X'01') AND (b = X'0203'))"); + } } diff --git a/rust/lancedb/src/expr/sql.rs b/rust/lancedb/src/expr/sql.rs index f9cd81b50..23b89821a 100644 --- a/rust/lancedb/src/expr/sql.rs +++ b/rust/lancedb/src/expr/sql.rs @@ -1,6 +1,8 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The LanceDB Authors +use datafusion_common::ScalarValue; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; use datafusion_expr::Expr; use datafusion_sql::unparser::{self, dialect::Dialect}; @@ -28,7 +30,36 @@ impl Dialect for LanceSqlDialect { } } -pub fn expr_to_sql_string(expr: &Expr) -> crate::Result { +/// Prefix for placeholder strings inserted in place of binary literals. Chosen +/// to be extremely unlikely to occur in user data. +const BINARY_PLACEHOLDER_PREFIX: &str = "__lancedb_binary_placeholder_"; + +fn bytes_to_hex_sql(bytes: &[u8]) -> String { + let hex: String = bytes.iter().map(|b| format!("{b:02X}")).collect(); + format!("X'{hex}'") +} + +/// Returns true if *expr* contains a `Binary` or `LargeBinary` scalar literal +/// anywhere in its subtree. DataFusion's SQL unparser cannot serialize those +/// variants, so we route such expressions through a placeholder-substitution +/// path that emits SQL `X'...'` byte-string literals. +fn has_binary_literal(expr: &Expr) -> bool { + let mut found = false; + let _ = expr.apply(&mut |e: &Expr| { + if matches!( + e, + Expr::Literal(ScalarValue::Binary(_) | ScalarValue::LargeBinary(_), _) + ) { + found = true; + Ok(TreeNodeRecursion::Stop) + } else { + Ok(TreeNodeRecursion::Continue) + } + }); + found +} + +fn run_unparser(expr: &Expr) -> crate::Result { let ast = unparser::Unparser::new(&LanceSqlDialect) .expr_to_sql(expr) .map_err(|e| crate::Error::InvalidInput { @@ -36,3 +67,49 @@ pub fn expr_to_sql_string(expr: &Expr) -> crate::Result { })?; Ok(ast.to_string()) } + +pub fn expr_to_sql_string(expr: &Expr) -> crate::Result { + // Fast path: no binary literals — DataFusion's unparser handles everything. + if !has_binary_literal(expr) { + return run_unparser(expr); + } + + // Slow path: DataFusion's unparser cannot serialize `Binary`/`LargeBinary` + // scalars, so we rewrite each one to a unique string-literal placeholder, + // let the unparser do the rest of the work, then substitute the SQL + // `X'...'` byte-string literal back in. This keeps the operator/function + // serialization logic centralized in DataFusion and works for every + // expression node type the unparser supports. + let mut bindings: Vec> = Vec::new(); + let rewritten = expr + .clone() + .transform(|e: Expr| match e { + Expr::Literal(ScalarValue::Binary(Some(bytes)), m) + | Expr::Literal(ScalarValue::LargeBinary(Some(bytes)), m) => { + let placeholder = format!("{}{}__", BINARY_PLACEHOLDER_PREFIX, bindings.len()); + bindings.push(bytes); + Ok(Transformed::yes(Expr::Literal( + ScalarValue::Utf8(Some(placeholder)), + m, + ))) + } + Expr::Literal(ScalarValue::Binary(None), m) + | Expr::Literal(ScalarValue::LargeBinary(None), m) => { + Ok(Transformed::yes(Expr::Literal(ScalarValue::Null, m))) + } + other => Ok(Transformed::no(other)), + }) + .map_err(|e| crate::Error::InvalidInput { + message: format!("failed to rewrite expression: {}", e), + })? + .data; + + let mut sql = run_unparser(&rewritten)?; + for (i, bytes) in bindings.iter().enumerate() { + // The unparser quotes string literals with single quotes, so the + // placeholder appears as `'__lancedb_binary_placeholder___'`. + let quoted = format!("'{}{}__'", BINARY_PLACEHOLDER_PREFIX, i); + sql = sql.replace("ed, &bytes_to_hex_sql(bytes)); + } + Ok(sql) +}