1use std::pin::Pin;
16use std::str::FromStr;
17use std::sync::atomic::{AtomicBool, Ordering};
18use std::sync::{Arc, RwLock};
19use std::task::{Context, Poll};
20
21use api::v1::auth_header::AuthScheme;
22use api::v1::ddl_request::Expr as DdlExpr;
23use api::v1::greptime_database_client::GreptimeDatabaseClient;
24use api::v1::greptime_request::Request;
25use api::v1::query_request::Query;
26use api::v1::{
27 AlterTableExpr, AuthHeader, Basic, CreateTableExpr, DdlRequest, GreptimeRequest,
28 InsertRequests, QueryRequest, RequestHeader, RowInsertRequests,
29};
30use arc_swap::ArcSwapOption;
31use arrow_flight::{FlightData, Ticket};
32use async_stream::stream;
33use base64::Engine;
34use base64::prelude::BASE64_STANDARD;
35use common_catalog::build_db_string;
36use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
37use common_error::ext::BoxedError;
38use common_grpc::flight::do_put::DoPutResponse;
39use common_grpc::flight::{FLOW_EXTENSIONS_METADATA_KEY, FlightDecoder, FlightMessage};
40use common_query::Output;
41use common_recordbatch::adapter::RecordBatchMetrics;
42use common_recordbatch::error::ExternalSnafu;
43use common_recordbatch::{OrderOption, RecordBatch, RecordBatchStream, RecordBatchStreamWrapper};
44use common_telemetry::tracing::Span;
45use common_telemetry::tracing_context::W3cTrace;
46use common_telemetry::{error, warn};
47use futures::future;
48use futures_util::{Stream, StreamExt, TryStreamExt};
49use prost::Message;
50use snafu::{OptionExt, ResultExt};
51use tonic::metadata::{AsciiMetadataKey, AsciiMetadataValue, MetadataMap, MetadataValue};
52use tonic::transport::Channel;
53
54use crate::error::{
55 ConvertFlightDataSnafu, Error, FlightGetSnafu, IllegalFlightMessagesSnafu,
56 InvalidTonicMetadataValueSnafu,
57};
58use crate::{Client, Result, error, from_grpc_response};
59
60type FlightDataStream = Pin<Box<dyn Stream<Item = FlightData> + Send>>;
61
62type DoPutResponseStream = Pin<Box<dyn Stream<Item = Result<DoPutResponse>>>>;
63
64#[derive(Debug, Clone, Default)]
69pub struct OutputMetrics {
70 inner: Arc<OutputMetricsInner>,
71}
72
73#[derive(Debug, Default)]
74struct OutputMetricsInner {
75 metrics: RwLock<Option<RecordBatchMetrics>>,
76 ready: AtomicBool,
77}
78
79impl OutputMetrics {
80 fn new() -> Self {
81 Self::default()
82 }
83
84 pub fn update(&self, metrics: Option<RecordBatchMetrics>) {
86 *self.inner.metrics.write().unwrap() = metrics;
87 }
88
89 pub fn mark_ready(&self) {
91 let _ = self
92 .inner
93 .ready
94 .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire);
95 }
96
97 pub fn is_ready(&self) -> bool {
101 self.inner.ready.load(Ordering::Acquire)
102 }
103
104 pub fn get(&self) -> Option<RecordBatchMetrics> {
106 self.inner.metrics.read().unwrap().clone()
107 }
108
109 pub fn region_watermark_map(&self) -> Option<std::collections::HashMap<u64, u64>> {
115 Some(
116 self.get()?
117 .region_watermarks
118 .into_iter()
119 .filter_map(|entry| entry.watermark.map(|seq| (entry.region_id, seq)))
120 .collect::<std::collections::HashMap<_, _>>(),
121 )
122 }
123
124 pub fn participating_regions(&self) -> Option<std::collections::BTreeSet<u64>> {
127 Some(
128 self.get()?
129 .region_watermarks
130 .into_iter()
131 .map(|entry| entry.region_id)
132 .collect::<std::collections::BTreeSet<_>>(),
133 )
134 }
135}
136
137#[derive(Debug)]
143pub struct OutputWithMetrics {
144 pub output: Output,
145 pub metrics: OutputMetrics,
146}
147
148impl OutputWithMetrics {
149 pub fn from_output(output: Output) -> Self {
154 let terminal_metrics = OutputMetrics::new();
155 let output = attach_terminal_metrics(output, &terminal_metrics);
156 Self {
157 output,
158 metrics: terminal_metrics,
159 }
160 }
161
162 pub fn region_watermark_map(&self) -> Option<std::collections::HashMap<u64, u64>> {
164 self.metrics.region_watermark_map()
165 }
166
167 pub fn participating_regions(&self) -> Option<std::collections::BTreeSet<u64>> {
169 self.metrics.participating_regions()
170 }
171
172 pub fn into_output(self) -> Output {
174 self.output
175 }
176}
177
178fn parse_terminal_metrics(metrics_json: &str) -> Result<RecordBatchMetrics> {
179 serde_json::from_str(metrics_json).map_err(|e| {
180 IllegalFlightMessagesSnafu {
181 reason: format!("Invalid terminal metrics message: {e}"),
182 }
183 .build()
184 })
185}
186
187struct StreamWithMetrics {
188 stream: common_recordbatch::SendableRecordBatchStream,
189 metrics: OutputMetrics,
190}
191
192impl StreamWithMetrics {
193 fn new(stream: common_recordbatch::SendableRecordBatchStream, metrics: OutputMetrics) -> Self {
194 Self { stream, metrics }
195 }
196
197 fn sync_terminal_metrics(&self) {
198 self.metrics.update(self.stream.metrics());
199 }
200}
201
202impl RecordBatchStream for StreamWithMetrics {
203 fn name(&self) -> &str {
204 self.stream.name()
205 }
206
207 fn schema(&self) -> datatypes::schema::SchemaRef {
208 self.stream.schema()
209 }
210
211 fn output_ordering(&self) -> Option<&[OrderOption]> {
212 self.stream.output_ordering()
213 }
214
215 fn metrics(&self) -> Option<RecordBatchMetrics> {
216 self.sync_terminal_metrics();
217 self.metrics.get()
218 }
219}
220
221impl Stream for StreamWithMetrics {
222 type Item = common_recordbatch::error::Result<RecordBatch>;
223
224 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
225 let polled = Pin::new(&mut self.stream).poll_next(cx);
226 if let Poll::Ready(None) = &polled {
227 self.sync_terminal_metrics();
228 self.metrics.mark_ready();
229 }
230 polled
231 }
232
233 fn size_hint(&self) -> (usize, Option<usize>) {
234 self.stream.size_hint()
235 }
236}
237
238fn attach_terminal_metrics(output: Output, terminal_metrics: &OutputMetrics) -> Output {
239 let Output { data, meta } = output;
240 let data = match data {
241 common_query::OutputData::Stream(stream) => {
242 terminal_metrics.update(stream.metrics());
243 common_query::OutputData::Stream(Box::pin(StreamWithMetrics::new(
244 stream,
245 terminal_metrics.clone(),
246 )))
247 }
248 other => {
249 terminal_metrics.mark_ready();
250 other
251 }
252 };
253 Output::new(data, meta)
254}
255
256async fn output_from_flight_message_stream<S>(
257 mut flight_message_stream: S,
258) -> Result<OutputWithMetrics>
259where
260 S: Stream<Item = Result<FlightMessage>> + Send + Unpin + 'static,
261{
262 let Some(first_flight_message) = flight_message_stream.next().await else {
263 return IllegalFlightMessagesSnafu {
264 reason: "Expect the response not to be empty",
265 }
266 .fail();
267 };
268
269 let first_flight_message = first_flight_message?;
270
271 match first_flight_message {
272 FlightMessage::AffectedRows { rows, metrics } => {
273 let terminal_metrics = OutputMetrics::new();
274 if let Some(metrics) = metrics {
275 terminal_metrics.update(Some(parse_terminal_metrics(&metrics)?));
276 }
277 let next_message = flight_message_stream.next().await.transpose()?;
278 match next_message {
279 None => terminal_metrics.mark_ready(),
280 Some(FlightMessage::Metrics(s)) if terminal_metrics.get().is_none() => {
281 terminal_metrics.update(Some(parse_terminal_metrics(&s)?));
282 terminal_metrics.mark_ready();
283 }
284 Some(FlightMessage::Metrics(_)) => {
285 return IllegalFlightMessagesSnafu {
286 reason: "'AffectedRows' Flight metadata already carries Metrics and cannot be followed by another Metrics message",
287 }
288 .fail();
289 }
290 Some(other) => {
291 return IllegalFlightMessagesSnafu {
292 reason: format!(
293 "'AffectedRows' Flight message can only be followed by a Metrics message, got {other:?}"
294 ),
295 }
296 .fail();
297 }
298 }
299 Ok(OutputWithMetrics {
300 output: Output::new_with_affected_rows(rows),
301 metrics: terminal_metrics,
302 })
303 }
304 FlightMessage::RecordBatch(_) | FlightMessage::Metrics(_) => IllegalFlightMessagesSnafu {
305 reason: "The first flight message cannot be a RecordBatch or Metrics message",
306 }
307 .fail(),
308 FlightMessage::Schema(schema) => {
309 let metrics = Arc::new(ArcSwapOption::from(None));
310 let metrics_ref = metrics.clone();
311 let schema = Arc::new(
312 datatypes::schema::Schema::try_from(schema).context(error::ConvertSchemaSnafu)?,
313 );
314 let schema_cloned = schema.clone();
315 let stream = Box::pin(stream!({
316 while let Some(flight_message_item) = flight_message_stream.next().await {
317 let flight_message = match flight_message_item {
318 Ok(message) => message,
319 Err(e) => {
320 yield Err(BoxedError::new(e)).context(ExternalSnafu);
321 break;
322 }
323 };
324
325 match flight_message {
326 FlightMessage::RecordBatch(arrow_batch) => {
327 yield Ok(RecordBatch::from_df_record_batch(
328 schema_cloned.clone(),
329 arrow_batch,
330 ))
331 }
332 FlightMessage::Metrics(s) => {
333 match parse_terminal_metrics(&s) {
334 Ok(m) => {
335 metrics_ref.swap(Some(Arc::new(m)));
336 }
337 Err(e) => {
338 yield Err(BoxedError::new(e)).context(ExternalSnafu);
339 }
340 };
341 }
342 FlightMessage::AffectedRows { .. } | FlightMessage::Schema(_) => {
343 yield IllegalFlightMessagesSnafu {
344 reason: format!(
345 "A Schema message must be succeeded exclusively by a set of RecordBatch messages, flight_message: {:?}",
346 flight_message
347 )
348 }
349 .fail()
350 .map_err(BoxedError::new)
351 .context(ExternalSnafu);
352 break;
353 }
354 }
355 }
356 }));
357 let record_batch_stream = RecordBatchStreamWrapper {
358 schema,
359 stream,
360 output_ordering: None,
361 metrics,
362 span: Span::current(),
363 };
364 Ok(OutputWithMetrics::from_output(Output::new_with_stream(
365 Box::pin(record_batch_stream),
366 )))
367 }
368 }
369}
370
371#[derive(Clone, Debug, Default)]
372pub struct Database {
373 catalog: String,
377 schema: String,
378 dbname: String,
381 timezone: String,
384
385 client: Client,
386 ctx: FlightContext,
387}
388
389pub struct DatabaseClient {
390 pub addr: String,
391 pub inner: GreptimeDatabaseClient<Channel>,
392}
393
394impl DatabaseClient {
395 pub fn inspect_err<'a>(&'a self, context: &'a str) -> impl Fn(&tonic::Status) + 'a {
397 let addr = &self.addr;
398 move |status| {
399 error!("Failed to {context} request, peer: {addr}, status: {status:?}");
400 }
401 }
402}
403
404fn make_database_client(client: &Client) -> Result<DatabaseClient> {
405 let (addr, channel) = client.find_channel()?;
406 Ok(DatabaseClient {
407 addr,
408 inner: GreptimeDatabaseClient::new(channel)
409 .max_decoding_message_size(client.max_grpc_recv_message_size())
410 .max_encoding_message_size(client.max_grpc_send_message_size()),
411 })
412}
413
414impl Database {
415 pub fn new(catalog: impl Into<String>, schema: impl Into<String>, client: Client) -> Self {
417 Self {
418 catalog: catalog.into(),
419 schema: schema.into(),
420 dbname: String::default(),
421 timezone: String::default(),
422 client,
423 ctx: FlightContext::default(),
424 }
425 }
426
427 pub fn new_with_dbname(dbname: impl Into<String>, client: Client) -> Self {
435 Self {
436 catalog: String::default(),
437 schema: String::default(),
438 timezone: String::default(),
439 dbname: dbname.into(),
440 client,
441 ctx: FlightContext::default(),
442 }
443 }
444
445 pub fn set_catalog(&mut self, catalog: impl Into<String>) {
447 self.catalog = catalog.into();
448 }
449
450 fn catalog_or_default(&self) -> &str {
451 if self.catalog.is_empty() {
452 DEFAULT_CATALOG_NAME
453 } else {
454 &self.catalog
455 }
456 }
457
458 pub fn set_schema(&mut self, schema: impl Into<String>) {
460 self.schema = schema.into();
461 }
462
463 fn schema_or_default(&self) -> &str {
464 if self.schema.is_empty() {
465 DEFAULT_SCHEMA_NAME
466 } else {
467 &self.schema
468 }
469 }
470
471 pub fn set_timezone(&mut self, timezone: impl Into<String>) {
473 self.timezone = timezone.into();
474 }
475
476 pub fn set_auth(&mut self, auth: AuthScheme) {
478 self.ctx.auth_header = Some(AuthHeader {
479 auth_scheme: Some(auth),
480 });
481 }
482
483 pub async fn insert(&self, requests: InsertRequests) -> Result<u32> {
485 self.handle(Request::Inserts(requests)).await
486 }
487
488 pub async fn insert_with_hints(
490 &self,
491 requests: InsertRequests,
492 hints: &[(&str, &str)],
493 ) -> Result<u32> {
494 let mut client = make_database_client(&self.client)?;
495 let request = self.to_rpc_request(Request::Inserts(requests));
496
497 let mut request = tonic::Request::new(request);
498 let metadata = request.metadata_mut();
499 Self::put_hints(metadata, hints)?;
500
501 let response = client
502 .inner
503 .handle(request)
504 .await
505 .inspect_err(client.inspect_err("insert_with_hints"))?
506 .into_inner();
507 from_grpc_response(response)
508 }
509
510 pub async fn row_inserts(&self, requests: RowInsertRequests) -> Result<u32> {
512 self.handle(Request::RowInserts(requests)).await
513 }
514
515 pub async fn row_inserts_with_hints(
517 &self,
518 requests: RowInsertRequests,
519 hints: &[(&str, &str)],
520 ) -> Result<u32> {
521 let mut client = make_database_client(&self.client)?;
522 let request = self.to_rpc_request(Request::RowInserts(requests));
523
524 let mut request = tonic::Request::new(request);
525 let metadata = request.metadata_mut();
526 Self::put_hints(metadata, hints)?;
527
528 let response = client
529 .inner
530 .handle(request)
531 .await
532 .inspect_err(client.inspect_err("row_inserts_with_hints"))?
533 .into_inner();
534 from_grpc_response(response)
535 }
536
537 fn put_hints(metadata: &mut MetadataMap, hints: &[(&str, &str)]) -> Result<()> {
538 let Some(value) = hints
539 .iter()
540 .map(|(k, v)| format!("{}={}", k, v))
541 .reduce(|a, b| format!("{},{}", a, b))
542 else {
543 return Ok(());
544 };
545
546 let key = AsciiMetadataKey::from_static("x-greptime-hints");
547 let value = AsciiMetadataValue::from_str(&value).context(InvalidTonicMetadataValueSnafu)?;
548 metadata.insert(key, value);
549 Ok(())
550 }
551
552 fn put_flow_extensions(
553 metadata: &mut MetadataMap,
554 flow_extensions: &[(&str, &str)],
555 ) -> Result<()> {
556 if flow_extensions.is_empty() {
557 return Ok(());
558 }
559
560 let value = serde_json::to_string(&flow_extensions.to_vec())
561 .expect("flow extension pairs should serialize");
562 let key = AsciiMetadataKey::from_static(FLOW_EXTENSIONS_METADATA_KEY);
563 let value = AsciiMetadataValue::from_str(&value).context(InvalidTonicMetadataValueSnafu)?;
564 metadata.insert(key, value);
565 Ok(())
566 }
567
568 pub async fn handle(&self, request: Request) -> Result<u32> {
570 let mut client = make_database_client(&self.client)?;
571 let request = self.to_rpc_request(request);
572 let response = client
573 .inner
574 .handle(request)
575 .await
576 .inspect_err(client.inspect_err("handle"))?
577 .into_inner();
578 from_grpc_response(response)
579 }
580
581 pub async fn handle_with_retry(
584 &self,
585 request: Request,
586 max_retries: u32,
587 hints: &[(&str, &str)],
588 ) -> Result<u32> {
589 let mut client = make_database_client(&self.client)?;
590 let mut retries = 0;
591
592 let request = self.to_rpc_request(request);
593
594 loop {
595 let mut tonic_request = tonic::Request::new(request.clone());
596 let metadata = tonic_request.metadata_mut();
597 Self::put_hints(metadata, hints)?;
598 let raw_response = client
599 .inner
600 .handle(tonic_request)
601 .await
602 .inspect_err(client.inspect_err("handle"));
603 match (raw_response, retries < max_retries) {
604 (Ok(resp), _) => return from_grpc_response(resp.into_inner()),
605 (Err(err), true) => {
606 if is_grpc_retryable(&err) {
608 retries += 1;
610 warn!("Retrying {} times with error = {:?}", retries, err);
611 continue;
612 } else {
613 error!(
614 err; "Failed to send request to grpc handle, retries = {}, not retryable error, aborting",
615 retries
616 );
617 return Err(err.into());
618 }
619 }
620 (Err(err), false) => {
621 error!(
622 err; "Failed to send request to grpc handle after {} retries",
623 retries,
624 );
625 return Err(err.into());
626 }
627 }
628 }
629 }
630
631 #[inline]
632 fn to_rpc_request(&self, request: Request) -> GreptimeRequest {
633 GreptimeRequest {
634 header: Some(RequestHeader {
635 catalog: self.catalog.clone(),
636 schema: self.schema.clone(),
637 authorization: self.ctx.auth_header.clone(),
638 dbname: self.dbname.clone(),
639 timezone: self.timezone.clone(),
640 tracing_context: W3cTrace::new(),
642 }),
643 request: Some(request),
644 }
645 }
646
647 pub async fn sql<S>(&self, sql: S) -> Result<Output>
649 where
650 S: AsRef<str>,
651 {
652 self.sql_with_hint(sql, &[]).await
653 }
654
655 pub async fn sql_with_hint<S>(&self, sql: S, hints: &[(&str, &str)]) -> Result<Output>
657 where
658 S: AsRef<str>,
659 {
660 let request = Request::Query(QueryRequest {
661 query: Some(Query::Sql(sql.as_ref().to_string())),
662 });
663 self.do_get(request, hints, &[])
664 .await
665 .map(OutputWithMetrics::into_output)
666 }
667
668 pub async fn sql_with_terminal_metrics<S>(
673 &self,
674 sql: S,
675 hints: &[(&str, &str)],
676 ) -> Result<OutputWithMetrics>
677 where
678 S: AsRef<str>,
679 {
680 self.query_with_terminal_metrics_and_flow_extensions(
681 QueryRequest {
682 query: Some(Query::Sql(sql.as_ref().to_string())),
683 },
684 hints,
685 &[],
686 )
687 .await
688 }
689
690 pub async fn logical_plan(&self, logical_plan: Vec<u8>) -> Result<Output> {
692 self.query_with_terminal_metrics_and_flow_extensions(
693 QueryRequest {
694 query: Some(Query::LogicalPlan(logical_plan)),
695 },
696 &[],
697 &[],
698 )
699 .await
700 .map(OutputWithMetrics::into_output)
701 }
702
703 pub async fn query_with_terminal_metrics_and_flow_extensions(
708 &self,
709 request: QueryRequest,
710 hints: &[(&str, &str)],
711 flow_extensions: &[(&str, &str)],
712 ) -> Result<OutputWithMetrics> {
713 self.do_get(Request::Query(request), hints, flow_extensions)
714 .await
715 }
716
717 pub async fn create(&self, expr: CreateTableExpr) -> Result<Output> {
719 let request = Request::Ddl(DdlRequest {
720 expr: Some(DdlExpr::CreateTable(expr)),
721 });
722 self.do_get(request, &[], &[])
723 .await
724 .map(OutputWithMetrics::into_output)
725 }
726
727 pub async fn alter(&self, expr: AlterTableExpr) -> Result<Output> {
729 let request = Request::Ddl(DdlRequest {
730 expr: Some(DdlExpr::AlterTable(expr)),
731 });
732 self.do_get(request, &[], &[])
733 .await
734 .map(OutputWithMetrics::into_output)
735 }
736
737 async fn do_get(
738 &self,
739 request: Request,
740 hints: &[(&str, &str)],
741 flow_extensions: &[(&str, &str)],
742 ) -> Result<OutputWithMetrics> {
743 let request = self.to_rpc_request(request);
744 let request = Ticket {
745 ticket: request.encode_to_vec().into(),
746 };
747
748 let mut request = tonic::Request::new(request);
749 let metadata = request.metadata_mut();
750 Self::put_hints(metadata, hints)?;
751 Self::put_flow_extensions(metadata, flow_extensions)?;
752
753 let mut client = self.client.make_flight_client(false, false)?;
754
755 let response = client.mut_inner().do_get(request).await.or_else(|e| {
756 let tonic_code = e.code();
757 let e: Error = e.into();
758 error!(
759 "Failed to do Flight get, addr: {}, code: {}, source: {:?}",
760 client.addr(),
761 tonic_code,
762 e
763 );
764 Err(BoxedError::new(e)).with_context(|_| FlightGetSnafu {
765 addr: client.addr().to_string(),
766 tonic_code,
767 })
768 })?;
769
770 let flight_data_stream = response.into_inner();
771 let mut decoder = FlightDecoder::default();
772
773 let flight_message_stream = flight_data_stream.map(move |flight_data| {
774 flight_data
775 .map_err(Error::from)
776 .and_then(|data| decoder.try_decode(&data).context(ConvertFlightDataSnafu))?
777 .context(IllegalFlightMessagesSnafu {
778 reason: "none message",
779 })
780 });
781
782 output_from_flight_message_stream(flight_message_stream).await
783 }
784
785 pub async fn do_put(&self, stream: FlightDataStream) -> Result<DoPutResponseStream> {
788 let mut request = tonic::Request::new(stream);
789
790 if let Some(AuthHeader {
791 auth_scheme: Some(AuthScheme::Basic(Basic { username, password })),
792 }) = &self.ctx.auth_header
793 {
794 let encoded = BASE64_STANDARD.encode(format!("{username}:{password}"));
795 let value = MetadataValue::from_str(&format!("Basic {encoded}"))
796 .context(InvalidTonicMetadataValueSnafu)?;
797 request.metadata_mut().insert("x-greptime-auth", value);
798 }
799
800 let db_to_put = if !self.dbname.is_empty() {
801 &self.dbname
802 } else {
803 &build_db_string(self.catalog_or_default(), self.schema_or_default())
804 };
805 request.metadata_mut().insert(
806 "x-greptime-db-name",
807 MetadataValue::from_str(db_to_put).context(InvalidTonicMetadataValueSnafu)?,
808 );
809
810 let mut client = self.client.make_flight_client(false, false)?;
811 let response = client.mut_inner().do_put(request).await?;
812 let response = response
813 .into_inner()
814 .map_err(Into::into)
815 .and_then(|x| future::ready(DoPutResponse::try_from(x).context(ConvertFlightDataSnafu)))
816 .boxed();
817 Ok(response)
818 }
819}
820
821pub fn is_grpc_retryable(err: &tonic::Status) -> bool {
823 matches!(err.code(), tonic::Code::Unavailable)
824}
825
826#[derive(Default, Debug, Clone)]
827struct FlightContext {
828 auth_header: Option<AuthHeader>,
829}
830
831#[cfg(test)]
832mod tests {
833 use std::sync::Arc;
834 use std::task::{Context, Poll};
835
836 use api::v1::auth_header::AuthScheme;
837 use api::v1::{AuthHeader, Basic};
838 use common_error::status_code::StatusCode;
839 use common_query::OutputData;
840 use common_recordbatch::{OrderOption, RecordBatch, RecordBatchStream};
841 use datatypes::prelude::{ConcreteDataType, VectorRef};
842 use datatypes::schema::{ColumnSchema, Schema};
843 use datatypes::vectors::Int32Vector;
844 use futures_util::StreamExt;
845 use tonic::{Code, Status};
846
847 use super::*;
848 use crate::error::TonicSnafu;
849
850 struct MockMetricsStream {
851 schema: datatypes::schema::SchemaRef,
852 batch: Option<RecordBatch>,
853 metrics: RecordBatchMetrics,
854 terminal_metrics_only: bool,
855 }
856
857 impl Stream for MockMetricsStream {
858 type Item = common_recordbatch::error::Result<RecordBatch>;
859
860 fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
861 Poll::Ready(self.batch.take().map(Ok))
862 }
863 }
864
865 impl RecordBatchStream for MockMetricsStream {
866 fn name(&self) -> &str {
867 "MockMetricsStream"
868 }
869
870 fn schema(&self) -> datatypes::schema::SchemaRef {
871 self.schema.clone()
872 }
873
874 fn output_ordering(&self) -> Option<&[OrderOption]> {
875 None
876 }
877
878 fn metrics(&self) -> Option<RecordBatchMetrics> {
879 if self.terminal_metrics_only && self.batch.is_some() {
880 return None;
881 }
882 Some(self.metrics.clone())
883 }
884 }
885
886 fn terminal_metrics_json() -> String {
887 terminal_metrics_json_with_seq(42)
888 }
889
890 fn terminal_metrics_json_with_seq(seq: u64) -> String {
891 serde_json::to_string(&RecordBatchMetrics {
892 region_watermarks: vec![common_recordbatch::adapter::RegionWatermarkEntry {
893 region_id: 7,
894 watermark: Some(seq),
895 }],
896 ..Default::default()
897 })
898 .unwrap()
899 }
900
901 #[test]
902 fn test_put_flow_extensions_preserves_comma_bearing_values() {
903 let mut metadata = MetadataMap::new();
904 Database::put_flow_extensions(
905 &mut metadata,
906 &[
907 ("flow.return_region_seq", "true"),
908 ("flow.incremental_after_seqs", r#"{"1":10,"2":20}"#),
909 ],
910 )
911 .unwrap();
912
913 let value = metadata
914 .get(FLOW_EXTENSIONS_METADATA_KEY)
915 .unwrap()
916 .to_str()
917 .unwrap();
918 let decoded: Vec<(String, String)> = serde_json::from_str(value).unwrap();
919 assert_eq!(
920 decoded,
921 vec![
922 ("flow.return_region_seq".to_string(), "true".to_string()),
923 (
924 "flow.incremental_after_seqs".to_string(),
925 r#"{"1":10,"2":20}"#.to_string()
926 ),
927 ]
928 );
929 }
930
931 #[test]
932 fn test_flight_ctx() {
933 let mut ctx = FlightContext::default();
934 assert!(ctx.auth_header.is_none());
935
936 let basic = AuthScheme::Basic(Basic {
937 username: "u".to_string(),
938 password: "p".to_string(),
939 });
940
941 ctx.auth_header = Some(AuthHeader {
942 auth_scheme: Some(basic),
943 });
944
945 assert!(matches!(
946 ctx.auth_header,
947 Some(AuthHeader {
948 auth_scheme: Some(AuthScheme::Basic(_)),
949 })
950 ));
951 }
952
953 #[test]
954 fn test_from_tonic_status() {
955 let expected = TonicSnafu {
956 code: StatusCode::Internal,
957 msg: "blabla".to_string(),
958 tonic_code: Code::Internal,
959 }
960 .build();
961
962 let status = Status::new(Code::Internal, "blabla");
963 let actual: Error = status.into();
964
965 assert_eq!(expected.to_string(), actual.to_string());
966 }
967
968 #[tokio::test]
969 async fn test_query_with_terminal_metrics_tracks_terminal_only_metrics() {
970 let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
971 "v",
972 ConcreteDataType::int32_datatype(),
973 false,
974 )]));
975 let batch = RecordBatch::new(
976 schema.clone(),
977 vec![Arc::new(Int32Vector::from_slice([1, 2])) as VectorRef],
978 )
979 .unwrap();
980 let output = Output::new_with_stream(Box::pin(MockMetricsStream {
981 schema,
982 batch: Some(batch),
983 metrics: RecordBatchMetrics {
984 region_watermarks: vec![common_recordbatch::adapter::RegionWatermarkEntry {
985 region_id: 7,
986 watermark: Some(42),
987 }],
988 ..Default::default()
989 },
990 terminal_metrics_only: true,
991 }));
992
993 let result = OutputWithMetrics::from_output(output);
994 let terminal_metrics = result.metrics.clone();
995 assert!(!terminal_metrics.is_ready());
996 assert!(terminal_metrics.get().is_none());
997
998 let OutputData::Stream(mut stream) = result.output.data else {
999 panic!("expected stream output");
1000 };
1001 while stream.next().await.is_some() {}
1002
1003 assert!(terminal_metrics.is_ready());
1004 assert_eq!(
1005 terminal_metrics.participating_regions(),
1006 Some(std::collections::BTreeSet::from([7_u64]))
1007 );
1008 assert_eq!(
1009 terminal_metrics.region_watermark_map(),
1010 Some(std::collections::HashMap::from([(7_u64, 42_u64)]))
1011 );
1012 }
1013
1014 #[test]
1015 fn test_parse_terminal_metrics_rejects_invalid_json() {
1016 assert!(parse_terminal_metrics("{not-json}").is_err());
1017 }
1018
1019 #[tokio::test]
1020 async fn test_affected_rows_inline_metrics_are_parsed() {
1021 let output = output_from_flight_message_stream(futures_util::stream::iter(vec![Ok(
1022 FlightMessage::AffectedRows {
1023 rows: 3,
1024 metrics: Some(terminal_metrics_json()),
1025 },
1026 )]
1027 as Vec<Result<FlightMessage>>))
1028 .await
1029 .unwrap();
1030
1031 assert!(matches!(output.output.data, OutputData::AffectedRows(3)));
1032 assert!(output.metrics.is_ready());
1033 assert_eq!(
1034 output.metrics.region_watermark_map(),
1035 Some(std::collections::HashMap::from([(7, 42)]))
1036 );
1037 }
1038
1039 #[tokio::test]
1040 async fn test_affected_rows_inline_metrics_rejects_trailing_metrics() {
1041 let metrics_json = terminal_metrics_json();
1042 let err = output_from_flight_message_stream(futures_util::stream::iter(vec![
1043 Ok(FlightMessage::AffectedRows {
1044 rows: 3,
1045 metrics: Some(metrics_json.clone()),
1046 }),
1047 Ok(FlightMessage::Metrics(metrics_json)),
1048 ]
1049 as Vec<Result<FlightMessage>>))
1050 .await
1051 .unwrap_err();
1052
1053 assert!(
1054 err.to_string().contains("already carries Metrics"),
1055 "unexpected error: {err:?}"
1056 );
1057 }
1058
1059 #[tokio::test]
1060 async fn test_invalid_terminal_metrics_after_record_batch_yields_batch_then_error() {
1061 let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
1062 "v",
1063 ConcreteDataType::int32_datatype(),
1064 false,
1065 )]));
1066 let batch = RecordBatch::new(
1067 schema.clone(),
1068 vec![Arc::new(Int32Vector::from_slice([1])) as VectorRef],
1069 )
1070 .unwrap();
1071 let output = output_from_flight_message_stream(futures_util::stream::iter(vec![
1072 Ok(FlightMessage::Schema(schema.arrow_schema().clone())),
1073 Ok(FlightMessage::RecordBatch(batch.into_df_record_batch())),
1074 Ok(FlightMessage::Metrics("{not-json}".to_string())),
1075 ]
1076 as Vec<Result<FlightMessage>>))
1077 .await
1078 .unwrap();
1079 let terminal_metrics = output.metrics.clone();
1080 let OutputData::Stream(mut record_batch_stream) = output.output.data else {
1081 panic!("expected stream output");
1082 };
1083
1084 let batch = record_batch_stream.next().await.unwrap().unwrap();
1085 assert_eq!(batch.num_rows(), 1);
1086
1087 let err = record_batch_stream.next().await.unwrap().unwrap_err();
1088 assert_eq!("External error", err.to_string());
1089 assert!(
1090 format!("{err:?}").contains("Invalid terminal metrics message"),
1091 "unexpected error: {err:?}"
1092 );
1093 assert!(record_batch_stream.next().await.is_none());
1094 assert!(terminal_metrics.is_ready());
1095 assert!(terminal_metrics.get().is_none());
1096 }
1097
1098 #[tokio::test]
1099 async fn test_record_batch_stream_continues_after_partial_metrics() {
1100 let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
1101 "v",
1102 ConcreteDataType::int32_datatype(),
1103 false,
1104 )]));
1105 let first_batch = RecordBatch::new(
1106 schema.clone(),
1107 vec![Arc::new(Int32Vector::from_slice([1])) as VectorRef],
1108 )
1109 .unwrap();
1110 let second_batch = RecordBatch::new(
1111 schema.clone(),
1112 vec![Arc::new(Int32Vector::from_slice([2])) as VectorRef],
1113 )
1114 .unwrap();
1115 let output = output_from_flight_message_stream(futures_util::stream::iter(vec![
1116 Ok(FlightMessage::Schema(schema.arrow_schema().clone())),
1117 Ok(FlightMessage::RecordBatch(
1118 first_batch.into_df_record_batch(),
1119 )),
1120 Ok(FlightMessage::Metrics(terminal_metrics_json_with_seq(1))),
1121 Ok(FlightMessage::RecordBatch(
1122 second_batch.into_df_record_batch(),
1123 )),
1124 Ok(FlightMessage::Metrics(terminal_metrics_json_with_seq(2))),
1125 ]
1126 as Vec<Result<FlightMessage>>))
1127 .await
1128 .unwrap();
1129 let terminal_metrics = output.metrics.clone();
1130 let OutputData::Stream(mut record_batch_stream) = output.output.data else {
1131 panic!("expected stream output");
1132 };
1133
1134 let first_batch = record_batch_stream.next().await.unwrap().unwrap();
1135 assert_eq!(first_batch.num_rows(), 1);
1136 let second_batch = record_batch_stream.next().await.unwrap().unwrap();
1137 assert_eq!(second_batch.num_rows(), 1);
1138 assert!(record_batch_stream.next().await.is_none());
1139
1140 assert!(terminal_metrics.is_ready());
1141 assert_eq!(
1142 terminal_metrics.region_watermark_map(),
1143 Some(std::collections::HashMap::from([(7, 2)]))
1144 );
1145 }
1146
1147 #[test]
1148 fn test_output_metrics_distinguishes_empty_region_watermarks_from_absence() {
1149 let metrics = OutputMetrics::default();
1150 metrics.update(Some(RecordBatchMetrics::default()));
1151
1152 assert_eq!(
1153 metrics.participating_regions(),
1154 Some(std::collections::BTreeSet::new())
1155 );
1156 assert_eq!(
1157 metrics.region_watermark_map(),
1158 Some(std::collections::HashMap::new())
1159 );
1160 }
1161}