diff --git a/Cargo.lock b/Cargo.lock index af94386b..1a53d252 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2571,7 +2571,7 @@ dependencies = [ [[package]] name = "fsst" version = "0.23.2" -source = "git+https://github.com/lancedb/lance.git?tag=v0.23.2-beta.1#c69a5a21389eb64f4b51810045bcb4cada9234e9" +source = "git+https://github.com/lancedb/lance.git?tag=v0.23.2-beta.3#59d65964d1113e7c06ea2af76a166eef6fffc465" dependencies = [ "rand", ] @@ -3533,7 +3533,7 @@ dependencies = [ [[package]] name = "lance" version = "0.23.2" -source = "git+https://github.com/lancedb/lance.git?tag=v0.23.2-beta.1#c69a5a21389eb64f4b51810045bcb4cada9234e9" +source = "git+https://github.com/lancedb/lance.git?tag=v0.23.2-beta.3#59d65964d1113e7c06ea2af76a166eef6fffc465" dependencies = [ "arrow", "arrow-arith", @@ -3593,7 +3593,7 @@ dependencies = [ [[package]] name = "lance-arrow" version = "0.23.2" -source = "git+https://github.com/lancedb/lance.git?tag=v0.23.2-beta.1#c69a5a21389eb64f4b51810045bcb4cada9234e9" +source = "git+https://github.com/lancedb/lance.git?tag=v0.23.2-beta.3#59d65964d1113e7c06ea2af76a166eef6fffc465" dependencies = [ "arrow-array", "arrow-buffer", @@ -3611,7 +3611,7 @@ dependencies = [ [[package]] name = "lance-core" version = "0.23.2" -source = "git+https://github.com/lancedb/lance.git?tag=v0.23.2-beta.1#c69a5a21389eb64f4b51810045bcb4cada9234e9" +source = "git+https://github.com/lancedb/lance.git?tag=v0.23.2-beta.3#59d65964d1113e7c06ea2af76a166eef6fffc465" dependencies = [ "arrow-array", "arrow-buffer", @@ -3648,7 +3648,7 @@ dependencies = [ [[package]] name = "lance-datafusion" version = "0.23.2" -source = "git+https://github.com/lancedb/lance.git?tag=v0.23.2-beta.1#c69a5a21389eb64f4b51810045bcb4cada9234e9" +source = "git+https://github.com/lancedb/lance.git?tag=v0.23.2-beta.3#59d65964d1113e7c06ea2af76a166eef6fffc465" dependencies = [ "arrow", "arrow-array", @@ -3674,7 +3674,7 @@ dependencies = [ [[package]] name = "lance-encoding" version = "0.23.2" -source = "git+https://github.com/lancedb/lance.git?tag=v0.23.2-beta.1#c69a5a21389eb64f4b51810045bcb4cada9234e9" +source = "git+https://github.com/lancedb/lance.git?tag=v0.23.2-beta.3#59d65964d1113e7c06ea2af76a166eef6fffc465" dependencies = [ "arrayref", "arrow", @@ -3713,7 +3713,7 @@ dependencies = [ [[package]] name = "lance-file" version = "0.23.2" -source = "git+https://github.com/lancedb/lance.git?tag=v0.23.2-beta.1#c69a5a21389eb64f4b51810045bcb4cada9234e9" +source = "git+https://github.com/lancedb/lance.git?tag=v0.23.2-beta.3#59d65964d1113e7c06ea2af76a166eef6fffc465" dependencies = [ "arrow-arith", "arrow-array", @@ -3748,7 +3748,7 @@ dependencies = [ [[package]] name = "lance-index" version = "0.23.2" -source = "git+https://github.com/lancedb/lance.git?tag=v0.23.2-beta.1#c69a5a21389eb64f4b51810045bcb4cada9234e9" +source = "git+https://github.com/lancedb/lance.git?tag=v0.23.2-beta.3#59d65964d1113e7c06ea2af76a166eef6fffc465" dependencies = [ "arrow", "arrow-array", @@ -3801,7 +3801,7 @@ dependencies = [ [[package]] name = "lance-io" version = "0.23.2" -source = "git+https://github.com/lancedb/lance.git?tag=v0.23.2-beta.1#c69a5a21389eb64f4b51810045bcb4cada9234e9" +source = "git+https://github.com/lancedb/lance.git?tag=v0.23.2-beta.3#59d65964d1113e7c06ea2af76a166eef6fffc465" dependencies = [ "arrow", "arrow-arith", @@ -3840,7 +3840,7 @@ dependencies = [ [[package]] name = "lance-linalg" version = "0.23.2" -source = "git+https://github.com/lancedb/lance.git?tag=v0.23.2-beta.1#c69a5a21389eb64f4b51810045bcb4cada9234e9" +source = "git+https://github.com/lancedb/lance.git?tag=v0.23.2-beta.3#59d65964d1113e7c06ea2af76a166eef6fffc465" dependencies = [ "arrow-array", "arrow-ord", @@ -3864,7 +3864,7 @@ dependencies = [ [[package]] name = "lance-table" version = "0.23.2" -source = "git+https://github.com/lancedb/lance.git?tag=v0.23.2-beta.1#c69a5a21389eb64f4b51810045bcb4cada9234e9" +source = "git+https://github.com/lancedb/lance.git?tag=v0.23.2-beta.3#59d65964d1113e7c06ea2af76a166eef6fffc465" dependencies = [ "arrow", "arrow-array", @@ -3904,7 +3904,7 @@ dependencies = [ [[package]] name = "lance-testing" version = "0.23.2" -source = "git+https://github.com/lancedb/lance.git?tag=v0.23.2-beta.1#c69a5a21389eb64f4b51810045bcb4cada9234e9" +source = "git+https://github.com/lancedb/lance.git?tag=v0.23.2-beta.3#59d65964d1113e7c06ea2af76a166eef6fffc465" dependencies = [ "arrow-array", "arrow-schema", diff --git a/Cargo.toml b/Cargo.toml index 853a18b0..2f7c0e3a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,14 +23,14 @@ rust-version = "1.78.0" [workspace.dependencies] lance = { "version" = "=0.23.2", "features" = [ "dynamodb", -], git = "https://github.com/lancedb/lance.git", tag = "v0.23.2-beta.1"} -lance-io = {version = "=0.23.2", tag="v0.23.2-beta.1", git = "https://github.com/lancedb/lance.git"} -lance-index = {version = "=0.23.2", tag="v0.23.2-beta.1", git = "https://github.com/lancedb/lance.git"} -lance-linalg = {version = "=0.23.2", tag="v0.23.2-beta.1", git = "https://github.com/lancedb/lance.git"} -lance-table = {version = "=0.23.2", tag="v0.23.2-beta.1", git = "https://github.com/lancedb/lance.git"} -lance-testing = {version = "=0.23.2", tag="v0.23.2-beta.1", git = "https://github.com/lancedb/lance.git"} -lance-datafusion = {version = "=0.23.2", tag="v0.23.2-beta.1", git = "https://github.com/lancedb/lance.git"} -lance-encoding = {version = "=0.23.2", tag="v0.23.2-beta.1", git = "https://github.com/lancedb/lance.git"} +], git = "https://github.com/lancedb/lance.git", tag = "v0.23.2-beta.3" } +lance-io = { version = "=0.23.2", tag = "v0.23.2-beta.3", git = "https://github.com/lancedb/lance.git" } +lance-index = { version = "=0.23.2", tag = "v0.23.2-beta.3", git = "https://github.com/lancedb/lance.git" } +lance-linalg = { version = "=0.23.2", tag = "v0.23.2-beta.3", git = "https://github.com/lancedb/lance.git" } +lance-table = { version = "=0.23.2", tag = "v0.23.2-beta.3", git = "https://github.com/lancedb/lance.git" } +lance-testing = { version = "=0.23.2", tag = "v0.23.2-beta.3", git = "https://github.com/lancedb/lance.git" } +lance-datafusion = { version = "=0.23.2", tag = "v0.23.2-beta.3", git = "https://github.com/lancedb/lance.git" } +lance-encoding = { version = "=0.23.2", tag = "v0.23.2-beta.3", git = "https://github.com/lancedb/lance.git" } # Note that this one does not include pyarrow arrow = { version = "53.2", optional = false } arrow-array = "53.2" diff --git a/python/pyproject.toml b/python/pyproject.toml index e216e6af..097f6ec9 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -4,7 +4,7 @@ name = "lancedb" dynamic = ["version"] dependencies = [ "deprecation", - "pylance==0.23.0", + "pylance==0.23.2b3", "tqdm>=4.27.0", "pydantic>=1.10", "packaging", @@ -55,7 +55,12 @@ tests = [ "tantivy", "pyarrow-stubs", ] -dev = ["ruff", "pre-commit", "pyright", 'typing-extensions>=4.0.0; python_version < "3.11"'] +dev = [ + "ruff", + "pre-commit", + "pyright", + 'typing-extensions>=4.0.0; python_version < "3.11"', +] docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"] clip = ["torch", "pillow", "open-clip"] embeddings = [ diff --git a/rust/lancedb/src/query.rs b/rust/lancedb/src/query.rs index 1a69f8e0..afc47222 100644 --- a/rust/lancedb/src/query.rs +++ b/rust/lancedb/src/query.rs @@ -7,6 +7,7 @@ use std::sync::Arc; use arrow::compute::concat_batches; use arrow_array::{make_array, Array, Float16Array, Float32Array, Float64Array}; use arrow_schema::DataType; +use datafusion_expr::Expr; use datafusion_physical_plan::ExecutionPlan; use futures::{stream, try_join, FutureExt, TryStreamExt}; use half::f16; @@ -464,7 +465,7 @@ impl QueryBase for T { } fn only_if(mut self, filter: impl AsRef) -> Self { - self.mut_query().filter = Some(filter.as_ref().to_string()); + self.mut_query().filter = Some(QueryFilter::Sql(filter.as_ref().to_string())); self } @@ -577,6 +578,17 @@ pub trait ExecutableQuery { fn explain_plan(&self, verbose: bool) -> impl Future> + Send; } +/// A query filter that can be applied to a query +#[derive(Clone, Debug)] +pub enum QueryFilter { + /// The filter is an SQL string + Sql(String), + /// The filter is a Substrait ExtendedExpression message with a single expression + Substrait(Arc<[u8]>), + /// The filter is a Datafusion expression + Datafusion(Expr), +} + /// A basic query into a table without any kind of search /// /// This will result in a (potentially filtered) scan if executed @@ -589,7 +601,7 @@ pub struct QueryRequest { pub offset: Option, /// Apply filter to the returned rows. - pub filter: Option, + pub filter: Option, /// Perform a full text search on the table. pub full_text_search: Option, diff --git a/rust/lancedb/src/remote/table.rs b/rust/lancedb/src/remote/table.rs index bbdbfacf..e7c182dd 100644 --- a/rust/lancedb/src/remote/table.rs +++ b/rust/lancedb/src/remote/table.rs @@ -7,7 +7,7 @@ use std::sync::{Arc, Mutex}; use crate::index::Index; use crate::index::IndexStatistics; -use crate::query::{QueryRequest, Select, VectorQueryRequest}; +use crate::query::{QueryFilter, QueryRequest, Select, VectorQueryRequest}; use crate::table::{AddDataMode, AnyQuery, Filter}; use crate::utils::{supported_btree_data_type, supported_vector_data_type}; use crate::{DistanceType, Error}; @@ -159,7 +159,13 @@ impl RemoteTable { } if let Some(filter) = ¶ms.filter { - body["filter"] = serde_json::Value::String(filter.clone()); + if let QueryFilter::Sql(filter) = filter { + body["filter"] = serde_json::Value::String(filter.clone()); + } else { + return Err(Error::NotSupported { + message: "querying a remote table with a non-sql filter".to_string(), + }); + } } match ¶ms.select { diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index ca5be432..4ba9d6c0 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -62,7 +62,7 @@ use crate::index::{ }; use crate::index::{IndexConfig, IndexStatisticsImpl}; use crate::query::{ - IntoQueryVector, Query, QueryExecutionOptions, QueryRequest, Select, VectorQuery, + IntoQueryVector, Query, QueryExecutionOptions, QueryFilter, QueryRequest, Select, VectorQuery, VectorQueryRequest, DEFAULT_TOP_K, }; use crate::utils::{ @@ -2125,7 +2125,17 @@ impl BaseTable for NativeTable { } if let Some(filter) = &query.base.filter { - scanner.filter(filter)?; + match filter { + QueryFilter::Sql(sql) => { + scanner.filter(sql)?; + } + QueryFilter::Substrait(substrait) => { + scanner.filter_substrait(substrait)?; + } + QueryFilter::Datafusion(expr) => { + scanner.filter_expr(expr.clone()); + } + } } if let Some(fts) = &query.base.full_text_search { diff --git a/rust/lancedb/src/table/datafusion.rs b/rust/lancedb/src/table/datafusion.rs index a27a8d92..a613c3b5 100644 --- a/rust/lancedb/src/table/datafusion.rs +++ b/rust/lancedb/src/table/datafusion.rs @@ -17,7 +17,7 @@ use futures::{TryFutureExt, TryStreamExt}; use super::{AnyQuery, BaseTable}; use crate::{ - query::{QueryExecutionOptions, QueryRequest, Select}, + query::{QueryExecutionOptions, QueryFilter, QueryRequest, Select}, Result, }; @@ -161,7 +161,13 @@ impl TableProvider for BaseTableAdapter { .collect(); query.select = Select::Columns(field_names); } - assert!(filters.is_empty()); + if !filters.is_empty() { + let first = filters.first().unwrap().clone(); + let filter = filters[1..] + .iter() + .fold(first, |acc, expr| acc.and(expr.clone())); + query.filter = Some(QueryFilter::Datafusion(filter)); + } if let Some(limit) = limit { query.limit = Some(limit); } else { @@ -180,11 +186,7 @@ impl TableProvider for BaseTableAdapter { &self, filters: &[&Expr], ) -> DataFusionResult> { - // TODO: Pushdown unsupported until we can support datafusion filters in BaseTable::create_plan - Ok(vec![ - TableProviderFilterPushDown::Unsupported; - filters.len() - ]) + Ok(vec![TableProviderFilterPushDown::Exact; filters.len()]) } fn statistics(&self) -> Option { @@ -197,67 +199,182 @@ impl TableProvider for BaseTableAdapter { pub mod tests { use std::{collections::HashMap, sync::Arc}; - use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator, RecordBatchReader}; + use arrow::array::AsArray; + use arrow_array::{ + Int32Array, RecordBatch, RecordBatchIterator, RecordBatchReader, UInt32Array, + }; use arrow_schema::{DataType, Field, Schema}; use datafusion::{datasource::provider_as_source, prelude::SessionContext}; use datafusion_catalog::TableProvider; - use datafusion_expr::LogicalPlanBuilder; + use datafusion_execution::SendableRecordBatchStream; + use datafusion_expr::{col, lit, LogicalPlan, LogicalPlanBuilder}; use futures::TryStreamExt; use tempfile::tempdir; - use crate::{connect, table::datafusion::BaseTableAdapter}; + use crate::{ + connect, + index::{scalar::BTreeIndexBuilder, Index}, + table::datafusion::BaseTableAdapter, + }; fn make_test_batches() -> impl RecordBatchReader + Send + Sync + 'static { let metadata = HashMap::from_iter(vec![("foo".to_string(), "bar".to_string())]); let schema = Arc::new( - Schema::new(vec![Field::new("i", DataType::Int32, false)]).with_metadata(metadata), + Schema::new(vec![ + Field::new("i", DataType::Int32, false), + Field::new("indexed", DataType::UInt32, false), + ]) + .with_metadata(metadata), ); RecordBatchIterator::new( vec![RecordBatch::try_new( schema.clone(), - vec![Arc::new(Int32Array::from_iter_values(0..10))], + vec![ + Arc::new(Int32Array::from_iter_values(0..10)), + Arc::new(UInt32Array::from_iter_values(0..10)), + ], )], schema, ) } + struct TestFixture { + _tmp_dir: tempfile::TempDir, + adapter: Arc, + } + + impl TestFixture { + async fn new() -> Self { + let tmp_dir = tempdir().unwrap(); + let dataset_path = tmp_dir.path().join("test.lance"); + let uri = dataset_path.to_str().unwrap(); + + let db = connect(uri).execute().await.unwrap(); + + let tbl = db + .create_table("foo", make_test_batches()) + .execute() + .await + .unwrap(); + + tbl.create_index(&["indexed"], Index::BTree(BTreeIndexBuilder::default())) + .execute() + .await + .unwrap(); + + let adapter = Arc::new( + BaseTableAdapter::try_new(tbl.base_table().clone()) + .await + .unwrap(), + ); + + Self { + _tmp_dir: tmp_dir, + adapter, + } + } + + async fn plan_to_stream(plan: LogicalPlan) -> SendableRecordBatchStream { + SessionContext::new() + .execute_logical_plan(plan) + .await + .unwrap() + .execute_stream() + .await + .unwrap() + } + + async fn plan_to_explain(plan: LogicalPlan) -> String { + let mut explain_stream = SessionContext::new() + .execute_logical_plan(plan) + .await + .unwrap() + .explain(true, false) + .unwrap() + .execute_stream() + .await + .unwrap(); + let batch = explain_stream.try_next().await.unwrap().unwrap(); + assert!(explain_stream.try_next().await.unwrap().is_none()); + + let plan_descs = batch.columns()[0].as_string::(); + let plans = batch.columns()[1].as_string::(); + + for (desc, plan) in plan_descs.iter().zip(plans.iter()) { + if desc.unwrap() == "physical_plan" { + return plan.unwrap().to_string(); + } + } + panic!("No physical plan found in explain output"); + } + + async fn check_plan(plan: LogicalPlan, expected: &str) { + let physical_plan = dbg!(Self::plan_to_explain(plan).await); + let mut lines_checked = 0; + for (actual_line, expected_line) in physical_plan.lines().zip(expected.lines()) { + lines_checked += 1; + let actual_trimmed = actual_line.trim(); + let expected_trimmed = if let Some(ellipsis_pos) = expected_line.find("...") { + expected_line[0..ellipsis_pos].trim() + } else { + expected_line.trim() + }; + assert_eq!(&actual_trimmed[..expected_trimmed.len()], expected_trimmed); + } + assert_eq!(lines_checked, expected.lines().count()); + } + } + #[tokio::test] async fn test_metadata_erased() { - let tmp_dir = tempdir().unwrap(); - let dataset_path = tmp_dir.path().join("test.lance"); - let uri = dataset_path.to_str().unwrap(); + let fixture = TestFixture::new().await; - let db = connect(uri).execute().await.unwrap(); + assert!(fixture.adapter.schema().metadata().is_empty()); - let tbl = db - .create_table("foo", make_test_batches()) - .execute() - .await - .unwrap(); - - let provider = Arc::new( - BaseTableAdapter::try_new(tbl.base_table().clone()) - .await - .unwrap(), - ); - - assert!(provider.schema().metadata().is_empty()); - - let plan = LogicalPlanBuilder::scan("foo", provider_as_source(provider), None) + let plan = LogicalPlanBuilder::scan("foo", provider_as_source(fixture.adapter), None) .unwrap() .build() .unwrap(); - let mut stream = SessionContext::new() - .execute_logical_plan(plan) - .await - .unwrap() - .execute_stream() - .await - .unwrap(); + let mut stream = TestFixture::plan_to_stream(plan).await; while let Some(batch) = stream.try_next().await.unwrap() { assert!(batch.schema().metadata().is_empty()); } } + + #[tokio::test] + async fn test_filter_pushdown() { + let fixture = TestFixture::new().await; + + // Basic filter, not much different pushed down than run from DF + let plan = + LogicalPlanBuilder::scan("foo", provider_as_source(fixture.adapter.clone()), None) + .unwrap() + .filter(col("i").gt_eq(lit(5))) + .unwrap() + .build() + .unwrap(); + + TestFixture::check_plan( + plan, + "MetadataEraserExec + RepartitionExec:... + CoalesceBatchesExec:... + FilterExec: i@0 >= 5 + ProjectionExec:... + LanceScan:...", + ) + .await; + + // Filter utilizing scalar index, make sure it gets pushed down + let plan = LogicalPlanBuilder::scan("foo", provider_as_source(fixture.adapter), None) + .unwrap() + .filter(col("indexed").eq(lit(5))) + .unwrap() + .build() + .unwrap(); + + TestFixture::check_plan(plan, "").await; + } }