mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-05-19 06:20:38 +00:00
383 lines
13 KiB
Rust
383 lines
13 KiB
Rust
// Copyright 2023 Greptime Team
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
use std::sync::Arc;
|
|
|
|
use api::v1::region::RegionRequestHeader;
|
|
use datafusion::arrow::datatypes::Schema;
|
|
use datafusion::execution::TaskContext;
|
|
use datafusion::physical_expr::expressions::Column;
|
|
use datafusion::physical_plan::PhysicalExpr;
|
|
use datafusion::physical_plan::joins::HashTableLookupExpr;
|
|
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
|
|
use datafusion_common::{DataFusionError, Result as DataFusionResult};
|
|
use datafusion_expr::LogicalPlan;
|
|
use datafusion_proto::physical_plan::DefaultPhysicalExtensionCodec;
|
|
use datafusion_proto::physical_plan::from_proto::parse_physical_expr;
|
|
use datafusion_proto::physical_plan::to_proto::serialize_physical_expr;
|
|
use datafusion_proto::protobuf::PhysicalExprNode;
|
|
use prost::Message;
|
|
use serde::{Deserialize, Serialize};
|
|
use store_api::storage::RegionId;
|
|
|
|
/// Current wire-format version for remote dynamic filter payload updates.
|
|
pub const DYN_FILTER_PROTOCOL_VERSION: u32 = 1;
|
|
|
|
/// Serialized predicate payload for remote dynamic filter updates.
|
|
///
|
|
/// The payload is tagged in JSON so receivers can reject unsupported encodings
|
|
/// before decoding engine-specific bytes. For DataFusion expressions the
|
|
/// `payload` bytes are serialized by `serde_json` as a base64 string, for example:
|
|
///
|
|
/// ```json
|
|
/// { "kind": "datafusion", "payload": "CQgH" }
|
|
/// ```
|
|
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
|
#[non_exhaustive]
|
|
#[serde(tag = "kind", content = "payload", rename_all = "snake_case")]
|
|
pub enum DynFilterPayload {
|
|
/// A serialized DataFusion [`PhysicalExpr`] encoded as a protobuf
|
|
/// [`PhysicalExprNode`].
|
|
Datafusion(#[serde(with = "base64_bytes")] Vec<u8>),
|
|
}
|
|
|
|
mod base64_bytes {
|
|
use base64::Engine;
|
|
use base64::prelude::BASE64_STANDARD;
|
|
use serde::de::Error;
|
|
use serde::{Deserialize, Deserializer, Serializer};
|
|
|
|
pub fn serialize<S>(bytes: &[u8], serializer: S) -> Result<S::Ok, S::Error>
|
|
where
|
|
S: Serializer,
|
|
{
|
|
serializer.serialize_str(&BASE64_STANDARD.encode(bytes))
|
|
}
|
|
|
|
pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
|
|
where
|
|
D: Deserializer<'de>,
|
|
{
|
|
let encoded = String::deserialize(deserializer)?;
|
|
BASE64_STANDARD.decode(encoded).map_err(|err| {
|
|
D::Error::custom(format!("invalid base64 dynamic filter payload: {err}"))
|
|
})
|
|
}
|
|
}
|
|
|
|
impl DynFilterPayload {
|
|
/// Encodes a DataFusion physical expression into a bounded dynamic filter payload.
|
|
///
|
|
/// This rejects expressions that cannot be safely shipped as dynamic filter
|
|
/// predicates and fails if the serialized payload exceeds `max_payload_bytes`.
|
|
pub fn from_datafusion_expr(
|
|
expr: &Arc<dyn PhysicalExpr>,
|
|
max_payload_bytes: usize,
|
|
) -> DataFusionResult<Self> {
|
|
validate_supported_payload_expr(expr)?;
|
|
|
|
let codec = DefaultPhysicalExtensionCodec {};
|
|
let proto = serialize_physical_expr(expr, &codec)?;
|
|
let mut bytes = Vec::new();
|
|
proto.encode(&mut bytes).map_err(|e| {
|
|
DataFusionError::Internal(format!("Failed to encode PhysicalExprNode: {e}"))
|
|
})?;
|
|
|
|
validate_payload_size(bytes.len(), max_payload_bytes)?;
|
|
|
|
Ok(Self::Datafusion(bytes))
|
|
}
|
|
|
|
pub fn decode_datafusion_expr(
|
|
&self,
|
|
task_ctx: &TaskContext,
|
|
input_schema: &Schema,
|
|
max_payload_bytes: usize,
|
|
) -> DataFusionResult<Arc<dyn PhysicalExpr>> {
|
|
let Self::Datafusion(bytes) = self;
|
|
validate_payload_size(bytes.len(), max_payload_bytes)?;
|
|
let codec = DefaultPhysicalExtensionCodec {};
|
|
let proto = PhysicalExprNode::decode(bytes.as_slice()).map_err(|e| {
|
|
DataFusionError::Internal(format!("Failed to decode PhysicalExprNode: {e}"))
|
|
})?;
|
|
|
|
let expr = parse_physical_expr(&proto, task_ctx, input_schema, &codec)?;
|
|
validate_supported_payload_expr(&expr)?;
|
|
validate_decoded_payload_expr(&expr, input_schema)?;
|
|
Ok(expr)
|
|
}
|
|
}
|
|
|
|
fn validate_payload_size(
|
|
payload_size_bytes: usize,
|
|
max_payload_bytes: usize,
|
|
) -> DataFusionResult<()> {
|
|
if payload_size_bytes > max_payload_bytes {
|
|
return Err(DataFusionError::Plan(format!(
|
|
"DynFilterPayload::Datafusion is {} bytes, which exceeds the configured limit of {} bytes",
|
|
payload_size_bytes, max_payload_bytes
|
|
)));
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
fn validate_supported_payload_expr(expr: &Arc<dyn PhysicalExpr>) -> DataFusionResult<()> {
|
|
expr.apply(|node| {
|
|
if node.as_any().is::<HashTableLookupExpr>() {
|
|
return Err(DataFusionError::Plan(
|
|
"HashTableLookupExpr cannot be encoded into DynFilterPayload::Datafusion"
|
|
.to_string(),
|
|
));
|
|
}
|
|
|
|
Ok(TreeNodeRecursion::Continue)
|
|
})?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
fn validate_decoded_payload_expr(
|
|
expr: &Arc<dyn PhysicalExpr>,
|
|
input_schema: &Schema,
|
|
) -> DataFusionResult<()> {
|
|
expr.apply(|node| {
|
|
if let Some(column) = node.as_any().downcast_ref::<Column>() {
|
|
if input_schema.fields().get(column.index()).is_none() {
|
|
return Err(DataFusionError::Plan(format!(
|
|
"Decoded Column '{}' references out-of-bounds index {} for input schema of size {}",
|
|
column.name(),
|
|
column.index(),
|
|
input_schema.fields().len()
|
|
)));
|
|
}
|
|
}
|
|
|
|
Ok(TreeNodeRecursion::Continue)
|
|
})?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// A remote dynamic filter update sent from a query coordinator to region servers.
|
|
///
|
|
/// `generation` is monotonic within a `query_id`/`filter_id` pair and matches the
|
|
/// gRPC field name used by `RemoteDynFilterUpdate`. Receivers use it to ignore
|
|
/// stale updates while `is_complete` marks the final payload for the filter.
|
|
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
|
|
pub struct DynFilterUpdate {
|
|
/// Protocol version used by this update payload.
|
|
pub protocol_version: u32,
|
|
/// Internal query identifier that owns this dynamic filter lifecycle.
|
|
pub query_id: String,
|
|
/// Identifier of the dynamic filter within the query.
|
|
pub filter_id: String,
|
|
/// Monotonic update generation for this filter.
|
|
pub generation: u64,
|
|
/// Whether this update completes the dynamic filter stream.
|
|
pub is_complete: bool,
|
|
/// Serialized predicate payload carried by this update.
|
|
pub payload: DynFilterPayload,
|
|
}
|
|
|
|
impl DynFilterUpdate {
|
|
/// Creates a dynamic filter update with the current protocol version.
|
|
pub fn new(
|
|
query_id: String,
|
|
filter_id: String,
|
|
generation: u64,
|
|
is_complete: bool,
|
|
payload: DynFilterPayload,
|
|
) -> Self {
|
|
Self {
|
|
protocol_version: DYN_FILTER_PROTOCOL_VERSION,
|
|
query_id,
|
|
filter_id,
|
|
generation,
|
|
is_complete,
|
|
payload,
|
|
}
|
|
}
|
|
}
|
|
|
|
/// The query request to be handled by the RegionServer (Datanode).
|
|
#[derive(Clone, Debug)]
|
|
pub struct QueryRequest {
|
|
/// The header of this request. Often to store some context of the query. None means all to defaults.
|
|
pub header: Option<RegionRequestHeader>,
|
|
|
|
/// The id of the region to be queried.
|
|
pub region_id: RegionId,
|
|
|
|
/// The form of the query: a logical plan.
|
|
pub plan: LogicalPlan,
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use std::sync::Arc;
|
|
|
|
use base64::Engine;
|
|
use base64::prelude::BASE64_STANDARD;
|
|
use datafusion::arrow::datatypes::{DataType, Field, Schema};
|
|
use datafusion::physical_expr::expressions::Column;
|
|
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn dyn_filter_update_sets_protocol_version() {
|
|
let update = DynFilterUpdate::new(
|
|
"query-1".to_string(),
|
|
"filter-1".to_string(),
|
|
3,
|
|
false,
|
|
DynFilterPayload::Datafusion(vec![1, 2, 3]),
|
|
);
|
|
|
|
assert_eq!(update.protocol_version, DYN_FILTER_PROTOCOL_VERSION);
|
|
assert!(!update.is_complete);
|
|
assert!(
|
|
matches!(update.payload, DynFilterPayload::Datafusion(ref bytes) if bytes == &vec![1, 2, 3])
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn dyn_filter_update_json_round_trip_preserves_payload_shape() {
|
|
let update = DynFilterUpdate::new(
|
|
"query-2".to_string(),
|
|
"filter-9".to_string(),
|
|
9,
|
|
true,
|
|
DynFilterPayload::Datafusion(vec![9, 8, 7]),
|
|
);
|
|
|
|
let json = serde_json::to_string(&update).unwrap();
|
|
let value: serde_json::Value = serde_json::from_str(&json).unwrap();
|
|
let decoded: DynFilterUpdate = serde_json::from_str(&json).unwrap();
|
|
|
|
assert_eq!(value["generation"], serde_json::json!(9));
|
|
assert!(value.get("epoch").is_none());
|
|
assert_eq!(
|
|
value["payload"],
|
|
serde_json::json!({ "kind": "datafusion", "payload": BASE64_STANDARD.encode([9, 8, 7]) })
|
|
);
|
|
assert_eq!(decoded, update);
|
|
assert!(decoded.is_complete);
|
|
assert!(
|
|
matches!(decoded.payload, DynFilterPayload::Datafusion(ref bytes) if bytes == &vec![9, 8, 7])
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn dyn_filter_payload_json_uses_base64_for_empty_and_padded_payloads() {
|
|
let empty = serde_json::to_value(DynFilterPayload::Datafusion(vec![])).unwrap();
|
|
let one = serde_json::to_value(DynFilterPayload::Datafusion(vec![1])).unwrap();
|
|
let two = serde_json::to_value(DynFilterPayload::Datafusion(vec![1, 2])).unwrap();
|
|
|
|
assert_eq!(
|
|
empty,
|
|
serde_json::json!({"kind": "datafusion", "payload": ""})
|
|
);
|
|
assert_eq!(
|
|
one,
|
|
serde_json::json!({"kind": "datafusion", "payload": BASE64_STANDARD.encode([1])})
|
|
);
|
|
assert_eq!(
|
|
two,
|
|
serde_json::json!({"kind": "datafusion", "payload": BASE64_STANDARD.encode([1, 2])})
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn dyn_filter_payload_json_rejects_invalid_base64() {
|
|
let err = serde_json::from_value::<DynFilterPayload>(serde_json::json!({
|
|
"kind": "datafusion",
|
|
"payload": "not base64!",
|
|
}))
|
|
.unwrap_err();
|
|
|
|
assert!(
|
|
err.to_string()
|
|
.contains("invalid base64 dynamic filter payload")
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn dyn_filter_payload_round_trips_physical_column_expr() {
|
|
let schema = Schema::new(vec![Field::new("host", DataType::Utf8, false)]);
|
|
let expr: Arc<dyn PhysicalExpr> =
|
|
Arc::new(Column::new_with_schema("host", &schema).unwrap());
|
|
|
|
let payload = DynFilterPayload::from_datafusion_expr(&expr, 1024).unwrap();
|
|
let decoded = payload
|
|
.decode_datafusion_expr(&TaskContext::default(), &schema, 1024)
|
|
.unwrap();
|
|
|
|
let original = expr.as_any().downcast_ref::<Column>().unwrap();
|
|
let decoded = decoded.as_any().downcast_ref::<Column>().unwrap();
|
|
|
|
assert_eq!(decoded.name(), original.name());
|
|
assert_eq!(decoded.index(), original.index());
|
|
}
|
|
|
|
#[test]
|
|
fn dyn_filter_payload_decode_rejects_invalid_bytes() {
|
|
let schema = Schema::new(vec![Field::new("host", DataType::Utf8, false)]);
|
|
let payload = DynFilterPayload::Datafusion(vec![1, 2, 3]);
|
|
|
|
let err = payload
|
|
.decode_datafusion_expr(&TaskContext::default(), &schema, 1024)
|
|
.unwrap_err();
|
|
|
|
assert!(matches!(err, DataFusionError::Internal(_)));
|
|
}
|
|
|
|
#[test]
|
|
fn dyn_filter_payload_decode_accepts_column_name_mismatch_when_index_is_valid() {
|
|
let schema = Schema::new(vec![Field::new("host", DataType::Utf8, false)]);
|
|
let expr: Arc<dyn PhysicalExpr> = Arc::new(Column::new("service", 0));
|
|
|
|
let payload = DynFilterPayload::from_datafusion_expr(&expr, 1024).unwrap();
|
|
let decoded = payload
|
|
.decode_datafusion_expr(&TaskContext::default(), &schema, 1024)
|
|
.unwrap();
|
|
|
|
let decoded = decoded.as_any().downcast_ref::<Column>().unwrap();
|
|
|
|
assert_eq!(decoded.index(), 0);
|
|
}
|
|
|
|
#[test]
|
|
fn dyn_filter_payload_decode_rejects_out_of_bounds_column_index() {
|
|
let schema = Schema::new(vec![Field::new("host", DataType::Utf8, false)]);
|
|
let expr: Arc<dyn PhysicalExpr> = Arc::new(Column::new("host", 1));
|
|
|
|
let payload = DynFilterPayload::from_datafusion_expr(&expr, 1024).unwrap();
|
|
let err = payload
|
|
.decode_datafusion_expr(&TaskContext::default(), &schema, 1024)
|
|
.unwrap_err();
|
|
|
|
assert!(matches!(err, DataFusionError::Plan(_)));
|
|
}
|
|
|
|
#[test]
|
|
fn dyn_filter_payload_rejects_oversized_payload() {
|
|
let expr: Arc<dyn PhysicalExpr> = Arc::new(Column::new("host", 0));
|
|
|
|
let err = DynFilterPayload::from_datafusion_expr(&expr, 1).unwrap_err();
|
|
|
|
assert!(matches!(err, DataFusionError::Plan(_)));
|
|
}
|
|
}
|