1use std::pin::Pin;
16use std::str::FromStr;
17use std::sync::atomic::{AtomicBool, Ordering};
18use std::sync::{Arc, RwLock};
19use std::task::{Context, Poll};
20
21use api::v1::auth_header::AuthScheme;
22use api::v1::ddl_request::Expr as DdlExpr;
23use api::v1::greptime_database_client::GreptimeDatabaseClient;
24use api::v1::greptime_request::Request;
25use api::v1::query_request::Query;
26use api::v1::{
27 AlterTableExpr, AuthHeader, Basic, CreateTableExpr, DdlRequest, GreptimeRequest,
28 InsertRequests, QueryRequest, RequestHeader, RowInsertRequests,
29};
30use arc_swap::ArcSwapOption;
31use arrow_flight::{FlightData, Ticket};
32use async_stream::stream;
33use base64::Engine;
34use base64::prelude::BASE64_STANDARD;
35use common_catalog::build_db_string;
36use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
37use common_error::ext::BoxedError;
38use common_grpc::flight::do_put::DoPutResponse;
39use common_grpc::flight::{
40 FLOW_EXTENSIONS_METADATA_KEY, FlightDecoder, FlightMessage, SNAPSHOT_SEQS_METADATA_KEY,
41};
42use common_query::Output;
43use common_recordbatch::adapter::RecordBatchMetrics;
44use common_recordbatch::error::ExternalSnafu;
45use common_recordbatch::{OrderOption, RecordBatch, RecordBatchStream, RecordBatchStreamWrapper};
46use common_telemetry::tracing::Span;
47use common_telemetry::tracing_context::W3cTrace;
48use common_telemetry::{error, warn};
49use futures::future;
50use futures_util::{Stream, StreamExt, TryStreamExt};
51use prost::Message;
52use snafu::{OptionExt, ResultExt};
53use tonic::metadata::{AsciiMetadataKey, AsciiMetadataValue, MetadataMap, MetadataValue};
54use tonic::transport::Channel;
55
56use crate::error::{
57 ConvertFlightDataSnafu, Error, FlightGetSnafu, IllegalFlightMessagesSnafu,
58 InvalidTonicMetadataValueSnafu,
59};
60use crate::{Client, Result, error, from_grpc_response};
61
62type FlightDataStream = Pin<Box<dyn Stream<Item = FlightData> + Send>>;
63
64type DoPutResponseStream = Pin<Box<dyn Stream<Item = Result<DoPutResponse>>>>;
65
66#[derive(Debug, Clone, Default)]
71pub struct OutputMetrics {
72 inner: Arc<OutputMetricsInner>,
73}
74
75#[derive(Debug, Default)]
76struct OutputMetricsInner {
77 metrics: RwLock<Option<RecordBatchMetrics>>,
78 ready: AtomicBool,
79}
80
81impl OutputMetrics {
82 fn new() -> Self {
83 Self::default()
84 }
85
86 pub fn update(&self, metrics: Option<RecordBatchMetrics>) {
88 *self.inner.metrics.write().unwrap() = metrics;
89 }
90
91 pub fn mark_ready(&self) {
93 let _ = self
94 .inner
95 .ready
96 .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire);
97 }
98
99 pub fn is_ready(&self) -> bool {
103 self.inner.ready.load(Ordering::Acquire)
104 }
105
106 pub fn get(&self) -> Option<RecordBatchMetrics> {
108 self.inner.metrics.read().unwrap().clone()
109 }
110
111 pub fn region_watermark_map(&self) -> Option<std::collections::HashMap<u64, u64>> {
117 Some(
118 self.get()?
119 .region_watermarks
120 .into_iter()
121 .filter_map(|entry| entry.watermark.map(|seq| (entry.region_id, seq)))
122 .collect::<std::collections::HashMap<_, _>>(),
123 )
124 }
125
126 pub fn participating_regions(&self) -> Option<std::collections::BTreeSet<u64>> {
129 Some(
130 self.get()?
131 .region_watermarks
132 .into_iter()
133 .map(|entry| entry.region_id)
134 .collect::<std::collections::BTreeSet<_>>(),
135 )
136 }
137}
138
139#[derive(Debug)]
145pub struct OutputWithMetrics {
146 pub output: Output,
147 pub metrics: OutputMetrics,
148}
149
150impl OutputWithMetrics {
151 pub fn from_output(output: Output) -> Self {
156 let terminal_metrics = OutputMetrics::new();
157 let output = attach_terminal_metrics(output, &terminal_metrics);
158 Self {
159 output,
160 metrics: terminal_metrics,
161 }
162 }
163
164 pub fn region_watermark_map(&self) -> Option<std::collections::HashMap<u64, u64>> {
166 self.metrics.region_watermark_map()
167 }
168
169 pub fn participating_regions(&self) -> Option<std::collections::BTreeSet<u64>> {
171 self.metrics.participating_regions()
172 }
173
174 pub fn into_output(self) -> Output {
176 self.output
177 }
178}
179
180fn parse_terminal_metrics(metrics_json: &str) -> Result<RecordBatchMetrics> {
181 serde_json::from_str(metrics_json).map_err(|e| {
182 IllegalFlightMessagesSnafu {
183 reason: format!("Invalid terminal metrics message: {e}"),
184 }
185 .build()
186 })
187}
188
189struct StreamWithMetrics {
190 stream: common_recordbatch::SendableRecordBatchStream,
191 metrics: OutputMetrics,
192}
193
194impl StreamWithMetrics {
195 fn new(stream: common_recordbatch::SendableRecordBatchStream, metrics: OutputMetrics) -> Self {
196 Self { stream, metrics }
197 }
198
199 fn sync_terminal_metrics(&self) {
200 self.metrics.update(self.stream.metrics());
201 }
202}
203
204impl RecordBatchStream for StreamWithMetrics {
205 fn name(&self) -> &str {
206 self.stream.name()
207 }
208
209 fn schema(&self) -> datatypes::schema::SchemaRef {
210 self.stream.schema()
211 }
212
213 fn output_ordering(&self) -> Option<&[OrderOption]> {
214 self.stream.output_ordering()
215 }
216
217 fn metrics(&self) -> Option<RecordBatchMetrics> {
218 self.sync_terminal_metrics();
219 self.metrics.get()
220 }
221}
222
223impl Stream for StreamWithMetrics {
224 type Item = common_recordbatch::error::Result<RecordBatch>;
225
226 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
227 let polled = Pin::new(&mut self.stream).poll_next(cx);
228 if let Poll::Ready(None) = &polled {
229 self.sync_terminal_metrics();
230 self.metrics.mark_ready();
231 }
232 polled
233 }
234
235 fn size_hint(&self) -> (usize, Option<usize>) {
236 self.stream.size_hint()
237 }
238}
239
240fn attach_terminal_metrics(output: Output, terminal_metrics: &OutputMetrics) -> Output {
241 let Output { data, meta } = output;
242 let data = match data {
243 common_query::OutputData::Stream(stream) => {
244 terminal_metrics.update(stream.metrics());
245 common_query::OutputData::Stream(Box::pin(StreamWithMetrics::new(
246 stream,
247 terminal_metrics.clone(),
248 )))
249 }
250 other => {
251 terminal_metrics.mark_ready();
252 other
253 }
254 };
255 Output::new(data, meta)
256}
257
258async fn output_from_flight_message_stream<S>(
259 mut flight_message_stream: S,
260) -> Result<OutputWithMetrics>
261where
262 S: Stream<Item = Result<FlightMessage>> + Send + Unpin + 'static,
263{
264 let Some(first_flight_message) = flight_message_stream.next().await else {
265 return IllegalFlightMessagesSnafu {
266 reason: "Expect the response not to be empty",
267 }
268 .fail();
269 };
270
271 let first_flight_message = first_flight_message?;
272
273 match first_flight_message {
274 FlightMessage::AffectedRows { rows, metrics } => {
275 let terminal_metrics = OutputMetrics::new();
276 if let Some(metrics) = metrics {
277 terminal_metrics.update(Some(parse_terminal_metrics(&metrics)?));
278 }
279 let next_message = flight_message_stream.next().await.transpose()?;
280 match next_message {
281 None => terminal_metrics.mark_ready(),
282 Some(FlightMessage::Metrics(s)) if terminal_metrics.get().is_none() => {
283 terminal_metrics.update(Some(parse_terminal_metrics(&s)?));
284 terminal_metrics.mark_ready();
285 }
286 Some(FlightMessage::Metrics(_)) => {
287 return IllegalFlightMessagesSnafu {
288 reason: "'AffectedRows' Flight metadata already carries Metrics and cannot be followed by another Metrics message",
289 }
290 .fail();
291 }
292 Some(other) => {
293 return IllegalFlightMessagesSnafu {
294 reason: format!(
295 "'AffectedRows' Flight message can only be followed by a Metrics message, got {other:?}"
296 ),
297 }
298 .fail();
299 }
300 }
301 Ok(OutputWithMetrics {
302 output: Output::new_with_affected_rows(rows),
303 metrics: terminal_metrics,
304 })
305 }
306 FlightMessage::RecordBatch(_) | FlightMessage::Metrics(_) => IllegalFlightMessagesSnafu {
307 reason: "The first flight message cannot be a RecordBatch or Metrics message",
308 }
309 .fail(),
310 FlightMessage::Schema(schema) => {
311 let metrics = Arc::new(ArcSwapOption::from(None));
312 let metrics_ref = metrics.clone();
313 let schema = Arc::new(
314 datatypes::schema::Schema::try_from(schema).context(error::ConvertSchemaSnafu)?,
315 );
316 let schema_cloned = schema.clone();
317 let stream = Box::pin(stream!({
318 while let Some(flight_message_item) = flight_message_stream.next().await {
319 let flight_message = match flight_message_item {
320 Ok(message) => message,
321 Err(e) => {
322 yield Err(BoxedError::new(e)).context(ExternalSnafu);
323 break;
324 }
325 };
326
327 match flight_message {
328 FlightMessage::RecordBatch(arrow_batch) => {
329 yield Ok(RecordBatch::from_df_record_batch(
330 schema_cloned.clone(),
331 arrow_batch,
332 ))
333 }
334 FlightMessage::Metrics(s) => {
335 match parse_terminal_metrics(&s) {
336 Ok(m) => {
337 metrics_ref.swap(Some(Arc::new(m)));
338 }
339 Err(e) => {
340 yield Err(BoxedError::new(e)).context(ExternalSnafu);
341 }
342 };
343 }
344 FlightMessage::AffectedRows { .. } | FlightMessage::Schema(_) => {
345 yield IllegalFlightMessagesSnafu {
346 reason: format!(
347 "A Schema message must be succeeded exclusively by a set of RecordBatch messages, flight_message: {:?}",
348 flight_message
349 )
350 }
351 .fail()
352 .map_err(BoxedError::new)
353 .context(ExternalSnafu);
354 break;
355 }
356 }
357 }
358 }));
359 let record_batch_stream = RecordBatchStreamWrapper {
360 schema,
361 stream,
362 output_ordering: None,
363 metrics,
364 span: Span::current(),
365 };
366 Ok(OutputWithMetrics::from_output(Output::new_with_stream(
367 Box::pin(record_batch_stream),
368 )))
369 }
370 }
371}
372
373#[derive(Clone, Debug, Default)]
374pub struct Database {
375 catalog: String,
379 schema: String,
380 dbname: String,
383 timezone: String,
386
387 client: Client,
388 ctx: FlightContext,
389}
390
391pub struct DatabaseClient {
392 pub addr: String,
393 pub inner: GreptimeDatabaseClient<Channel>,
394}
395
396impl DatabaseClient {
397 pub fn inspect_err<'a>(&'a self, context: &'a str) -> impl Fn(&tonic::Status) + 'a {
399 let addr = &self.addr;
400 move |status| {
401 error!("Failed to {context} request, peer: {addr}, status: {status:?}");
402 }
403 }
404}
405
406fn make_database_client(client: &Client) -> Result<DatabaseClient> {
407 let (addr, channel) = client.find_channel()?;
408 Ok(DatabaseClient {
409 addr,
410 inner: GreptimeDatabaseClient::new(channel)
411 .max_decoding_message_size(client.max_grpc_recv_message_size())
412 .max_encoding_message_size(client.max_grpc_send_message_size()),
413 })
414}
415
416impl Database {
417 pub fn new(catalog: impl Into<String>, schema: impl Into<String>, client: Client) -> Self {
419 Self {
420 catalog: catalog.into(),
421 schema: schema.into(),
422 dbname: String::default(),
423 timezone: String::default(),
424 client,
425 ctx: FlightContext::default(),
426 }
427 }
428
429 pub fn new_with_dbname(dbname: impl Into<String>, client: Client) -> Self {
437 Self {
438 catalog: String::default(),
439 schema: String::default(),
440 timezone: String::default(),
441 dbname: dbname.into(),
442 client,
443 ctx: FlightContext::default(),
444 }
445 }
446
447 pub fn set_catalog(&mut self, catalog: impl Into<String>) {
449 self.catalog = catalog.into();
450 }
451
452 fn catalog_or_default(&self) -> &str {
453 if self.catalog.is_empty() {
454 DEFAULT_CATALOG_NAME
455 } else {
456 &self.catalog
457 }
458 }
459
460 pub fn set_schema(&mut self, schema: impl Into<String>) {
462 self.schema = schema.into();
463 }
464
465 fn schema_or_default(&self) -> &str {
466 if self.schema.is_empty() {
467 DEFAULT_SCHEMA_NAME
468 } else {
469 &self.schema
470 }
471 }
472
473 pub fn set_timezone(&mut self, timezone: impl Into<String>) {
475 self.timezone = timezone.into();
476 }
477
478 pub fn set_auth(&mut self, auth: AuthScheme) {
480 self.ctx.auth_header = Some(AuthHeader {
481 auth_scheme: Some(auth),
482 });
483 }
484
485 pub async fn insert(&self, requests: InsertRequests) -> Result<u32> {
487 self.handle(Request::Inserts(requests)).await
488 }
489
490 pub async fn insert_with_hints(
492 &self,
493 requests: InsertRequests,
494 hints: &[(&str, &str)],
495 ) -> Result<u32> {
496 let mut client = make_database_client(&self.client)?;
497 let request = self.to_rpc_request(Request::Inserts(requests));
498
499 let mut request = tonic::Request::new(request);
500 let metadata = request.metadata_mut();
501 Self::put_hints(metadata, hints)?;
502
503 let response = client
504 .inner
505 .handle(request)
506 .await
507 .inspect_err(client.inspect_err("insert_with_hints"))?
508 .into_inner();
509 from_grpc_response(response)
510 }
511
512 pub async fn row_inserts(&self, requests: RowInsertRequests) -> Result<u32> {
514 self.handle(Request::RowInserts(requests)).await
515 }
516
517 pub async fn row_inserts_with_hints(
519 &self,
520 requests: RowInsertRequests,
521 hints: &[(&str, &str)],
522 ) -> Result<u32> {
523 let mut client = make_database_client(&self.client)?;
524 let request = self.to_rpc_request(Request::RowInserts(requests));
525
526 let mut request = tonic::Request::new(request);
527 let metadata = request.metadata_mut();
528 Self::put_hints(metadata, hints)?;
529
530 let response = client
531 .inner
532 .handle(request)
533 .await
534 .inspect_err(client.inspect_err("row_inserts_with_hints"))?
535 .into_inner();
536 from_grpc_response(response)
537 }
538
539 fn put_hints(metadata: &mut MetadataMap, hints: &[(&str, &str)]) -> Result<()> {
540 let Some(value) = hints
541 .iter()
542 .map(|(k, v)| format!("{}={}", k, v))
543 .reduce(|a, b| format!("{},{}", a, b))
544 else {
545 return Ok(());
546 };
547
548 let key = AsciiMetadataKey::from_static("x-greptime-hints");
549 let value = AsciiMetadataValue::from_str(&value).context(InvalidTonicMetadataValueSnafu)?;
550 metadata.insert(key, value);
551 Ok(())
552 }
553
554 fn put_flow_extensions(
555 metadata: &mut MetadataMap,
556 flow_extensions: &[(&str, &str)],
557 ) -> Result<()> {
558 if flow_extensions.is_empty() {
559 return Ok(());
560 }
561
562 let value = serde_json::to_string(&flow_extensions.to_vec())
563 .expect("flow extension pairs should serialize");
564 Self::put_metadata_value(metadata, FLOW_EXTENSIONS_METADATA_KEY, value)
565 }
566
567 fn put_snapshot_seqs(
568 metadata: &mut MetadataMap,
569 snapshot_seqs: &std::collections::HashMap<u64, u64>,
570 ) -> Result<()> {
571 if snapshot_seqs.is_empty() {
572 return Ok(());
573 }
574
575 let value =
576 serde_json::to_string(snapshot_seqs).expect("snapshot sequence map should serialize");
577 Self::put_metadata_value(metadata, SNAPSHOT_SEQS_METADATA_KEY, value)
578 }
579
580 fn put_metadata_value(
581 metadata: &mut MetadataMap,
582 key: &'static str,
583 value: String,
584 ) -> Result<()> {
585 let key = AsciiMetadataKey::from_static(key);
586 let value = AsciiMetadataValue::from_str(&value).context(InvalidTonicMetadataValueSnafu)?;
587 metadata.insert(key, value);
588 Ok(())
589 }
590
591 pub async fn handle(&self, request: Request) -> Result<u32> {
593 let mut client = make_database_client(&self.client)?;
594 let request = self.to_rpc_request(request);
595 let response = client
596 .inner
597 .handle(request)
598 .await
599 .inspect_err(client.inspect_err("handle"))?
600 .into_inner();
601 from_grpc_response(response)
602 }
603
604 pub async fn handle_with_retry(
607 &self,
608 request: Request,
609 max_retries: u32,
610 hints: &[(&str, &str)],
611 ) -> Result<u32> {
612 let mut client = make_database_client(&self.client)?;
613 let mut retries = 0;
614
615 let request = self.to_rpc_request(request);
616
617 loop {
618 let mut tonic_request = tonic::Request::new(request.clone());
619 let metadata = tonic_request.metadata_mut();
620 Self::put_hints(metadata, hints)?;
621 let raw_response = client
622 .inner
623 .handle(tonic_request)
624 .await
625 .inspect_err(client.inspect_err("handle"));
626 match (raw_response, retries < max_retries) {
627 (Ok(resp), _) => return from_grpc_response(resp.into_inner()),
628 (Err(err), true) => {
629 if is_grpc_retryable(&err) {
631 retries += 1;
633 warn!("Retrying {} times with error = {:?}", retries, err);
634 continue;
635 } else {
636 error!(
637 err; "Failed to send request to grpc handle, retries = {}, not retryable error, aborting",
638 retries
639 );
640 return Err(err.into());
641 }
642 }
643 (Err(err), false) => {
644 error!(
645 err; "Failed to send request to grpc handle after {} retries",
646 retries,
647 );
648 return Err(err.into());
649 }
650 }
651 }
652 }
653
654 #[inline]
655 fn to_rpc_request(&self, request: Request) -> GreptimeRequest {
656 GreptimeRequest {
657 header: Some(RequestHeader {
658 catalog: self.catalog.clone(),
659 schema: self.schema.clone(),
660 authorization: self.ctx.auth_header.clone(),
661 dbname: self.dbname.clone(),
662 timezone: self.timezone.clone(),
663 tracing_context: W3cTrace::new(),
665 }),
666 request: Some(request),
667 }
668 }
669
670 pub async fn sql<S>(&self, sql: S) -> Result<Output>
672 where
673 S: AsRef<str>,
674 {
675 self.sql_with_hint(sql, &[]).await
676 }
677
678 pub async fn sql_with_hint<S>(&self, sql: S, hints: &[(&str, &str)]) -> Result<Output>
680 where
681 S: AsRef<str>,
682 {
683 let request = Request::Query(QueryRequest {
684 query: Some(Query::Sql(sql.as_ref().to_string())),
685 });
686 self.do_get(request, hints, &[], &Default::default())
687 .await
688 .map(OutputWithMetrics::into_output)
689 }
690
691 pub async fn sql_with_terminal_metrics<S>(
696 &self,
697 sql: S,
698 hints: &[(&str, &str)],
699 ) -> Result<OutputWithMetrics>
700 where
701 S: AsRef<str>,
702 {
703 self.query_with_terminal_metrics_and_flow_extensions(
704 QueryRequest {
705 query: Some(Query::Sql(sql.as_ref().to_string())),
706 },
707 hints,
708 &[],
709 &Default::default(),
710 )
711 .await
712 }
713
714 pub async fn logical_plan(&self, logical_plan: Vec<u8>) -> Result<Output> {
716 self.query_with_terminal_metrics_and_flow_extensions(
717 QueryRequest {
718 query: Some(Query::LogicalPlan(logical_plan)),
719 },
720 &[],
721 &[],
722 &Default::default(),
723 )
724 .await
725 .map(OutputWithMetrics::into_output)
726 }
727
728 pub async fn query_with_terminal_metrics_and_flow_extensions(
733 &self,
734 request: QueryRequest,
735 hints: &[(&str, &str)],
736 flow_extensions: &[(&str, &str)],
737 snapshot_seqs: &std::collections::HashMap<u64, u64>,
738 ) -> Result<OutputWithMetrics> {
739 self.do_get(
740 Request::Query(request),
741 hints,
742 flow_extensions,
743 snapshot_seqs,
744 )
745 .await
746 }
747
748 pub async fn create(&self, expr: CreateTableExpr) -> Result<Output> {
750 let request = Request::Ddl(DdlRequest {
751 expr: Some(DdlExpr::CreateTable(expr)),
752 });
753 self.do_get(request, &[], &[], &Default::default())
754 .await
755 .map(OutputWithMetrics::into_output)
756 }
757
758 pub async fn alter(&self, expr: AlterTableExpr) -> Result<Output> {
760 let request = Request::Ddl(DdlRequest {
761 expr: Some(DdlExpr::AlterTable(expr)),
762 });
763 self.do_get(request, &[], &[], &Default::default())
764 .await
765 .map(OutputWithMetrics::into_output)
766 }
767
768 async fn do_get(
769 &self,
770 request: Request,
771 hints: &[(&str, &str)],
772 flow_extensions: &[(&str, &str)],
773 snapshot_seqs: &std::collections::HashMap<u64, u64>,
774 ) -> Result<OutputWithMetrics> {
775 let request = self.to_rpc_request(request);
776 let request = Ticket {
777 ticket: request.encode_to_vec().into(),
778 };
779
780 let mut request = tonic::Request::new(request);
781 let metadata = request.metadata_mut();
782 Self::put_hints(metadata, hints)?;
783 Self::put_flow_extensions(metadata, flow_extensions)?;
784 Self::put_snapshot_seqs(metadata, snapshot_seqs)?;
785
786 let mut client = self.client.make_flight_client(false, false)?;
787
788 let response = client.mut_inner().do_get(request).await.or_else(|e| {
789 let tonic_code = e.code();
790 let e: Error = e.into();
791 error!(
792 "Failed to do Flight get, addr: {}, code: {}, source: {:?}",
793 client.addr(),
794 tonic_code,
795 e
796 );
797 Err(BoxedError::new(e)).with_context(|_| FlightGetSnafu {
798 addr: client.addr().to_string(),
799 tonic_code,
800 })
801 })?;
802
803 let flight_data_stream = response.into_inner();
804 let mut decoder = FlightDecoder::default();
805
806 let flight_message_stream = flight_data_stream.map(move |flight_data| {
807 flight_data
808 .map_err(Error::from)
809 .and_then(|data| decoder.try_decode(&data).context(ConvertFlightDataSnafu))?
810 .context(IllegalFlightMessagesSnafu {
811 reason: "none message",
812 })
813 });
814
815 output_from_flight_message_stream(flight_message_stream).await
816 }
817
818 pub async fn do_put(&self, stream: FlightDataStream) -> Result<DoPutResponseStream> {
821 let mut request = tonic::Request::new(stream);
822
823 if let Some(AuthHeader {
824 auth_scheme: Some(AuthScheme::Basic(Basic { username, password })),
825 }) = &self.ctx.auth_header
826 {
827 let encoded = BASE64_STANDARD.encode(format!("{username}:{password}"));
828 let value = MetadataValue::from_str(&format!("Basic {encoded}"))
829 .context(InvalidTonicMetadataValueSnafu)?;
830 request.metadata_mut().insert("x-greptime-auth", value);
831 }
832
833 let db_to_put = if !self.dbname.is_empty() {
834 &self.dbname
835 } else {
836 &build_db_string(self.catalog_or_default(), self.schema_or_default())
837 };
838 request.metadata_mut().insert(
839 "x-greptime-db-name",
840 MetadataValue::from_str(db_to_put).context(InvalidTonicMetadataValueSnafu)?,
841 );
842
843 let mut client = self.client.make_flight_client(false, false)?;
844 let response = client.mut_inner().do_put(request).await?;
845 let response = response
846 .into_inner()
847 .map_err(Into::into)
848 .and_then(|x| future::ready(DoPutResponse::try_from(x).context(ConvertFlightDataSnafu)))
849 .boxed();
850 Ok(response)
851 }
852}
853
854pub fn is_grpc_retryable(err: &tonic::Status) -> bool {
856 matches!(err.code(), tonic::Code::Unavailable)
857}
858
859#[derive(Default, Debug, Clone)]
860struct FlightContext {
861 auth_header: Option<AuthHeader>,
862}
863
864#[cfg(test)]
865mod tests {
866 use std::sync::Arc;
867 use std::task::{Context, Poll};
868
869 use api::v1::auth_header::AuthScheme;
870 use api::v1::{AuthHeader, Basic};
871 use common_error::status_code::StatusCode;
872 use common_query::OutputData;
873 use common_recordbatch::{OrderOption, RecordBatch, RecordBatchStream};
874 use datatypes::prelude::{ConcreteDataType, VectorRef};
875 use datatypes::schema::{ColumnSchema, Schema};
876 use datatypes::vectors::Int32Vector;
877 use futures_util::StreamExt;
878 use tonic::{Code, Status};
879
880 use super::*;
881 use crate::error::TonicSnafu;
882
883 struct MockMetricsStream {
884 schema: datatypes::schema::SchemaRef,
885 batch: Option<RecordBatch>,
886 metrics: RecordBatchMetrics,
887 terminal_metrics_only: bool,
888 }
889
890 impl Stream for MockMetricsStream {
891 type Item = common_recordbatch::error::Result<RecordBatch>;
892
893 fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
894 Poll::Ready(self.batch.take().map(Ok))
895 }
896 }
897
898 impl RecordBatchStream for MockMetricsStream {
899 fn name(&self) -> &str {
900 "MockMetricsStream"
901 }
902
903 fn schema(&self) -> datatypes::schema::SchemaRef {
904 self.schema.clone()
905 }
906
907 fn output_ordering(&self) -> Option<&[OrderOption]> {
908 None
909 }
910
911 fn metrics(&self) -> Option<RecordBatchMetrics> {
912 if self.terminal_metrics_only && self.batch.is_some() {
913 return None;
914 }
915 Some(self.metrics.clone())
916 }
917 }
918
919 fn terminal_metrics_json() -> String {
920 terminal_metrics_json_with_seq(42)
921 }
922
923 fn terminal_metrics_json_with_seq(seq: u64) -> String {
924 serde_json::to_string(&RecordBatchMetrics {
925 region_watermarks: vec![common_recordbatch::adapter::RegionWatermarkEntry {
926 region_id: 7,
927 watermark: Some(seq),
928 }],
929 ..Default::default()
930 })
931 .unwrap()
932 }
933
934 #[test]
935 fn test_put_flow_extensions_preserves_comma_bearing_values() {
936 let mut metadata = MetadataMap::new();
937 Database::put_flow_extensions(
938 &mut metadata,
939 &[
940 ("flow.return_region_seq", "true"),
941 ("flow.incremental_after_seqs", r#"{"1":10,"2":20}"#),
942 ],
943 )
944 .unwrap();
945
946 let value = metadata
947 .get(FLOW_EXTENSIONS_METADATA_KEY)
948 .unwrap()
949 .to_str()
950 .unwrap();
951 let decoded: Vec<(String, String)> = serde_json::from_str(value).unwrap();
952 assert_eq!(
953 decoded,
954 vec![
955 ("flow.return_region_seq".to_string(), "true".to_string()),
956 (
957 "flow.incremental_after_seqs".to_string(),
958 r#"{"1":10,"2":20}"#.to_string()
959 ),
960 ]
961 );
962 }
963
964 #[test]
965 fn test_put_snapshot_seqs_preserves_u64_precision() {
966 let mut metadata = MetadataMap::new();
967 let snapshot_seqs = std::collections::HashMap::from([
968 (u64::MAX, u64::MAX - 1),
969 (9_007_199_254_740_993_u64, 9_007_199_254_740_995_u64),
970 ]);
971
972 Database::put_snapshot_seqs(&mut metadata, &snapshot_seqs).unwrap();
973
974 let value = metadata
975 .get(SNAPSHOT_SEQS_METADATA_KEY)
976 .unwrap()
977 .to_str()
978 .unwrap();
979 let decoded: std::collections::HashMap<u64, u64> = serde_json::from_str(value).unwrap();
980 assert_eq!(decoded, snapshot_seqs);
981 }
982
983 #[test]
984 fn test_flight_ctx() {
985 let mut ctx = FlightContext::default();
986 assert!(ctx.auth_header.is_none());
987
988 let basic = AuthScheme::Basic(Basic {
989 username: "u".to_string(),
990 password: "p".to_string(),
991 });
992
993 ctx.auth_header = Some(AuthHeader {
994 auth_scheme: Some(basic),
995 });
996
997 assert!(matches!(
998 ctx.auth_header,
999 Some(AuthHeader {
1000 auth_scheme: Some(AuthScheme::Basic(_)),
1001 })
1002 ));
1003 }
1004
1005 #[test]
1006 fn test_from_tonic_status() {
1007 let expected = TonicSnafu {
1008 code: StatusCode::Internal,
1009 msg: "blabla".to_string(),
1010 tonic_code: Code::Internal,
1011 }
1012 .build();
1013
1014 let status = Status::new(Code::Internal, "blabla");
1015 let actual: Error = status.into();
1016
1017 assert_eq!(expected.to_string(), actual.to_string());
1018 }
1019
1020 #[tokio::test]
1021 async fn test_query_with_terminal_metrics_tracks_terminal_only_metrics() {
1022 let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
1023 "v",
1024 ConcreteDataType::int32_datatype(),
1025 false,
1026 )]));
1027 let batch = RecordBatch::new(
1028 schema.clone(),
1029 vec![Arc::new(Int32Vector::from_slice([1, 2])) as VectorRef],
1030 )
1031 .unwrap();
1032 let output = Output::new_with_stream(Box::pin(MockMetricsStream {
1033 schema,
1034 batch: Some(batch),
1035 metrics: RecordBatchMetrics {
1036 region_watermarks: vec![common_recordbatch::adapter::RegionWatermarkEntry {
1037 region_id: 7,
1038 watermark: Some(42),
1039 }],
1040 ..Default::default()
1041 },
1042 terminal_metrics_only: true,
1043 }));
1044
1045 let result = OutputWithMetrics::from_output(output);
1046 let terminal_metrics = result.metrics.clone();
1047 assert!(!terminal_metrics.is_ready());
1048 assert!(terminal_metrics.get().is_none());
1049
1050 let OutputData::Stream(mut stream) = result.output.data else {
1051 panic!("expected stream output");
1052 };
1053 while stream.next().await.is_some() {}
1054
1055 assert!(terminal_metrics.is_ready());
1056 assert_eq!(
1057 terminal_metrics.participating_regions(),
1058 Some(std::collections::BTreeSet::from([7_u64]))
1059 );
1060 assert_eq!(
1061 terminal_metrics.region_watermark_map(),
1062 Some(std::collections::HashMap::from([(7_u64, 42_u64)]))
1063 );
1064 }
1065
1066 #[test]
1067 fn test_parse_terminal_metrics_rejects_invalid_json() {
1068 assert!(parse_terminal_metrics("{not-json}").is_err());
1069 }
1070
1071 #[tokio::test]
1072 async fn test_affected_rows_inline_metrics_are_parsed() {
1073 let output = output_from_flight_message_stream(futures_util::stream::iter(vec![Ok(
1074 FlightMessage::AffectedRows {
1075 rows: 3,
1076 metrics: Some(terminal_metrics_json()),
1077 },
1078 )]
1079 as Vec<Result<FlightMessage>>))
1080 .await
1081 .unwrap();
1082
1083 assert!(matches!(output.output.data, OutputData::AffectedRows(3)));
1084 assert!(output.metrics.is_ready());
1085 assert_eq!(
1086 output.metrics.region_watermark_map(),
1087 Some(std::collections::HashMap::from([(7, 42)]))
1088 );
1089 }
1090
1091 #[tokio::test]
1092 async fn test_affected_rows_inline_metrics_rejects_trailing_metrics() {
1093 let metrics_json = terminal_metrics_json();
1094 let err = output_from_flight_message_stream(futures_util::stream::iter(vec![
1095 Ok(FlightMessage::AffectedRows {
1096 rows: 3,
1097 metrics: Some(metrics_json.clone()),
1098 }),
1099 Ok(FlightMessage::Metrics(metrics_json)),
1100 ]
1101 as Vec<Result<FlightMessage>>))
1102 .await
1103 .unwrap_err();
1104
1105 assert!(
1106 err.to_string().contains("already carries Metrics"),
1107 "unexpected error: {err:?}"
1108 );
1109 }
1110
1111 #[tokio::test]
1112 async fn test_invalid_terminal_metrics_after_record_batch_yields_batch_then_error() {
1113 let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
1114 "v",
1115 ConcreteDataType::int32_datatype(),
1116 false,
1117 )]));
1118 let batch = RecordBatch::new(
1119 schema.clone(),
1120 vec![Arc::new(Int32Vector::from_slice([1])) as VectorRef],
1121 )
1122 .unwrap();
1123 let output = output_from_flight_message_stream(futures_util::stream::iter(vec![
1124 Ok(FlightMessage::Schema(schema.arrow_schema().clone())),
1125 Ok(FlightMessage::RecordBatch(batch.into_df_record_batch())),
1126 Ok(FlightMessage::Metrics("{not-json}".to_string())),
1127 ]
1128 as Vec<Result<FlightMessage>>))
1129 .await
1130 .unwrap();
1131 let terminal_metrics = output.metrics.clone();
1132 let OutputData::Stream(mut record_batch_stream) = output.output.data else {
1133 panic!("expected stream output");
1134 };
1135
1136 let batch = record_batch_stream.next().await.unwrap().unwrap();
1137 assert_eq!(batch.num_rows(), 1);
1138
1139 let err = record_batch_stream.next().await.unwrap().unwrap_err();
1140 assert_eq!("External error", err.to_string());
1141 assert!(
1142 format!("{err:?}").contains("Invalid terminal metrics message"),
1143 "unexpected error: {err:?}"
1144 );
1145 assert!(record_batch_stream.next().await.is_none());
1146 assert!(terminal_metrics.is_ready());
1147 assert!(terminal_metrics.get().is_none());
1148 }
1149
1150 #[tokio::test]
1151 async fn test_record_batch_stream_continues_after_partial_metrics() {
1152 let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
1153 "v",
1154 ConcreteDataType::int32_datatype(),
1155 false,
1156 )]));
1157 let first_batch = RecordBatch::new(
1158 schema.clone(),
1159 vec![Arc::new(Int32Vector::from_slice([1])) as VectorRef],
1160 )
1161 .unwrap();
1162 let second_batch = RecordBatch::new(
1163 schema.clone(),
1164 vec![Arc::new(Int32Vector::from_slice([2])) as VectorRef],
1165 )
1166 .unwrap();
1167 let output = output_from_flight_message_stream(futures_util::stream::iter(vec![
1168 Ok(FlightMessage::Schema(schema.arrow_schema().clone())),
1169 Ok(FlightMessage::RecordBatch(
1170 first_batch.into_df_record_batch(),
1171 )),
1172 Ok(FlightMessage::Metrics(terminal_metrics_json_with_seq(1))),
1173 Ok(FlightMessage::RecordBatch(
1174 second_batch.into_df_record_batch(),
1175 )),
1176 Ok(FlightMessage::Metrics(terminal_metrics_json_with_seq(2))),
1177 ]
1178 as Vec<Result<FlightMessage>>))
1179 .await
1180 .unwrap();
1181 let terminal_metrics = output.metrics.clone();
1182 let OutputData::Stream(mut record_batch_stream) = output.output.data else {
1183 panic!("expected stream output");
1184 };
1185
1186 let first_batch = record_batch_stream.next().await.unwrap().unwrap();
1187 assert_eq!(first_batch.num_rows(), 1);
1188 let second_batch = record_batch_stream.next().await.unwrap().unwrap();
1189 assert_eq!(second_batch.num_rows(), 1);
1190 assert!(record_batch_stream.next().await.is_none());
1191
1192 assert!(terminal_metrics.is_ready());
1193 assert_eq!(
1194 terminal_metrics.region_watermark_map(),
1195 Some(std::collections::HashMap::from([(7, 2)]))
1196 );
1197 }
1198
1199 #[test]
1200 fn test_output_metrics_distinguishes_empty_region_watermarks_from_absence() {
1201 let metrics = OutputMetrics::default();
1202 metrics.update(Some(RecordBatchMetrics::default()));
1203
1204 assert_eq!(
1205 metrics.participating_regions(),
1206 Some(std::collections::BTreeSet::new())
1207 );
1208 assert_eq!(
1209 metrics.region_watermark_map(),
1210 Some(std::collections::HashMap::new())
1211 );
1212 }
1213}