feat: push filters down into DF table provider (#2128)

This commit is contained in:
Weston Pace
2025-02-25 14:46:28 -08:00
committed by GitHub
parent 979a2d3d9d
commit c4f99e82e5
7 changed files with 216 additions and 66 deletions

24
Cargo.lock generated
View File

@@ -2571,7 +2571,7 @@ dependencies = [
[[package]]
name = "fsst"
version = "0.23.2"
source = "git+https://github.com/lancedb/lance.git?tag=v0.23.2-beta.1#c69a5a21389eb64f4b51810045bcb4cada9234e9"
source = "git+https://github.com/lancedb/lance.git?tag=v0.23.2-beta.3#59d65964d1113e7c06ea2af76a166eef6fffc465"
dependencies = [
"rand",
]
@@ -3533,7 +3533,7 @@ dependencies = [
[[package]]
name = "lance"
version = "0.23.2"
source = "git+https://github.com/lancedb/lance.git?tag=v0.23.2-beta.1#c69a5a21389eb64f4b51810045bcb4cada9234e9"
source = "git+https://github.com/lancedb/lance.git?tag=v0.23.2-beta.3#59d65964d1113e7c06ea2af76a166eef6fffc465"
dependencies = [
"arrow",
"arrow-arith",
@@ -3593,7 +3593,7 @@ dependencies = [
[[package]]
name = "lance-arrow"
version = "0.23.2"
source = "git+https://github.com/lancedb/lance.git?tag=v0.23.2-beta.1#c69a5a21389eb64f4b51810045bcb4cada9234e9"
source = "git+https://github.com/lancedb/lance.git?tag=v0.23.2-beta.3#59d65964d1113e7c06ea2af76a166eef6fffc465"
dependencies = [
"arrow-array",
"arrow-buffer",
@@ -3611,7 +3611,7 @@ dependencies = [
[[package]]
name = "lance-core"
version = "0.23.2"
source = "git+https://github.com/lancedb/lance.git?tag=v0.23.2-beta.1#c69a5a21389eb64f4b51810045bcb4cada9234e9"
source = "git+https://github.com/lancedb/lance.git?tag=v0.23.2-beta.3#59d65964d1113e7c06ea2af76a166eef6fffc465"
dependencies = [
"arrow-array",
"arrow-buffer",
@@ -3648,7 +3648,7 @@ dependencies = [
[[package]]
name = "lance-datafusion"
version = "0.23.2"
source = "git+https://github.com/lancedb/lance.git?tag=v0.23.2-beta.1#c69a5a21389eb64f4b51810045bcb4cada9234e9"
source = "git+https://github.com/lancedb/lance.git?tag=v0.23.2-beta.3#59d65964d1113e7c06ea2af76a166eef6fffc465"
dependencies = [
"arrow",
"arrow-array",
@@ -3674,7 +3674,7 @@ dependencies = [
[[package]]
name = "lance-encoding"
version = "0.23.2"
source = "git+https://github.com/lancedb/lance.git?tag=v0.23.2-beta.1#c69a5a21389eb64f4b51810045bcb4cada9234e9"
source = "git+https://github.com/lancedb/lance.git?tag=v0.23.2-beta.3#59d65964d1113e7c06ea2af76a166eef6fffc465"
dependencies = [
"arrayref",
"arrow",
@@ -3713,7 +3713,7 @@ dependencies = [
[[package]]
name = "lance-file"
version = "0.23.2"
source = "git+https://github.com/lancedb/lance.git?tag=v0.23.2-beta.1#c69a5a21389eb64f4b51810045bcb4cada9234e9"
source = "git+https://github.com/lancedb/lance.git?tag=v0.23.2-beta.3#59d65964d1113e7c06ea2af76a166eef6fffc465"
dependencies = [
"arrow-arith",
"arrow-array",
@@ -3748,7 +3748,7 @@ dependencies = [
[[package]]
name = "lance-index"
version = "0.23.2"
source = "git+https://github.com/lancedb/lance.git?tag=v0.23.2-beta.1#c69a5a21389eb64f4b51810045bcb4cada9234e9"
source = "git+https://github.com/lancedb/lance.git?tag=v0.23.2-beta.3#59d65964d1113e7c06ea2af76a166eef6fffc465"
dependencies = [
"arrow",
"arrow-array",
@@ -3801,7 +3801,7 @@ dependencies = [
[[package]]
name = "lance-io"
version = "0.23.2"
source = "git+https://github.com/lancedb/lance.git?tag=v0.23.2-beta.1#c69a5a21389eb64f4b51810045bcb4cada9234e9"
source = "git+https://github.com/lancedb/lance.git?tag=v0.23.2-beta.3#59d65964d1113e7c06ea2af76a166eef6fffc465"
dependencies = [
"arrow",
"arrow-arith",
@@ -3840,7 +3840,7 @@ dependencies = [
[[package]]
name = "lance-linalg"
version = "0.23.2"
source = "git+https://github.com/lancedb/lance.git?tag=v0.23.2-beta.1#c69a5a21389eb64f4b51810045bcb4cada9234e9"
source = "git+https://github.com/lancedb/lance.git?tag=v0.23.2-beta.3#59d65964d1113e7c06ea2af76a166eef6fffc465"
dependencies = [
"arrow-array",
"arrow-ord",
@@ -3864,7 +3864,7 @@ dependencies = [
[[package]]
name = "lance-table"
version = "0.23.2"
source = "git+https://github.com/lancedb/lance.git?tag=v0.23.2-beta.1#c69a5a21389eb64f4b51810045bcb4cada9234e9"
source = "git+https://github.com/lancedb/lance.git?tag=v0.23.2-beta.3#59d65964d1113e7c06ea2af76a166eef6fffc465"
dependencies = [
"arrow",
"arrow-array",
@@ -3904,7 +3904,7 @@ dependencies = [
[[package]]
name = "lance-testing"
version = "0.23.2"
source = "git+https://github.com/lancedb/lance.git?tag=v0.23.2-beta.1#c69a5a21389eb64f4b51810045bcb4cada9234e9"
source = "git+https://github.com/lancedb/lance.git?tag=v0.23.2-beta.3#59d65964d1113e7c06ea2af76a166eef6fffc465"
dependencies = [
"arrow-array",
"arrow-schema",

View File

@@ -23,14 +23,14 @@ rust-version = "1.78.0"
[workspace.dependencies]
lance = { "version" = "=0.23.2", "features" = [
"dynamodb",
], git = "https://github.com/lancedb/lance.git", tag = "v0.23.2-beta.1"}
lance-io = {version = "=0.23.2", tag="v0.23.2-beta.1", git = "https://github.com/lancedb/lance.git"}
lance-index = {version = "=0.23.2", tag="v0.23.2-beta.1", git = "https://github.com/lancedb/lance.git"}
lance-linalg = {version = "=0.23.2", tag="v0.23.2-beta.1", git = "https://github.com/lancedb/lance.git"}
lance-table = {version = "=0.23.2", tag="v0.23.2-beta.1", git = "https://github.com/lancedb/lance.git"}
lance-testing = {version = "=0.23.2", tag="v0.23.2-beta.1", git = "https://github.com/lancedb/lance.git"}
lance-datafusion = {version = "=0.23.2", tag="v0.23.2-beta.1", git = "https://github.com/lancedb/lance.git"}
lance-encoding = {version = "=0.23.2", tag="v0.23.2-beta.1", git = "https://github.com/lancedb/lance.git"}
], git = "https://github.com/lancedb/lance.git", tag = "v0.23.2-beta.3" }
lance-io = { version = "=0.23.2", tag = "v0.23.2-beta.3", git = "https://github.com/lancedb/lance.git" }
lance-index = { version = "=0.23.2", tag = "v0.23.2-beta.3", git = "https://github.com/lancedb/lance.git" }
lance-linalg = { version = "=0.23.2", tag = "v0.23.2-beta.3", git = "https://github.com/lancedb/lance.git" }
lance-table = { version = "=0.23.2", tag = "v0.23.2-beta.3", git = "https://github.com/lancedb/lance.git" }
lance-testing = { version = "=0.23.2", tag = "v0.23.2-beta.3", git = "https://github.com/lancedb/lance.git" }
lance-datafusion = { version = "=0.23.2", tag = "v0.23.2-beta.3", git = "https://github.com/lancedb/lance.git" }
lance-encoding = { version = "=0.23.2", tag = "v0.23.2-beta.3", git = "https://github.com/lancedb/lance.git" }
# Note that this one does not include pyarrow
arrow = { version = "53.2", optional = false }
arrow-array = "53.2"

View File

@@ -4,7 +4,7 @@ name = "lancedb"
dynamic = ["version"]
dependencies = [
"deprecation",
"pylance==0.23.0",
"pylance==0.23.2b3",
"tqdm>=4.27.0",
"pydantic>=1.10",
"packaging",
@@ -55,7 +55,12 @@ tests = [
"tantivy",
"pyarrow-stubs",
]
dev = ["ruff", "pre-commit", "pyright", 'typing-extensions>=4.0.0; python_version < "3.11"']
dev = [
"ruff",
"pre-commit",
"pyright",
'typing-extensions>=4.0.0; python_version < "3.11"',
]
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
clip = ["torch", "pillow", "open-clip"]
embeddings = [

View File

@@ -7,6 +7,7 @@ use std::sync::Arc;
use arrow::compute::concat_batches;
use arrow_array::{make_array, Array, Float16Array, Float32Array, Float64Array};
use arrow_schema::DataType;
use datafusion_expr::Expr;
use datafusion_physical_plan::ExecutionPlan;
use futures::{stream, try_join, FutureExt, TryStreamExt};
use half::f16;
@@ -464,7 +465,7 @@ impl<T: HasQuery> QueryBase for T {
}
fn only_if(mut self, filter: impl AsRef<str>) -> Self {
self.mut_query().filter = Some(filter.as_ref().to_string());
self.mut_query().filter = Some(QueryFilter::Sql(filter.as_ref().to_string()));
self
}
@@ -577,6 +578,17 @@ pub trait ExecutableQuery {
fn explain_plan(&self, verbose: bool) -> impl Future<Output = Result<String>> + Send;
}
/// A query filter that can be applied to a query
#[derive(Clone, Debug)]
pub enum QueryFilter {
/// The filter is an SQL string
Sql(String),
/// The filter is a Substrait ExtendedExpression message with a single expression
Substrait(Arc<[u8]>),
/// The filter is a Datafusion expression
Datafusion(Expr),
}
/// A basic query into a table without any kind of search
///
/// This will result in a (potentially filtered) scan if executed
@@ -589,7 +601,7 @@ pub struct QueryRequest {
pub offset: Option<usize>,
/// Apply filter to the returned rows.
pub filter: Option<String>,
pub filter: Option<QueryFilter>,
/// Perform a full text search on the table.
pub full_text_search: Option<FullTextSearchQuery>,

View File

@@ -7,7 +7,7 @@ use std::sync::{Arc, Mutex};
use crate::index::Index;
use crate::index::IndexStatistics;
use crate::query::{QueryRequest, Select, VectorQueryRequest};
use crate::query::{QueryFilter, QueryRequest, Select, VectorQueryRequest};
use crate::table::{AddDataMode, AnyQuery, Filter};
use crate::utils::{supported_btree_data_type, supported_vector_data_type};
use crate::{DistanceType, Error};
@@ -159,7 +159,13 @@ impl<S: HttpSend> RemoteTable<S> {
}
if let Some(filter) = &params.filter {
body["filter"] = serde_json::Value::String(filter.clone());
if let QueryFilter::Sql(filter) = filter {
body["filter"] = serde_json::Value::String(filter.clone());
} else {
return Err(Error::NotSupported {
message: "querying a remote table with a non-sql filter".to_string(),
});
}
}
match &params.select {

View File

@@ -62,7 +62,7 @@ use crate::index::{
};
use crate::index::{IndexConfig, IndexStatisticsImpl};
use crate::query::{
IntoQueryVector, Query, QueryExecutionOptions, QueryRequest, Select, VectorQuery,
IntoQueryVector, Query, QueryExecutionOptions, QueryFilter, QueryRequest, Select, VectorQuery,
VectorQueryRequest, DEFAULT_TOP_K,
};
use crate::utils::{
@@ -2125,7 +2125,17 @@ impl BaseTable for NativeTable {
}
if let Some(filter) = &query.base.filter {
scanner.filter(filter)?;
match filter {
QueryFilter::Sql(sql) => {
scanner.filter(sql)?;
}
QueryFilter::Substrait(substrait) => {
scanner.filter_substrait(substrait)?;
}
QueryFilter::Datafusion(expr) => {
scanner.filter_expr(expr.clone());
}
}
}
if let Some(fts) = &query.base.full_text_search {

View File

@@ -17,7 +17,7 @@ use futures::{TryFutureExt, TryStreamExt};
use super::{AnyQuery, BaseTable};
use crate::{
query::{QueryExecutionOptions, QueryRequest, Select},
query::{QueryExecutionOptions, QueryFilter, QueryRequest, Select},
Result,
};
@@ -161,7 +161,13 @@ impl TableProvider for BaseTableAdapter {
.collect();
query.select = Select::Columns(field_names);
}
assert!(filters.is_empty());
if !filters.is_empty() {
let first = filters.first().unwrap().clone();
let filter = filters[1..]
.iter()
.fold(first, |acc, expr| acc.and(expr.clone()));
query.filter = Some(QueryFilter::Datafusion(filter));
}
if let Some(limit) = limit {
query.limit = Some(limit);
} else {
@@ -180,11 +186,7 @@ impl TableProvider for BaseTableAdapter {
&self,
filters: &[&Expr],
) -> DataFusionResult<Vec<TableProviderFilterPushDown>> {
// TODO: Pushdown unsupported until we can support datafusion filters in BaseTable::create_plan
Ok(vec![
TableProviderFilterPushDown::Unsupported;
filters.len()
])
Ok(vec![TableProviderFilterPushDown::Exact; filters.len()])
}
fn statistics(&self) -> Option<Statistics> {
@@ -197,67 +199,182 @@ impl TableProvider for BaseTableAdapter {
pub mod tests {
use std::{collections::HashMap, sync::Arc};
use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator, RecordBatchReader};
use arrow::array::AsArray;
use arrow_array::{
Int32Array, RecordBatch, RecordBatchIterator, RecordBatchReader, UInt32Array,
};
use arrow_schema::{DataType, Field, Schema};
use datafusion::{datasource::provider_as_source, prelude::SessionContext};
use datafusion_catalog::TableProvider;
use datafusion_expr::LogicalPlanBuilder;
use datafusion_execution::SendableRecordBatchStream;
use datafusion_expr::{col, lit, LogicalPlan, LogicalPlanBuilder};
use futures::TryStreamExt;
use tempfile::tempdir;
use crate::{connect, table::datafusion::BaseTableAdapter};
use crate::{
connect,
index::{scalar::BTreeIndexBuilder, Index},
table::datafusion::BaseTableAdapter,
};
fn make_test_batches() -> impl RecordBatchReader + Send + Sync + 'static {
let metadata = HashMap::from_iter(vec![("foo".to_string(), "bar".to_string())]);
let schema = Arc::new(
Schema::new(vec![Field::new("i", DataType::Int32, false)]).with_metadata(metadata),
Schema::new(vec![
Field::new("i", DataType::Int32, false),
Field::new("indexed", DataType::UInt32, false),
])
.with_metadata(metadata),
);
RecordBatchIterator::new(
vec![RecordBatch::try_new(
schema.clone(),
vec![Arc::new(Int32Array::from_iter_values(0..10))],
vec![
Arc::new(Int32Array::from_iter_values(0..10)),
Arc::new(UInt32Array::from_iter_values(0..10)),
],
)],
schema,
)
}
struct TestFixture {
_tmp_dir: tempfile::TempDir,
adapter: Arc<BaseTableAdapter>,
}
impl TestFixture {
async fn new() -> Self {
let tmp_dir = tempdir().unwrap();
let dataset_path = tmp_dir.path().join("test.lance");
let uri = dataset_path.to_str().unwrap();
let db = connect(uri).execute().await.unwrap();
let tbl = db
.create_table("foo", make_test_batches())
.execute()
.await
.unwrap();
tbl.create_index(&["indexed"], Index::BTree(BTreeIndexBuilder::default()))
.execute()
.await
.unwrap();
let adapter = Arc::new(
BaseTableAdapter::try_new(tbl.base_table().clone())
.await
.unwrap(),
);
Self {
_tmp_dir: tmp_dir,
adapter,
}
}
async fn plan_to_stream(plan: LogicalPlan) -> SendableRecordBatchStream {
SessionContext::new()
.execute_logical_plan(plan)
.await
.unwrap()
.execute_stream()
.await
.unwrap()
}
async fn plan_to_explain(plan: LogicalPlan) -> String {
let mut explain_stream = SessionContext::new()
.execute_logical_plan(plan)
.await
.unwrap()
.explain(true, false)
.unwrap()
.execute_stream()
.await
.unwrap();
let batch = explain_stream.try_next().await.unwrap().unwrap();
assert!(explain_stream.try_next().await.unwrap().is_none());
let plan_descs = batch.columns()[0].as_string::<i32>();
let plans = batch.columns()[1].as_string::<i32>();
for (desc, plan) in plan_descs.iter().zip(plans.iter()) {
if desc.unwrap() == "physical_plan" {
return plan.unwrap().to_string();
}
}
panic!("No physical plan found in explain output");
}
async fn check_plan(plan: LogicalPlan, expected: &str) {
let physical_plan = dbg!(Self::plan_to_explain(plan).await);
let mut lines_checked = 0;
for (actual_line, expected_line) in physical_plan.lines().zip(expected.lines()) {
lines_checked += 1;
let actual_trimmed = actual_line.trim();
let expected_trimmed = if let Some(ellipsis_pos) = expected_line.find("...") {
expected_line[0..ellipsis_pos].trim()
} else {
expected_line.trim()
};
assert_eq!(&actual_trimmed[..expected_trimmed.len()], expected_trimmed);
}
assert_eq!(lines_checked, expected.lines().count());
}
}
#[tokio::test]
async fn test_metadata_erased() {
let tmp_dir = tempdir().unwrap();
let dataset_path = tmp_dir.path().join("test.lance");
let uri = dataset_path.to_str().unwrap();
let fixture = TestFixture::new().await;
let db = connect(uri).execute().await.unwrap();
assert!(fixture.adapter.schema().metadata().is_empty());
let tbl = db
.create_table("foo", make_test_batches())
.execute()
.await
.unwrap();
let provider = Arc::new(
BaseTableAdapter::try_new(tbl.base_table().clone())
.await
.unwrap(),
);
assert!(provider.schema().metadata().is_empty());
let plan = LogicalPlanBuilder::scan("foo", provider_as_source(provider), None)
let plan = LogicalPlanBuilder::scan("foo", provider_as_source(fixture.adapter), None)
.unwrap()
.build()
.unwrap();
let mut stream = SessionContext::new()
.execute_logical_plan(plan)
.await
.unwrap()
.execute_stream()
.await
.unwrap();
let mut stream = TestFixture::plan_to_stream(plan).await;
while let Some(batch) = stream.try_next().await.unwrap() {
assert!(batch.schema().metadata().is_empty());
}
}
#[tokio::test]
async fn test_filter_pushdown() {
let fixture = TestFixture::new().await;
// Basic filter, not much different pushed down than run from DF
let plan =
LogicalPlanBuilder::scan("foo", provider_as_source(fixture.adapter.clone()), None)
.unwrap()
.filter(col("i").gt_eq(lit(5)))
.unwrap()
.build()
.unwrap();
TestFixture::check_plan(
plan,
"MetadataEraserExec
RepartitionExec:...
CoalesceBatchesExec:...
FilterExec: i@0 >= 5
ProjectionExec:...
LanceScan:...",
)
.await;
// Filter utilizing scalar index, make sure it gets pushed down
let plan = LogicalPlanBuilder::scan("foo", provider_as_source(fixture.adapter), None)
.unwrap()
.filter(col("indexed").eq(lit(5)))
.unwrap()
.build()
.unwrap();
TestFixture::check_plan(plan, "").await;
}
}