feat: writer mem limiter for http and grpc service (#7092)

* feat: writer mem limiter for http and grpc service

Signed-off-by: jeremyhi <fengjiachun@gmail.com>

* fix: docs

Signed-off-by: jeremyhi <fengjiachun@gmail.com>

* feat: add metrics for limiter

Signed-off-by: jeremyhi <fengjiachun@gmail.com>

* Apply suggestion from @MichaelScofield

Co-authored-by: LFC <990479+MichaelScofield@users.noreply.github.com>

* chore: refactor try_acquire

Signed-off-by: jeremyhi <fengjiachun@gmail.com>

* chore: make size human readable

Signed-off-by: jeremyhi <fengjiachun@gmail.com>

---------

Signed-off-by: jeremyhi <fengjiachun@gmail.com>
Co-authored-by: LFC <990479+MichaelScofield@users.noreply.github.com>
This commit is contained in:
jeremyhi
2025-10-22 17:30:36 +08:00
committed by GitHub
parent a9a3e0b121
commit 62b51c6736
18 changed files with 704 additions and 8 deletions

View File

@@ -25,12 +25,14 @@
| `http.addr` | String | `127.0.0.1:4000` | The address to bind the HTTP server. |
| `http.timeout` | String | `0s` | HTTP request timeout. Set to 0 to disable timeout. |
| `http.body_limit` | String | `64MB` | HTTP request body limit.<br/>The following units are supported: `B`, `KB`, `KiB`, `MB`, `MiB`, `GB`, `GiB`, `TB`, `TiB`, `PB`, `PiB`.<br/>Set to 0 to disable limit. |
| `http.max_total_body_memory` | String | Unset | Maximum total memory for all concurrent HTTP request bodies.<br/>Set to 0 to disable the limit. Default: "0" (unlimited) |
| `http.enable_cors` | Bool | `true` | HTTP CORS support, it's turned on by default<br/>This allows browser to access http APIs without CORS restrictions |
| `http.cors_allowed_origins` | Array | Unset | Customize allowed origins for HTTP CORS. |
| `http.prom_validation_mode` | String | `strict` | Whether to enable validation for Prometheus remote write requests.<br/>Available options:<br/>- strict: deny invalid UTF-8 strings (default).<br/>- lossy: allow invalid UTF-8 strings, replace invalid characters with REPLACEMENT_CHARACTER(U+FFFD).<br/>- unchecked: do not valid strings. |
| `grpc` | -- | -- | The gRPC server options. |
| `grpc.bind_addr` | String | `127.0.0.1:4001` | The address to bind the gRPC server. |
| `grpc.runtime_size` | Integer | `8` | The number of server worker threads. |
| `grpc.max_total_message_memory` | String | Unset | Maximum total memory for all concurrent gRPC request messages.<br/>Set to 0 to disable the limit. Default: "0" (unlimited) |
| `grpc.max_connection_age` | String | Unset | The maximum connection age for gRPC connection.<br/>The value can be a human-readable time string. For example: `10m` for ten minutes or `1h` for one hour.<br/>Refer to https://grpc.io/docs/guides/keepalive/ for more details. |
| `grpc.tls` | -- | -- | gRPC server TLS options, see `mysql.tls` section. |
| `grpc.tls.mode` | String | `disable` | TLS mode. |
@@ -235,6 +237,7 @@
| `http.addr` | String | `127.0.0.1:4000` | The address to bind the HTTP server. |
| `http.timeout` | String | `0s` | HTTP request timeout. Set to 0 to disable timeout. |
| `http.body_limit` | String | `64MB` | HTTP request body limit.<br/>The following units are supported: `B`, `KB`, `KiB`, `MB`, `MiB`, `GB`, `GiB`, `TB`, `TiB`, `PB`, `PiB`.<br/>Set to 0 to disable limit. |
| `http.max_total_body_memory` | String | Unset | Maximum total memory for all concurrent HTTP request bodies.<br/>Set to 0 to disable the limit. Default: "0" (unlimited) |
| `http.enable_cors` | Bool | `true` | HTTP CORS support, it's turned on by default<br/>This allows browser to access http APIs without CORS restrictions |
| `http.cors_allowed_origins` | Array | Unset | Customize allowed origins for HTTP CORS. |
| `http.prom_validation_mode` | String | `strict` | Whether to enable validation for Prometheus remote write requests.<br/>Available options:<br/>- strict: deny invalid UTF-8 strings (default).<br/>- lossy: allow invalid UTF-8 strings, replace invalid characters with REPLACEMENT_CHARACTER(U+FFFD).<br/>- unchecked: do not valid strings. |
@@ -242,6 +245,7 @@
| `grpc.bind_addr` | String | `127.0.0.1:4001` | The address to bind the gRPC server. |
| `grpc.server_addr` | String | `127.0.0.1:4001` | The address advertised to the metasrv, and used for connections from outside the host.<br/>If left empty or unset, the server will automatically use the IP address of the first network interface<br/>on the host, with the same port number as the one specified in `grpc.bind_addr`. |
| `grpc.runtime_size` | Integer | `8` | The number of server worker threads. |
| `grpc.max_total_message_memory` | String | Unset | Maximum total memory for all concurrent gRPC request messages.<br/>Set to 0 to disable the limit. Default: "0" (unlimited) |
| `grpc.flight_compression` | String | `arrow_ipc` | Compression mode for frontend side Arrow IPC service. Available options:<br/>- `none`: disable all compression<br/>- `transport`: only enable gRPC transport compression (zstd)<br/>- `arrow_ipc`: only enable Arrow IPC compression (lz4)<br/>- `all`: enable all compression.<br/>Default to `none` |
| `grpc.max_connection_age` | String | Unset | The maximum connection age for gRPC connection.<br/>The value can be a human-readable time string. For example: `10m` for ten minutes or `1h` for one hour.<br/>Refer to https://grpc.io/docs/guides/keepalive/ for more details. |
| `grpc.tls` | -- | -- | gRPC server TLS options, see `mysql.tls` section. |

View File

@@ -31,6 +31,10 @@ timeout = "0s"
## The following units are supported: `B`, `KB`, `KiB`, `MB`, `MiB`, `GB`, `GiB`, `TB`, `TiB`, `PB`, `PiB`.
## Set to 0 to disable limit.
body_limit = "64MB"
## Maximum total memory for all concurrent HTTP request bodies.
## Set to 0 to disable the limit. Default: "0" (unlimited)
## @toml2docs:none-default
#+ max_total_body_memory = "1GB"
## HTTP CORS support, it's turned on by default
## This allows browser to access http APIs without CORS restrictions
enable_cors = true
@@ -54,6 +58,10 @@ bind_addr = "127.0.0.1:4001"
server_addr = "127.0.0.1:4001"
## The number of server worker threads.
runtime_size = 8
## Maximum total memory for all concurrent gRPC request messages.
## Set to 0 to disable the limit. Default: "0" (unlimited)
## @toml2docs:none-default
#+ max_total_message_memory = "1GB"
## Compression mode for frontend side Arrow IPC service. Available options:
## - `none`: disable all compression
## - `transport`: only enable gRPC transport compression (zstd)

View File

@@ -36,6 +36,10 @@ timeout = "0s"
## The following units are supported: `B`, `KB`, `KiB`, `MB`, `MiB`, `GB`, `GiB`, `TB`, `TiB`, `PB`, `PiB`.
## Set to 0 to disable limit.
body_limit = "64MB"
## Maximum total memory for all concurrent HTTP request bodies.
## Set to 0 to disable the limit. Default: "0" (unlimited)
## @toml2docs:none-default
#+ max_total_body_memory = "1GB"
## HTTP CORS support, it's turned on by default
## This allows browser to access http APIs without CORS restrictions
enable_cors = true
@@ -56,6 +60,10 @@ prom_validation_mode = "strict"
bind_addr = "127.0.0.1:4001"
## The number of server worker threads.
runtime_size = 8
## Maximum total memory for all concurrent gRPC request messages.
## Set to 0 to disable the limit. Default: "0" (unlimited)
## @toml2docs:none-default
#+ max_total_message_memory = "1GB"
## The maximum connection age for gRPC connection.
## The value can be a human-readable time string. For example: `10m` for ten minutes or `1h` for one hour.
## Refer to https://grpc.io/docs/guides/keepalive/ for more details.

View File

@@ -490,6 +490,7 @@ impl<'a> FlownodeServiceBuilder<'a> {
let config = GrpcServerConfig {
max_recv_message_size: opts.grpc.max_recv_message_size.as_bytes() as usize,
max_send_message_size: opts.grpc.max_send_message_size.as_bytes() as usize,
max_total_message_memory: opts.grpc.max_total_message_memory.as_bytes() as usize,
tls: opts.grpc.tls.clone(),
max_connection_age: opts.grpc.max_connection_age,
};

View File

@@ -20,6 +20,7 @@ use axum::http::StatusCode as HttpStatusCode;
use axum::response::{IntoResponse, Response};
use axum::{Json, http};
use base64::DecodeError;
use common_base::readable_size::ReadableSize;
use common_error::define_into_tonic_status;
use common_error::ext::{BoxedError, ErrorExt};
use common_error::status_code::StatusCode;
@@ -164,6 +165,18 @@ pub enum Error {
location: Location,
},
#[snafu(display(
"Too many concurrent large requests, limit: {}, request size: {}",
ReadableSize(*limit as u64),
ReadableSize(*request_size as u64)
))]
TooManyConcurrentRequests {
limit: usize,
request_size: usize,
#[snafu(implicit)]
location: Location,
},
#[snafu(display("Invalid query: {}", reason))]
InvalidQuery {
reason: String,
@@ -729,6 +742,8 @@ impl ErrorExt for Error {
InvalidUtf8Value { .. } | InvalidHeaderValue { .. } => StatusCode::InvalidArguments,
TooManyConcurrentRequests { .. } => StatusCode::RuntimeResourcesExhausted,
ParsePromQL { source, .. } => source.status_code(),
Other { source, .. } => source.status_code(),

View File

@@ -19,6 +19,7 @@ mod database;
pub mod flight;
pub mod frontend_grpc_handler;
pub mod greptime_handler;
pub mod memory_limit;
pub mod prom_query_gateway;
pub mod region_server;
@@ -51,6 +52,7 @@ use crate::error::{AlreadyStartedSnafu, InternalSnafu, Result, StartGrpcSnafu, T
use crate::metrics::MetricsMiddlewareLayer;
use crate::otel_arrow::{HeaderInterceptor, OtelArrowServiceHandler};
use crate::query_handler::OpenTelemetryProtocolHandlerRef;
use crate::request_limiter::RequestMemoryLimiter;
use crate::server::Server;
use crate::tls::TlsOption;
@@ -67,6 +69,8 @@ pub struct GrpcOptions {
pub max_recv_message_size: ReadableSize,
/// Max gRPC sending(encoding) message size
pub max_send_message_size: ReadableSize,
/// Maximum total memory for all concurrent gRPC request messages. 0 disables the limit.
pub max_total_message_memory: ReadableSize,
/// Compression mode in Arrow Flight service.
pub flight_compression: FlightCompression,
pub runtime_size: usize,
@@ -116,6 +120,7 @@ impl GrpcOptions {
GrpcServerConfig {
max_recv_message_size: self.max_recv_message_size.as_bytes() as usize,
max_send_message_size: self.max_send_message_size.as_bytes() as usize,
max_total_message_memory: self.max_total_message_memory.as_bytes() as usize,
tls: self.tls.clone(),
max_connection_age: self.max_connection_age,
}
@@ -134,6 +139,7 @@ impl Default for GrpcOptions {
server_addr: String::new(),
max_recv_message_size: DEFAULT_MAX_GRPC_RECV_MESSAGE_SIZE,
max_send_message_size: DEFAULT_MAX_GRPC_SEND_MESSAGE_SIZE,
max_total_message_memory: ReadableSize(0),
flight_compression: FlightCompression::ArrowIpc,
runtime_size: 8,
tls: TlsOption::default(),
@@ -153,6 +159,7 @@ impl GrpcOptions {
server_addr: format!("127.0.0.1:{}", DEFAULT_INTERNAL_GRPC_ADDR_PORT),
max_recv_message_size: DEFAULT_MAX_GRPC_RECV_MESSAGE_SIZE,
max_send_message_size: DEFAULT_MAX_GRPC_SEND_MESSAGE_SIZE,
max_total_message_memory: ReadableSize(0),
flight_compression: FlightCompression::ArrowIpc,
runtime_size: 8,
tls: TlsOption::default(),
@@ -217,6 +224,7 @@ pub struct GrpcServer {
bind_addr: Option<SocketAddr>,
name: Option<String>,
config: GrpcServerConfig,
memory_limiter: RequestMemoryLimiter,
}
/// Grpc Server configuration
@@ -226,6 +234,8 @@ pub struct GrpcServerConfig {
pub max_recv_message_size: usize,
// Max gRPC sending(encoding) message size
pub max_send_message_size: usize,
/// Maximum total memory for all concurrent gRPC request messages. 0 disables the limit.
pub max_total_message_memory: usize,
pub tls: TlsOption,
/// Maximum time that a channel may exist.
/// Useful when the server wants to control the reconnection of its clients.
@@ -238,6 +248,7 @@ impl Default for GrpcServerConfig {
Self {
max_recv_message_size: DEFAULT_MAX_GRPC_RECV_MESSAGE_SIZE.as_bytes() as usize,
max_send_message_size: DEFAULT_MAX_GRPC_SEND_MESSAGE_SIZE.as_bytes() as usize,
max_total_message_memory: 0,
tls: TlsOption::default(),
max_connection_age: None,
}
@@ -277,6 +288,11 @@ impl GrpcServer {
}
Ok(())
}
/// Get the memory limiter for monitoring current memory usage
pub fn memory_limiter(&self) -> &RequestMemoryLimiter {
&self.memory_limiter
}
}
pub struct HealthCheckHandler;

View File

@@ -38,6 +38,7 @@ use crate::grpc::{GrpcServer, GrpcServerConfig};
use crate::otel_arrow::{HeaderInterceptor, OtelArrowServiceHandler};
use crate::prometheus_handler::PrometheusHandlerRef;
use crate::query_handler::OpenTelemetryProtocolHandlerRef;
use crate::request_limiter::RequestMemoryLimiter;
use crate::tls::TlsOption;
/// Add a gRPC service (`service`) to a `builder`([RoutesBuilder]).
@@ -57,7 +58,17 @@ macro_rules! add_service {
.send_compressed(CompressionEncoding::Gzip)
.send_compressed(CompressionEncoding::Zstd);
$builder.routes_builder_mut().add_service(service_builder);
// Apply memory limiter layer
use $crate::grpc::memory_limit::MemoryLimiterExtensionLayer;
let service_with_limiter = $crate::tower::ServiceBuilder::new()
.layer(MemoryLimiterExtensionLayer::new(
$builder.memory_limiter().clone(),
))
.service(service_builder);
$builder
.routes_builder_mut()
.add_service(service_with_limiter);
};
}
@@ -73,10 +84,12 @@ pub struct GrpcServerBuilder {
HeaderInterceptor,
>,
>,
memory_limiter: RequestMemoryLimiter,
}
impl GrpcServerBuilder {
pub fn new(config: GrpcServerConfig, runtime: Runtime) -> Self {
let memory_limiter = RequestMemoryLimiter::new(config.max_total_message_memory);
Self {
name: None,
config,
@@ -84,6 +97,7 @@ impl GrpcServerBuilder {
routes_builder: RoutesBuilder::default(),
tls_config: None,
otel_arrow_service: None,
memory_limiter,
}
}
@@ -95,6 +109,10 @@ impl GrpcServerBuilder {
&self.runtime
}
pub fn memory_limiter(&self) -> &RequestMemoryLimiter {
&self.memory_limiter
}
pub fn name(self, name: Option<String>) -> Self {
Self { name, ..self }
}
@@ -198,6 +216,7 @@ impl GrpcServerBuilder {
bind_addr: None,
name: self.name,
config: self.config,
memory_limiter: self.memory_limiter,
}
}
}

View File

@@ -20,11 +20,14 @@ use common_error::status_code::StatusCode;
use common_query::OutputData;
use common_telemetry::{debug, warn};
use futures::StreamExt;
use prost::Message;
use tonic::{Request, Response, Status, Streaming};
use crate::grpc::greptime_handler::GreptimeRequestHandler;
use crate::grpc::{TonicResult, cancellation};
use crate::hint_headers;
use crate::metrics::{METRIC_GRPC_MEMORY_USAGE_BYTES, METRIC_GRPC_REQUESTS_REJECTED_TOTAL};
use crate::request_limiter::RequestMemoryLimiter;
pub(crate) struct DatabaseService {
handler: GreptimeRequestHandler,
@@ -48,6 +51,27 @@ impl GreptimeDatabase for DatabaseService {
"GreptimeDatabase::Handle: request from {:?} with hints: {:?}",
remote_addr, hints
);
let _guard = request
.extensions()
.get::<RequestMemoryLimiter>()
.filter(|limiter| limiter.is_enabled())
.and_then(|limiter| {
let message_size = request.get_ref().encoded_len();
limiter
.try_acquire(message_size)
.map(|guard| {
guard.inspect(|g| {
METRIC_GRPC_MEMORY_USAGE_BYTES.set(g.current_usage() as i64);
})
})
.inspect_err(|_| {
METRIC_GRPC_REQUESTS_REJECTED_TOTAL.inc();
})
.transpose()
})
.transpose()?;
let handler = self.handler.clone();
let request_future = async move {
let request = request.into_inner();
@@ -94,6 +118,9 @@ impl GreptimeDatabase for DatabaseService {
"GreptimeDatabase::HandleRequests: request from {:?} with hints: {:?}",
remote_addr, hints
);
let limiter = request.extensions().get::<RequestMemoryLimiter>().cloned();
let handler = self.handler.clone();
let request_future = async move {
let mut affected_rows = 0;
@@ -101,6 +128,25 @@ impl GreptimeDatabase for DatabaseService {
let mut stream = request.into_inner();
while let Some(request) = stream.next().await {
let request = request?;
let _guard = limiter
.as_ref()
.filter(|limiter| limiter.is_enabled())
.and_then(|limiter| {
let message_size = request.encoded_len();
limiter
.try_acquire(message_size)
.map(|guard| {
guard.inspect(|g| {
METRIC_GRPC_MEMORY_USAGE_BYTES.set(g.current_usage() as i64);
})
})
.inspect_err(|_| {
METRIC_GRPC_REQUESTS_REJECTED_TOTAL.inc();
})
.transpose()
})
.transpose()?;
let output = handler.handle_request(request, hints.clone()).await?;
match output.data {
OutputData::AffectedRows(rows) => affected_rows += rows,

View File

@@ -45,6 +45,8 @@ use crate::error::{InvalidParameterSnafu, ParseJsonSnafu, Result, ToJsonSnafu};
pub use crate::grpc::flight::stream::FlightRecordBatchStream;
use crate::grpc::greptime_handler::{GreptimeRequestHandler, get_request_type};
use crate::grpc::{FlightCompression, TonicResult, context_auth};
use crate::metrics::{METRIC_GRPC_MEMORY_USAGE_BYTES, METRIC_GRPC_REQUESTS_REJECTED_TOTAL};
use crate::request_limiter::{RequestMemoryGuard, RequestMemoryLimiter};
use crate::{error, hint_headers};
pub type TonicStream<T> = Pin<Box<dyn Stream<Item = TonicResult<T>> + Send + 'static>>;
@@ -211,7 +213,9 @@ impl FlightCraft for GreptimeRequestHandler {
&self,
request: Request<Streaming<FlightData>>,
) -> TonicResult<Response<TonicStream<PutResult>>> {
let (headers, _, stream) = request.into_parts();
let (headers, extensions, stream) = request.into_parts();
let limiter = extensions.get::<RequestMemoryLimiter>().cloned();
let query_ctx = context_auth::create_query_context_from_grpc_metadata(&headers)?;
context_auth::check_auth(self.user_provider.clone(), &headers, query_ctx.clone()).await?;
@@ -225,6 +229,7 @@ impl FlightCraft for GreptimeRequestHandler {
query_ctx.current_catalog().to_string(),
query_ctx.current_schema(),
),
limiter,
};
self.put_record_batches(stream, tx, query_ctx).await;
@@ -248,10 +253,15 @@ pub(crate) struct PutRecordBatchRequest {
pub(crate) table_name: TableName,
pub(crate) request_id: i64,
pub(crate) data: FlightData,
pub(crate) _guard: Option<RequestMemoryGuard>,
}
impl PutRecordBatchRequest {
fn try_new(table_name: TableName, flight_data: FlightData) -> Result<Self> {
fn try_new(
table_name: TableName,
flight_data: FlightData,
limiter: Option<&RequestMemoryLimiter>,
) -> Result<Self> {
let request_id = if !flight_data.app_metadata.is_empty() {
let metadata: DoPutMetadata =
serde_json::from_slice(&flight_data.app_metadata).context(ParseJsonSnafu)?;
@@ -259,10 +269,30 @@ impl PutRecordBatchRequest {
} else {
0
};
let _guard = limiter
.filter(|limiter| limiter.is_enabled())
.map(|limiter| {
let message_size = flight_data.encoded_len();
limiter
.try_acquire(message_size)
.map(|guard| {
guard.inspect(|g| {
METRIC_GRPC_MEMORY_USAGE_BYTES.set(g.current_usage() as i64);
})
})
.inspect_err(|_| {
METRIC_GRPC_REQUESTS_REJECTED_TOTAL.inc();
})
})
.transpose()?
.flatten();
Ok(Self {
table_name,
request_id,
data: flight_data,
_guard,
})
}
}
@@ -270,6 +300,7 @@ impl PutRecordBatchRequest {
pub(crate) struct PutRecordBatchRequestStream {
flight_data_stream: Streaming<FlightData>,
state: PutRecordBatchRequestStreamState,
limiter: Option<RequestMemoryLimiter>,
}
enum PutRecordBatchRequestStreamState {
@@ -298,6 +329,7 @@ impl Stream for PutRecordBatchRequestStream {
}
let poll = ready!(self.flight_data_stream.poll_next_unpin(cx));
let limiter = self.limiter.clone();
let result = match &mut self.state {
PutRecordBatchRequestStreamState::Init(catalog, schema) => match poll {
@@ -311,8 +343,11 @@ impl Stream for PutRecordBatchRequestStream {
Err(e) => return Poll::Ready(Some(Err(e.into()))),
};
let request =
PutRecordBatchRequest::try_new(table_name.clone(), flight_data);
let request = PutRecordBatchRequest::try_new(
table_name.clone(),
flight_data,
limiter.as_ref(),
);
let request = match request {
Ok(request) => request,
Err(e) => return Poll::Ready(Some(Err(e.into()))),
@@ -333,8 +368,12 @@ impl Stream for PutRecordBatchRequestStream {
},
PutRecordBatchRequestStreamState::Started(table_name) => poll.map(|x| {
x.and_then(|flight_data| {
PutRecordBatchRequest::try_new(table_name.clone(), flight_data)
.map_err(Into::into)
PutRecordBatchRequest::try_new(
table_name.clone(),
flight_data,
limiter.as_ref(),
)
.map_err(Into::into)
})
}),
};

View File

@@ -160,6 +160,7 @@ impl GreptimeRequestHandler {
table_name,
request_id,
data,
_guard,
} = request;
let timer = metrics::GRPC_BULK_INSERT_ELAPSED.start_timer();

View File

@@ -0,0 +1,72 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::task::{Context, Poll};
use futures::future::BoxFuture;
use tonic::server::NamedService;
use tower::{Layer, Service};
use crate::request_limiter::RequestMemoryLimiter;
#[derive(Clone)]
pub struct MemoryLimiterExtensionLayer {
limiter: RequestMemoryLimiter,
}
impl MemoryLimiterExtensionLayer {
pub fn new(limiter: RequestMemoryLimiter) -> Self {
Self { limiter }
}
}
impl<S> Layer<S> for MemoryLimiterExtensionLayer {
type Service = MemoryLimiterExtensionService<S>;
fn layer(&self, service: S) -> Self::Service {
MemoryLimiterExtensionService {
inner: service,
limiter: self.limiter.clone(),
}
}
}
#[derive(Clone)]
pub struct MemoryLimiterExtensionService<S> {
inner: S,
limiter: RequestMemoryLimiter,
}
impl<S: NamedService> NamedService for MemoryLimiterExtensionService<S> {
const NAME: &'static str = S::NAME;
}
impl<S, ReqBody> Service<http::Request<ReqBody>> for MemoryLimiterExtensionService<S>
where
S: Service<http::Request<ReqBody>>,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: http::Request<ReqBody>) -> Self::Future {
req.extensions_mut().insert(self.limiter.clone());
Box::pin(self.inner.call(req))
}
}

View File

@@ -82,6 +82,7 @@ use crate::query_handler::{
OpenTelemetryProtocolHandlerRef, OpentsdbProtocolHandlerRef, PipelineHandlerRef,
PromStoreProtocolHandlerRef,
};
use crate::request_limiter::RequestMemoryLimiter;
use crate::server::Server;
pub mod authorize;
@@ -97,6 +98,7 @@ pub mod jaeger;
pub mod logs;
pub mod loki;
pub mod mem_prof;
mod memory_limit;
pub mod opentsdb;
pub mod otlp;
pub mod pprof;
@@ -129,6 +131,7 @@ pub struct HttpServer {
router: StdMutex<Router>,
shutdown_tx: Mutex<Option<Sender<()>>>,
user_provider: Option<UserProviderRef>,
memory_limiter: RequestMemoryLimiter,
// plugins
plugins: Plugins,
@@ -151,6 +154,9 @@ pub struct HttpOptions {
pub body_limit: ReadableSize,
/// Maximum total memory for all concurrent HTTP request bodies. 0 disables the limit.
pub max_total_body_memory: ReadableSize,
/// Validation mode while decoding Prometheus remote write requests.
pub prom_validation_mode: PromValidationMode,
@@ -195,6 +201,7 @@ impl Default for HttpOptions {
timeout: Duration::from_secs(0),
disable_dashboard: false,
body_limit: DEFAULT_BODY_LIMIT,
max_total_body_memory: ReadableSize(0),
cors_allowed_origins: Vec::new(),
enable_cors: true,
prom_validation_mode: PromValidationMode::Strict,
@@ -746,6 +753,8 @@ impl HttpServerBuilder {
}
pub fn build(self) -> HttpServer {
let memory_limiter =
RequestMemoryLimiter::new(self.options.max_total_body_memory.as_bytes() as usize);
HttpServer {
options: self.options,
user_provider: self.user_provider,
@@ -753,6 +762,7 @@ impl HttpServerBuilder {
plugins: self.plugins,
router: StdMutex::new(self.router),
bind_addr: None,
memory_limiter,
}
}
}
@@ -877,6 +887,11 @@ impl HttpServer {
.option_layer(cors_layer)
.option_layer(timeout_layer)
.option_layer(body_limit_layer)
// memory limit layer - must be before body is consumed
.layer(middleware::from_fn_with_state(
self.memory_limiter.clone(),
memory_limit::memory_limit_middleware,
))
// auth layer
.layer(middleware::from_fn_with_state(
AuthState::new(self.user_provider.clone()),

View File

@@ -0,0 +1,52 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Middleware for limiting total memory usage of concurrent HTTP request bodies.
use axum::extract::{Request, State};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use http::StatusCode;
use crate::metrics::{METRIC_HTTP_MEMORY_USAGE_BYTES, METRIC_HTTP_REQUESTS_REJECTED_TOTAL};
use crate::request_limiter::RequestMemoryLimiter;
pub async fn memory_limit_middleware(
State(limiter): State<RequestMemoryLimiter>,
req: Request,
next: Next,
) -> Response {
let content_length = req
.headers()
.get(http::header::CONTENT_LENGTH)
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<usize>().ok())
.unwrap_or(0);
let _guard = match limiter.try_acquire(content_length) {
Ok(guard) => guard.inspect(|g| {
METRIC_HTTP_MEMORY_USAGE_BYTES.set(g.current_usage() as i64);
}),
Err(e) => {
METRIC_HTTP_REQUESTS_REJECTED_TOTAL.inc();
return (
StatusCode::TOO_MANY_REQUESTS,
format!("Request body memory limit exceeded: {}", e),
)
.into_response();
}
};
next.run(req).await
}

View File

@@ -20,6 +20,9 @@
use datafusion_expr::LogicalPlan;
use datatypes::schema::Schema;
use sql::statements::statement::Statement;
// Re-export for use in add_service! macro
#[doc(hidden)]
pub use tower;
pub mod addrs;
pub mod configurator;
@@ -47,6 +50,7 @@ pub mod prometheus_handler;
pub mod proto;
pub mod query_handler;
pub mod repeated_field;
pub mod request_limiter;
mod row_writer;
pub mod server;
pub mod tls;

View File

@@ -298,6 +298,26 @@ lazy_static! {
"greptime_servers_bulk_insert_elapsed",
"servers handle bulk insert elapsed",
).unwrap();
pub static ref METRIC_HTTP_MEMORY_USAGE_BYTES: IntGauge = register_int_gauge!(
"greptime_servers_http_memory_usage_bytes",
"current http request memory usage in bytes"
).unwrap();
pub static ref METRIC_HTTP_REQUESTS_REJECTED_TOTAL: IntCounter = register_int_counter!(
"greptime_servers_http_requests_rejected_total",
"total number of http requests rejected due to memory limit"
).unwrap();
pub static ref METRIC_GRPC_MEMORY_USAGE_BYTES: IntGauge = register_int_gauge!(
"greptime_servers_grpc_memory_usage_bytes",
"current grpc request memory usage in bytes"
).unwrap();
pub static ref METRIC_GRPC_REQUESTS_REJECTED_TOTAL: IntCounter = register_int_counter!(
"greptime_servers_grpc_requests_rejected_total",
"total number of grpc requests rejected due to memory limit"
).unwrap();
}
// Based on https://github.com/hyperium/tonic/blob/master/examples/src/tower/server.rs

View File

@@ -0,0 +1,215 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Request memory limiter for controlling total memory usage of concurrent requests.
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use crate::error::{Result, TooManyConcurrentRequestsSnafu};
/// Limiter for total memory usage of concurrent request bodies.
///
/// Tracks the total memory used by all concurrent request bodies
/// and rejects new requests when the limit is reached.
#[derive(Clone, Default)]
pub struct RequestMemoryLimiter {
inner: Option<Arc<LimiterInner>>,
}
struct LimiterInner {
current_usage: AtomicUsize,
max_memory: usize,
}
impl RequestMemoryLimiter {
/// Create a new memory limiter.
///
/// # Arguments
/// * `max_memory` - Maximum total memory for all concurrent request bodies in bytes (0 = unlimited)
pub fn new(max_memory: usize) -> Self {
if max_memory == 0 {
return Self { inner: None };
}
Self {
inner: Some(Arc::new(LimiterInner {
current_usage: AtomicUsize::new(0),
max_memory,
})),
}
}
/// Try to acquire memory for a request of given size.
///
/// Returns `Ok(RequestMemoryGuard)` if memory was acquired successfully.
/// Returns `Err` if the memory limit would be exceeded.
pub fn try_acquire(&self, request_size: usize) -> Result<Option<RequestMemoryGuard>> {
let Some(inner) = self.inner.as_ref() else {
return Ok(None);
};
let mut new_usage = 0;
let result =
inner
.current_usage
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
new_usage = current.saturating_add(request_size);
if new_usage <= inner.max_memory {
Some(new_usage)
} else {
None
}
});
match result {
Ok(_) => Ok(Some(RequestMemoryGuard {
size: request_size,
limiter: Arc::clone(inner),
usage_snapshot: new_usage,
})),
Err(_current) => TooManyConcurrentRequestsSnafu {
limit: inner.max_memory,
request_size,
}
.fail(),
}
}
/// Check if limiter is enabled
pub fn is_enabled(&self) -> bool {
self.inner.is_some()
}
/// Get current memory usage
pub fn current_usage(&self) -> usize {
self.inner
.as_ref()
.map(|inner| inner.current_usage.load(Ordering::Relaxed))
.unwrap_or(0)
}
/// Get max memory limit
pub fn max_memory(&self) -> usize {
self.inner
.as_ref()
.map(|inner| inner.max_memory)
.unwrap_or(0)
}
}
/// RAII guard that releases memory when dropped
pub struct RequestMemoryGuard {
size: usize,
limiter: Arc<LimiterInner>,
usage_snapshot: usize,
}
impl RequestMemoryGuard {
/// Returns the total memory usage snapshot at the time this guard was acquired.
pub fn current_usage(&self) -> usize {
self.usage_snapshot
}
}
impl Drop for RequestMemoryGuard {
fn drop(&mut self) {
self.limiter
.current_usage
.fetch_sub(self.size, Ordering::Release);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_limiter_disabled() {
let limiter = RequestMemoryLimiter::new(0);
assert!(!limiter.is_enabled());
assert!(limiter.try_acquire(1000000).unwrap().is_none());
assert_eq!(limiter.current_usage(), 0);
}
#[test]
fn test_limiter_basic() {
let limiter = RequestMemoryLimiter::new(1000);
assert!(limiter.is_enabled());
assert_eq!(limiter.max_memory(), 1000);
assert_eq!(limiter.current_usage(), 0);
// Acquire 400 bytes
let _guard1 = limiter.try_acquire(400).unwrap();
assert_eq!(limiter.current_usage(), 400);
// Acquire another 500 bytes
let _guard2 = limiter.try_acquire(500).unwrap();
assert_eq!(limiter.current_usage(), 900);
// Try to acquire 200 bytes - should fail (900 + 200 > 1000)
let result = limiter.try_acquire(200);
assert!(result.is_err());
assert_eq!(limiter.current_usage(), 900);
// Drop first guard
drop(_guard1);
assert_eq!(limiter.current_usage(), 500);
// Now we can acquire 200 bytes
let _guard3 = limiter.try_acquire(200).unwrap();
assert_eq!(limiter.current_usage(), 700);
}
#[test]
fn test_limiter_exact_limit() {
let limiter = RequestMemoryLimiter::new(1000);
// Acquire exactly the limit
let _guard = limiter.try_acquire(1000).unwrap();
assert_eq!(limiter.current_usage(), 1000);
// Try to acquire 1 more byte - should fail
let result = limiter.try_acquire(1);
assert!(result.is_err());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_limiter_concurrent() {
let limiter = RequestMemoryLimiter::new(1000);
let mut handles = vec![];
// Spawn 10 tasks each trying to acquire 200 bytes
for _ in 0..10 {
let limiter_clone = limiter.clone();
let handle = tokio::spawn(async move { limiter_clone.try_acquire(200) });
handles.push(handle);
}
let mut success_count = 0;
let mut fail_count = 0;
for handle in handles {
match handle.await.unwrap() {
Ok(Some(_)) => success_count += 1,
Err(_) => fail_count += 1,
Ok(None) => unreachable!(),
}
}
// Only 5 tasks should succeed (5 * 200 = 1000)
assert_eq!(success_count, 5);
assert_eq!(fail_count, 5);
}
}

View File

@@ -14,10 +14,12 @@
use api::v1::alter_table_expr::Kind;
use api::v1::promql_request::Promql;
use api::v1::value::ValueData;
use api::v1::{
AddColumn, AddColumns, AlterTableExpr, Basic, Column, ColumnDataType, ColumnDef,
CreateTableExpr, InsertRequest, InsertRequests, PromInstantQuery, PromRangeQuery,
PromqlRequest, RequestHeader, SemanticType, column,
PromqlRequest, RequestHeader, Row, RowInsertRequest, RowInsertRequests, SemanticType, Value,
column,
};
use auth::user_provider_from_option;
use client::{Client, DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, Database, OutputData};
@@ -89,6 +91,7 @@ macro_rules! grpc_tests {
test_prom_gateway_query,
test_grpc_timezone,
test_grpc_tls_config,
test_grpc_memory_limit,
);
)*
};
@@ -954,6 +957,7 @@ pub async fn test_grpc_tls_config(store_type: StorageType) {
let config = GrpcServerConfig {
max_recv_message_size: 1024,
max_send_message_size: 1024,
max_total_message_memory: 1024 * 1024 * 1024,
tls,
max_connection_age: None,
};
@@ -996,6 +1000,7 @@ pub async fn test_grpc_tls_config(store_type: StorageType) {
let config = GrpcServerConfig {
max_recv_message_size: 1024,
max_send_message_size: 1024,
max_total_message_memory: 1024 * 1024 * 1024,
tls,
max_connection_age: None,
};
@@ -1007,3 +1012,157 @@ pub async fn test_grpc_tls_config(store_type: StorageType) {
let _ = fe_grpc_server.shutdown().await;
}
pub async fn test_grpc_memory_limit(store_type: StorageType) {
let config = GrpcServerConfig {
max_recv_message_size: 1024 * 1024,
max_send_message_size: 1024 * 1024,
max_total_message_memory: 200,
tls: Default::default(),
max_connection_age: None,
};
let (_db, fe_grpc_server) =
setup_grpc_server_with(store_type, "test_grpc_memory_limit", None, Some(config)).await;
let addr = fe_grpc_server.bind_addr().unwrap().to_string();
let grpc_client = Client::with_urls([&addr]);
let db = Database::new(DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, grpc_client);
let table_name = "demo";
let column_schemas = vec![
ColumnDef {
name: "host".to_string(),
data_type: ColumnDataType::String as i32,
is_nullable: false,
default_constraint: vec![],
semantic_type: SemanticType::Tag as i32,
comment: String::new(),
datatype_extension: None,
options: None,
},
ColumnDef {
name: "ts".to_string(),
data_type: ColumnDataType::TimestampMillisecond as i32,
is_nullable: false,
default_constraint: vec![],
semantic_type: SemanticType::Timestamp as i32,
comment: String::new(),
datatype_extension: None,
options: None,
},
ColumnDef {
name: "cpu".to_string(),
data_type: ColumnDataType::Float64 as i32,
is_nullable: true,
default_constraint: vec![],
semantic_type: SemanticType::Field as i32,
comment: String::new(),
datatype_extension: None,
options: None,
},
];
let expr = CreateTableExpr {
catalog_name: DEFAULT_CATALOG_NAME.to_string(),
schema_name: DEFAULT_SCHEMA_NAME.to_string(),
table_name: table_name.to_string(),
desc: String::new(),
column_defs: column_schemas.clone(),
time_index: "ts".to_string(),
primary_keys: vec!["host".to_string()],
create_if_not_exists: true,
table_options: Default::default(),
table_id: None,
engine: MITO_ENGINE.to_string(),
};
db.create(expr).await.unwrap();
// Test that small request succeeds
let small_row_insert = RowInsertRequest {
table_name: table_name.to_owned(),
rows: Some(api::v1::Rows {
schema: column_schemas
.iter()
.map(|c| api::v1::ColumnSchema {
column_name: c.name.clone(),
datatype: c.data_type,
semantic_type: c.semantic_type,
datatype_extension: None,
options: None,
})
.collect(),
rows: vec![Row {
values: vec![
Value {
value_data: Some(ValueData::StringValue("host1".to_string())),
},
Value {
value_data: Some(ValueData::TimestampMillisecondValue(1000)),
},
Value {
value_data: Some(ValueData::F64Value(1.2)),
},
],
}],
}),
};
let result = db
.row_inserts(RowInsertRequests {
inserts: vec![small_row_insert],
})
.await;
assert!(result.is_ok());
// Test that large request exceeds limit
let large_rows: Vec<Row> = (0..100)
.map(|i| Row {
values: vec![
Value {
value_data: Some(ValueData::StringValue(format!("host{}", i))),
},
Value {
value_data: Some(ValueData::TimestampMillisecondValue(1000 + i)),
},
Value {
value_data: Some(ValueData::F64Value(i as f64 * 1.2)),
},
],
})
.collect();
let large_row_insert = RowInsertRequest {
table_name: table_name.to_owned(),
rows: Some(api::v1::Rows {
schema: column_schemas
.iter()
.map(|c| api::v1::ColumnSchema {
column_name: c.name.clone(),
datatype: c.data_type,
semantic_type: c.semantic_type,
datatype_extension: None,
options: None,
})
.collect(),
rows: large_rows,
}),
};
let result = db
.row_inserts(RowInsertRequests {
inserts: vec![large_row_insert],
})
.await;
assert!(result.is_err());
let err = result.unwrap_err();
let err_msg = err.to_string();
assert!(
err_msg.contains("Too many concurrent"),
"Expected memory limit error, got: {}",
err_msg
);
let _ = fe_grpc_server.shutdown().await;
}

View File

@@ -1597,6 +1597,8 @@ fn drop_lines_with_inconsistent_results(input: String) -> String {
"max_background_compactions =",
"max_background_purges =",
"enable_read_cache =",
"max_total_body_memory =",
"max_total_message_memory =",
];
input