mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-15 02:50:44 +00:00
feat(python): support bytes in lit() expressions (#3387)
Closes #3261. ## Summary Adds `bytes` to the accepted types of `lancedb.expr.lit()` so that binary scalars can be used in filter / projection expressions. The previous attempt in #3235 had to be reverted because DataFusion's SQL unparser does not support `Binary` / `LargeBinary` scalars, so any expression containing such a literal would fail in both `to_sql()` and `__repr__`. ## How `expr_to_sql_string` now has two paths: - **Fast path** (no binary literals): delegate to DataFusion's unparser unchanged. - **Slow path**: rewrite each `Binary(Some(bytes))` literal in the tree to a unique string-literal placeholder, run the unparser, then substitute `'<placeholder>'` with `X'<HEX>'` in the resulting SQL. `Binary(None)` / `LargeBinary(None)` are rewritten to `ScalarValue::Null` so the unparser emits plain `NULL`. This keeps DataFusion as the single source of truth for operator and function serialization, so binary literals work in every expression node type the unparser already supports — including nested cases like `contains(col("data"), lit(b"\xff"))`, `NOT (col == lit(b"..."))`, and `col.cast(...) == lit(b"...")`. ## Changes - `rust/lancedb/src/expr/sql.rs`: placeholder-substitution implementation. - `rust/lancedb/src/expr.rs`: 4 new unit tests covering binary literals in equality, compound predicates, scalar function calls, negation, and `NULL` binary literals. - `python/src/expr.rs`: `expr_lit` accepts `PyBytes` and produces `ScalarValue::Binary`. - `python/Cargo.toml` + `Cargo.lock`: pull in `datafusion-common` for `ScalarValue`. - `python/python/lancedb/expr.py`: extend `ExprLike` and `lit()` type annotations / docstrings with `bytes`. - `python/python/lancedb/_lancedb.pyi`: update `expr_lit` stub. - `python/tests/test_expr.py`: unit tests for `to_sql` / `repr` of binary literals and an integration test against a real `pa.binary()` column for equality / inequality / compound filters. ## Example ```python from lancedb.expr import col, lit, func # Equality against a binary column col("payload") == lit(b"\xca\xfe") # Expr((payload = X'CAFE')) # Nested inside a function call (previously failed) func("contains", col("data"), lit(b"\xff")) # Expr(contains(data, X'FF')) # repr() no longer crashes repr(lit(b"\xde\xad\xbe\xef")) # "Expr(X'DEADBEEF')" ``` ## Verification - [x] `cargo test -p lancedb --lib expr::` — 12/12 pass (was 9; +3 new tests) - [x] `cargo check --features remote --tests --examples` — clean - [x] `cargo clippy --features remote --tests --examples` — no warnings - [x] `cargo fmt --all -- --check` — clean - [x] `pytest python/tests/test_expr.py` — 76/76 pass (was 74; +2 new tests) - [x] `ruff check python` / `ruff format --check python` — clean ## Follow-ups (not in this PR) Issue #3261 also raises the possibility of a *truncated* `__repr__` for very large binary literals. This PR keeps `__repr__` exact (it forwards to `to_sql()`), since truncating display output would diverge from the SQL that actually gets executed. A display-only truncation could be added in a follow-up by giving `__repr__` its own renderer. Made with [Cursor](https://cursor.com) Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -5115,6 +5115,7 @@ dependencies = [
|
||||
"arrow",
|
||||
"async-trait",
|
||||
"bytes",
|
||||
"datafusion-common",
|
||||
"env_logger",
|
||||
"futures",
|
||||
"lance-core",
|
||||
|
||||
@@ -19,6 +19,7 @@ arrow = { version = "58.0.0", features = ["pyarrow"] }
|
||||
async-trait = "0.1"
|
||||
bytes = "1"
|
||||
lancedb = { path = "../rust/lancedb", default-features = false }
|
||||
datafusion-common.workspace = true
|
||||
lance-core.workspace = true
|
||||
lance-namespace.workspace = true
|
||||
lance-namespace-impls.workspace = true
|
||||
|
||||
@@ -51,7 +51,7 @@ class PyExpr:
|
||||
def to_sql(self) -> str: ...
|
||||
|
||||
def expr_col(name: str) -> PyExpr: ...
|
||||
def expr_lit(value: Union[bool, int, float, str]) -> PyExpr: ...
|
||||
def expr_lit(value: Union[bool, int, float, str, bytes]) -> PyExpr: ...
|
||||
def expr_func(name: str, args: List[PyExpr]) -> PyExpr: ...
|
||||
|
||||
class Session:
|
||||
|
||||
@@ -63,7 +63,7 @@ def _coerce(value: "ExprLike") -> "Expr":
|
||||
|
||||
|
||||
# Type alias used in annotations.
|
||||
ExprLike = Union["Expr", bool, int, float, str]
|
||||
ExprLike = Union["Expr", bool, int, float, str, bytes]
|
||||
|
||||
|
||||
class Expr:
|
||||
@@ -261,13 +261,13 @@ def col(name: str) -> Expr:
|
||||
return Expr(expr_col(name))
|
||||
|
||||
|
||||
def lit(value: Union[bool, int, float, str]) -> Expr:
|
||||
def lit(value: Union[bool, int, float, str, bytes]) -> Expr:
|
||||
"""Create a literal (constant) value expression.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
value:
|
||||
A Python ``bool``, ``int``, ``float``, or ``str``.
|
||||
A Python ``bool``, ``int``, ``float``, ``str``, or ``bytes``.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
@@ -8,7 +8,9 @@
|
||||
//! DataFusion [`Expr`] nodes, bypassing SQL string parsing.
|
||||
|
||||
use arrow::{datatypes::DataType, pyarrow::PyArrowType};
|
||||
use datafusion_common::ScalarValue;
|
||||
use lancedb::expr::{DfExpr, col as ldb_col, contains, expr_cast, lit as df_lit, lower, upper};
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3::{Bound, PyAny, PyResult, exceptions::PyValueError, prelude::*, pyfunction};
|
||||
|
||||
/// A type-safe DataFusion expression.
|
||||
@@ -141,7 +143,7 @@ pub fn expr_col(name: &str) -> PyExpr {
|
||||
|
||||
/// Create a literal value expression.
|
||||
///
|
||||
/// Supported Python types: `bool`, `int`, `float`, `str`.
|
||||
/// Supported Python types: `bool`, `int`, `float`, `str`, `bytes`.
|
||||
#[pyfunction]
|
||||
pub fn expr_lit(value: Bound<'_, PyAny>) -> PyResult<PyExpr> {
|
||||
// bool must be checked before int because bool is a subclass of int in Python
|
||||
@@ -157,8 +159,12 @@ pub fn expr_lit(value: Bound<'_, PyAny>) -> PyResult<PyExpr> {
|
||||
if let Ok(s) = value.extract::<String>() {
|
||||
return Ok(PyExpr(df_lit(s)));
|
||||
}
|
||||
if value.is_instance_of::<PyBytes>() {
|
||||
let bytes = value.extract::<Vec<u8>>()?;
|
||||
return Ok(PyExpr(df_lit(ScalarValue::Binary(Some(bytes)))));
|
||||
}
|
||||
Err(PyValueError::new_err(format!(
|
||||
"unsupported literal type: {}. Supported: bool, int, float, str",
|
||||
"unsupported literal type: {}. Supported: bool, int, float, str, bytes",
|
||||
value.get_type().name()?
|
||||
)))
|
||||
}
|
||||
|
||||
@@ -33,6 +33,14 @@ class TestExprConstruction:
|
||||
e = lit(True)
|
||||
assert isinstance(e, Expr)
|
||||
|
||||
def test_lit_bytes(self):
|
||||
e = lit(b"\xde\xad\xbe\xef")
|
||||
assert isinstance(e, Expr)
|
||||
|
||||
def test_lit_bytes_empty(self):
|
||||
e = lit(b"")
|
||||
assert isinstance(e, Expr)
|
||||
|
||||
def test_lit_unsupported_type_raises(self):
|
||||
with pytest.raises(Exception):
|
||||
lit([1, 2, 3])
|
||||
@@ -135,6 +143,43 @@ class TestExprOperators:
|
||||
assert e.to_sql() == "(name = 'alice')"
|
||||
|
||||
|
||||
class TestExprBytesLiteral:
|
||||
def test_bytes_to_sql(self):
|
||||
e = lit(b"\xde\xad\xbe\xef")
|
||||
assert e.to_sql() == "X'DEADBEEF'"
|
||||
|
||||
def test_empty_bytes_to_sql(self):
|
||||
e = lit(b"")
|
||||
assert e.to_sql() == "X''"
|
||||
|
||||
def test_bytes_repr(self):
|
||||
e = lit(b"\x01\x02")
|
||||
assert repr(e) == "Expr(X'0102')"
|
||||
|
||||
def test_bytes_equality_expr_sql(self):
|
||||
e = col("data") == lit(b"\xca\xfe")
|
||||
assert e.to_sql() == "(data = X'CAFE')"
|
||||
|
||||
def test_bytes_ne_expr_sql(self):
|
||||
e = col("data") != lit(b"\xff")
|
||||
assert e.to_sql() == "(data <> X'FF')"
|
||||
|
||||
def test_bytes_compound_expr_sql(self):
|
||||
e = (col("data") == lit(b"\x01")) & (col("id") > lit(5))
|
||||
assert e.to_sql() == "((data = X'01') AND (id > 5))"
|
||||
|
||||
def test_bytes_in_function_call(self):
|
||||
# Regression test: binary literals inside scalar function calls
|
||||
# used to fail because DataFusion's unparser does not support Binary
|
||||
# scalars. Now handled via a placeholder-substitution rewrite.
|
||||
e = func("contains", col("data"), lit(b"\xff"))
|
||||
assert e.to_sql() == "contains(data, X'FF')"
|
||||
|
||||
def test_bytes_in_not(self):
|
||||
e = ~(col("data") == lit(b"\xff"))
|
||||
assert e.to_sql() == "NOT (data = X'FF')"
|
||||
|
||||
|
||||
class TestExprStringMethods:
|
||||
def test_lower(self):
|
||||
e = col("name").lower()
|
||||
@@ -385,3 +430,44 @@ class TestColNamingIntegration:
|
||||
)
|
||||
assert "upper_name" in result.schema.names
|
||||
assert sorted(result["upper_name"].to_pylist()) == ["ALICE", "BOB", "CHARLIE"]
|
||||
|
||||
|
||||
# ── bytes / binary column integration tests ───────────────────────────────────
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def binary_table(tmp_path):
|
||||
db = lancedb.connect(str(tmp_path))
|
||||
data = pa.table(
|
||||
{
|
||||
"id": [1, 2, 3],
|
||||
"payload": pa.array(
|
||||
[b"\x01\x02", b"\xca\xfe", b"\xff\x00"],
|
||||
type=pa.binary(),
|
||||
),
|
||||
}
|
||||
)
|
||||
return db.create_table("binary_test", data)
|
||||
|
||||
|
||||
class TestExprBytesIntegration:
|
||||
def test_binary_equality_filter(self, binary_table):
|
||||
result = (
|
||||
binary_table.search().where(col("payload") == lit(b"\xca\xfe")).to_arrow()
|
||||
)
|
||||
assert result.num_rows == 1
|
||||
assert result["id"][0].as_py() == 2
|
||||
|
||||
def test_binary_ne_filter(self, binary_table):
|
||||
result = (
|
||||
binary_table.search().where(col("payload") != lit(b"\x01\x02")).to_arrow()
|
||||
)
|
||||
assert result.num_rows == 2
|
||||
|
||||
def test_binary_compound_filter(self, binary_table):
|
||||
result = (
|
||||
binary_table.search()
|
||||
.where((col("payload") == lit(b"\x01\x02")) | (col("id") == lit(3)))
|
||||
.to_arrow()
|
||||
)
|
||||
assert result.num_rows == 2
|
||||
|
||||
@@ -138,4 +138,69 @@ mod tests {
|
||||
let sql = expr_to_sql_string(&expr).unwrap();
|
||||
assert!(sql.contains("price"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_literal() {
|
||||
use datafusion_common::ScalarValue;
|
||||
let expr = lit(ScalarValue::Binary(Some(vec![0xde, 0xad, 0xbe, 0xef])));
|
||||
let sql = expr_to_sql_string(&expr).unwrap();
|
||||
assert_eq!(sql, "X'DEADBEEF'");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_literal_in_filter() {
|
||||
use datafusion_common::ScalarValue;
|
||||
let expr = col("data").eq(lit(ScalarValue::Binary(Some(vec![0xca, 0xfe]))));
|
||||
let sql = expr_to_sql_string(&expr).unwrap();
|
||||
assert_eq!(sql, "(data = X'CAFE')");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_literal_compound() {
|
||||
use datafusion_common::ScalarValue;
|
||||
let bin_expr = col("data").eq(lit(ScalarValue::Binary(Some(vec![0x01]))));
|
||||
let int_expr = col("id").gt(lit(5i64));
|
||||
let combined = bin_expr.and(int_expr);
|
||||
let sql = expr_to_sql_string(&combined).unwrap();
|
||||
assert_eq!(sql, "((data = X'01') AND (id > 5))");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_null_binary_literal() {
|
||||
use datafusion_common::ScalarValue;
|
||||
let expr = lit(ScalarValue::Binary(None));
|
||||
let sql = expr_to_sql_string(&expr).unwrap();
|
||||
assert_eq!(sql, "NULL");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_literal_in_function_call() {
|
||||
use datafusion_common::ScalarValue;
|
||||
// Binary literals inside scalar function arguments must also be
|
||||
// serialized correctly (regression test for placeholder rewrite path).
|
||||
let expr = contains(col("data"), lit(ScalarValue::Binary(Some(vec![0xff]))));
|
||||
let sql = expr_to_sql_string(&expr).unwrap();
|
||||
assert_eq!(sql, "contains(data, X'FF')");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_literal_in_negation() {
|
||||
use datafusion_common::ScalarValue;
|
||||
use std::ops::Not;
|
||||
let expr = col("data")
|
||||
.eq(lit(ScalarValue::Binary(Some(vec![0xab, 0xcd]))))
|
||||
.not();
|
||||
let sql = expr_to_sql_string(&expr).unwrap();
|
||||
assert_eq!(sql, "NOT (data = X'ABCD')");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiple_binary_literals() {
|
||||
use datafusion_common::ScalarValue;
|
||||
let lhs = col("a").eq(lit(ScalarValue::Binary(Some(vec![0x01]))));
|
||||
let rhs = col("b").eq(lit(ScalarValue::Binary(Some(vec![0x02, 0x03]))));
|
||||
let expr = lhs.and(rhs);
|
||||
let sql = expr_to_sql_string(&expr).unwrap();
|
||||
assert_eq!(sql, "((a = X'01') AND (b = X'0203'))");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
// SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
||||
|
||||
use datafusion_common::ScalarValue;
|
||||
use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
|
||||
use datafusion_expr::Expr;
|
||||
use datafusion_sql::unparser::{self, dialect::Dialect};
|
||||
|
||||
@@ -28,7 +30,36 @@ impl Dialect for LanceSqlDialect {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn expr_to_sql_string(expr: &Expr) -> crate::Result<String> {
|
||||
/// Prefix for placeholder strings inserted in place of binary literals. Chosen
|
||||
/// to be extremely unlikely to occur in user data.
|
||||
const BINARY_PLACEHOLDER_PREFIX: &str = "__lancedb_binary_placeholder_";
|
||||
|
||||
fn bytes_to_hex_sql(bytes: &[u8]) -> String {
|
||||
let hex: String = bytes.iter().map(|b| format!("{b:02X}")).collect();
|
||||
format!("X'{hex}'")
|
||||
}
|
||||
|
||||
/// Returns true if *expr* contains a `Binary` or `LargeBinary` scalar literal
|
||||
/// anywhere in its subtree. DataFusion's SQL unparser cannot serialize those
|
||||
/// variants, so we route such expressions through a placeholder-substitution
|
||||
/// path that emits SQL `X'...'` byte-string literals.
|
||||
fn has_binary_literal(expr: &Expr) -> bool {
|
||||
let mut found = false;
|
||||
let _ = expr.apply(&mut |e: &Expr| {
|
||||
if matches!(
|
||||
e,
|
||||
Expr::Literal(ScalarValue::Binary(_) | ScalarValue::LargeBinary(_), _)
|
||||
) {
|
||||
found = true;
|
||||
Ok(TreeNodeRecursion::Stop)
|
||||
} else {
|
||||
Ok(TreeNodeRecursion::Continue)
|
||||
}
|
||||
});
|
||||
found
|
||||
}
|
||||
|
||||
fn run_unparser(expr: &Expr) -> crate::Result<String> {
|
||||
let ast = unparser::Unparser::new(&LanceSqlDialect)
|
||||
.expr_to_sql(expr)
|
||||
.map_err(|e| crate::Error::InvalidInput {
|
||||
@@ -36,3 +67,49 @@ pub fn expr_to_sql_string(expr: &Expr) -> crate::Result<String> {
|
||||
})?;
|
||||
Ok(ast.to_string())
|
||||
}
|
||||
|
||||
pub fn expr_to_sql_string(expr: &Expr) -> crate::Result<String> {
|
||||
// Fast path: no binary literals — DataFusion's unparser handles everything.
|
||||
if !has_binary_literal(expr) {
|
||||
return run_unparser(expr);
|
||||
}
|
||||
|
||||
// Slow path: DataFusion's unparser cannot serialize `Binary`/`LargeBinary`
|
||||
// scalars, so we rewrite each one to a unique string-literal placeholder,
|
||||
// let the unparser do the rest of the work, then substitute the SQL
|
||||
// `X'...'` byte-string literal back in. This keeps the operator/function
|
||||
// serialization logic centralized in DataFusion and works for every
|
||||
// expression node type the unparser supports.
|
||||
let mut bindings: Vec<Vec<u8>> = Vec::new();
|
||||
let rewritten = expr
|
||||
.clone()
|
||||
.transform(|e: Expr| match e {
|
||||
Expr::Literal(ScalarValue::Binary(Some(bytes)), m)
|
||||
| Expr::Literal(ScalarValue::LargeBinary(Some(bytes)), m) => {
|
||||
let placeholder = format!("{}{}__", BINARY_PLACEHOLDER_PREFIX, bindings.len());
|
||||
bindings.push(bytes);
|
||||
Ok(Transformed::yes(Expr::Literal(
|
||||
ScalarValue::Utf8(Some(placeholder)),
|
||||
m,
|
||||
)))
|
||||
}
|
||||
Expr::Literal(ScalarValue::Binary(None), m)
|
||||
| Expr::Literal(ScalarValue::LargeBinary(None), m) => {
|
||||
Ok(Transformed::yes(Expr::Literal(ScalarValue::Null, m)))
|
||||
}
|
||||
other => Ok(Transformed::no(other)),
|
||||
})
|
||||
.map_err(|e| crate::Error::InvalidInput {
|
||||
message: format!("failed to rewrite expression: {}", e),
|
||||
})?
|
||||
.data;
|
||||
|
||||
let mut sql = run_unparser(&rewritten)?;
|
||||
for (i, bytes) in bindings.iter().enumerate() {
|
||||
// The unparser quotes string literals with single quotes, so the
|
||||
// placeholder appears as `'__lancedb_binary_placeholder_<i>__'`.
|
||||
let quoted = format!("'{}{}__'", BINARY_PLACEHOLDER_PREFIX, i);
|
||||
sql = sql.replace("ed, &bytes_to_hex_sql(bytes));
|
||||
}
|
||||
Ok(sql)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user