diff --git a/src/client/src/database.rs b/src/client/src/database.rs index d8e538242e..f786186388 100644 --- a/src/client/src/database.rs +++ b/src/client/src/database.rs @@ -315,7 +315,7 @@ impl Database { let mut flight_message_stream = flight_data_stream.map(move |flight_data| { flight_data .map_err(Error::from) - .and_then(|data| decoder.try_decode(data).context(ConvertFlightDataSnafu)) + .and_then(|data| decoder.try_decode(&data).context(ConvertFlightDataSnafu)) }); let Some(first_flight_message) = flight_message_stream.next().await else { diff --git a/src/client/src/region.rs b/src/client/src/region.rs index 15d123936d..c31d3a1e17 100644 --- a/src/client/src/region.rs +++ b/src/client/src/region.rs @@ -125,7 +125,7 @@ impl RegionRequester { let mut flight_message_stream = flight_data_stream.map(move |flight_data| { flight_data .map_err(Error::from) - .and_then(|data| decoder.try_decode(data).context(ConvertFlightDataSnafu)) + .and_then(|data| decoder.try_decode(&data).context(ConvertFlightDataSnafu)) }); let Some(first_flight_message) = flight_message_stream.next().await else { diff --git a/src/common/grpc/src/flight.rs b/src/common/grpc/src/flight.rs index 37a725e0cc..63f4d05289 100644 --- a/src/common/grpc/src/flight.rs +++ b/src/common/grpc/src/flight.rs @@ -127,7 +127,7 @@ pub struct FlightDecoder { } impl FlightDecoder { - pub fn try_decode(&mut self, flight_data: FlightData) -> Result { + pub fn try_decode(&mut self, flight_data: &FlightData) -> Result { let message = root_as_message(&flight_data.data_header).map_err(|e| { InvalidFlightDataSnafu { reason: e.to_string(), @@ -136,7 +136,7 @@ impl FlightDecoder { })?; match message.header_type() { MessageHeader::NONE => { - let metadata = FlightMetadata::decode(flight_data.app_metadata) + let metadata = FlightMetadata::decode(flight_data.app_metadata.clone()) .context(DecodeFlightDataSnafu)?; if let Some(AffectedRows { value }) = metadata.affected_rows { return Ok(FlightMessage::AffectedRows(value as _)); @@ -152,7 +152,7 @@ impl FlightDecoder { .fail() } MessageHeader::Schema => { - let arrow_schema = ArrowSchema::try_from(&flight_data).map_err(|e| { + let arrow_schema = ArrowSchema::try_from(flight_data).map_err(|e| { InvalidFlightDataSnafu { reason: e.to_string(), } @@ -172,7 +172,7 @@ impl FlightDecoder { let arrow_schema = schema.arrow_schema().clone(); let arrow_batch = - flight_data_to_arrow_batch(&flight_data, arrow_schema, &HashMap::new()) + flight_data_to_arrow_batch(flight_data, arrow_schema, &HashMap::new()) .map_err(|e| { InvalidFlightDataSnafu { reason: e.to_string(), @@ -287,14 +287,14 @@ mod test { let decoder = &mut FlightDecoder::default(); assert!(decoder.schema.is_none()); - let result = decoder.try_decode(d2.clone()); + let result = decoder.try_decode(d2); assert!(matches!(result, Err(Error::InvalidFlightData { .. }))); assert!(result .unwrap_err() .to_string() .contains("Should have decoded schema first!")); - let message = decoder.try_decode(d1.clone()).unwrap(); + let message = decoder.try_decode(d1).unwrap(); assert!(matches!(message, FlightMessage::Schema(_))); let FlightMessage::Schema(decoded_schema) = message else { unreachable!() @@ -303,14 +303,14 @@ mod test { let _ = decoder.schema.as_ref().unwrap(); - let message = decoder.try_decode(d2.clone()).unwrap(); + let message = decoder.try_decode(d2).unwrap(); assert!(matches!(message, FlightMessage::Recordbatch(_))); let FlightMessage::Recordbatch(actual_batch) = message else { unreachable!() }; assert_eq!(actual_batch, batch1); - let message = decoder.try_decode(d3.clone()).unwrap(); + let message = decoder.try_decode(d3).unwrap(); assert!(matches!(message, FlightMessage::Recordbatch(_))); let FlightMessage::Recordbatch(actual_batch) = message else { unreachable!() diff --git a/src/operator/src/bulk_insert.rs b/src/operator/src/bulk_insert.rs index 6d199f3025..7a51bd4904 100644 --- a/src/operator/src/bulk_insert.rs +++ b/src/operator/src/bulk_insert.rs @@ -48,7 +48,7 @@ impl Inserter { let body_size = data.data_body.len(); // Build region server requests let message = decoder - .try_decode(data) + .try_decode(&data) .context(error::DecodeFlightDataSnafu)?; let FlightMessage::Recordbatch(rb) = message else { return Ok(0); @@ -82,6 +82,51 @@ impl Inserter { .context(error::SplitInsertSnafu)?; partition_timer.observe_duration(); + // fast path: only one region. + if region_masks.len() == 1 { + metrics::BULK_REQUEST_ROWS + .with_label_values(&["rows_per_region"]) + .observe(record_batch.num_rows() as f64); + + // SAFETY: region masks length checked + let (region_number, _) = region_masks.into_iter().next().unwrap(); + let region_id = RegionId::new(table_id, region_number); + let datanode = self + .partition_manager + .find_region_leader(region_id) + .await + .context(error::FindRegionLeaderSnafu)?; + let payload = { + let _encode_timer = metrics::HANDLE_BULK_INSERT_ELAPSED + .with_label_values(&["encode"]) + .start_timer(); + Bytes::from(data.encode_to_vec()) + }; + let request = RegionRequest { + header: Some(RegionRequestHeader { + tracing_context: TracingContext::from_current_span().to_w3c(), + ..Default::default() + }), + body: Some(region_request::Body::BulkInsert(BulkInsertRequest { + body: Some(bulk_insert_request::Body::ArrowIpc(ArrowIpc { + region_id: region_id.as_u64(), + schema: schema_bytes, + payload, + })), + })), + }; + + let _datanode_handle_timer = metrics::HANDLE_BULK_INSERT_ELAPSED + .with_label_values(&["datanode_handle"]) + .start_timer(); + let datanode = self.node_manager.datanode(&datanode).await; + return datanode + .handle(request) + .await + .context(error::RequestRegionSnafu) + .map(|r| r.affected_rows); + } + let mut mask_per_datanode = HashMap::with_capacity(region_masks.len()); for (region_number, mask) in region_masks { let region_id = RegionId::new(table_id, region_number); @@ -104,6 +149,7 @@ impl Inserter { let record_batch_schema = Arc::new(Schema::try_from(record_batch.schema()).context(error::ConvertSchemaSnafu)?); + let mut raw_data_bytes = None; for (peer, masks) in mask_per_datanode { for (region_id, mask) in masks { let rb = record_batch.clone(); @@ -111,30 +157,45 @@ impl Inserter { let record_batch_schema = record_batch_schema.clone(); let node_manager = self.node_manager.clone(); let peer = peer.clone(); + let raw_data = if mask.select_all() { + Some( + raw_data_bytes + .get_or_insert_with(|| Bytes::from(data.encode_to_vec())) + .clone(), + ) + } else { + None + }; let handle: common_runtime::JoinHandle> = common_runtime::spawn_global(async move { - let filter_timer = metrics::HANDLE_BULK_INSERT_ELAPSED - .with_label_values(&["filter"]) - .start_timer(); - let rb = arrow::compute::filter_record_batch(&rb, &mask) - .context(error::ComputeArrowSnafu)?; - filter_timer.observe_duration(); - metrics::BULK_REQUEST_ROWS - .with_label_values(&["rows_per_region"]) - .observe(rb.num_rows() as f64); - - let encode_timer = metrics::HANDLE_BULK_INSERT_ELAPSED - .with_label_values(&["encode"]) - .start_timer(); - let batch = RecordBatch::try_from_df_record_batch(record_batch_schema, rb) - .context(error::BuildRecordBatchSnafu)?; - let payload = Bytes::from( - FlightEncoder::default() - .encode(FlightMessage::Recordbatch(batch)) - .encode_to_vec(), - ); - encode_timer.observe_duration(); + let payload = if mask.select_all() { + // SAFETY: raw data must be present, we can avoid re-encoding. + raw_data.unwrap() + } else { + let filter_timer = metrics::HANDLE_BULK_INSERT_ELAPSED + .with_label_values(&["filter"]) + .start_timer(); + let rb = arrow::compute::filter_record_batch(&rb, mask.array()) + .context(error::ComputeArrowSnafu)?; + filter_timer.observe_duration(); + metrics::BULK_REQUEST_ROWS + .with_label_values(&["rows_per_region"]) + .observe(rb.num_rows() as f64); + let encode_timer = metrics::HANDLE_BULK_INSERT_ELAPSED + .with_label_values(&["encode"]) + .start_timer(); + let batch = + RecordBatch::try_from_df_record_batch(record_batch_schema, rb) + .context(error::BuildRecordBatchSnafu)?; + let payload = Bytes::from( + FlightEncoder::default() + .encode(FlightMessage::Recordbatch(batch)) + .encode_to_vec(), + ); + encode_timer.observe_duration(); + payload + }; let _datanode_handle_timer = metrics::HANDLE_BULK_INSERT_ELAPSED .with_label_values(&["datanode_handle"]) .start_timer(); diff --git a/src/partition/src/multi_dim.rs b/src/partition/src/multi_dim.rs index 9bbbf8f015..83c1deab61 100644 --- a/src/partition/src/multi_dim.rs +++ b/src/partition/src/multi_dim.rs @@ -14,6 +14,7 @@ use std::any::Any; use std::cmp::Ordering; +use std::collections::hash_map::Entry; use std::collections::HashMap; use std::sync::{Arc, RwLock}; @@ -34,6 +35,7 @@ use crate::error::{ UndefinedColumnSnafu, }; use crate::expr::{Operand, PartitionExpr, RestrictedOp}; +use crate::partition::RegionMask; use crate::PartitionRule; /// The default region number when no partition exprs are matched. @@ -209,14 +211,15 @@ impl MultiDimPartitionRule { pub fn split_record_batch( &self, record_batch: &RecordBatch, - ) -> Result> { + ) -> Result> { let num_rows = record_batch.num_rows(); if self.regions.len() == 1 { - return Ok( - [(self.regions[0], BooleanArray::from(vec![true; num_rows]))] - .into_iter() - .collect(), - ); + return Ok([( + self.regions[0], + RegionMask::from(BooleanArray::from(vec![true; num_rows])), + )] + .into_iter() + .collect()); } let physical_exprs = { let cache_read_guard = self.physical_expr_cache.read().unwrap(); @@ -240,34 +243,56 @@ impl MultiDimPartitionRule { } }; - let mut result: HashMap = physical_exprs + let mut result: HashMap = physical_exprs .iter() .zip(self.regions.iter()) - .map(|(expr, region_num)| { - let ColumnarValue::Array(column) = expr + .filter_map(|(expr, region_num)| { + let col_val = match expr .evaluate(record_batch) - .context(error::EvaluateRecordBatchSnafu)? - else { + .context(error::EvaluateRecordBatchSnafu) + { + Ok(array) => array, + Err(e) => { + return Some(Err(e)); + } + }; + let ColumnarValue::Array(column) = col_val else { unreachable!("Expected an array") }; - Ok(( - *region_num, - column + let array = + match column .as_any() .downcast_ref::() .with_context(|| error::UnexpectedColumnTypeSnafu { data_type: column.data_type().clone(), - })? - .clone(), - )) + }) { + Ok(array) => array, + Err(e) => { + return Some(Err(e)); + } + }; + let selected_rows = array.true_count(); + if selected_rows == 0 { + // skip empty region in results. + return None; + } + Some(Ok(( + *region_num, + RegionMask::new(array.clone(), selected_rows), + ))) }) .collect::>()?; - let mut selected = BooleanArray::new(BooleanBuffer::new_unset(num_rows), None); - for region_selection in result.values() { - selected = arrow::compute::kernels::boolean::or(&selected, region_selection) - .context(error::ComputeArrowKernelSnafu)?; - } + let selected = if result.len() == 1 { + result.values().next().unwrap().array().clone() + } else { + let mut selected = BooleanArray::new(BooleanBuffer::new_unset(num_rows), None); + for region_mask in result.values() { + selected = arrow::compute::kernels::boolean::or(&selected, region_mask.array()) + .context(error::ComputeArrowKernelSnafu)?; + } + selected + }; // fast path: all rows are selected if selected.true_count() == num_rows { @@ -277,12 +302,20 @@ impl MultiDimPartitionRule { // find unselected rows and assign to default region let unselected = arrow::compute::kernels::boolean::not(&selected) .context(error::ComputeArrowKernelSnafu)?; - let default_region_selection = result - .entry(DEFAULT_REGION) - .or_insert_with(|| unselected.clone()); - *default_region_selection = - arrow::compute::kernels::boolean::or(default_region_selection, &unselected) - .context(error::ComputeArrowKernelSnafu)?; + match result.entry(DEFAULT_REGION) { + Entry::Occupied(mut o) => { + // merge default region with unselected rows. + let default_region_mask = RegionMask::from( + arrow::compute::kernels::boolean::or(o.get().array(), &unselected) + .context(error::ComputeArrowKernelSnafu)?, + ); + o.insert(default_region_mask); + } + Entry::Vacant(v) => { + // default region has no rows, simply put all unselected rows to default region. + v.insert(RegionMask::from(unselected)); + } + } Ok(result) } } @@ -303,7 +336,7 @@ impl PartitionRule for MultiDimPartitionRule { fn split_record_batch( &self, record_batch: &RecordBatch, - ) -> Result> { + ) -> Result> { self.split_record_batch(record_batch) } } @@ -845,7 +878,7 @@ mod test_split_record_batch { assert_eq!(result.len(), expected.len()); for (region, value) in &result { assert_eq!( - value, + value.array(), expected.get(region).unwrap(), "failed on region: {}", region @@ -904,7 +937,7 @@ mod test_split_record_batch { let expected = rule.split_record_batch_naive(&batch).unwrap(); assert_eq!(result.len(), expected.len()); for (region, value) in &result { - assert_eq!(value, expected.get(region).unwrap()); + assert_eq!(value.array(), expected.get(region).unwrap()); } } @@ -937,9 +970,117 @@ mod test_split_record_batch { .unwrap(); let result = rule.split_record_batch(&batch).unwrap(); let expected = rule.split_record_batch_naive(&batch).unwrap(); - assert_eq!(result.len(), expected.len()); for (region, value) in &result { - assert_eq!(value, expected.get(region).unwrap()); + assert_eq!(value.array(), expected.get(region).unwrap()); } } + + #[test] + fn test_default_region_with_unselected_rows() { + // Create a rule where some rows won't match any partition + let rule = MultiDimPartitionRule::try_new( + vec!["host".to_string(), "value".to_string()], + vec![1, 2, 3], + vec![ + col("value").eq(Value::Int64(10)), + col("value").eq(Value::Int64(20)), + col("value").eq(Value::Int64(30)), + ], + ) + .unwrap(); + + let schema = test_schema(); + let host_array = + StringArray::from(vec!["server1", "server2", "server3", "server4", "server5"]); + let value_array = Int64Array::from(vec![10, 20, 30, 40, 50]); + let batch = RecordBatch::try_new(schema, vec![Arc::new(host_array), Arc::new(value_array)]) + .unwrap(); + + let result = rule.split_record_batch(&batch).unwrap(); + + // Check that we have 4 regions (3 defined + default) + assert_eq!(result.len(), 4); + + // Check that default region (0) contains the unselected rows + assert!(result.contains_key(&DEFAULT_REGION)); + let default_mask = result.get(&DEFAULT_REGION).unwrap(); + + // The default region should have 2 rows (with values 40 and 50) + assert_eq!(default_mask.selected_rows(), 2); + + // Verify each region has the correct number of rows + assert_eq!(result.get(&1).unwrap().selected_rows(), 1); // value = 10 + assert_eq!(result.get(&2).unwrap().selected_rows(), 1); // value = 20 + assert_eq!(result.get(&3).unwrap().selected_rows(), 1); // value = 30 + } + + #[test] + fn test_default_region_with_existing_default() { + // Create a rule where some rows are explicitly assigned to default region + // and some rows are implicitly assigned to default region + let rule = MultiDimPartitionRule::try_new( + vec!["host".to_string(), "value".to_string()], + vec![0, 1, 2], + vec![ + col("value").eq(Value::Int64(10)), // Explicitly assign value=10 to region 0 (default) + col("value").eq(Value::Int64(20)), + col("value").eq(Value::Int64(30)), + ], + ) + .unwrap(); + + let schema = test_schema(); + let host_array = + StringArray::from(vec!["server1", "server2", "server3", "server4", "server5"]); + let value_array = Int64Array::from(vec![10, 20, 30, 40, 50]); + let batch = RecordBatch::try_new(schema, vec![Arc::new(host_array), Arc::new(value_array)]) + .unwrap(); + + let result = rule.split_record_batch(&batch).unwrap(); + + // Check that we have 3 regions + assert_eq!(result.len(), 3); + + // Check that default region contains both explicitly assigned and unselected rows + assert!(result.contains_key(&DEFAULT_REGION)); + let default_mask = result.get(&DEFAULT_REGION).unwrap(); + + // The default region should have 3 rows (value=10, 40, 50) + assert_eq!(default_mask.selected_rows(), 3); + + // Verify each region has the correct number of rows + assert_eq!(result.get(&1).unwrap().selected_rows(), 1); // value = 20 + assert_eq!(result.get(&2).unwrap().selected_rows(), 1); // value = 30 + } + + #[test] + fn test_all_rows_selected() { + // Test the fast path where all rows are selected by some partition + let rule = MultiDimPartitionRule::try_new( + vec!["value".to_string()], + vec![1, 2], + vec![ + col("value").lt(Value::Int64(30)), + col("value").gt_eq(Value::Int64(30)), + ], + ) + .unwrap(); + + let schema = test_schema(); + let host_array = StringArray::from(vec!["server1", "server2", "server3", "server4"]); + let value_array = Int64Array::from(vec![10, 20, 30, 40]); + let batch = RecordBatch::try_new(schema, vec![Arc::new(host_array), Arc::new(value_array)]) + .unwrap(); + + let result = rule.split_record_batch(&batch).unwrap(); + + // Check that we have 2 regions and no default region + assert_eq!(result.len(), 2); + assert!(result.contains_key(&1)); + assert!(result.contains_key(&2)); + + // Verify each region has the correct number of rows + assert_eq!(result.get(&1).unwrap().selected_rows(), 2); // values < 30 + assert_eq!(result.get(&2).unwrap().selected_rows(), 2); // values >= 30 + } } diff --git a/src/partition/src/partition.rs b/src/partition/src/partition.rs index a190d33eca..bcaa7dbb31 100644 --- a/src/partition/src/partition.rs +++ b/src/partition/src/partition.rs @@ -41,11 +41,12 @@ pub trait PartitionRule: Sync + Send { fn find_region(&self, values: &[Value]) -> Result; /// Split the record batch into multiple regions by the partition values. - /// The result is a map from region number to a boolean array, where the boolean array is true for the rows that match the partition values. + /// The result is a map from region mask in which the array is true for the rows that match the partition values. + /// Region with now rows selected may not appear in result map. fn split_record_batch( &self, record_batch: &RecordBatch, - ) -> Result>; + ) -> Result>; } /// The right bound(exclusive) of partition range. @@ -177,6 +178,48 @@ impl PartitionExpr { } } +pub struct RegionMask { + array: BooleanArray, + selected_rows: usize, +} + +impl From for RegionMask { + fn from(array: BooleanArray) -> Self { + let selected_rows = array.true_count(); + Self { + array, + selected_rows, + } + } +} + +impl RegionMask { + pub fn new(array: BooleanArray, selected_rows: usize) -> Self { + Self { + array, + selected_rows, + } + } + + pub fn array(&self) -> &BooleanArray { + &self.array + } + + /// All rows are selected. + pub fn select_all(&self) -> bool { + self.selected_rows == self.array.len() + } + + /// No row is selected. + pub fn select_none(&self) -> bool { + self.selected_rows == 0 + } + + pub fn selected_rows(&self) -> usize { + self.selected_rows + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/partition/src/splitter.rs b/src/partition/src/splitter.rs index 87c04a4942..5ff9b04e84 100644 --- a/src/partition/src/splitter.rs +++ b/src/partition/src/splitter.rs @@ -136,10 +136,10 @@ mod tests { use api::v1::value::ValueData; use api::v1::{ColumnDataType, SemanticType}; - use datatypes::arrow::array::BooleanArray; use serde::{Deserialize, Serialize}; use super::*; + use crate::partition::RegionMask; use crate::PartitionRule; fn mock_rows() -> Rows { @@ -214,7 +214,7 @@ mod tests { fn split_record_batch( &self, _record_batch: &datatypes::arrow::array::RecordBatch, - ) -> Result> { + ) -> Result> { unimplemented!() } } @@ -244,7 +244,7 @@ mod tests { fn split_record_batch( &self, _record_batch: &datatypes::arrow::array::RecordBatch, - ) -> Result> { + ) -> Result> { unimplemented!() } } @@ -268,7 +268,7 @@ mod tests { fn split_record_batch( &self, _record_batch: &datatypes::arrow::array::RecordBatch, - ) -> Result> { + ) -> Result> { unimplemented!() } } diff --git a/src/servers/src/grpc/flight/stream.rs b/src/servers/src/grpc/flight/stream.rs index f5b8811dcf..dea9f40af2 100644 --- a/src/servers/src/grpc/flight/stream.rs +++ b/src/servers/src/grpc/flight/stream.rs @@ -167,7 +167,7 @@ mod test { let decoder = &mut FlightDecoder::default(); let mut flight_messages = raw_data .into_iter() - .map(|x| decoder.try_decode(x).unwrap()) + .map(|x| decoder.try_decode(&x).unwrap()) .collect::>(); assert_eq!(flight_messages.len(), 2); diff --git a/src/store-api/src/region_request.rs b/src/store-api/src/region_request.rs index 2bd9a89978..bd4c080d15 100644 --- a/src/store-api/src/region_request.rs +++ b/src/store-api/src/region_request.rs @@ -340,9 +340,10 @@ fn make_region_bulk_inserts(request: BulkInsertRequest) -> Result