feat: QueryContext to ScanRequest

Signed-off-by: discord9 <discord9@163.com>
This commit is contained in:
discord9
2026-03-16 19:14:54 +08:00
parent e6eb272137
commit d84176095e
3 changed files with 275 additions and 59 deletions

View File

@@ -43,6 +43,7 @@ use table::metadata::{TableId, TableInfoRef};
use table::table::scan::RegionScanExec;
use crate::error::{GetRegionMetadataSnafu, Result};
use crate::options::FlowQueryExtensions;
/// Resolve to the given region (specified by [RegionId]) unconditionally.
#[derive(Clone, Debug)]
@@ -295,14 +296,11 @@ impl DummyTableProviderFactory {
region_id,
})?;
let scan_request = query_ctx
.as_ref()
.map(|ctx| ScanRequest {
memtable_max_sequence: ctx.get_snapshot(region_id.as_u64()),
sst_min_sequence: ctx.sst_min_sequence(region_id.as_u64()),
..Default::default()
})
.unwrap_or_default();
let scan_request = if let Some(ctx) = query_ctx.as_ref() {
scan_request_from_query_context(region_id, ctx)?
} else {
ScanRequest::default()
};
Ok(DummyTableProvider {
region_id,
@@ -314,6 +312,32 @@ impl DummyTableProviderFactory {
}
}
fn scan_request_from_query_context(
region_id: RegionId,
query_ctx: &QueryContext,
) -> Result<ScanRequest> {
let mut scan_request = ScanRequest {
memtable_max_sequence: query_ctx.get_snapshot(region_id.as_u64()),
sst_min_sequence: query_ctx.sst_min_sequence(region_id.as_u64()),
..Default::default()
};
let flow_extensions = FlowQueryExtensions::from_extensions(&query_ctx.extensions())?;
let should_apply_incremental = flow_extensions.validate_for_scan(region_id)?;
if should_apply_incremental
&& let Some(after_seq) = flow_extensions
.incremental_after_seqs
.as_ref()
.and_then(|seqs| seqs.get(&region_id.as_u64()))
.copied()
{
scan_request.memtable_min_sequence = Some(after_seq);
}
Ok(scan_request)
}
#[async_trait]
impl TableProviderFactory for DummyTableProviderFactory {
async fn create(
@@ -443,3 +467,135 @@ impl CatalogManager for DummyCatalogManager {
Box::pin(futures::stream::empty())
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::sync::RwLock;
use common_error::ext::ErrorExt;
use common_error::status_code::StatusCode;
use session::context::QueryContextBuilder;
use super::*;
use crate::error::Error;
use crate::options::{FLOW_INCREMENTAL_AFTER_SEQS, FLOW_INCREMENTAL_MODE, FLOW_SINK_TABLE_ID};
fn test_region_id() -> RegionId {
RegionId::new(1024, 1)
}
#[test]
fn test_scan_request_from_query_context_keeps_snapshot_fields() {
let region_id = test_region_id();
let query_ctx = QueryContextBuilder::default()
.snapshot_seqs(Arc::new(RwLock::new(HashMap::from([(
region_id.as_u64(),
100,
)]))))
.sst_min_sequences(Arc::new(RwLock::new(HashMap::from([(
region_id.as_u64(),
90,
)]))))
.build();
let request = scan_request_from_query_context(region_id, &query_ctx).unwrap();
assert_eq!(request.memtable_max_sequence, Some(100));
assert_eq!(request.sst_min_sequence, Some(90));
assert_eq!(request.memtable_min_sequence, None);
}
#[test]
fn test_scan_request_from_query_context_applies_incremental_after_seq_for_source_scan() {
let region_id = test_region_id();
let query_ctx = QueryContextBuilder::default()
.extensions(HashMap::from([
(
FLOW_INCREMENTAL_MODE.to_string(),
"memtable_only".to_string(),
),
(
FLOW_INCREMENTAL_AFTER_SEQS.to_string(),
format!(r#"{{"{}":55}}"#, region_id.as_u64()),
),
]))
.build();
let request = scan_request_from_query_context(region_id, &query_ctx).unwrap();
assert_eq!(request.memtable_min_sequence, Some(55));
}
#[test]
fn test_scan_request_from_query_context_does_not_apply_incremental_for_sink_table() {
let region_id = test_region_id();
let query_ctx = QueryContextBuilder::default()
.extensions(HashMap::from([
(
FLOW_INCREMENTAL_MODE.to_string(),
"memtable_only".to_string(),
),
(
FLOW_INCREMENTAL_AFTER_SEQS.to_string(),
format!(r#"{{"{}":55}}"#, region_id.as_u64()),
),
(
FLOW_SINK_TABLE_ID.to_string(),
region_id.table_id().to_string(),
),
]))
.build();
let request = scan_request_from_query_context(region_id, &query_ctx).unwrap();
assert_eq!(request.memtable_min_sequence, None);
}
#[test]
fn test_scan_request_from_query_context_rejects_missing_memtable_only_region() {
let region_id = test_region_id();
let query_ctx = QueryContextBuilder::default()
.extensions(HashMap::from([
(
FLOW_INCREMENTAL_MODE.to_string(),
"memtable_only".to_string(),
),
(
FLOW_INCREMENTAL_AFTER_SEQS.to_string(),
r#"{"9":55}"#.to_string(),
),
]))
.build();
let err = scan_request_from_query_context(region_id, &query_ctx).unwrap_err();
assert!(matches!(err, Error::InvalidQueryContextExtension { .. }));
}
#[test]
fn test_scan_request_from_query_context_rejects_invalid_incremental_json() {
let region_id = test_region_id();
let query_ctx = QueryContextBuilder::default()
.extensions(HashMap::from([(
FLOW_INCREMENTAL_AFTER_SEQS.to_string(),
"not-json".to_string(),
)]))
.build();
let err = scan_request_from_query_context(region_id, &query_ctx).unwrap_err();
assert!(matches!(err, Error::InvalidQueryContextExtension { .. }));
assert_eq!(err.status_code(), StatusCode::InvalidArguments);
}
#[test]
fn test_scan_request_from_query_context_rejects_invalid_sink_table_id() {
let region_id = test_region_id();
let query_ctx = QueryContextBuilder::default()
.extensions(HashMap::from([(
FLOW_SINK_TABLE_ID.to_string(),
"abc".to_string(),
)]))
.build();
let err = scan_request_from_query_context(region_id, &query_ctx).unwrap_err();
assert!(matches!(err, Error::InvalidQueryContextExtension { .. }));
assert_eq!(err.status_code(), StatusCode::InvalidArguments);
}
}

View File

@@ -368,6 +368,13 @@ pub enum Error {
location: Location,
},
#[snafu(display("Invalid query context extension: {}", reason))]
InvalidQueryContextExtension {
reason: String,
#[snafu(implicit)]
location: Location,
},
#[snafu(transparent)]
Datatypes {
source: datatypes::error::Error,
@@ -399,7 +406,8 @@ impl ErrorExt for Error {
| ColumnSchemaNoDefault { .. }
| CteColumnSchemaMismatch { .. }
| ConvertValue { .. }
| TryIntoDuration { .. } => StatusCode::InvalidArguments,
| TryIntoDuration { .. }
| InvalidQueryContextExtension { .. } => StatusCode::InvalidArguments,
BuildBackend { .. } | ListObjects { .. } => StatusCode::StorageUnavailable,

View File

@@ -16,8 +16,11 @@ use std::collections::HashMap;
use common_base::memory_limit::MemoryLimit;
use serde::{Deserialize, Serialize};
use store_api::storage::RegionId;
use table::metadata::TableId;
use crate::error::{Error, InvalidQueryContextExtensionSnafu, Result};
pub const FLOW_INCREMENTAL_AFTER_SEQS: &str = "flow.incremental_after_seqs";
pub const FLOW_INCREMENTAL_MODE: &str = "flow.incremental_mode";
pub const FLOW_RETURN_REGION_SEQ: &str = "flow.return_region_seq";
@@ -64,17 +67,17 @@ pub struct FlowQueryExtensions {
}
impl FlowQueryExtensions {
pub fn from_extensions(extensions: &HashMap<String, String>) -> Result<Self, String> {
pub fn from_extensions(extensions: &HashMap<String, String>) -> Result<Self> {
let incremental_mode = extensions
.get(FLOW_INCREMENTAL_MODE)
.map(|value| match value.as_str() {
v if v.eq_ignore_ascii_case(FLOW_INCREMENTAL_MODE_MEMTABLE_ONLY) => {
Ok(FlowIncrementalMode::MemtableOnly)
}
_ => Err(format!(
_ => Err(invalid_query_context_extension(format!(
"Invalid value for {}: {}",
FLOW_INCREMENTAL_MODE, value
)),
))),
})
.transpose()?;
@@ -92,28 +95,31 @@ impl FlowQueryExtensions {
let sink_table_id = extensions
.get(FLOW_SINK_TABLE_ID)
.map(|value| {
value
.parse::<TableId>()
.map_err(|_| format!("Invalid value for {}: {}", FLOW_SINK_TABLE_ID, value))
value.parse::<TableId>().map_err(|_| {
invalid_query_context_extension(format!(
"Invalid value for {}: {}",
FLOW_SINK_TABLE_ID, value
))
})
})
.transpose()?;
if matches!(incremental_mode, Some(FlowIncrementalMode::MemtableOnly)) {
let after_seqs = incremental_after_seqs.as_ref().ok_or_else(|| {
format!(
invalid_query_context_extension(format!(
"{} is required when {}={}.",
FLOW_INCREMENTAL_AFTER_SEQS,
FLOW_INCREMENTAL_MODE,
FLOW_INCREMENTAL_MODE_MEMTABLE_ONLY
)
))
})?;
if after_seqs.is_empty() {
return Err(format!(
return Err(invalid_query_context_extension(format!(
"{} must not be empty when {}={}.",
FLOW_INCREMENTAL_AFTER_SEQS,
FLOW_INCREMENTAL_MODE,
FLOW_INCREMENTAL_MODE_MEMTABLE_ONLY
));
)));
}
}
@@ -125,12 +131,8 @@ impl FlowQueryExtensions {
})
}
pub fn validate_for_scan(
&self,
source_region_ids: &[u64],
current_scan_table_id: Option<TableId>,
) -> Result<bool, String> {
if self.sink_table_id.is_some() && self.sink_table_id == current_scan_table_id {
pub fn validate_for_scan(&self, source_region_id: RegionId) -> Result<bool> {
if self.sink_table_id.is_some() && self.sink_table_id == Some(source_region_id.table_id()) {
return Ok(false);
}
@@ -139,19 +141,17 @@ impl FlowQueryExtensions {
Some(FlowIncrementalMode::MemtableOnly)
) {
let after_seqs = self.incremental_after_seqs.as_ref().ok_or_else(|| {
format!(
invalid_query_context_extension(format!(
"{} is required when {}=memtable_only.",
FLOW_INCREMENTAL_AFTER_SEQS, FLOW_INCREMENTAL_MODE
)
))
})?;
for region_id in source_region_ids {
if !after_seqs.contains_key(region_id) {
return Err(format!(
"Missing region {} in {} when {}=memtable_only.",
region_id, FLOW_INCREMENTAL_AFTER_SEQS, FLOW_INCREMENTAL_MODE
));
}
if !after_seqs.contains_key(&source_region_id.as_u64()) {
return Err(invalid_query_context_extension(format!(
"Missing region {} in {} when {}=memtable_only.",
source_region_id, FLOW_INCREMENTAL_AFTER_SEQS, FLOW_INCREMENTAL_MODE
)));
}
}
@@ -159,40 +159,64 @@ impl FlowQueryExtensions {
}
}
fn parse_incremental_after_seqs(value: &str) -> Result<HashMap<u64, u64>, String> {
let raw = serde_json::from_str::<HashMap<String, u64>>(value).map_err(|e| {
format!(
fn parse_incremental_after_seqs(value: &str) -> Result<HashMap<u64, u64>> {
let raw = serde_json::from_str::<HashMap<String, serde_json::Value>>(value).map_err(|e| {
invalid_query_context_extension(format!(
"Invalid JSON for {}: {} ({})",
FLOW_INCREMENTAL_AFTER_SEQS, value, e
)
))
})?;
raw.into_iter()
.map(|(region_id, seq)| {
region_id
.parse::<u64>()
.map(|region_id| (region_id, seq))
.map_err(|_| {
format!(
"Invalid region id in {}: {}",
.map(|(region_id, raw_seq)| {
let region_id = region_id.parse::<u64>().map_err(|_| {
invalid_query_context_extension(format!(
"Invalid region id in {}: {}",
FLOW_INCREMENTAL_AFTER_SEQS, region_id
))
})?;
let seq = match raw_seq {
serde_json::Value::Number(num) => num.as_u64().ok_or_else(|| {
invalid_query_context_extension(format!(
"Invalid sequence value in {} for region {}: {}",
FLOW_INCREMENTAL_AFTER_SEQS, region_id, num
))
})?,
serde_json::Value::String(s) => s.parse::<u64>().map_err(|_| {
invalid_query_context_extension(format!(
"Invalid sequence string in {} for region {}: {}",
FLOW_INCREMENTAL_AFTER_SEQS, region_id, s
))
})?,
_ => {
return Err(invalid_query_context_extension(format!(
"Invalid sequence value type in {} for region {}",
FLOW_INCREMENTAL_AFTER_SEQS, region_id
)
})
)));
}
};
Ok((region_id, seq))
})
.collect()
}
fn parse_bool(value: &str) -> Result<bool, String> {
fn parse_bool(value: &str) -> Result<bool> {
match value {
v if v.eq_ignore_ascii_case("true") => Ok(true),
v if v.eq_ignore_ascii_case("false") => Ok(false),
_ => Err(format!(
_ => Err(invalid_query_context_extension(format!(
"Invalid value for {}: {}",
FLOW_RETURN_REGION_SEQ, value
)),
))),
}
}
fn invalid_query_context_extension(reason: String) -> Error {
InvalidQueryContextExtensionSnafu { reason }.build()
}
#[cfg(test)]
mod flow_extension_tests {
use super::*;
@@ -244,7 +268,7 @@ mod flow_extension_tests {
)]);
let err = FlowQueryExtensions::from_extensions(&exts).unwrap_err();
assert!(err.contains(FLOW_INCREMENTAL_AFTER_SEQS));
assert!(format!("{err}").contains(FLOW_INCREMENTAL_AFTER_SEQS));
}
#[test]
@@ -252,7 +276,7 @@ mod flow_extension_tests {
let exts = HashMap::from([(FLOW_INCREMENTAL_MODE.to_string(), "foo".to_string())]);
let err = FlowQueryExtensions::from_extensions(&exts).unwrap_err();
assert!(err.contains(FLOW_INCREMENTAL_MODE));
assert!(format!("{err}").contains(FLOW_INCREMENTAL_MODE));
}
#[test]
@@ -269,7 +293,32 @@ mod flow_extension_tests {
]);
let err = FlowQueryExtensions::from_extensions(&exts).unwrap_err();
assert!(err.contains(FLOW_INCREMENTAL_AFTER_SEQS));
assert!(format!("{err}").contains(FLOW_INCREMENTAL_AFTER_SEQS));
}
#[test]
fn test_parse_flow_extensions_after_seqs_string_values() {
let exts = HashMap::from([(
FLOW_INCREMENTAL_AFTER_SEQS.to_string(),
r#"{"1":"10","2":"20"}"#.to_string(),
)]);
let parsed = FlowQueryExtensions::from_extensions(&exts).unwrap();
assert_eq!(
parsed.incremental_after_seqs.unwrap(),
HashMap::from([(1, 10), (2, 20)])
);
}
#[test]
fn test_parse_flow_extensions_after_seqs_invalid_value_type() {
let exts = HashMap::from([(
FLOW_INCREMENTAL_AFTER_SEQS.to_string(),
r#"{"1":true}"#.to_string(),
)]);
let err = FlowQueryExtensions::from_extensions(&exts).unwrap_err();
assert!(format!("{err}").contains(FLOW_INCREMENTAL_AFTER_SEQS));
}
#[test]
@@ -277,11 +326,13 @@ mod flow_extension_tests {
let exts = HashMap::from([(FLOW_SINK_TABLE_ID.to_string(), "x".to_string())]);
let err = FlowQueryExtensions::from_extensions(&exts).unwrap_err();
assert!(err.contains(FLOW_SINK_TABLE_ID));
assert!(format!("{err}").contains(FLOW_SINK_TABLE_ID));
}
#[test]
fn test_validate_for_scan_missing_source_region() {
let source_region_id = RegionId::new(100, 2);
let existing_region_id = RegionId::new(100, 1);
let exts = HashMap::from([
(
FLOW_INCREMENTAL_MODE.to_string(),
@@ -289,17 +340,18 @@ mod flow_extension_tests {
),
(
FLOW_INCREMENTAL_AFTER_SEQS.to_string(),
r#"{"1":10}"#.to_string(),
format!(r#"{{"{}":10}}"#, existing_region_id.as_u64()),
),
]);
let parsed = FlowQueryExtensions::from_extensions(&exts).unwrap();
let err = parsed.validate_for_scan(&[1, 2], Some(100)).unwrap_err();
assert!(err.contains("Missing region 2"));
let err = parsed.validate_for_scan(source_region_id).unwrap_err();
assert!(format!("{err}").contains("Missing region"));
}
#[test]
fn test_validate_for_scan_sink_table_excluded() {
let source_region_id = RegionId::new(1024, 1);
let exts = HashMap::from([
(
FLOW_INCREMENTAL_MODE.to_string(),
@@ -307,13 +359,13 @@ mod flow_extension_tests {
),
(
FLOW_INCREMENTAL_AFTER_SEQS.to_string(),
r#"{"1":10}"#.to_string(),
format!(r#"{{"{}":10}}"#, source_region_id.as_u64()),
),
(FLOW_SINK_TABLE_ID.to_string(), "1024".to_string()),
]);
let parsed = FlowQueryExtensions::from_extensions(&exts).unwrap();
let apply_incremental = parsed.validate_for_scan(&[1, 2], Some(1024)).unwrap();
let apply_incremental = parsed.validate_for_scan(source_region_id).unwrap();
assert!(!apply_incremental);
}
}