From d84176095e529c419f897c7a5d19f281faa12b60 Mon Sep 17 00:00:00 2001 From: discord9 Date: Mon, 16 Mar 2026 19:14:54 +0800 Subject: [PATCH] feat: QueryContext to ScanRequest Signed-off-by: discord9 --- src/query/src/dummy_catalog.rs | 172 +++++++++++++++++++++++++++++++-- src/query/src/error.rs | 10 +- src/query/src/options.rs | 152 +++++++++++++++++++---------- 3 files changed, 275 insertions(+), 59 deletions(-) diff --git a/src/query/src/dummy_catalog.rs b/src/query/src/dummy_catalog.rs index 239cf7cea8..75f612a37e 100644 --- a/src/query/src/dummy_catalog.rs +++ b/src/query/src/dummy_catalog.rs @@ -43,6 +43,7 @@ use table::metadata::{TableId, TableInfoRef}; use table::table::scan::RegionScanExec; use crate::error::{GetRegionMetadataSnafu, Result}; +use crate::options::FlowQueryExtensions; /// Resolve to the given region (specified by [RegionId]) unconditionally. #[derive(Clone, Debug)] @@ -295,14 +296,11 @@ impl DummyTableProviderFactory { region_id, })?; - let scan_request = query_ctx - .as_ref() - .map(|ctx| ScanRequest { - memtable_max_sequence: ctx.get_snapshot(region_id.as_u64()), - sst_min_sequence: ctx.sst_min_sequence(region_id.as_u64()), - ..Default::default() - }) - .unwrap_or_default(); + let scan_request = if let Some(ctx) = query_ctx.as_ref() { + scan_request_from_query_context(region_id, ctx)? + } else { + ScanRequest::default() + }; Ok(DummyTableProvider { region_id, @@ -314,6 +312,32 @@ impl DummyTableProviderFactory { } } +fn scan_request_from_query_context( + region_id: RegionId, + query_ctx: &QueryContext, +) -> Result { + let mut scan_request = ScanRequest { + memtable_max_sequence: query_ctx.get_snapshot(region_id.as_u64()), + sst_min_sequence: query_ctx.sst_min_sequence(region_id.as_u64()), + ..Default::default() + }; + + let flow_extensions = FlowQueryExtensions::from_extensions(&query_ctx.extensions())?; + + let should_apply_incremental = flow_extensions.validate_for_scan(region_id)?; + if should_apply_incremental + && let Some(after_seq) = flow_extensions + .incremental_after_seqs + .as_ref() + .and_then(|seqs| seqs.get(®ion_id.as_u64())) + .copied() + { + scan_request.memtable_min_sequence = Some(after_seq); + } + + Ok(scan_request) +} + #[async_trait] impl TableProviderFactory for DummyTableProviderFactory { async fn create( @@ -443,3 +467,135 @@ impl CatalogManager for DummyCatalogManager { Box::pin(futures::stream::empty()) } } + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + use std::sync::RwLock; + + use common_error::ext::ErrorExt; + use common_error::status_code::StatusCode; + use session::context::QueryContextBuilder; + + use super::*; + use crate::error::Error; + use crate::options::{FLOW_INCREMENTAL_AFTER_SEQS, FLOW_INCREMENTAL_MODE, FLOW_SINK_TABLE_ID}; + + fn test_region_id() -> RegionId { + RegionId::new(1024, 1) + } + + #[test] + fn test_scan_request_from_query_context_keeps_snapshot_fields() { + let region_id = test_region_id(); + let query_ctx = QueryContextBuilder::default() + .snapshot_seqs(Arc::new(RwLock::new(HashMap::from([( + region_id.as_u64(), + 100, + )])))) + .sst_min_sequences(Arc::new(RwLock::new(HashMap::from([( + region_id.as_u64(), + 90, + )])))) + .build(); + + let request = scan_request_from_query_context(region_id, &query_ctx).unwrap(); + assert_eq!(request.memtable_max_sequence, Some(100)); + assert_eq!(request.sst_min_sequence, Some(90)); + assert_eq!(request.memtable_min_sequence, None); + } + + #[test] + fn test_scan_request_from_query_context_applies_incremental_after_seq_for_source_scan() { + let region_id = test_region_id(); + let query_ctx = QueryContextBuilder::default() + .extensions(HashMap::from([ + ( + FLOW_INCREMENTAL_MODE.to_string(), + "memtable_only".to_string(), + ), + ( + FLOW_INCREMENTAL_AFTER_SEQS.to_string(), + format!(r#"{{"{}":55}}"#, region_id.as_u64()), + ), + ])) + .build(); + + let request = scan_request_from_query_context(region_id, &query_ctx).unwrap(); + assert_eq!(request.memtable_min_sequence, Some(55)); + } + + #[test] + fn test_scan_request_from_query_context_does_not_apply_incremental_for_sink_table() { + let region_id = test_region_id(); + let query_ctx = QueryContextBuilder::default() + .extensions(HashMap::from([ + ( + FLOW_INCREMENTAL_MODE.to_string(), + "memtable_only".to_string(), + ), + ( + FLOW_INCREMENTAL_AFTER_SEQS.to_string(), + format!(r#"{{"{}":55}}"#, region_id.as_u64()), + ), + ( + FLOW_SINK_TABLE_ID.to_string(), + region_id.table_id().to_string(), + ), + ])) + .build(); + + let request = scan_request_from_query_context(region_id, &query_ctx).unwrap(); + assert_eq!(request.memtable_min_sequence, None); + } + + #[test] + fn test_scan_request_from_query_context_rejects_missing_memtable_only_region() { + let region_id = test_region_id(); + let query_ctx = QueryContextBuilder::default() + .extensions(HashMap::from([ + ( + FLOW_INCREMENTAL_MODE.to_string(), + "memtable_only".to_string(), + ), + ( + FLOW_INCREMENTAL_AFTER_SEQS.to_string(), + r#"{"9":55}"#.to_string(), + ), + ])) + .build(); + + let err = scan_request_from_query_context(region_id, &query_ctx).unwrap_err(); + assert!(matches!(err, Error::InvalidQueryContextExtension { .. })); + } + + #[test] + fn test_scan_request_from_query_context_rejects_invalid_incremental_json() { + let region_id = test_region_id(); + let query_ctx = QueryContextBuilder::default() + .extensions(HashMap::from([( + FLOW_INCREMENTAL_AFTER_SEQS.to_string(), + "not-json".to_string(), + )])) + .build(); + + let err = scan_request_from_query_context(region_id, &query_ctx).unwrap_err(); + assert!(matches!(err, Error::InvalidQueryContextExtension { .. })); + assert_eq!(err.status_code(), StatusCode::InvalidArguments); + } + + #[test] + fn test_scan_request_from_query_context_rejects_invalid_sink_table_id() { + let region_id = test_region_id(); + let query_ctx = QueryContextBuilder::default() + .extensions(HashMap::from([( + FLOW_SINK_TABLE_ID.to_string(), + "abc".to_string(), + )])) + .build(); + + let err = scan_request_from_query_context(region_id, &query_ctx).unwrap_err(); + assert!(matches!(err, Error::InvalidQueryContextExtension { .. })); + assert_eq!(err.status_code(), StatusCode::InvalidArguments); + } +} diff --git a/src/query/src/error.rs b/src/query/src/error.rs index f863a26c4a..b3a4ebeba5 100644 --- a/src/query/src/error.rs +++ b/src/query/src/error.rs @@ -368,6 +368,13 @@ pub enum Error { location: Location, }, + #[snafu(display("Invalid query context extension: {}", reason))] + InvalidQueryContextExtension { + reason: String, + #[snafu(implicit)] + location: Location, + }, + #[snafu(transparent)] Datatypes { source: datatypes::error::Error, @@ -399,7 +406,8 @@ impl ErrorExt for Error { | ColumnSchemaNoDefault { .. } | CteColumnSchemaMismatch { .. } | ConvertValue { .. } - | TryIntoDuration { .. } => StatusCode::InvalidArguments, + | TryIntoDuration { .. } + | InvalidQueryContextExtension { .. } => StatusCode::InvalidArguments, BuildBackend { .. } | ListObjects { .. } => StatusCode::StorageUnavailable, diff --git a/src/query/src/options.rs b/src/query/src/options.rs index f807c1c16a..0bcfbb2041 100644 --- a/src/query/src/options.rs +++ b/src/query/src/options.rs @@ -16,8 +16,11 @@ use std::collections::HashMap; use common_base::memory_limit::MemoryLimit; use serde::{Deserialize, Serialize}; +use store_api::storage::RegionId; use table::metadata::TableId; +use crate::error::{Error, InvalidQueryContextExtensionSnafu, Result}; + pub const FLOW_INCREMENTAL_AFTER_SEQS: &str = "flow.incremental_after_seqs"; pub const FLOW_INCREMENTAL_MODE: &str = "flow.incremental_mode"; pub const FLOW_RETURN_REGION_SEQ: &str = "flow.return_region_seq"; @@ -64,17 +67,17 @@ pub struct FlowQueryExtensions { } impl FlowQueryExtensions { - pub fn from_extensions(extensions: &HashMap) -> Result { + pub fn from_extensions(extensions: &HashMap) -> Result { let incremental_mode = extensions .get(FLOW_INCREMENTAL_MODE) .map(|value| match value.as_str() { v if v.eq_ignore_ascii_case(FLOW_INCREMENTAL_MODE_MEMTABLE_ONLY) => { Ok(FlowIncrementalMode::MemtableOnly) } - _ => Err(format!( + _ => Err(invalid_query_context_extension(format!( "Invalid value for {}: {}", FLOW_INCREMENTAL_MODE, value - )), + ))), }) .transpose()?; @@ -92,28 +95,31 @@ impl FlowQueryExtensions { let sink_table_id = extensions .get(FLOW_SINK_TABLE_ID) .map(|value| { - value - .parse::() - .map_err(|_| format!("Invalid value for {}: {}", FLOW_SINK_TABLE_ID, value)) + value.parse::().map_err(|_| { + invalid_query_context_extension(format!( + "Invalid value for {}: {}", + FLOW_SINK_TABLE_ID, value + )) + }) }) .transpose()?; if matches!(incremental_mode, Some(FlowIncrementalMode::MemtableOnly)) { let after_seqs = incremental_after_seqs.as_ref().ok_or_else(|| { - format!( + invalid_query_context_extension(format!( "{} is required when {}={}.", FLOW_INCREMENTAL_AFTER_SEQS, FLOW_INCREMENTAL_MODE, FLOW_INCREMENTAL_MODE_MEMTABLE_ONLY - ) + )) })?; if after_seqs.is_empty() { - return Err(format!( + return Err(invalid_query_context_extension(format!( "{} must not be empty when {}={}.", FLOW_INCREMENTAL_AFTER_SEQS, FLOW_INCREMENTAL_MODE, FLOW_INCREMENTAL_MODE_MEMTABLE_ONLY - )); + ))); } } @@ -125,12 +131,8 @@ impl FlowQueryExtensions { }) } - pub fn validate_for_scan( - &self, - source_region_ids: &[u64], - current_scan_table_id: Option, - ) -> Result { - if self.sink_table_id.is_some() && self.sink_table_id == current_scan_table_id { + pub fn validate_for_scan(&self, source_region_id: RegionId) -> Result { + if self.sink_table_id.is_some() && self.sink_table_id == Some(source_region_id.table_id()) { return Ok(false); } @@ -139,19 +141,17 @@ impl FlowQueryExtensions { Some(FlowIncrementalMode::MemtableOnly) ) { let after_seqs = self.incremental_after_seqs.as_ref().ok_or_else(|| { - format!( + invalid_query_context_extension(format!( "{} is required when {}=memtable_only.", FLOW_INCREMENTAL_AFTER_SEQS, FLOW_INCREMENTAL_MODE - ) + )) })?; - for region_id in source_region_ids { - if !after_seqs.contains_key(region_id) { - return Err(format!( - "Missing region {} in {} when {}=memtable_only.", - region_id, FLOW_INCREMENTAL_AFTER_SEQS, FLOW_INCREMENTAL_MODE - )); - } + if !after_seqs.contains_key(&source_region_id.as_u64()) { + return Err(invalid_query_context_extension(format!( + "Missing region {} in {} when {}=memtable_only.", + source_region_id, FLOW_INCREMENTAL_AFTER_SEQS, FLOW_INCREMENTAL_MODE + ))); } } @@ -159,40 +159,64 @@ impl FlowQueryExtensions { } } -fn parse_incremental_after_seqs(value: &str) -> Result, String> { - let raw = serde_json::from_str::>(value).map_err(|e| { - format!( +fn parse_incremental_after_seqs(value: &str) -> Result> { + let raw = serde_json::from_str::>(value).map_err(|e| { + invalid_query_context_extension(format!( "Invalid JSON for {}: {} ({})", FLOW_INCREMENTAL_AFTER_SEQS, value, e - ) + )) })?; raw.into_iter() - .map(|(region_id, seq)| { - region_id - .parse::() - .map(|region_id| (region_id, seq)) - .map_err(|_| { - format!( - "Invalid region id in {}: {}", + .map(|(region_id, raw_seq)| { + let region_id = region_id.parse::().map_err(|_| { + invalid_query_context_extension(format!( + "Invalid region id in {}: {}", + FLOW_INCREMENTAL_AFTER_SEQS, region_id + )) + })?; + + let seq = match raw_seq { + serde_json::Value::Number(num) => num.as_u64().ok_or_else(|| { + invalid_query_context_extension(format!( + "Invalid sequence value in {} for region {}: {}", + FLOW_INCREMENTAL_AFTER_SEQS, region_id, num + )) + })?, + serde_json::Value::String(s) => s.parse::().map_err(|_| { + invalid_query_context_extension(format!( + "Invalid sequence string in {} for region {}: {}", + FLOW_INCREMENTAL_AFTER_SEQS, region_id, s + )) + })?, + _ => { + return Err(invalid_query_context_extension(format!( + "Invalid sequence value type in {} for region {}", FLOW_INCREMENTAL_AFTER_SEQS, region_id - ) - }) + ))); + } + }; + + Ok((region_id, seq)) }) .collect() } -fn parse_bool(value: &str) -> Result { +fn parse_bool(value: &str) -> Result { match value { v if v.eq_ignore_ascii_case("true") => Ok(true), v if v.eq_ignore_ascii_case("false") => Ok(false), - _ => Err(format!( + _ => Err(invalid_query_context_extension(format!( "Invalid value for {}: {}", FLOW_RETURN_REGION_SEQ, value - )), + ))), } } +fn invalid_query_context_extension(reason: String) -> Error { + InvalidQueryContextExtensionSnafu { reason }.build() +} + #[cfg(test)] mod flow_extension_tests { use super::*; @@ -244,7 +268,7 @@ mod flow_extension_tests { )]); let err = FlowQueryExtensions::from_extensions(&exts).unwrap_err(); - assert!(err.contains(FLOW_INCREMENTAL_AFTER_SEQS)); + assert!(format!("{err}").contains(FLOW_INCREMENTAL_AFTER_SEQS)); } #[test] @@ -252,7 +276,7 @@ mod flow_extension_tests { let exts = HashMap::from([(FLOW_INCREMENTAL_MODE.to_string(), "foo".to_string())]); let err = FlowQueryExtensions::from_extensions(&exts).unwrap_err(); - assert!(err.contains(FLOW_INCREMENTAL_MODE)); + assert!(format!("{err}").contains(FLOW_INCREMENTAL_MODE)); } #[test] @@ -269,7 +293,32 @@ mod flow_extension_tests { ]); let err = FlowQueryExtensions::from_extensions(&exts).unwrap_err(); - assert!(err.contains(FLOW_INCREMENTAL_AFTER_SEQS)); + assert!(format!("{err}").contains(FLOW_INCREMENTAL_AFTER_SEQS)); + } + + #[test] + fn test_parse_flow_extensions_after_seqs_string_values() { + let exts = HashMap::from([( + FLOW_INCREMENTAL_AFTER_SEQS.to_string(), + r#"{"1":"10","2":"20"}"#.to_string(), + )]); + + let parsed = FlowQueryExtensions::from_extensions(&exts).unwrap(); + assert_eq!( + parsed.incremental_after_seqs.unwrap(), + HashMap::from([(1, 10), (2, 20)]) + ); + } + + #[test] + fn test_parse_flow_extensions_after_seqs_invalid_value_type() { + let exts = HashMap::from([( + FLOW_INCREMENTAL_AFTER_SEQS.to_string(), + r#"{"1":true}"#.to_string(), + )]); + + let err = FlowQueryExtensions::from_extensions(&exts).unwrap_err(); + assert!(format!("{err}").contains(FLOW_INCREMENTAL_AFTER_SEQS)); } #[test] @@ -277,11 +326,13 @@ mod flow_extension_tests { let exts = HashMap::from([(FLOW_SINK_TABLE_ID.to_string(), "x".to_string())]); let err = FlowQueryExtensions::from_extensions(&exts).unwrap_err(); - assert!(err.contains(FLOW_SINK_TABLE_ID)); + assert!(format!("{err}").contains(FLOW_SINK_TABLE_ID)); } #[test] fn test_validate_for_scan_missing_source_region() { + let source_region_id = RegionId::new(100, 2); + let existing_region_id = RegionId::new(100, 1); let exts = HashMap::from([ ( FLOW_INCREMENTAL_MODE.to_string(), @@ -289,17 +340,18 @@ mod flow_extension_tests { ), ( FLOW_INCREMENTAL_AFTER_SEQS.to_string(), - r#"{"1":10}"#.to_string(), + format!(r#"{{"{}":10}}"#, existing_region_id.as_u64()), ), ]); let parsed = FlowQueryExtensions::from_extensions(&exts).unwrap(); - let err = parsed.validate_for_scan(&[1, 2], Some(100)).unwrap_err(); - assert!(err.contains("Missing region 2")); + let err = parsed.validate_for_scan(source_region_id).unwrap_err(); + assert!(format!("{err}").contains("Missing region")); } #[test] fn test_validate_for_scan_sink_table_excluded() { + let source_region_id = RegionId::new(1024, 1); let exts = HashMap::from([ ( FLOW_INCREMENTAL_MODE.to_string(), @@ -307,13 +359,13 @@ mod flow_extension_tests { ), ( FLOW_INCREMENTAL_AFTER_SEQS.to_string(), - r#"{"1":10}"#.to_string(), + format!(r#"{{"{}":10}}"#, source_region_id.as_u64()), ), (FLOW_SINK_TABLE_ID.to_string(), "1024".to_string()), ]); let parsed = FlowQueryExtensions::from_extensions(&exts).unwrap(); - let apply_incremental = parsed.validate_for_scan(&[1, 2], Some(1024)).unwrap(); + let apply_incremental = parsed.validate_for_scan(source_region_id).unwrap(); assert!(!apply_incremental); } }