diff --git a/src/datatypes/src/timestamp.rs b/src/datatypes/src/timestamp.rs index dacc125eee..6d699d3d56 100644 --- a/src/datatypes/src/timestamp.rs +++ b/src/datatypes/src/timestamp.rs @@ -12,6 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. +use arrow_array::{ + ArrayRef, PrimitiveArray, TimestampMicrosecondArray, TimestampMillisecondArray, + TimestampNanosecondArray, TimestampSecondArray, +}; +use arrow_schema::DataType; use common_time::timestamp::TimeUnit; use common_time::Timestamp; use paste::paste; @@ -138,6 +143,41 @@ define_timestamp_with_unit!(Millisecond); define_timestamp_with_unit!(Microsecond); define_timestamp_with_unit!(Nanosecond); +pub fn timestamp_array_to_primitive( + ts_array: &ArrayRef, +) -> Option<( + PrimitiveArray, + arrow::datatypes::TimeUnit, +)> { + let DataType::Timestamp(unit, _) = ts_array.data_type() else { + return None; + }; + + let ts_primitive = match unit { + arrow_schema::TimeUnit::Second => ts_array + .as_any() + .downcast_ref::() + .unwrap() + .reinterpret_cast::(), + arrow_schema::TimeUnit::Millisecond => ts_array + .as_any() + .downcast_ref::() + .unwrap() + .reinterpret_cast::(), + arrow_schema::TimeUnit::Microsecond => ts_array + .as_any() + .downcast_ref::() + .unwrap() + .reinterpret_cast::(), + arrow_schema::TimeUnit::Nanosecond => ts_array + .as_any() + .downcast_ref::() + .unwrap() + .reinterpret_cast::(), + }; + Some((ts_primitive, *unit)) +} + #[cfg(test)] mod tests { use common_time::timezone::set_default_timezone; diff --git a/src/mito2/src/error.rs b/src/mito2/src/error.rs index 6ea0386f03..e23e9e154a 100644 --- a/src/mito2/src/error.rs +++ b/src/mito2/src/error.rs @@ -1009,6 +1009,18 @@ pub enum Error { #[snafu(implicit)] location: Location, }, + + #[snafu(display( + "Inconsistent timestamp column length, expect: {}, actual: {}", + expected, + actual + ))] + InconsistentTimestampLength { + expected: usize, + actual: usize, + #[snafu(implicit)] + location: Location, + }, } pub type Result = std::result::Result; @@ -1166,6 +1178,8 @@ impl ErrorExt for Error { #[cfg(feature = "enterprise")] ScanExternalRange { source, .. } => source.status_code(), + + InconsistentTimestampLength { .. } => StatusCode::InvalidArguments, } } diff --git a/src/mito2/src/worker/handle_bulk_insert.rs b/src/mito2/src/worker/handle_bulk_insert.rs index b739cef380..e648edf307 100644 --- a/src/mito2/src/worker/handle_bulk_insert.rs +++ b/src/mito2/src/worker/handle_bulk_insert.rs @@ -15,15 +15,11 @@ //! Handles bulk insert requests. use datatypes::arrow; -use datatypes::arrow::array::{ - TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, - TimestampSecondArray, -}; -use datatypes::arrow::datatypes::{DataType, TimeUnit}; use store_api::logstore::LogStore; use store_api::metadata::RegionMetadataRef; use store_api::region_request::RegionBulkInsertsRequest; +use crate::error::InconsistentTimestampLengthSnafu; use crate::memtable::bulk::part::BulkPart; use crate::request::{OptionOutputTx, SenderBulkRequest}; use crate::worker::RegionWorkerLoop; @@ -41,6 +37,10 @@ impl RegionWorkerLoop { .with_label_values(&["process_bulk_req"]) .start_timer(); let batch = request.payload; + if batch.num_rows() == 0 { + sender.send(Ok(0)); + return; + } let Some((ts_index, ts)) = batch .schema() @@ -60,55 +60,23 @@ impl RegionWorkerLoop { return; }; - let DataType::Timestamp(unit, _) = ts.data_type() else { - // safety: ts data type must be a timestamp type. - unreachable!() - }; + if batch.num_rows() != ts.len() { + sender.send( + InconsistentTimestampLengthSnafu { + expected: batch.num_rows(), + actual: ts.len(), + } + .fail(), + ); + return; + } - let (min_ts, max_ts) = match unit { - TimeUnit::Second => { - let ts = ts.as_any().downcast_ref::().unwrap(); - ( - //safety: ts array must contain at least one row so this won't return None. - arrow::compute::min(ts).unwrap(), - arrow::compute::max(ts).unwrap(), - ) - } + // safety: ts data type must be a timestamp type. + let (ts_primitive, _) = datatypes::timestamp::timestamp_array_to_primitive(ts).unwrap(); - TimeUnit::Millisecond => { - let ts = ts - .as_any() - .downcast_ref::() - .unwrap(); - ( - //safety: ts array must contain at least one row so this won't return None. - arrow::compute::min(ts).unwrap(), - arrow::compute::max(ts).unwrap(), - ) - } - TimeUnit::Microsecond => { - let ts = ts - .as_any() - .downcast_ref::() - .unwrap(); - ( - //safety: ts array must contain at least one row so this won't return None. - arrow::compute::min(ts).unwrap(), - arrow::compute::max(ts).unwrap(), - ) - } - TimeUnit::Nanosecond => { - let ts = ts - .as_any() - .downcast_ref::() - .unwrap(); - ( - //safety: ts array must contain at least one row so this won't return None. - arrow::compute::min(ts).unwrap(), - arrow::compute::max(ts).unwrap(), - ) - } - }; + // safety: we've checked ts.len() == batch.num_rows() and batch is not empty + let min_ts = arrow::compute::min(&ts_primitive).unwrap(); + let max_ts = arrow::compute::max(&ts_primitive).unwrap(); let part = BulkPart { batch, diff --git a/src/operator/src/bulk_insert.rs b/src/operator/src/bulk_insert.rs index 4aa90b6cf8..7ef442a63c 100644 --- a/src/operator/src/bulk_insert.rs +++ b/src/operator/src/bulk_insert.rs @@ -20,11 +20,7 @@ use api::v1::region::{ bulk_insert_request, region_request, BulkInsertRequest, RegionRequest, RegionRequestHeader, }; use api::v1::ArrowIpc; -use arrow::array::{ - Array, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, - TimestampSecondArray, -}; -use arrow::datatypes::{DataType, Int64Type, TimeUnit}; +use arrow::array::Array; use arrow::record_batch::RecordBatch; use common_base::AffectedRows; use common_grpc::flight::{FlightDecoder, FlightEncoder, FlightMessage}; @@ -62,6 +58,10 @@ impl Inserter { }; decode_timer.observe_duration(); + if record_batch.num_rows() == 0 { + return Ok(0); + } + // notify flownode to update dirty timestamps if flow is configured. self.maybe_update_flow_dirty_window(table_info, record_batch.clone()); @@ -155,6 +155,9 @@ impl Inserter { let mut raw_data_bytes = None; for (peer, masks) in mask_per_datanode { for (region_id, mask) in masks { + if mask.select_none() { + continue; + } let rb = record_batch.clone(); let schema_bytes = schema_bytes.clone(); let node_manager = self.node_manager.clone(); @@ -304,32 +307,11 @@ fn extract_timestamps(rb: &RecordBatch, timestamp_index_name: &str) -> error::Re if rb.num_rows() == 0 { return Ok(vec![]); } - let primitive = match ts_col.data_type() { - DataType::Timestamp(unit, _) => match unit { - TimeUnit::Second => ts_col - .as_any() - .downcast_ref::() - .unwrap() - .reinterpret_cast::(), - TimeUnit::Millisecond => ts_col - .as_any() - .downcast_ref::() - .unwrap() - .reinterpret_cast::(), - TimeUnit::Microsecond => ts_col - .as_any() - .downcast_ref::() - .unwrap() - .reinterpret_cast::(), - TimeUnit::Nanosecond => ts_col - .as_any() - .downcast_ref::() - .unwrap() - .reinterpret_cast::(), - }, - t => { - return error::InvalidTimeIndexTypeSnafu { ty: t.clone() }.fail(); - } - }; + let (primitive, _) = + datatypes::timestamp::timestamp_array_to_primitive(ts_col).with_context(|| { + error::InvalidTimeIndexTypeSnafu { + ty: ts_col.data_type().clone(), + } + })?; Ok(primitive.iter().flatten().collect()) }