Skip to main content

operator/
bulk_insert.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16
17use ahash::{HashMap, HashMapExt};
18use api::v1::flow::{DirtyWindowRequest, DirtyWindowRequests};
19use api::v1::region::{
20    BulkInsertRequest, RegionRequest, RegionRequestHeader, bulk_insert_request, region_request,
21};
22use api::v1::{ArrowIpc, PartitionExprVersion};
23use arrow::array::Array;
24use arrow::record_batch::RecordBatch;
25use bytes::Bytes;
26use common_base::AffectedRows;
27use common_grpc::FlightData;
28use common_grpc::flight::{FlightEncoder, FlightMessage};
29use common_telemetry::error;
30use common_telemetry::tracing_context::TracingContext;
31use snafu::{OptionExt, ResultExt, ensure};
32use store_api::storage::RegionId;
33use table::TableRef;
34use table::metadata::TableInfoRef;
35
36use crate::insert::Inserter;
37use crate::{error, metrics};
38
39impl Inserter {
40    /// Handle bulk insert request.
41    pub async fn handle_bulk_insert(
42        &self,
43        table: TableRef,
44        raw_flight_data: FlightData,
45        record_batch: RecordBatch,
46        schema_bytes: Bytes,
47    ) -> error::Result<AffectedRows> {
48        let table_info = table.table_info();
49        let table_id = table_info.table_id();
50        let db_name = table_info.get_db_string();
51
52        if record_batch.num_rows() == 0 {
53            return Ok(0);
54        }
55
56        let body_size = raw_flight_data.data_body.len();
57        // TODO(yingwen): Fill record batch impure default values. Note that we should override `raw_flight_data` if we have to fill defaults.
58        // notify flownode to update dirty timestamps if flow is configured.
59        self.maybe_update_flow_dirty_window(table_info.clone(), record_batch.clone());
60
61        metrics::BULK_REQUEST_MESSAGE_SIZE.observe(body_size as f64);
62        metrics::BULK_REQUEST_ROWS
63            .with_label_values(&["raw"])
64            .observe(record_batch.num_rows() as f64);
65
66        let partition_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
67            .with_label_values(&["partition"])
68            .start_timer();
69        let (partition_rule, partition_versions) = self
70            .partition_manager
71            .find_table_partition_rule(&table_info)
72            .await
73            .context(error::InvalidPartitionSnafu)?;
74
75        // find partitions for each row in the record batch
76        let region_masks = partition_rule
77            .split_record_batch(&record_batch)
78            .context(error::SplitInsertSnafu)?;
79        partition_timer.observe_duration();
80
81        // fast path: only one region.
82        if region_masks.len() == 1 {
83            metrics::BULK_REQUEST_ROWS
84                .with_label_values(&["rows_per_region"])
85                .observe(record_batch.num_rows() as f64);
86
87            // SAFETY: region masks length checked
88            let (region_number, _) = region_masks.into_iter().next().unwrap();
89            let region_id = RegionId::new(table_id, region_number);
90            let partition_expr_version = partition_versions
91                .get(&region_number)
92                .copied()
93                .unwrap_or_default();
94            let datanode = self
95                .partition_manager
96                .find_region_leader(region_id)
97                .await
98                .context(error::FindRegionLeaderSnafu)?;
99
100            let request = RegionRequest {
101                header: Some(RegionRequestHeader {
102                    tracing_context: TracingContext::from_current_span().to_w3c(),
103                    ..Default::default()
104                }),
105                body: Some(region_request::Body::BulkInsert(BulkInsertRequest {
106                    region_id: region_id.as_u64(),
107                    partition_expr_version: partition_expr_version
108                        .map(|value| PartitionExprVersion { value }),
109                    aligned_schema_version: None,
110                    body: Some(bulk_insert_request::Body::ArrowIpc(ArrowIpc {
111                        schema: schema_bytes.clone(),
112                        data_header: raw_flight_data.data_header,
113                        payload: raw_flight_data.data_body,
114                    })),
115                })),
116            };
117
118            let _datanode_handle_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
119                .with_label_values(&["datanode_handle"])
120                .start_timer();
121            let datanode = self.node_manager.datanode(&datanode).await;
122            let result = datanode
123                .handle(request)
124                .await
125                .context(error::RequestRegionSnafu)
126                .map(|r| r.affected_rows);
127            if let Ok(rows) = result {
128                crate::metrics::DIST_INGEST_ROW_COUNT
129                    .with_label_values(&[db_name.as_str()])
130                    .inc_by(rows as u64);
131            }
132            return result;
133        }
134
135        let mut mask_per_datanode = HashMap::with_capacity(region_masks.len());
136        for (region_number, mask) in region_masks {
137            let region_id = RegionId::new(table_id, region_number);
138            let datanode = self
139                .partition_manager
140                .find_region_leader(region_id)
141                .await
142                .context(error::FindRegionLeaderSnafu)?;
143            mask_per_datanode
144                .entry(datanode)
145                .or_insert_with(Vec::new)
146                .push((region_id, mask));
147        }
148
149        let wait_all_datanode_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
150            .with_label_values(&["wait_all_datanode"])
151            .start_timer();
152
153        let mut handles = Vec::with_capacity(mask_per_datanode.len());
154
155        for (peer, masks) in mask_per_datanode {
156            for (region_id, mask) in masks {
157                if mask.select_none() {
158                    continue;
159                }
160                let partition_expr_version = partition_versions
161                    .get(&region_id.region_number())
162                    .copied()
163                    .unwrap_or_default();
164                let rb = record_batch.clone();
165                let schema_bytes = schema_bytes.clone();
166                let node_manager = self.node_manager.clone();
167                let peer = peer.clone();
168                let raw_header_and_data = if mask.select_all() {
169                    Some((
170                        raw_flight_data.data_header.clone(),
171                        raw_flight_data.data_body.clone(),
172                    ))
173                } else {
174                    None
175                };
176                let handle: common_runtime::JoinHandle<error::Result<api::region::RegionResponse>> =
177                    common_runtime::spawn_global(async move {
178                        let (header, payload) = if mask.select_all() {
179                            // SAFETY: raw data must be present, we can avoid re-encoding.
180                            raw_header_and_data.unwrap()
181                        } else {
182                            let filter_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
183                                .with_label_values(&["filter"])
184                                .start_timer();
185                            let batch = arrow::compute::filter_record_batch(&rb, mask.array())
186                                .context(error::ComputeArrowSnafu)?;
187                            filter_timer.observe_duration();
188                            metrics::BULK_REQUEST_ROWS
189                                .with_label_values(&["rows_per_region"])
190                                .observe(batch.num_rows() as f64);
191
192                            let encode_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
193                                .with_label_values(&["encode"])
194                                .start_timer();
195                            let mut iter = FlightEncoder::default()
196                                .encode(FlightMessage::RecordBatch(batch))
197                                .into_iter();
198                            let Some(flight_data) = iter.next() else {
199                                // Safety: `iter` on a type of `Vec1`, which is guaranteed to have
200                                // at least one element.
201                                unreachable!()
202                            };
203                            ensure!(
204                                iter.next().is_none(),
205                                error::NotSupportedSnafu {
206                                    feat: "bulk insert RecordBatch with dictionary arrays",
207                                }
208                            );
209                            encode_timer.observe_duration();
210                            (flight_data.data_header, flight_data.data_body)
211                        };
212                        let _datanode_handle_timer = metrics::HANDLE_BULK_INSERT_ELAPSED
213                            .with_label_values(&["datanode_handle"])
214                            .start_timer();
215                        let request = RegionRequest {
216                            header: Some(RegionRequestHeader {
217                                tracing_context: TracingContext::from_current_span().to_w3c(),
218                                ..Default::default()
219                            }),
220                            body: Some(region_request::Body::BulkInsert(BulkInsertRequest {
221                                region_id: region_id.as_u64(),
222                                partition_expr_version: partition_expr_version
223                                    .map(|value| PartitionExprVersion { value }),
224                                aligned_schema_version: None,
225                                body: Some(bulk_insert_request::Body::ArrowIpc(ArrowIpc {
226                                    schema: schema_bytes,
227                                    data_header: header,
228                                    payload,
229                                })),
230                            })),
231                        };
232
233                        let datanode = node_manager.datanode(&peer).await;
234                        datanode
235                            .handle(request)
236                            .await
237                            .context(error::RequestRegionSnafu)
238                    });
239                handles.push(handle);
240            }
241        }
242
243        let region_responses = futures::future::try_join_all(handles)
244            .await
245            .context(error::JoinTaskSnafu)?;
246        wait_all_datanode_timer.observe_duration();
247        let mut rows_inserted: usize = 0;
248        for res in region_responses {
249            rows_inserted += res?.affected_rows;
250        }
251        crate::metrics::DIST_INGEST_ROW_COUNT
252            .with_label_values(&[db_name.as_str()])
253            .inc_by(rows_inserted as u64);
254        Ok(rows_inserted)
255    }
256
257    fn maybe_update_flow_dirty_window(&self, table_info: TableInfoRef, record_batch: RecordBatch) {
258        let table_id = table_info.table_id();
259        let table_flownode_set_cache = self.table_flownode_set_cache.clone();
260        let node_manager = self.node_manager.clone();
261        common_runtime::spawn_global(async move {
262            let result = table_flownode_set_cache
263                .get(table_id)
264                .await
265                .context(error::RequestInsertsSnafu);
266            let flownodes = match result {
267                Ok(flownodes) => flownodes.unwrap_or_default(),
268                Err(e) => {
269                    error!(e; "Failed to get flownodes for table id: {}", table_id);
270                    return;
271                }
272            };
273
274            let peers: HashSet<_> = flownodes.values().cloned().collect();
275            if peers.is_empty() {
276                return;
277            }
278
279            let Ok(timestamps) = extract_timestamps(
280                &record_batch,
281                &table_info
282                    .meta
283                    .schema
284                    .timestamp_column()
285                    .as_ref()
286                    .unwrap()
287                    .name,
288            )
289            .inspect_err(|e| {
290                error!(e; "Failed to extract timestamps from record batch");
291            }) else {
292                return;
293            };
294
295            for peer in peers {
296                let node_manager = node_manager.clone();
297                let timestamps = timestamps.clone();
298                common_runtime::spawn_global(async move {
299                    if let Err(e) = node_manager
300                        .flownode(&peer)
301                        .await
302                        .handle_mark_window_dirty(DirtyWindowRequests {
303                            requests: vec![DirtyWindowRequest {
304                                table_id,
305                                timestamps,
306                            }],
307                        })
308                        .await
309                        .context(error::RequestInsertsSnafu)
310                    {
311                        error!(e; "Failed to mark timestamps as dirty, table: {}", table_id);
312                    }
313                });
314            }
315        });
316    }
317}
318
319/// Calculate the timestamp range of record batch. Return `None` if record batch is empty.
320fn extract_timestamps(rb: &RecordBatch, timestamp_index_name: &str) -> error::Result<Vec<i64>> {
321    let ts_col = rb
322        .column_by_name(timestamp_index_name)
323        .context(error::ColumnNotFoundSnafu {
324            msg: timestamp_index_name,
325        })?;
326    if rb.num_rows() == 0 {
327        return Ok(vec![]);
328    }
329    let (primitive, _) =
330        datatypes::timestamp::timestamp_array_to_primitive(ts_col).with_context(|| {
331            error::InvalidTimeIndexTypeSnafu {
332                ty: ts_col.data_type().clone(),
333            }
334        })?;
335    Ok(primitive.iter().flatten().collect())
336}