1use std::collections::HashMap;
16
17use common_base::memory_limit::MemoryLimit;
18use serde::{Deserialize, Serialize};
19use store_api::storage::RegionId;
20use table::metadata::TableId;
21
22use crate::error::{Error, InvalidQueryContextExtensionSnafu, Result};
23
24pub const FLOW_INCREMENTAL_AFTER_SEQS: &str = "flow.incremental_after_seqs";
25pub const FLOW_INCREMENTAL_MODE: &str = "flow.incremental_mode";
26pub const FLOW_RETURN_REGION_SEQ: &str = "flow.return_region_seq";
27pub const FLOW_SINK_TABLE_ID: &str = "flow.sink_table_id";
28
29pub const FLOW_INCREMENTAL_MODE_MEMTABLE_ONLY: &str = "memtable_only";
30
31#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
33#[serde(default)]
34pub struct QueryOptions {
35 pub parallelism: usize,
37 pub allow_query_fallback: bool,
39 pub memory_pool_size: MemoryLimit,
43}
44
45#[allow(clippy::derivable_impls)]
46impl Default for QueryOptions {
47 fn default() -> Self {
48 Self {
49 parallelism: 0,
50 allow_query_fallback: false,
51 memory_pool_size: MemoryLimit::default(),
52 }
53 }
54}
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
57pub enum FlowIncrementalMode {
58 MemtableOnly,
59}
60
61#[derive(Debug, Clone, PartialEq, Eq, Default)]
62pub struct FlowQueryExtensions {
63 pub incremental_after_seqs: Option<HashMap<u64, u64>>,
65 pub incremental_mode: Option<FlowIncrementalMode>,
67 pub return_region_seq: bool,
69 pub sink_table_id: Option<TableId>,
71}
72
73impl FlowQueryExtensions {
74 pub fn parse_flow_extensions(extensions: &HashMap<String, String>) -> Result<Option<Self>> {
80 let has_flow_context = extensions.contains_key(FLOW_INCREMENTAL_AFTER_SEQS)
81 || extensions.contains_key(FLOW_INCREMENTAL_MODE)
82 || extensions.contains_key(FLOW_RETURN_REGION_SEQ)
83 || extensions.contains_key(FLOW_SINK_TABLE_ID);
84
85 if !has_flow_context {
86 return Ok(None);
87 }
88
89 let incremental_mode = extensions
90 .get(FLOW_INCREMENTAL_MODE)
91 .map(|value| match value.as_str() {
92 v if v.eq_ignore_ascii_case(FLOW_INCREMENTAL_MODE_MEMTABLE_ONLY) => {
93 Ok(FlowIncrementalMode::MemtableOnly)
94 }
95 _ => Err(invalid_query_context_extension(format!(
96 "Invalid value for {}: {}",
97 FLOW_INCREMENTAL_MODE, value
98 ))),
99 })
100 .transpose()?;
101
102 let incremental_after_seqs = extensions
103 .get(FLOW_INCREMENTAL_AFTER_SEQS)
104 .map(|value| parse_incremental_after_seqs(value.as_str()))
105 .transpose()?;
106
107 let return_region_seq = extensions
108 .get(FLOW_RETURN_REGION_SEQ)
109 .map(|value| parse_bool(value.as_str()))
110 .transpose()?
111 .unwrap_or(false);
112
113 let sink_table_id = extensions
114 .get(FLOW_SINK_TABLE_ID)
115 .map(|value| {
116 value.parse::<TableId>().map_err(|_| {
117 invalid_query_context_extension(format!(
118 "Invalid value for {}: {}",
119 FLOW_SINK_TABLE_ID, value
120 ))
121 })
122 })
123 .transpose()?;
124
125 if matches!(incremental_mode, Some(FlowIncrementalMode::MemtableOnly)) {
126 let after_seqs = incremental_after_seqs.as_ref().ok_or_else(|| {
127 invalid_query_context_extension(format!(
128 "{} is required when {}={}.",
129 FLOW_INCREMENTAL_AFTER_SEQS,
130 FLOW_INCREMENTAL_MODE,
131 FLOW_INCREMENTAL_MODE_MEMTABLE_ONLY
132 ))
133 })?;
134 if after_seqs.is_empty() {
135 return Err(invalid_query_context_extension(format!(
136 "{} must not be empty when {}={}.",
137 FLOW_INCREMENTAL_AFTER_SEQS,
138 FLOW_INCREMENTAL_MODE,
139 FLOW_INCREMENTAL_MODE_MEMTABLE_ONLY
140 )));
141 }
142 }
143
144 Ok(Some(Self {
145 incremental_after_seqs,
146 incremental_mode,
147 return_region_seq,
148 sink_table_id,
149 }))
150 }
151
152 pub fn validate_for_scan(&self, source_region_id: RegionId) -> Result<bool> {
153 if self.sink_table_id.is_some() && self.sink_table_id == Some(source_region_id.table_id()) {
154 return Ok(false);
155 }
156
157 if matches!(
158 self.incremental_mode,
159 Some(FlowIncrementalMode::MemtableOnly)
160 ) {
161 let after_seqs = self.incremental_after_seqs.as_ref().ok_or_else(|| {
162 invalid_query_context_extension(format!(
163 "{} is required when {}=memtable_only.",
164 FLOW_INCREMENTAL_AFTER_SEQS, FLOW_INCREMENTAL_MODE
165 ))
166 })?;
167
168 if !after_seqs.contains_key(&source_region_id.as_u64()) {
169 return Err(invalid_query_context_extension(format!(
170 "Missing region {} in {} when {}=memtable_only.",
171 source_region_id, FLOW_INCREMENTAL_AFTER_SEQS, FLOW_INCREMENTAL_MODE
172 )));
173 }
174 }
175
176 Ok(self.incremental_after_seqs.is_some())
177 }
178
179 pub fn should_collect_region_watermark(&self) -> bool {
180 should_collect_region_watermark(
181 self.return_region_seq,
182 self.incremental_after_seqs.is_some(),
183 )
184 }
185}
186
187pub fn should_collect_region_watermark_from_extensions(
192 extensions: &HashMap<String, String>,
193) -> bool {
194 let return_region_seq = extensions
195 .get(FLOW_RETURN_REGION_SEQ)
196 .is_some_and(|value| value.eq_ignore_ascii_case("true"));
197 let has_incremental_after_seqs = extensions.contains_key(FLOW_INCREMENTAL_AFTER_SEQS);
198
199 should_collect_region_watermark(return_region_seq, has_incremental_after_seqs)
200}
201
202fn should_collect_region_watermark(
203 return_region_seq: bool,
204 has_incremental_after_seqs: bool,
205) -> bool {
206 return_region_seq || has_incremental_after_seqs
207}
208
209fn parse_incremental_after_seqs(value: &str) -> Result<HashMap<u64, u64>> {
210 let raw = serde_json::from_str::<HashMap<String, serde_json::Value>>(value).map_err(|e| {
211 invalid_query_context_extension(format!(
212 "Invalid JSON for {}: {} ({})",
213 FLOW_INCREMENTAL_AFTER_SEQS, value, e
214 ))
215 })?;
216
217 raw.into_iter()
218 .map(|(region_id, raw_seq)| {
219 let region_id = region_id.parse::<u64>().map_err(|_| {
220 invalid_query_context_extension(format!(
221 "Invalid region id in {}: {}",
222 FLOW_INCREMENTAL_AFTER_SEQS, region_id
223 ))
224 })?;
225
226 let seq = match raw_seq {
227 serde_json::Value::Number(num) => num.as_u64().ok_or_else(|| {
228 invalid_query_context_extension(format!(
229 "Invalid sequence value in {} for region {}: {}",
230 FLOW_INCREMENTAL_AFTER_SEQS, region_id, num
231 ))
232 })?,
233 serde_json::Value::String(s) => s.parse::<u64>().map_err(|_| {
234 invalid_query_context_extension(format!(
235 "Invalid sequence string in {} for region {}: {}",
236 FLOW_INCREMENTAL_AFTER_SEQS, region_id, s
237 ))
238 })?,
239 _ => {
240 return Err(invalid_query_context_extension(format!(
241 "Invalid sequence value type in {} for region {}",
242 FLOW_INCREMENTAL_AFTER_SEQS, region_id
243 )));
244 }
245 };
246
247 Ok((region_id, seq))
248 })
249 .collect()
250}
251
252fn parse_bool(value: &str) -> Result<bool> {
253 match value {
254 v if v.eq_ignore_ascii_case("true") => Ok(true),
255 v if v.eq_ignore_ascii_case("false") => Ok(false),
256 _ => Err(invalid_query_context_extension(format!(
257 "Invalid value for {}: {}",
258 FLOW_RETURN_REGION_SEQ, value
259 ))),
260 }
261}
262
263fn invalid_query_context_extension(reason: String) -> Error {
264 InvalidQueryContextExtensionSnafu { reason }.build()
265}
266
267#[cfg(test)]
268mod flow_extension_tests {
269 use super::*;
270
271 #[test]
272 fn test_parse_flow_extensions_returns_none_for_non_flow_query() {
273 let exts = HashMap::new();
274 let parsed = FlowQueryExtensions::parse_flow_extensions(&exts).unwrap();
275
276 assert_eq!(parsed, None);
277 }
278
279 #[test]
280 fn test_parse_flow_extensions_memtable_only_success() {
281 let exts = HashMap::from([
282 (
283 FLOW_INCREMENTAL_MODE.to_string(),
284 FLOW_INCREMENTAL_MODE_MEMTABLE_ONLY.to_string(),
285 ),
286 (
287 FLOW_INCREMENTAL_AFTER_SEQS.to_string(),
288 r#"{"1":10,"2":20}"#.to_string(),
289 ),
290 (FLOW_RETURN_REGION_SEQ.to_string(), "true".to_string()),
291 (FLOW_SINK_TABLE_ID.to_string(), "1024".to_string()),
292 ]);
293
294 let parsed = FlowQueryExtensions::parse_flow_extensions(&exts)
295 .unwrap()
296 .unwrap();
297 assert_eq!(
298 parsed.incremental_mode,
299 Some(FlowIncrementalMode::MemtableOnly)
300 );
301 assert_eq!(
302 parsed.incremental_after_seqs.unwrap(),
303 HashMap::from([(1, 10), (2, 20)])
304 );
305 assert!(parsed.return_region_seq);
306 assert_eq!(parsed.sink_table_id, Some(1024));
307 }
308
309 #[test]
310 fn test_parse_flow_extensions_mode_requires_after_seqs() {
311 let exts = HashMap::from([(
312 FLOW_INCREMENTAL_MODE.to_string(),
313 FLOW_INCREMENTAL_MODE_MEMTABLE_ONLY.to_string(),
314 )]);
315
316 let err = FlowQueryExtensions::parse_flow_extensions(&exts).unwrap_err();
317 assert!(format!("{err}").contains(FLOW_INCREMENTAL_AFTER_SEQS));
318 }
319
320 #[test]
321 fn test_parse_flow_extensions_invalid_mode() {
322 let exts = HashMap::from([(FLOW_INCREMENTAL_MODE.to_string(), "foo".to_string())]);
323
324 let err = FlowQueryExtensions::parse_flow_extensions(&exts).unwrap_err();
325 assert!(format!("{err}").contains(FLOW_INCREMENTAL_MODE));
326 }
327
328 #[test]
329 fn test_parse_flow_extensions_invalid_after_seqs_json() {
330 let exts = HashMap::from([
331 (
332 FLOW_INCREMENTAL_MODE.to_string(),
333 FLOW_INCREMENTAL_MODE_MEMTABLE_ONLY.to_string(),
334 ),
335 (
336 FLOW_INCREMENTAL_AFTER_SEQS.to_string(),
337 "not-json".to_string(),
338 ),
339 ]);
340
341 let err = FlowQueryExtensions::parse_flow_extensions(&exts).unwrap_err();
342 assert!(format!("{err}").contains(FLOW_INCREMENTAL_AFTER_SEQS));
343 }
344
345 #[test]
346 fn test_parse_flow_extensions_after_seqs_string_values() {
347 let exts = HashMap::from([(
348 FLOW_INCREMENTAL_AFTER_SEQS.to_string(),
349 r#"{"1":"10","2":"20"}"#.to_string(),
350 )]);
351
352 let parsed = FlowQueryExtensions::parse_flow_extensions(&exts)
353 .unwrap()
354 .unwrap();
355 assert_eq!(
356 parsed.incremental_after_seqs.unwrap(),
357 HashMap::from([(1, 10), (2, 20)])
358 );
359 }
360
361 #[test]
362 fn test_parse_flow_extensions_after_seqs_invalid_value_type() {
363 let exts = HashMap::from([(
364 FLOW_INCREMENTAL_AFTER_SEQS.to_string(),
365 r#"{"1":true}"#.to_string(),
366 )]);
367
368 let err = FlowQueryExtensions::parse_flow_extensions(&exts).unwrap_err();
369 assert!(format!("{err}").contains(FLOW_INCREMENTAL_AFTER_SEQS));
370 }
371
372 #[test]
373 fn test_parse_flow_extensions_invalid_sink_table_id() {
374 let exts = HashMap::from([(FLOW_SINK_TABLE_ID.to_string(), "x".to_string())]);
375
376 let err = FlowQueryExtensions::parse_flow_extensions(&exts).unwrap_err();
377 assert!(format!("{err}").contains(FLOW_SINK_TABLE_ID));
378 }
379
380 #[test]
381 fn test_validate_for_scan_missing_source_region() {
382 let source_region_id = RegionId::new(100, 2);
383 let existing_region_id = RegionId::new(100, 1);
384 let exts = HashMap::from([
385 (
386 FLOW_INCREMENTAL_MODE.to_string(),
387 FLOW_INCREMENTAL_MODE_MEMTABLE_ONLY.to_string(),
388 ),
389 (
390 FLOW_INCREMENTAL_AFTER_SEQS.to_string(),
391 format!(r#"{{"{}":10}}"#, existing_region_id.as_u64()),
392 ),
393 ]);
394
395 let parsed = FlowQueryExtensions::parse_flow_extensions(&exts)
396 .unwrap()
397 .unwrap();
398 let err = parsed.validate_for_scan(source_region_id).unwrap_err();
399 assert!(format!("{err}").contains("Missing region"));
400 }
401
402 #[test]
403 fn test_validate_for_scan_sink_table_excluded() {
404 let source_region_id = RegionId::new(1024, 1);
405 let exts = HashMap::from([
406 (
407 FLOW_INCREMENTAL_MODE.to_string(),
408 FLOW_INCREMENTAL_MODE_MEMTABLE_ONLY.to_string(),
409 ),
410 (
411 FLOW_INCREMENTAL_AFTER_SEQS.to_string(),
412 format!(r#"{{"{}":10}}"#, source_region_id.as_u64()),
413 ),
414 (FLOW_SINK_TABLE_ID.to_string(), "1024".to_string()),
415 ]);
416
417 let parsed = FlowQueryExtensions::parse_flow_extensions(&exts)
418 .unwrap()
419 .unwrap();
420 let apply_incremental = parsed.validate_for_scan(source_region_id).unwrap();
421 assert!(!apply_incremental);
422 }
423
424 #[test]
425 fn test_should_collect_region_watermark_defaults_false() {
426 let parsed = FlowQueryExtensions::default();
427 assert!(!parsed.should_collect_region_watermark());
428 }
429
430 #[test]
431 fn test_should_collect_region_watermark_true_for_return_region_seq() {
432 let parsed = FlowQueryExtensions {
433 return_region_seq: true,
434 ..Default::default()
435 };
436 assert!(parsed.should_collect_region_watermark());
437 }
438
439 #[test]
440 fn test_should_collect_region_watermark_true_for_incremental_query() {
441 let parsed = FlowQueryExtensions {
442 incremental_after_seqs: Some(HashMap::from([(1, 10)])),
443 ..Default::default()
444 };
445 assert!(parsed.should_collect_region_watermark());
446 }
447
448 #[test]
449 fn test_should_collect_region_watermark_from_extensions() {
450 let exts = HashMap::from([(FLOW_RETURN_REGION_SEQ.to_string(), "true".to_string())]);
451 assert!(should_collect_region_watermark_from_extensions(&exts));
452
453 let exts = HashMap::from([(
454 FLOW_INCREMENTAL_AFTER_SEQS.to_string(),
455 r#"{"1":10}"#.to_string(),
456 )]);
457 assert!(should_collect_region_watermark_from_extensions(&exts));
458
459 let exts = HashMap::from([(FLOW_RETURN_REGION_SEQ.to_string(), "false".to_string())]);
460 assert!(!should_collect_region_watermark_from_extensions(&exts));
461 assert!(!should_collect_region_watermark_from_extensions(
462 &HashMap::new()
463 ));
464 }
465
466 #[test]
467 fn test_parse_flow_extensions_return_region_seq_only_returns_some() {
468 let exts = HashMap::from([(FLOW_RETURN_REGION_SEQ.to_string(), "true".to_string())]);
469
470 let parsed = FlowQueryExtensions::parse_flow_extensions(&exts)
471 .unwrap()
472 .unwrap();
473
474 assert!(parsed.return_region_seq);
475 }
476
477 #[test]
478 fn test_parse_flow_extensions_sink_table_only_returns_some() {
479 let exts = HashMap::from([(FLOW_SINK_TABLE_ID.to_string(), "1024".to_string())]);
480
481 let parsed = FlowQueryExtensions::parse_flow_extensions(&exts)
482 .unwrap()
483 .unwrap();
484
485 assert_eq!(parsed.sink_table_id, Some(1024));
486 }
487
488 #[test]
489 fn test_parse_flow_extensions_incremental_after_seqs_only_returns_some() {
490 let exts = HashMap::from([(
491 FLOW_INCREMENTAL_AFTER_SEQS.to_string(),
492 r#"{"1":10}"#.to_string(),
493 )]);
494
495 let parsed = FlowQueryExtensions::parse_flow_extensions(&exts)
496 .unwrap()
497 .unwrap();
498
499 assert_eq!(
500 parsed.incremental_after_seqs,
501 Some(HashMap::from([(1, 10)]))
502 );
503 }
504}