feat: dyn filter update abi

Signed-off-by: discord9 <discord9@163.com>
This commit is contained in:
discord9
2026-04-01 14:25:52 +08:00
parent 76cad696c6
commit 727e681fd5
10 changed files with 417 additions and 12 deletions

41
Cargo.lock generated
View File

@@ -2632,10 +2632,13 @@ dependencies = [
"datafusion",
"datafusion-common",
"datafusion-expr",
"datafusion-proto",
"datatypes",
"futures-util",
"once_cell",
"prost 0.14.1",
"serde",
"serde_json",
"snafu 0.8.6",
"sqlparser",
"sqlparser_derive 0.1.1",
@@ -4129,6 +4132,43 @@ dependencies = [
"tokio",
]
[[package]]
name = "datafusion-proto"
version = "52.1.0"
source = "git+https://github.com/GreptimeTeam/datafusion.git?rev=02b82535e0160c4545667f36a03e1ff9d1d2e51f#02b82535e0160c4545667f36a03e1ff9d1d2e51f"
dependencies = [
"arrow 57.3.0",
"chrono",
"datafusion-catalog",
"datafusion-catalog-listing",
"datafusion-common",
"datafusion-datasource",
"datafusion-datasource-arrow",
"datafusion-datasource-csv",
"datafusion-datasource-json",
"datafusion-datasource-parquet",
"datafusion-execution",
"datafusion-expr",
"datafusion-functions-table",
"datafusion-physical-expr",
"datafusion-physical-expr-common",
"datafusion-physical-plan",
"datafusion-proto-common",
"object_store",
"prost 0.14.1",
"rand 0.9.1",
]
[[package]]
name = "datafusion-proto-common"
version = "52.1.0"
source = "git+https://github.com/GreptimeTeam/datafusion.git?rev=02b82535e0160c4545667f36a03e1ff9d1d2e51f#02b82535e0160c4545667f36a03e1ff9d1d2e51f"
dependencies = [
"arrow 57.3.0",
"datafusion-common",
"prost 0.14.1",
]
[[package]]
name = "datafusion-pruning"
version = "52.1.0"
@@ -12137,6 +12177,7 @@ dependencies = [
"derive_more",
"snafu 0.8.6",
"sql",
"uuid",
]
[[package]]

View File

@@ -139,6 +139,7 @@ datafusion-orc = "0.7"
datafusion-pg-catalog = "0.15.1"
datafusion-physical-expr = "=52.1"
datafusion-physical-plan = "=52.1"
datafusion-proto = "=52.1"
datafusion-sql = "=52.1"
datafusion-substrait = "=52.1"
deadpool = "0.12"
@@ -251,7 +252,7 @@ tracing-appender = "0.2"
tracing-opentelemetry = "0.31.0"
tracing-subscriber = { version = "0.3", features = ["env-filter", "json", "fmt"] }
typetag = "0.2"
uuid = { version = "1.17", features = ["serde", "v4", "fast-rng"] }
uuid = { version = "1.17", features = ["serde", "v4", "v7", "fast-rng"] }
vrl = "0.25"
zstd = "0.13"
# DO_NOT_REMOVE_THIS: END_OF_EXTERNAL_DEPENDENCIES
@@ -341,6 +342,7 @@ datafusion-optimizer = { git = "https://github.com/GreptimeTeam/datafusion.git",
datafusion-physical-expr = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "02b82535e0160c4545667f36a03e1ff9d1d2e51f" }
datafusion-physical-expr-common = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "02b82535e0160c4545667f36a03e1ff9d1d2e51f" }
datafusion-physical-plan = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "02b82535e0160c4545667f36a03e1ff9d1d2e51f" }
datafusion-proto = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "02b82535e0160c4545667f36a03e1ff9d1d2e51f" }
datafusion-datasource = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "02b82535e0160c4545667f36a03e1ff9d1d2e51f" }
datafusion-sql = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "02b82535e0160c4545667f36a03e1ff9d1d2e51f" }
datafusion-substrait = { git = "https://github.com/GreptimeTeam/datafusion.git", rev = "02b82535e0160c4545667f36a03e1ff9d1d2e51f" }

View File

@@ -22,8 +22,10 @@ common-time.workspace = true
datafusion.workspace = true
datafusion-common.workspace = true
datafusion-expr.workspace = true
datafusion-proto.workspace = true
datatypes.workspace = true
once_cell.workspace = true
prost.workspace = true
serde.workspace = true
snafu.workspace = true
sqlparser.workspace = true
@@ -33,4 +35,5 @@ store-api.workspace = true
[dev-dependencies]
common-base.workspace = true
futures-util.workspace = true
serde_json.workspace = true
tokio.workspace = true

View File

@@ -12,10 +12,162 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use api::v1::region::RegionRequestHeader;
use datafusion::arrow::datatypes::Schema;
use datafusion::execution::TaskContext;
use datafusion::physical_expr::expressions::Column;
use datafusion::physical_plan::PhysicalExpr;
use datafusion::physical_plan::joins::HashTableLookupExpr;
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
use datafusion_common::{DataFusionError, Result as DataFusionResult};
use datafusion_expr::LogicalPlan;
use datafusion_proto::physical_plan::DefaultPhysicalExtensionCodec;
use datafusion_proto::physical_plan::from_proto::parse_physical_expr;
use datafusion_proto::physical_plan::to_proto::serialize_physical_expr;
use datafusion_proto::protobuf::PhysicalExprNode;
use prost::Message;
use serde::{Deserialize, Serialize};
use store_api::storage::RegionId;
pub const DYN_FILTER_PROTOCOL_VERSION: u32 = 1;
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
#[serde(tag = "kind", content = "payload", rename_all = "snake_case")]
pub enum DynFilterPayload {
Datafusion(Vec<u8>),
}
impl DynFilterPayload {
pub fn from_datafusion_expr(
expr: &Arc<dyn PhysicalExpr>,
max_payload_bytes: usize,
) -> DataFusionResult<Self> {
validate_supported_payload_expr(expr)?;
let codec = DefaultPhysicalExtensionCodec {};
let proto = serialize_physical_expr(expr, &codec)?;
let mut bytes = Vec::new();
proto.encode(&mut bytes).map_err(|e| {
DataFusionError::Internal(format!("Failed to encode PhysicalExprNode: {e}"))
})?;
validate_payload_size(bytes.len(), max_payload_bytes)?;
Ok(Self::Datafusion(bytes))
}
pub fn decode_datafusion_expr(
&self,
task_ctx: &TaskContext,
input_schema: &Schema,
max_payload_bytes: usize,
) -> DataFusionResult<Arc<dyn PhysicalExpr>> {
let Self::Datafusion(bytes) = self;
validate_payload_size(bytes.len(), max_payload_bytes)?;
let codec = DefaultPhysicalExtensionCodec {};
let proto = PhysicalExprNode::decode(bytes.as_slice()).map_err(|e| {
DataFusionError::Internal(format!("Failed to decode PhysicalExprNode: {e}"))
})?;
let expr = parse_physical_expr(&proto, task_ctx, input_schema, &codec)?;
validate_supported_payload_expr(&expr)?;
validate_decoded_payload_expr(&expr, input_schema)?;
Ok(expr)
}
}
fn validate_payload_size(
payload_size_bytes: usize,
max_payload_bytes: usize,
) -> DataFusionResult<()> {
if payload_size_bytes > max_payload_bytes {
return Err(DataFusionError::Plan(format!(
"DynFilterPayload::Datafusion is {} bytes, which exceeds the configured limit of {} bytes",
payload_size_bytes, max_payload_bytes
)));
}
Ok(())
}
fn validate_supported_payload_expr(expr: &Arc<dyn PhysicalExpr>) -> DataFusionResult<()> {
expr.apply(|node| {
if node.as_any().is::<HashTableLookupExpr>() {
return Err(DataFusionError::Plan(
"HashTableLookupExpr cannot be encoded into DynFilterPayload::Datafusion"
.to_string(),
));
}
Ok(TreeNodeRecursion::Continue)
})?;
Ok(())
}
fn validate_decoded_payload_expr(
expr: &Arc<dyn PhysicalExpr>,
input_schema: &Schema,
) -> DataFusionResult<()> {
expr.apply(|node| {
if let Some(column) = node.as_any().downcast_ref::<Column>() {
let Some(field) = input_schema.fields().get(column.index()) else {
return Err(DataFusionError::Plan(format!(
"Decoded Column '{}' references out-of-bounds index {} for input schema of size {}",
column.name(),
column.index(),
input_schema.fields().len()
)));
};
if field.name() != column.name() {
return Err(DataFusionError::Plan(format!(
"Decoded Column name/index mismatch: payload has '{}' at index {}, but schema field is '{}'",
column.name(),
column.index(),
field.name()
)));
}
}
Ok(TreeNodeRecursion::Continue)
})?;
Ok(())
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct DynFilterUpdate {
pub protocol_version: u32,
pub query_id: String,
pub filter_id: String,
pub epoch: u64,
pub is_complete: bool,
pub payload: DynFilterPayload,
}
impl DynFilterUpdate {
pub fn new(
query_id: String,
filter_id: String,
epoch: u64,
is_complete: bool,
payload: DynFilterPayload,
) -> Self {
Self {
protocol_version: DYN_FILTER_PROTOCOL_VERSION,
query_id,
filter_id,
epoch,
is_complete,
payload,
}
}
}
/// The query request to be handled by the RegionServer (Datanode).
#[derive(Clone, Debug)]
pub struct QueryRequest {
@@ -28,3 +180,102 @@ pub struct QueryRequest {
/// The form of the query: a logical plan.
pub plan: LogicalPlan,
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use datafusion::arrow::datatypes::{DataType, Field, Schema};
use datafusion::physical_expr::expressions::Column;
use super::*;
#[test]
fn dyn_filter_update_sets_protocol_version() {
let update = DynFilterUpdate::new(
"query-1".to_string(),
"filter-1".to_string(),
3,
false,
DynFilterPayload::Datafusion(vec![1, 2, 3]),
);
assert_eq!(update.protocol_version, DYN_FILTER_PROTOCOL_VERSION);
assert!(!update.is_complete);
assert!(
matches!(update.payload, DynFilterPayload::Datafusion(ref bytes) if bytes == &vec![1, 2, 3])
);
}
#[test]
fn dyn_filter_update_json_round_trip_preserves_payload_shape() {
let update = DynFilterUpdate::new(
"query-2".to_string(),
"filter-9".to_string(),
9,
true,
DynFilterPayload::Datafusion(vec![9, 8, 7]),
);
let json = serde_json::to_string(&update).unwrap();
let decoded: DynFilterUpdate = serde_json::from_str(&json).unwrap();
assert_eq!(decoded, update);
assert!(decoded.is_complete);
assert!(
matches!(decoded.payload, DynFilterPayload::Datafusion(ref bytes) if bytes == &vec![9, 8, 7])
);
}
#[test]
fn dyn_filter_payload_round_trips_physical_column_expr() {
let schema = Schema::new(vec![Field::new("host", DataType::Utf8, false)]);
let expr: Arc<dyn PhysicalExpr> =
Arc::new(Column::new_with_schema("host", &schema).unwrap());
let payload = DynFilterPayload::from_datafusion_expr(&expr, 1024).unwrap();
let decoded = payload
.decode_datafusion_expr(&TaskContext::default(), &schema, 1024)
.unwrap();
let original = expr.as_any().downcast_ref::<Column>().unwrap();
let decoded = decoded.as_any().downcast_ref::<Column>().unwrap();
assert_eq!(decoded.name(), original.name());
assert_eq!(decoded.index(), original.index());
}
#[test]
fn dyn_filter_payload_decode_rejects_invalid_bytes() {
let schema = Schema::new(vec![Field::new("host", DataType::Utf8, false)]);
let payload = DynFilterPayload::Datafusion(vec![1, 2, 3]);
let err = payload
.decode_datafusion_expr(&TaskContext::default(), &schema, 1024)
.unwrap_err();
assert!(matches!(err, DataFusionError::Internal(_)));
}
#[test]
fn dyn_filter_payload_decode_rejects_column_name_index_mismatch() {
let schema = Schema::new(vec![Field::new("host", DataType::Utf8, false)]);
let mismatched_expr: Arc<dyn PhysicalExpr> = Arc::new(Column::new("service", 0));
let payload = DynFilterPayload::from_datafusion_expr(&mismatched_expr, 1024).unwrap();
let err = payload
.decode_datafusion_expr(&TaskContext::default(), &schema, 1024)
.unwrap_err();
assert!(matches!(err, DataFusionError::Plan(_)));
}
#[test]
fn dyn_filter_payload_rejects_oversized_payload() {
let expr: Arc<dyn PhysicalExpr> = Arc::new(Column::new("host", 0));
let err = DynFilterPayload::from_datafusion_expr(&expr, 1).unwrap_err();
assert!(matches!(err, DataFusionError::Plan(_)));
}
}

View File

@@ -20,7 +20,10 @@ use auth::{Identity, Password, UserInfoRef, UserProviderRef};
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use common_catalog::parse_catalog_and_schema_from_db_string;
use common_error::ext::ErrorExt;
use session::context::{Channel, QueryContextBuilder, QueryContextRef};
use session::context::{
Channel, QueryContextBuilder, QueryContextRef, REMOTE_QUERY_ID_EXTENSION_KEY,
generate_remote_query_id,
};
use snafu::{OptionExt, ResultExt};
use tonic::Status;
use tonic::metadata::MetadataMap;
@@ -50,6 +53,10 @@ pub fn create_query_context_from_grpc_metadata(
.current_catalog(catalog)
.current_schema(schema)
.channel(Channel::Grpc)
.set_extension(
REMOTE_QUERY_ID_EXTENSION_KEY.to_string(),
generate_remote_query_id(),
)
.build(),
))
}

View File

@@ -33,7 +33,10 @@ use common_telemetry::tracing_context::{FutureExt, TracingContext};
use common_telemetry::{debug, error, tracing, warn};
use common_time::timezone::parse_timezone;
use futures_util::StreamExt;
use session::context::{Channel, QueryContextBuilder, QueryContextRef};
use session::context::{
Channel, QueryContextBuilder, QueryContextRef, REMOTE_QUERY_ID_EXTENSION_KEY,
generate_remote_query_id,
};
use session::hints::READ_PREFERENCE_HINT;
use snafu::{OptionExt, ResultExt};
use tokio::sync::mpsc;
@@ -214,7 +217,11 @@ pub(crate) fn create_query_context(
.current_catalog(catalog)
.current_schema(schema)
.timezone(timezone)
.channel(channel);
.channel(channel)
.set_extension(
REMOTE_QUERY_ID_EXTENSION_KEY.to_string(),
generate_remote_query_id(),
);
if let Some(x) = extensions
.iter()
@@ -308,9 +315,16 @@ mod tests {
query_context.read_preference(),
ReadPreference::Leader
));
let mut extensions = query_context.extensions().into_iter().collect::<Vec<_>>();
extensions.sort_unstable_by(|a, b| a.0.cmp(&b.0));
assert_eq!(
query_context.extensions().into_iter().collect::<Vec<_>>(),
vec![("auto_create_table".to_string(), "true".to_string())]
extensions[0],
("auto_create_table".to_string(), "true".to_string())
);
assert_eq!(extensions[1].0, REMOTE_QUERY_ID_EXTENSION_KEY.to_string());
assert_eq!(
query_context.remote_query_id(),
Some(extensions[1].1.as_str())
);
}
}

View File

@@ -28,7 +28,9 @@ use common_telemetry::warn;
use common_time::Timezone;
use common_time::timezone::parse_timezone;
use headers::Header;
use session::context::QueryContextBuilder;
use session::context::{
QueryContextBuilder, REMOTE_QUERY_ID_EXTENSION_KEY, generate_remote_query_id,
};
use snafu::{OptionExt, ResultExt, ensure};
use crate::error::{
@@ -64,7 +66,11 @@ pub async fn inner_auth<B>(
let query_ctx_builder = QueryContextBuilder::default()
.current_catalog(catalog.clone())
.current_schema(schema.clone())
.timezone(timezone);
.timezone(timezone)
.set_extension(
REMOTE_QUERY_ID_EXTENSION_KEY.to_string(),
generate_remote_query_id(),
);
let query_ctx = query_ctx_builder.build();
let need_auth = need_auth(&req);
@@ -388,6 +394,19 @@ mod tests {
assert!(auth_scheme.is_err());
}
#[test]
fn test_inner_auth_assigns_remote_query_id() {
let req =
mock_http_request(None, Some("http://127.0.0.1/v1/sql?db=greptime-public")).unwrap();
let req = futures::executor::block_on(inner_auth::<()>(None, req)).unwrap();
let query_ctx = req
.extensions()
.get::<session::context::QueryContext>()
.unwrap();
assert!(query_ctx.remote_query_id().is_some());
}
#[test]
fn test_auth_header() {
// base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ="

View File

@@ -27,3 +27,4 @@ derive_builder.workspace = true
derive_more.workspace = true
snafu.workspace = true
sql.workspace = true
uuid.workspace = true

View File

@@ -31,6 +31,7 @@ use common_time::timezone::parse_timezone;
use datafusion_common::config::ConfigOptions;
use derive_builder::Builder;
use sql::dialect::{Dialect, GenericDialect, GreptimeDbDialect, MySqlDialect, PostgreSqlDialect};
use uuid::Uuid;
use crate::protocol_ctx::ProtocolCtx;
use crate::session_config::{PGByteaOutputValue, PGDateOrder, PGDateTimeStyle, PGIntervalStyle};
@@ -40,6 +41,11 @@ pub type QueryContextRef = Arc<QueryContext>;
pub type ConnInfoRef = Arc<ConnInfo>;
const CURSOR_COUNT_WARNING_LIMIT: usize = 10;
pub const REMOTE_QUERY_ID_EXTENSION_KEY: &str = "remote_query_id";
pub fn generate_remote_query_id() -> String {
Uuid::now_v7().to_string()
}
#[derive(Debug, Builder, Clone)]
#[builder(pattern = "owned")]
@@ -152,7 +158,12 @@ impl From<&RegionRequestHeader> for QueryContext {
if let Some(ctx) = &value.query_context {
ctx.clone().into()
} else {
QueryContextBuilder::default().build()
QueryContextBuilder::default()
.set_extension(
REMOTE_QUERY_ID_EXTENSION_KEY.to_string(),
generate_remote_query_id(),
)
.build()
}
}
}
@@ -219,7 +230,14 @@ impl From<&QueryContext> for api::v1::QueryContext {
impl QueryContext {
pub fn arc() -> QueryContextRef {
Arc::new(QueryContextBuilder::default().build())
Arc::new(
QueryContextBuilder::default()
.set_extension(
REMOTE_QUERY_ID_EXTENSION_KEY.to_string(),
generate_remote_query_id(),
)
.build(),
)
}
/// Create a new datafusion's ConfigOptions instance based on the current QueryContext.
@@ -233,6 +251,10 @@ impl QueryContext {
QueryContextBuilder::default()
.current_catalog(catalog.to_string())
.current_schema(schema.to_string())
.set_extension(
REMOTE_QUERY_ID_EXTENSION_KEY.to_string(),
generate_remote_query_id(),
)
.build()
}
@@ -241,6 +263,10 @@ impl QueryContext {
.current_catalog(catalog.to_string())
.current_schema(schema.to_string())
.channel(channel)
.set_extension(
REMOTE_QUERY_ID_EXTENSION_KEY.to_string(),
generate_remote_query_id(),
)
.build()
}
@@ -259,6 +285,10 @@ impl QueryContext {
QueryContextBuilder::default()
.current_catalog(catalog)
.current_schema(schema.clone())
.set_extension(
REMOTE_QUERY_ID_EXTENSION_KEY.to_string(),
generate_remote_query_id(),
)
.build()
}
@@ -320,6 +350,10 @@ impl QueryContext {
self.extensions.get(key.as_ref()).map(|v| v.as_str())
}
pub fn remote_query_id(&self) -> Option<&str> {
self.extension(REMOTE_QUERY_ID_EXTENSION_KEY)
}
pub fn extensions(&self) -> HashMap<String, String> {
self.extensions.clone()
}
@@ -483,6 +517,10 @@ impl QueryContext {
impl QueryContextBuilder {
pub fn build(self) -> QueryContext {
let channel = self.channel.unwrap_or_default();
let mut extensions = self.extensions.unwrap_or_default();
extensions
.entry(REMOTE_QUERY_ID_EXTENSION_KEY.to_string())
.or_insert_with(generate_remote_query_id);
QueryContext {
current_catalog: self
.current_catalog
@@ -494,7 +532,7 @@ impl QueryContextBuilder {
sql_dialect: self
.sql_dialect
.unwrap_or_else(|| Arc::new(GreptimeDbDialect {})),
extensions: self.extensions.unwrap_or_default(),
extensions,
configuration_parameter: self
.configuration_parameter
.unwrap_or_else(|| Arc::new(ConfigurationVariables::default())),
@@ -707,6 +745,9 @@ mod test {
assert_eq!("mysql[127.0.0.1:9000]", session.conn_info().to_string());
assert_eq!(100, session.process_id());
let query_ctx = session.new_query_context();
assert!(query_ctx.remote_query_id().is_some());
}
#[test]
@@ -743,4 +784,23 @@ mod test {
assert_eq!(roundtrip_api.channel, api_ctx.channel);
assert_eq!(roundtrip_api.snapshot_seqs, api_ctx.snapshot_seqs);
}
#[test]
fn test_query_context_remote_query_id_round_trip() {
let query_id = "0195f4fd-c503-7c54-8b8f-7dfb8f6f9c4a";
let ctx = QueryContextBuilder::default()
.current_catalog(DEFAULT_CATALOG_NAME.to_string())
.current_schema("public".to_string())
.set_extension(
REMOTE_QUERY_ID_EXTENSION_KEY.to_string(),
query_id.to_string(),
)
.build();
assert_eq!(ctx.remote_query_id(), Some(query_id));
let proto: api::v1::QueryContext = (&ctx).into();
let restored = QueryContext::from(proto);
assert_eq!(restored.remote_query_id(), Some(query_id));
}
}

View File

@@ -30,7 +30,10 @@ use common_recordbatch::cursor::RecordBatchStreamCursor;
pub use common_session::ReadPreference;
use common_time::Timezone;
use common_time::timezone::get_timezone;
use context::{ConfigurationVariables, QueryContextBuilder};
use context::{
ConfigurationVariables, QueryContextBuilder, REMOTE_QUERY_ID_EXTENSION_KEY,
generate_remote_query_id,
};
use derive_more::Debug;
use crate::context::{Channel, ConnInfo, QueryContextRef};
@@ -106,6 +109,10 @@ impl Session {
.channel(self.conn_info.channel)
.process_id(self.process_id)
.conn_info(self.conn_info.clone())
.set_extension(
REMOTE_QUERY_ID_EXTENSION_KEY.to_string(),
generate_remote_query_id(),
)
.build()
.into()
}