refactor: per review

Signed-off-by: discord9 <discord9@163.com>
This commit is contained in:
discord9
2026-04-29 16:54:24 +08:00
parent 5166282269
commit 04fd62dccc
7 changed files with 497 additions and 246 deletions

View File

@@ -41,7 +41,10 @@ use crate::error::{DecodeFlightDataSnafu, InvalidFlightDataSnafu, Result};
pub enum FlightMessage {
Schema(SchemaRef),
RecordBatch(DfRecordBatch),
AffectedRows(usize),
AffectedRows {
rows: usize,
metrics: Option<String>,
},
Metrics(String),
}
@@ -116,10 +119,12 @@ impl FlightEncoder {
encoded_batch.into(),
)
}
FlightMessage::AffectedRows(rows) => {
FlightMessage::AffectedRows { rows, metrics } => {
let metadata = FlightMetadata {
affected_rows: Some(AffectedRows { value: rows as _ }),
metrics: None,
metrics: metrics.map(|s| Metrics {
metrics: s.into_bytes(),
}),
}
.encode_to_vec();
vec1![FlightData {
@@ -223,7 +228,12 @@ impl FlightDecoder {
let metadata = FlightMetadata::decode(flight_data.app_metadata.clone())
.context(DecodeFlightDataSnafu)?;
if let Some(AffectedRows { value }) = metadata.affected_rows {
return Ok(Some(FlightMessage::AffectedRows(value as _)));
return Ok(Some(FlightMessage::AffectedRows {
rows: value as _,
metrics: metadata
.metrics
.map(|m| String::from_utf8_lossy(&m.metrics).to_string()),
}));
}
if let Some(Metrics { metrics }) = metadata.metrics {
return Ok(Some(FlightMessage::Metrics(
@@ -426,6 +436,47 @@ mod test {
Ok(())
}
#[test]
fn test_affected_rows_metrics_encode_decode() -> Result<()> {
let metrics = r#"{"region_watermarks":[{"region_id":42,"watermark":7}]}"#;
let mut encoder = FlightEncoder::default();
let encoded = encoder.encode(FlightMessage::AffectedRows {
rows: 3,
metrics: Some(metrics.to_string()),
});
assert_eq!(encoded.len(), 1);
let mut decoder = FlightDecoder::default();
let decoded = decoder.try_decode(encoded.first())?.unwrap();
let FlightMessage::AffectedRows {
rows,
metrics: decoded_metrics,
} = decoded
else {
unreachable!()
};
assert_eq!(rows, 3);
assert_eq!(decoded_metrics.as_deref(), Some(metrics));
let encoded = encoder.encode(FlightMessage::AffectedRows {
rows: 5,
metrics: None,
});
let decoded = decoder.try_decode(encoded.first())?.unwrap();
let FlightMessage::AffectedRows {
rows,
metrics: decoded_metrics,
} = decoded
else {
unreachable!()
};
assert_eq!(rows, 5);
assert!(decoded_metrics.is_none());
Ok(())
}
#[test]
fn test_flight_messages_to_recordbatches() {
let schema = Arc::new(Schema::new(vec![Field::new("m", DataType::Int32, true)]));