diff --git a/src/client/src/database.rs b/src/client/src/database.rs index a9ab7fd888..0ca16caa04 100644 --- a/src/client/src/database.rs +++ b/src/client/src/database.rs @@ -61,6 +61,8 @@ type FlightDataStream = Pin + Send>>; type DoPutResponseStream = Pin>>>; +const FLOW_EXTENSIONS_METADATA_KEY: &str = "x-greptime-flow-extensions"; + #[derive(Debug, Clone, Default)] pub struct OutputMetrics { metrics: Arc>>, @@ -528,6 +530,22 @@ impl Database { Ok(()) } + fn put_flow_extensions( + metadata: &mut MetadataMap, + flow_extensions: &[(&str, &str)], + ) -> Result<()> { + if flow_extensions.is_empty() { + return Ok(()); + } + + let value = serde_json::to_string(&flow_extensions.to_vec()) + .expect("flow extension pairs should serialize"); + let key = AsciiMetadataKey::from_static(FLOW_EXTENSIONS_METADATA_KEY); + let value = AsciiMetadataValue::from_str(&value).context(InvalidTonicMetadataValueSnafu)?; + metadata.insert(key, value); + Ok(()) + } + /// Make a request to the database. pub async fn handle(&self, request: Request) -> Result { let mut client = make_database_client(&self.client)?; @@ -623,7 +641,7 @@ impl Database { let request = Request::Query(QueryRequest { query: Some(Query::Sql(sql.as_ref().to_string())), }); - self.do_get(request, hints) + self.do_get(request, hints, &[]) .await .map(OutputWithMetrics::into_output) } @@ -662,7 +680,18 @@ impl Database { request: QueryRequest, hints: &[(&str, &str)], ) -> Result { - self.do_get(Request::Query(request), hints).await + self.query_with_terminal_metrics_and_flow_extensions(request, hints, &[]) + .await + } + + pub async fn query_with_terminal_metrics_and_flow_extensions( + &self, + request: QueryRequest, + hints: &[(&str, &str)], + flow_extensions: &[(&str, &str)], + ) -> Result { + self.do_get(Request::Query(request), hints, flow_extensions) + .await } /// Creates a new table using the provided table expression. @@ -670,7 +699,7 @@ impl Database { let request = Request::Ddl(DdlRequest { expr: Some(DdlExpr::CreateTable(expr)), }); - self.do_get(request, &[]) + self.do_get(request, &[], &[]) .await .map(OutputWithMetrics::into_output) } @@ -680,19 +709,26 @@ impl Database { let request = Request::Ddl(DdlRequest { expr: Some(DdlExpr::AlterTable(expr)), }); - self.do_get(request, &[]) + self.do_get(request, &[], &[]) .await .map(OutputWithMetrics::into_output) } - async fn do_get(&self, request: Request, hints: &[(&str, &str)]) -> Result { + async fn do_get( + &self, + request: Request, + hints: &[(&str, &str)], + flow_extensions: &[(&str, &str)], + ) -> Result { let request = self.to_rpc_request(request); let request = Ticket { ticket: request.encode_to_vec().into(), }; let mut request = tonic::Request::new(request); - Self::put_hints(request.metadata_mut(), hints)?; + let metadata = request.metadata_mut(); + Self::put_hints(metadata, hints)?; + Self::put_flow_extensions(metadata, flow_extensions)?; let mut client = self.client.make_flight_client(false, false)?; @@ -838,6 +874,36 @@ mod tests { .unwrap() } + #[test] + fn test_put_flow_extensions_preserves_comma_bearing_values() { + let mut metadata = MetadataMap::new(); + Database::put_flow_extensions( + &mut metadata, + &[ + ("flow.return_region_seq", "true"), + ("flow.incremental_after_seqs", r#"{"1":10,"2":20}"#), + ], + ) + .unwrap(); + + let value = metadata + .get(FLOW_EXTENSIONS_METADATA_KEY) + .unwrap() + .to_str() + .unwrap(); + let decoded: Vec<(String, String)> = serde_json::from_str(value).unwrap(); + assert_eq!( + decoded, + vec![ + ("flow.return_region_seq".to_string(), "true".to_string()), + ( + "flow.incremental_after_seqs".to_string(), + r#"{"1":10,"2":20}"#.to_string() + ), + ] + ); + } + #[test] fn test_flight_ctx() { let mut ctx = FlightContext::default(); diff --git a/src/flow/src/batching_mode/frontend_client.rs b/src/flow/src/batching_mode/frontend_client.rs index d8f8a044ac..6fd80d12de 100644 --- a/src/flow/src/batching_mode/frontend_client.rs +++ b/src/flow/src/batching_mode/frontend_client.rs @@ -356,18 +356,13 @@ impl FrontendClient { query, batch_opts, .. } => { let query_parallelism = query.parallelism.to_string(); - let mut hints = vec![ + let hints = vec![ (QUERY_PARALLELISM_HINT, query_parallelism.as_str()), (READ_PREFERENCE_HINT, batch_opts.read_preference.as_ref()), ]; - // PR2b only sends simple flow hint values such as - // `flow.return_region_seq=true`. The distributed client forwards - // hints through `x-greptime-hints`, whose existing comma-separated - // encoding is not suitable for comma-bearing values. - hints.extend_from_slice(extensions); let db = self.get_random_active_frontend(catalog, schema).await?; db.database - .query_with_terminal_metrics(request, &hints) + .query_with_terminal_metrics_and_flow_extensions(request, &hints, extensions) .await .map_err(BoxedError::new) .context(ExternalSnafu) diff --git a/src/servers/src/grpc/flight.rs b/src/servers/src/grpc/flight.rs index ddd0b694a8..c85019a99c 100644 --- a/src/servers/src/grpc/flight.rs +++ b/src/servers/src/grpc/flight.rs @@ -58,6 +58,8 @@ use crate::request_memory_limiter::ServerMemoryLimiter; use crate::request_memory_metrics::RequestMemoryMetrics; use crate::{error, hint_headers}; +const FLOW_EXTENSIONS_METADATA_KEY: &str = "x-greptime-flow-extensions"; + pub type TonicStream = Pin> + Send + 'static>>; /// A subset of [FlightService] @@ -191,7 +193,8 @@ impl FlightCraft for GreptimeRequestHandler { &self, request: Request, ) -> TonicResult>> { - let hints = hint_headers::extract_hints(request.metadata()); + let mut hints = hint_headers::extract_hints(request.metadata()); + hints.extend(extract_flow_extensions(request.metadata())?); let ticket = request.into_inner().ticket; let request = @@ -524,6 +527,26 @@ impl Stream for PutRecordBatchRequestStream { } } +fn extract_flow_extensions( + metadata: &tonic::metadata::MetadataMap, +) -> TonicResult> { + let Some(value) = metadata.get(FLOW_EXTENSIONS_METADATA_KEY) else { + return Ok(vec![]); + }; + + let value = value.to_str().map_err(|e| { + Status::invalid_argument(format!( + "Invalid {FLOW_EXTENSIONS_METADATA_KEY} metadata value: {e}" + )) + })?; + + serde_json::from_str::>(value).map_err(|e| { + Status::invalid_argument(format!( + "Invalid {FLOW_EXTENSIONS_METADATA_KEY} metadata JSON: {e}" + )) + }) +} + fn to_flight_data_stream( output: Output, tracing_context: TracingContext, @@ -568,3 +591,46 @@ fn to_flight_data_stream( } } } + +#[cfg(test)] +mod tests { + use tonic::metadata::{AsciiMetadataValue, MetadataMap}; + + use super::*; + + #[test] + fn test_extract_flow_extensions_preserves_comma_bearing_values() { + let mut metadata = MetadataMap::new(); + metadata.insert( + FLOW_EXTENSIONS_METADATA_KEY, + AsciiMetadataValue::try_from( + r#"[["flow.return_region_seq","true"],["flow.incremental_after_seqs","{\"1\":10,\"2\":20}"]]"#, + ) + .unwrap(), + ); + + let extensions = extract_flow_extensions(&metadata).unwrap(); + assert_eq!( + extensions, + vec![ + ("flow.return_region_seq".to_string(), "true".to_string()), + ( + "flow.incremental_after_seqs".to_string(), + r#"{"1":10,"2":20}"#.to_string() + ), + ] + ); + } + + #[test] + fn test_extract_flow_extensions_rejects_invalid_json() { + let mut metadata = MetadataMap::new(); + metadata.insert( + FLOW_EXTENSIONS_METADATA_KEY, + AsciiMetadataValue::try_from("not-json").unwrap(), + ); + + let err = extract_flow_extensions(&metadata).unwrap_err(); + assert_eq!(err.code(), tonic::Code::InvalidArgument); + } +}