Skip to main content

client/
database.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::pin::Pin;
16use std::str::FromStr;
17use std::sync::atomic::{AtomicBool, Ordering};
18use std::sync::{Arc, RwLock};
19use std::task::{Context, Poll};
20
21use api::v1::auth_header::AuthScheme;
22use api::v1::ddl_request::Expr as DdlExpr;
23use api::v1::greptime_database_client::GreptimeDatabaseClient;
24use api::v1::greptime_request::Request;
25use api::v1::query_request::Query;
26use api::v1::{
27    AlterTableExpr, AuthHeader, Basic, CreateTableExpr, DdlRequest, GreptimeRequest,
28    InsertRequests, QueryRequest, RequestHeader, RowInsertRequests,
29};
30use arc_swap::ArcSwapOption;
31use arrow_flight::{FlightData, Ticket};
32use async_stream::stream;
33use base64::Engine;
34use base64::prelude::BASE64_STANDARD;
35use common_catalog::build_db_string;
36use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
37use common_error::ext::BoxedError;
38use common_grpc::flight::do_put::DoPutResponse;
39use common_grpc::flight::{FLOW_EXTENSIONS_METADATA_KEY, FlightDecoder, FlightMessage};
40use common_query::Output;
41use common_recordbatch::adapter::RecordBatchMetrics;
42use common_recordbatch::error::ExternalSnafu;
43use common_recordbatch::{OrderOption, RecordBatch, RecordBatchStream, RecordBatchStreamWrapper};
44use common_telemetry::tracing::Span;
45use common_telemetry::tracing_context::W3cTrace;
46use common_telemetry::{error, warn};
47use futures::future;
48use futures_util::{Stream, StreamExt, TryStreamExt};
49use prost::Message;
50use snafu::{OptionExt, ResultExt};
51use tonic::metadata::{AsciiMetadataKey, AsciiMetadataValue, MetadataMap, MetadataValue};
52use tonic::transport::Channel;
53
54use crate::error::{
55    ConvertFlightDataSnafu, Error, FlightGetSnafu, IllegalFlightMessagesSnafu,
56    InvalidTonicMetadataValueSnafu,
57};
58use crate::{Client, Result, error, from_grpc_response};
59
60type FlightDataStream = Pin<Box<dyn Stream<Item = FlightData> + Send>>;
61
62type DoPutResponseStream = Pin<Box<dyn Stream<Item = Result<DoPutResponse>>>>;
63
64/// Terminal metrics associated with a query output.
65///
66/// For streaming outputs, metrics are only final after the stream is fully
67/// drained and [`Self::is_ready`] returns `true`.
68#[derive(Debug, Clone, Default)]
69pub struct OutputMetrics {
70    inner: Arc<OutputMetricsInner>,
71}
72
73#[derive(Debug, Default)]
74struct OutputMetricsInner {
75    metrics: RwLock<Option<RecordBatchMetrics>>,
76    ready: AtomicBool,
77}
78
79impl OutputMetrics {
80    fn new() -> Self {
81        Self::default()
82    }
83
84    /// Replaces the current terminal metrics snapshot.
85    pub fn update(&self, metrics: Option<RecordBatchMetrics>) {
86        *self.inner.metrics.write().unwrap() = metrics;
87    }
88
89    /// Marks the terminal metrics as final for this output.
90    pub fn mark_ready(&self) {
91        let _ = self
92            .inner
93            .ready
94            .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire);
95    }
96
97    /// Returns whether terminal metrics are final.
98    ///
99    /// Streaming outputs become ready only after the stream reaches EOF.
100    pub fn is_ready(&self) -> bool {
101        self.inner.ready.load(Ordering::Acquire)
102    }
103
104    /// Returns the latest terminal metrics snapshot, if any.
105    pub fn get(&self) -> Option<RecordBatchMetrics> {
106        self.inner.metrics.read().unwrap().clone()
107    }
108
109    /// Returns proved per-region watermarks.
110    ///
111    /// Entries whose watermark is `None` are intentionally omitted because they
112    /// represent participating regions whose terminal sequence bound was not
113    /// provable.
114    pub fn region_watermark_map(&self) -> Option<std::collections::HashMap<u64, u64>> {
115        Some(
116            self.get()?
117                .region_watermarks
118                .into_iter()
119                .filter_map(|entry| entry.watermark.map(|seq| (entry.region_id, seq)))
120                .collect::<std::collections::HashMap<_, _>>(),
121        )
122    }
123
124    /// Returns all regions that participated in terminal metric collection,
125    /// including entries whose watermark is `None`.
126    pub fn participating_regions(&self) -> Option<std::collections::BTreeSet<u64>> {
127        Some(
128            self.get()?
129                .region_watermarks
130                .into_iter()
131                .map(|entry| entry.region_id)
132                .collect::<std::collections::BTreeSet<_>>(),
133        )
134    }
135}
136
137/// Query output together with a handle for its terminal metrics.
138///
139/// The contained [`OutputMetrics`] lets callers read stream terminal metrics
140/// after consuming `output`. For non-stream outputs, metrics are ready
141/// immediately.
142#[derive(Debug)]
143pub struct OutputWithMetrics {
144    pub output: Output,
145    pub metrics: OutputMetrics,
146}
147
148impl OutputWithMetrics {
149    /// Wraps an output with a terminal metrics handle.
150    ///
151    /// Stream outputs update the handle as the stream is consumed. Non-stream
152    /// outputs are marked ready immediately.
153    pub fn from_output(output: Output) -> Self {
154        let terminal_metrics = OutputMetrics::new();
155        let output = attach_terminal_metrics(output, &terminal_metrics);
156        Self {
157            output,
158            metrics: terminal_metrics,
159        }
160    }
161
162    /// Returns proved per-region watermarks from the terminal metrics.
163    pub fn region_watermark_map(&self) -> Option<std::collections::HashMap<u64, u64>> {
164        self.metrics.region_watermark_map()
165    }
166
167    /// Returns all regions participating in terminal metric collection.
168    pub fn participating_regions(&self) -> Option<std::collections::BTreeSet<u64>> {
169        self.metrics.participating_regions()
170    }
171
172    /// Drops the terminal metrics handle and returns the original output.
173    pub fn into_output(self) -> Output {
174        self.output
175    }
176}
177
178fn parse_terminal_metrics(metrics_json: &str) -> Result<RecordBatchMetrics> {
179    serde_json::from_str(metrics_json).map_err(|e| {
180        IllegalFlightMessagesSnafu {
181            reason: format!("Invalid terminal metrics message: {e}"),
182        }
183        .build()
184    })
185}
186
187struct StreamWithMetrics {
188    stream: common_recordbatch::SendableRecordBatchStream,
189    metrics: OutputMetrics,
190}
191
192impl StreamWithMetrics {
193    fn new(stream: common_recordbatch::SendableRecordBatchStream, metrics: OutputMetrics) -> Self {
194        Self { stream, metrics }
195    }
196
197    fn sync_terminal_metrics(&self) {
198        self.metrics.update(self.stream.metrics());
199    }
200}
201
202impl RecordBatchStream for StreamWithMetrics {
203    fn name(&self) -> &str {
204        self.stream.name()
205    }
206
207    fn schema(&self) -> datatypes::schema::SchemaRef {
208        self.stream.schema()
209    }
210
211    fn output_ordering(&self) -> Option<&[OrderOption]> {
212        self.stream.output_ordering()
213    }
214
215    fn metrics(&self) -> Option<RecordBatchMetrics> {
216        self.sync_terminal_metrics();
217        self.metrics.get()
218    }
219}
220
221impl Stream for StreamWithMetrics {
222    type Item = common_recordbatch::error::Result<RecordBatch>;
223
224    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
225        let polled = Pin::new(&mut self.stream).poll_next(cx);
226        if let Poll::Ready(None) = &polled {
227            self.sync_terminal_metrics();
228            self.metrics.mark_ready();
229        }
230        polled
231    }
232
233    fn size_hint(&self) -> (usize, Option<usize>) {
234        self.stream.size_hint()
235    }
236}
237
238fn attach_terminal_metrics(output: Output, terminal_metrics: &OutputMetrics) -> Output {
239    let Output { data, meta } = output;
240    let data = match data {
241        common_query::OutputData::Stream(stream) => {
242            terminal_metrics.update(stream.metrics());
243            common_query::OutputData::Stream(Box::pin(StreamWithMetrics::new(
244                stream,
245                terminal_metrics.clone(),
246            )))
247        }
248        other => {
249            terminal_metrics.mark_ready();
250            other
251        }
252    };
253    Output::new(data, meta)
254}
255
256async fn output_from_flight_message_stream<S>(
257    mut flight_message_stream: S,
258) -> Result<OutputWithMetrics>
259where
260    S: Stream<Item = Result<FlightMessage>> + Send + Unpin + 'static,
261{
262    let Some(first_flight_message) = flight_message_stream.next().await else {
263        return IllegalFlightMessagesSnafu {
264            reason: "Expect the response not to be empty",
265        }
266        .fail();
267    };
268
269    let first_flight_message = first_flight_message?;
270
271    match first_flight_message {
272        FlightMessage::AffectedRows { rows, metrics } => {
273            let terminal_metrics = OutputMetrics::new();
274            if let Some(metrics) = metrics {
275                terminal_metrics.update(Some(parse_terminal_metrics(&metrics)?));
276            }
277            let next_message = flight_message_stream.next().await.transpose()?;
278            match next_message {
279                None => terminal_metrics.mark_ready(),
280                Some(FlightMessage::Metrics(s)) if terminal_metrics.get().is_none() => {
281                    terminal_metrics.update(Some(parse_terminal_metrics(&s)?));
282                    terminal_metrics.mark_ready();
283                }
284                Some(FlightMessage::Metrics(_)) => {
285                    return IllegalFlightMessagesSnafu {
286                        reason: "'AffectedRows' Flight metadata already carries Metrics and cannot be followed by another Metrics message",
287                    }
288                    .fail();
289                }
290                Some(other) => {
291                    return IllegalFlightMessagesSnafu {
292                        reason: format!(
293                            "'AffectedRows' Flight message can only be followed by a Metrics message, got {other:?}"
294                        ),
295                    }
296                    .fail();
297                }
298            }
299            Ok(OutputWithMetrics {
300                output: Output::new_with_affected_rows(rows),
301                metrics: terminal_metrics,
302            })
303        }
304        FlightMessage::RecordBatch(_) | FlightMessage::Metrics(_) => IllegalFlightMessagesSnafu {
305            reason: "The first flight message cannot be a RecordBatch or Metrics message",
306        }
307        .fail(),
308        FlightMessage::Schema(schema) => {
309            let metrics = Arc::new(ArcSwapOption::from(None));
310            let metrics_ref = metrics.clone();
311            let schema = Arc::new(
312                datatypes::schema::Schema::try_from(schema).context(error::ConvertSchemaSnafu)?,
313            );
314            let schema_cloned = schema.clone();
315            let stream = Box::pin(stream!({
316                while let Some(flight_message_item) = flight_message_stream.next().await {
317                    let flight_message = match flight_message_item {
318                        Ok(message) => message,
319                        Err(e) => {
320                            yield Err(BoxedError::new(e)).context(ExternalSnafu);
321                            break;
322                        }
323                    };
324
325                    match flight_message {
326                        FlightMessage::RecordBatch(arrow_batch) => {
327                            yield Ok(RecordBatch::from_df_record_batch(
328                                schema_cloned.clone(),
329                                arrow_batch,
330                            ))
331                        }
332                        FlightMessage::Metrics(s) => {
333                            match parse_terminal_metrics(&s) {
334                                Ok(m) => {
335                                    metrics_ref.swap(Some(Arc::new(m)));
336                                }
337                                Err(e) => {
338                                    yield Err(BoxedError::new(e)).context(ExternalSnafu);
339                                }
340                            };
341                        }
342                        FlightMessage::AffectedRows { .. } | FlightMessage::Schema(_) => {
343                            yield IllegalFlightMessagesSnafu {
344                                reason: format!(
345                                    "A Schema message must be succeeded exclusively by a set of RecordBatch messages, flight_message: {:?}",
346                                    flight_message
347                                )
348                            }
349                            .fail()
350                            .map_err(BoxedError::new)
351                            .context(ExternalSnafu);
352                            break;
353                        }
354                    }
355                }
356            }));
357            let record_batch_stream = RecordBatchStreamWrapper {
358                schema,
359                stream,
360                output_ordering: None,
361                metrics,
362                span: Span::current(),
363            };
364            Ok(OutputWithMetrics::from_output(Output::new_with_stream(
365                Box::pin(record_batch_stream),
366            )))
367        }
368    }
369}
370
371#[derive(Clone, Debug, Default)]
372pub struct Database {
373    // The "catalog" and "schema" to be used in processing the requests at the server side.
374    // They are the "hint" or "context", just like how the "database" in "USE" statement is treated in MySQL.
375    // They will be carried in the request header.
376    catalog: String,
377    schema: String,
378    // The dbname follows naming rule as out mysql, postgres and http
379    // protocol. The server treat dbname in priority of catalog/schema.
380    dbname: String,
381    // The time zone indicates the time zone where the user is located.
382    // Some queries need to be aware of the user's time zone to perform some specific actions.
383    timezone: String,
384
385    client: Client,
386    ctx: FlightContext,
387}
388
389pub struct DatabaseClient {
390    pub addr: String,
391    pub inner: GreptimeDatabaseClient<Channel>,
392}
393
394impl DatabaseClient {
395    /// Returns a closure that logs the error when the request fails.
396    pub fn inspect_err<'a>(&'a self, context: &'a str) -> impl Fn(&tonic::Status) + 'a {
397        let addr = &self.addr;
398        move |status| {
399            error!("Failed to {context} request, peer: {addr}, status: {status:?}");
400        }
401    }
402}
403
404fn make_database_client(client: &Client) -> Result<DatabaseClient> {
405    let (addr, channel) = client.find_channel()?;
406    Ok(DatabaseClient {
407        addr,
408        inner: GreptimeDatabaseClient::new(channel)
409            .max_decoding_message_size(client.max_grpc_recv_message_size())
410            .max_encoding_message_size(client.max_grpc_send_message_size()),
411    })
412}
413
414impl Database {
415    /// Create database service client using catalog and schema
416    pub fn new(catalog: impl Into<String>, schema: impl Into<String>, client: Client) -> Self {
417        Self {
418            catalog: catalog.into(),
419            schema: schema.into(),
420            dbname: String::default(),
421            timezone: String::default(),
422            client,
423            ctx: FlightContext::default(),
424        }
425    }
426
427    /// Create database service client using dbname.
428    ///
429    /// This API is designed for external usage. `dbname` is:
430    ///
431    /// - the name of database when using GreptimeDB standalone or cluster
432    /// - the name provided by GreptimeCloud or other multi-tenant GreptimeDB
433    ///   environment
434    pub fn new_with_dbname(dbname: impl Into<String>, client: Client) -> Self {
435        Self {
436            catalog: String::default(),
437            schema: String::default(),
438            timezone: String::default(),
439            dbname: dbname.into(),
440            client,
441            ctx: FlightContext::default(),
442        }
443    }
444
445    /// Set the catalog for the database client.
446    pub fn set_catalog(&mut self, catalog: impl Into<String>) {
447        self.catalog = catalog.into();
448    }
449
450    fn catalog_or_default(&self) -> &str {
451        if self.catalog.is_empty() {
452            DEFAULT_CATALOG_NAME
453        } else {
454            &self.catalog
455        }
456    }
457
458    /// Set the schema for the database client.
459    pub fn set_schema(&mut self, schema: impl Into<String>) {
460        self.schema = schema.into();
461    }
462
463    fn schema_or_default(&self) -> &str {
464        if self.schema.is_empty() {
465            DEFAULT_SCHEMA_NAME
466        } else {
467            &self.schema
468        }
469    }
470
471    /// Set the timezone for the database client.
472    pub fn set_timezone(&mut self, timezone: impl Into<String>) {
473        self.timezone = timezone.into();
474    }
475
476    /// Set the auth scheme for the database client.
477    pub fn set_auth(&mut self, auth: AuthScheme) {
478        self.ctx.auth_header = Some(AuthHeader {
479            auth_scheme: Some(auth),
480        });
481    }
482
483    /// Make an InsertRequests request to the database.
484    pub async fn insert(&self, requests: InsertRequests) -> Result<u32> {
485        self.handle(Request::Inserts(requests)).await
486    }
487
488    /// Make an InsertRequests request to the database with hints.
489    pub async fn insert_with_hints(
490        &self,
491        requests: InsertRequests,
492        hints: &[(&str, &str)],
493    ) -> Result<u32> {
494        let mut client = make_database_client(&self.client)?;
495        let request = self.to_rpc_request(Request::Inserts(requests));
496
497        let mut request = tonic::Request::new(request);
498        let metadata = request.metadata_mut();
499        Self::put_hints(metadata, hints)?;
500
501        let response = client
502            .inner
503            .handle(request)
504            .await
505            .inspect_err(client.inspect_err("insert_with_hints"))?
506            .into_inner();
507        from_grpc_response(response)
508    }
509
510    /// Make a RowInsertRequests request to the database.
511    pub async fn row_inserts(&self, requests: RowInsertRequests) -> Result<u32> {
512        self.handle(Request::RowInserts(requests)).await
513    }
514
515    /// Make a RowInsertRequests request to the database with hints.
516    pub async fn row_inserts_with_hints(
517        &self,
518        requests: RowInsertRequests,
519        hints: &[(&str, &str)],
520    ) -> Result<u32> {
521        let mut client = make_database_client(&self.client)?;
522        let request = self.to_rpc_request(Request::RowInserts(requests));
523
524        let mut request = tonic::Request::new(request);
525        let metadata = request.metadata_mut();
526        Self::put_hints(metadata, hints)?;
527
528        let response = client
529            .inner
530            .handle(request)
531            .await
532            .inspect_err(client.inspect_err("row_inserts_with_hints"))?
533            .into_inner();
534        from_grpc_response(response)
535    }
536
537    fn put_hints(metadata: &mut MetadataMap, hints: &[(&str, &str)]) -> Result<()> {
538        let Some(value) = hints
539            .iter()
540            .map(|(k, v)| format!("{}={}", k, v))
541            .reduce(|a, b| format!("{},{}", a, b))
542        else {
543            return Ok(());
544        };
545
546        let key = AsciiMetadataKey::from_static("x-greptime-hints");
547        let value = AsciiMetadataValue::from_str(&value).context(InvalidTonicMetadataValueSnafu)?;
548        metadata.insert(key, value);
549        Ok(())
550    }
551
552    fn put_flow_extensions(
553        metadata: &mut MetadataMap,
554        flow_extensions: &[(&str, &str)],
555    ) -> Result<()> {
556        if flow_extensions.is_empty() {
557            return Ok(());
558        }
559
560        let value = serde_json::to_string(&flow_extensions.to_vec())
561            .expect("flow extension pairs should serialize");
562        let key = AsciiMetadataKey::from_static(FLOW_EXTENSIONS_METADATA_KEY);
563        let value = AsciiMetadataValue::from_str(&value).context(InvalidTonicMetadataValueSnafu)?;
564        metadata.insert(key, value);
565        Ok(())
566    }
567
568    /// Make a request to the database.
569    pub async fn handle(&self, request: Request) -> Result<u32> {
570        let mut client = make_database_client(&self.client)?;
571        let request = self.to_rpc_request(request);
572        let response = client
573            .inner
574            .handle(request)
575            .await
576            .inspect_err(client.inspect_err("handle"))?
577            .into_inner();
578        from_grpc_response(response)
579    }
580
581    /// Retry if connection fails, max_retries is the max number of retries, so the total wait time
582    /// is `max_retries * GRPC_CONN_TIMEOUT`
583    pub async fn handle_with_retry(
584        &self,
585        request: Request,
586        max_retries: u32,
587        hints: &[(&str, &str)],
588    ) -> Result<u32> {
589        let mut client = make_database_client(&self.client)?;
590        let mut retries = 0;
591
592        let request = self.to_rpc_request(request);
593
594        loop {
595            let mut tonic_request = tonic::Request::new(request.clone());
596            let metadata = tonic_request.metadata_mut();
597            Self::put_hints(metadata, hints)?;
598            let raw_response = client
599                .inner
600                .handle(tonic_request)
601                .await
602                .inspect_err(client.inspect_err("handle"));
603            match (raw_response, retries < max_retries) {
604                (Ok(resp), _) => return from_grpc_response(resp.into_inner()),
605                (Err(err), true) => {
606                    // determine if the error is retryable
607                    if is_grpc_retryable(&err) {
608                        // retry
609                        retries += 1;
610                        warn!("Retrying {} times with error = {:?}", retries, err);
611                        continue;
612                    } else {
613                        error!(
614                            err; "Failed to send request to grpc handle, retries = {}, not retryable error, aborting",
615                            retries
616                        );
617                        return Err(err.into());
618                    }
619                }
620                (Err(err), false) => {
621                    error!(
622                        err; "Failed to send request to grpc handle after {} retries",
623                        retries,
624                    );
625                    return Err(err.into());
626                }
627            }
628        }
629    }
630
631    #[inline]
632    fn to_rpc_request(&self, request: Request) -> GreptimeRequest {
633        GreptimeRequest {
634            header: Some(RequestHeader {
635                catalog: self.catalog.clone(),
636                schema: self.schema.clone(),
637                authorization: self.ctx.auth_header.clone(),
638                dbname: self.dbname.clone(),
639                timezone: self.timezone.clone(),
640                // TODO(Taylor-lagrange): add client grpc tracing
641                tracing_context: W3cTrace::new(),
642            }),
643            request: Some(request),
644        }
645    }
646
647    /// Executes a SQL query without any hints.
648    pub async fn sql<S>(&self, sql: S) -> Result<Output>
649    where
650        S: AsRef<str>,
651    {
652        self.sql_with_hint(sql, &[]).await
653    }
654
655    /// Executes a SQL query with optional hints for query optimization.
656    pub async fn sql_with_hint<S>(&self, sql: S, hints: &[(&str, &str)]) -> Result<Output>
657    where
658        S: AsRef<str>,
659    {
660        let request = Request::Query(QueryRequest {
661            query: Some(Query::Sql(sql.as_ref().to_string())),
662        });
663        self.do_get(request, hints, &[])
664            .await
665            .map(OutputWithMetrics::into_output)
666    }
667
668    /// Executes a SQL query and returns the output with terminal metrics.
669    ///
670    /// For stream outputs, callers must consume the stream before reading final
671    /// terminal metrics from [`OutputWithMetrics::metrics`].
672    pub async fn sql_with_terminal_metrics<S>(
673        &self,
674        sql: S,
675        hints: &[(&str, &str)],
676    ) -> Result<OutputWithMetrics>
677    where
678        S: AsRef<str>,
679    {
680        self.query_with_terminal_metrics_and_flow_extensions(
681            QueryRequest {
682                query: Some(Query::Sql(sql.as_ref().to_string())),
683            },
684            hints,
685            &[],
686        )
687        .await
688    }
689
690    /// Executes a logical plan directly without SQL parsing.
691    pub async fn logical_plan(&self, logical_plan: Vec<u8>) -> Result<Output> {
692        self.query_with_terminal_metrics_and_flow_extensions(
693            QueryRequest {
694                query: Some(Query::LogicalPlan(logical_plan)),
695            },
696            &[],
697            &[],
698        )
699        .await
700        .map(OutputWithMetrics::into_output)
701    }
702
703    /// Executes a query and carries flow extensions through Flight metadata.
704    ///
705    /// This is the lower-level terminal-metrics API for Flow callers that need
706    /// to pass JSON-bearing flow extensions without going through hint metadata.
707    pub async fn query_with_terminal_metrics_and_flow_extensions(
708        &self,
709        request: QueryRequest,
710        hints: &[(&str, &str)],
711        flow_extensions: &[(&str, &str)],
712    ) -> Result<OutputWithMetrics> {
713        self.do_get(Request::Query(request), hints, flow_extensions)
714            .await
715    }
716
717    /// Creates a new table using the provided table expression.
718    pub async fn create(&self, expr: CreateTableExpr) -> Result<Output> {
719        let request = Request::Ddl(DdlRequest {
720            expr: Some(DdlExpr::CreateTable(expr)),
721        });
722        self.do_get(request, &[], &[])
723            .await
724            .map(OutputWithMetrics::into_output)
725    }
726
727    /// Alters an existing table using the provided alter expression.
728    pub async fn alter(&self, expr: AlterTableExpr) -> Result<Output> {
729        let request = Request::Ddl(DdlRequest {
730            expr: Some(DdlExpr::AlterTable(expr)),
731        });
732        self.do_get(request, &[], &[])
733            .await
734            .map(OutputWithMetrics::into_output)
735    }
736
737    async fn do_get(
738        &self,
739        request: Request,
740        hints: &[(&str, &str)],
741        flow_extensions: &[(&str, &str)],
742    ) -> Result<OutputWithMetrics> {
743        let request = self.to_rpc_request(request);
744        let request = Ticket {
745            ticket: request.encode_to_vec().into(),
746        };
747
748        let mut request = tonic::Request::new(request);
749        let metadata = request.metadata_mut();
750        Self::put_hints(metadata, hints)?;
751        Self::put_flow_extensions(metadata, flow_extensions)?;
752
753        let mut client = self.client.make_flight_client(false, false)?;
754
755        let response = client.mut_inner().do_get(request).await.or_else(|e| {
756            let tonic_code = e.code();
757            let e: Error = e.into();
758            error!(
759                "Failed to do Flight get, addr: {}, code: {}, source: {:?}",
760                client.addr(),
761                tonic_code,
762                e
763            );
764            Err(BoxedError::new(e)).with_context(|_| FlightGetSnafu {
765                addr: client.addr().to_string(),
766                tonic_code,
767            })
768        })?;
769
770        let flight_data_stream = response.into_inner();
771        let mut decoder = FlightDecoder::default();
772
773        let flight_message_stream = flight_data_stream.map(move |flight_data| {
774            flight_data
775                .map_err(Error::from)
776                .and_then(|data| decoder.try_decode(&data).context(ConvertFlightDataSnafu))?
777                .context(IllegalFlightMessagesSnafu {
778                    reason: "none message",
779                })
780        });
781
782        output_from_flight_message_stream(flight_message_stream).await
783    }
784
785    /// Ingest a stream of [RecordBatch]es that belong to a table, using Arrow Flight's "`DoPut`"
786    /// method. The return value is also a stream, produces [DoPutResponse]s.
787    pub async fn do_put(&self, stream: FlightDataStream) -> Result<DoPutResponseStream> {
788        let mut request = tonic::Request::new(stream);
789
790        if let Some(AuthHeader {
791            auth_scheme: Some(AuthScheme::Basic(Basic { username, password })),
792        }) = &self.ctx.auth_header
793        {
794            let encoded = BASE64_STANDARD.encode(format!("{username}:{password}"));
795            let value = MetadataValue::from_str(&format!("Basic {encoded}"))
796                .context(InvalidTonicMetadataValueSnafu)?;
797            request.metadata_mut().insert("x-greptime-auth", value);
798        }
799
800        let db_to_put = if !self.dbname.is_empty() {
801            &self.dbname
802        } else {
803            &build_db_string(self.catalog_or_default(), self.schema_or_default())
804        };
805        request.metadata_mut().insert(
806            "x-greptime-db-name",
807            MetadataValue::from_str(db_to_put).context(InvalidTonicMetadataValueSnafu)?,
808        );
809
810        let mut client = self.client.make_flight_client(false, false)?;
811        let response = client.mut_inner().do_put(request).await?;
812        let response = response
813            .into_inner()
814            .map_err(Into::into)
815            .and_then(|x| future::ready(DoPutResponse::try_from(x).context(ConvertFlightDataSnafu)))
816            .boxed();
817        Ok(response)
818    }
819}
820
821/// by grpc standard, only `Unavailable` is retryable, see: https://github.com/grpc/grpc/blob/master/doc/statuscodes.md#status-codes-and-their-use-in-grpc
822pub fn is_grpc_retryable(err: &tonic::Status) -> bool {
823    matches!(err.code(), tonic::Code::Unavailable)
824}
825
826#[derive(Default, Debug, Clone)]
827struct FlightContext {
828    auth_header: Option<AuthHeader>,
829}
830
831#[cfg(test)]
832mod tests {
833    use std::sync::Arc;
834    use std::task::{Context, Poll};
835
836    use api::v1::auth_header::AuthScheme;
837    use api::v1::{AuthHeader, Basic};
838    use common_error::status_code::StatusCode;
839    use common_query::OutputData;
840    use common_recordbatch::{OrderOption, RecordBatch, RecordBatchStream};
841    use datatypes::prelude::{ConcreteDataType, VectorRef};
842    use datatypes::schema::{ColumnSchema, Schema};
843    use datatypes::vectors::Int32Vector;
844    use futures_util::StreamExt;
845    use tonic::{Code, Status};
846
847    use super::*;
848    use crate::error::TonicSnafu;
849
850    struct MockMetricsStream {
851        schema: datatypes::schema::SchemaRef,
852        batch: Option<RecordBatch>,
853        metrics: RecordBatchMetrics,
854        terminal_metrics_only: bool,
855    }
856
857    impl Stream for MockMetricsStream {
858        type Item = common_recordbatch::error::Result<RecordBatch>;
859
860        fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
861            Poll::Ready(self.batch.take().map(Ok))
862        }
863    }
864
865    impl RecordBatchStream for MockMetricsStream {
866        fn name(&self) -> &str {
867            "MockMetricsStream"
868        }
869
870        fn schema(&self) -> datatypes::schema::SchemaRef {
871            self.schema.clone()
872        }
873
874        fn output_ordering(&self) -> Option<&[OrderOption]> {
875            None
876        }
877
878        fn metrics(&self) -> Option<RecordBatchMetrics> {
879            if self.terminal_metrics_only && self.batch.is_some() {
880                return None;
881            }
882            Some(self.metrics.clone())
883        }
884    }
885
886    fn terminal_metrics_json() -> String {
887        terminal_metrics_json_with_seq(42)
888    }
889
890    fn terminal_metrics_json_with_seq(seq: u64) -> String {
891        serde_json::to_string(&RecordBatchMetrics {
892            region_watermarks: vec![common_recordbatch::adapter::RegionWatermarkEntry {
893                region_id: 7,
894                watermark: Some(seq),
895            }],
896            ..Default::default()
897        })
898        .unwrap()
899    }
900
901    #[test]
902    fn test_put_flow_extensions_preserves_comma_bearing_values() {
903        let mut metadata = MetadataMap::new();
904        Database::put_flow_extensions(
905            &mut metadata,
906            &[
907                ("flow.return_region_seq", "true"),
908                ("flow.incremental_after_seqs", r#"{"1":10,"2":20}"#),
909            ],
910        )
911        .unwrap();
912
913        let value = metadata
914            .get(FLOW_EXTENSIONS_METADATA_KEY)
915            .unwrap()
916            .to_str()
917            .unwrap();
918        let decoded: Vec<(String, String)> = serde_json::from_str(value).unwrap();
919        assert_eq!(
920            decoded,
921            vec![
922                ("flow.return_region_seq".to_string(), "true".to_string()),
923                (
924                    "flow.incremental_after_seqs".to_string(),
925                    r#"{"1":10,"2":20}"#.to_string()
926                ),
927            ]
928        );
929    }
930
931    #[test]
932    fn test_flight_ctx() {
933        let mut ctx = FlightContext::default();
934        assert!(ctx.auth_header.is_none());
935
936        let basic = AuthScheme::Basic(Basic {
937            username: "u".to_string(),
938            password: "p".to_string(),
939        });
940
941        ctx.auth_header = Some(AuthHeader {
942            auth_scheme: Some(basic),
943        });
944
945        assert!(matches!(
946            ctx.auth_header,
947            Some(AuthHeader {
948                auth_scheme: Some(AuthScheme::Basic(_)),
949            })
950        ));
951    }
952
953    #[test]
954    fn test_from_tonic_status() {
955        let expected = TonicSnafu {
956            code: StatusCode::Internal,
957            msg: "blabla".to_string(),
958            tonic_code: Code::Internal,
959        }
960        .build();
961
962        let status = Status::new(Code::Internal, "blabla");
963        let actual: Error = status.into();
964
965        assert_eq!(expected.to_string(), actual.to_string());
966    }
967
968    #[tokio::test]
969    async fn test_query_with_terminal_metrics_tracks_terminal_only_metrics() {
970        let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
971            "v",
972            ConcreteDataType::int32_datatype(),
973            false,
974        )]));
975        let batch = RecordBatch::new(
976            schema.clone(),
977            vec![Arc::new(Int32Vector::from_slice([1, 2])) as VectorRef],
978        )
979        .unwrap();
980        let output = Output::new_with_stream(Box::pin(MockMetricsStream {
981            schema,
982            batch: Some(batch),
983            metrics: RecordBatchMetrics {
984                region_watermarks: vec![common_recordbatch::adapter::RegionWatermarkEntry {
985                    region_id: 7,
986                    watermark: Some(42),
987                }],
988                ..Default::default()
989            },
990            terminal_metrics_only: true,
991        }));
992
993        let result = OutputWithMetrics::from_output(output);
994        let terminal_metrics = result.metrics.clone();
995        assert!(!terminal_metrics.is_ready());
996        assert!(terminal_metrics.get().is_none());
997
998        let OutputData::Stream(mut stream) = result.output.data else {
999            panic!("expected stream output");
1000        };
1001        while stream.next().await.is_some() {}
1002
1003        assert!(terminal_metrics.is_ready());
1004        assert_eq!(
1005            terminal_metrics.participating_regions(),
1006            Some(std::collections::BTreeSet::from([7_u64]))
1007        );
1008        assert_eq!(
1009            terminal_metrics.region_watermark_map(),
1010            Some(std::collections::HashMap::from([(7_u64, 42_u64)]))
1011        );
1012    }
1013
1014    #[test]
1015    fn test_parse_terminal_metrics_rejects_invalid_json() {
1016        assert!(parse_terminal_metrics("{not-json}").is_err());
1017    }
1018
1019    #[tokio::test]
1020    async fn test_affected_rows_inline_metrics_are_parsed() {
1021        let output = output_from_flight_message_stream(futures_util::stream::iter(vec![Ok(
1022            FlightMessage::AffectedRows {
1023                rows: 3,
1024                metrics: Some(terminal_metrics_json()),
1025            },
1026        )]
1027            as Vec<Result<FlightMessage>>))
1028        .await
1029        .unwrap();
1030
1031        assert!(matches!(output.output.data, OutputData::AffectedRows(3)));
1032        assert!(output.metrics.is_ready());
1033        assert_eq!(
1034            output.metrics.region_watermark_map(),
1035            Some(std::collections::HashMap::from([(7, 42)]))
1036        );
1037    }
1038
1039    #[tokio::test]
1040    async fn test_affected_rows_inline_metrics_rejects_trailing_metrics() {
1041        let metrics_json = terminal_metrics_json();
1042        let err = output_from_flight_message_stream(futures_util::stream::iter(vec![
1043            Ok(FlightMessage::AffectedRows {
1044                rows: 3,
1045                metrics: Some(metrics_json.clone()),
1046            }),
1047            Ok(FlightMessage::Metrics(metrics_json)),
1048        ]
1049            as Vec<Result<FlightMessage>>))
1050        .await
1051        .unwrap_err();
1052
1053        assert!(
1054            err.to_string().contains("already carries Metrics"),
1055            "unexpected error: {err:?}"
1056        );
1057    }
1058
1059    #[tokio::test]
1060    async fn test_invalid_terminal_metrics_after_record_batch_yields_batch_then_error() {
1061        let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
1062            "v",
1063            ConcreteDataType::int32_datatype(),
1064            false,
1065        )]));
1066        let batch = RecordBatch::new(
1067            schema.clone(),
1068            vec![Arc::new(Int32Vector::from_slice([1])) as VectorRef],
1069        )
1070        .unwrap();
1071        let output = output_from_flight_message_stream(futures_util::stream::iter(vec![
1072            Ok(FlightMessage::Schema(schema.arrow_schema().clone())),
1073            Ok(FlightMessage::RecordBatch(batch.into_df_record_batch())),
1074            Ok(FlightMessage::Metrics("{not-json}".to_string())),
1075        ]
1076            as Vec<Result<FlightMessage>>))
1077        .await
1078        .unwrap();
1079        let terminal_metrics = output.metrics.clone();
1080        let OutputData::Stream(mut record_batch_stream) = output.output.data else {
1081            panic!("expected stream output");
1082        };
1083
1084        let batch = record_batch_stream.next().await.unwrap().unwrap();
1085        assert_eq!(batch.num_rows(), 1);
1086
1087        let err = record_batch_stream.next().await.unwrap().unwrap_err();
1088        assert_eq!("External error", err.to_string());
1089        assert!(
1090            format!("{err:?}").contains("Invalid terminal metrics message"),
1091            "unexpected error: {err:?}"
1092        );
1093        assert!(record_batch_stream.next().await.is_none());
1094        assert!(terminal_metrics.is_ready());
1095        assert!(terminal_metrics.get().is_none());
1096    }
1097
1098    #[tokio::test]
1099    async fn test_record_batch_stream_continues_after_partial_metrics() {
1100        let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
1101            "v",
1102            ConcreteDataType::int32_datatype(),
1103            false,
1104        )]));
1105        let first_batch = RecordBatch::new(
1106            schema.clone(),
1107            vec![Arc::new(Int32Vector::from_slice([1])) as VectorRef],
1108        )
1109        .unwrap();
1110        let second_batch = RecordBatch::new(
1111            schema.clone(),
1112            vec![Arc::new(Int32Vector::from_slice([2])) as VectorRef],
1113        )
1114        .unwrap();
1115        let output = output_from_flight_message_stream(futures_util::stream::iter(vec![
1116            Ok(FlightMessage::Schema(schema.arrow_schema().clone())),
1117            Ok(FlightMessage::RecordBatch(
1118                first_batch.into_df_record_batch(),
1119            )),
1120            Ok(FlightMessage::Metrics(terminal_metrics_json_with_seq(1))),
1121            Ok(FlightMessage::RecordBatch(
1122                second_batch.into_df_record_batch(),
1123            )),
1124            Ok(FlightMessage::Metrics(terminal_metrics_json_with_seq(2))),
1125        ]
1126            as Vec<Result<FlightMessage>>))
1127        .await
1128        .unwrap();
1129        let terminal_metrics = output.metrics.clone();
1130        let OutputData::Stream(mut record_batch_stream) = output.output.data else {
1131            panic!("expected stream output");
1132        };
1133
1134        let first_batch = record_batch_stream.next().await.unwrap().unwrap();
1135        assert_eq!(first_batch.num_rows(), 1);
1136        let second_batch = record_batch_stream.next().await.unwrap().unwrap();
1137        assert_eq!(second_batch.num_rows(), 1);
1138        assert!(record_batch_stream.next().await.is_none());
1139
1140        assert!(terminal_metrics.is_ready());
1141        assert_eq!(
1142            terminal_metrics.region_watermark_map(),
1143            Some(std::collections::HashMap::from([(7, 2)]))
1144        );
1145    }
1146
1147    #[test]
1148    fn test_output_metrics_distinguishes_empty_region_watermarks_from_absence() {
1149        let metrics = OutputMetrics::default();
1150        metrics.update(Some(RecordBatchMetrics::default()));
1151
1152        assert_eq!(
1153            metrics.participating_regions(),
1154            Some(std::collections::BTreeSet::new())
1155        );
1156        assert_eq!(
1157            metrics.region_watermark_map(),
1158            Some(std::collections::HashMap::new())
1159        );
1160    }
1161}