diff --git a/Cargo.lock b/Cargo.lock index d8f6241136..676eaf0822 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9384,6 +9384,7 @@ dependencies = [ "common-macro", "common-meta", "common-query", + "common-telemetry", "criterion 0.7.0", "datafusion-common", "datafusion-expr", @@ -12067,6 +12068,7 @@ dependencies = [ "operator", "otel-arrow-rust", "parking_lot 0.12.4", + "partition", "permutation", "pg_interval_2", "pgwire", diff --git a/src/frontend/src/instance/prom_store.rs b/src/frontend/src/instance/prom_store.rs index 9a323eb989..c8f76753af 100644 --- a/src/frontend/src/instance/prom_store.rs +++ b/src/frontend/src/instance/prom_store.rs @@ -161,12 +161,11 @@ impl Instance { #[async_trait] impl PromStoreProtocolHandler for Instance { - async fn write( + async fn pre_write( &self, - request: RowInsertRequests, + request: &RowInsertRequests, ctx: QueryContextRef, - with_metric_engine: bool, - ) -> ServerResult { + ) -> ServerResult<()> { self.plugins .get::() .as_ref() @@ -175,7 +174,17 @@ impl PromStoreProtocolHandler for Instance { let interceptor_ref = self .plugins .get::>(); - interceptor_ref.pre_write(&request, ctx.clone())?; + interceptor_ref.pre_write(request, ctx)?; + Ok(()) + } + + async fn write( + &self, + request: RowInsertRequests, + ctx: QueryContextRef, + with_metric_engine: bool, + ) -> ServerResult { + self.pre_write(&request, ctx.clone()).await?; let output = if with_metric_engine { let physical_table = ctx diff --git a/src/frontend/src/server.rs b/src/frontend/src/server.rs index 4b51efbd33..4d0db700d1 100644 --- a/src/frontend/src/server.rs +++ b/src/frontend/src/server.rs @@ -37,6 +37,7 @@ use servers::interceptor::LogIngestInterceptorRef; use servers::metrics_handler::MetricsHandler; use servers::mysql::server::{MysqlServer, MysqlSpawnConfig, MysqlSpawnRef}; use servers::otel_arrow::OtelArrowServiceHandler; +use servers::pending_rows_batcher::PendingRowsBatcher; use servers::postgres::PostgresServer; use servers::request_memory_limiter::ServerMemoryLimiter; use servers::server::{Server, ServerHandlers}; @@ -124,12 +125,27 @@ where } if opts.prom_store.enable { + let pending_rows_batcher = if opts.prom_store.with_metric_engine { + PendingRowsBatcher::try_new( + self.instance.partition_manager().clone(), + self.instance.node_manager().clone(), + self.instance.catalog_manager().clone(), + opts.prom_store.pending_rows_flush_interval, + opts.prom_store.max_batch_rows, + opts.prom_store.max_concurrent_flushes, + opts.prom_store.worker_channel_capacity, + opts.prom_store.max_inflight_requests, + ) + } else { + None + }; builder = builder .with_prom_handler( self.instance.clone(), Some(self.instance.clone()), opts.prom_store.with_metric_engine, opts.http.prom_validation_mode, + pending_rows_batcher, ) .with_prometheus_handler(self.instance.clone()); } diff --git a/src/frontend/src/service_config/prom_store.rs b/src/frontend/src/service_config/prom_store.rs index b3adf889d2..99f1eada6d 100644 --- a/src/frontend/src/service_config/prom_store.rs +++ b/src/frontend/src/service_config/prom_store.rs @@ -12,12 +12,40 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::time::Duration; + use serde::{Deserialize, Serialize}; #[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] pub struct PromStoreOptions { pub enable: bool, pub with_metric_engine: bool, + #[serde(default, with = "humantime_serde")] + pub pending_rows_flush_interval: Duration, + #[serde(default = "default_max_batch_rows")] + pub max_batch_rows: usize, + #[serde(default = "default_max_concurrent_flushes")] + pub max_concurrent_flushes: usize, + #[serde(default = "default_worker_channel_capacity")] + pub worker_channel_capacity: usize, + #[serde(default = "default_max_inflight_requests")] + pub max_inflight_requests: usize, +} + +fn default_max_batch_rows() -> usize { + 100_000 +} + +fn default_max_concurrent_flushes() -> usize { + 256 +} + +fn default_worker_channel_capacity() -> usize { + 65526 +} + +fn default_max_inflight_requests() -> usize { + 3000 } impl Default for PromStoreOptions { @@ -25,18 +53,43 @@ impl Default for PromStoreOptions { Self { enable: true, with_metric_engine: true, + pending_rows_flush_interval: Duration::ZERO, + max_batch_rows: default_max_batch_rows(), + max_concurrent_flushes: default_max_concurrent_flushes(), + worker_channel_capacity: default_worker_channel_capacity(), + max_inflight_requests: default_max_inflight_requests(), } } } #[cfg(test)] mod tests { + use std::time::Duration; + use super::PromStoreOptions; + use crate::service_config::prom_store::{ + default_max_batch_rows, default_max_concurrent_flushes, default_max_inflight_requests, + default_worker_channel_capacity, + }; #[test] fn test_prom_store_options() { let default = PromStoreOptions::default(); assert!(default.enable); - assert!(default.with_metric_engine) + assert!(default.with_metric_engine); + assert_eq!(default.pending_rows_flush_interval, Duration::ZERO); + assert_eq!(default.max_batch_rows, default_max_batch_rows()); + assert_eq!( + default.max_concurrent_flushes, + default_max_concurrent_flushes() + ); + assert_eq!( + default.worker_channel_capacity, + default_worker_channel_capacity() + ); + assert_eq!( + default.max_inflight_requests, + default_max_inflight_requests() + ); } } diff --git a/src/metric-engine/src/batch_modifier.rs b/src/metric-engine/src/batch_modifier.rs index 8a5774889b..d06eaa976b 100644 --- a/src/metric-engine/src/batch_modifier.rs +++ b/src/metric-engine/src/batch_modifier.rs @@ -18,12 +18,11 @@ use std::sync::Arc; use datatypes::arrow::array::{Array, BinaryBuilder, StringArray, UInt64Array}; use datatypes::arrow::datatypes::{DataType, Field, Schema as ArrowSchema}; use datatypes::arrow::record_batch::RecordBatch; -use datatypes::value::ValueRef; use fxhash::FxHasher; use mito_codec::row_converter::SparsePrimaryKeyCodec; use snafu::ResultExt; use store_api::storage::ColumnId; -use store_api::storage::consts::{PRIMARY_KEY_COLUMN_NAME, ReservedColumnId}; +use store_api::storage::consts::PRIMARY_KEY_COLUMN_NAME; use crate::error::{EncodePrimaryKeySnafu, Result, UnexpectedRequestSnafu}; @@ -112,7 +111,6 @@ fn build_tag_arrays<'a>( } /// Modifies a RecordBatch for sparse primary key encoding. -#[allow(dead_code)] pub(crate) fn modify_batch_sparse( batch: RecordBatch, table_id: u32, @@ -128,24 +126,17 @@ pub(crate) fn modify_batch_sparse( let mut buffer = Vec::new(); for row in 0..num_rows { buffer.clear(); - let internal = [ - (ReservedColumnId::table_id(), ValueRef::UInt32(table_id)), - ( - ReservedColumnId::tsid(), - ValueRef::UInt64(tsid_array.value(row)), - ), - ]; codec - .encode_to_vec(internal.into_iter(), &mut buffer) + .encode_internal(table_id, tsid_array.value(row), &mut buffer) .context(EncodePrimaryKeySnafu)?; let tags = sorted_tag_columns .iter() .zip(tag_arrays.iter()) .filter(|(_, arr)| !arr.is_null(row)) - .map(|(tc, arr)| (tc.column_id, ValueRef::String(arr.value(row)))); + .map(|(tc, arr)| (tc.column_id, arr.value(row).as_bytes())); codec - .encode_to_vec(tags, &mut buffer) + .encode_raw_tag_value(tags, &mut buffer) .context(EncodePrimaryKeySnafu)?; pk_builder.append_value(&buffer); diff --git a/src/partition/Cargo.toml b/src/partition/Cargo.toml index d498ed8c13..a8e3a8ae11 100644 --- a/src/partition/Cargo.toml +++ b/src/partition/Cargo.toml @@ -15,6 +15,7 @@ common-error.workspace = true common-macro.workspace = true common-meta.workspace = true common-query.workspace = true +common-telemetry.workspace = true datafusion-common.workspace = true datafusion-expr.workspace = true datafusion-physical-expr.workspace = true diff --git a/src/servers/Cargo.toml b/src/servers/Cargo.toml index 8e84ef77d6..6531390ca3 100644 --- a/src/servers/Cargo.toml +++ b/src/servers/Cargo.toml @@ -88,6 +88,7 @@ opentelemetry-proto.workspace = true operator.workspace = true otel-arrow-rust.workspace = true parking_lot.workspace = true +partition.workspace = true pg_interval = { version = "0.5.2", package = "pg_interval_2" } pgwire = { version = "0.38.2", default-features = false, features = [ "server-api-ring", diff --git a/src/servers/src/error.rs b/src/servers/src/error.rs index 18ac964f05..5fae7a82db 100644 --- a/src/servers/src/error.rs +++ b/src/servers/src/error.rs @@ -56,6 +56,9 @@ pub enum Error { #[snafu(display("Internal error: {}", err_msg))] Internal { err_msg: String }, + #[snafu(display("Pending rows batcher channel closed"))] + BatcherChannelClosed, + #[snafu(display("Unsupported data type: {}, reason: {}", data_type, reason))] UnsupportedDataType { data_type: ConcreteDataType, @@ -684,6 +687,7 @@ impl ErrorExt for Error { use Error::*; match self { Internal { .. } + | BatcherChannelClosed | InternalIo { .. } | TokioIo { .. } | StartHttp { .. } diff --git a/src/servers/src/http.rs b/src/servers/src/http.rs index 506a240cac..eb2086726a 100644 --- a/src/servers/src/http.rs +++ b/src/servers/src/http.rs @@ -16,7 +16,7 @@ use std::collections::HashMap; use std::convert::Infallible; use std::fmt::Display; use std::net::SocketAddr; -use std::sync::Mutex as StdMutex; +use std::sync::{Arc, Mutex as StdMutex}; use std::time::Duration; use async_trait::async_trait; @@ -75,6 +75,7 @@ use crate::http::result::null_result::NullResponse; use crate::interceptor::LogIngestInterceptorRef; use crate::metrics::http_metrics_layer; use crate::metrics_handler::MetricsHandler; +use crate::pending_rows_batcher::PendingRowsBatcher; use crate::prometheus_handler::PrometheusHandlerRef; use crate::query_handler::sql::ServerSqlQueryHandlerRef; use crate::query_handler::{ @@ -585,12 +586,14 @@ impl HttpServerBuilder { pipeline_handler: Option, prom_store_with_metric_engine: bool, prom_validation_mode: PromValidationMode, + pending_rows_batcher: Option>, ) -> Self { let state = PromStoreState { prom_store_handler: handler, pipeline_handler, prom_store_with_metric_engine, prom_validation_mode, + pending_rows_batcher, }; Self { diff --git a/src/servers/src/http/prom_store.rs b/src/servers/src/http/prom_store.rs index 58c6e0eddd..bfc072e84e 100644 --- a/src/servers/src/http/prom_store.rs +++ b/src/servers/src/http/prom_store.rs @@ -35,6 +35,7 @@ use snafu::prelude::*; use crate::error::{self, InternalSnafu, PipelineSnafu, Result}; use crate::http::extractor::PipelineInfo; use crate::http::header::{GREPTIME_DB_HEADER_METRICS, write_cost_header_map}; +use crate::pending_rows_batcher::PendingRowsBatcher; use crate::prom_remote_write::decode::PromSeriesProcessor; use crate::prom_remote_write::decode_remote_write_request; use crate::prom_remote_write::validation::PromValidationMode; @@ -52,6 +53,7 @@ pub struct PromStoreState { pub pipeline_handler: Option, pub prom_store_with_metric_engine: bool, pub prom_validation_mode: PromValidationMode, + pub pending_rows_batcher: Option>, } #[derive(Debug, Serialize, Deserialize)] @@ -92,6 +94,7 @@ pub async fn remote_write( pipeline_handler, prom_store_with_metric_engine, prom_validation_mode, + pending_rows_batcher, } = state; if let Some(_vm_handshake) = params.get_vm_proto_version { @@ -100,9 +103,11 @@ pub async fn remote_write( let db = params.db.clone().unwrap_or_default(); query_ctx.set_channel(Channel::Prometheus); - if let Some(physical_table) = params.physical_table { - query_ctx.set_extension(PHYSICAL_TABLE_PARAM, physical_table); - } + let physical_table = params + .physical_table + .clone() + .unwrap_or_else(|| GREPTIME_PHYSICAL_TABLE.to_string()); + query_ctx.set_extension(PHYSICAL_TABLE_PARAM, physical_table.clone()); let query_ctx = Arc::new(query_ctx); let _timer = crate::metrics::METRIC_HTTP_PROM_STORE_WRITE_ELAPSED .with_label_values(&[db.as_str()]) @@ -135,6 +140,19 @@ pub async fn remote_write( req.as_insert_requests() }; + if prom_store_with_metric_engine && let Some(batcher) = pending_rows_batcher { + for (temp_ctx, reqs) in req.as_req_iter(query_ctx) { + prom_store_handler + .pre_write(&reqs, temp_ctx.clone()) + .await?; + let rows = batcher.submit(reqs, temp_ctx).await?; + crate::metrics::PROM_STORE_REMOTE_WRITE_SAMPLES + .with_label_values(&[db.as_str()]) + .inc_by(rows); + } + return Ok((StatusCode::NO_CONTENT, write_cost_header_map(0)).into_response()); + } + let mut cost = 0; for (temp_ctx, reqs) in req.as_req_iter(query_ctx) { let cnt: u64 = reqs diff --git a/src/servers/src/lib.rs b/src/servers/src/lib.rs index 9ee7395691..c44c674b9e 100644 --- a/src/servers/src/lib.rs +++ b/src/servers/src/lib.rs @@ -41,6 +41,7 @@ pub mod mysql; pub mod opentsdb; pub mod otel_arrow; pub mod otlp; +pub mod pending_rows_batcher; mod pipeline; pub mod postgres; pub mod prom_remote_write; diff --git a/src/servers/src/metrics.rs b/src/servers/src/metrics.rs index 25a900ed3d..37f923b73d 100644 --- a/src/servers/src/metrics.rs +++ b/src/servers/src/metrics.rs @@ -121,13 +121,62 @@ lazy_static! { /// Duration to convert prometheus write request to gRPC request. pub static ref METRIC_HTTP_PROM_STORE_CONVERT_ELAPSED: Histogram = METRIC_HTTP_PROM_STORE_CODEC_ELAPSED .with_label_values(&["convert"]); - /// The samples count of Prometheus remote write. + /// The samples count of Prometheus remote write. pub static ref PROM_STORE_REMOTE_WRITE_SAMPLES: IntCounterVec = register_int_counter_vec!( "greptime_servers_prometheus_remote_write_samples", "frontend prometheus remote write samples", &[METRIC_DB_LABEL] ) .unwrap(); + pub static ref PENDING_BATCHES: IntGauge = register_int_gauge!( + "greptime_prom_store_pending_batches", + "Number of pending batches waiting to be flushed" + ) + .unwrap(); + pub static ref PENDING_ROWS: IntGauge = register_int_gauge!( + "greptime_prom_store_pending_rows", + "Number of pending rows waiting to be flushed" + ) + .unwrap(); + pub static ref PENDING_WORKERS: IntGauge = register_int_gauge!( + "greptime_prom_store_pending_workers", + "Number of active pending rows batch workers" + ) + .unwrap(); + pub static ref FLUSH_TOTAL: IntCounter = register_int_counter!( + "greptime_prom_store_flush_total", + "Total number of batch flushes" + ) + .unwrap(); + pub static ref FLUSH_ROWS: Histogram = register_histogram!( + "greptime_prom_store_flush_rows", + "Number of rows per flush", + vec![100.0, 1000.0, 10000.0, 50000.0, 100000.0, 500000.0] + ) + .unwrap(); + pub static ref FLUSH_ELAPSED: Histogram = register_histogram!( + "greptime_prom_store_flush_elapsed", + "Elapsed time of pending rows batch flush in seconds", + vec![0.005, 0.01, 0.05, 0.1, 0.5, 1.0, 5.0, 10.0, 60.0, 300.0] + ) + .unwrap(); + pub static ref FLUSH_DROPPED_ROWS: IntCounter = register_int_counter!( + "greptime_pending_rows_flush_dropped_rows", + "Total rows dropped due to pending rows flush failures" + ) + .unwrap(); + pub static ref FLUSH_FAILURES: IntCounter = register_int_counter!( + "greptime_pending_rows_flush_failures", + "Total pending rows flush failures" + ) + .unwrap(); + pub static ref PENDING_ROWS_BATCH_INGEST_STAGE_ELAPSED: HistogramVec = register_histogram_vec!( + "greptime_prom_store_pending_rows_batch_ingest_stage_elapsed", + "Elapsed time of pending rows batch ingestion stages in seconds", + &["stage"], + vec![0.0005, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1.0, 5.0, 10.0, 60.0] + ) + .unwrap(); /// Http prometheus read duration per database. pub static ref METRIC_HTTP_PROM_STORE_READ_ELAPSED: HistogramVec = register_histogram_vec!( "greptime_servers_http_prometheus_read_elapsed", diff --git a/src/servers/src/pending_rows_batcher.rs b/src/servers/src/pending_rows_batcher.rs new file mode 100644 index 0000000000..f8486e3636 --- /dev/null +++ b/src/servers/src/pending_rows_batcher.rs @@ -0,0 +1,1253 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::HashMap; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use api::helper::ColumnDataTypeWrapper; +use api::v1::region::{ + BulkInsertRequest, RegionRequest, RegionRequestHeader, bulk_insert_request, region_request, +}; +use api::v1::value::ValueData; +use api::v1::{ArrowIpc, RowInsertRequests, Rows}; +use arrow::array::{ + ArrayRef, Float64Builder, StringBuilder, TimestampMicrosecondBuilder, + TimestampMillisecondBuilder, TimestampNanosecondBuilder, TimestampSecondBuilder, + new_null_array, +}; +use arrow::compute::{cast, concat_batches, filter_record_batch}; +use arrow::datatypes::{Field, Schema as ArrowSchema}; +use arrow::record_batch::RecordBatch; +use arrow_schema::TimeUnit; +use bytes::Bytes; +use catalog::CatalogManagerRef; +use common_grpc::flight::{FlightEncoder, FlightMessage}; +use common_meta::node_manager::NodeManagerRef; +use common_query::prelude::GREPTIME_PHYSICAL_TABLE; +use common_telemetry::tracing_context::TracingContext; +use common_telemetry::{debug, error, info, warn}; +use dashmap::DashMap; +use dashmap::mapref::entry::Entry; +use datatypes::data_type::DataType; +use datatypes::prelude::ConcreteDataType; +use partition::manager::PartitionRuleManagerRef; +use session::context::QueryContextRef; +use snafu::{ResultExt, ensure}; +use store_api::storage::RegionId; +use tokio::sync::{OwnedSemaphorePermit, Semaphore, broadcast, mpsc, oneshot}; + +use crate::error; +use crate::error::{Error, Result}; +use crate::metrics::{ + FLUSH_DROPPED_ROWS, FLUSH_ELAPSED, FLUSH_FAILURES, FLUSH_ROWS, FLUSH_TOTAL, PENDING_BATCHES, + PENDING_ROWS, PENDING_ROWS_BATCH_INGEST_STAGE_ELAPSED, PENDING_WORKERS, +}; + +const PHYSICAL_TABLE_KEY: &str = "physical_table"; +/// Whether wait for ingestion result before reply to client. +const PENDING_ROWS_BATCH_SYNC_ENV: &str = "PENDING_ROWS_BATCH_SYNC"; +const WORKER_IDLE_TIMEOUT_MULTIPLIER: u32 = 3; + +#[derive(Debug, Clone, Hash, Eq, PartialEq)] +struct BatchKey { + catalog: String, + schema: String, + physical_table: String, +} + +#[derive(Debug)] +struct TableBatch { + table_name: String, + batches: Vec, + row_count: usize, +} + +struct PendingBatch { + tables: HashMap, + created_at: Option, + total_row_count: usize, + ctx: Option, + waiters: Vec, +} + +struct FlushWaiter { + response_tx: oneshot::Sender>, + _permit: OwnedSemaphorePermit, +} + +struct FlushBatch { + table_batches: Vec, + total_row_count: usize, + ctx: QueryContextRef, + waiters: Vec, +} + +#[derive(Clone)] +struct PendingWorker { + tx: mpsc::Sender, +} + +enum WorkerCommand { + Submit { + table_batches: Vec<(String, RecordBatch)>, + total_rows: usize, + ctx: QueryContextRef, + response_tx: oneshot::Sender>, + _permit: OwnedSemaphorePermit, + }, +} + +// Batch key is derived from QueryContext; it assumes catalog/schema/physical_table fully +// define the write target and must remain consistent across the batch. +fn batch_key_from_ctx(ctx: &QueryContextRef) -> BatchKey { + let physical_table = ctx + .extension(PHYSICAL_TABLE_KEY) + .unwrap_or(GREPTIME_PHYSICAL_TABLE) + .to_string(); + BatchKey { + catalog: ctx.current_catalog().to_string(), + schema: ctx.current_schema(), + physical_table, + } +} + +/// Prometheus remote write pending rows batcher. +pub struct PendingRowsBatcher { + workers: Arc>, + flush_interval: Duration, + max_batch_rows: usize, + partition_manager: PartitionRuleManagerRef, + node_manager: NodeManagerRef, + catalog_manager: CatalogManagerRef, + flush_semaphore: Arc, + inflight_semaphore: Arc, + worker_channel_capacity: usize, + pending_rows_batch_sync: bool, + shutdown: broadcast::Sender<()>, +} + +impl PendingRowsBatcher { + #[allow(clippy::too_many_arguments)] + pub fn try_new( + partition_manager: PartitionRuleManagerRef, + node_manager: NodeManagerRef, + catalog_manager: CatalogManagerRef, + flush_interval: Duration, + max_batch_rows: usize, + max_concurrent_flushes: usize, + worker_channel_capacity: usize, + max_inflight_requests: usize, + ) -> Option> { + // Disable the batcher if flush is disabled or configuration is invalid. + // Zero values for these knobs either cause panics (e.g., zero-capacity channels) + // or deadlocks (e.g., semaphores with no permits). + if flush_interval.is_zero() + || max_batch_rows == 0 + || max_concurrent_flushes == 0 + || worker_channel_capacity == 0 + || max_inflight_requests == 0 + { + return None; + } + + let (shutdown, _) = broadcast::channel(1); + let pending_rows_batch_sync = std::env::var(PENDING_ROWS_BATCH_SYNC_ENV) + .ok() + .as_deref() + .and_then(|v| v.parse::().ok()) + .unwrap_or(true); + let workers = Arc::new(DashMap::new()); + PENDING_WORKERS.set(workers.len() as i64); + + Some(Arc::new(Self { + workers, + flush_interval, + max_batch_rows, + partition_manager, + node_manager, + catalog_manager, + flush_semaphore: Arc::new(Semaphore::new(max_concurrent_flushes)), + inflight_semaphore: Arc::new(Semaphore::new(max_inflight_requests)), + worker_channel_capacity, + pending_rows_batch_sync, + shutdown, + })) + } + + pub async fn submit(&self, requests: RowInsertRequests, ctx: QueryContextRef) -> Result { + let (table_batches, total_rows) = { + let _timer = PENDING_ROWS_BATCH_INGEST_STAGE_ELAPSED + .with_label_values(&["submit_build_table_batches"]) + .start_timer(); + build_table_batches(requests)? + }; + if total_rows == 0 { + return Ok(0); + } + let table_batches = { + let _timer = PENDING_ROWS_BATCH_INGEST_STAGE_ELAPSED + .with_label_values(&["submit_align_region_schema"]) + .start_timer(); + self.align_table_batches_to_region_schema(table_batches, &ctx) + .await? + }; + + let permit = { + let _timer = PENDING_ROWS_BATCH_INGEST_STAGE_ELAPSED + .with_label_values(&["submit_acquire_inflight_permit"]) + .start_timer(); + self.inflight_semaphore + .clone() + .acquire_owned() + .await + .map_err(|_| Error::BatcherChannelClosed)? + }; + + let (response_tx, response_rx) = oneshot::channel(); + + let batch_key = batch_key_from_ctx(&ctx); + let mut cmd = Some(WorkerCommand::Submit { + table_batches, + total_rows, + ctx, + response_tx, + _permit: permit, + }); + + { + let _timer = PENDING_ROWS_BATCH_INGEST_STAGE_ELAPSED + .with_label_values(&["submit_send_to_worker"]) + .start_timer(); + + for _ in 0..2 { + let worker = self.get_or_spawn_worker(batch_key.clone()); + let Some(worker_cmd) = cmd.take() else { + break; + }; + + match worker.tx.send(worker_cmd).await { + Ok(()) => break, + Err(err) => { + cmd = Some(err.0); + remove_worker_if_same_channel( + self.workers.as_ref(), + &batch_key, + &worker.tx, + ); + } + } + } + + if cmd.is_some() { + return Err(Error::BatcherChannelClosed); + } + } + + if self.pending_rows_batch_sync { + let result = { + let _timer = PENDING_ROWS_BATCH_INGEST_STAGE_ELAPSED + .with_label_values(&["submit_wait_flush_result"]) + .start_timer(); + response_rx.await.map_err(|_| Error::BatcherChannelClosed)? + }; + result.map(|()| total_rows as u64) + } else { + Ok(total_rows as u64) + } + } + + async fn align_table_batches_to_region_schema( + &self, + table_batches: Vec<(String, RecordBatch)>, + ctx: &QueryContextRef, + ) -> Result> { + let catalog = ctx.current_catalog().to_string(); + let schema = ctx.current_schema(); + let mut region_schemas: HashMap> = HashMap::new(); + let mut aligned_batches = Vec::with_capacity(table_batches.len()); + + for (table_name, record_batch) in table_batches { + let region_schema = if let Some(region_schema) = region_schemas.get(&table_name) { + region_schema.clone() + } else { + let table = self + .catalog_manager + .table(&catalog, &schema, &table_name, Some(ctx.as_ref())) + .await + .map_err(|err| Error::Internal { + err_msg: format!( + "Failed to resolve table {} for pending batch alignment: {}", + table_name, err + ), + })? + .ok_or_else(|| Error::Internal { + err_msg: format!( + "Table not found during pending batch alignment: {}", + table_name + ), + })?; + let region_schema = table.table_info().meta.schema.arrow_schema().clone(); + region_schemas.insert(table_name.clone(), region_schema.clone()); + region_schema + }; + + let record_batch = align_record_batch_to_schema(record_batch, region_schema.as_ref())?; + aligned_batches.push((table_name, record_batch)); + } + + Ok(aligned_batches) + } + + fn get_or_spawn_worker(&self, key: BatchKey) -> PendingWorker { + if let Some(worker) = self.workers.get(&key) + && !worker.tx.is_closed() + { + return worker.clone(); + } + + let entry = self.workers.entry(key.clone()); + match entry { + Entry::Occupied(mut worker) => { + if worker.get().tx.is_closed() { + let new_worker = self.spawn_worker(key); + worker.insert(new_worker.clone()); + PENDING_WORKERS.set(self.workers.len() as i64); + new_worker + } else { + worker.get().clone() + } + } + Entry::Vacant(vacant) => { + let worker = self.spawn_worker(key); + + vacant.insert(worker.clone()); + PENDING_WORKERS.set(self.workers.len() as i64); + worker + } + } + } + + fn spawn_worker(&self, key: BatchKey) -> PendingWorker { + let (tx, rx) = mpsc::channel(self.worker_channel_capacity); + let worker = PendingWorker { tx: tx.clone() }; + let worker_idle_timeout = self + .flush_interval + .checked_mul(WORKER_IDLE_TIMEOUT_MULTIPLIER) + .unwrap_or(self.flush_interval); + + start_worker( + key, + worker.tx.clone(), + self.workers.clone(), + rx, + self.shutdown.clone(), + self.partition_manager.clone(), + self.node_manager.clone(), + self.catalog_manager.clone(), + self.flush_interval, + worker_idle_timeout, + self.max_batch_rows, + self.flush_semaphore.clone(), + ); + + worker + } +} + +impl Drop for PendingRowsBatcher { + fn drop(&mut self) { + let _ = self.shutdown.send(()); + } +} + +impl PendingBatch { + fn new() -> Self { + Self { + tables: HashMap::new(), + created_at: None, + total_row_count: 0, + ctx: None, + waiters: Vec::new(), + } + } +} + +#[allow(clippy::too_many_arguments)] +fn start_worker( + key: BatchKey, + worker_tx: mpsc::Sender, + workers: Arc>, + mut rx: mpsc::Receiver, + shutdown: broadcast::Sender<()>, + partition_manager: PartitionRuleManagerRef, + node_manager: NodeManagerRef, + catalog_manager: CatalogManagerRef, + flush_interval: Duration, + worker_idle_timeout: Duration, + max_batch_rows: usize, + flush_semaphore: Arc, +) { + tokio::spawn(async move { + let mut batch = PendingBatch::new(); + let mut interval = tokio::time::interval(flush_interval); + let mut shutdown_rx = shutdown.subscribe(); + let idle_deadline = tokio::time::Instant::now() + worker_idle_timeout; + let idle_timer = tokio::time::sleep_until(idle_deadline); + tokio::pin!(idle_timer); + + loop { + tokio::select! { + cmd = rx.recv() => { + match cmd { + Some(WorkerCommand::Submit { table_batches, total_rows, ctx, response_tx, _permit }) => { + idle_timer.as_mut().reset(tokio::time::Instant::now() + worker_idle_timeout); + + if batch.total_row_count == 0 { + batch.created_at = Some(Instant::now()); + batch.ctx = Some(ctx); + PENDING_BATCHES.inc(); + } + + batch.waiters.push(FlushWaiter { response_tx, _permit }); + + for (table_name, record_batch) in table_batches { + let entry = batch.tables.entry(table_name.clone()).or_insert_with(|| TableBatch { + table_name, + batches: Vec::new(), + row_count: 0, + }); + entry.row_count += record_batch.num_rows(); + entry.batches.push(record_batch); + } + + batch.total_row_count += total_rows; + PENDING_ROWS.add(total_rows as i64); + + if batch.total_row_count >= max_batch_rows + && let Some(flush) = drain_batch(&mut batch) { + spawn_flush( + flush, + partition_manager.clone(), + node_manager.clone(), + catalog_manager.clone(), + flush_semaphore.clone(), + ).await; + } + } + None => { + if let Some(flush) = drain_batch(&mut batch) { + flush_batch( + flush, + partition_manager.clone(), + node_manager.clone(), + catalog_manager.clone(), + ).await; + } + break; + } + } + } + _ = &mut idle_timer => { + if !should_close_worker_on_idle_timeout(batch.total_row_count, rx.len()) { + idle_timer + .as_mut() + .reset(tokio::time::Instant::now() + worker_idle_timeout); + continue; + } + + debug!( + "Closing idle pending rows worker due to timeout: catalog={}, schema={}, physical_table={}", + key.catalog, + key.schema, + key.physical_table + ); + break; + } + _ = interval.tick() => { + if let Some(created_at) = batch.created_at + && batch.total_row_count > 0 + && created_at.elapsed() >= flush_interval + && let Some(flush) = drain_batch(&mut batch) { + spawn_flush( + flush, + partition_manager.clone(), + node_manager.clone(), + catalog_manager.clone(), + flush_semaphore.clone(), + ).await; + } + } + _ = shutdown_rx.recv() => { + if let Some(flush) = drain_batch(&mut batch) { + flush_batch( + flush, + partition_manager.clone(), + node_manager.clone(), + catalog_manager.clone(), + ).await; + } + break; + } + } + } + + remove_worker_if_same_channel(workers.as_ref(), &key, &worker_tx); + }); +} + +fn remove_worker_if_same_channel( + workers: &DashMap, + key: &BatchKey, + worker_tx: &mpsc::Sender, +) -> bool { + if let Some(worker) = workers.get(key) + && worker.tx.same_channel(worker_tx) + { + drop(worker); + workers.remove(key); + PENDING_WORKERS.set(workers.len() as i64); + return true; + } + + false +} + +fn should_close_worker_on_idle_timeout(total_row_count: usize, queued_requests: usize) -> bool { + total_row_count == 0 && queued_requests == 0 +} + +fn drain_batch(batch: &mut PendingBatch) -> Option { + if batch.total_row_count == 0 { + return None; + } + + let ctx = match batch.ctx.take() { + Some(ctx) => ctx, + None => { + flush_with_error(batch, "Pending batch missing context"); + return None; + } + }; + + let total_row_count = batch.total_row_count; + let table_batches = std::mem::take(&mut batch.tables).into_values().collect(); + let waiters = std::mem::take(&mut batch.waiters); + batch.total_row_count = 0; + batch.created_at = None; + + PENDING_ROWS.sub(total_row_count as i64); + PENDING_BATCHES.dec(); + + Some(FlushBatch { + table_batches, + total_row_count, + ctx, + waiters, + }) +} + +async fn spawn_flush( + flush: FlushBatch, + partition_manager: PartitionRuleManagerRef, + node_manager: NodeManagerRef, + catalog_manager: CatalogManagerRef, + semaphore: Arc, +) { + match semaphore.acquire_owned().await { + Ok(permit) => { + tokio::spawn(async move { + let _permit = permit; + flush_batch(flush, partition_manager, node_manager, catalog_manager).await; + }); + } + Err(err) => { + warn!(err; "Flush semaphore closed, flushing inline"); + flush_batch(flush, partition_manager, node_manager, catalog_manager).await; + } + } +} + +async fn flush_batch( + flush: FlushBatch, + partition_manager: PartitionRuleManagerRef, + node_manager: NodeManagerRef, + catalog_manager: CatalogManagerRef, +) { + let FlushBatch { + table_batches, + total_row_count, + ctx, + waiters, + } = flush; + let start = Instant::now(); + let mut first_error: Option = None; + + let catalog = ctx.current_catalog().to_string(); + let schema = ctx.current_schema(); + + macro_rules! record_failure { + ($row_count:expr, $msg:expr) => {{ + let msg = $msg; + if first_error.is_none() { + first_error = Some(msg.clone()); + } + mark_flush_failure($row_count, &msg); + }}; + } + + for table_batch in table_batches { + let Some(first_batch) = table_batch.batches.first() else { + continue; + }; + + let schema_ref = first_batch.schema(); + let record_batch = { + let _timer = PENDING_ROWS_BATCH_INGEST_STAGE_ELAPSED + .with_label_values(&["flush_concat_table_batches"]) + .start_timer(); + match concat_batches(&schema_ref, &table_batch.batches) { + Ok(batch) => batch, + Err(err) => { + record_failure!( + table_batch.row_count, + format!( + "Failed to concat table batch {}: {:?}", + table_batch.table_name, err + ) + ); + continue; + } + } + }; + + let table = { + let _timer = PENDING_ROWS_BATCH_INGEST_STAGE_ELAPSED + .with_label_values(&["flush_resolve_table"]) + .start_timer(); + match catalog_manager + .table( + &catalog, + &schema, + &table_batch.table_name, + Some(ctx.as_ref()), + ) + .await + { + Ok(Some(table)) => table, + Ok(None) => { + record_failure!( + table_batch.row_count, + format!( + "Table not found during pending flush: {}", + table_batch.table_name + ) + ); + continue; + } + Err(err) => { + record_failure!( + table_batch.row_count, + format!( + "Failed to resolve table {} for pending flush: {:?}", + table_batch.table_name, err + ) + ); + continue; + } + } + }; + let table_info = table.table_info(); + + let partition_rule = { + let _timer = PENDING_ROWS_BATCH_INGEST_STAGE_ELAPSED + .with_label_values(&["flush_fetch_partition_rule"]) + .start_timer(); + match partition_manager + .find_table_partition_rule(&table_info) + .await + { + Ok(rule) => rule, + Err(err) => { + record_failure!( + table_batch.row_count, + format!( + "Failed to fetch partition rule for table {}: {:?}", + table_batch.table_name, err + ) + ); + continue; + } + } + }; + + let region_masks = { + let _timer = PENDING_ROWS_BATCH_INGEST_STAGE_ELAPSED + .with_label_values(&["flush_split_record_batch"]) + .start_timer(); + match partition_rule.0.split_record_batch(&record_batch) { + Ok(masks) => masks, + Err(err) => { + record_failure!( + table_batch.row_count, + format!( + "Failed to split record batch for table {}: {:?}", + table_batch.table_name, err + ) + ); + continue; + } + } + }; + + for (region_number, mask) in region_masks { + if mask.select_none() { + continue; + } + + let region_batch = if mask.select_all() { + record_batch.clone() + } else { + let _timer = PENDING_ROWS_BATCH_INGEST_STAGE_ELAPSED + .with_label_values(&["flush_filter_record_batch"]) + .start_timer(); + match filter_record_batch(&record_batch, mask.array()) { + Ok(batch) => batch, + Err(err) => { + record_failure!( + table_batch.row_count, + format!( + "Failed to filter record batch for table {}: {:?}", + table_batch.table_name, err + ) + ); + continue; + } + } + }; + + let row_count = region_batch.num_rows(); + if row_count == 0 { + continue; + } + + let region_id = RegionId::new(table_info.table_id(), region_number); + let datanode = { + let _timer = PENDING_ROWS_BATCH_INGEST_STAGE_ELAPSED + .with_label_values(&["flush_resolve_region_leader"]) + .start_timer(); + match partition_manager.find_region_leader(region_id).await { + Ok(peer) => peer, + Err(err) => { + record_failure!( + row_count, + format!("Failed to resolve region leader {}: {:?}", region_id, err) + ); + continue; + } + } + }; + + let (schema_bytes, data_header, payload) = { + let _timer = PENDING_ROWS_BATCH_INGEST_STAGE_ELAPSED + .with_label_values(&["flush_encode_ipc"]) + .start_timer(); + match record_batch_to_ipc(region_batch) { + Ok(encoded) => encoded, + Err(err) => { + record_failure!( + row_count, + format!( + "Failed to encode Arrow IPC for region {}: {:?}", + region_id, err + ) + ); + continue; + } + } + }; + + let request = RegionRequest { + header: Some(RegionRequestHeader { + tracing_context: TracingContext::from_current_span().to_w3c(), + ..Default::default() + }), + body: Some(region_request::Body::BulkInsert(BulkInsertRequest { + region_id: region_id.as_u64(), + partition_expr_version: None, + body: Some(bulk_insert_request::Body::ArrowIpc(ArrowIpc { + schema: schema_bytes, + data_header, + payload, + })), + })), + }; + + let datanode = node_manager.datanode(&datanode).await; + let _timer = PENDING_ROWS_BATCH_INGEST_STAGE_ELAPSED + .with_label_values(&["flush_write_region"]) + .start_timer(); + match datanode.handle(request).await { + Ok(_) => { + FLUSH_TOTAL.inc(); + FLUSH_ROWS.observe(row_count as f64); + } + Err(err) => { + record_failure!( + row_count, + format!( + "Bulk insert flush failed for region {}: {:?}", + region_id, err + ) + ); + } + } + } + } + + let elapsed = start.elapsed().as_secs_f64(); + FLUSH_ELAPSED.observe(elapsed); + info!( + "Pending rows batch flushed, total rows: {}, elapsed time: {}s", + total_row_count, elapsed + ); + + notify_waiters(waiters, &first_error); +} + +fn notify_waiters(waiters: Vec, first_error: &Option) { + for waiter in waiters { + let result = match first_error { + Some(err_msg) => Err(Error::Internal { + err_msg: err_msg.clone(), + }), + None => Ok(()), + }; + let _ = waiter.response_tx.send(result); + // waiter._permit is dropped here, releasing the inflight semaphore slot + } +} + +fn mark_flush_failure(row_count: usize, message: &str) { + error!("Pending rows batch flush failed, message: {}", message); + FLUSH_FAILURES.inc(); + FLUSH_DROPPED_ROWS.inc_by(row_count as u64); +} + +fn flush_with_error(batch: &mut PendingBatch, message: &str) { + if batch.total_row_count == 0 { + return; + } + + let row_count = batch.total_row_count; + let waiters = std::mem::take(&mut batch.waiters); + batch.tables.clear(); + batch.total_row_count = 0; + batch.created_at = None; + batch.ctx = None; + + PENDING_ROWS.sub(row_count as i64); + PENDING_BATCHES.dec(); + + let err_msg = Some(message.to_string()); + notify_waiters(waiters, &err_msg); + mark_flush_failure(row_count, message); +} + +fn build_table_batches(requests: RowInsertRequests) -> Result<(Vec<(String, RecordBatch)>, usize)> { + let mut table_batches = Vec::with_capacity(requests.inserts.len()); + let mut total_rows = 0; + + for request in requests.inserts { + let Some(rows) = request.rows else { + continue; + }; + if rows.rows.is_empty() { + continue; + } + + let record_batch = rows_to_record_batch(&rows)?; + total_rows += record_batch.num_rows(); + table_batches.push((request.table_name, record_batch)); + } + + Ok((table_batches, total_rows)) +} + +fn align_record_batch_to_schema( + record_batch: RecordBatch, + target_schema: &ArrowSchema, +) -> Result { + let source_schema = record_batch.schema(); + if source_schema.as_ref() == target_schema { + return Ok(record_batch); + } + + for source_field in source_schema.fields() { + if target_schema + .column_with_name(source_field.name()) + .is_none() + { + return Err(Error::Internal { + err_msg: format!( + "Failed to align record batch schema, column '{}' not found in target schema", + source_field.name() + ), + }); + } + } + + let row_count = record_batch.num_rows(); + let mut columns = Vec::with_capacity(target_schema.fields().len()); + for target_field in target_schema.fields() { + let column = if let Some((index, source_field)) = + source_schema.column_with_name(target_field.name()) + { + let source_column = record_batch.column(index).clone(); + if source_field.data_type() == target_field.data_type() { + source_column + } else { + cast(source_column.as_ref(), target_field.data_type()).map_err(|err| { + Error::Internal { + err_msg: format!( + "Failed to cast column '{}' to target type {:?}: {}", + target_field.name(), + target_field.data_type(), + err + ), + } + })? + } + } else { + new_null_array(target_field.data_type(), row_count) + }; + columns.push(column); + } + + RecordBatch::try_new(Arc::new(target_schema.clone()), columns).map_err(|err| Error::Internal { + err_msg: format!("Failed to build aligned record batch: {}", err), + }) +} + +fn rows_to_record_batch(rows: &Rows) -> Result { + let row_count = rows.rows.len(); + let column_count = rows.schema.len(); + + for (idx, row) in rows.rows.iter().enumerate() { + ensure!( + row.values.len() == column_count, + error::InternalSnafu { + err_msg: format!( + "Column count mismatch in row {}, expected {}, got {}", + idx, + column_count, + row.values.len() + ) + } + ); + } + + let mut fields = Vec::with_capacity(column_count); + let mut columns = Vec::with_capacity(column_count); + + for (idx, column_schema) in rows.schema.iter().enumerate() { + let datatype_wrapper = ColumnDataTypeWrapper::try_new( + column_schema.datatype, + column_schema.datatype_extension.clone(), + )?; + let data_type = ConcreteDataType::from(datatype_wrapper); + fields.push(Field::new( + column_schema.column_name.clone(), + data_type.as_arrow_type(), + true, + )); + columns.push(build_arrow_array( + rows, + idx, + &column_schema.column_name, + data_type.as_arrow_type(), + row_count, + )?); + } + + RecordBatch::try_new(Arc::new(ArrowSchema::new(fields)), columns).context(error::ArrowSnafu) +} + +fn build_arrow_array( + rows: &Rows, + col_idx: usize, + column_name: &String, + column_data_type: arrow::datatypes::DataType, + row_count: usize, +) -> Result { + macro_rules! build_array { + ($builder:expr, $( $pattern:pat => $value:expr ),+ $(,)?) => {{ + let mut builder = $builder; + for row in &rows.rows { + match row.values[col_idx].value_data.as_ref() { + $(Some($pattern) => builder.append_value($value),)+ + Some(v) => { + return error::InvalidPromRemoteRequestSnafu { + msg: format!("Unexpected value: {:?}", v), + } + .fail(); + } + None => builder.append_null(), + } + } + Arc::new(builder.finish()) as ArrayRef + }}; + } + + let array: ArrayRef = match column_data_type { + arrow::datatypes::DataType::Float64 => { + build_array!(Float64Builder::with_capacity(row_count), ValueData::F64Value(v) => *v) + } + arrow::datatypes::DataType::Utf8 => build_array!( + StringBuilder::with_capacity(row_count, 0), + ValueData::StringValue(v) => v + ), + arrow::datatypes::DataType::Timestamp(u, _) => match u { + TimeUnit::Second => build_array!( + TimestampSecondBuilder::with_capacity(row_count), + ValueData::TimestampSecondValue(v) => *v + ), + TimeUnit::Millisecond => build_array!( + TimestampMillisecondBuilder::with_capacity(row_count), + ValueData::TimestampMillisecondValue(v) => *v + ), + TimeUnit::Microsecond => build_array!( + TimestampMicrosecondBuilder::with_capacity(row_count), + ValueData::DatetimeValue(v) => *v, + ValueData::TimestampMicrosecondValue(v) => *v + ), + TimeUnit::Nanosecond => build_array!( + TimestampNanosecondBuilder::with_capacity(row_count), + ValueData::TimestampNanosecondValue(v) => *v + ), + }, + ty => { + return error::InvalidPromRemoteRequestSnafu { + msg: format!( + "Unexpected column type {:?}, column name: {}", + ty, column_name + ), + } + .fail(); + } + }; + + Ok(array) +} + +fn record_batch_to_ipc(record_batch: RecordBatch) -> Result<(Bytes, Bytes, Bytes)> { + let mut encoder = FlightEncoder::default(); + let schema = encoder.encode_schema(record_batch.schema().as_ref()); + let mut iter = encoder + .encode(FlightMessage::RecordBatch(record_batch)) + .into_iter(); + let Some(flight_data) = iter.next() else { + return Err(Error::Internal { + err_msg: "Failed to encode empty flight data".to_string(), + }); + }; + if iter.next().is_some() { + return Err(Error::NotSupported { + feat: "bulk insert RecordBatch with dictionary arrays".to_string(), + }); + } + + Ok(( + schema.data_header, + flight_data.data_header, + flight_data.data_body, + )) +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use api::v1::value::ValueData; + use api::v1::{ColumnDataType, ColumnSchema, Row, Rows, SemanticType, Value}; + use arrow::array::{Array, Float64Array, Int32Array, Int64Array, StringArray}; + use arrow::datatypes::{DataType, Field, Schema as ArrowSchema}; + use arrow::record_batch::RecordBatch; + use dashmap::DashMap; + use tokio::sync::mpsc; + + use super::{ + BatchKey, PendingWorker, WorkerCommand, align_record_batch_to_schema, + remove_worker_if_same_channel, rows_to_record_batch, should_close_worker_on_idle_timeout, + }; + + #[test] + fn test_rows_to_record_batch() { + let rows = Rows { + schema: vec![ + ColumnSchema { + column_name: "ts".to_string(), + datatype: ColumnDataType::TimestampMillisecond as i32, + semantic_type: SemanticType::Timestamp as i32, + ..Default::default() + }, + ColumnSchema { + column_name: "value".to_string(), + datatype: ColumnDataType::Float64 as i32, + semantic_type: SemanticType::Field as i32, + ..Default::default() + }, + ColumnSchema { + column_name: "host".to_string(), + datatype: ColumnDataType::String as i32, + semantic_type: SemanticType::Tag as i32, + ..Default::default() + }, + ], + rows: vec![ + Row { + values: vec![ + Value { + value_data: Some(ValueData::TimestampMillisecondValue(1000)), + }, + Value { + value_data: Some(ValueData::F64Value(42.0)), + }, + Value { + value_data: Some(ValueData::StringValue("h1".to_string())), + }, + ], + }, + Row { + values: vec![ + Value { + value_data: Some(ValueData::TimestampMillisecondValue(2000)), + }, + Value { value_data: None }, + Value { + value_data: Some(ValueData::StringValue("h2".to_string())), + }, + ], + }, + ], + }; + + let rb = rows_to_record_batch(&rows).unwrap(); + assert_eq!(2, rb.num_rows()); + assert_eq!(3, rb.num_columns()); + } + + #[test] + fn test_align_record_batch_to_schema_reorder_and_fill_missing() { + let source_schema = Arc::new(ArrowSchema::new(vec![ + Field::new("host", DataType::Utf8, true), + Field::new("value", DataType::Float64, true), + ])); + let source = RecordBatch::try_new( + source_schema, + vec![ + Arc::new(StringArray::from(vec!["h1"])), + Arc::new(Float64Array::from(vec![42.0])), + ], + ) + .unwrap(); + + let target = ArrowSchema::new(vec![ + Field::new("ts", DataType::Int64, true), + Field::new("host", DataType::Utf8, true), + Field::new("value", DataType::Float64, true), + ]); + + let aligned = align_record_batch_to_schema(source, &target).unwrap(); + assert_eq!(aligned.schema().as_ref(), &target); + assert_eq!(1, aligned.num_rows()); + assert_eq!(3, aligned.num_columns()); + let ts = aligned + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert!(ts.is_null(0)); + } + + #[test] + fn test_align_record_batch_to_schema_cast_column_type() { + let source_schema = Arc::new(ArrowSchema::new(vec![Field::new( + "value", + DataType::Int32, + true, + )])); + let source = RecordBatch::try_new( + source_schema, + vec![Arc::new(Int32Array::from(vec![Some(7), None]))], + ) + .unwrap(); + + let target = ArrowSchema::new(vec![Field::new("value", DataType::Int64, true)]); + let aligned = align_record_batch_to_schema(source, &target).unwrap(); + let value = aligned + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(Some(7), value.iter().next().flatten()); + assert!(value.is_null(1)); + } + + #[test] + fn test_remove_worker_if_same_channel_removes_matching_entry() { + let workers = DashMap::new(); + let key = BatchKey { + catalog: "greptime".to_string(), + schema: "public".to_string(), + physical_table: "phy".to_string(), + }; + + let (tx, _rx) = mpsc::channel::(1); + workers.insert(key.clone(), PendingWorker { tx: tx.clone() }); + + assert!(remove_worker_if_same_channel(&workers, &key, &tx)); + assert!(!workers.contains_key(&key)); + } + + #[test] + fn test_remove_worker_if_same_channel_keeps_newer_entry() { + let workers = DashMap::new(); + let key = BatchKey { + catalog: "greptime".to_string(), + schema: "public".to_string(), + physical_table: "phy".to_string(), + }; + + let (stale_tx, _stale_rx) = mpsc::channel::(1); + let (fresh_tx, _fresh_rx) = mpsc::channel::(1); + workers.insert( + key.clone(), + PendingWorker { + tx: fresh_tx.clone(), + }, + ); + + assert!(!remove_worker_if_same_channel(&workers, &key, &stale_tx)); + assert!(workers.contains_key(&key)); + assert!(workers.get(&key).unwrap().tx.same_channel(&fresh_tx)); + } + + #[test] + fn test_worker_idle_timeout_close_decision() { + assert!(should_close_worker_on_idle_timeout(0, 0)); + assert!(!should_close_worker_on_idle_timeout(1, 0)); + assert!(!should_close_worker_on_idle_timeout(0, 1)); + } +} diff --git a/src/servers/src/query_handler.rs b/src/servers/src/query_handler.rs index 21c7646560..b55502e742 100644 --- a/src/servers/src/query_handler.rs +++ b/src/servers/src/query_handler.rs @@ -86,6 +86,11 @@ pub struct PromStoreResponse { #[async_trait] pub trait PromStoreProtocolHandler { + /// Runs pre-write checks/hooks for prometheus remote write requests. + async fn pre_write(&self, _request: &RowInsertRequests, _ctx: QueryContextRef) -> Result<()> { + Ok(()) + } + /// Handling prometheus remote write requests async fn write( &self, diff --git a/src/servers/tests/http/prom_store_test.rs b/src/servers/tests/http/prom_store_test.rs index b1e974d3d3..c5d5207486 100644 --- a/src/servers/tests/http/prom_store_test.rs +++ b/src/servers/tests/http/prom_store_test.rs @@ -120,7 +120,7 @@ fn make_test_app(tx: mpsc::Sender<(String, Vec)>) -> Router { let instance = Arc::new(DummyInstance { tx }); let server = HttpServerBuilder::new(http_opts) .with_sql_handler(instance.clone()) - .with_prom_handler(instance, None, true, PromValidationMode::Unchecked) + .with_prom_handler(instance, None, true, PromValidationMode::Unchecked, None) .build(); server.build(server.make_app()).unwrap() } diff --git a/tests-integration/src/test_util.rs b/tests-integration/src/test_util.rs index 2bf6e812c7..8e7c3ce8a6 100644 --- a/tests-integration/src/test_util.rs +++ b/tests-integration/src/test_util.rs @@ -623,6 +623,7 @@ pub async fn setup_test_prom_app_with_frontend( Some(frontend_ref.clone()), true, PromValidationMode::Strict, + None, ) .with_prometheus_handler(frontend_ref) .with_greptime_config_options(instance.opts.datanode_options().to_toml().unwrap()) diff --git a/tests-integration/tests/http.rs b/tests-integration/tests/http.rs index 05a34eb5b7..933fcadf6b 100644 --- a/tests-integration/tests/http.rs +++ b/tests-integration/tests/http.rs @@ -1483,6 +1483,11 @@ enable = true [prom_store] enable = true with_metric_engine = true +pending_rows_flush_interval = "0s" +max_batch_rows = 100000 +max_concurrent_flushes = 256 +worker_channel_capacity = 65526 +max_inflight_requests = 3000 [wal] provider = "raft_engine"