1use std::collections::HashSet;
16
17use ahash::{HashMap, HashMapExt};
18use api::v1::flow::{DirtyWindowRequest, DirtyWindowRequests};
19use api::v1::region::{
20 BulkInsertRequest, RegionRequest, RegionRequestHeader, bulk_insert_request, region_request,
21};
22use api::v1::{ArrowIpc, PartitionExprVersion};
23use arrow::array::Array;
24use arrow::record_batch::RecordBatch;
25use bytes::Bytes;
26use common_base::AffectedRows;
27use common_grpc::FlightData;
28use common_grpc::flight::{FlightEncoder, FlightMessage};
29use common_telemetry::error;
30use common_telemetry::tracing_context::TracingContext;
31use snafu::{OptionExt, ResultExt, ensure};
32use store_api::storage::RegionId;
33use table::TableRef;
34use table::metadata::TableInfoRef;
35
36use crate::insert::Inserter;
37use crate::{error, metrics};
38
39impl Inserter {
40 pub async fn handle_bulk_insert(
42 &self,
43 table: TableRef,
44 raw_flight_data: FlightData,
45 record_batch: RecordBatch,
46 schema_bytes: Bytes,
47 ) -> error::Result<AffectedRows> {
48 let table_info = table.table_info();
49 let table_id = table_info.table_id();
50 let db_name = table_info.get_db_string();
51
52 if record_batch.num_rows() == 0 {
53 return Ok(0);
54 }
55
56 let body_size = raw_flight_data.data_body.len();
57 self.maybe_update_flow_dirty_window(table_info.clone(), record_batch.clone());
60
61 metrics::BULK_REQUEST_MESSAGE_SIZE.observe(body_size as f64);
62 metrics::BULK_REQUEST_ROWS
63 .with_label_values(&["raw"])
64 .observe(record_batch.num_rows() as f64);
65
66 let partition_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
67 .with_label_values(&["partition"])
68 .start_timer();
69 let (partition_rule, partition_versions) = self
70 .partition_manager
71 .find_table_partition_rule(&table_info)
72 .await
73 .context(error::InvalidPartitionSnafu)?;
74
75 let region_masks = partition_rule
77 .split_record_batch(&record_batch)
78 .context(error::SplitInsertSnafu)?;
79 partition_timer.observe_duration();
80
81 if region_masks.len() == 1 {
83 metrics::BULK_REQUEST_ROWS
84 .with_label_values(&["rows_per_region"])
85 .observe(record_batch.num_rows() as f64);
86
87 let (region_number, _) = region_masks.into_iter().next().unwrap();
89 let region_id = RegionId::new(table_id, region_number);
90 let partition_expr_version = partition_versions
91 .get(®ion_number)
92 .copied()
93 .unwrap_or_default();
94 let datanode = self
95 .partition_manager
96 .find_region_leader(region_id)
97 .await
98 .context(error::FindRegionLeaderSnafu)?;
99
100 let request = RegionRequest {
101 header: Some(RegionRequestHeader {
102 tracing_context: TracingContext::from_current_span().to_w3c(),
103 ..Default::default()
104 }),
105 body: Some(region_request::Body::BulkInsert(BulkInsertRequest {
106 region_id: region_id.as_u64(),
107 partition_expr_version: partition_expr_version
108 .map(|value| PartitionExprVersion { value }),
109 aligned_schema_version: None,
110 body: Some(bulk_insert_request::Body::ArrowIpc(ArrowIpc {
111 schema: schema_bytes.clone(),
112 data_header: raw_flight_data.data_header,
113 payload: raw_flight_data.data_body,
114 })),
115 })),
116 };
117
118 let _datanode_handle_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
119 .with_label_values(&["datanode_handle"])
120 .start_timer();
121 let datanode = self.node_manager.datanode(&datanode).await;
122 let result = datanode
123 .handle(request)
124 .await
125 .context(error::RequestRegionSnafu)
126 .map(|r| r.affected_rows);
127 if let Ok(rows) = result {
128 crate::metrics::DIST_INGEST_ROW_COUNT
129 .with_label_values(&[db_name.as_str()])
130 .inc_by(rows as u64);
131 }
132 return result;
133 }
134
135 let mut mask_per_datanode = HashMap::with_capacity(region_masks.len());
136 for (region_number, mask) in region_masks {
137 let region_id = RegionId::new(table_id, region_number);
138 let datanode = self
139 .partition_manager
140 .find_region_leader(region_id)
141 .await
142 .context(error::FindRegionLeaderSnafu)?;
143 mask_per_datanode
144 .entry(datanode)
145 .or_insert_with(Vec::new)
146 .push((region_id, mask));
147 }
148
149 let wait_all_datanode_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
150 .with_label_values(&["wait_all_datanode"])
151 .start_timer();
152
153 let mut handles = Vec::with_capacity(mask_per_datanode.len());
154
155 for (peer, masks) in mask_per_datanode {
156 for (region_id, mask) in masks {
157 if mask.select_none() {
158 continue;
159 }
160 let partition_expr_version = partition_versions
161 .get(®ion_id.region_number())
162 .copied()
163 .unwrap_or_default();
164 let rb = record_batch.clone();
165 let schema_bytes = schema_bytes.clone();
166 let node_manager = self.node_manager.clone();
167 let peer = peer.clone();
168 let raw_header_and_data = if mask.select_all() {
169 Some((
170 raw_flight_data.data_header.clone(),
171 raw_flight_data.data_body.clone(),
172 ))
173 } else {
174 None
175 };
176 let handle: common_runtime::JoinHandle<error::Result<api::region::RegionResponse>> =
177 common_runtime::spawn_global(async move {
178 let (header, payload) = if mask.select_all() {
179 raw_header_and_data.unwrap()
181 } else {
182 let filter_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
183 .with_label_values(&["filter"])
184 .start_timer();
185 let batch = arrow::compute::filter_record_batch(&rb, mask.array())
186 .context(error::ComputeArrowSnafu)?;
187 filter_timer.observe_duration();
188 metrics::BULK_REQUEST_ROWS
189 .with_label_values(&["rows_per_region"])
190 .observe(batch.num_rows() as f64);
191
192 let encode_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
193 .with_label_values(&["encode"])
194 .start_timer();
195 let mut iter = FlightEncoder::default()
196 .encode(FlightMessage::RecordBatch(batch))
197 .into_iter();
198 let Some(flight_data) = iter.next() else {
199 unreachable!()
202 };
203 ensure!(
204 iter.next().is_none(),
205 error::NotSupportedSnafu {
206 feat: "bulk insert RecordBatch with dictionary arrays",
207 }
208 );
209 encode_timer.observe_duration();
210 (flight_data.data_header, flight_data.data_body)
211 };
212 let _datanode_handle_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
213 .with_label_values(&["datanode_handle"])
214 .start_timer();
215 let request = RegionRequest {
216 header: Some(RegionRequestHeader {
217 tracing_context: TracingContext::from_current_span().to_w3c(),
218 ..Default::default()
219 }),
220 body: Some(region_request::Body::BulkInsert(BulkInsertRequest {
221 region_id: region_id.as_u64(),
222 partition_expr_version: partition_expr_version
223 .map(|value| PartitionExprVersion { value }),
224 aligned_schema_version: None,
225 body: Some(bulk_insert_request::Body::ArrowIpc(ArrowIpc {
226 schema: schema_bytes,
227 data_header: header,
228 payload,
229 })),
230 })),
231 };
232
233 let datanode = node_manager.datanode(&peer).await;
234 datanode
235 .handle(request)
236 .await
237 .context(error::RequestRegionSnafu)
238 });
239 handles.push(handle);
240 }
241 }
242
243 let region_responses = futures::future::try_join_all(handles)
244 .await
245 .context(error::JoinTaskSnafu)?;
246 wait_all_datanode_timer.observe_duration();
247 let mut rows_inserted: usize = 0;
248 for res in region_responses {
249 rows_inserted += res?.affected_rows;
250 }
251 crate::metrics::DIST_INGEST_ROW_COUNT
252 .with_label_values(&[db_name.as_str()])
253 .inc_by(rows_inserted as u64);
254 Ok(rows_inserted)
255 }
256
257 fn maybe_update_flow_dirty_window(&self, table_info: TableInfoRef, record_batch: RecordBatch) {
258 let table_id = table_info.table_id();
259 let table_flownode_set_cache = self.table_flownode_set_cache.clone();
260 let node_manager = self.node_manager.clone();
261 common_runtime::spawn_global(async move {
262 let result = table_flownode_set_cache
263 .get(table_id)
264 .await
265 .context(error::RequestInsertsSnafu);
266 let flownodes = match result {
267 Ok(flownodes) => flownodes.unwrap_or_default(),
268 Err(e) => {
269 error!(e; "Failed to get flownodes for table id: {}", table_id);
270 return;
271 }
272 };
273
274 let peers: HashSet<_> = flownodes.values().cloned().collect();
275 if peers.is_empty() {
276 return;
277 }
278
279 let Ok(timestamps) = extract_timestamps(
280 &record_batch,
281 &table_info
282 .meta
283 .schema
284 .timestamp_column()
285 .as_ref()
286 .unwrap()
287 .name,
288 )
289 .inspect_err(|e| {
290 error!(e; "Failed to extract timestamps from record batch");
291 }) else {
292 return;
293 };
294
295 for peer in peers {
296 let node_manager = node_manager.clone();
297 let timestamps = timestamps.clone();
298 common_runtime::spawn_global(async move {
299 if let Err(e) = node_manager
300 .flownode(&peer)
301 .await
302 .handle_mark_window_dirty(DirtyWindowRequests {
303 requests: vec![DirtyWindowRequest {
304 table_id,
305 timestamps,
306 }],
307 })
308 .await
309 .context(error::RequestInsertsSnafu)
310 {
311 error!(e; "Failed to mark timestamps as dirty, table: {}", table_id);
312 }
313 });
314 }
315 });
316 }
317}
318
319fn extract_timestamps(rb: &RecordBatch, timestamp_index_name: &str) -> error::Result<Vec<i64>> {
321 let ts_col = rb
322 .column_by_name(timestamp_index_name)
323 .context(error::ColumnNotFoundSnafu {
324 msg: timestamp_index_name,
325 })?;
326 if rb.num_rows() == 0 {
327 return Ok(vec![]);
328 }
329 let (primitive, _) =
330 datatypes::timestamp::timestamp_array_to_primitive(ts_col).with_context(|| {
331 error::InvalidTimeIndexTypeSnafu {
332 ty: ts_col.data_type().clone(),
333 }
334 })?;
335 Ok(primitive.iter().flatten().collect())
336}