diff --git a/rust/vectordb/src/query.rs b/rust/vectordb/src/query.rs index efc0f9cb..e6580936 100644 --- a/rust/vectordb/src/query.rs +++ b/rust/vectordb/src/query.rs @@ -32,6 +32,7 @@ pub struct Query { pub refine_factor: Option, pub metric_type: Option, pub use_index: bool, + pub prefilter: bool, } impl Query { @@ -56,6 +57,7 @@ impl Query { use_index: true, filter: None, select: None, + prefilter: false, } } @@ -74,6 +76,8 @@ impl Query { )?; scanner.nprobs(self.nprobes); scanner.use_index(self.use_index); + scanner.prefilter(self.prefilter); + self.select.as_ref().map(|p| scanner.project(p.as_slice())); self.filter.as_ref().map(|f| scanner.filter(f)); self.refine_factor.map(|rf| scanner.refine(rf)); @@ -158,6 +162,11 @@ impl Query { self.select = columns; self } + + pub fn prefilter(mut self, prefilter: bool) -> Query { + self.prefilter = prefilter; + self + } } #[cfg(test)] @@ -167,7 +176,9 @@ mod tests { use super::*; use arrow_array::{Float32Array, RecordBatch, RecordBatchIterator, RecordBatchReader}; use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema}; + use futures::StreamExt; use lance::dataset::Dataset; + use lance_testing::datagen::{BatchGenerator, IncrementingInt32, RandomVector}; use crate::query::Query; @@ -200,13 +211,43 @@ mod tests { #[tokio::test] async fn test_execute() { - let batches = make_test_batches(); - let ds = Dataset::write(batches, "memory://foo", None).await.unwrap(); + let batches = make_non_empty_batches(); + let ds = Arc::new(Dataset::write(batches, "memory://foo", None).await.unwrap()); - let vector = Float32Array::from_iter_values([0.1; 128]); - let query = Query::new(Arc::new(ds), vector.clone()); - let result = query.execute().await; - assert_eq!(result.is_ok(), true); + let vector = Float32Array::from_iter_values([0.1; 4]); + + let query = Query::new(ds.clone(), vector.clone()); + let result = query + .limit(10) + .filter(Some("id % 2 == 0".to_string())) + .execute() + .await; + let mut stream = result.expect("should have result"); + // should only have one batch + while let Some(batch) = stream.next().await { + // post filter should have removed some rows + assert!(batch.expect("should be Ok").num_rows() < 10); + } + + let query = Query::new(ds, vector.clone()); + let result = query + .limit(10) + .filter(Some("id % 2 == 0".to_string())) + .prefilter(true) + .execute() + .await; + let mut stream = result.expect("should have result"); + // should only have one batch + while let Some(batch) = stream.next().await { + // pre filter should return 10 rows + assert!(batch.expect("should be Ok").num_rows() == 10); + } + } + + fn make_non_empty_batches() -> impl RecordBatchReader + Send + 'static { + let vec = Box::new(RandomVector::new().named("vector".to_string())); + let id = Box::new(IncrementingInt32::new().named("id".to_string())); + BatchGenerator::new().col(vec).col(id).batch(512) } fn make_test_batches() -> impl RecordBatchReader + Send + 'static {