1use std::collections::HashMap;
16
17use common_base::memory_limit::MemoryLimit;
18use serde::{Deserialize, Serialize};
19use store_api::storage::RegionId;
20use table::metadata::TableId;
21
22use crate::error::{Error, InvalidQueryContextExtensionSnafu, Result};
23
24pub const FLOW_INCREMENTAL_AFTER_SEQS: &str = "flow.incremental_after_seqs";
25pub const FLOW_INCREMENTAL_MODE: &str = "flow.incremental_mode";
26pub const FLOW_RETURN_REGION_SEQ: &str = "flow.return_region_seq";
27pub const FLOW_SINK_TABLE_ID: &str = "flow.sink_table_id";
28
29pub const FLOW_INCREMENTAL_MODE_MEMTABLE_ONLY: &str = "memtable_only";
30
31#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
33#[serde(default)]
34pub struct QueryOptions {
35 pub parallelism: usize,
37 pub allow_query_fallback: bool,
39 pub memory_pool_size: MemoryLimit,
43}
44
45#[allow(clippy::derivable_impls)]
46impl Default for QueryOptions {
47 fn default() -> Self {
48 Self {
49 parallelism: 0,
50 allow_query_fallback: false,
51 memory_pool_size: MemoryLimit::default(),
52 }
53 }
54}
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
57pub enum FlowIncrementalMode {
58 MemtableOnly,
59}
60
61#[derive(Debug, Clone, PartialEq, Eq, Default)]
62pub struct FlowQueryExtensions {
63 pub incremental_after_seqs: Option<HashMap<u64, u64>>,
65 pub incremental_mode: Option<FlowIncrementalMode>,
67 pub return_region_seq: bool,
69 pub sink_table_id: Option<TableId>,
71}
72
73impl FlowQueryExtensions {
74 pub fn parse_flow_extensions(extensions: &HashMap<String, String>) -> Result<Option<Self>> {
80 let has_flow_context = extensions.contains_key(FLOW_INCREMENTAL_AFTER_SEQS)
81 || extensions.contains_key(FLOW_INCREMENTAL_MODE)
82 || extensions.contains_key(FLOW_RETURN_REGION_SEQ)
83 || extensions.contains_key(FLOW_SINK_TABLE_ID);
84
85 if !has_flow_context {
86 return Ok(None);
87 }
88
89 let incremental_mode = extensions
90 .get(FLOW_INCREMENTAL_MODE)
91 .map(|value| match value.as_str() {
92 v if v.eq_ignore_ascii_case(FLOW_INCREMENTAL_MODE_MEMTABLE_ONLY) => {
93 Ok(FlowIncrementalMode::MemtableOnly)
94 }
95 _ => Err(invalid_query_context_extension(format!(
96 "Invalid value for {}: {}",
97 FLOW_INCREMENTAL_MODE, value
98 ))),
99 })
100 .transpose()?;
101
102 let incremental_after_seqs = extensions
103 .get(FLOW_INCREMENTAL_AFTER_SEQS)
104 .map(|value| parse_incremental_after_seqs(value.as_str()))
105 .transpose()?;
106
107 let return_region_seq = extensions
108 .get(FLOW_RETURN_REGION_SEQ)
109 .map(|value| parse_bool(value.as_str()))
110 .transpose()?
111 .unwrap_or(false);
112
113 let sink_table_id = extensions
114 .get(FLOW_SINK_TABLE_ID)
115 .map(|value| {
116 value.parse::<TableId>().map_err(|_| {
117 invalid_query_context_extension(format!(
118 "Invalid value for {}: {}",
119 FLOW_SINK_TABLE_ID, value
120 ))
121 })
122 })
123 .transpose()?;
124
125 if matches!(incremental_mode, Some(FlowIncrementalMode::MemtableOnly)) {
126 let after_seqs = incremental_after_seqs.as_ref().ok_or_else(|| {
127 invalid_query_context_extension(format!(
128 "{} is required when {}={}.",
129 FLOW_INCREMENTAL_AFTER_SEQS,
130 FLOW_INCREMENTAL_MODE,
131 FLOW_INCREMENTAL_MODE_MEMTABLE_ONLY
132 ))
133 })?;
134 if after_seqs.is_empty() {
135 return Err(invalid_query_context_extension(format!(
136 "{} must not be empty when {}={}.",
137 FLOW_INCREMENTAL_AFTER_SEQS,
138 FLOW_INCREMENTAL_MODE,
139 FLOW_INCREMENTAL_MODE_MEMTABLE_ONLY
140 )));
141 }
142 }
143
144 Ok(Some(Self {
145 incremental_after_seqs,
146 incremental_mode,
147 return_region_seq,
148 sink_table_id,
149 }))
150 }
151
152 pub fn validate_for_scan(&self, source_region_id: RegionId) -> Result<bool> {
153 if self.sink_table_id.is_some() && self.sink_table_id == Some(source_region_id.table_id()) {
154 return Ok(false);
155 }
156
157 if matches!(
158 self.incremental_mode,
159 Some(FlowIncrementalMode::MemtableOnly)
160 ) {
161 let after_seqs = self.incremental_after_seqs.as_ref().ok_or_else(|| {
162 invalid_query_context_extension(format!(
163 "{} is required when {}=memtable_only.",
164 FLOW_INCREMENTAL_AFTER_SEQS, FLOW_INCREMENTAL_MODE
165 ))
166 })?;
167
168 if !after_seqs.contains_key(&source_region_id.as_u64()) {
169 return Err(invalid_query_context_extension(format!(
170 "Missing region {} in {} when {}=memtable_only.",
171 source_region_id, FLOW_INCREMENTAL_AFTER_SEQS, FLOW_INCREMENTAL_MODE
172 )));
173 }
174 }
175
176 Ok(self.incremental_after_seqs.is_some())
177 }
178
179 pub fn should_collect_region_watermark(&self) -> bool {
180 self.return_region_seq || self.incremental_after_seqs.is_some()
181 }
182}
183
184fn parse_incremental_after_seqs(value: &str) -> Result<HashMap<u64, u64>> {
185 let raw = serde_json::from_str::<HashMap<String, serde_json::Value>>(value).map_err(|e| {
186 invalid_query_context_extension(format!(
187 "Invalid JSON for {}: {} ({})",
188 FLOW_INCREMENTAL_AFTER_SEQS, value, e
189 ))
190 })?;
191
192 raw.into_iter()
193 .map(|(region_id, raw_seq)| {
194 let region_id = region_id.parse::<u64>().map_err(|_| {
195 invalid_query_context_extension(format!(
196 "Invalid region id in {}: {}",
197 FLOW_INCREMENTAL_AFTER_SEQS, region_id
198 ))
199 })?;
200
201 let seq = match raw_seq {
202 serde_json::Value::Number(num) => num.as_u64().ok_or_else(|| {
203 invalid_query_context_extension(format!(
204 "Invalid sequence value in {} for region {}: {}",
205 FLOW_INCREMENTAL_AFTER_SEQS, region_id, num
206 ))
207 })?,
208 serde_json::Value::String(s) => s.parse::<u64>().map_err(|_| {
209 invalid_query_context_extension(format!(
210 "Invalid sequence string in {} for region {}: {}",
211 FLOW_INCREMENTAL_AFTER_SEQS, region_id, s
212 ))
213 })?,
214 _ => {
215 return Err(invalid_query_context_extension(format!(
216 "Invalid sequence value type in {} for region {}",
217 FLOW_INCREMENTAL_AFTER_SEQS, region_id
218 )));
219 }
220 };
221
222 Ok((region_id, seq))
223 })
224 .collect()
225}
226
227fn parse_bool(value: &str) -> Result<bool> {
228 match value {
229 v if v.eq_ignore_ascii_case("true") => Ok(true),
230 v if v.eq_ignore_ascii_case("false") => Ok(false),
231 _ => Err(invalid_query_context_extension(format!(
232 "Invalid value for {}: {}",
233 FLOW_RETURN_REGION_SEQ, value
234 ))),
235 }
236}
237
238fn invalid_query_context_extension(reason: String) -> Error {
239 InvalidQueryContextExtensionSnafu { reason }.build()
240}
241
242#[cfg(test)]
243mod flow_extension_tests {
244 use super::*;
245
246 #[test]
247 fn test_parse_flow_extensions_returns_none_for_non_flow_query() {
248 let exts = HashMap::new();
249 let parsed = FlowQueryExtensions::parse_flow_extensions(&exts).unwrap();
250
251 assert_eq!(parsed, None);
252 }
253
254 #[test]
255 fn test_parse_flow_extensions_memtable_only_success() {
256 let exts = HashMap::from([
257 (
258 FLOW_INCREMENTAL_MODE.to_string(),
259 FLOW_INCREMENTAL_MODE_MEMTABLE_ONLY.to_string(),
260 ),
261 (
262 FLOW_INCREMENTAL_AFTER_SEQS.to_string(),
263 r#"{"1":10,"2":20}"#.to_string(),
264 ),
265 (FLOW_RETURN_REGION_SEQ.to_string(), "true".to_string()),
266 (FLOW_SINK_TABLE_ID.to_string(), "1024".to_string()),
267 ]);
268
269 let parsed = FlowQueryExtensions::parse_flow_extensions(&exts)
270 .unwrap()
271 .unwrap();
272 assert_eq!(
273 parsed.incremental_mode,
274 Some(FlowIncrementalMode::MemtableOnly)
275 );
276 assert_eq!(
277 parsed.incremental_after_seqs.unwrap(),
278 HashMap::from([(1, 10), (2, 20)])
279 );
280 assert!(parsed.return_region_seq);
281 assert_eq!(parsed.sink_table_id, Some(1024));
282 }
283
284 #[test]
285 fn test_parse_flow_extensions_mode_requires_after_seqs() {
286 let exts = HashMap::from([(
287 FLOW_INCREMENTAL_MODE.to_string(),
288 FLOW_INCREMENTAL_MODE_MEMTABLE_ONLY.to_string(),
289 )]);
290
291 let err = FlowQueryExtensions::parse_flow_extensions(&exts).unwrap_err();
292 assert!(format!("{err}").contains(FLOW_INCREMENTAL_AFTER_SEQS));
293 }
294
295 #[test]
296 fn test_parse_flow_extensions_invalid_mode() {
297 let exts = HashMap::from([(FLOW_INCREMENTAL_MODE.to_string(), "foo".to_string())]);
298
299 let err = FlowQueryExtensions::parse_flow_extensions(&exts).unwrap_err();
300 assert!(format!("{err}").contains(FLOW_INCREMENTAL_MODE));
301 }
302
303 #[test]
304 fn test_parse_flow_extensions_invalid_after_seqs_json() {
305 let exts = HashMap::from([
306 (
307 FLOW_INCREMENTAL_MODE.to_string(),
308 FLOW_INCREMENTAL_MODE_MEMTABLE_ONLY.to_string(),
309 ),
310 (
311 FLOW_INCREMENTAL_AFTER_SEQS.to_string(),
312 "not-json".to_string(),
313 ),
314 ]);
315
316 let err = FlowQueryExtensions::parse_flow_extensions(&exts).unwrap_err();
317 assert!(format!("{err}").contains(FLOW_INCREMENTAL_AFTER_SEQS));
318 }
319
320 #[test]
321 fn test_parse_flow_extensions_after_seqs_string_values() {
322 let exts = HashMap::from([(
323 FLOW_INCREMENTAL_AFTER_SEQS.to_string(),
324 r#"{"1":"10","2":"20"}"#.to_string(),
325 )]);
326
327 let parsed = FlowQueryExtensions::parse_flow_extensions(&exts)
328 .unwrap()
329 .unwrap();
330 assert_eq!(
331 parsed.incremental_after_seqs.unwrap(),
332 HashMap::from([(1, 10), (2, 20)])
333 );
334 }
335
336 #[test]
337 fn test_parse_flow_extensions_after_seqs_invalid_value_type() {
338 let exts = HashMap::from([(
339 FLOW_INCREMENTAL_AFTER_SEQS.to_string(),
340 r#"{"1":true}"#.to_string(),
341 )]);
342
343 let err = FlowQueryExtensions::parse_flow_extensions(&exts).unwrap_err();
344 assert!(format!("{err}").contains(FLOW_INCREMENTAL_AFTER_SEQS));
345 }
346
347 #[test]
348 fn test_parse_flow_extensions_invalid_sink_table_id() {
349 let exts = HashMap::from([(FLOW_SINK_TABLE_ID.to_string(), "x".to_string())]);
350
351 let err = FlowQueryExtensions::parse_flow_extensions(&exts).unwrap_err();
352 assert!(format!("{err}").contains(FLOW_SINK_TABLE_ID));
353 }
354
355 #[test]
356 fn test_validate_for_scan_missing_source_region() {
357 let source_region_id = RegionId::new(100, 2);
358 let existing_region_id = RegionId::new(100, 1);
359 let exts = HashMap::from([
360 (
361 FLOW_INCREMENTAL_MODE.to_string(),
362 FLOW_INCREMENTAL_MODE_MEMTABLE_ONLY.to_string(),
363 ),
364 (
365 FLOW_INCREMENTAL_AFTER_SEQS.to_string(),
366 format!(r#"{{"{}":10}}"#, existing_region_id.as_u64()),
367 ),
368 ]);
369
370 let parsed = FlowQueryExtensions::parse_flow_extensions(&exts)
371 .unwrap()
372 .unwrap();
373 let err = parsed.validate_for_scan(source_region_id).unwrap_err();
374 assert!(format!("{err}").contains("Missing region"));
375 }
376
377 #[test]
378 fn test_validate_for_scan_sink_table_excluded() {
379 let source_region_id = RegionId::new(1024, 1);
380 let exts = HashMap::from([
381 (
382 FLOW_INCREMENTAL_MODE.to_string(),
383 FLOW_INCREMENTAL_MODE_MEMTABLE_ONLY.to_string(),
384 ),
385 (
386 FLOW_INCREMENTAL_AFTER_SEQS.to_string(),
387 format!(r#"{{"{}":10}}"#, source_region_id.as_u64()),
388 ),
389 (FLOW_SINK_TABLE_ID.to_string(), "1024".to_string()),
390 ]);
391
392 let parsed = FlowQueryExtensions::parse_flow_extensions(&exts)
393 .unwrap()
394 .unwrap();
395 let apply_incremental = parsed.validate_for_scan(source_region_id).unwrap();
396 assert!(!apply_incremental);
397 }
398
399 #[test]
400 fn test_should_collect_region_watermark_defaults_false() {
401 let parsed = FlowQueryExtensions::default();
402 assert!(!parsed.should_collect_region_watermark());
403 }
404
405 #[test]
406 fn test_should_collect_region_watermark_true_for_return_region_seq() {
407 let parsed = FlowQueryExtensions {
408 return_region_seq: true,
409 ..Default::default()
410 };
411 assert!(parsed.should_collect_region_watermark());
412 }
413
414 #[test]
415 fn test_should_collect_region_watermark_true_for_incremental_query() {
416 let parsed = FlowQueryExtensions {
417 incremental_after_seqs: Some(HashMap::from([(1, 10)])),
418 ..Default::default()
419 };
420 assert!(parsed.should_collect_region_watermark());
421 }
422
423 #[test]
424 fn test_parse_flow_extensions_return_region_seq_only_returns_some() {
425 let exts = HashMap::from([(FLOW_RETURN_REGION_SEQ.to_string(), "true".to_string())]);
426
427 let parsed = FlowQueryExtensions::parse_flow_extensions(&exts)
428 .unwrap()
429 .unwrap();
430
431 assert!(parsed.return_region_seq);
432 }
433
434 #[test]
435 fn test_parse_flow_extensions_sink_table_only_returns_some() {
436 let exts = HashMap::from([(FLOW_SINK_TABLE_ID.to_string(), "1024".to_string())]);
437
438 let parsed = FlowQueryExtensions::parse_flow_extensions(&exts)
439 .unwrap()
440 .unwrap();
441
442 assert_eq!(parsed.sink_table_id, Some(1024));
443 }
444
445 #[test]
446 fn test_parse_flow_extensions_incremental_after_seqs_only_returns_some() {
447 let exts = HashMap::from([(
448 FLOW_INCREMENTAL_AFTER_SEQS.to_string(),
449 r#"{"1":10}"#.to_string(),
450 )]);
451
452 let parsed = FlowQueryExtensions::parse_flow_extensions(&exts)
453 .unwrap()
454 .unwrap();
455
456 assert_eq!(
457 parsed.incremental_after_seqs,
458 Some(HashMap::from([(1, 10)]))
459 );
460 }
461}