1use std::pin::Pin;
16use std::str::FromStr;
17use std::sync::atomic::{AtomicBool, Ordering};
18use std::sync::{Arc, RwLock};
19use std::task::{Context, Poll};
20
21use api::v1::auth_header::AuthScheme;
22use api::v1::ddl_request::Expr as DdlExpr;
23use api::v1::greptime_database_client::GreptimeDatabaseClient;
24use api::v1::greptime_request::Request;
25use api::v1::query_request::Query;
26use api::v1::{
27 AlterTableExpr, AuthHeader, Basic, CreateTableExpr, DdlRequest, GreptimeRequest,
28 InsertRequests, QueryRequest, RequestHeader, RowInsertRequests,
29};
30use arc_swap::ArcSwapOption;
31use arrow_flight::{FlightData, Ticket};
32use async_stream::stream;
33use base64::Engine;
34use base64::prelude::BASE64_STANDARD;
35use common_catalog::build_db_string;
36use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
37use common_error::ext::BoxedError;
38use common_grpc::flight::do_put::DoPutResponse;
39use common_grpc::flight::{
40 FLOW_EXTENSIONS_METADATA_KEY, FlightDecoder, FlightMessage, SNAPSHOT_SEQS_METADATA_KEY,
41};
42use common_query::Output;
43use common_recordbatch::adapter::RecordBatchMetrics;
44use common_recordbatch::error::ExternalSnafu;
45use common_recordbatch::{OrderOption, RecordBatch, RecordBatchStream, RecordBatchStreamWrapper};
46use common_telemetry::tracing::Span;
47use common_telemetry::tracing_context::W3cTrace;
48use common_telemetry::{error, warn};
49use futures::future;
50use futures_util::{Stream, StreamExt, TryStreamExt};
51use prost::Message;
52use snafu::{OptionExt, ResultExt};
53use tonic::metadata::{AsciiMetadataKey, AsciiMetadataValue, MetadataMap, MetadataValue};
54use tonic::transport::Channel;
55
56use crate::error::{
57 ConvertFlightDataSnafu, Error, FlightGetSnafu, IllegalFlightMessagesSnafu,
58 InvalidTonicMetadataValueSnafu,
59};
60use crate::{Client, Result, error, from_grpc_response};
61
62type FlightDataStream = Pin<Box<dyn Stream<Item = FlightData> + Send>>;
63
64type DoPutResponseStream = Pin<Box<dyn Stream<Item = Result<DoPutResponse>>>>;
65
66#[derive(Debug, Clone, Default)]
71pub struct OutputMetrics {
72 inner: Arc<OutputMetricsInner>,
73}
74
75#[derive(Debug, Default)]
76struct OutputMetricsInner {
77 metrics: RwLock<Option<RecordBatchMetrics>>,
78 ready: AtomicBool,
79}
80
81impl OutputMetrics {
82 fn new() -> Self {
83 Self::default()
84 }
85
86 pub fn update(&self, metrics: Option<RecordBatchMetrics>) {
88 *self.inner.metrics.write().unwrap() = metrics;
89 }
90
91 pub fn mark_ready(&self) {
93 let _ = self
94 .inner
95 .ready
96 .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire);
97 }
98
99 pub fn is_ready(&self) -> bool {
103 self.inner.ready.load(Ordering::Acquire)
104 }
105
106 pub fn get(&self) -> Option<RecordBatchMetrics> {
108 self.inner.metrics.read().unwrap().clone()
109 }
110
111 pub fn region_watermark_map(&self) -> Option<std::collections::HashMap<u64, u64>> {
117 Some(
118 self.get()?
119 .region_watermarks
120 .into_iter()
121 .filter_map(|entry| entry.watermark.map(|seq| (entry.region_id, seq)))
122 .collect::<std::collections::HashMap<_, _>>(),
123 )
124 }
125
126 pub fn participating_regions(&self) -> Option<std::collections::BTreeSet<u64>> {
129 Some(
130 self.get()?
131 .region_watermarks
132 .into_iter()
133 .map(|entry| entry.region_id)
134 .collect::<std::collections::BTreeSet<_>>(),
135 )
136 }
137}
138
139#[derive(Debug)]
145pub struct OutputWithMetrics {
146 pub output: Output,
147 pub metrics: OutputMetrics,
148}
149
150impl OutputWithMetrics {
151 pub fn from_output(output: Output) -> Self {
156 let terminal_metrics = OutputMetrics::new();
157 let output = attach_terminal_metrics(output, &terminal_metrics);
158 Self {
159 output,
160 metrics: terminal_metrics,
161 }
162 }
163
164 pub fn region_watermark_map(&self) -> Option<std::collections::HashMap<u64, u64>> {
166 self.metrics.region_watermark_map()
167 }
168
169 pub fn participating_regions(&self) -> Option<std::collections::BTreeSet<u64>> {
171 self.metrics.participating_regions()
172 }
173
174 pub fn into_output(self) -> Output {
176 self.output
177 }
178}
179
180fn parse_terminal_metrics(metrics_json: &str) -> Result<RecordBatchMetrics> {
181 serde_json::from_str(metrics_json).map_err(|e| {
182 IllegalFlightMessagesSnafu {
183 reason: format!("Invalid terminal metrics message: {e}"),
184 }
185 .build()
186 })
187}
188
189struct StreamWithMetrics {
190 stream: common_recordbatch::SendableRecordBatchStream,
191 metrics: OutputMetrics,
192}
193
194impl StreamWithMetrics {
195 fn new(stream: common_recordbatch::SendableRecordBatchStream, metrics: OutputMetrics) -> Self {
196 Self { stream, metrics }
197 }
198
199 fn sync_terminal_metrics(&self) {
200 self.metrics.update(self.stream.metrics());
201 }
202}
203
204impl RecordBatchStream for StreamWithMetrics {
205 fn name(&self) -> &str {
206 self.stream.name()
207 }
208
209 fn schema(&self) -> datatypes::schema::SchemaRef {
210 self.stream.schema()
211 }
212
213 fn output_ordering(&self) -> Option<&[OrderOption]> {
214 self.stream.output_ordering()
215 }
216
217 fn metrics(&self) -> Option<RecordBatchMetrics> {
218 self.sync_terminal_metrics();
219 self.metrics.get()
220 }
221}
222
223impl Stream for StreamWithMetrics {
224 type Item = common_recordbatch::error::Result<RecordBatch>;
225
226 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
227 let polled = Pin::new(&mut self.stream).poll_next(cx);
228 if let Poll::Ready(None) = &polled {
229 self.sync_terminal_metrics();
230 self.metrics.mark_ready();
231 }
232 polled
233 }
234
235 fn size_hint(&self) -> (usize, Option<usize>) {
236 self.stream.size_hint()
237 }
238}
239
240fn attach_terminal_metrics(output: Output, terminal_metrics: &OutputMetrics) -> Output {
241 let Output { data, meta } = output;
242 let data = match data {
243 common_query::OutputData::Stream(stream) => {
244 terminal_metrics.update(stream.metrics());
245 common_query::OutputData::Stream(Box::pin(StreamWithMetrics::new(
246 stream,
247 terminal_metrics.clone(),
248 )))
249 }
250 other => {
251 terminal_metrics.mark_ready();
252 other
253 }
254 };
255 Output::new(data, meta)
256}
257
258async fn output_from_flight_message_stream<S>(
259 mut flight_message_stream: S,
260) -> Result<OutputWithMetrics>
261where
262 S: Stream<Item = Result<FlightMessage>> + Send + Unpin + 'static,
263{
264 let Some(first_flight_message) = flight_message_stream.next().await else {
265 return IllegalFlightMessagesSnafu {
266 reason: "Expect the response not to be empty",
267 }
268 .fail();
269 };
270
271 let first_flight_message = first_flight_message?;
272
273 match first_flight_message {
274 FlightMessage::AffectedRows { rows, metrics } => {
275 let terminal_metrics = OutputMetrics::new();
276 if let Some(metrics) = metrics {
277 terminal_metrics.update(Some(parse_terminal_metrics(&metrics)?));
278 }
279 let next_message = flight_message_stream.next().await.transpose()?;
280 match next_message {
281 None => terminal_metrics.mark_ready(),
282 Some(FlightMessage::Metrics(s)) if terminal_metrics.get().is_none() => {
283 terminal_metrics.update(Some(parse_terminal_metrics(&s)?));
284 terminal_metrics.mark_ready();
285 }
286 Some(FlightMessage::Metrics(_)) => {
287 return IllegalFlightMessagesSnafu {
288 reason: "'AffectedRows' Flight metadata already carries Metrics and cannot be followed by another Metrics message",
289 }
290 .fail();
291 }
292 Some(other) => {
293 return IllegalFlightMessagesSnafu {
294 reason: format!(
295 "'AffectedRows' Flight message can only be followed by a Metrics message, got {other:?}"
296 ),
297 }
298 .fail();
299 }
300 }
301 Ok(OutputWithMetrics {
302 output: Output::new_with_affected_rows(rows),
303 metrics: terminal_metrics,
304 })
305 }
306 FlightMessage::RecordBatch(_) | FlightMessage::Metrics(_) => IllegalFlightMessagesSnafu {
307 reason: "The first flight message cannot be a RecordBatch or Metrics message",
308 }
309 .fail(),
310 FlightMessage::Schema(schema) => {
311 let metrics = Arc::new(ArcSwapOption::from(None));
312 let metrics_ref = metrics.clone();
313 let schema = Arc::new(
314 datatypes::schema::Schema::try_from(schema).context(error::ConvertSchemaSnafu)?,
315 );
316 let schema_cloned = schema.clone();
317 let stream = Box::pin(stream!({
318 while let Some(flight_message_item) = flight_message_stream.next().await {
319 let flight_message = match flight_message_item {
320 Ok(message) => message,
321 Err(e) => {
322 yield Err(BoxedError::new(e)).context(ExternalSnafu);
323 break;
324 }
325 };
326
327 match flight_message {
328 FlightMessage::RecordBatch(arrow_batch) => {
329 yield Ok(RecordBatch::from_df_record_batch(
330 schema_cloned.clone(),
331 arrow_batch,
332 ))
333 }
334 FlightMessage::Metrics(s) => {
335 match parse_terminal_metrics(&s) {
336 Ok(m) => {
337 metrics_ref.swap(Some(Arc::new(m)));
338 }
339 Err(e) => {
340 yield Err(BoxedError::new(e)).context(ExternalSnafu);
341 }
342 };
343 }
344 FlightMessage::AffectedRows { .. } | FlightMessage::Schema(_) => {
345 yield IllegalFlightMessagesSnafu {
346 reason: format!(
347 "A Schema message must be succeeded exclusively by a set of RecordBatch messages, flight_message: {:?}",
348 flight_message
349 )
350 }
351 .fail()
352 .map_err(BoxedError::new)
353 .context(ExternalSnafu);
354 break;
355 }
356 }
357 }
358 }));
359 let record_batch_stream = RecordBatchStreamWrapper {
360 schema,
361 stream,
362 output_ordering: None,
363 metrics,
364 span: Span::current(),
365 };
366 Ok(OutputWithMetrics::from_output(Output::new_with_stream(
367 Box::pin(record_batch_stream),
368 )))
369 }
370 }
371}
372
373#[derive(Clone, Debug, Default)]
374pub struct Database {
375 catalog: String,
379 schema: String,
380 dbname: String,
383 timezone: String,
386
387 client: Client,
388 ctx: FlightContext,
389}
390
391pub struct DatabaseClient {
392 pub addr: String,
393 pub inner: GreptimeDatabaseClient<Channel>,
394}
395
396impl DatabaseClient {
397 pub fn inspect_err<'a>(&'a self, context: &'a str) -> impl Fn(&tonic::Status) + 'a {
399 let addr = &self.addr;
400 move |status| {
401 error!("Failed to {context} request, peer: {addr}, status: {status:?}");
402 }
403 }
404}
405
406fn make_database_client(client: &Client) -> Result<DatabaseClient> {
407 let (addr, channel) = client.find_channel()?;
408 Ok(DatabaseClient {
409 addr,
410 inner: GreptimeDatabaseClient::new(channel)
411 .max_decoding_message_size(client.max_grpc_recv_message_size())
412 .max_encoding_message_size(client.max_grpc_send_message_size()),
413 })
414}
415
416impl Database {
417 pub fn new(catalog: impl Into<String>, schema: impl Into<String>, client: Client) -> Self {
419 Self {
420 catalog: catalog.into(),
421 schema: schema.into(),
422 dbname: String::default(),
423 timezone: String::default(),
424 client,
425 ctx: FlightContext::default(),
426 }
427 }
428
429 pub fn new_with_dbname(dbname: impl Into<String>, client: Client) -> Self {
437 Self {
438 catalog: String::default(),
439 schema: String::default(),
440 timezone: String::default(),
441 dbname: dbname.into(),
442 client,
443 ctx: FlightContext::default(),
444 }
445 }
446
447 pub fn set_catalog(&mut self, catalog: impl Into<String>) {
449 self.catalog = catalog.into();
450 }
451
452 fn catalog_or_default(&self) -> &str {
453 if self.catalog.is_empty() {
454 DEFAULT_CATALOG_NAME
455 } else {
456 &self.catalog
457 }
458 }
459
460 pub fn set_schema(&mut self, schema: impl Into<String>) {
462 self.schema = schema.into();
463 }
464
465 fn schema_or_default(&self) -> &str {
466 if self.schema.is_empty() {
467 DEFAULT_SCHEMA_NAME
468 } else {
469 &self.schema
470 }
471 }
472
473 pub fn set_timezone(&mut self, timezone: impl Into<String>) {
475 self.timezone = timezone.into();
476 }
477
478 pub fn set_auth(&mut self, auth: AuthScheme) {
480 self.ctx.auth_header = Some(AuthHeader {
481 auth_scheme: Some(auth),
482 });
483 }
484
485 pub async fn insert(&self, requests: InsertRequests) -> Result<u32> {
487 self.handle(Request::Inserts(requests)).await
488 }
489
490 pub async fn insert_with_hints(
492 &self,
493 requests: InsertRequests,
494 hints: &[(&str, &str)],
495 ) -> Result<u32> {
496 let mut client = make_database_client(&self.client)?;
497 let request = self.to_rpc_request(Request::Inserts(requests));
498
499 let mut request = tonic::Request::new(request);
500 let metadata = request.metadata_mut();
501 Self::put_hints(metadata, hints)?;
502
503 let response = client
504 .inner
505 .handle(request)
506 .await
507 .inspect_err(client.inspect_err("insert_with_hints"))?
508 .into_inner();
509 from_grpc_response(response)
510 }
511
512 pub async fn row_inserts(&self, requests: RowInsertRequests) -> Result<u32> {
514 self.handle(Request::RowInserts(requests)).await
515 }
516
517 pub async fn row_inserts_with_hints(
519 &self,
520 requests: RowInsertRequests,
521 hints: &[(&str, &str)],
522 ) -> Result<u32> {
523 let mut client = make_database_client(&self.client)?;
524 let request = self.to_rpc_request(Request::RowInserts(requests));
525
526 let mut request = tonic::Request::new(request);
527 let metadata = request.metadata_mut();
528 Self::put_hints(metadata, hints)?;
529
530 let response = client
531 .inner
532 .handle(request)
533 .await
534 .inspect_err(client.inspect_err("row_inserts_with_hints"))?
535 .into_inner();
536 from_grpc_response(response)
537 }
538
539 fn put_hints(metadata: &mut MetadataMap, hints: &[(&str, &str)]) -> Result<()> {
540 let Some(value) = hints
541 .iter()
542 .map(|(k, v)| format!("{}={}", k, v))
543 .reduce(|a, b| format!("{},{}", a, b))
544 else {
545 return Ok(());
546 };
547
548 let key = AsciiMetadataKey::from_static("x-greptime-hints");
549 let value = AsciiMetadataValue::from_str(&value).context(InvalidTonicMetadataValueSnafu)?;
550 metadata.insert(key, value);
551 Ok(())
552 }
553
554 fn put_flow_extensions(
555 metadata: &mut MetadataMap,
556 flow_extensions: &[(&str, &str)],
557 ) -> Result<()> {
558 if flow_extensions.is_empty() {
559 return Ok(());
560 }
561
562 let value = serde_json::to_string(&flow_extensions.to_vec())
563 .expect("flow extension pairs should serialize");
564 Self::put_metadata_value(metadata, FLOW_EXTENSIONS_METADATA_KEY, value)
565 }
566
567 fn put_snapshot_seqs(
568 metadata: &mut MetadataMap,
569 snapshot_seqs: &std::collections::HashMap<u64, u64>,
570 ) -> Result<()> {
571 if snapshot_seqs.is_empty() {
572 return Ok(());
573 }
574
575 let value =
576 serde_json::to_string(snapshot_seqs).expect("snapshot sequence map should serialize");
577 Self::put_metadata_value(metadata, SNAPSHOT_SEQS_METADATA_KEY, value)
578 }
579
580 fn put_metadata_value(
581 metadata: &mut MetadataMap,
582 key: &'static str,
583 value: String,
584 ) -> Result<()> {
585 let key = AsciiMetadataKey::from_static(key);
586 let value = AsciiMetadataValue::from_str(&value).context(InvalidTonicMetadataValueSnafu)?;
587 metadata.insert(key, value);
588 Ok(())
589 }
590
591 pub async fn handle(&self, request: Request) -> Result<u32> {
593 let mut client = make_database_client(&self.client)?;
594 let request = self.to_rpc_request(request);
595 let response = client
596 .inner
597 .handle(request)
598 .await
599 .inspect_err(client.inspect_err("handle"))?
600 .into_inner();
601 from_grpc_response(response)
602 }
603
604 pub async fn handle_with_retry(
607 &self,
608 request: Request,
609 max_retries: u32,
610 hints: &[(&str, &str)],
611 ) -> Result<u32> {
612 let mut client = make_database_client(&self.client)?;
613 let mut retries = 0;
614
615 let request = self.to_rpc_request(request);
616
617 loop {
618 let mut tonic_request = tonic::Request::new(request.clone());
619 let metadata = tonic_request.metadata_mut();
620 Self::put_hints(metadata, hints)?;
621 let raw_response = client
622 .inner
623 .handle(tonic_request)
624 .await
625 .inspect_err(client.inspect_err("handle"));
626 match (raw_response, retries < max_retries) {
627 (Ok(resp), _) => return from_grpc_response(resp.into_inner()),
628 (Err(err), true) => {
629 if is_grpc_retryable(&err) {
631 retries += 1;
633 warn!("Retrying {} times with error = {:?}", retries, err);
634 continue;
635 } else {
636 error!(
637 err; "Failed to send request to grpc handle, retries = {}, not retryable error, aborting",
638 retries
639 );
640 return Err(err.into());
641 }
642 }
643 (Err(err), false) => {
644 error!(
645 err; "Failed to send request to grpc handle after {} retries",
646 retries,
647 );
648 return Err(err.into());
649 }
650 }
651 }
652 }
653
654 #[inline]
655 fn to_rpc_request(&self, request: Request) -> GreptimeRequest {
656 GreptimeRequest {
657 header: Some(RequestHeader {
658 catalog: self.catalog.clone(),
659 schema: self.schema.clone(),
660 authorization: self.ctx.auth_header.clone(),
661 dbname: self.dbname.clone(),
662 timezone: self.timezone.clone(),
663 tracing_context: W3cTrace::new(),
665 }),
666 request: Some(request),
667 }
668 }
669
670 pub async fn sql<S>(&self, sql: S) -> Result<Output>
672 where
673 S: AsRef<str>,
674 {
675 self.sql_with_hint(sql, &[]).await
676 }
677
678 pub async fn sql_with_hint<S>(&self, sql: S, hints: &[(&str, &str)]) -> Result<Output>
680 where
681 S: AsRef<str>,
682 {
683 let request = Request::Query(QueryRequest {
684 query: Some(Query::Sql(sql.as_ref().to_string())),
685 });
686 self.do_get(request, hints, &[], &Default::default())
687 .await
688 .map(OutputWithMetrics::into_output)
689 }
690
691 pub async fn sql_with_terminal_metrics<S>(
696 &self,
697 sql: S,
698 hints: &[(&str, &str)],
699 ) -> Result<OutputWithMetrics>
700 where
701 S: AsRef<str>,
702 {
703 self.query_with_terminal_metrics_and_flow_extensions(
704 QueryRequest {
705 query: Some(Query::Sql(sql.as_ref().to_string())),
706 },
707 hints,
708 &[],
709 &Default::default(),
710 )
711 .await
712 }
713
714 pub async fn logical_plan(&self, logical_plan: Vec<u8>) -> Result<Output> {
716 self.query_with_terminal_metrics_and_flow_extensions(
717 QueryRequest {
718 query: Some(Query::LogicalPlan(logical_plan)),
719 },
720 &[],
721 &[],
722 &Default::default(),
723 )
724 .await
725 .map(OutputWithMetrics::into_output)
726 }
727
728 pub async fn query_with_terminal_metrics_and_flow_extensions(
733 &self,
734 request: QueryRequest,
735 hints: &[(&str, &str)],
736 flow_extensions: &[(&str, &str)],
737 snapshot_seqs: &std::collections::HashMap<u64, u64>,
738 ) -> Result<OutputWithMetrics> {
739 self.do_get(
740 Request::Query(request),
741 hints,
742 flow_extensions,
743 snapshot_seqs,
744 )
745 .await
746 }
747
748 pub async fn create(&self, expr: CreateTableExpr) -> Result<Output> {
750 let request = Request::Ddl(DdlRequest {
751 expr: Some(DdlExpr::CreateTable(expr)),
752 });
753 self.do_get(request, &[], &[], &Default::default())
754 .await
755 .map(OutputWithMetrics::into_output)
756 }
757
758 pub async fn alter(&self, expr: AlterTableExpr) -> Result<Output> {
760 let request = Request::Ddl(DdlRequest {
761 expr: Some(DdlExpr::AlterTable(expr)),
762 });
763 self.do_get(request, &[], &[], &Default::default())
764 .await
765 .map(OutputWithMetrics::into_output)
766 }
767
768 async fn do_get(
769 &self,
770 request: Request,
771 hints: &[(&str, &str)],
772 flow_extensions: &[(&str, &str)],
773 snapshot_seqs: &std::collections::HashMap<u64, u64>,
774 ) -> Result<OutputWithMetrics> {
775 let request = self.to_rpc_request(request);
776 let request = Ticket {
777 ticket: request.encode_to_vec().into(),
778 };
779
780 let mut request = tonic::Request::new(request);
781 let metadata = request.metadata_mut();
782 Self::put_hints(metadata, hints)?;
783 Self::put_flow_extensions(metadata, flow_extensions)?;
784 Self::put_snapshot_seqs(metadata, snapshot_seqs)?;
785
786 let mut client = self.client.make_flight_client(false, false)?;
787
788 let response = client.mut_inner().do_get(request).await.or_else(|e| {
789 let tonic_code = e.code();
790 let e: Error = e.into();
791 error!(
792 "Failed to do Flight get, addr: {}, code: {}, source: {:?}",
793 client.addr(),
794 tonic_code,
795 e
796 );
797 Err(BoxedError::new(e)).with_context(|_| FlightGetSnafu {
798 addr: client.addr().to_string(),
799 tonic_code,
800 })
801 })?;
802
803 let flight_data_stream = response.into_inner();
804 let mut decoder = FlightDecoder::default();
805
806 let flight_message_stream = flight_data_stream.map(move |flight_data| {
807 flight_data
808 .map_err(Error::from)
809 .and_then(|data| decoder.try_decode(&data).context(ConvertFlightDataSnafu))?
810 .context(IllegalFlightMessagesSnafu {
811 reason: "none message",
812 })
813 });
814
815 output_from_flight_message_stream(flight_message_stream).await
816 }
817
818 pub async fn do_put(&self, stream: FlightDataStream) -> Result<DoPutResponseStream> {
821 let mut request = tonic::Request::new(stream);
822
823 if let Some(AuthHeader {
824 auth_scheme: Some(AuthScheme::Basic(Basic { username, password })),
825 }) = &self.ctx.auth_header
826 {
827 let encoded = BASE64_STANDARD.encode(format!("{username}:{password}"));
828 let value = MetadataValue::from_str(&format!("Basic {encoded}"))
829 .context(InvalidTonicMetadataValueSnafu)?;
830 request.metadata_mut().insert("x-greptime-auth", value);
831 }
832
833 let db_to_put = if !self.dbname.is_empty() {
834 &self.dbname
835 } else {
836 &build_db_string(self.catalog_or_default(), self.schema_or_default())
837 };
838 request.metadata_mut().insert(
839 "x-greptime-db-name",
840 MetadataValue::from_str(db_to_put).context(InvalidTonicMetadataValueSnafu)?,
841 );
842
843 let mut client = self.client.make_flight_client(false, false)?;
844 let response = client.mut_inner().do_put(request).await?;
845 let response = response
846 .into_inner()
847 .map_err(Into::into)
848 .and_then(|x| future::ready(DoPutResponse::try_from(x).context(ConvertFlightDataSnafu)))
849 .boxed();
850 Ok(response)
851 }
852}
853
854pub fn is_grpc_retryable(err: &tonic::Status) -> bool {
856 matches!(err.code(), tonic::Code::Unavailable)
857}
858
859#[derive(Default, Debug, Clone)]
860struct FlightContext {
861 auth_header: Option<AuthHeader>,
862}
863
864#[cfg(test)]
865mod tests {
866 use std::sync::Arc;
867 use std::task::{Context, Poll};
868
869 use api::v1::auth_header::AuthScheme;
870 use api::v1::{AuthHeader, Basic};
871 use common_error::ext::{ErrorExt, RetryHint};
872 use common_error::status_code::StatusCode;
873 use common_error::{GREPTIME_DB_HEADER_ERROR_CODE, GREPTIME_DB_HEADER_ERROR_RETRY_HINT};
874 use common_query::OutputData;
875 use common_recordbatch::{OrderOption, RecordBatch, RecordBatchStream};
876 use datatypes::prelude::{ConcreteDataType, VectorRef};
877 use datatypes::schema::{ColumnSchema, Schema};
878 use datatypes::vectors::Int32Vector;
879 use futures_util::StreamExt;
880 use tonic::codegen::http::{HeaderMap, HeaderValue};
881 use tonic::metadata::MetadataMap;
882 use tonic::{Code, Status};
883
884 use super::*;
885 use crate::error::TonicSnafu;
886
887 struct MockMetricsStream {
888 schema: datatypes::schema::SchemaRef,
889 batch: Option<RecordBatch>,
890 metrics: RecordBatchMetrics,
891 terminal_metrics_only: bool,
892 }
893
894 impl Stream for MockMetricsStream {
895 type Item = common_recordbatch::error::Result<RecordBatch>;
896
897 fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
898 Poll::Ready(self.batch.take().map(Ok))
899 }
900 }
901
902 impl RecordBatchStream for MockMetricsStream {
903 fn name(&self) -> &str {
904 "MockMetricsStream"
905 }
906
907 fn schema(&self) -> datatypes::schema::SchemaRef {
908 self.schema.clone()
909 }
910
911 fn output_ordering(&self) -> Option<&[OrderOption]> {
912 None
913 }
914
915 fn metrics(&self) -> Option<RecordBatchMetrics> {
916 if self.terminal_metrics_only && self.batch.is_some() {
917 return None;
918 }
919 Some(self.metrics.clone())
920 }
921 }
922
923 fn terminal_metrics_json() -> String {
924 terminal_metrics_json_with_seq(42)
925 }
926
927 fn terminal_metrics_json_with_seq(seq: u64) -> String {
928 serde_json::to_string(&RecordBatchMetrics {
929 region_watermarks: vec![common_recordbatch::adapter::RegionWatermarkEntry {
930 region_id: 7,
931 watermark: Some(seq),
932 }],
933 ..Default::default()
934 })
935 .unwrap()
936 }
937
938 #[test]
939 fn test_put_flow_extensions_preserves_comma_bearing_values() {
940 let mut metadata = MetadataMap::new();
941 Database::put_flow_extensions(
942 &mut metadata,
943 &[
944 ("flow.return_region_seq", "true"),
945 ("flow.incremental_after_seqs", r#"{"1":10,"2":20}"#),
946 ],
947 )
948 .unwrap();
949
950 let value = metadata
951 .get(FLOW_EXTENSIONS_METADATA_KEY)
952 .unwrap()
953 .to_str()
954 .unwrap();
955 let decoded: Vec<(String, String)> = serde_json::from_str(value).unwrap();
956 assert_eq!(
957 decoded,
958 vec![
959 ("flow.return_region_seq".to_string(), "true".to_string()),
960 (
961 "flow.incremental_after_seqs".to_string(),
962 r#"{"1":10,"2":20}"#.to_string()
963 ),
964 ]
965 );
966 }
967
968 #[test]
969 fn test_put_snapshot_seqs_preserves_u64_precision() {
970 let mut metadata = MetadataMap::new();
971 let snapshot_seqs = std::collections::HashMap::from([
972 (u64::MAX, u64::MAX - 1),
973 (9_007_199_254_740_993_u64, 9_007_199_254_740_995_u64),
974 ]);
975
976 Database::put_snapshot_seqs(&mut metadata, &snapshot_seqs).unwrap();
977
978 let value = metadata
979 .get(SNAPSHOT_SEQS_METADATA_KEY)
980 .unwrap()
981 .to_str()
982 .unwrap();
983 let decoded: std::collections::HashMap<u64, u64> = serde_json::from_str(value).unwrap();
984 assert_eq!(decoded, snapshot_seqs);
985 }
986
987 #[test]
988 fn test_flight_ctx() {
989 let mut ctx = FlightContext::default();
990 assert!(ctx.auth_header.is_none());
991
992 let basic = AuthScheme::Basic(Basic {
993 username: "u".to_string(),
994 password: "p".to_string(),
995 });
996
997 ctx.auth_header = Some(AuthHeader {
998 auth_scheme: Some(basic),
999 });
1000
1001 assert!(matches!(
1002 ctx.auth_header,
1003 Some(AuthHeader {
1004 auth_scheme: Some(AuthScheme::Basic(_)),
1005 })
1006 ));
1007 }
1008
1009 #[test]
1010 fn test_from_tonic_status() {
1011 let expected = TonicSnafu {
1012 code: StatusCode::Internal,
1013 msg: "blabla".to_string(),
1014 tonic_code: Code::Internal,
1015 retry_hint: RetryHint::NonRetryable,
1016 }
1017 .build();
1018
1019 let status = Status::new(Code::Internal, "blabla");
1020 let actual: Error = status.into();
1021
1022 assert_eq!(expected.to_string(), actual.to_string());
1023 assert_eq!(expected.retry_hint(), actual.retry_hint());
1024 assert_eq!(expected.should_retry(), actual.should_retry());
1025 }
1026
1027 #[test]
1028 fn test_from_tonic_status_with_retry_hint() {
1029 let mut headers = HeaderMap::new();
1030 headers.insert(
1031 GREPTIME_DB_HEADER_ERROR_CODE,
1032 HeaderValue::from(StatusCode::Internal as u32),
1033 );
1034 headers.insert(
1035 GREPTIME_DB_HEADER_ERROR_RETRY_HINT,
1036 HeaderValue::from_static(RetryHint::Retryable.as_str()),
1037 );
1038 let status =
1039 Status::with_metadata(Code::Internal, "blabla", MetadataMap::from_headers(headers));
1040
1041 let actual: Error = status.into();
1042
1043 assert_eq!(actual.retry_hint(), RetryHint::Retryable);
1044 assert!(actual.should_retry());
1045 }
1046
1047 #[test]
1048 fn test_from_tonic_status_fallback() {
1049 let mut headers = HeaderMap::new();
1050 headers.insert(
1051 GREPTIME_DB_HEADER_ERROR_CODE,
1052 HeaderValue::from(StatusCode::InvalidArguments as u32),
1053 );
1054 let status =
1055 Status::with_metadata(Code::Internal, "blabla", MetadataMap::from_headers(headers));
1056
1057 let actual: Error = status.into();
1058
1059 assert_eq!(actual.retry_hint(), RetryHint::NonRetryable);
1060 assert!(!actual.should_retry());
1061 }
1062
1063 #[test]
1064 fn test_should_retry_preserves_transport_retry() {
1065 let status = Status::new(Code::Unavailable, "blabla");
1066 let actual: Error = status.into();
1067
1068 assert!(actual.should_retry());
1069 }
1070
1071 #[tokio::test]
1072 async fn test_query_with_terminal_metrics_tracks_terminal_only_metrics() {
1073 let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
1074 "v",
1075 ConcreteDataType::int32_datatype(),
1076 false,
1077 )]));
1078 let batch = RecordBatch::new(
1079 schema.clone(),
1080 vec![Arc::new(Int32Vector::from_slice([1, 2])) as VectorRef],
1081 )
1082 .unwrap();
1083 let output = Output::new_with_stream(Box::pin(MockMetricsStream {
1084 schema,
1085 batch: Some(batch),
1086 metrics: RecordBatchMetrics {
1087 region_watermarks: vec![common_recordbatch::adapter::RegionWatermarkEntry {
1088 region_id: 7,
1089 watermark: Some(42),
1090 }],
1091 ..Default::default()
1092 },
1093 terminal_metrics_only: true,
1094 }));
1095
1096 let result = OutputWithMetrics::from_output(output);
1097 let terminal_metrics = result.metrics.clone();
1098 assert!(!terminal_metrics.is_ready());
1099 assert!(terminal_metrics.get().is_none());
1100
1101 let OutputData::Stream(mut stream) = result.output.data else {
1102 panic!("expected stream output");
1103 };
1104 while stream.next().await.is_some() {}
1105
1106 assert!(terminal_metrics.is_ready());
1107 assert_eq!(
1108 terminal_metrics.participating_regions(),
1109 Some(std::collections::BTreeSet::from([7_u64]))
1110 );
1111 assert_eq!(
1112 terminal_metrics.region_watermark_map(),
1113 Some(std::collections::HashMap::from([(7_u64, 42_u64)]))
1114 );
1115 }
1116
1117 #[test]
1118 fn test_parse_terminal_metrics_rejects_invalid_json() {
1119 assert!(parse_terminal_metrics("{not-json}").is_err());
1120 }
1121
1122 #[tokio::test]
1123 async fn test_affected_rows_inline_metrics_are_parsed() {
1124 let output = output_from_flight_message_stream(futures_util::stream::iter(vec![Ok(
1125 FlightMessage::AffectedRows {
1126 rows: 3,
1127 metrics: Some(terminal_metrics_json()),
1128 },
1129 )]
1130 as Vec<Result<FlightMessage>>))
1131 .await
1132 .unwrap();
1133
1134 assert!(matches!(output.output.data, OutputData::AffectedRows(3)));
1135 assert!(output.metrics.is_ready());
1136 assert_eq!(
1137 output.metrics.region_watermark_map(),
1138 Some(std::collections::HashMap::from([(7, 42)]))
1139 );
1140 }
1141
1142 #[tokio::test]
1143 async fn test_affected_rows_inline_metrics_rejects_trailing_metrics() {
1144 let metrics_json = terminal_metrics_json();
1145 let err = output_from_flight_message_stream(futures_util::stream::iter(vec![
1146 Ok(FlightMessage::AffectedRows {
1147 rows: 3,
1148 metrics: Some(metrics_json.clone()),
1149 }),
1150 Ok(FlightMessage::Metrics(metrics_json)),
1151 ]
1152 as Vec<Result<FlightMessage>>))
1153 .await
1154 .unwrap_err();
1155
1156 assert!(
1157 err.to_string().contains("already carries Metrics"),
1158 "unexpected error: {err:?}"
1159 );
1160 }
1161
1162 #[tokio::test]
1163 async fn test_invalid_terminal_metrics_after_record_batch_yields_batch_then_error() {
1164 let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
1165 "v",
1166 ConcreteDataType::int32_datatype(),
1167 false,
1168 )]));
1169 let batch = RecordBatch::new(
1170 schema.clone(),
1171 vec![Arc::new(Int32Vector::from_slice([1])) as VectorRef],
1172 )
1173 .unwrap();
1174 let output = output_from_flight_message_stream(futures_util::stream::iter(vec![
1175 Ok(FlightMessage::Schema(schema.arrow_schema().clone())),
1176 Ok(FlightMessage::RecordBatch(batch.into_df_record_batch())),
1177 Ok(FlightMessage::Metrics("{not-json}".to_string())),
1178 ]
1179 as Vec<Result<FlightMessage>>))
1180 .await
1181 .unwrap();
1182 let terminal_metrics = output.metrics.clone();
1183 let OutputData::Stream(mut record_batch_stream) = output.output.data else {
1184 panic!("expected stream output");
1185 };
1186
1187 let batch = record_batch_stream.next().await.unwrap().unwrap();
1188 assert_eq!(batch.num_rows(), 1);
1189
1190 let err = record_batch_stream.next().await.unwrap().unwrap_err();
1191 assert_eq!("External error", err.to_string());
1192 assert!(
1193 format!("{err:?}").contains("Invalid terminal metrics message"),
1194 "unexpected error: {err:?}"
1195 );
1196 assert!(record_batch_stream.next().await.is_none());
1197 assert!(terminal_metrics.is_ready());
1198 assert!(terminal_metrics.get().is_none());
1199 }
1200
1201 #[tokio::test]
1202 async fn test_record_batch_stream_continues_after_partial_metrics() {
1203 let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
1204 "v",
1205 ConcreteDataType::int32_datatype(),
1206 false,
1207 )]));
1208 let first_batch = RecordBatch::new(
1209 schema.clone(),
1210 vec![Arc::new(Int32Vector::from_slice([1])) as VectorRef],
1211 )
1212 .unwrap();
1213 let second_batch = RecordBatch::new(
1214 schema.clone(),
1215 vec![Arc::new(Int32Vector::from_slice([2])) as VectorRef],
1216 )
1217 .unwrap();
1218 let output = output_from_flight_message_stream(futures_util::stream::iter(vec![
1219 Ok(FlightMessage::Schema(schema.arrow_schema().clone())),
1220 Ok(FlightMessage::RecordBatch(
1221 first_batch.into_df_record_batch(),
1222 )),
1223 Ok(FlightMessage::Metrics(terminal_metrics_json_with_seq(1))),
1224 Ok(FlightMessage::RecordBatch(
1225 second_batch.into_df_record_batch(),
1226 )),
1227 Ok(FlightMessage::Metrics(terminal_metrics_json_with_seq(2))),
1228 ]
1229 as Vec<Result<FlightMessage>>))
1230 .await
1231 .unwrap();
1232 let terminal_metrics = output.metrics.clone();
1233 let OutputData::Stream(mut record_batch_stream) = output.output.data else {
1234 panic!("expected stream output");
1235 };
1236
1237 let first_batch = record_batch_stream.next().await.unwrap().unwrap();
1238 assert_eq!(first_batch.num_rows(), 1);
1239 let second_batch = record_batch_stream.next().await.unwrap().unwrap();
1240 assert_eq!(second_batch.num_rows(), 1);
1241 assert!(record_batch_stream.next().await.is_none());
1242
1243 assert!(terminal_metrics.is_ready());
1244 assert_eq!(
1245 terminal_metrics.region_watermark_map(),
1246 Some(std::collections::HashMap::from([(7, 2)]))
1247 );
1248 }
1249
1250 #[test]
1251 fn test_output_metrics_distinguishes_empty_region_watermarks_from_absence() {
1252 let metrics = OutputMetrics::default();
1253 metrics.update(Some(RecordBatchMetrics::default()));
1254
1255 assert_eq!(
1256 metrics.participating_regions(),
1257 Some(std::collections::BTreeSet::new())
1258 );
1259 assert_eq!(
1260 metrics.region_watermark_map(),
1261 Some(std::collections::HashMap::new())
1262 );
1263 }
1264}