diff --git a/docs/src/js/classes/Query.md b/docs/src/js/classes/Query.md index bdf4764b7..15d78036b 100644 --- a/docs/src/js/classes/Query.md +++ b/docs/src/js/classes/Query.md @@ -518,6 +518,9 @@ x > 5 OR y = 'test' Filtering performance can often be improved by creating a scalar index on the filter column(s). + +Calling this multiple times combines the filters with a logical AND rather +than replacing the previous filter. ``` #### Inherited from diff --git a/docs/src/js/classes/VectorQuery.md b/docs/src/js/classes/VectorQuery.md index fb010f65f..5ff8beec7 100644 --- a/docs/src/js/classes/VectorQuery.md +++ b/docs/src/js/classes/VectorQuery.md @@ -767,6 +767,9 @@ x > 5 OR y = 'test' Filtering performance can often be improved by creating a scalar index on the filter column(s). + +Calling this multiple times combines the filters with a logical AND rather +than replacing the previous filter. ``` #### Inherited from diff --git a/nodejs/__test__/query.test.ts b/nodejs/__test__/query.test.ts index a2975c133..da001b1eb 100644 --- a/nodejs/__test__/query.test.ts +++ b/nodejs/__test__/query.test.ts @@ -215,6 +215,20 @@ describe("Query orderBy", () => { expect(results[2].score).toBeCloseTo(4.1, 0.001); }); + it("should combine repeated where clauses with AND", async () => { + const results = await table + .query() + .where("score > 1.0") + .where("score < 3.0") + .orderBy({ columnName: "score" }) + .toArray(); + // Only rows matching both predicates should be returned, rather than the + // second where() silently replacing the first. + expect(results.length).toBe(2); + expect(results[0].score).toBeCloseTo(1.2, 0.001); + expect(results[1].score).toBeCloseTo(2.8, 0.001); + }); + it("should support method chaining with limit", async () => { const results = await table .query() diff --git a/nodejs/lancedb/query.ts b/nodejs/lancedb/query.ts index f985eaf83..04a8a69d8 100644 --- a/nodejs/lancedb/query.ts +++ b/nodejs/lancedb/query.ts @@ -362,6 +362,9 @@ export class StandardQueryBase< * * Filtering performance can often be improved by creating a scalar index * on the filter column(s). + * + * Calling this multiple times combines the filters with a logical AND rather + * than replacing the previous filter. */ where(predicate: string): this { this.doCall((inner: NativeQueryType) => inner.onlyIf(predicate)); diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index 847ba54db..748f88af5 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -119,6 +119,27 @@ def _filter_to_sql(filter: Optional[Union[str, Expr]]) -> Optional[str]: return filter +def _combine_where( + existing: Optional[Union[str, Expr]], new: Union[str, Expr] +) -> Union[str, Expr]: + """Combine a new filter with an existing one using a logical AND. + + Calling ``where`` more than once composes the filters with AND instead of + replacing the previous filter. Two :class:`~lancedb.expr.Expr` filters are + combined as an expression; otherwise both filters are lowered to SQL strings + and combined as SQL. + """ + if existing is None: + return new + existing_is_expr = isinstance(existing, Expr) + new_is_expr = isinstance(new, Expr) + if existing_is_expr and new_is_expr: + return existing & new + existing_sql = existing.to_sql() if existing_is_expr else existing + new_sql = new.to_sql() if new_is_expr else new + return f"({existing_sql}) AND ({new_sql})" + + def _projection_to_scanner_kwargs( columns: Optional[ Union[ @@ -1148,8 +1169,13 @@ class LanceQueryBuilder(ABC): ------- LanceQueryBuilder The LanceQueryBuilder object. + + Notes + ----- + Calling this multiple times combines the filters with a logical AND + rather than replacing the previous filter. """ - self._where = where + self._where = _combine_where(self._where, where) self._postfilter = not prefilter return self @@ -1693,8 +1719,13 @@ class LanceVectorQueryBuilder(LanceQueryBuilder): ------- LanceQueryBuilder The LanceQueryBuilder object. + + Notes + ----- + Calling this multiple times combines the filters with a logical AND + rather than replacing the previous filter. """ - self._where = where + self._where = _combine_where(self._where, where) if prefilter is not None: self._postfilter = not prefilter return self @@ -2894,6 +2925,9 @@ class AsyncStandardQuery(AsyncQueryBase): Filtering performance can often be improved by creating a scalar index on the filter column(s). + + Calling this multiple times combines the filters with a logical AND + rather than replacing the previous filter. """ if isinstance(predicate, Expr): self._inner.where_expr(predicate._inner) diff --git a/python/python/tests/test_query.py b/python/python/tests/test_query.py index 8f977bf91..b60efc68e 100644 --- a/python/python/tests/test_query.py +++ b/python/python/tests/test_query.py @@ -502,6 +502,61 @@ def test_with_row_id(table: lancedb.table.Table): assert rs["_rowid"].to_pylist() == [0, 1] +def test_where_repeated_combines_with_and(table: lancedb.table.Table): + # Calling where() more than once should AND the filters together instead of + # silently replacing the previous one (regression test for #2649). + builder = table.search().where("id >= 1").where("id < 2") + assert builder._where == "(id >= 1) AND (id < 2)" + + ids = [row["id"] for row in builder.limit(10).to_list()] + assert ids == [1] + + +def test_where_repeated_combines_expr(table: lancedb.table.Table): + from lancedb.expr import col, lit + + builder = table.search().where(col("id") >= lit(1)).where(col("id") < lit(2)) + ids = [row["id"] for row in builder.limit(10).to_list()] + assert ids == [1] + + +def test_where_mixed_filter_kinds_combines(table: lancedb.table.Table): + # Mixing a SQL string filter with an expression filter lowers the + # expression to SQL and combines them as SQL strings. + from lancedb.expr import col, lit + + builder = table.search().where("id >= 1").where(col("id") < lit(2)) + ids = [row["id"] for row in builder.limit(10).to_list()] + assert ids == [1] + + +@pytest.mark.asyncio +async def test_where_repeated_combines_with_and_async(table_async: AsyncTable): + ids = [ + row["id"] + for row in ( + await table_async.query().where("id >= 1").where("id < 2").to_list() + ) + ] + assert ids == [1] + + +@pytest.mark.asyncio +async def test_where_mixed_filter_kinds_combines_async(table_async: AsyncTable): + from lancedb.expr import col, lit + + ids = [ + row["id"] + for row in ( + await table_async.query() + .where("id >= 1") + .where(col("id") < lit(2)) + .to_list() + ) + ] + assert ids == [1] + + def test_distance_range(table: lancedb.table.Table): q = [0, 0] rs = table.search(q).to_arrow() diff --git a/rust/lancedb/src/query.rs b/rust/lancedb/src/query.rs index 893ea7f9b..d00ca5836 100644 --- a/rust/lancedb/src/query.rs +++ b/rust/lancedb/src/query.rs @@ -401,6 +401,9 @@ pub trait QueryBase { /// /// Filtering performance can often be improved by creating a scalar index /// on the filter column(s). + /// + /// Calling this multiple times combines the filters with a logical AND + /// (i.e. `(previous) AND (new)`) rather than replacing the previous filter. fn only_if(self, filter: impl AsRef) -> Self; /// Only return rows which match the filter, using an expression builder. @@ -423,6 +426,9 @@ pub trait QueryBase { /// /// Note: Expression filters are not supported for remote/server-side queries. /// Use [`QueryBase::only_if`] with SQL strings for remote tables. + /// + /// Calling this multiple times combines the expressions with a logical AND + /// rather than replacing the previous filter. fn only_if_expr(self, filter: datafusion_expr::Expr) -> Self; /// Perform a full text search on the table. @@ -535,12 +541,13 @@ impl QueryBase for T { } fn only_if(mut self, filter: impl AsRef) -> Self { - self.mut_query().filter = Some(QueryFilter::Sql(filter.as_ref().to_string())); + self.mut_query() + .add_filter(QueryFilter::Sql(filter.as_ref().to_string())); self } fn only_if_expr(mut self, filter: datafusion_expr::Expr) -> Self { - self.mut_query().filter = Some(QueryFilter::Datafusion(filter)); + self.mut_query().add_filter(QueryFilter::Datafusion(filter)); self } @@ -716,6 +723,39 @@ pub enum QueryFilter { Datafusion(Expr), } +/// Combine two filters with a logical AND. +/// +/// This is used when a query receives more than one filter (for example when +/// `where`/`only_if` is called multiple times) so the filters are composed +/// with AND rather than the later filter silently replacing the earlier one. +/// +/// SQL string and expression filters are combined within their own +/// representation. When the two representations are mixed, the expression is +/// lowered to SQL (via [`crate::expr::expr_to_sql_string`]) and the filters are +/// combined as SQL strings. Substrait filters cannot be combined and return an +/// error. +fn and_filters(existing: QueryFilter, new: QueryFilter) -> Result { + match (existing, new) { + (QueryFilter::Sql(lhs), QueryFilter::Sql(rhs)) => { + Ok(QueryFilter::Sql(format!("({lhs}) AND ({rhs})"))) + } + (QueryFilter::Datafusion(lhs), QueryFilter::Datafusion(rhs)) => { + Ok(QueryFilter::Datafusion(lhs.and(rhs))) + } + (QueryFilter::Sql(lhs), QueryFilter::Datafusion(rhs)) => { + let rhs = crate::expr::expr_to_sql_string(&rhs)?; + Ok(QueryFilter::Sql(format!("({lhs}) AND ({rhs})"))) + } + (QueryFilter::Datafusion(lhs), QueryFilter::Sql(rhs)) => { + let lhs = crate::expr::expr_to_sql_string(&lhs)?; + Ok(QueryFilter::Sql(format!("({lhs}) AND ({rhs})"))) + } + _ => Err(Error::InvalidInput { + message: "cannot combine a Substrait filter with another filter".to_string(), + }), + } +} + /// A basic query into a table without any kind of search /// /// This will result in a (potentially filtered) scan if executed @@ -730,6 +770,13 @@ pub struct QueryRequest { /// Apply filter to the returned rows. pub filter: Option, + /// An error recorded while combining repeated filters that could not be + /// composed (see [`QueryRequest::add_filter`]). It is surfaced when the + /// query is executed via [`QueryRequest::check_filter`]. We defer the error + /// because the builder methods that set filters return `Self` rather than a + /// `Result`. + pub(crate) filter_error: Option, + /// Perform a full text search on the table. pub full_text_search: Option, @@ -775,6 +822,7 @@ impl Default for QueryRequest { limit: None, offset: None, filter: None, + filter_error: None, full_text_search: None, select: Select::All, fast_search: false, @@ -788,6 +836,41 @@ impl Default for QueryRequest { } } +impl QueryRequest { + /// Add a filter, combining it with any existing filter using a logical AND. + /// + /// If the new filter cannot be combined with the existing one (because they + /// use different representations) the error is recorded and surfaced later + /// by [`Self::check_filter`]. + pub(crate) fn add_filter(&mut self, new: QueryFilter) { + self.filter = Some(match self.filter.take() { + None => new, + Some(existing) => match and_filters(existing, new) { + Ok(combined) => combined, + Err(err) => { + // The filters were consumed while attempting to combine + // them; the recorded error is surfaced by `check_filter` + // before the query executes. + self.filter_error = Some(err.to_string()); + return; + } + }, + }); + } + + /// Return an error if combining filters failed (see [`Self::add_filter`]). + /// + /// This must be called by every backend before executing a query. + pub(crate) fn check_filter(&self) -> Result<()> { + if let Some(message) = &self.filter_error { + return Err(Error::InvalidInput { + message: message.clone(), + }); + } + Ok(()) + } +} + /// A builder for LanceDB queries. /// /// See [`crate::Table::query`] for more details on queries @@ -1682,6 +1765,70 @@ mod tests { } } + #[tokio::test] + async fn test_repeated_only_if_combines_with_and() { + use crate::expr::{col, lit}; + + let tmp_dir = tempdir().unwrap(); + let dataset_path = tmp_dir.path().join("test.lance"); + let uri = dataset_path.to_str().unwrap(); + + let conn = connect(uri).execute().await.unwrap(); + let table = conn + .create_table("my_table", make_non_empty_batches()) + .execute() + .await + .unwrap(); + + let query = table.query().only_if("id > 0").only_if("id < 100"); + match &query.request.filter { + Some(QueryFilter::Sql(sql)) => assert_eq!(sql, "(id > 0) AND (id < 100)"), + other => panic!("expected combined SQL filter, got {other:?}"), + } + + // A single filter is left untouched. + let query = table.query().only_if("id > 0"); + match &query.request.filter { + Some(QueryFilter::Sql(sql)) => assert_eq!(sql, "id > 0"), + other => panic!("expected single SQL filter, got {other:?}"), + } + + // Expression filters are combined with a logical AND as well. + let query = table + .query() + .only_if_expr(col("id").gt(lit(0i32))) + .only_if_expr(col("id").lt(lit(100i32))); + match &query.request.filter { + Some(QueryFilter::Datafusion(expr)) => { + assert_eq!( + expr, + &col("id").gt(lit(0i32)).and(col("id").lt(lit(100i32))) + ); + } + other => panic!("expected combined Datafusion filter, got {other:?}"), + } + + // Mixing an SQL string filter with an expression filter lowers the + // expression to SQL and combines them as SQL strings. + let query = table + .query() + .only_if("id > 0") + .only_if_expr(col("id").lt(lit(100i32))); + match &query.request.filter { + Some(QueryFilter::Sql(sql)) => { + let expected = format!( + "(id > 0) AND ({})", + crate::expr::expr_to_sql_string(&col("id").lt(lit(100i32))).unwrap() + ); + assert_eq!(sql, &expected); + } + other => panic!("expected combined SQL filter, got {other:?}"), + } + assert!(query.request.check_filter().is_ok()); + // The combined filter executes without error. + query.execute().await.unwrap(); + } + #[tokio::test] async fn test_select_with_transform() { // TODO: Switch back to memory://foo after https://github.com/lancedb/lancedb/issues/1051 diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index 3dfb8218c..29495b73d 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -612,6 +612,7 @@ impl RemoteTable { body: &mut serde_json::Value, params: &QueryRequest, ) -> Result<()> { + params.check_filter()?; body["prefilter"] = params.prefilter.into(); if let Some(offset) = params.offset { body["offset"] = serde_json::Value::Number(serde_json::Number::from(offset)); diff --git a/rust/lancedb/src/table/query.rs b/rust/lancedb/src/table/query.rs index e38fec598..cfc4b17d3 100644 --- a/rust/lancedb/src/table/query.rs +++ b/rust/lancedb/src/table/query.rs @@ -35,6 +35,15 @@ pub enum AnyQuery { VectorQuery(VectorQueryRequest), } +impl AnyQuery { + pub(crate) fn base(&self) -> &QueryRequest { + match self { + Self::Query(query) => query, + Self::VectorQuery(query) => &query.base, + } + } +} + //Decide between namespace or local pub async fn execute_query( table: &NativeTable, @@ -108,6 +117,7 @@ pub async fn create_plan( AnyQuery::VectorQuery(query) => query.clone(), AnyQuery::Query(query) => VectorQueryRequest::from_plain_query(query.clone()), }; + query.base.check_filter()?; let ds_ref = table.dataset.get().await?; let schema = ds_ref.schema(); @@ -357,6 +367,7 @@ async fn execute_namespace_query( /// Convert an AnyQuery to the namespace QueryTableRequest format. fn convert_to_namespace_query(query: &AnyQuery) -> Result { + query.base().check_filter()?; match query { AnyQuery::VectorQuery(vq) => { // Extract the query vector(s)