From ae7f2cbfe84fed96037856c1fbbd92044e9dcd50 Mon Sep 17 00:00:00 2001 From: Trenton H <797416+stumpylog@users.noreply.github.com> Date: Thu, 11 Jun 2026 07:59:49 -0700 Subject: [PATCH] feat(python): accept Expr in Table.delete and merge when_not_matched_by_source_delete (#3524) Another little pain point as I was working to integrate with paperless-ngx. The read path of table.search() or table.query() already accepted an Expr, but write paths Table.delete and merge_insert(...).when_not_matched_by_source_delete did not. This PR attempts to close that gap, so writes and reads can both use Expr, instead of one side needing to build a string. --- python/python/lancedb/_lancedb.pyi | 2 +- python/python/lancedb/merge.py | 18 ++++--- python/python/lancedb/table.py | 30 +++++++----- python/python/tests/test_table.py | 77 ++++++++++++++++++++++++++++++ python/src/table.rs | 26 ++++++++-- 5 files changed, 129 insertions(+), 24 deletions(-) diff --git a/python/python/lancedb/_lancedb.pyi b/python/python/lancedb/_lancedb.pyi index d8b92d20e..739baaee6 100644 --- a/python/python/lancedb/_lancedb.pyi +++ b/python/python/lancedb/_lancedb.pyi @@ -205,7 +205,7 @@ class Table: async def prewarm_index(self, index_name: str) -> None: ... async def prewarm_data(self, columns: Optional[List[str]] = None) -> None: ... async def list_indices(self) -> list[IndexConfig]: ... - async def delete(self, filter: str) -> DeleteResult: ... + async def delete(self, filter: Union[str, PyExpr]) -> DeleteResult: ... async def add_columns(self, columns: list[tuple[str, str]]) -> AddColumnsResult: ... async def add_columns_with_schema(self, schema: pa.Schema) -> AddColumnsResult: ... async def alter_columns( diff --git a/python/python/lancedb/merge.py b/python/python/lancedb/merge.py index 6085f5a06..299c188a0 100644 --- a/python/python/lancedb/merge.py +++ b/python/python/lancedb/merge.py @@ -5,7 +5,9 @@ from __future__ import annotations from datetime import timedelta -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, List, Optional, Union + +from .expr import Expr if TYPE_CHECKING: from .common import DATA @@ -32,6 +34,7 @@ class LanceMergeInsertBuilder(object): self._when_not_matched_insert_all = False self._when_not_matched_by_source_delete = False self._when_not_matched_by_source_condition = None + self._when_not_matched_by_source_condition_expr = None self._timeout = None self._use_index = True self._use_lsm_write = None @@ -62,7 +65,7 @@ class LanceMergeInsertBuilder(object): return self def when_not_matched_by_source_delete( - self, condition: Optional[str] = None + self, condition: Union[str, Expr, None] = None ) -> LanceMergeInsertBuilder: """ Rows that exist only in the target table (old data) will be @@ -71,13 +74,16 @@ class LanceMergeInsertBuilder(object): Parameters ---------- - condition: Optional[str], default None + condition: str or :class:`~lancedb.expr.Expr` or None, default None If None then all such rows will be deleted. Otherwise the - condition will be used as an SQL filter to limit what rows - are deleted. + condition will be used as a filter to limit what rows are deleted. + Can be a SQL string or a type-safe :class:`~lancedb.expr.Expr` + built with :func:`~lancedb.expr.col` and :func:`~lancedb.expr.lit`. """ self._when_not_matched_by_source_delete = True - if condition is not None: + if isinstance(condition, Expr): + self._when_not_matched_by_source_condition_expr = condition._inner + elif condition is not None: self._when_not_matched_by_source_condition = condition return self diff --git a/python/python/lancedb/table.py b/python/python/lancedb/table.py index d25a34b80..64a05b3dd 100644 --- a/python/python/lancedb/table.py +++ b/python/python/lancedb/table.py @@ -61,6 +61,7 @@ from .index import ( HnswFlat, FTS, ) +from .expr import Expr from .merge import LanceMergeInsertBuilder from .pydantic import LanceModel, model_to_dict from .query import ( @@ -1533,7 +1534,7 @@ class Table(ABC): ) -> MergeResult: ... @abstractmethod - def delete(self, where: str) -> DeleteResult: + def delete(self, where: Union[str, Expr]) -> DeleteResult: """Delete rows from the table. This can be used to delete a single row, many rows, all rows, or @@ -1541,10 +1542,10 @@ class Table(ABC): Parameters ---------- - where: str - The SQL where clause to use when deleting rows. - - - For example, 'x = 2' or 'x IN (1, 2, 3)'. + where: str or :class:`~lancedb.expr.Expr` + The filter condition. Can be a SQL string or a type-safe + :class:`~lancedb.expr.Expr` built with :func:`~lancedb.expr.col` + and :func:`~lancedb.expr.lit`. The filter must not be empty, or it will error. @@ -3423,8 +3424,9 @@ class LanceTable(Table): ) return self - def delete(self, where: str) -> DeleteResult: - return LOOP.run(self._table.delete(where)) + def delete(self, where: Union[str, Expr]) -> DeleteResult: + predicate = where._inner if isinstance(where, Expr) else where + return LOOP.run(self._table.delete(predicate)) def update( self, @@ -5214,6 +5216,7 @@ class AsyncTable: when_not_matched_insert_all=merge._when_not_matched_insert_all, when_not_matched_by_source_delete=merge._when_not_matched_by_source_delete, when_not_matched_by_source_condition=merge._when_not_matched_by_source_condition, + when_not_matched_by_source_condition_expr=merge._when_not_matched_by_source_condition_expr, timeout=merge._timeout, use_index=merge._use_index, use_lsm_write=merge._use_lsm_write, @@ -5221,7 +5224,7 @@ class AsyncTable: ), ) - async def delete(self, where: str) -> DeleteResult: + async def delete(self, where: Union[str, Expr]) -> DeleteResult: """Delete rows from the table. This can be used to delete a single row, many rows, all rows, or @@ -5229,10 +5232,10 @@ class AsyncTable: Parameters ---------- - where: str - The SQL where clause to use when deleting rows. - - - For example, 'x = 2' or 'x IN (1, 2, 3)'. + where: str or :class:`~lancedb.expr.Expr` + The filter condition. Can be a SQL string or a type-safe + :class:`~lancedb.expr.Expr` built with :func:`~lancedb.expr.col` + and :func:`~lancedb.expr.lit`. The filter must not be empty, or it will error. @@ -5271,7 +5274,8 @@ class AsyncTable: x vector 0 3 [5.0, 6.0] """ - return await self._inner.delete(where) + predicate = where._inner if isinstance(where, Expr) else where + return await self._inner.delete(predicate) async def update( self, diff --git a/python/python/tests/test_table.py b/python/python/tests/test_table.py index 913220a1a..6b9d723a7 100644 --- a/python/python/tests/test_table.py +++ b/python/python/tests/test_table.py @@ -22,6 +22,7 @@ import pytest from lancedb.conftest import MockTextEmbeddingFunction from lancedb.db import AsyncConnection, DBConnection from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry +from lancedb.expr import col, lit from lancedb.pydantic import LanceModel, Vector from lancedb.table import LanceTable from pydantic import BaseModel @@ -1966,6 +1967,38 @@ def test_delete(mem_db: DBConnection): assert table.to_arrow()["id"].to_pylist() == [1] +def test_delete_expr(mem_db: DBConnection): + table = mem_db.create_table( + "my_table", + data=[ + {"vector": [1.1, 0.9], "id": 0}, + {"vector": [1.2, 1.9], "id": 1}, + {"vector": [1.3, 2.9], "id": 2}, + ], + ) + assert len(table) == 3 + delete_res = table.delete(col("id") == lit(0)) + assert delete_res.version == 2 + assert len(table) == 2 + assert sorted(table.to_arrow()["id"].to_pylist()) == [1, 2] + + +@pytest.mark.asyncio +async def test_delete_expr_async(mem_db_async: AsyncConnection): + table = await mem_db_async.create_table( + "my_table", + data=[ + {"vector": [1.1, 0.9], "id": 0}, + {"vector": [1.2, 1.9], "id": 1}, + {"vector": [1.3, 2.9], "id": 2}, + ], + ) + assert await table.count_rows() == 3 + await table.delete(col("id") == lit(0)) + assert await table.count_rows() == 2 + assert sorted((await table.to_arrow())["id"].to_pylist()) == [1, 2] + + def test_update(mem_db: DBConnection): table = mem_db.create_table( "my_table", @@ -2151,6 +2184,50 @@ def test_merge_insert(mem_db: DBConnection): ) +def test_merge_insert_by_source_delete_expr(mem_db: DBConnection): + table = mem_db.create_table( + "my_table", + data=pa.table({"a": [1, 2, 3], "b": ["a", "b", "c"]}), + ) + new_data = pa.table({"a": [2, 4], "b": ["x", "z"]}) + + # replace-range, limiting the source-absent delete with an Expr condition + merge_insert_res = ( + table.merge_insert("a") + .when_matched_update_all() + .when_not_matched_insert_all() + .when_not_matched_by_source_delete(col("a") > lit(2)) + .execute(new_data) + ) + assert merge_insert_res.num_inserted_rows == 1 + assert merge_insert_res.num_updated_rows == 1 + assert merge_insert_res.num_deleted_rows == 1 + + expected = pa.table({"a": [1, 2, 4], "b": ["a", "x", "z"]}) + assert table.to_arrow().sort_by("a") == expected + + +@pytest.mark.asyncio +async def test_merge_insert_by_source_delete_expr_async( + mem_db_async: AsyncConnection, +): + data = pa.table({"a": [1, 2, 3], "b": ["a", "b", "c"]}) + table = await mem_db_async.create_table("some_table", data=data) + new_data = pa.table({"a": [2, 4], "b": ["x", "z"]}) + + # replace-range, limiting the source-absent delete with an Expr condition + await ( + table.merge_insert("a") + .when_matched_update_all() + .when_not_matched_insert_all() + .when_not_matched_by_source_delete(col("a") > lit(2)) + .execute(new_data) + ) + + expected = pa.table({"a": [1, 2, 4], "b": ["a", "x", "z"]}) + assert (await table.to_arrow()).sort_by("a") == expected + + # We vary the data format because there are slight differences in how # subschemas are handled in different formats @pytest.mark.parametrize( diff --git a/python/src/table.rs b/python/src/table.rs index 365d50d2a..fd3857249 100644 --- a/python/src/table.rs +++ b/python/src/table.rs @@ -6,6 +6,7 @@ use crate::runtime::future_into_py; use crate::{ connection::Connection, error::PythonErrorExt, + expr::PyExpr, index::{IndexConfig, extract_index_params}, query::{Query, TakeQuery}, table::scannable::PyScannable, @@ -28,6 +29,12 @@ use pyo3::{ mod scannable; +#[derive(FromPyObject)] +enum PredicateArg { + Expr(PyExpr), + Sql(String), +} + /// Statistics about a compaction operation. #[pyclass(get_all, from_py_object)] #[derive(Clone, Debug)] @@ -561,10 +568,15 @@ impl Table { }) } - pub fn delete(self_: PyRef<'_, Self>, condition: String) -> PyResult> { + #[allow(private_interfaces)] + pub fn delete(self_: PyRef<'_, Self>, condition: PredicateArg) -> PyResult> { let inner = self_.inner_ref()?.clone(); future_into_py(self_.py(), async move { - let result = inner.delete(&condition).await.infer_error()?; + let result = match &condition { + PredicateArg::Expr(e) => inner.delete(&e.0).await, + PredicateArg::Sql(s) => inner.delete(s.as_str()).await, + } + .infer_error()?; Ok(DeleteResult::from(result)) }) } @@ -959,8 +971,13 @@ impl Table { builder.when_not_matched_insert_all(); } if parameters.when_not_matched_by_source_delete { - builder - .when_not_matched_by_source_delete(parameters.when_not_matched_by_source_condition); + if let Some(e) = parameters.when_not_matched_by_source_condition_expr { + builder.when_not_matched_by_source_delete_expr(e.0); + } else { + builder.when_not_matched_by_source_delete( + parameters.when_not_matched_by_source_condition, + ); + } } if let Some(timeout) = parameters.timeout { builder.timeout(timeout); @@ -1196,6 +1213,7 @@ pub struct MergeInsertParams { when_not_matched_insert_all: bool, when_not_matched_by_source_delete: bool, when_not_matched_by_source_condition: Option, + when_not_matched_by_source_condition_expr: Option, timeout: Option, use_index: Option, use_lsm_write: Option,