1pub mod do_put;
16
17use std::collections::HashMap;
18use std::sync::Arc;
19
20use api::v1::{AffectedRows, FlightMetadata, Metrics};
21use arrow_flight::utils::flight_data_to_arrow_batch;
22use arrow_flight::{FlightData, SchemaAsIpc};
23use common_base::bytes::Bytes;
24use common_recordbatch::DfRecordBatch;
25use datatypes::arrow;
26use datatypes::arrow::array::ArrayRef;
27use datatypes::arrow::buffer::Buffer;
28use datatypes::arrow::datatypes::{DataType, Schema as ArrowSchema, SchemaRef};
29use datatypes::arrow::error::ArrowError;
30use datatypes::arrow::ipc::{MessageHeader, convert, reader, root_as_message, writer};
31use flatbuffers::FlatBufferBuilder;
32use prost::Message;
33use prost::bytes::Bytes as ProstBytes;
34use snafu::{OptionExt, ResultExt};
35use vec1::{Vec1, vec1};
36
37use crate::error;
38use crate::error::{DecodeFlightDataSnafu, InvalidFlightDataSnafu, Result};
39
40pub const FLOW_EXTENSIONS_METADATA_KEY: &str = "x-greptime-flow-extensions";
42
43#[derive(Debug, Clone)]
44pub enum FlightMessage {
45 Schema(SchemaRef),
46 RecordBatch(DfRecordBatch),
47 AffectedRows {
48 rows: usize,
49 metrics: Option<String>,
50 },
51 Metrics(String),
52}
53
54pub struct FlightEncoder {
55 write_options: writer::IpcWriteOptions,
56 data_gen: writer::IpcDataGenerator,
57 dictionary_tracker: writer::DictionaryTracker,
58}
59
60impl Default for FlightEncoder {
61 fn default() -> Self {
62 let write_options = writer::IpcWriteOptions::default()
63 .try_with_compression(Some(arrow::ipc::CompressionType::LZ4_FRAME))
64 .unwrap();
65
66 Self {
67 write_options,
68 data_gen: writer::IpcDataGenerator::default(),
69 dictionary_tracker: writer::DictionaryTracker::new(false),
70 }
71 }
72}
73
74impl FlightEncoder {
75 pub fn with_compression_disabled() -> Self {
77 let write_options = writer::IpcWriteOptions::default()
78 .try_with_compression(None)
79 .unwrap();
80
81 Self {
82 write_options,
83 data_gen: writer::IpcDataGenerator::default(),
84 dictionary_tracker: writer::DictionaryTracker::new(false),
85 }
86 }
87
88 pub fn encode_schema(&self, schema: &ArrowSchema) -> FlightData {
90 SchemaAsIpc::new(schema, &self.write_options).into()
91 }
92
93 pub fn encode(&mut self, flight_message: FlightMessage) -> Vec1<FlightData> {
99 match flight_message {
100 FlightMessage::Schema(schema) => {
101 schema.fields().iter().for_each(|x| {
102 if matches!(x.data_type(), DataType::Dictionary(_, _)) {
103 self.dictionary_tracker.next_dict_id();
104 }
105 });
106
107 vec1![self.encode_schema(schema.as_ref())]
108 }
109 FlightMessage::RecordBatch(record_batch) => {
110 let (encoded_dictionaries, encoded_batch) = self
111 .data_gen
112 .encode(
113 &record_batch,
114 &mut self.dictionary_tracker,
115 &self.write_options,
116 &mut Default::default(),
117 )
118 .expect("DictionaryTracker configured above to not fail on replacement");
119
120 Vec1::from_vec_push(
121 encoded_dictionaries.into_iter().map(Into::into).collect(),
122 encoded_batch.into(),
123 )
124 }
125 FlightMessage::AffectedRows { rows, metrics } => {
126 let metadata = FlightMetadata {
127 affected_rows: Some(AffectedRows { value: rows as _ }),
128 metrics: metrics.map(|s| Metrics {
129 metrics: s.into_bytes(),
130 }),
131 }
132 .encode_to_vec();
133 vec1![FlightData {
134 flight_descriptor: None,
135 data_header: build_none_flight_msg().into(),
136 app_metadata: metadata.into(),
137 data_body: ProstBytes::default(),
138 }]
139 }
140 FlightMessage::Metrics(s) => {
141 let metadata = FlightMetadata {
142 affected_rows: None,
143 metrics: Some(Metrics {
144 metrics: s.as_bytes().to_vec(),
145 }),
146 }
147 .encode_to_vec();
148 vec1![FlightData {
149 flight_descriptor: None,
150 data_header: build_none_flight_msg().into(),
151 app_metadata: metadata.into(),
152 data_body: ProstBytes::default(),
153 }]
154 }
155 }
156 }
157}
158
159#[derive(Default)]
160pub struct FlightDecoder {
161 schema: Option<SchemaRef>,
162 schema_bytes: Option<bytes::Bytes>,
163 dictionaries_by_id: HashMap<i64, ArrayRef>,
164}
165
166impl FlightDecoder {
167 pub fn try_from_schema_bytes(schema_bytes: &bytes::Bytes) -> Result<Self> {
169 let arrow_schema = convert::try_schema_from_flatbuffer_bytes(&schema_bytes[..])
170 .context(error::ArrowSnafu)?;
171 Ok(Self {
172 schema: Some(Arc::new(arrow_schema)),
173 schema_bytes: Some(schema_bytes.clone()),
174 dictionaries_by_id: HashMap::new(),
175 })
176 }
177
178 pub fn try_decode_record_batch(
179 &mut self,
180 data_header: &bytes::Bytes,
181 data_body: &bytes::Bytes,
182 ) -> Result<DfRecordBatch> {
183 let schema = self
184 .schema
185 .as_ref()
186 .context(InvalidFlightDataSnafu {
187 reason: "Should have decoded schema first!",
188 })?
189 .clone();
190 let message = root_as_message(&data_header[..])
191 .map_err(|err| {
192 ArrowError::ParseError(format!("Unable to get root as message: {err:?}"))
193 })
194 .context(error::ArrowSnafu)?;
195 let result = message
196 .header_as_record_batch()
197 .ok_or_else(|| {
198 ArrowError::ParseError(
199 "Unable to convert flight data header to a record batch".to_string(),
200 )
201 })
202 .and_then(|batch| {
203 reader::read_record_batch(
204 &Buffer::from(data_body.as_ref()),
205 batch,
206 schema,
207 &HashMap::new(),
208 None,
209 &message.version(),
210 )
211 })
212 .context(error::ArrowSnafu)?;
213 Ok(result)
214 }
215
216 pub fn try_decode(&mut self, flight_data: &FlightData) -> Result<Option<FlightMessage>> {
223 let message = root_as_message(&flight_data.data_header).map_err(|e| {
224 InvalidFlightDataSnafu {
225 reason: e.to_string(),
226 }
227 .build()
228 })?;
229 match message.header_type() {
230 MessageHeader::NONE => {
231 let metadata = FlightMetadata::decode(flight_data.app_metadata.clone())
232 .context(DecodeFlightDataSnafu)?;
233 if let Some(AffectedRows { value }) = metadata.affected_rows {
234 return Ok(Some(FlightMessage::AffectedRows {
235 rows: value as _,
236 metrics: metadata
237 .metrics
238 .map(|m| String::from_utf8_lossy(&m.metrics).to_string()),
239 }));
240 }
241 if let Some(Metrics { metrics }) = metadata.metrics {
242 return Ok(Some(FlightMessage::Metrics(
243 String::from_utf8_lossy(&metrics).to_string(),
244 )));
245 }
246 InvalidFlightDataSnafu {
247 reason: "Expecting FlightMetadata have some meaningful content.",
248 }
249 .fail()
250 }
251 MessageHeader::Schema => {
252 let arrow_schema = Arc::new(ArrowSchema::try_from(flight_data).map_err(|e| {
253 InvalidFlightDataSnafu {
254 reason: e.to_string(),
255 }
256 .build()
257 })?);
258 self.schema = Some(arrow_schema.clone());
259 self.schema_bytes = Some(flight_data.data_header.clone());
260 Ok(Some(FlightMessage::Schema(arrow_schema)))
261 }
262 MessageHeader::RecordBatch => {
263 let schema = self.schema.clone().context(InvalidFlightDataSnafu {
264 reason: "Should have decoded schema first!",
265 })?;
266 let arrow_batch = flight_data_to_arrow_batch(
267 flight_data,
268 schema.clone(),
269 &self.dictionaries_by_id,
270 )
271 .map_err(|e| {
272 InvalidFlightDataSnafu {
273 reason: e.to_string(),
274 }
275 .build()
276 })?;
277 Ok(Some(FlightMessage::RecordBatch(arrow_batch)))
278 }
279 MessageHeader::DictionaryBatch => {
280 let dictionary_batch =
281 message
282 .header_as_dictionary_batch()
283 .context(InvalidFlightDataSnafu {
284 reason: "could not get dictionary batch from DictionaryBatch message",
285 })?;
286
287 let schema = self.schema.as_ref().context(InvalidFlightDataSnafu {
288 reason: "schema message is not present previously",
289 })?;
290
291 reader::read_dictionary(
292 &flight_data.data_body.clone().into(),
293 dictionary_batch,
294 schema,
295 &mut self.dictionaries_by_id,
296 &message.version(),
297 )
298 .context(error::ArrowSnafu)?;
299 Ok(None)
300 }
301 other => {
302 let name = other.variant_name().unwrap_or("UNKNOWN");
303 InvalidFlightDataSnafu {
304 reason: format!("Unsupported FlightData type: {name}"),
305 }
306 .fail()
307 }
308 }
309 }
310
311 pub fn schema(&self) -> Option<&SchemaRef> {
312 self.schema.as_ref()
313 }
314
315 pub fn schema_bytes(&self) -> Option<bytes::Bytes> {
316 self.schema_bytes.clone()
317 }
318}
319
320pub fn flight_messages_to_recordbatches(
321 messages: Vec<FlightMessage>,
322) -> Result<Vec<DfRecordBatch>> {
323 if messages.is_empty() {
324 Ok(vec![])
325 } else {
326 let mut recordbatches = Vec::with_capacity(messages.len() - 1);
327
328 match &messages[0] {
329 FlightMessage::Schema(_schema) => {}
330 _ => {
331 return InvalidFlightDataSnafu {
332 reason: "First Flight Message must be schema!",
333 }
334 .fail();
335 }
336 };
337
338 for message in messages.into_iter().skip(1) {
339 match message {
340 FlightMessage::RecordBatch(recordbatch) => recordbatches.push(recordbatch),
341 _ => {
342 return InvalidFlightDataSnafu {
343 reason: "Expect the following Flight Messages are all Recordbatches!",
344 }
345 .fail();
346 }
347 }
348 }
349
350 Ok(recordbatches)
351 }
352}
353
354fn build_none_flight_msg() -> Bytes {
355 let mut builder = FlatBufferBuilder::new();
356
357 let mut message = arrow::ipc::MessageBuilder::new(&mut builder);
358 message.add_version(arrow::ipc::MetadataVersion::V5);
359 message.add_header_type(MessageHeader::NONE);
360 message.add_bodyLength(0);
361
362 let data = message.finish();
363 builder.finish(data, None);
364
365 builder.finished_data().into()
366}
367
368#[cfg(test)]
369mod test {
370 use arrow_flight::utils::batches_to_flight_data;
371 use datatypes::arrow::array::{
372 DictionaryArray, Int32Array, StringArray, UInt8Array, UInt32Array,
373 };
374 use datatypes::arrow::datatypes::{DataType, Field, Schema};
375
376 use super::*;
377 use crate::Error;
378
379 #[test]
380 fn test_try_decode() -> Result<()> {
381 let schema = Arc::new(ArrowSchema::new(vec![Field::new(
382 "n",
383 DataType::Int32,
384 true,
385 )]));
386
387 let batch1 = DfRecordBatch::try_new(
388 schema.clone(),
389 vec![Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])) as _],
390 )
391 .unwrap();
392 let batch2 = DfRecordBatch::try_new(
393 schema.clone(),
394 vec![Arc::new(Int32Array::from(vec![None, Some(5)])) as _],
395 )
396 .unwrap();
397
398 let flight_data =
399 batches_to_flight_data(&schema, vec![batch1.clone(), batch2.clone()]).unwrap();
400 assert_eq!(flight_data.len(), 3);
401 let [d1, d2, d3] = flight_data.as_slice() else {
402 unreachable!()
403 };
404
405 let decoder = &mut FlightDecoder::default();
406 assert!(decoder.schema.is_none());
407
408 let result = decoder.try_decode(d2);
409 assert!(matches!(result, Err(Error::InvalidFlightData { .. })));
410 assert!(
411 result
412 .unwrap_err()
413 .to_string()
414 .contains("Should have decoded schema first!")
415 );
416
417 let message = decoder.try_decode(d1)?.unwrap();
418 assert!(matches!(message, FlightMessage::Schema(_)));
419 let FlightMessage::Schema(decoded_schema) = message else {
420 unreachable!()
421 };
422 assert_eq!(decoded_schema, schema);
423
424 let _ = decoder.schema.as_ref().unwrap();
425
426 let message = decoder.try_decode(d2)?.unwrap();
427 assert!(matches!(message, FlightMessage::RecordBatch(_)));
428 let FlightMessage::RecordBatch(actual_batch) = message else {
429 unreachable!()
430 };
431 assert_eq!(actual_batch, batch1);
432
433 let message = decoder.try_decode(d3)?.unwrap();
434 assert!(matches!(message, FlightMessage::RecordBatch(_)));
435 let FlightMessage::RecordBatch(actual_batch) = message else {
436 unreachable!()
437 };
438 assert_eq!(actual_batch, batch2);
439 Ok(())
440 }
441
442 #[test]
443 fn test_affected_rows_metrics_encode_decode() -> Result<()> {
444 let metrics = r#"{"region_watermarks":[{"region_id":42,"watermark":7}]}"#;
445 let mut encoder = FlightEncoder::default();
446 let encoded = encoder.encode(FlightMessage::AffectedRows {
447 rows: 3,
448 metrics: Some(metrics.to_string()),
449 });
450
451 assert_eq!(encoded.len(), 1);
452
453 let mut decoder = FlightDecoder::default();
454 let decoded = decoder.try_decode(encoded.first())?.unwrap();
455 let FlightMessage::AffectedRows {
456 rows,
457 metrics: decoded_metrics,
458 } = decoded
459 else {
460 unreachable!()
461 };
462 assert_eq!(rows, 3);
463 assert_eq!(decoded_metrics.as_deref(), Some(metrics));
464
465 let encoded = encoder.encode(FlightMessage::AffectedRows {
466 rows: 5,
467 metrics: None,
468 });
469 let decoded = decoder.try_decode(encoded.first())?.unwrap();
470 let FlightMessage::AffectedRows {
471 rows,
472 metrics: decoded_metrics,
473 } = decoded
474 else {
475 unreachable!()
476 };
477 assert_eq!(rows, 5);
478 assert!(decoded_metrics.is_none());
479
480 Ok(())
481 }
482
483 #[test]
484 fn test_flight_messages_to_recordbatches() {
485 let schema = Arc::new(Schema::new(vec![Field::new("m", DataType::Int32, true)]));
486 let batch1 = DfRecordBatch::try_new(
487 schema.clone(),
488 vec![Arc::new(Int32Array::from(vec![Some(2), None, Some(4)])) as _],
489 )
490 .unwrap();
491 let batch2 = DfRecordBatch::try_new(
492 schema.clone(),
493 vec![Arc::new(Int32Array::from(vec![None, Some(6)])) as _],
494 )
495 .unwrap();
496 let recordbatches = vec![batch1.clone(), batch2.clone()];
497
498 let m1 = FlightMessage::Schema(schema);
499 let m2 = FlightMessage::RecordBatch(batch1);
500 let m3 = FlightMessage::RecordBatch(batch2);
501
502 let result = flight_messages_to_recordbatches(vec![m2.clone(), m1.clone(), m3.clone()]);
503 assert!(matches!(result, Err(Error::InvalidFlightData { .. })));
504 assert!(
505 result
506 .unwrap_err()
507 .to_string()
508 .contains("First Flight Message must be schema!")
509 );
510
511 let result = flight_messages_to_recordbatches(vec![m1.clone(), m2.clone(), m1.clone()]);
512 assert!(matches!(result, Err(Error::InvalidFlightData { .. })));
513 assert!(
514 result
515 .unwrap_err()
516 .to_string()
517 .contains("Expect the following Flight Messages are all Recordbatches!")
518 );
519
520 let actual = flight_messages_to_recordbatches(vec![m1, m2, m3]).unwrap();
521 assert_eq!(actual, recordbatches);
522 }
523
524 #[test]
525 fn test_flight_encode_decode_with_dictionary_array() -> Result<()> {
526 let schema = Arc::new(Schema::new(vec![
527 Field::new("i", DataType::UInt8, true),
528 Field::new_dictionary("s", DataType::UInt32, DataType::Utf8, true),
529 ]));
530 let batch1 = DfRecordBatch::try_new(
531 schema.clone(),
532 vec![
533 Arc::new(UInt8Array::from_iter_values(vec![1, 2, 3])) as _,
534 Arc::new(DictionaryArray::new(
535 UInt32Array::from_value(0, 3),
536 Arc::new(StringArray::from_iter_values(["x"])),
537 )) as _,
538 ],
539 )
540 .unwrap();
541 let batch2 = DfRecordBatch::try_new(
542 schema.clone(),
543 vec![
544 Arc::new(UInt8Array::from_iter_values(vec![4, 5, 6, 7, 8])) as _,
545 Arc::new(DictionaryArray::new(
546 UInt32Array::from_iter_values([0, 1, 2, 2, 3]),
547 Arc::new(StringArray::from_iter_values(["h", "e", "l", "o"])),
548 )) as _,
549 ],
550 )
551 .unwrap();
552
553 let message_1 = FlightMessage::Schema(schema.clone());
554 let message_2 = FlightMessage::RecordBatch(batch1);
555 let message_3 = FlightMessage::RecordBatch(batch2);
556
557 let mut encoder = FlightEncoder::default();
558 let encoded_1 = encoder.encode(message_1);
559 let encoded_2 = encoder.encode(message_2);
560 let encoded_3 = encoder.encode(message_3);
561 assert_eq!(encoded_1.len(), 1);
563 assert_eq!(encoded_2.len(), 2);
566 assert_eq!(encoded_3.len(), 2);
567
568 let mut decoder = FlightDecoder::default();
569 let decoded_1 = decoder.try_decode(encoded_1.first())?;
570 let Some(FlightMessage::Schema(actual_schema)) = decoded_1 else {
571 unreachable!()
572 };
573 assert_eq!(actual_schema, schema);
574 let decoded_2 = decoder.try_decode(&encoded_2[0])?;
575 assert!(decoded_2.is_none());
577 let Some(FlightMessage::RecordBatch(decoded_2)) = decoder.try_decode(&encoded_2[1])? else {
578 unreachable!()
579 };
580 let decoded_3 = decoder.try_decode(&encoded_3[0])?;
581 assert!(decoded_3.is_none());
583 let Some(FlightMessage::RecordBatch(decoded_3)) = decoder.try_decode(&encoded_3[1])? else {
584 unreachable!()
585 };
586 let actual = arrow::util::pretty::pretty_format_batches(&[decoded_2, decoded_3])
587 .unwrap()
588 .to_string();
589 let expected = r"
590+---+---+
591| i | s |
592+---+---+
593| 1 | x |
594| 2 | x |
595| 3 | x |
596| 4 | h |
597| 5 | e |
598| 6 | l |
599| 7 | l |
600| 8 | o |
601+---+---+";
602 assert_eq!(actual, expected.trim());
603 Ok(())
604 }
605
606 #[test]
607 fn test_affected_rows_roundtrip_through_flight_codec() {
608 let mut encoder = FlightEncoder::default();
612 let mut decoder = FlightDecoder::default();
613
614 let encoded = encoder.encode(FlightMessage::AffectedRows {
616 rows: 7,
617 metrics: None,
618 });
619 let decoded = decoder.try_decode(encoded.first()).unwrap().unwrap();
620 assert!(matches!(
621 decoded,
622 FlightMessage::AffectedRows {
623 rows: 7,
624 metrics: None,
625 }
626 ));
627
628 let json = r#"{"region_watermarks":[{"region_id":1,"watermark":99}]}"#;
630 let encoded = encoder.encode(FlightMessage::AffectedRows {
631 rows: 42,
632 metrics: Some(json.to_string()),
633 });
634 let decoded = decoder.try_decode(encoded.first()).unwrap().unwrap();
635 assert!(matches!(
636 decoded,
637 FlightMessage::AffectedRows {
638 rows: 42,
639 metrics: Some(_),
640 }
641 ));
642 }
643
644 #[test]
647 fn test_old_affected_rows_format_decoded_by_new_code() {
648 use arrow_flight::FlightData;
649 use prost::bytes::Bytes as ProstBytes;
650
651 let old_wire_bytes = FlightData {
656 flight_descriptor: None,
657 data_header: build_none_flight_msg().into(),
658 app_metadata: FlightMetadata {
659 affected_rows: Some(AffectedRows { value: 99 }),
660 metrics: None, }
662 .encode_to_vec()
663 .into(),
664 data_body: ProstBytes::default(),
665 };
666
667 let mut decoder = FlightDecoder::default();
668 let decoded = decoder.try_decode(&old_wire_bytes).unwrap().unwrap();
669 assert!(matches!(
670 decoded,
671 FlightMessage::AffectedRows {
672 rows: 99,
673 metrics: None,
674 }
675 ));
676 }
677}