mirror of
https://github.com/lancedb/lancedb.git
synced 2026-05-16 19:40:40 +00:00
Closes #3261. ## Summary Adds `bytes` to the accepted types of `lancedb.expr.lit()` so that binary scalars can be used in filter / projection expressions. The previous attempt in #3235 had to be reverted because DataFusion's SQL unparser does not support `Binary` / `LargeBinary` scalars, so any expression containing such a literal would fail in both `to_sql()` and `__repr__`. ## How `expr_to_sql_string` now has two paths: - **Fast path** (no binary literals): delegate to DataFusion's unparser unchanged. - **Slow path**: rewrite each `Binary(Some(bytes))` literal in the tree to a unique string-literal placeholder, run the unparser, then substitute `'<placeholder>'` with `X'<HEX>'` in the resulting SQL. `Binary(None)` / `LargeBinary(None)` are rewritten to `ScalarValue::Null` so the unparser emits plain `NULL`. This keeps DataFusion as the single source of truth for operator and function serialization, so binary literals work in every expression node type the unparser already supports — including nested cases like `contains(col("data"), lit(b"\xff"))`, `NOT (col == lit(b"..."))`, and `col.cast(...) == lit(b"...")`. ## Changes - `rust/lancedb/src/expr/sql.rs`: placeholder-substitution implementation. - `rust/lancedb/src/expr.rs`: 4 new unit tests covering binary literals in equality, compound predicates, scalar function calls, negation, and `NULL` binary literals. - `python/src/expr.rs`: `expr_lit` accepts `PyBytes` and produces `ScalarValue::Binary`. - `python/Cargo.toml` + `Cargo.lock`: pull in `datafusion-common` for `ScalarValue`. - `python/python/lancedb/expr.py`: extend `ExprLike` and `lit()` type annotations / docstrings with `bytes`. - `python/python/lancedb/_lancedb.pyi`: update `expr_lit` stub. - `python/tests/test_expr.py`: unit tests for `to_sql` / `repr` of binary literals and an integration test against a real `pa.binary()` column for equality / inequality / compound filters. ## Example ```python from lancedb.expr import col, lit, func # Equality against a binary column col("payload") == lit(b"\xca\xfe") # Expr((payload = X'CAFE')) # Nested inside a function call (previously failed) func("contains", col("data"), lit(b"\xff")) # Expr(contains(data, X'FF')) # repr() no longer crashes repr(lit(b"\xde\xad\xbe\xef")) # "Expr(X'DEADBEEF')" ``` ## Verification - [x] `cargo test -p lancedb --lib expr::` — 12/12 pass (was 9; +3 new tests) - [x] `cargo check --features remote --tests --examples` — clean - [x] `cargo clippy --features remote --tests --examples` — no warnings - [x] `cargo fmt --all -- --check` — clean - [x] `pytest python/tests/test_expr.py` — 76/76 pass (was 74; +2 new tests) - [x] `ruff check python` / `ruff format --check python` — clean ## Follow-ups (not in this PR) Issue #3261 also raises the possibility of a *truncated* `__repr__` for very large binary literals. This PR keeps `__repr__` exact (it forwards to `to_sql()`), since truncating display output would diverge from the SQL that actually gets executed. A display-only truncation could be added in a follow-up by giving `__repr__` its own renderer. Made with [Cursor](https://cursor.com) Co-authored-by: Cursor <cursoragent@cursor.com>
474 lines
14 KiB
Python
474 lines
14 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
|
|
|
|
"""Tests for the type-safe expression builder API."""
|
|
|
|
import pytest
|
|
import pyarrow as pa
|
|
import lancedb
|
|
from lancedb.expr import Expr, col, lit, func
|
|
|
|
|
|
# ── unit tests for Expr construction ─────────────────────────────────────────
|
|
|
|
|
|
class TestExprConstruction:
|
|
def test_col_returns_expr(self):
|
|
e = col("age")
|
|
assert isinstance(e, Expr)
|
|
|
|
def test_lit_int(self):
|
|
e = lit(42)
|
|
assert isinstance(e, Expr)
|
|
|
|
def test_lit_float(self):
|
|
e = lit(3.14)
|
|
assert isinstance(e, Expr)
|
|
|
|
def test_lit_str(self):
|
|
e = lit("hello")
|
|
assert isinstance(e, Expr)
|
|
|
|
def test_lit_bool(self):
|
|
e = lit(True)
|
|
assert isinstance(e, Expr)
|
|
|
|
def test_lit_bytes(self):
|
|
e = lit(b"\xde\xad\xbe\xef")
|
|
assert isinstance(e, Expr)
|
|
|
|
def test_lit_bytes_empty(self):
|
|
e = lit(b"")
|
|
assert isinstance(e, Expr)
|
|
|
|
def test_lit_unsupported_type_raises(self):
|
|
with pytest.raises(Exception):
|
|
lit([1, 2, 3])
|
|
|
|
def test_func(self):
|
|
e = func("lower", col("name"))
|
|
assert isinstance(e, Expr)
|
|
assert e.to_sql() == "lower(name)"
|
|
|
|
def test_func_unknown_raises(self):
|
|
with pytest.raises(Exception):
|
|
func("not_a_real_function", col("x"))
|
|
|
|
|
|
class TestExprOperators:
|
|
def test_eq_operator(self):
|
|
e = col("x") == lit(1)
|
|
assert isinstance(e, Expr)
|
|
assert e.to_sql() == "(x = 1)"
|
|
|
|
def test_ne_operator(self):
|
|
e = col("x") != lit(1)
|
|
assert isinstance(e, Expr)
|
|
assert e.to_sql() == "(x <> 1)"
|
|
|
|
def test_lt_operator(self):
|
|
e = col("age") < lit(18)
|
|
assert isinstance(e, Expr)
|
|
assert e.to_sql() == "(age < 18)"
|
|
|
|
def test_le_operator(self):
|
|
e = col("age") <= lit(18)
|
|
assert isinstance(e, Expr)
|
|
assert e.to_sql() == "(age <= 18)"
|
|
|
|
def test_gt_operator(self):
|
|
e = col("age") > lit(18)
|
|
assert isinstance(e, Expr)
|
|
assert e.to_sql() == "(age > 18)"
|
|
|
|
def test_ge_operator(self):
|
|
e = col("age") >= lit(18)
|
|
assert isinstance(e, Expr)
|
|
assert e.to_sql() == "(age >= 18)"
|
|
|
|
def test_and_operator(self):
|
|
e = (col("age") > lit(18)) & (col("status") == lit("active"))
|
|
assert isinstance(e, Expr)
|
|
assert e.to_sql() == "((age > 18) AND (status = 'active'))"
|
|
|
|
def test_or_operator(self):
|
|
e = (col("a") == lit(1)) | (col("b") == lit(2))
|
|
assert isinstance(e, Expr)
|
|
assert e.to_sql() == "((a = 1) OR (b = 2))"
|
|
|
|
def test_invert_operator(self):
|
|
e = ~(col("active") == lit(True))
|
|
assert isinstance(e, Expr)
|
|
assert e.to_sql() == "NOT (active = true)"
|
|
|
|
def test_add_operator(self):
|
|
e = col("x") + lit(1)
|
|
assert isinstance(e, Expr)
|
|
assert e.to_sql() == "(x + 1)"
|
|
|
|
def test_sub_operator(self):
|
|
e = col("x") - lit(1)
|
|
assert isinstance(e, Expr)
|
|
assert e.to_sql() == "(x - 1)"
|
|
|
|
def test_mul_operator(self):
|
|
e = col("price") * lit(1.1)
|
|
assert isinstance(e, Expr)
|
|
assert e.to_sql() == "(price * 1.1)"
|
|
|
|
def test_div_operator(self):
|
|
e = col("total") / lit(2)
|
|
assert isinstance(e, Expr)
|
|
assert e.to_sql() == "(total / 2)"
|
|
|
|
def test_radd(self):
|
|
e = lit(1) + col("x")
|
|
assert isinstance(e, Expr)
|
|
assert e.to_sql() == "(1 + x)"
|
|
|
|
def test_rmul(self):
|
|
e = lit(2) * col("x")
|
|
assert isinstance(e, Expr)
|
|
assert e.to_sql() == "(2 * x)"
|
|
|
|
def test_coerce_plain_int(self):
|
|
# Operators should auto-wrap plain Python values via lit()
|
|
e = col("age") > 18
|
|
assert isinstance(e, Expr)
|
|
assert e.to_sql() == "(age > 18)"
|
|
|
|
def test_coerce_plain_str(self):
|
|
e = col("name") == "alice"
|
|
assert isinstance(e, Expr)
|
|
assert e.to_sql() == "(name = 'alice')"
|
|
|
|
|
|
class TestExprBytesLiteral:
|
|
def test_bytes_to_sql(self):
|
|
e = lit(b"\xde\xad\xbe\xef")
|
|
assert e.to_sql() == "X'DEADBEEF'"
|
|
|
|
def test_empty_bytes_to_sql(self):
|
|
e = lit(b"")
|
|
assert e.to_sql() == "X''"
|
|
|
|
def test_bytes_repr(self):
|
|
e = lit(b"\x01\x02")
|
|
assert repr(e) == "Expr(X'0102')"
|
|
|
|
def test_bytes_equality_expr_sql(self):
|
|
e = col("data") == lit(b"\xca\xfe")
|
|
assert e.to_sql() == "(data = X'CAFE')"
|
|
|
|
def test_bytes_ne_expr_sql(self):
|
|
e = col("data") != lit(b"\xff")
|
|
assert e.to_sql() == "(data <> X'FF')"
|
|
|
|
def test_bytes_compound_expr_sql(self):
|
|
e = (col("data") == lit(b"\x01")) & (col("id") > lit(5))
|
|
assert e.to_sql() == "((data = X'01') AND (id > 5))"
|
|
|
|
def test_bytes_in_function_call(self):
|
|
# Regression test: binary literals inside scalar function calls
|
|
# used to fail because DataFusion's unparser does not support Binary
|
|
# scalars. Now handled via a placeholder-substitution rewrite.
|
|
e = func("contains", col("data"), lit(b"\xff"))
|
|
assert e.to_sql() == "contains(data, X'FF')"
|
|
|
|
def test_bytes_in_not(self):
|
|
e = ~(col("data") == lit(b"\xff"))
|
|
assert e.to_sql() == "NOT (data = X'FF')"
|
|
|
|
|
|
class TestExprStringMethods:
|
|
def test_lower(self):
|
|
e = col("name").lower()
|
|
assert isinstance(e, Expr)
|
|
assert e.to_sql() == "lower(name)"
|
|
|
|
def test_upper(self):
|
|
e = col("name").upper()
|
|
assert isinstance(e, Expr)
|
|
assert e.to_sql() == "upper(name)"
|
|
|
|
def test_contains(self):
|
|
e = col("text").contains(lit("hello"))
|
|
assert isinstance(e, Expr)
|
|
assert e.to_sql() == "contains(text, 'hello')"
|
|
|
|
def test_contains_with_str_coerce(self):
|
|
e = col("text").contains("hello")
|
|
assert isinstance(e, Expr)
|
|
assert e.to_sql() == "contains(text, 'hello')"
|
|
|
|
def test_chained_lower_eq(self):
|
|
e = col("name").lower() == lit("alice")
|
|
assert isinstance(e, Expr)
|
|
assert e.to_sql() == "(lower(name) = 'alice')"
|
|
|
|
|
|
class TestExprCast:
|
|
def test_cast_string(self):
|
|
e = col("id").cast("string")
|
|
assert isinstance(e, Expr)
|
|
assert e.to_sql() == "CAST(id AS VARCHAR)"
|
|
|
|
def test_cast_int32(self):
|
|
e = col("score").cast("int32")
|
|
assert isinstance(e, Expr)
|
|
assert e.to_sql() == "CAST(score AS INTEGER)"
|
|
|
|
def test_cast_float64(self):
|
|
e = col("val").cast("float64")
|
|
assert isinstance(e, Expr)
|
|
assert e.to_sql() == "CAST(val AS DOUBLE)"
|
|
|
|
def test_cast_pyarrow_type(self):
|
|
e = col("score").cast(pa.int32())
|
|
assert isinstance(e, Expr)
|
|
assert e.to_sql() == "CAST(score AS INTEGER)"
|
|
|
|
def test_cast_pyarrow_float64(self):
|
|
e = col("val").cast(pa.float64())
|
|
assert isinstance(e, Expr)
|
|
assert e.to_sql() == "CAST(val AS DOUBLE)"
|
|
|
|
def test_cast_pyarrow_string(self):
|
|
e = col("id").cast(pa.string())
|
|
assert isinstance(e, Expr)
|
|
assert e.to_sql() == "CAST(id AS VARCHAR)"
|
|
|
|
def test_cast_pyarrow_and_string_equivalent(self):
|
|
# pa.int32() and "int32" should produce equivalent SQL
|
|
sql_str = col("x").cast("int32").to_sql()
|
|
sql_pa = col("x").cast(pa.int32()).to_sql()
|
|
assert sql_str == sql_pa
|
|
|
|
|
|
class TestExprNamedMethods:
|
|
def test_eq_method(self):
|
|
e = col("x").eq(lit(1))
|
|
assert isinstance(e, Expr)
|
|
assert e.to_sql() == "(x = 1)"
|
|
|
|
def test_gt_method(self):
|
|
e = col("x").gt(lit(0))
|
|
assert isinstance(e, Expr)
|
|
assert e.to_sql() == "(x > 0)"
|
|
|
|
def test_and_method(self):
|
|
e = col("x").gt(lit(0)).and_(col("y").lt(lit(10)))
|
|
assert isinstance(e, Expr)
|
|
assert e.to_sql() == "((x > 0) AND (y < 10))"
|
|
|
|
def test_or_method(self):
|
|
e = col("x").eq(lit(1)).or_(col("x").eq(lit(2)))
|
|
assert isinstance(e, Expr)
|
|
assert e.to_sql() == "((x = 1) OR (x = 2))"
|
|
|
|
|
|
class TestExprRepr:
|
|
def test_repr(self):
|
|
e = col("age") > lit(18)
|
|
assert repr(e) == "Expr((age > 18))"
|
|
|
|
def test_to_sql(self):
|
|
e = col("age") > 18
|
|
assert e.to_sql() == "(age > 18)"
|
|
|
|
def test_unhashable(self):
|
|
e = col("x")
|
|
with pytest.raises(TypeError):
|
|
{e: 1}
|
|
|
|
|
|
# ── integration tests: end-to-end query against a real table ─────────────────
|
|
|
|
|
|
@pytest.fixture
|
|
def simple_table(tmp_path):
|
|
db = lancedb.connect(str(tmp_path))
|
|
data = pa.table(
|
|
{
|
|
"id": [1, 2, 3, 4, 5],
|
|
"name": ["Alice", "Bob", "Charlie", "alice", "BOB"],
|
|
"age": [25, 17, 30, 22, 15],
|
|
"score": [1.5, 2.0, 3.5, 4.0, 0.5],
|
|
}
|
|
)
|
|
return db.create_table("test", data)
|
|
|
|
|
|
class TestExprFilter:
|
|
def test_simple_gt_filter(self, simple_table):
|
|
result = simple_table.search().where(col("age") > lit(20)).to_arrow()
|
|
assert result.num_rows == 3 # ages 25, 30, 22
|
|
|
|
def test_compound_and_filter(self, simple_table):
|
|
result = (
|
|
simple_table.search()
|
|
.where((col("age") > lit(18)) & (col("score") > lit(2.0)))
|
|
.to_arrow()
|
|
)
|
|
assert result.num_rows == 2 # (30, 3.5) and (22, 4.0)
|
|
|
|
def test_string_equality_filter(self, simple_table):
|
|
result = simple_table.search().where(col("name") == lit("Bob")).to_arrow()
|
|
assert result.num_rows == 1
|
|
|
|
def test_or_filter(self, simple_table):
|
|
result = (
|
|
simple_table.search()
|
|
.where((col("age") < lit(18)) | (col("age") > lit(28)))
|
|
.to_arrow()
|
|
)
|
|
assert result.num_rows == 3 # ages 17, 30, 15
|
|
|
|
def test_coercion_no_lit(self, simple_table):
|
|
# Python values should be auto-coerced
|
|
result = simple_table.search().where(col("age") > 20).to_arrow()
|
|
assert result.num_rows == 3
|
|
|
|
def test_string_sql_still_works(self, simple_table):
|
|
# Backwards compatibility: plain strings still accepted
|
|
result = simple_table.search().where("age > 20").to_arrow()
|
|
assert result.num_rows == 3
|
|
|
|
|
|
class TestExprProjection:
|
|
def test_select_with_expr(self, simple_table):
|
|
result = (
|
|
simple_table.search()
|
|
.select({"double_score": col("score") * lit(2)})
|
|
.to_arrow()
|
|
)
|
|
assert "double_score" in result.schema.names
|
|
|
|
def test_select_mixed_str_and_expr(self, simple_table):
|
|
result = (
|
|
simple_table.search()
|
|
.select({"id": "id", "double_score": col("score") * lit(2)})
|
|
.to_arrow()
|
|
)
|
|
assert "id" in result.schema.names
|
|
assert "double_score" in result.schema.names
|
|
|
|
def test_select_list_of_columns(self, simple_table):
|
|
# Plain list of str still works
|
|
result = simple_table.search().select(["id", "name"]).to_arrow()
|
|
assert result.schema.names == ["id", "name"]
|
|
|
|
|
|
# ── column name edge cases ────────────────────────────────────────────────────
|
|
|
|
|
|
class TestColNaming:
|
|
"""Unit tests verifying that col() preserves identifiers exactly.
|
|
|
|
Identifiers that need quoting (camelCase, spaces, leading digits, unicode)
|
|
are wrapped in backticks to match the lance SQL parser's dialect.
|
|
"""
|
|
|
|
def test_camel_case_preserved_in_sql(self):
|
|
# camelCase is quoted with backticks so the case round-trips correctly.
|
|
assert col("firstName").to_sql() == "`firstName`"
|
|
|
|
def test_camel_case_in_expression(self):
|
|
assert (col("firstName") > lit(18)).to_sql() == "(`firstName` > 18)"
|
|
|
|
def test_space_in_name_quoted(self):
|
|
assert col("first name").to_sql() == "`first name`"
|
|
|
|
def test_space_in_expression(self):
|
|
assert (col("first name") == lit("A")).to_sql() == "(`first name` = 'A')"
|
|
|
|
def test_leading_digit_quoted(self):
|
|
assert col("2fast").to_sql() == "`2fast`"
|
|
|
|
def test_unicode_quoted(self):
|
|
assert col("名前").to_sql() == "`名前`"
|
|
|
|
def test_snake_case_unquoted(self):
|
|
# Plain snake_case needs no quoting.
|
|
assert col("first_name").to_sql() == "first_name"
|
|
|
|
|
|
@pytest.fixture
|
|
def special_col_table(tmp_path):
|
|
db = lancedb.connect(str(tmp_path))
|
|
data = pa.table(
|
|
{
|
|
"firstName": ["Alice", "Bob", "Charlie"],
|
|
"first name": ["A", "B", "C"],
|
|
"score": [10, 20, 30],
|
|
}
|
|
)
|
|
return db.create_table("special", data)
|
|
|
|
|
|
class TestColNamingIntegration:
|
|
def test_camel_case_filter(self, special_col_table):
|
|
result = (
|
|
special_col_table.search()
|
|
.where(col("firstName") == lit("Alice"))
|
|
.to_arrow()
|
|
)
|
|
assert result.num_rows == 1
|
|
assert result["firstName"][0].as_py() == "Alice"
|
|
|
|
def test_space_in_col_filter(self, special_col_table):
|
|
result = (
|
|
special_col_table.search().where(col("first name") == lit("B")).to_arrow()
|
|
)
|
|
assert result.num_rows == 1
|
|
|
|
def test_camel_case_projection(self, special_col_table):
|
|
result = (
|
|
special_col_table.search()
|
|
.select({"upper_name": col("firstName").upper()})
|
|
.to_arrow()
|
|
)
|
|
assert "upper_name" in result.schema.names
|
|
assert sorted(result["upper_name"].to_pylist()) == ["ALICE", "BOB", "CHARLIE"]
|
|
|
|
|
|
# ── bytes / binary column integration tests ───────────────────────────────────
|
|
|
|
|
|
@pytest.fixture
|
|
def binary_table(tmp_path):
|
|
db = lancedb.connect(str(tmp_path))
|
|
data = pa.table(
|
|
{
|
|
"id": [1, 2, 3],
|
|
"payload": pa.array(
|
|
[b"\x01\x02", b"\xca\xfe", b"\xff\x00"],
|
|
type=pa.binary(),
|
|
),
|
|
}
|
|
)
|
|
return db.create_table("binary_test", data)
|
|
|
|
|
|
class TestExprBytesIntegration:
|
|
def test_binary_equality_filter(self, binary_table):
|
|
result = (
|
|
binary_table.search().where(col("payload") == lit(b"\xca\xfe")).to_arrow()
|
|
)
|
|
assert result.num_rows == 1
|
|
assert result["id"][0].as_py() == 2
|
|
|
|
def test_binary_ne_filter(self, binary_table):
|
|
result = (
|
|
binary_table.search().where(col("payload") != lit(b"\x01\x02")).to_arrow()
|
|
)
|
|
assert result.num_rows == 2
|
|
|
|
def test_binary_compound_filter(self, binary_table):
|
|
result = (
|
|
binary_table.search()
|
|
.where((col("payload") == lit(b"\x01\x02")) | (col("id") == lit(3)))
|
|
.to_arrow()
|
|
)
|
|
assert result.num_rows == 2
|