From 62b51c673635ad597c2a1a5200512beeccf52fb0 Mon Sep 17 00:00:00 2001 From: jeremyhi Date: Wed, 22 Oct 2025 17:30:36 +0800 Subject: [PATCH] feat: writer mem limiter for http and grpc service (#7092) * feat: writer mem limiter for http and grpc service Signed-off-by: jeremyhi * fix: docs Signed-off-by: jeremyhi * feat: add metrics for limiter Signed-off-by: jeremyhi * Apply suggestion from @MichaelScofield Co-authored-by: LFC <990479+MichaelScofield@users.noreply.github.com> * chore: refactor try_acquire Signed-off-by: jeremyhi * chore: make size human readable Signed-off-by: jeremyhi --------- Signed-off-by: jeremyhi Co-authored-by: LFC <990479+MichaelScofield@users.noreply.github.com> --- config/config.md | 4 + config/frontend.example.toml | 8 + config/standalone.example.toml | 8 + src/flow/src/server.rs | 1 + src/servers/src/error.rs | 15 ++ src/servers/src/grpc.rs | 16 ++ src/servers/src/grpc/builder.rs | 21 ++- src/servers/src/grpc/database.rs | 46 +++++ src/servers/src/grpc/flight.rs | 51 +++++- src/servers/src/grpc/greptime_handler.rs | 1 + src/servers/src/grpc/memory_limit.rs | 72 ++++++++ src/servers/src/http.rs | 15 ++ src/servers/src/http/memory_limit.rs | 52 ++++++ src/servers/src/lib.rs | 4 + src/servers/src/metrics.rs | 20 +++ src/servers/src/request_limiter.rs | 215 +++++++++++++++++++++++ tests-integration/tests/grpc.rs | 161 ++++++++++++++++- tests-integration/tests/http.rs | 2 + 18 files changed, 704 insertions(+), 8 deletions(-) create mode 100644 src/servers/src/grpc/memory_limit.rs create mode 100644 src/servers/src/http/memory_limit.rs create mode 100644 src/servers/src/request_limiter.rs diff --git a/config/config.md b/config/config.md index 46a0aee1a7..72d48b5bcb 100644 --- a/config/config.md +++ b/config/config.md @@ -25,12 +25,14 @@ | `http.addr` | String | `127.0.0.1:4000` | The address to bind the HTTP server. | | `http.timeout` | String | `0s` | HTTP request timeout. Set to 0 to disable timeout. | | `http.body_limit` | String | `64MB` | HTTP request body limit.
The following units are supported: `B`, `KB`, `KiB`, `MB`, `MiB`, `GB`, `GiB`, `TB`, `TiB`, `PB`, `PiB`.
Set to 0 to disable limit. | +| `http.max_total_body_memory` | String | Unset | Maximum total memory for all concurrent HTTP request bodies.
Set to 0 to disable the limit. Default: "0" (unlimited) | | `http.enable_cors` | Bool | `true` | HTTP CORS support, it's turned on by default
This allows browser to access http APIs without CORS restrictions | | `http.cors_allowed_origins` | Array | Unset | Customize allowed origins for HTTP CORS. | | `http.prom_validation_mode` | String | `strict` | Whether to enable validation for Prometheus remote write requests.
Available options:
- strict: deny invalid UTF-8 strings (default).
- lossy: allow invalid UTF-8 strings, replace invalid characters with REPLACEMENT_CHARACTER(U+FFFD).
- unchecked: do not valid strings. | | `grpc` | -- | -- | The gRPC server options. | | `grpc.bind_addr` | String | `127.0.0.1:4001` | The address to bind the gRPC server. | | `grpc.runtime_size` | Integer | `8` | The number of server worker threads. | +| `grpc.max_total_message_memory` | String | Unset | Maximum total memory for all concurrent gRPC request messages.
Set to 0 to disable the limit. Default: "0" (unlimited) | | `grpc.max_connection_age` | String | Unset | The maximum connection age for gRPC connection.
The value can be a human-readable time string. For example: `10m` for ten minutes or `1h` for one hour.
Refer to https://grpc.io/docs/guides/keepalive/ for more details. | | `grpc.tls` | -- | -- | gRPC server TLS options, see `mysql.tls` section. | | `grpc.tls.mode` | String | `disable` | TLS mode. | @@ -235,6 +237,7 @@ | `http.addr` | String | `127.0.0.1:4000` | The address to bind the HTTP server. | | `http.timeout` | String | `0s` | HTTP request timeout. Set to 0 to disable timeout. | | `http.body_limit` | String | `64MB` | HTTP request body limit.
The following units are supported: `B`, `KB`, `KiB`, `MB`, `MiB`, `GB`, `GiB`, `TB`, `TiB`, `PB`, `PiB`.
Set to 0 to disable limit. | +| `http.max_total_body_memory` | String | Unset | Maximum total memory for all concurrent HTTP request bodies.
Set to 0 to disable the limit. Default: "0" (unlimited) | | `http.enable_cors` | Bool | `true` | HTTP CORS support, it's turned on by default
This allows browser to access http APIs without CORS restrictions | | `http.cors_allowed_origins` | Array | Unset | Customize allowed origins for HTTP CORS. | | `http.prom_validation_mode` | String | `strict` | Whether to enable validation for Prometheus remote write requests.
Available options:
- strict: deny invalid UTF-8 strings (default).
- lossy: allow invalid UTF-8 strings, replace invalid characters with REPLACEMENT_CHARACTER(U+FFFD).
- unchecked: do not valid strings. | @@ -242,6 +245,7 @@ | `grpc.bind_addr` | String | `127.0.0.1:4001` | The address to bind the gRPC server. | | `grpc.server_addr` | String | `127.0.0.1:4001` | The address advertised to the metasrv, and used for connections from outside the host.
If left empty or unset, the server will automatically use the IP address of the first network interface
on the host, with the same port number as the one specified in `grpc.bind_addr`. | | `grpc.runtime_size` | Integer | `8` | The number of server worker threads. | +| `grpc.max_total_message_memory` | String | Unset | Maximum total memory for all concurrent gRPC request messages.
Set to 0 to disable the limit. Default: "0" (unlimited) | | `grpc.flight_compression` | String | `arrow_ipc` | Compression mode for frontend side Arrow IPC service. Available options:
- `none`: disable all compression
- `transport`: only enable gRPC transport compression (zstd)
- `arrow_ipc`: only enable Arrow IPC compression (lz4)
- `all`: enable all compression.
Default to `none` | | `grpc.max_connection_age` | String | Unset | The maximum connection age for gRPC connection.
The value can be a human-readable time string. For example: `10m` for ten minutes or `1h` for one hour.
Refer to https://grpc.io/docs/guides/keepalive/ for more details. | | `grpc.tls` | -- | -- | gRPC server TLS options, see `mysql.tls` section. | diff --git a/config/frontend.example.toml b/config/frontend.example.toml index b26d88323e..9ffcdad540 100644 --- a/config/frontend.example.toml +++ b/config/frontend.example.toml @@ -31,6 +31,10 @@ timeout = "0s" ## The following units are supported: `B`, `KB`, `KiB`, `MB`, `MiB`, `GB`, `GiB`, `TB`, `TiB`, `PB`, `PiB`. ## Set to 0 to disable limit. body_limit = "64MB" +## Maximum total memory for all concurrent HTTP request bodies. +## Set to 0 to disable the limit. Default: "0" (unlimited) +## @toml2docs:none-default +#+ max_total_body_memory = "1GB" ## HTTP CORS support, it's turned on by default ## This allows browser to access http APIs without CORS restrictions enable_cors = true @@ -54,6 +58,10 @@ bind_addr = "127.0.0.1:4001" server_addr = "127.0.0.1:4001" ## The number of server worker threads. runtime_size = 8 +## Maximum total memory for all concurrent gRPC request messages. +## Set to 0 to disable the limit. Default: "0" (unlimited) +## @toml2docs:none-default +#+ max_total_message_memory = "1GB" ## Compression mode for frontend side Arrow IPC service. Available options: ## - `none`: disable all compression ## - `transport`: only enable gRPC transport compression (zstd) diff --git a/config/standalone.example.toml b/config/standalone.example.toml index 5fae0f444f..744dbbe751 100644 --- a/config/standalone.example.toml +++ b/config/standalone.example.toml @@ -36,6 +36,10 @@ timeout = "0s" ## The following units are supported: `B`, `KB`, `KiB`, `MB`, `MiB`, `GB`, `GiB`, `TB`, `TiB`, `PB`, `PiB`. ## Set to 0 to disable limit. body_limit = "64MB" +## Maximum total memory for all concurrent HTTP request bodies. +## Set to 0 to disable the limit. Default: "0" (unlimited) +## @toml2docs:none-default +#+ max_total_body_memory = "1GB" ## HTTP CORS support, it's turned on by default ## This allows browser to access http APIs without CORS restrictions enable_cors = true @@ -56,6 +60,10 @@ prom_validation_mode = "strict" bind_addr = "127.0.0.1:4001" ## The number of server worker threads. runtime_size = 8 +## Maximum total memory for all concurrent gRPC request messages. +## Set to 0 to disable the limit. Default: "0" (unlimited) +## @toml2docs:none-default +#+ max_total_message_memory = "1GB" ## The maximum connection age for gRPC connection. ## The value can be a human-readable time string. For example: `10m` for ten minutes or `1h` for one hour. ## Refer to https://grpc.io/docs/guides/keepalive/ for more details. diff --git a/src/flow/src/server.rs b/src/flow/src/server.rs index 3f46203ba0..eae97756a5 100644 --- a/src/flow/src/server.rs +++ b/src/flow/src/server.rs @@ -490,6 +490,7 @@ impl<'a> FlownodeServiceBuilder<'a> { let config = GrpcServerConfig { max_recv_message_size: opts.grpc.max_recv_message_size.as_bytes() as usize, max_send_message_size: opts.grpc.max_send_message_size.as_bytes() as usize, + max_total_message_memory: opts.grpc.max_total_message_memory.as_bytes() as usize, tls: opts.grpc.tls.clone(), max_connection_age: opts.grpc.max_connection_age, }; diff --git a/src/servers/src/error.rs b/src/servers/src/error.rs index d36bdd1494..c7e5c5d07a 100644 --- a/src/servers/src/error.rs +++ b/src/servers/src/error.rs @@ -20,6 +20,7 @@ use axum::http::StatusCode as HttpStatusCode; use axum::response::{IntoResponse, Response}; use axum::{Json, http}; use base64::DecodeError; +use common_base::readable_size::ReadableSize; use common_error::define_into_tonic_status; use common_error::ext::{BoxedError, ErrorExt}; use common_error::status_code::StatusCode; @@ -164,6 +165,18 @@ pub enum Error { location: Location, }, + #[snafu(display( + "Too many concurrent large requests, limit: {}, request size: {}", + ReadableSize(*limit as u64), + ReadableSize(*request_size as u64) + ))] + TooManyConcurrentRequests { + limit: usize, + request_size: usize, + #[snafu(implicit)] + location: Location, + }, + #[snafu(display("Invalid query: {}", reason))] InvalidQuery { reason: String, @@ -729,6 +742,8 @@ impl ErrorExt for Error { InvalidUtf8Value { .. } | InvalidHeaderValue { .. } => StatusCode::InvalidArguments, + TooManyConcurrentRequests { .. } => StatusCode::RuntimeResourcesExhausted, + ParsePromQL { source, .. } => source.status_code(), Other { source, .. } => source.status_code(), diff --git a/src/servers/src/grpc.rs b/src/servers/src/grpc.rs index 2f759db2a0..1c479a04de 100644 --- a/src/servers/src/grpc.rs +++ b/src/servers/src/grpc.rs @@ -19,6 +19,7 @@ mod database; pub mod flight; pub mod frontend_grpc_handler; pub mod greptime_handler; +pub mod memory_limit; pub mod prom_query_gateway; pub mod region_server; @@ -51,6 +52,7 @@ use crate::error::{AlreadyStartedSnafu, InternalSnafu, Result, StartGrpcSnafu, T use crate::metrics::MetricsMiddlewareLayer; use crate::otel_arrow::{HeaderInterceptor, OtelArrowServiceHandler}; use crate::query_handler::OpenTelemetryProtocolHandlerRef; +use crate::request_limiter::RequestMemoryLimiter; use crate::server::Server; use crate::tls::TlsOption; @@ -67,6 +69,8 @@ pub struct GrpcOptions { pub max_recv_message_size: ReadableSize, /// Max gRPC sending(encoding) message size pub max_send_message_size: ReadableSize, + /// Maximum total memory for all concurrent gRPC request messages. 0 disables the limit. + pub max_total_message_memory: ReadableSize, /// Compression mode in Arrow Flight service. pub flight_compression: FlightCompression, pub runtime_size: usize, @@ -116,6 +120,7 @@ impl GrpcOptions { GrpcServerConfig { max_recv_message_size: self.max_recv_message_size.as_bytes() as usize, max_send_message_size: self.max_send_message_size.as_bytes() as usize, + max_total_message_memory: self.max_total_message_memory.as_bytes() as usize, tls: self.tls.clone(), max_connection_age: self.max_connection_age, } @@ -134,6 +139,7 @@ impl Default for GrpcOptions { server_addr: String::new(), max_recv_message_size: DEFAULT_MAX_GRPC_RECV_MESSAGE_SIZE, max_send_message_size: DEFAULT_MAX_GRPC_SEND_MESSAGE_SIZE, + max_total_message_memory: ReadableSize(0), flight_compression: FlightCompression::ArrowIpc, runtime_size: 8, tls: TlsOption::default(), @@ -153,6 +159,7 @@ impl GrpcOptions { server_addr: format!("127.0.0.1:{}", DEFAULT_INTERNAL_GRPC_ADDR_PORT), max_recv_message_size: DEFAULT_MAX_GRPC_RECV_MESSAGE_SIZE, max_send_message_size: DEFAULT_MAX_GRPC_SEND_MESSAGE_SIZE, + max_total_message_memory: ReadableSize(0), flight_compression: FlightCompression::ArrowIpc, runtime_size: 8, tls: TlsOption::default(), @@ -217,6 +224,7 @@ pub struct GrpcServer { bind_addr: Option, name: Option, config: GrpcServerConfig, + memory_limiter: RequestMemoryLimiter, } /// Grpc Server configuration @@ -226,6 +234,8 @@ pub struct GrpcServerConfig { pub max_recv_message_size: usize, // Max gRPC sending(encoding) message size pub max_send_message_size: usize, + /// Maximum total memory for all concurrent gRPC request messages. 0 disables the limit. + pub max_total_message_memory: usize, pub tls: TlsOption, /// Maximum time that a channel may exist. /// Useful when the server wants to control the reconnection of its clients. @@ -238,6 +248,7 @@ impl Default for GrpcServerConfig { Self { max_recv_message_size: DEFAULT_MAX_GRPC_RECV_MESSAGE_SIZE.as_bytes() as usize, max_send_message_size: DEFAULT_MAX_GRPC_SEND_MESSAGE_SIZE.as_bytes() as usize, + max_total_message_memory: 0, tls: TlsOption::default(), max_connection_age: None, } @@ -277,6 +288,11 @@ impl GrpcServer { } Ok(()) } + + /// Get the memory limiter for monitoring current memory usage + pub fn memory_limiter(&self) -> &RequestMemoryLimiter { + &self.memory_limiter + } } pub struct HealthCheckHandler; diff --git a/src/servers/src/grpc/builder.rs b/src/servers/src/grpc/builder.rs index 75a0bb13c3..ae5c226138 100644 --- a/src/servers/src/grpc/builder.rs +++ b/src/servers/src/grpc/builder.rs @@ -38,6 +38,7 @@ use crate::grpc::{GrpcServer, GrpcServerConfig}; use crate::otel_arrow::{HeaderInterceptor, OtelArrowServiceHandler}; use crate::prometheus_handler::PrometheusHandlerRef; use crate::query_handler::OpenTelemetryProtocolHandlerRef; +use crate::request_limiter::RequestMemoryLimiter; use crate::tls::TlsOption; /// Add a gRPC service (`service`) to a `builder`([RoutesBuilder]). @@ -57,7 +58,17 @@ macro_rules! add_service { .send_compressed(CompressionEncoding::Gzip) .send_compressed(CompressionEncoding::Zstd); - $builder.routes_builder_mut().add_service(service_builder); + // Apply memory limiter layer + use $crate::grpc::memory_limit::MemoryLimiterExtensionLayer; + let service_with_limiter = $crate::tower::ServiceBuilder::new() + .layer(MemoryLimiterExtensionLayer::new( + $builder.memory_limiter().clone(), + )) + .service(service_builder); + + $builder + .routes_builder_mut() + .add_service(service_with_limiter); }; } @@ -73,10 +84,12 @@ pub struct GrpcServerBuilder { HeaderInterceptor, >, >, + memory_limiter: RequestMemoryLimiter, } impl GrpcServerBuilder { pub fn new(config: GrpcServerConfig, runtime: Runtime) -> Self { + let memory_limiter = RequestMemoryLimiter::new(config.max_total_message_memory); Self { name: None, config, @@ -84,6 +97,7 @@ impl GrpcServerBuilder { routes_builder: RoutesBuilder::default(), tls_config: None, otel_arrow_service: None, + memory_limiter, } } @@ -95,6 +109,10 @@ impl GrpcServerBuilder { &self.runtime } + pub fn memory_limiter(&self) -> &RequestMemoryLimiter { + &self.memory_limiter + } + pub fn name(self, name: Option) -> Self { Self { name, ..self } } @@ -198,6 +216,7 @@ impl GrpcServerBuilder { bind_addr: None, name: self.name, config: self.config, + memory_limiter: self.memory_limiter, } } } diff --git a/src/servers/src/grpc/database.rs b/src/servers/src/grpc/database.rs index 13c328399d..5d132c434e 100644 --- a/src/servers/src/grpc/database.rs +++ b/src/servers/src/grpc/database.rs @@ -20,11 +20,14 @@ use common_error::status_code::StatusCode; use common_query::OutputData; use common_telemetry::{debug, warn}; use futures::StreamExt; +use prost::Message; use tonic::{Request, Response, Status, Streaming}; use crate::grpc::greptime_handler::GreptimeRequestHandler; use crate::grpc::{TonicResult, cancellation}; use crate::hint_headers; +use crate::metrics::{METRIC_GRPC_MEMORY_USAGE_BYTES, METRIC_GRPC_REQUESTS_REJECTED_TOTAL}; +use crate::request_limiter::RequestMemoryLimiter; pub(crate) struct DatabaseService { handler: GreptimeRequestHandler, @@ -48,6 +51,27 @@ impl GreptimeDatabase for DatabaseService { "GreptimeDatabase::Handle: request from {:?} with hints: {:?}", remote_addr, hints ); + + let _guard = request + .extensions() + .get::() + .filter(|limiter| limiter.is_enabled()) + .and_then(|limiter| { + let message_size = request.get_ref().encoded_len(); + limiter + .try_acquire(message_size) + .map(|guard| { + guard.inspect(|g| { + METRIC_GRPC_MEMORY_USAGE_BYTES.set(g.current_usage() as i64); + }) + }) + .inspect_err(|_| { + METRIC_GRPC_REQUESTS_REJECTED_TOTAL.inc(); + }) + .transpose() + }) + .transpose()?; + let handler = self.handler.clone(); let request_future = async move { let request = request.into_inner(); @@ -94,6 +118,9 @@ impl GreptimeDatabase for DatabaseService { "GreptimeDatabase::HandleRequests: request from {:?} with hints: {:?}", remote_addr, hints ); + + let limiter = request.extensions().get::().cloned(); + let handler = self.handler.clone(); let request_future = async move { let mut affected_rows = 0; @@ -101,6 +128,25 @@ impl GreptimeDatabase for DatabaseService { let mut stream = request.into_inner(); while let Some(request) = stream.next().await { let request = request?; + + let _guard = limiter + .as_ref() + .filter(|limiter| limiter.is_enabled()) + .and_then(|limiter| { + let message_size = request.encoded_len(); + limiter + .try_acquire(message_size) + .map(|guard| { + guard.inspect(|g| { + METRIC_GRPC_MEMORY_USAGE_BYTES.set(g.current_usage() as i64); + }) + }) + .inspect_err(|_| { + METRIC_GRPC_REQUESTS_REJECTED_TOTAL.inc(); + }) + .transpose() + }) + .transpose()?; let output = handler.handle_request(request, hints.clone()).await?; match output.data { OutputData::AffectedRows(rows) => affected_rows += rows, diff --git a/src/servers/src/grpc/flight.rs b/src/servers/src/grpc/flight.rs index bb431bfdae..44b307fe71 100644 --- a/src/servers/src/grpc/flight.rs +++ b/src/servers/src/grpc/flight.rs @@ -45,6 +45,8 @@ use crate::error::{InvalidParameterSnafu, ParseJsonSnafu, Result, ToJsonSnafu}; pub use crate::grpc::flight::stream::FlightRecordBatchStream; use crate::grpc::greptime_handler::{GreptimeRequestHandler, get_request_type}; use crate::grpc::{FlightCompression, TonicResult, context_auth}; +use crate::metrics::{METRIC_GRPC_MEMORY_USAGE_BYTES, METRIC_GRPC_REQUESTS_REJECTED_TOTAL}; +use crate::request_limiter::{RequestMemoryGuard, RequestMemoryLimiter}; use crate::{error, hint_headers}; pub type TonicStream = Pin> + Send + 'static>>; @@ -211,7 +213,9 @@ impl FlightCraft for GreptimeRequestHandler { &self, request: Request>, ) -> TonicResult>> { - let (headers, _, stream) = request.into_parts(); + let (headers, extensions, stream) = request.into_parts(); + + let limiter = extensions.get::().cloned(); let query_ctx = context_auth::create_query_context_from_grpc_metadata(&headers)?; context_auth::check_auth(self.user_provider.clone(), &headers, query_ctx.clone()).await?; @@ -225,6 +229,7 @@ impl FlightCraft for GreptimeRequestHandler { query_ctx.current_catalog().to_string(), query_ctx.current_schema(), ), + limiter, }; self.put_record_batches(stream, tx, query_ctx).await; @@ -248,10 +253,15 @@ pub(crate) struct PutRecordBatchRequest { pub(crate) table_name: TableName, pub(crate) request_id: i64, pub(crate) data: FlightData, + pub(crate) _guard: Option, } impl PutRecordBatchRequest { - fn try_new(table_name: TableName, flight_data: FlightData) -> Result { + fn try_new( + table_name: TableName, + flight_data: FlightData, + limiter: Option<&RequestMemoryLimiter>, + ) -> Result { let request_id = if !flight_data.app_metadata.is_empty() { let metadata: DoPutMetadata = serde_json::from_slice(&flight_data.app_metadata).context(ParseJsonSnafu)?; @@ -259,10 +269,30 @@ impl PutRecordBatchRequest { } else { 0 }; + + let _guard = limiter + .filter(|limiter| limiter.is_enabled()) + .map(|limiter| { + let message_size = flight_data.encoded_len(); + limiter + .try_acquire(message_size) + .map(|guard| { + guard.inspect(|g| { + METRIC_GRPC_MEMORY_USAGE_BYTES.set(g.current_usage() as i64); + }) + }) + .inspect_err(|_| { + METRIC_GRPC_REQUESTS_REJECTED_TOTAL.inc(); + }) + }) + .transpose()? + .flatten(); + Ok(Self { table_name, request_id, data: flight_data, + _guard, }) } } @@ -270,6 +300,7 @@ impl PutRecordBatchRequest { pub(crate) struct PutRecordBatchRequestStream { flight_data_stream: Streaming, state: PutRecordBatchRequestStreamState, + limiter: Option, } enum PutRecordBatchRequestStreamState { @@ -298,6 +329,7 @@ impl Stream for PutRecordBatchRequestStream { } let poll = ready!(self.flight_data_stream.poll_next_unpin(cx)); + let limiter = self.limiter.clone(); let result = match &mut self.state { PutRecordBatchRequestStreamState::Init(catalog, schema) => match poll { @@ -311,8 +343,11 @@ impl Stream for PutRecordBatchRequestStream { Err(e) => return Poll::Ready(Some(Err(e.into()))), }; - let request = - PutRecordBatchRequest::try_new(table_name.clone(), flight_data); + let request = PutRecordBatchRequest::try_new( + table_name.clone(), + flight_data, + limiter.as_ref(), + ); let request = match request { Ok(request) => request, Err(e) => return Poll::Ready(Some(Err(e.into()))), @@ -333,8 +368,12 @@ impl Stream for PutRecordBatchRequestStream { }, PutRecordBatchRequestStreamState::Started(table_name) => poll.map(|x| { x.and_then(|flight_data| { - PutRecordBatchRequest::try_new(table_name.clone(), flight_data) - .map_err(Into::into) + PutRecordBatchRequest::try_new( + table_name.clone(), + flight_data, + limiter.as_ref(), + ) + .map_err(Into::into) }) }), }; diff --git a/src/servers/src/grpc/greptime_handler.rs b/src/servers/src/grpc/greptime_handler.rs index e19fc4352b..095c36abb1 100644 --- a/src/servers/src/grpc/greptime_handler.rs +++ b/src/servers/src/grpc/greptime_handler.rs @@ -160,6 +160,7 @@ impl GreptimeRequestHandler { table_name, request_id, data, + _guard, } = request; let timer = metrics::GRPC_BULK_INSERT_ELAPSED.start_timer(); diff --git a/src/servers/src/grpc/memory_limit.rs b/src/servers/src/grpc/memory_limit.rs new file mode 100644 index 0000000000..a3dee9da57 --- /dev/null +++ b/src/servers/src/grpc/memory_limit.rs @@ -0,0 +1,72 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::task::{Context, Poll}; + +use futures::future::BoxFuture; +use tonic::server::NamedService; +use tower::{Layer, Service}; + +use crate::request_limiter::RequestMemoryLimiter; + +#[derive(Clone)] +pub struct MemoryLimiterExtensionLayer { + limiter: RequestMemoryLimiter, +} + +impl MemoryLimiterExtensionLayer { + pub fn new(limiter: RequestMemoryLimiter) -> Self { + Self { limiter } + } +} + +impl Layer for MemoryLimiterExtensionLayer { + type Service = MemoryLimiterExtensionService; + + fn layer(&self, service: S) -> Self::Service { + MemoryLimiterExtensionService { + inner: service, + limiter: self.limiter.clone(), + } + } +} + +#[derive(Clone)] +pub struct MemoryLimiterExtensionService { + inner: S, + limiter: RequestMemoryLimiter, +} + +impl NamedService for MemoryLimiterExtensionService { + const NAME: &'static str = S::NAME; +} + +impl Service> for MemoryLimiterExtensionService +where + S: Service>, + S::Future: Send + 'static, +{ + type Response = S::Response; + type Error = S::Error; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, mut req: http::Request) -> Self::Future { + req.extensions_mut().insert(self.limiter.clone()); + Box::pin(self.inner.call(req)) + } +} diff --git a/src/servers/src/http.rs b/src/servers/src/http.rs index 946e22ba5b..404b087535 100644 --- a/src/servers/src/http.rs +++ b/src/servers/src/http.rs @@ -82,6 +82,7 @@ use crate::query_handler::{ OpenTelemetryProtocolHandlerRef, OpentsdbProtocolHandlerRef, PipelineHandlerRef, PromStoreProtocolHandlerRef, }; +use crate::request_limiter::RequestMemoryLimiter; use crate::server::Server; pub mod authorize; @@ -97,6 +98,7 @@ pub mod jaeger; pub mod logs; pub mod loki; pub mod mem_prof; +mod memory_limit; pub mod opentsdb; pub mod otlp; pub mod pprof; @@ -129,6 +131,7 @@ pub struct HttpServer { router: StdMutex, shutdown_tx: Mutex>>, user_provider: Option, + memory_limiter: RequestMemoryLimiter, // plugins plugins: Plugins, @@ -151,6 +154,9 @@ pub struct HttpOptions { pub body_limit: ReadableSize, + /// Maximum total memory for all concurrent HTTP request bodies. 0 disables the limit. + pub max_total_body_memory: ReadableSize, + /// Validation mode while decoding Prometheus remote write requests. pub prom_validation_mode: PromValidationMode, @@ -195,6 +201,7 @@ impl Default for HttpOptions { timeout: Duration::from_secs(0), disable_dashboard: false, body_limit: DEFAULT_BODY_LIMIT, + max_total_body_memory: ReadableSize(0), cors_allowed_origins: Vec::new(), enable_cors: true, prom_validation_mode: PromValidationMode::Strict, @@ -746,6 +753,8 @@ impl HttpServerBuilder { } pub fn build(self) -> HttpServer { + let memory_limiter = + RequestMemoryLimiter::new(self.options.max_total_body_memory.as_bytes() as usize); HttpServer { options: self.options, user_provider: self.user_provider, @@ -753,6 +762,7 @@ impl HttpServerBuilder { plugins: self.plugins, router: StdMutex::new(self.router), bind_addr: None, + memory_limiter, } } } @@ -877,6 +887,11 @@ impl HttpServer { .option_layer(cors_layer) .option_layer(timeout_layer) .option_layer(body_limit_layer) + // memory limit layer - must be before body is consumed + .layer(middleware::from_fn_with_state( + self.memory_limiter.clone(), + memory_limit::memory_limit_middleware, + )) // auth layer .layer(middleware::from_fn_with_state( AuthState::new(self.user_provider.clone()), diff --git a/src/servers/src/http/memory_limit.rs b/src/servers/src/http/memory_limit.rs new file mode 100644 index 0000000000..346b5d3409 --- /dev/null +++ b/src/servers/src/http/memory_limit.rs @@ -0,0 +1,52 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Middleware for limiting total memory usage of concurrent HTTP request bodies. + +use axum::extract::{Request, State}; +use axum::middleware::Next; +use axum::response::{IntoResponse, Response}; +use http::StatusCode; + +use crate::metrics::{METRIC_HTTP_MEMORY_USAGE_BYTES, METRIC_HTTP_REQUESTS_REJECTED_TOTAL}; +use crate::request_limiter::RequestMemoryLimiter; + +pub async fn memory_limit_middleware( + State(limiter): State, + req: Request, + next: Next, +) -> Response { + let content_length = req + .headers() + .get(http::header::CONTENT_LENGTH) + .and_then(|v| v.to_str().ok()) + .and_then(|v| v.parse::().ok()) + .unwrap_or(0); + + let _guard = match limiter.try_acquire(content_length) { + Ok(guard) => guard.inspect(|g| { + METRIC_HTTP_MEMORY_USAGE_BYTES.set(g.current_usage() as i64); + }), + Err(e) => { + METRIC_HTTP_REQUESTS_REJECTED_TOTAL.inc(); + return ( + StatusCode::TOO_MANY_REQUESTS, + format!("Request body memory limit exceeded: {}", e), + ) + .into_response(); + } + }; + + next.run(req).await +} diff --git a/src/servers/src/lib.rs b/src/servers/src/lib.rs index 7172934e66..c73883f0da 100644 --- a/src/servers/src/lib.rs +++ b/src/servers/src/lib.rs @@ -20,6 +20,9 @@ use datafusion_expr::LogicalPlan; use datatypes::schema::Schema; use sql::statements::statement::Statement; +// Re-export for use in add_service! macro +#[doc(hidden)] +pub use tower; pub mod addrs; pub mod configurator; @@ -47,6 +50,7 @@ pub mod prometheus_handler; pub mod proto; pub mod query_handler; pub mod repeated_field; +pub mod request_limiter; mod row_writer; pub mod server; pub mod tls; diff --git a/src/servers/src/metrics.rs b/src/servers/src/metrics.rs index af44e697db..8662465f94 100644 --- a/src/servers/src/metrics.rs +++ b/src/servers/src/metrics.rs @@ -298,6 +298,26 @@ lazy_static! { "greptime_servers_bulk_insert_elapsed", "servers handle bulk insert elapsed", ).unwrap(); + + pub static ref METRIC_HTTP_MEMORY_USAGE_BYTES: IntGauge = register_int_gauge!( + "greptime_servers_http_memory_usage_bytes", + "current http request memory usage in bytes" + ).unwrap(); + + pub static ref METRIC_HTTP_REQUESTS_REJECTED_TOTAL: IntCounter = register_int_counter!( + "greptime_servers_http_requests_rejected_total", + "total number of http requests rejected due to memory limit" + ).unwrap(); + + pub static ref METRIC_GRPC_MEMORY_USAGE_BYTES: IntGauge = register_int_gauge!( + "greptime_servers_grpc_memory_usage_bytes", + "current grpc request memory usage in bytes" + ).unwrap(); + + pub static ref METRIC_GRPC_REQUESTS_REJECTED_TOTAL: IntCounter = register_int_counter!( + "greptime_servers_grpc_requests_rejected_total", + "total number of grpc requests rejected due to memory limit" + ).unwrap(); } // Based on https://github.com/hyperium/tonic/blob/master/examples/src/tower/server.rs diff --git a/src/servers/src/request_limiter.rs b/src/servers/src/request_limiter.rs new file mode 100644 index 0000000000..62fb4cf216 --- /dev/null +++ b/src/servers/src/request_limiter.rs @@ -0,0 +1,215 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Request memory limiter for controlling total memory usage of concurrent requests. + +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; + +use crate::error::{Result, TooManyConcurrentRequestsSnafu}; + +/// Limiter for total memory usage of concurrent request bodies. +/// +/// Tracks the total memory used by all concurrent request bodies +/// and rejects new requests when the limit is reached. +#[derive(Clone, Default)] +pub struct RequestMemoryLimiter { + inner: Option>, +} + +struct LimiterInner { + current_usage: AtomicUsize, + max_memory: usize, +} + +impl RequestMemoryLimiter { + /// Create a new memory limiter. + /// + /// # Arguments + /// * `max_memory` - Maximum total memory for all concurrent request bodies in bytes (0 = unlimited) + pub fn new(max_memory: usize) -> Self { + if max_memory == 0 { + return Self { inner: None }; + } + + Self { + inner: Some(Arc::new(LimiterInner { + current_usage: AtomicUsize::new(0), + max_memory, + })), + } + } + + /// Try to acquire memory for a request of given size. + /// + /// Returns `Ok(RequestMemoryGuard)` if memory was acquired successfully. + /// Returns `Err` if the memory limit would be exceeded. + pub fn try_acquire(&self, request_size: usize) -> Result> { + let Some(inner) = self.inner.as_ref() else { + return Ok(None); + }; + + let mut new_usage = 0; + let result = + inner + .current_usage + .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| { + new_usage = current.saturating_add(request_size); + if new_usage <= inner.max_memory { + Some(new_usage) + } else { + None + } + }); + + match result { + Ok(_) => Ok(Some(RequestMemoryGuard { + size: request_size, + limiter: Arc::clone(inner), + usage_snapshot: new_usage, + })), + Err(_current) => TooManyConcurrentRequestsSnafu { + limit: inner.max_memory, + request_size, + } + .fail(), + } + } + + /// Check if limiter is enabled + pub fn is_enabled(&self) -> bool { + self.inner.is_some() + } + + /// Get current memory usage + pub fn current_usage(&self) -> usize { + self.inner + .as_ref() + .map(|inner| inner.current_usage.load(Ordering::Relaxed)) + .unwrap_or(0) + } + + /// Get max memory limit + pub fn max_memory(&self) -> usize { + self.inner + .as_ref() + .map(|inner| inner.max_memory) + .unwrap_or(0) + } +} + +/// RAII guard that releases memory when dropped +pub struct RequestMemoryGuard { + size: usize, + limiter: Arc, + usage_snapshot: usize, +} + +impl RequestMemoryGuard { + /// Returns the total memory usage snapshot at the time this guard was acquired. + pub fn current_usage(&self) -> usize { + self.usage_snapshot + } +} + +impl Drop for RequestMemoryGuard { + fn drop(&mut self) { + self.limiter + .current_usage + .fetch_sub(self.size, Ordering::Release); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_limiter_disabled() { + let limiter = RequestMemoryLimiter::new(0); + assert!(!limiter.is_enabled()); + assert!(limiter.try_acquire(1000000).unwrap().is_none()); + assert_eq!(limiter.current_usage(), 0); + } + + #[test] + fn test_limiter_basic() { + let limiter = RequestMemoryLimiter::new(1000); + assert!(limiter.is_enabled()); + assert_eq!(limiter.max_memory(), 1000); + assert_eq!(limiter.current_usage(), 0); + + // Acquire 400 bytes + let _guard1 = limiter.try_acquire(400).unwrap(); + assert_eq!(limiter.current_usage(), 400); + + // Acquire another 500 bytes + let _guard2 = limiter.try_acquire(500).unwrap(); + assert_eq!(limiter.current_usage(), 900); + + // Try to acquire 200 bytes - should fail (900 + 200 > 1000) + let result = limiter.try_acquire(200); + assert!(result.is_err()); + assert_eq!(limiter.current_usage(), 900); + + // Drop first guard + drop(_guard1); + assert_eq!(limiter.current_usage(), 500); + + // Now we can acquire 200 bytes + let _guard3 = limiter.try_acquire(200).unwrap(); + assert_eq!(limiter.current_usage(), 700); + } + + #[test] + fn test_limiter_exact_limit() { + let limiter = RequestMemoryLimiter::new(1000); + + // Acquire exactly the limit + let _guard = limiter.try_acquire(1000).unwrap(); + assert_eq!(limiter.current_usage(), 1000); + + // Try to acquire 1 more byte - should fail + let result = limiter.try_acquire(1); + assert!(result.is_err()); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 4)] + async fn test_limiter_concurrent() { + let limiter = RequestMemoryLimiter::new(1000); + let mut handles = vec![]; + + // Spawn 10 tasks each trying to acquire 200 bytes + for _ in 0..10 { + let limiter_clone = limiter.clone(); + let handle = tokio::spawn(async move { limiter_clone.try_acquire(200) }); + handles.push(handle); + } + + let mut success_count = 0; + let mut fail_count = 0; + + for handle in handles { + match handle.await.unwrap() { + Ok(Some(_)) => success_count += 1, + Err(_) => fail_count += 1, + Ok(None) => unreachable!(), + } + } + + // Only 5 tasks should succeed (5 * 200 = 1000) + assert_eq!(success_count, 5); + assert_eq!(fail_count, 5); + } +} diff --git a/tests-integration/tests/grpc.rs b/tests-integration/tests/grpc.rs index b9e56564a5..6f82d4fc55 100644 --- a/tests-integration/tests/grpc.rs +++ b/tests-integration/tests/grpc.rs @@ -14,10 +14,12 @@ use api::v1::alter_table_expr::Kind; use api::v1::promql_request::Promql; +use api::v1::value::ValueData; use api::v1::{ AddColumn, AddColumns, AlterTableExpr, Basic, Column, ColumnDataType, ColumnDef, CreateTableExpr, InsertRequest, InsertRequests, PromInstantQuery, PromRangeQuery, - PromqlRequest, RequestHeader, SemanticType, column, + PromqlRequest, RequestHeader, Row, RowInsertRequest, RowInsertRequests, SemanticType, Value, + column, }; use auth::user_provider_from_option; use client::{Client, DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, Database, OutputData}; @@ -89,6 +91,7 @@ macro_rules! grpc_tests { test_prom_gateway_query, test_grpc_timezone, test_grpc_tls_config, + test_grpc_memory_limit, ); )* }; @@ -954,6 +957,7 @@ pub async fn test_grpc_tls_config(store_type: StorageType) { let config = GrpcServerConfig { max_recv_message_size: 1024, max_send_message_size: 1024, + max_total_message_memory: 1024 * 1024 * 1024, tls, max_connection_age: None, }; @@ -996,6 +1000,7 @@ pub async fn test_grpc_tls_config(store_type: StorageType) { let config = GrpcServerConfig { max_recv_message_size: 1024, max_send_message_size: 1024, + max_total_message_memory: 1024 * 1024 * 1024, tls, max_connection_age: None, }; @@ -1007,3 +1012,157 @@ pub async fn test_grpc_tls_config(store_type: StorageType) { let _ = fe_grpc_server.shutdown().await; } + +pub async fn test_grpc_memory_limit(store_type: StorageType) { + let config = GrpcServerConfig { + max_recv_message_size: 1024 * 1024, + max_send_message_size: 1024 * 1024, + max_total_message_memory: 200, + tls: Default::default(), + max_connection_age: None, + }; + let (_db, fe_grpc_server) = + setup_grpc_server_with(store_type, "test_grpc_memory_limit", None, Some(config)).await; + let addr = fe_grpc_server.bind_addr().unwrap().to_string(); + + let grpc_client = Client::with_urls([&addr]); + let db = Database::new(DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, grpc_client); + + let table_name = "demo"; + + let column_schemas = vec![ + ColumnDef { + name: "host".to_string(), + data_type: ColumnDataType::String as i32, + is_nullable: false, + default_constraint: vec![], + semantic_type: SemanticType::Tag as i32, + comment: String::new(), + datatype_extension: None, + options: None, + }, + ColumnDef { + name: "ts".to_string(), + data_type: ColumnDataType::TimestampMillisecond as i32, + is_nullable: false, + default_constraint: vec![], + semantic_type: SemanticType::Timestamp as i32, + comment: String::new(), + datatype_extension: None, + options: None, + }, + ColumnDef { + name: "cpu".to_string(), + data_type: ColumnDataType::Float64 as i32, + is_nullable: true, + default_constraint: vec![], + semantic_type: SemanticType::Field as i32, + comment: String::new(), + datatype_extension: None, + options: None, + }, + ]; + + let expr = CreateTableExpr { + catalog_name: DEFAULT_CATALOG_NAME.to_string(), + schema_name: DEFAULT_SCHEMA_NAME.to_string(), + table_name: table_name.to_string(), + desc: String::new(), + column_defs: column_schemas.clone(), + time_index: "ts".to_string(), + primary_keys: vec!["host".to_string()], + create_if_not_exists: true, + table_options: Default::default(), + table_id: None, + engine: MITO_ENGINE.to_string(), + }; + + db.create(expr).await.unwrap(); + + // Test that small request succeeds + let small_row_insert = RowInsertRequest { + table_name: table_name.to_owned(), + rows: Some(api::v1::Rows { + schema: column_schemas + .iter() + .map(|c| api::v1::ColumnSchema { + column_name: c.name.clone(), + datatype: c.data_type, + semantic_type: c.semantic_type, + datatype_extension: None, + options: None, + }) + .collect(), + rows: vec![Row { + values: vec![ + Value { + value_data: Some(ValueData::StringValue("host1".to_string())), + }, + Value { + value_data: Some(ValueData::TimestampMillisecondValue(1000)), + }, + Value { + value_data: Some(ValueData::F64Value(1.2)), + }, + ], + }], + }), + }; + + let result = db + .row_inserts(RowInsertRequests { + inserts: vec![small_row_insert], + }) + .await; + assert!(result.is_ok()); + + // Test that large request exceeds limit + let large_rows: Vec = (0..100) + .map(|i| Row { + values: vec![ + Value { + value_data: Some(ValueData::StringValue(format!("host{}", i))), + }, + Value { + value_data: Some(ValueData::TimestampMillisecondValue(1000 + i)), + }, + Value { + value_data: Some(ValueData::F64Value(i as f64 * 1.2)), + }, + ], + }) + .collect(); + + let large_row_insert = RowInsertRequest { + table_name: table_name.to_owned(), + rows: Some(api::v1::Rows { + schema: column_schemas + .iter() + .map(|c| api::v1::ColumnSchema { + column_name: c.name.clone(), + datatype: c.data_type, + semantic_type: c.semantic_type, + datatype_extension: None, + options: None, + }) + .collect(), + rows: large_rows, + }), + }; + + let result = db + .row_inserts(RowInsertRequests { + inserts: vec![large_row_insert], + }) + .await; + assert!(result.is_err()); + let err = result.unwrap_err(); + let err_msg = err.to_string(); + assert!( + err_msg.contains("Too many concurrent"), + "Expected memory limit error, got: {}", + err_msg + ); + + let _ = fe_grpc_server.shutdown().await; +} diff --git a/tests-integration/tests/http.rs b/tests-integration/tests/http.rs index 538392e437..d5ed2ed4e6 100644 --- a/tests-integration/tests/http.rs +++ b/tests-integration/tests/http.rs @@ -1597,6 +1597,8 @@ fn drop_lines_with_inconsistent_results(input: String) -> String { "max_background_compactions =", "max_background_purges =", "enable_read_cache =", + "max_total_body_memory =", + "max_total_message_memory =", ]; input