Skip to main content

client/
database.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::pin::Pin;
16use std::str::FromStr;
17use std::sync::atomic::{AtomicBool, Ordering};
18use std::sync::{Arc, RwLock};
19use std::task::{Context, Poll};
20
21use api::v1::auth_header::AuthScheme;
22use api::v1::ddl_request::Expr as DdlExpr;
23use api::v1::greptime_database_client::GreptimeDatabaseClient;
24use api::v1::greptime_request::Request;
25use api::v1::query_request::Query;
26use api::v1::{
27    AlterTableExpr, AuthHeader, Basic, CreateTableExpr, DdlRequest, GreptimeRequest,
28    InsertRequests, QueryRequest, RequestHeader, RowInsertRequests,
29};
30use arc_swap::ArcSwapOption;
31use arrow_flight::{FlightData, Ticket};
32use async_stream::stream;
33use base64::Engine;
34use base64::prelude::BASE64_STANDARD;
35use common_catalog::build_db_string;
36use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
37use common_error::ext::BoxedError;
38use common_grpc::flight::do_put::DoPutResponse;
39use common_grpc::flight::{
40    FLOW_EXTENSIONS_METADATA_KEY, FlightDecoder, FlightMessage, SNAPSHOT_SEQS_METADATA_KEY,
41};
42use common_query::Output;
43use common_recordbatch::adapter::RecordBatchMetrics;
44use common_recordbatch::error::ExternalSnafu;
45use common_recordbatch::{OrderOption, RecordBatch, RecordBatchStream, RecordBatchStreamWrapper};
46use common_telemetry::tracing::Span;
47use common_telemetry::tracing_context::W3cTrace;
48use common_telemetry::{error, warn};
49use futures::future;
50use futures_util::{Stream, StreamExt, TryStreamExt};
51use prost::Message;
52use snafu::{OptionExt, ResultExt};
53use tonic::metadata::{AsciiMetadataKey, AsciiMetadataValue, MetadataMap, MetadataValue};
54use tonic::transport::Channel;
55
56use crate::error::{
57    ConvertFlightDataSnafu, Error, FlightGetSnafu, IllegalFlightMessagesSnafu,
58    InvalidTonicMetadataValueSnafu,
59};
60use crate::{Client, Result, error, from_grpc_response};
61
62type FlightDataStream = Pin<Box<dyn Stream<Item = FlightData> + Send>>;
63
64type DoPutResponseStream = Pin<Box<dyn Stream<Item = Result<DoPutResponse>>>>;
65
66/// Terminal metrics associated with a query output.
67///
68/// For streaming outputs, metrics are only final after the stream is fully
69/// drained and [`Self::is_ready`] returns `true`.
70#[derive(Debug, Clone, Default)]
71pub struct OutputMetrics {
72    inner: Arc<OutputMetricsInner>,
73}
74
75#[derive(Debug, Default)]
76struct OutputMetricsInner {
77    metrics: RwLock<Option<RecordBatchMetrics>>,
78    ready: AtomicBool,
79}
80
81impl OutputMetrics {
82    fn new() -> Self {
83        Self::default()
84    }
85
86    /// Replaces the current terminal metrics snapshot.
87    pub fn update(&self, metrics: Option<RecordBatchMetrics>) {
88        *self.inner.metrics.write().unwrap() = metrics;
89    }
90
91    /// Marks the terminal metrics as final for this output.
92    pub fn mark_ready(&self) {
93        let _ = self
94            .inner
95            .ready
96            .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire);
97    }
98
99    /// Returns whether terminal metrics are final.
100    ///
101    /// Streaming outputs become ready only after the stream reaches EOF.
102    pub fn is_ready(&self) -> bool {
103        self.inner.ready.load(Ordering::Acquire)
104    }
105
106    /// Returns the latest terminal metrics snapshot, if any.
107    pub fn get(&self) -> Option<RecordBatchMetrics> {
108        self.inner.metrics.read().unwrap().clone()
109    }
110
111    /// Returns proved per-region watermarks.
112    ///
113    /// Entries whose watermark is `None` are intentionally omitted because they
114    /// represent participating regions whose terminal sequence bound was not
115    /// provable.
116    pub fn region_watermark_map(&self) -> Option<std::collections::HashMap<u64, u64>> {
117        Some(
118            self.get()?
119                .region_watermarks
120                .into_iter()
121                .filter_map(|entry| entry.watermark.map(|seq| (entry.region_id, seq)))
122                .collect::<std::collections::HashMap<_, _>>(),
123        )
124    }
125
126    /// Returns all regions that participated in terminal metric collection,
127    /// including entries whose watermark is `None`.
128    pub fn participating_regions(&self) -> Option<std::collections::BTreeSet<u64>> {
129        Some(
130            self.get()?
131                .region_watermarks
132                .into_iter()
133                .map(|entry| entry.region_id)
134                .collect::<std::collections::BTreeSet<_>>(),
135        )
136    }
137}
138
139/// Query output together with a handle for its terminal metrics.
140///
141/// The contained [`OutputMetrics`] lets callers read stream terminal metrics
142/// after consuming `output`. For non-stream outputs, metrics are ready
143/// immediately.
144#[derive(Debug)]
145pub struct OutputWithMetrics {
146    pub output: Output,
147    pub metrics: OutputMetrics,
148}
149
150impl OutputWithMetrics {
151    /// Wraps an output with a terminal metrics handle.
152    ///
153    /// Stream outputs update the handle as the stream is consumed. Non-stream
154    /// outputs are marked ready immediately.
155    pub fn from_output(output: Output) -> Self {
156        let terminal_metrics = OutputMetrics::new();
157        let output = attach_terminal_metrics(output, &terminal_metrics);
158        Self {
159            output,
160            metrics: terminal_metrics,
161        }
162    }
163
164    /// Returns proved per-region watermarks from the terminal metrics.
165    pub fn region_watermark_map(&self) -> Option<std::collections::HashMap<u64, u64>> {
166        self.metrics.region_watermark_map()
167    }
168
169    /// Returns all regions participating in terminal metric collection.
170    pub fn participating_regions(&self) -> Option<std::collections::BTreeSet<u64>> {
171        self.metrics.participating_regions()
172    }
173
174    /// Drops the terminal metrics handle and returns the original output.
175    pub fn into_output(self) -> Output {
176        self.output
177    }
178}
179
180fn parse_terminal_metrics(metrics_json: &str) -> Result<RecordBatchMetrics> {
181    serde_json::from_str(metrics_json).map_err(|e| {
182        IllegalFlightMessagesSnafu {
183            reason: format!("Invalid terminal metrics message: {e}"),
184        }
185        .build()
186    })
187}
188
189struct StreamWithMetrics {
190    stream: common_recordbatch::SendableRecordBatchStream,
191    metrics: OutputMetrics,
192}
193
194impl StreamWithMetrics {
195    fn new(stream: common_recordbatch::SendableRecordBatchStream, metrics: OutputMetrics) -> Self {
196        Self { stream, metrics }
197    }
198
199    fn sync_terminal_metrics(&self) {
200        self.metrics.update(self.stream.metrics());
201    }
202}
203
204impl RecordBatchStream for StreamWithMetrics {
205    fn name(&self) -> &str {
206        self.stream.name()
207    }
208
209    fn schema(&self) -> datatypes::schema::SchemaRef {
210        self.stream.schema()
211    }
212
213    fn output_ordering(&self) -> Option<&[OrderOption]> {
214        self.stream.output_ordering()
215    }
216
217    fn metrics(&self) -> Option<RecordBatchMetrics> {
218        self.sync_terminal_metrics();
219        self.metrics.get()
220    }
221}
222
223impl Stream for StreamWithMetrics {
224    type Item = common_recordbatch::error::Result<RecordBatch>;
225
226    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
227        let polled = Pin::new(&mut self.stream).poll_next(cx);
228        if let Poll::Ready(None) = &polled {
229            self.sync_terminal_metrics();
230            self.metrics.mark_ready();
231        }
232        polled
233    }
234
235    fn size_hint(&self) -> (usize, Option<usize>) {
236        self.stream.size_hint()
237    }
238}
239
240fn attach_terminal_metrics(output: Output, terminal_metrics: &OutputMetrics) -> Output {
241    let Output { data, meta } = output;
242    let data = match data {
243        common_query::OutputData::Stream(stream) => {
244            terminal_metrics.update(stream.metrics());
245            common_query::OutputData::Stream(Box::pin(StreamWithMetrics::new(
246                stream,
247                terminal_metrics.clone(),
248            )))
249        }
250        other => {
251            terminal_metrics.mark_ready();
252            other
253        }
254    };
255    Output::new(data, meta)
256}
257
258async fn output_from_flight_message_stream<S>(
259    mut flight_message_stream: S,
260) -> Result<OutputWithMetrics>
261where
262    S: Stream<Item = Result<FlightMessage>> + Send + Unpin + 'static,
263{
264    let Some(first_flight_message) = flight_message_stream.next().await else {
265        return IllegalFlightMessagesSnafu {
266            reason: "Expect the response not to be empty",
267        }
268        .fail();
269    };
270
271    let first_flight_message = first_flight_message?;
272
273    match first_flight_message {
274        FlightMessage::AffectedRows { rows, metrics } => {
275            let terminal_metrics = OutputMetrics::new();
276            if let Some(metrics) = metrics {
277                terminal_metrics.update(Some(parse_terminal_metrics(&metrics)?));
278            }
279            let next_message = flight_message_stream.next().await.transpose()?;
280            match next_message {
281                None => terminal_metrics.mark_ready(),
282                Some(FlightMessage::Metrics(s)) if terminal_metrics.get().is_none() => {
283                    terminal_metrics.update(Some(parse_terminal_metrics(&s)?));
284                    terminal_metrics.mark_ready();
285                }
286                Some(FlightMessage::Metrics(_)) => {
287                    return IllegalFlightMessagesSnafu {
288                        reason: "'AffectedRows' Flight metadata already carries Metrics and cannot be followed by another Metrics message",
289                    }
290                    .fail();
291                }
292                Some(other) => {
293                    return IllegalFlightMessagesSnafu {
294                        reason: format!(
295                            "'AffectedRows' Flight message can only be followed by a Metrics message, got {other:?}"
296                        ),
297                    }
298                    .fail();
299                }
300            }
301            Ok(OutputWithMetrics {
302                output: Output::new_with_affected_rows(rows),
303                metrics: terminal_metrics,
304            })
305        }
306        FlightMessage::RecordBatch(_) | FlightMessage::Metrics(_) => IllegalFlightMessagesSnafu {
307            reason: "The first flight message cannot be a RecordBatch or Metrics message",
308        }
309        .fail(),
310        FlightMessage::Schema(schema) => {
311            let metrics = Arc::new(ArcSwapOption::from(None));
312            let metrics_ref = metrics.clone();
313            let schema = Arc::new(
314                datatypes::schema::Schema::try_from(schema).context(error::ConvertSchemaSnafu)?,
315            );
316            let schema_cloned = schema.clone();
317            let stream = Box::pin(stream!({
318                while let Some(flight_message_item) = flight_message_stream.next().await {
319                    let flight_message = match flight_message_item {
320                        Ok(message) => message,
321                        Err(e) => {
322                            yield Err(BoxedError::new(e)).context(ExternalSnafu);
323                            break;
324                        }
325                    };
326
327                    match flight_message {
328                        FlightMessage::RecordBatch(arrow_batch) => {
329                            yield Ok(RecordBatch::from_df_record_batch(
330                                schema_cloned.clone(),
331                                arrow_batch,
332                            ))
333                        }
334                        FlightMessage::Metrics(s) => {
335                            match parse_terminal_metrics(&s) {
336                                Ok(m) => {
337                                    metrics_ref.swap(Some(Arc::new(m)));
338                                }
339                                Err(e) => {
340                                    yield Err(BoxedError::new(e)).context(ExternalSnafu);
341                                }
342                            };
343                        }
344                        FlightMessage::AffectedRows { .. } | FlightMessage::Schema(_) => {
345                            yield IllegalFlightMessagesSnafu {
346                                reason: format!(
347                                    "A Schema message must be succeeded exclusively by a set of RecordBatch messages, flight_message: {:?}",
348                                    flight_message
349                                )
350                            }
351                            .fail()
352                            .map_err(BoxedError::new)
353                            .context(ExternalSnafu);
354                            break;
355                        }
356                    }
357                }
358            }));
359            let record_batch_stream = RecordBatchStreamWrapper {
360                schema,
361                stream,
362                output_ordering: None,
363                metrics,
364                span: Span::current(),
365            };
366            Ok(OutputWithMetrics::from_output(Output::new_with_stream(
367                Box::pin(record_batch_stream),
368            )))
369        }
370    }
371}
372
373#[derive(Clone, Debug, Default)]
374pub struct Database {
375    // The "catalog" and "schema" to be used in processing the requests at the server side.
376    // They are the "hint" or "context", just like how the "database" in "USE" statement is treated in MySQL.
377    // They will be carried in the request header.
378    catalog: String,
379    schema: String,
380    // The dbname follows naming rule as out mysql, postgres and http
381    // protocol. The server treat dbname in priority of catalog/schema.
382    dbname: String,
383    // The time zone indicates the time zone where the user is located.
384    // Some queries need to be aware of the user's time zone to perform some specific actions.
385    timezone: String,
386
387    client: Client,
388    ctx: FlightContext,
389}
390
391pub struct DatabaseClient {
392    pub addr: String,
393    pub inner: GreptimeDatabaseClient<Channel>,
394}
395
396impl DatabaseClient {
397    /// Returns a closure that logs the error when the request fails.
398    pub fn inspect_err<'a>(&'a self, context: &'a str) -> impl Fn(&tonic::Status) + 'a {
399        let addr = &self.addr;
400        move |status| {
401            error!("Failed to {context} request, peer: {addr}, status: {status:?}");
402        }
403    }
404}
405
406fn make_database_client(client: &Client) -> Result<DatabaseClient> {
407    let (addr, channel) = client.find_channel()?;
408    Ok(DatabaseClient {
409        addr,
410        inner: GreptimeDatabaseClient::new(channel)
411            .max_decoding_message_size(client.max_grpc_recv_message_size())
412            .max_encoding_message_size(client.max_grpc_send_message_size()),
413    })
414}
415
416impl Database {
417    /// Create database service client using catalog and schema
418    pub fn new(catalog: impl Into<String>, schema: impl Into<String>, client: Client) -> Self {
419        Self {
420            catalog: catalog.into(),
421            schema: schema.into(),
422            dbname: String::default(),
423            timezone: String::default(),
424            client,
425            ctx: FlightContext::default(),
426        }
427    }
428
429    /// Create database service client using dbname.
430    ///
431    /// This API is designed for external usage. `dbname` is:
432    ///
433    /// - the name of database when using GreptimeDB standalone or cluster
434    /// - the name provided by GreptimeCloud or other multi-tenant GreptimeDB
435    ///   environment
436    pub fn new_with_dbname(dbname: impl Into<String>, client: Client) -> Self {
437        Self {
438            catalog: String::default(),
439            schema: String::default(),
440            timezone: String::default(),
441            dbname: dbname.into(),
442            client,
443            ctx: FlightContext::default(),
444        }
445    }
446
447    /// Set the catalog for the database client.
448    pub fn set_catalog(&mut self, catalog: impl Into<String>) {
449        self.catalog = catalog.into();
450    }
451
452    fn catalog_or_default(&self) -> &str {
453        if self.catalog.is_empty() {
454            DEFAULT_CATALOG_NAME
455        } else {
456            &self.catalog
457        }
458    }
459
460    /// Set the schema for the database client.
461    pub fn set_schema(&mut self, schema: impl Into<String>) {
462        self.schema = schema.into();
463    }
464
465    fn schema_or_default(&self) -> &str {
466        if self.schema.is_empty() {
467            DEFAULT_SCHEMA_NAME
468        } else {
469            &self.schema
470        }
471    }
472
473    /// Set the timezone for the database client.
474    pub fn set_timezone(&mut self, timezone: impl Into<String>) {
475        self.timezone = timezone.into();
476    }
477
478    /// Set the auth scheme for the database client.
479    pub fn set_auth(&mut self, auth: AuthScheme) {
480        self.ctx.auth_header = Some(AuthHeader {
481            auth_scheme: Some(auth),
482        });
483    }
484
485    /// Make an InsertRequests request to the database.
486    pub async fn insert(&self, requests: InsertRequests) -> Result<u32> {
487        self.handle(Request::Inserts(requests)).await
488    }
489
490    /// Make an InsertRequests request to the database with hints.
491    pub async fn insert_with_hints(
492        &self,
493        requests: InsertRequests,
494        hints: &[(&str, &str)],
495    ) -> Result<u32> {
496        let mut client = make_database_client(&self.client)?;
497        let request = self.to_rpc_request(Request::Inserts(requests));
498
499        let mut request = tonic::Request::new(request);
500        let metadata = request.metadata_mut();
501        Self::put_hints(metadata, hints)?;
502
503        let response = client
504            .inner
505            .handle(request)
506            .await
507            .inspect_err(client.inspect_err("insert_with_hints"))?
508            .into_inner();
509        from_grpc_response(response)
510    }
511
512    /// Make a RowInsertRequests request to the database.
513    pub async fn row_inserts(&self, requests: RowInsertRequests) -> Result<u32> {
514        self.handle(Request::RowInserts(requests)).await
515    }
516
517    /// Make a RowInsertRequests request to the database with hints.
518    pub async fn row_inserts_with_hints(
519        &self,
520        requests: RowInsertRequests,
521        hints: &[(&str, &str)],
522    ) -> Result<u32> {
523        let mut client = make_database_client(&self.client)?;
524        let request = self.to_rpc_request(Request::RowInserts(requests));
525
526        let mut request = tonic::Request::new(request);
527        let metadata = request.metadata_mut();
528        Self::put_hints(metadata, hints)?;
529
530        let response = client
531            .inner
532            .handle(request)
533            .await
534            .inspect_err(client.inspect_err("row_inserts_with_hints"))?
535            .into_inner();
536        from_grpc_response(response)
537    }
538
539    fn put_hints(metadata: &mut MetadataMap, hints: &[(&str, &str)]) -> Result<()> {
540        let Some(value) = hints
541            .iter()
542            .map(|(k, v)| format!("{}={}", k, v))
543            .reduce(|a, b| format!("{},{}", a, b))
544        else {
545            return Ok(());
546        };
547
548        let key = AsciiMetadataKey::from_static("x-greptime-hints");
549        let value = AsciiMetadataValue::from_str(&value).context(InvalidTonicMetadataValueSnafu)?;
550        metadata.insert(key, value);
551        Ok(())
552    }
553
554    fn put_flow_extensions(
555        metadata: &mut MetadataMap,
556        flow_extensions: &[(&str, &str)],
557    ) -> Result<()> {
558        if flow_extensions.is_empty() {
559            return Ok(());
560        }
561
562        let value = serde_json::to_string(&flow_extensions.to_vec())
563            .expect("flow extension pairs should serialize");
564        Self::put_metadata_value(metadata, FLOW_EXTENSIONS_METADATA_KEY, value)
565    }
566
567    fn put_snapshot_seqs(
568        metadata: &mut MetadataMap,
569        snapshot_seqs: &std::collections::HashMap<u64, u64>,
570    ) -> Result<()> {
571        if snapshot_seqs.is_empty() {
572            return Ok(());
573        }
574
575        let value =
576            serde_json::to_string(snapshot_seqs).expect("snapshot sequence map should serialize");
577        Self::put_metadata_value(metadata, SNAPSHOT_SEQS_METADATA_KEY, value)
578    }
579
580    fn put_metadata_value(
581        metadata: &mut MetadataMap,
582        key: &'static str,
583        value: String,
584    ) -> Result<()> {
585        let key = AsciiMetadataKey::from_static(key);
586        let value = AsciiMetadataValue::from_str(&value).context(InvalidTonicMetadataValueSnafu)?;
587        metadata.insert(key, value);
588        Ok(())
589    }
590
591    /// Make a request to the database.
592    pub async fn handle(&self, request: Request) -> Result<u32> {
593        let mut client = make_database_client(&self.client)?;
594        let request = self.to_rpc_request(request);
595        let response = client
596            .inner
597            .handle(request)
598            .await
599            .inspect_err(client.inspect_err("handle"))?
600            .into_inner();
601        from_grpc_response(response)
602    }
603
604    /// Retry if connection fails, max_retries is the max number of retries, so the total wait time
605    /// is `max_retries * GRPC_CONN_TIMEOUT`
606    pub async fn handle_with_retry(
607        &self,
608        request: Request,
609        max_retries: u32,
610        hints: &[(&str, &str)],
611    ) -> Result<u32> {
612        let mut client = make_database_client(&self.client)?;
613        let mut retries = 0;
614
615        let request = self.to_rpc_request(request);
616
617        loop {
618            let mut tonic_request = tonic::Request::new(request.clone());
619            let metadata = tonic_request.metadata_mut();
620            Self::put_hints(metadata, hints)?;
621            let raw_response = client
622                .inner
623                .handle(tonic_request)
624                .await
625                .inspect_err(client.inspect_err("handle"));
626            match (raw_response, retries < max_retries) {
627                (Ok(resp), _) => return from_grpc_response(resp.into_inner()),
628                (Err(err), true) => {
629                    // determine if the error is retryable
630                    if is_grpc_retryable(&err) {
631                        // retry
632                        retries += 1;
633                        warn!("Retrying {} times with error = {:?}", retries, err);
634                        continue;
635                    } else {
636                        error!(
637                            err; "Failed to send request to grpc handle, retries = {}, not retryable error, aborting",
638                            retries
639                        );
640                        return Err(err.into());
641                    }
642                }
643                (Err(err), false) => {
644                    error!(
645                        err; "Failed to send request to grpc handle after {} retries",
646                        retries,
647                    );
648                    return Err(err.into());
649                }
650            }
651        }
652    }
653
654    #[inline]
655    fn to_rpc_request(&self, request: Request) -> GreptimeRequest {
656        GreptimeRequest {
657            header: Some(RequestHeader {
658                catalog: self.catalog.clone(),
659                schema: self.schema.clone(),
660                authorization: self.ctx.auth_header.clone(),
661                dbname: self.dbname.clone(),
662                timezone: self.timezone.clone(),
663                // TODO(Taylor-lagrange): add client grpc tracing
664                tracing_context: W3cTrace::new(),
665            }),
666            request: Some(request),
667        }
668    }
669
670    /// Executes a SQL query without any hints.
671    pub async fn sql<S>(&self, sql: S) -> Result<Output>
672    where
673        S: AsRef<str>,
674    {
675        self.sql_with_hint(sql, &[]).await
676    }
677
678    /// Executes a SQL query with optional hints for query optimization.
679    pub async fn sql_with_hint<S>(&self, sql: S, hints: &[(&str, &str)]) -> Result<Output>
680    where
681        S: AsRef<str>,
682    {
683        let request = Request::Query(QueryRequest {
684            query: Some(Query::Sql(sql.as_ref().to_string())),
685        });
686        self.do_get(request, hints, &[], &Default::default())
687            .await
688            .map(OutputWithMetrics::into_output)
689    }
690
691    /// Executes a SQL query and returns the output with terminal metrics.
692    ///
693    /// For stream outputs, callers must consume the stream before reading final
694    /// terminal metrics from [`OutputWithMetrics::metrics`].
695    pub async fn sql_with_terminal_metrics<S>(
696        &self,
697        sql: S,
698        hints: &[(&str, &str)],
699    ) -> Result<OutputWithMetrics>
700    where
701        S: AsRef<str>,
702    {
703        self.query_with_terminal_metrics_and_flow_extensions(
704            QueryRequest {
705                query: Some(Query::Sql(sql.as_ref().to_string())),
706            },
707            hints,
708            &[],
709            &Default::default(),
710        )
711        .await
712    }
713
714    /// Executes a logical plan directly without SQL parsing.
715    pub async fn logical_plan(&self, logical_plan: Vec<u8>) -> Result<Output> {
716        self.query_with_terminal_metrics_and_flow_extensions(
717            QueryRequest {
718                query: Some(Query::LogicalPlan(logical_plan)),
719            },
720            &[],
721            &[],
722            &Default::default(),
723        )
724        .await
725        .map(OutputWithMetrics::into_output)
726    }
727
728    /// Executes a query and carries flow extensions through Flight metadata.
729    ///
730    /// This is the lower-level terminal-metrics API for Flow callers that need
731    /// to pass JSON-bearing flow extensions without going through hint metadata.
732    pub async fn query_with_terminal_metrics_and_flow_extensions(
733        &self,
734        request: QueryRequest,
735        hints: &[(&str, &str)],
736        flow_extensions: &[(&str, &str)],
737        snapshot_seqs: &std::collections::HashMap<u64, u64>,
738    ) -> Result<OutputWithMetrics> {
739        self.do_get(
740            Request::Query(request),
741            hints,
742            flow_extensions,
743            snapshot_seqs,
744        )
745        .await
746    }
747
748    /// Creates a new table using the provided table expression.
749    pub async fn create(&self, expr: CreateTableExpr) -> Result<Output> {
750        let request = Request::Ddl(DdlRequest {
751            expr: Some(DdlExpr::CreateTable(expr)),
752        });
753        self.do_get(request, &[], &[], &Default::default())
754            .await
755            .map(OutputWithMetrics::into_output)
756    }
757
758    /// Alters an existing table using the provided alter expression.
759    pub async fn alter(&self, expr: AlterTableExpr) -> Result<Output> {
760        let request = Request::Ddl(DdlRequest {
761            expr: Some(DdlExpr::AlterTable(expr)),
762        });
763        self.do_get(request, &[], &[], &Default::default())
764            .await
765            .map(OutputWithMetrics::into_output)
766    }
767
768    async fn do_get(
769        &self,
770        request: Request,
771        hints: &[(&str, &str)],
772        flow_extensions: &[(&str, &str)],
773        snapshot_seqs: &std::collections::HashMap<u64, u64>,
774    ) -> Result<OutputWithMetrics> {
775        let request = self.to_rpc_request(request);
776        let request = Ticket {
777            ticket: request.encode_to_vec().into(),
778        };
779
780        let mut request = tonic::Request::new(request);
781        let metadata = request.metadata_mut();
782        Self::put_hints(metadata, hints)?;
783        Self::put_flow_extensions(metadata, flow_extensions)?;
784        Self::put_snapshot_seqs(metadata, snapshot_seqs)?;
785
786        let mut client = self.client.make_flight_client(false, false)?;
787
788        let response = client.mut_inner().do_get(request).await.or_else(|e| {
789            let tonic_code = e.code();
790            let e: Error = e.into();
791            error!(
792                "Failed to do Flight get, addr: {}, code: {}, source: {:?}",
793                client.addr(),
794                tonic_code,
795                e
796            );
797            Err(BoxedError::new(e)).with_context(|_| FlightGetSnafu {
798                addr: client.addr().to_string(),
799                tonic_code,
800            })
801        })?;
802
803        let flight_data_stream = response.into_inner();
804        let mut decoder = FlightDecoder::default();
805
806        let flight_message_stream = flight_data_stream.map(move |flight_data| {
807            flight_data
808                .map_err(Error::from)
809                .and_then(|data| decoder.try_decode(&data).context(ConvertFlightDataSnafu))?
810                .context(IllegalFlightMessagesSnafu {
811                    reason: "none message",
812                })
813        });
814
815        output_from_flight_message_stream(flight_message_stream).await
816    }
817
818    /// Ingest a stream of [RecordBatch]es that belong to a table, using Arrow Flight's "`DoPut`"
819    /// method. The return value is also a stream, produces [DoPutResponse]s.
820    pub async fn do_put(&self, stream: FlightDataStream) -> Result<DoPutResponseStream> {
821        let mut request = tonic::Request::new(stream);
822
823        if let Some(AuthHeader {
824            auth_scheme: Some(AuthScheme::Basic(Basic { username, password })),
825        }) = &self.ctx.auth_header
826        {
827            let encoded = BASE64_STANDARD.encode(format!("{username}:{password}"));
828            let value = MetadataValue::from_str(&format!("Basic {encoded}"))
829                .context(InvalidTonicMetadataValueSnafu)?;
830            request.metadata_mut().insert("x-greptime-auth", value);
831        }
832
833        let db_to_put = if !self.dbname.is_empty() {
834            &self.dbname
835        } else {
836            &build_db_string(self.catalog_or_default(), self.schema_or_default())
837        };
838        request.metadata_mut().insert(
839            "x-greptime-db-name",
840            MetadataValue::from_str(db_to_put).context(InvalidTonicMetadataValueSnafu)?,
841        );
842
843        let mut client = self.client.make_flight_client(false, false)?;
844        let response = client.mut_inner().do_put(request).await?;
845        let response = response
846            .into_inner()
847            .map_err(Into::into)
848            .and_then(|x| future::ready(DoPutResponse::try_from(x).context(ConvertFlightDataSnafu)))
849            .boxed();
850        Ok(response)
851    }
852}
853
854/// by grpc standard, only `Unavailable` is retryable, see: https://github.com/grpc/grpc/blob/master/doc/statuscodes.md#status-codes-and-their-use-in-grpc
855pub fn is_grpc_retryable(err: &tonic::Status) -> bool {
856    matches!(err.code(), tonic::Code::Unavailable)
857}
858
859#[derive(Default, Debug, Clone)]
860struct FlightContext {
861    auth_header: Option<AuthHeader>,
862}
863
864#[cfg(test)]
865mod tests {
866    use std::sync::Arc;
867    use std::task::{Context, Poll};
868
869    use api::v1::auth_header::AuthScheme;
870    use api::v1::{AuthHeader, Basic};
871    use common_error::status_code::StatusCode;
872    use common_query::OutputData;
873    use common_recordbatch::{OrderOption, RecordBatch, RecordBatchStream};
874    use datatypes::prelude::{ConcreteDataType, VectorRef};
875    use datatypes::schema::{ColumnSchema, Schema};
876    use datatypes::vectors::Int32Vector;
877    use futures_util::StreamExt;
878    use tonic::{Code, Status};
879
880    use super::*;
881    use crate::error::TonicSnafu;
882
883    struct MockMetricsStream {
884        schema: datatypes::schema::SchemaRef,
885        batch: Option<RecordBatch>,
886        metrics: RecordBatchMetrics,
887        terminal_metrics_only: bool,
888    }
889
890    impl Stream for MockMetricsStream {
891        type Item = common_recordbatch::error::Result<RecordBatch>;
892
893        fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
894            Poll::Ready(self.batch.take().map(Ok))
895        }
896    }
897
898    impl RecordBatchStream for MockMetricsStream {
899        fn name(&self) -> &str {
900            "MockMetricsStream"
901        }
902
903        fn schema(&self) -> datatypes::schema::SchemaRef {
904            self.schema.clone()
905        }
906
907        fn output_ordering(&self) -> Option<&[OrderOption]> {
908            None
909        }
910
911        fn metrics(&self) -> Option<RecordBatchMetrics> {
912            if self.terminal_metrics_only && self.batch.is_some() {
913                return None;
914            }
915            Some(self.metrics.clone())
916        }
917    }
918
919    fn terminal_metrics_json() -> String {
920        terminal_metrics_json_with_seq(42)
921    }
922
923    fn terminal_metrics_json_with_seq(seq: u64) -> String {
924        serde_json::to_string(&RecordBatchMetrics {
925            region_watermarks: vec![common_recordbatch::adapter::RegionWatermarkEntry {
926                region_id: 7,
927                watermark: Some(seq),
928            }],
929            ..Default::default()
930        })
931        .unwrap()
932    }
933
934    #[test]
935    fn test_put_flow_extensions_preserves_comma_bearing_values() {
936        let mut metadata = MetadataMap::new();
937        Database::put_flow_extensions(
938            &mut metadata,
939            &[
940                ("flow.return_region_seq", "true"),
941                ("flow.incremental_after_seqs", r#"{"1":10,"2":20}"#),
942            ],
943        )
944        .unwrap();
945
946        let value = metadata
947            .get(FLOW_EXTENSIONS_METADATA_KEY)
948            .unwrap()
949            .to_str()
950            .unwrap();
951        let decoded: Vec<(String, String)> = serde_json::from_str(value).unwrap();
952        assert_eq!(
953            decoded,
954            vec![
955                ("flow.return_region_seq".to_string(), "true".to_string()),
956                (
957                    "flow.incremental_after_seqs".to_string(),
958                    r#"{"1":10,"2":20}"#.to_string()
959                ),
960            ]
961        );
962    }
963
964    #[test]
965    fn test_put_snapshot_seqs_preserves_u64_precision() {
966        let mut metadata = MetadataMap::new();
967        let snapshot_seqs = std::collections::HashMap::from([
968            (u64::MAX, u64::MAX - 1),
969            (9_007_199_254_740_993_u64, 9_007_199_254_740_995_u64),
970        ]);
971
972        Database::put_snapshot_seqs(&mut metadata, &snapshot_seqs).unwrap();
973
974        let value = metadata
975            .get(SNAPSHOT_SEQS_METADATA_KEY)
976            .unwrap()
977            .to_str()
978            .unwrap();
979        let decoded: std::collections::HashMap<u64, u64> = serde_json::from_str(value).unwrap();
980        assert_eq!(decoded, snapshot_seqs);
981    }
982
983    #[test]
984    fn test_flight_ctx() {
985        let mut ctx = FlightContext::default();
986        assert!(ctx.auth_header.is_none());
987
988        let basic = AuthScheme::Basic(Basic {
989            username: "u".to_string(),
990            password: "p".to_string(),
991        });
992
993        ctx.auth_header = Some(AuthHeader {
994            auth_scheme: Some(basic),
995        });
996
997        assert!(matches!(
998            ctx.auth_header,
999            Some(AuthHeader {
1000                auth_scheme: Some(AuthScheme::Basic(_)),
1001            })
1002        ));
1003    }
1004
1005    #[test]
1006    fn test_from_tonic_status() {
1007        let expected = TonicSnafu {
1008            code: StatusCode::Internal,
1009            msg: "blabla".to_string(),
1010            tonic_code: Code::Internal,
1011        }
1012        .build();
1013
1014        let status = Status::new(Code::Internal, "blabla");
1015        let actual: Error = status.into();
1016
1017        assert_eq!(expected.to_string(), actual.to_string());
1018    }
1019
1020    #[tokio::test]
1021    async fn test_query_with_terminal_metrics_tracks_terminal_only_metrics() {
1022        let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
1023            "v",
1024            ConcreteDataType::int32_datatype(),
1025            false,
1026        )]));
1027        let batch = RecordBatch::new(
1028            schema.clone(),
1029            vec![Arc::new(Int32Vector::from_slice([1, 2])) as VectorRef],
1030        )
1031        .unwrap();
1032        let output = Output::new_with_stream(Box::pin(MockMetricsStream {
1033            schema,
1034            batch: Some(batch),
1035            metrics: RecordBatchMetrics {
1036                region_watermarks: vec![common_recordbatch::adapter::RegionWatermarkEntry {
1037                    region_id: 7,
1038                    watermark: Some(42),
1039                }],
1040                ..Default::default()
1041            },
1042            terminal_metrics_only: true,
1043        }));
1044
1045        let result = OutputWithMetrics::from_output(output);
1046        let terminal_metrics = result.metrics.clone();
1047        assert!(!terminal_metrics.is_ready());
1048        assert!(terminal_metrics.get().is_none());
1049
1050        let OutputData::Stream(mut stream) = result.output.data else {
1051            panic!("expected stream output");
1052        };
1053        while stream.next().await.is_some() {}
1054
1055        assert!(terminal_metrics.is_ready());
1056        assert_eq!(
1057            terminal_metrics.participating_regions(),
1058            Some(std::collections::BTreeSet::from([7_u64]))
1059        );
1060        assert_eq!(
1061            terminal_metrics.region_watermark_map(),
1062            Some(std::collections::HashMap::from([(7_u64, 42_u64)]))
1063        );
1064    }
1065
1066    #[test]
1067    fn test_parse_terminal_metrics_rejects_invalid_json() {
1068        assert!(parse_terminal_metrics("{not-json}").is_err());
1069    }
1070
1071    #[tokio::test]
1072    async fn test_affected_rows_inline_metrics_are_parsed() {
1073        let output = output_from_flight_message_stream(futures_util::stream::iter(vec![Ok(
1074            FlightMessage::AffectedRows {
1075                rows: 3,
1076                metrics: Some(terminal_metrics_json()),
1077            },
1078        )]
1079            as Vec<Result<FlightMessage>>))
1080        .await
1081        .unwrap();
1082
1083        assert!(matches!(output.output.data, OutputData::AffectedRows(3)));
1084        assert!(output.metrics.is_ready());
1085        assert_eq!(
1086            output.metrics.region_watermark_map(),
1087            Some(std::collections::HashMap::from([(7, 42)]))
1088        );
1089    }
1090
1091    #[tokio::test]
1092    async fn test_affected_rows_inline_metrics_rejects_trailing_metrics() {
1093        let metrics_json = terminal_metrics_json();
1094        let err = output_from_flight_message_stream(futures_util::stream::iter(vec![
1095            Ok(FlightMessage::AffectedRows {
1096                rows: 3,
1097                metrics: Some(metrics_json.clone()),
1098            }),
1099            Ok(FlightMessage::Metrics(metrics_json)),
1100        ]
1101            as Vec<Result<FlightMessage>>))
1102        .await
1103        .unwrap_err();
1104
1105        assert!(
1106            err.to_string().contains("already carries Metrics"),
1107            "unexpected error: {err:?}"
1108        );
1109    }
1110
1111    #[tokio::test]
1112    async fn test_invalid_terminal_metrics_after_record_batch_yields_batch_then_error() {
1113        let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
1114            "v",
1115            ConcreteDataType::int32_datatype(),
1116            false,
1117        )]));
1118        let batch = RecordBatch::new(
1119            schema.clone(),
1120            vec![Arc::new(Int32Vector::from_slice([1])) as VectorRef],
1121        )
1122        .unwrap();
1123        let output = output_from_flight_message_stream(futures_util::stream::iter(vec![
1124            Ok(FlightMessage::Schema(schema.arrow_schema().clone())),
1125            Ok(FlightMessage::RecordBatch(batch.into_df_record_batch())),
1126            Ok(FlightMessage::Metrics("{not-json}".to_string())),
1127        ]
1128            as Vec<Result<FlightMessage>>))
1129        .await
1130        .unwrap();
1131        let terminal_metrics = output.metrics.clone();
1132        let OutputData::Stream(mut record_batch_stream) = output.output.data else {
1133            panic!("expected stream output");
1134        };
1135
1136        let batch = record_batch_stream.next().await.unwrap().unwrap();
1137        assert_eq!(batch.num_rows(), 1);
1138
1139        let err = record_batch_stream.next().await.unwrap().unwrap_err();
1140        assert_eq!("External error", err.to_string());
1141        assert!(
1142            format!("{err:?}").contains("Invalid terminal metrics message"),
1143            "unexpected error: {err:?}"
1144        );
1145        assert!(record_batch_stream.next().await.is_none());
1146        assert!(terminal_metrics.is_ready());
1147        assert!(terminal_metrics.get().is_none());
1148    }
1149
1150    #[tokio::test]
1151    async fn test_record_batch_stream_continues_after_partial_metrics() {
1152        let schema = Arc::new(Schema::new(vec![ColumnSchema::new(
1153            "v",
1154            ConcreteDataType::int32_datatype(),
1155            false,
1156        )]));
1157        let first_batch = RecordBatch::new(
1158            schema.clone(),
1159            vec![Arc::new(Int32Vector::from_slice([1])) as VectorRef],
1160        )
1161        .unwrap();
1162        let second_batch = RecordBatch::new(
1163            schema.clone(),
1164            vec![Arc::new(Int32Vector::from_slice([2])) as VectorRef],
1165        )
1166        .unwrap();
1167        let output = output_from_flight_message_stream(futures_util::stream::iter(vec![
1168            Ok(FlightMessage::Schema(schema.arrow_schema().clone())),
1169            Ok(FlightMessage::RecordBatch(
1170                first_batch.into_df_record_batch(),
1171            )),
1172            Ok(FlightMessage::Metrics(terminal_metrics_json_with_seq(1))),
1173            Ok(FlightMessage::RecordBatch(
1174                second_batch.into_df_record_batch(),
1175            )),
1176            Ok(FlightMessage::Metrics(terminal_metrics_json_with_seq(2))),
1177        ]
1178            as Vec<Result<FlightMessage>>))
1179        .await
1180        .unwrap();
1181        let terminal_metrics = output.metrics.clone();
1182        let OutputData::Stream(mut record_batch_stream) = output.output.data else {
1183            panic!("expected stream output");
1184        };
1185
1186        let first_batch = record_batch_stream.next().await.unwrap().unwrap();
1187        assert_eq!(first_batch.num_rows(), 1);
1188        let second_batch = record_batch_stream.next().await.unwrap().unwrap();
1189        assert_eq!(second_batch.num_rows(), 1);
1190        assert!(record_batch_stream.next().await.is_none());
1191
1192        assert!(terminal_metrics.is_ready());
1193        assert_eq!(
1194            terminal_metrics.region_watermark_map(),
1195            Some(std::collections::HashMap::from([(7, 2)]))
1196        );
1197    }
1198
1199    #[test]
1200    fn test_output_metrics_distinguishes_empty_region_watermarks_from_absence() {
1201        let metrics = OutputMetrics::default();
1202        metrics.update(Some(RecordBatchMetrics::default()));
1203
1204        assert_eq!(
1205            metrics.participating_regions(),
1206            Some(std::collections::BTreeSet::new())
1207        );
1208        assert_eq!(
1209            metrics.region_watermark_map(),
1210            Some(std::collections::HashMap::new())
1211        );
1212    }
1213}