Skip to main content

common_grpc/
flight.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15pub mod do_put;
16
17use std::collections::HashMap;
18use std::sync::Arc;
19
20use api::v1::{AffectedRows, FlightMetadata, Metrics};
21use arrow_flight::utils::flight_data_to_arrow_batch;
22use arrow_flight::{FlightData, SchemaAsIpc};
23use common_base::bytes::Bytes;
24use common_recordbatch::DfRecordBatch;
25use datatypes::arrow;
26use datatypes::arrow::array::ArrayRef;
27use datatypes::arrow::buffer::Buffer;
28use datatypes::arrow::datatypes::{DataType, Schema as ArrowSchema, SchemaRef};
29use datatypes::arrow::error::ArrowError;
30use datatypes::arrow::ipc::{MessageHeader, convert, reader, root_as_message, writer};
31use flatbuffers::FlatBufferBuilder;
32use prost::Message;
33use prost::bytes::Bytes as ProstBytes;
34use snafu::{OptionExt, ResultExt};
35use vec1::{Vec1, vec1};
36
37use crate::error;
38use crate::error::{DecodeFlightDataSnafu, InvalidFlightDataSnafu, Result};
39
40/// Flight metadata key used to carry flow query extensions as JSON pairs.
41pub const FLOW_EXTENSIONS_METADATA_KEY: &str = "x-greptime-flow-extensions";
42
43#[derive(Debug, Clone)]
44pub enum FlightMessage {
45    Schema(SchemaRef),
46    RecordBatch(DfRecordBatch),
47    AffectedRows {
48        rows: usize,
49        metrics: Option<String>,
50    },
51    Metrics(String),
52}
53
54pub struct FlightEncoder {
55    write_options: writer::IpcWriteOptions,
56    data_gen: writer::IpcDataGenerator,
57    dictionary_tracker: writer::DictionaryTracker,
58}
59
60impl Default for FlightEncoder {
61    fn default() -> Self {
62        let write_options = writer::IpcWriteOptions::default()
63            .try_with_compression(Some(arrow::ipc::CompressionType::LZ4_FRAME))
64            .unwrap();
65
66        Self {
67            write_options,
68            data_gen: writer::IpcDataGenerator::default(),
69            dictionary_tracker: writer::DictionaryTracker::new(false),
70        }
71    }
72}
73
74impl FlightEncoder {
75    /// Creates new [FlightEncoder] with compression disabled.
76    pub fn with_compression_disabled() -> Self {
77        let write_options = writer::IpcWriteOptions::default()
78            .try_with_compression(None)
79            .unwrap();
80
81        Self {
82            write_options,
83            data_gen: writer::IpcDataGenerator::default(),
84            dictionary_tracker: writer::DictionaryTracker::new(false),
85        }
86    }
87
88    /// Encode the Arrow schema to [FlightData].
89    pub fn encode_schema(&self, schema: &ArrowSchema) -> FlightData {
90        SchemaAsIpc::new(schema, &self.write_options).into()
91    }
92
93    /// Encode the [FlightMessage] to a list (at least one element) of [FlightData]s.
94    ///
95    /// Normally only when the [FlightMessage] is an Arrow [RecordBatch] with dictionary arrays
96    /// will the encoder produce more than one [FlightData]s. Other types of [FlightMessage] should
97    /// be encoded to exactly one [FlightData].
98    pub fn encode(&mut self, flight_message: FlightMessage) -> Vec1<FlightData> {
99        match flight_message {
100            FlightMessage::Schema(schema) => {
101                schema.fields().iter().for_each(|x| {
102                    if matches!(x.data_type(), DataType::Dictionary(_, _)) {
103                        self.dictionary_tracker.next_dict_id();
104                    }
105                });
106
107                vec1![self.encode_schema(schema.as_ref())]
108            }
109            FlightMessage::RecordBatch(record_batch) => {
110                let (encoded_dictionaries, encoded_batch) = self
111                    .data_gen
112                    .encode(
113                        &record_batch,
114                        &mut self.dictionary_tracker,
115                        &self.write_options,
116                        &mut Default::default(),
117                    )
118                    .expect("DictionaryTracker configured above to not fail on replacement");
119
120                Vec1::from_vec_push(
121                    encoded_dictionaries.into_iter().map(Into::into).collect(),
122                    encoded_batch.into(),
123                )
124            }
125            FlightMessage::AffectedRows { rows, metrics } => {
126                let metadata = FlightMetadata {
127                    affected_rows: Some(AffectedRows { value: rows as _ }),
128                    metrics: metrics.map(|s| Metrics {
129                        metrics: s.into_bytes(),
130                    }),
131                }
132                .encode_to_vec();
133                vec1![FlightData {
134                    flight_descriptor: None,
135                    data_header: build_none_flight_msg().into(),
136                    app_metadata: metadata.into(),
137                    data_body: ProstBytes::default(),
138                }]
139            }
140            FlightMessage::Metrics(s) => {
141                let metadata = FlightMetadata {
142                    affected_rows: None,
143                    metrics: Some(Metrics {
144                        metrics: s.as_bytes().to_vec(),
145                    }),
146                }
147                .encode_to_vec();
148                vec1![FlightData {
149                    flight_descriptor: None,
150                    data_header: build_none_flight_msg().into(),
151                    app_metadata: metadata.into(),
152                    data_body: ProstBytes::default(),
153                }]
154            }
155        }
156    }
157}
158
159#[derive(Default)]
160pub struct FlightDecoder {
161    schema: Option<SchemaRef>,
162    schema_bytes: Option<bytes::Bytes>,
163    dictionaries_by_id: HashMap<i64, ArrayRef>,
164}
165
166impl FlightDecoder {
167    /// Build a [FlightDecoder] instance from provided schema bytes.
168    pub fn try_from_schema_bytes(schema_bytes: &bytes::Bytes) -> Result<Self> {
169        let arrow_schema = convert::try_schema_from_flatbuffer_bytes(&schema_bytes[..])
170            .context(error::ArrowSnafu)?;
171        Ok(Self {
172            schema: Some(Arc::new(arrow_schema)),
173            schema_bytes: Some(schema_bytes.clone()),
174            dictionaries_by_id: HashMap::new(),
175        })
176    }
177
178    pub fn try_decode_record_batch(
179        &mut self,
180        data_header: &bytes::Bytes,
181        data_body: &bytes::Bytes,
182    ) -> Result<DfRecordBatch> {
183        let schema = self
184            .schema
185            .as_ref()
186            .context(InvalidFlightDataSnafu {
187                reason: "Should have decoded schema first!",
188            })?
189            .clone();
190        let message = root_as_message(&data_header[..])
191            .map_err(|err| {
192                ArrowError::ParseError(format!("Unable to get root as message: {err:?}"))
193            })
194            .context(error::ArrowSnafu)?;
195        let result = message
196            .header_as_record_batch()
197            .ok_or_else(|| {
198                ArrowError::ParseError(
199                    "Unable to convert flight data header to a record batch".to_string(),
200                )
201            })
202            .and_then(|batch| {
203                reader::read_record_batch(
204                    &Buffer::from(data_body.as_ref()),
205                    batch,
206                    schema,
207                    &HashMap::new(),
208                    None,
209                    &message.version(),
210                )
211            })
212            .context(error::ArrowSnafu)?;
213        Ok(result)
214    }
215
216    /// Try to decode the [FlightData] to a [FlightMessage].
217    ///
218    /// If the [FlightData] is of type `DictionaryBatch` (produced while encoding an Arrow
219    /// [RecordBatch] with dictionary arrays), the decoder will not return any [FlightMessage]s.
220    /// Instead, it will update its internal dictionary cache. Other types of [FlightData] will
221    /// be decoded to exactly one [FlightMessage].
222    pub fn try_decode(&mut self, flight_data: &FlightData) -> Result<Option<FlightMessage>> {
223        let message = root_as_message(&flight_data.data_header).map_err(|e| {
224            InvalidFlightDataSnafu {
225                reason: e.to_string(),
226            }
227            .build()
228        })?;
229        match message.header_type() {
230            MessageHeader::NONE => {
231                let metadata = FlightMetadata::decode(flight_data.app_metadata.clone())
232                    .context(DecodeFlightDataSnafu)?;
233                if let Some(AffectedRows { value }) = metadata.affected_rows {
234                    return Ok(Some(FlightMessage::AffectedRows {
235                        rows: value as _,
236                        metrics: metadata
237                            .metrics
238                            .map(|m| String::from_utf8_lossy(&m.metrics).to_string()),
239                    }));
240                }
241                if let Some(Metrics { metrics }) = metadata.metrics {
242                    return Ok(Some(FlightMessage::Metrics(
243                        String::from_utf8_lossy(&metrics).to_string(),
244                    )));
245                }
246                InvalidFlightDataSnafu {
247                    reason: "Expecting FlightMetadata have some meaningful content.",
248                }
249                .fail()
250            }
251            MessageHeader::Schema => {
252                let arrow_schema = Arc::new(ArrowSchema::try_from(flight_data).map_err(|e| {
253                    InvalidFlightDataSnafu {
254                        reason: e.to_string(),
255                    }
256                    .build()
257                })?);
258                self.schema = Some(arrow_schema.clone());
259                self.schema_bytes = Some(flight_data.data_header.clone());
260                Ok(Some(FlightMessage::Schema(arrow_schema)))
261            }
262            MessageHeader::RecordBatch => {
263                let schema = self.schema.clone().context(InvalidFlightDataSnafu {
264                    reason: "Should have decoded schema first!",
265                })?;
266                let arrow_batch = flight_data_to_arrow_batch(
267                    flight_data,
268                    schema.clone(),
269                    &self.dictionaries_by_id,
270                )
271                .map_err(|e| {
272                    InvalidFlightDataSnafu {
273                        reason: e.to_string(),
274                    }
275                    .build()
276                })?;
277                Ok(Some(FlightMessage::RecordBatch(arrow_batch)))
278            }
279            MessageHeader::DictionaryBatch => {
280                let dictionary_batch =
281                    message
282                        .header_as_dictionary_batch()
283                        .context(InvalidFlightDataSnafu {
284                            reason: "could not get dictionary batch from DictionaryBatch message",
285                        })?;
286
287                let schema = self.schema.as_ref().context(InvalidFlightDataSnafu {
288                    reason: "schema message is not present previously",
289                })?;
290
291                reader::read_dictionary(
292                    &flight_data.data_body.clone().into(),
293                    dictionary_batch,
294                    schema,
295                    &mut self.dictionaries_by_id,
296                    &message.version(),
297                )
298                .context(error::ArrowSnafu)?;
299                Ok(None)
300            }
301            other => {
302                let name = other.variant_name().unwrap_or("UNKNOWN");
303                InvalidFlightDataSnafu {
304                    reason: format!("Unsupported FlightData type: {name}"),
305                }
306                .fail()
307            }
308        }
309    }
310
311    pub fn schema(&self) -> Option<&SchemaRef> {
312        self.schema.as_ref()
313    }
314
315    pub fn schema_bytes(&self) -> Option<bytes::Bytes> {
316        self.schema_bytes.clone()
317    }
318}
319
320pub fn flight_messages_to_recordbatches(
321    messages: Vec<FlightMessage>,
322) -> Result<Vec<DfRecordBatch>> {
323    if messages.is_empty() {
324        Ok(vec![])
325    } else {
326        let mut recordbatches = Vec::with_capacity(messages.len() - 1);
327
328        match &messages[0] {
329            FlightMessage::Schema(_schema) => {}
330            _ => {
331                return InvalidFlightDataSnafu {
332                    reason: "First Flight Message must be schema!",
333                }
334                .fail();
335            }
336        };
337
338        for message in messages.into_iter().skip(1) {
339            match message {
340                FlightMessage::RecordBatch(recordbatch) => recordbatches.push(recordbatch),
341                _ => {
342                    return InvalidFlightDataSnafu {
343                        reason: "Expect the following Flight Messages are all Recordbatches!",
344                    }
345                    .fail();
346                }
347            }
348        }
349
350        Ok(recordbatches)
351    }
352}
353
354fn build_none_flight_msg() -> Bytes {
355    let mut builder = FlatBufferBuilder::new();
356
357    let mut message = arrow::ipc::MessageBuilder::new(&mut builder);
358    message.add_version(arrow::ipc::MetadataVersion::V5);
359    message.add_header_type(MessageHeader::NONE);
360    message.add_bodyLength(0);
361
362    let data = message.finish();
363    builder.finish(data, None);
364
365    builder.finished_data().into()
366}
367
368#[cfg(test)]
369mod test {
370    use arrow_flight::utils::batches_to_flight_data;
371    use datatypes::arrow::array::{
372        DictionaryArray, Int32Array, StringArray, UInt8Array, UInt32Array,
373    };
374    use datatypes::arrow::datatypes::{DataType, Field, Schema};
375
376    use super::*;
377    use crate::Error;
378
379    #[test]
380    fn test_try_decode() -> Result<()> {
381        let schema = Arc::new(ArrowSchema::new(vec![Field::new(
382            "n",
383            DataType::Int32,
384            true,
385        )]));
386
387        let batch1 = DfRecordBatch::try_new(
388            schema.clone(),
389            vec![Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])) as _],
390        )
391        .unwrap();
392        let batch2 = DfRecordBatch::try_new(
393            schema.clone(),
394            vec![Arc::new(Int32Array::from(vec![None, Some(5)])) as _],
395        )
396        .unwrap();
397
398        let flight_data =
399            batches_to_flight_data(&schema, vec![batch1.clone(), batch2.clone()]).unwrap();
400        assert_eq!(flight_data.len(), 3);
401        let [d1, d2, d3] = flight_data.as_slice() else {
402            unreachable!()
403        };
404
405        let decoder = &mut FlightDecoder::default();
406        assert!(decoder.schema.is_none());
407
408        let result = decoder.try_decode(d2);
409        assert!(matches!(result, Err(Error::InvalidFlightData { .. })));
410        assert!(
411            result
412                .unwrap_err()
413                .to_string()
414                .contains("Should have decoded schema first!")
415        );
416
417        let message = decoder.try_decode(d1)?.unwrap();
418        assert!(matches!(message, FlightMessage::Schema(_)));
419        let FlightMessage::Schema(decoded_schema) = message else {
420            unreachable!()
421        };
422        assert_eq!(decoded_schema, schema);
423
424        let _ = decoder.schema.as_ref().unwrap();
425
426        let message = decoder.try_decode(d2)?.unwrap();
427        assert!(matches!(message, FlightMessage::RecordBatch(_)));
428        let FlightMessage::RecordBatch(actual_batch) = message else {
429            unreachable!()
430        };
431        assert_eq!(actual_batch, batch1);
432
433        let message = decoder.try_decode(d3)?.unwrap();
434        assert!(matches!(message, FlightMessage::RecordBatch(_)));
435        let FlightMessage::RecordBatch(actual_batch) = message else {
436            unreachable!()
437        };
438        assert_eq!(actual_batch, batch2);
439        Ok(())
440    }
441
442    #[test]
443    fn test_affected_rows_metrics_encode_decode() -> Result<()> {
444        let metrics = r#"{"region_watermarks":[{"region_id":42,"watermark":7}]}"#;
445        let mut encoder = FlightEncoder::default();
446        let encoded = encoder.encode(FlightMessage::AffectedRows {
447            rows: 3,
448            metrics: Some(metrics.to_string()),
449        });
450
451        assert_eq!(encoded.len(), 1);
452
453        let mut decoder = FlightDecoder::default();
454        let decoded = decoder.try_decode(encoded.first())?.unwrap();
455        let FlightMessage::AffectedRows {
456            rows,
457            metrics: decoded_metrics,
458        } = decoded
459        else {
460            unreachable!()
461        };
462        assert_eq!(rows, 3);
463        assert_eq!(decoded_metrics.as_deref(), Some(metrics));
464
465        let encoded = encoder.encode(FlightMessage::AffectedRows {
466            rows: 5,
467            metrics: None,
468        });
469        let decoded = decoder.try_decode(encoded.first())?.unwrap();
470        let FlightMessage::AffectedRows {
471            rows,
472            metrics: decoded_metrics,
473        } = decoded
474        else {
475            unreachable!()
476        };
477        assert_eq!(rows, 5);
478        assert!(decoded_metrics.is_none());
479
480        Ok(())
481    }
482
483    #[test]
484    fn test_flight_messages_to_recordbatches() {
485        let schema = Arc::new(Schema::new(vec![Field::new("m", DataType::Int32, true)]));
486        let batch1 = DfRecordBatch::try_new(
487            schema.clone(),
488            vec![Arc::new(Int32Array::from(vec![Some(2), None, Some(4)])) as _],
489        )
490        .unwrap();
491        let batch2 = DfRecordBatch::try_new(
492            schema.clone(),
493            vec![Arc::new(Int32Array::from(vec![None, Some(6)])) as _],
494        )
495        .unwrap();
496        let recordbatches = vec![batch1.clone(), batch2.clone()];
497
498        let m1 = FlightMessage::Schema(schema);
499        let m2 = FlightMessage::RecordBatch(batch1);
500        let m3 = FlightMessage::RecordBatch(batch2);
501
502        let result = flight_messages_to_recordbatches(vec![m2.clone(), m1.clone(), m3.clone()]);
503        assert!(matches!(result, Err(Error::InvalidFlightData { .. })));
504        assert!(
505            result
506                .unwrap_err()
507                .to_string()
508                .contains("First Flight Message must be schema!")
509        );
510
511        let result = flight_messages_to_recordbatches(vec![m1.clone(), m2.clone(), m1.clone()]);
512        assert!(matches!(result, Err(Error::InvalidFlightData { .. })));
513        assert!(
514            result
515                .unwrap_err()
516                .to_string()
517                .contains("Expect the following Flight Messages are all Recordbatches!")
518        );
519
520        let actual = flight_messages_to_recordbatches(vec![m1, m2, m3]).unwrap();
521        assert_eq!(actual, recordbatches);
522    }
523
524    #[test]
525    fn test_flight_encode_decode_with_dictionary_array() -> Result<()> {
526        let schema = Arc::new(Schema::new(vec![
527            Field::new("i", DataType::UInt8, true),
528            Field::new_dictionary("s", DataType::UInt32, DataType::Utf8, true),
529        ]));
530        let batch1 = DfRecordBatch::try_new(
531            schema.clone(),
532            vec![
533                Arc::new(UInt8Array::from_iter_values(vec![1, 2, 3])) as _,
534                Arc::new(DictionaryArray::new(
535                    UInt32Array::from_value(0, 3),
536                    Arc::new(StringArray::from_iter_values(["x"])),
537                )) as _,
538            ],
539        )
540        .unwrap();
541        let batch2 = DfRecordBatch::try_new(
542            schema.clone(),
543            vec![
544                Arc::new(UInt8Array::from_iter_values(vec![4, 5, 6, 7, 8])) as _,
545                Arc::new(DictionaryArray::new(
546                    UInt32Array::from_iter_values([0, 1, 2, 2, 3]),
547                    Arc::new(StringArray::from_iter_values(["h", "e", "l", "o"])),
548                )) as _,
549            ],
550        )
551        .unwrap();
552
553        let message_1 = FlightMessage::Schema(schema.clone());
554        let message_2 = FlightMessage::RecordBatch(batch1);
555        let message_3 = FlightMessage::RecordBatch(batch2);
556
557        let mut encoder = FlightEncoder::default();
558        let encoded_1 = encoder.encode(message_1);
559        let encoded_2 = encoder.encode(message_2);
560        let encoded_3 = encoder.encode(message_3);
561        // message 1 is Arrow Schema, should be encoded to one FlightData:
562        assert_eq!(encoded_1.len(), 1);
563        // message 2 and 3 are Arrow RecordBatch with dictionary arrays, should be encoded to
564        // multiple FlightData:
565        assert_eq!(encoded_2.len(), 2);
566        assert_eq!(encoded_3.len(), 2);
567
568        let mut decoder = FlightDecoder::default();
569        let decoded_1 = decoder.try_decode(encoded_1.first())?;
570        let Some(FlightMessage::Schema(actual_schema)) = decoded_1 else {
571            unreachable!()
572        };
573        assert_eq!(actual_schema, schema);
574        let decoded_2 = decoder.try_decode(&encoded_2[0])?;
575        // expected to be a dictionary batch message, decoder should return none:
576        assert!(decoded_2.is_none());
577        let Some(FlightMessage::RecordBatch(decoded_2)) = decoder.try_decode(&encoded_2[1])? else {
578            unreachable!()
579        };
580        let decoded_3 = decoder.try_decode(&encoded_3[0])?;
581        // expected to be a dictionary batch message, decoder should return none:
582        assert!(decoded_3.is_none());
583        let Some(FlightMessage::RecordBatch(decoded_3)) = decoder.try_decode(&encoded_3[1])? else {
584            unreachable!()
585        };
586        let actual = arrow::util::pretty::pretty_format_batches(&[decoded_2, decoded_3])
587            .unwrap()
588            .to_string();
589        let expected = r"
590+---+---+
591| i | s |
592+---+---+
593| 1 | x |
594| 2 | x |
595| 3 | x |
596| 4 | h |
597| 5 | e |
598| 6 | l |
599| 7 | l |
600| 8 | o |
601+---+---+";
602        assert_eq!(actual, expected.trim());
603        Ok(())
604    }
605
606    #[test]
607    fn test_affected_rows_roundtrip_through_flight_codec() {
608        // Verify the full FlightEncoder → FlightDecoder pipeline handles
609        // the new FlightMessage::AffectedRows variant with optional inline
610        // metrics without breaking the wire protocol.
611        let mut encoder = FlightEncoder::default();
612        let mut decoder = FlightDecoder::default();
613
614        // Without metrics — same wire format as old `AffectedRows(7)`.
615        let encoded = encoder.encode(FlightMessage::AffectedRows {
616            rows: 7,
617            metrics: None,
618        });
619        let decoded = decoder.try_decode(encoded.first()).unwrap().unwrap();
620        assert!(matches!(
621            decoded,
622            FlightMessage::AffectedRows {
623                rows: 7,
624                metrics: None,
625            }
626        ));
627
628        // With metrics — new capability, row count preserved.
629        let json = r#"{"region_watermarks":[{"region_id":1,"watermark":99}]}"#;
630        let encoded = encoder.encode(FlightMessage::AffectedRows {
631            rows: 42,
632            metrics: Some(json.to_string()),
633        });
634        let decoded = decoder.try_decode(encoded.first()).unwrap().unwrap();
635        assert!(matches!(
636            decoded,
637            FlightMessage::AffectedRows {
638                rows: 42,
639                metrics: Some(_),
640            }
641        ));
642    }
643
644    /// Simulates the wire output of the **old** `FlightMessage::AffectedRows(usize)`
645    /// variant and verifies that the **new** `FlightDecoder` handles it.
646    #[test]
647    fn test_old_affected_rows_format_decoded_by_new_code() {
648        use arrow_flight::FlightData;
649        use prost::bytes::Bytes as ProstBytes;
650
651        // The old encoder produced FlightData whose app_metadata is
652        // FlightMetadata { affected_rows, metrics: None }. The new
653        // `AffectedRows { rows, metrics: Option<String> }` variant with
654        // `metrics: None` produces the exact same wire bytes.
655        let old_wire_bytes = FlightData {
656            flight_descriptor: None,
657            data_header: build_none_flight_msg().into(),
658            app_metadata: FlightMetadata {
659                affected_rows: Some(AffectedRows { value: 99 }),
660                metrics: None, // old format: no metrics field
661            }
662            .encode_to_vec()
663            .into(),
664            data_body: ProstBytes::default(),
665        };
666
667        let mut decoder = FlightDecoder::default();
668        let decoded = decoder.try_decode(&old_wire_bytes).unwrap().unwrap();
669        assert!(matches!(
670            decoded,
671            FlightMessage::AffectedRows {
672                rows: 99,
673                metrics: None,
674            }
675        ));
676    }
677}