diff --git a/python/python/lancedb/query.py b/python/python/lancedb/query.py index 9e62101f..95b691b2 100644 --- a/python/python/lancedb/query.py +++ b/python/python/lancedb/query.py @@ -1329,6 +1329,27 @@ class AsyncQueryBase(object): self._inner.fast_search() return self + def postfilter(self) -> AsyncQuery: + """ + If this is called then filtering will happen after the search instead of + before. + By default filtering will be performed before the search. This is how + filtering is typically understood to work. This prefilter step does add some + additional latency. Creating a scalar index on the filter column(s) can + often improve this latency. However, sometimes a filter is too complex or + scalar indices cannot be applied to the column. In these cases postfiltering + can be used instead of prefiltering to improve latency. + Post filtering applies the filter to the results of the search. This + means we only run the filter on a much smaller set of data. However, it can + cause the query to return fewer than `limit` results (or even no results) if + none of the nearest results match the filter. + Post filtering happens during the "refine stage" (described in more detail in + @see {@link VectorQuery#refineFactor}). This means that setting a higher refine + factor can often help restore some of the results lost by post filtering. + """ + self._inner.postfilter() + return self + async def to_batches( self, *, max_batch_length: Optional[int] = None ) -> AsyncRecordBatchReader: @@ -1632,30 +1653,6 @@ class AsyncVectorQuery(AsyncQueryBase): self._inner.distance_type(distance_type) return self - def postfilter(self) -> AsyncVectorQuery: - """ - If this is called then filtering will happen after the vector search instead of - before. - - By default filtering will be performed before the vector search. This is how - filtering is typically understood to work. This prefilter step does add some - additional latency. Creating a scalar index on the filter column(s) can - often improve this latency. However, sometimes a filter is too complex or - scalar indices cannot be applied to the column. In these cases postfiltering - can be used instead of prefiltering to improve latency. - - Post filtering applies the filter to the results of the vector search. This - means we only run the filter on a much smaller set of data. However, it can - cause the query to return fewer than `limit` results (or even no results) if - none of the nearest results match the filter. - - Post filtering happens during the "refine stage" (described in more detail in - @see {@link VectorQuery#refineFactor}). This means that setting a higher refine - factor can often help restore some of the results lost by post filtering. - """ - self._inner.postfilter() - return self - def bypass_vector_index(self) -> AsyncVectorQuery: """ If this is called then any vector index is skipped diff --git a/python/python/tests/test_fts.py b/python/python/tests/test_fts.py index ce649581..594552a0 100644 --- a/python/python/tests/test_fts.py +++ b/python/python/tests/test_fts.py @@ -235,6 +235,29 @@ async def test_search_fts_async(async_table): results = await async_table.query().nearest_to_text("puppy").limit(5).to_list() assert len(results) == 5 + expected_count = await async_table.count_rows( + "count > 5000 and contains(text, 'puppy')" + ) + expected_count = min(expected_count, 10) + + limited_results_pre_filter = await ( + async_table.query() + .nearest_to_text("puppy") + .where("count > 5000") + .limit(10) + .to_list() + ) + assert len(limited_results_pre_filter) == expected_count + limited_results_post_filter = await ( + async_table.query() + .nearest_to_text("puppy") + .where("count > 5000") + .limit(10) + .postfilter() + .to_list() + ) + assert len(limited_results_post_filter) <= expected_count + @pytest.mark.asyncio async def test_search_fts_specify_column_async(async_table): diff --git a/python/src/query.rs b/python/src/query.rs index b68e96b7..e3b127ba 100644 --- a/python/src/query.rs +++ b/python/src/query.rs @@ -72,6 +72,10 @@ impl Query { self.inner = self.inner.clone().fast_search(); } + pub fn postfilter(&mut self) { + self.inner = self.inner.clone().postfilter(); + } + pub fn nearest_to(&mut self, vector: Bound<'_, PyAny>) -> PyResult { let data: ArrayData = ArrayData::from_pyarrow_bound(&vector)?; let array = make_array(data); diff --git a/rust/lancedb/src/query.rs b/rust/lancedb/src/query.rs index 6118e6b7..135f46a1 100644 --- a/rust/lancedb/src/query.rs +++ b/rust/lancedb/src/query.rs @@ -403,6 +403,26 @@ pub trait QueryBase { /// By default, it is false. fn fast_search(self) -> Self; + /// If this is called then filtering will happen after the vector search instead of + /// before. + /// + /// By default filtering will be performed before the vector search. This is how + /// filtering is typically understood to work. This prefilter step does add some + /// additional latency. Creating a scalar index on the filter column(s) can + /// often improve this latency. However, sometimes a filter is too complex or scalar + /// indices cannot be applied to the column. In these cases postfiltering can be + /// used instead of prefiltering to improve latency. + /// + /// Post filtering applies the filter to the results of the vector search. This means + /// we only run the filter on a much smaller set of data. However, it can cause the + /// query to return fewer than `limit` results (or even no results) if none of the nearest + /// results match the filter. + /// + /// Post filtering happens during the "refine stage" (described in more detail in + /// [`Self::refine_factor`]). This means that setting a higher refine factor can often + /// help restore some of the results lost by post filtering. + fn postfilter(self) -> Self; + /// Return the `_rowid` meta column from the Table. fn with_row_id(self) -> Self; } @@ -442,6 +462,11 @@ impl QueryBase for T { self } + fn postfilter(mut self) -> Self { + self.mut_query().prefilter = false; + self + } + fn with_row_id(mut self) -> Self { self.mut_query().with_row_id = true; self @@ -561,6 +586,9 @@ pub struct Query { /// /// By default, this is false. pub(crate) with_row_id: bool, + + /// If set to false, the filter will be applied after the vector search. + pub(crate) prefilter: bool, } impl Query { @@ -574,6 +602,7 @@ impl Query { select: Select::All, fast_search: false, with_row_id: false, + prefilter: true, } } @@ -678,8 +707,6 @@ pub struct VectorQuery { pub(crate) distance_type: Option, /// Default is true. Set to false to enforce a brute force search. pub(crate) use_index: bool, - /// Apply filter before ANN search/ - pub(crate) prefilter: bool, } impl VectorQuery { @@ -692,7 +719,6 @@ impl VectorQuery { refine_factor: None, distance_type: None, use_index: true, - prefilter: true, } } @@ -782,29 +808,6 @@ impl VectorQuery { self } - /// If this is called then filtering will happen after the vector search instead of - /// before. - /// - /// By default filtering will be performed before the vector search. This is how - /// filtering is typically understood to work. This prefilter step does add some - /// additional latency. Creating a scalar index on the filter column(s) can - /// often improve this latency. However, sometimes a filter is too complex or scalar - /// indices cannot be applied to the column. In these cases postfiltering can be - /// used instead of prefiltering to improve latency. - /// - /// Post filtering applies the filter to the results of the vector search. This means - /// we only run the filter on a much smaller set of data. However, it can cause the - /// query to return fewer than `limit` results (or even no results) if none of the nearest - /// results match the filter. - /// - /// Post filtering happens during the "refine stage" (described in more detail in - /// [`Self::refine_factor`]). This means that setting a higher refine factor can often - /// help restore some of the results lost by post filtering. - pub fn postfilter(mut self) -> Self { - self.prefilter = false; - self - } - /// If this is called then any vector index is skipped /// /// An exhaustive (flat) search will be performed. The query vector will diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index f9900b2c..5ceaac55 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -305,7 +305,7 @@ impl TableInternal for RemoteTable { let mut body = serde_json::Value::Object(Default::default()); Self::apply_query_params(&mut body, &query.base)?; - body["prefilter"] = query.prefilter.into(); + body["prefilter"] = query.base.prefilter.into(); body["distance_type"] = serde_json::json!(query.distance_type.unwrap_or_default()); body["nprobes"] = query.nprobes.into(); body["refine_factor"] = query.refine_factor.into(); diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index a94526ca..ee5e5bba 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -1842,7 +1842,7 @@ impl TableInternal for NativeTable { scanner.nprobs(query.nprobes); scanner.use_index(query.use_index); - scanner.prefilter(query.prefilter); + scanner.prefilter(query.base.prefilter); match query.base.select { Select::Columns(ref columns) => { scanner.project(columns.as_slice())?;