Skip to main content

client/
region.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use api::region::RegionResponse;
18use api::v1::ResponseHeader;
19use api::v1::region::{
20    RegionRequest, RegionRequestHeader, RemoteDynFilterRequest, RemoteDynFilterUnregister,
21    RemoteDynFilterUpdate, region_request, remote_dyn_filter_request,
22};
23use arc_swap::ArcSwapOption;
24use arrow_flight::Ticket;
25use async_stream::stream;
26use async_trait::async_trait;
27use common_error::ext::BoxedError;
28use common_error::status_code::StatusCode;
29use common_grpc::flight::{FlightDecoder, FlightMessage};
30use common_meta::error::{self as meta_error, Result as MetaResult};
31use common_meta::node_manager::Datanode;
32use common_query::request::QueryRequest;
33use common_recordbatch::error::ExternalSnafu;
34use common_recordbatch::{RecordBatch, RecordBatchStreamWrapper, SendableRecordBatchStream};
35use common_telemetry::error;
36use common_telemetry::tracing::Span;
37use common_telemetry::tracing_context::TracingContext;
38use prost::Message;
39use query::query_engine::DefaultSerializer;
40use snafu::{OptionExt, ResultExt, location};
41use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan};
42use tokio_stream::StreamExt;
43
44use crate::error::{
45    self, ConvertFlightDataSnafu, FlightGetSnafu, IllegalDatabaseResponseSnafu,
46    IllegalFlightMessagesSnafu, MissingFieldSnafu, Result, ServerSnafu,
47};
48use crate::{Client, Error, metrics};
49
50#[derive(Debug)]
51pub struct RegionRequester {
52    client: Client,
53    send_compression: bool,
54    accept_compression: bool,
55}
56
57#[async_trait]
58impl Datanode for RegionRequester {
59    async fn handle(&self, request: RegionRequest) -> MetaResult<RegionResponse> {
60        self.handle_inner(request).await.map_err(|err| {
61            if err.should_retry() {
62                meta_error::Error::RetryLater {
63                    source: BoxedError::new(err),
64                    clean_poisons: false,
65                }
66            } else {
67                meta_error::Error::External {
68                    source: BoxedError::new(err),
69                    location: location!(),
70                }
71            }
72        })
73    }
74
75    async fn handle_query(&self, request: QueryRequest) -> MetaResult<SendableRecordBatchStream> {
76        let plan = DFLogicalSubstraitConvertor
77            .encode(&request.plan, DefaultSerializer)
78            .map_err(BoxedError::new)
79            .context(meta_error::ExternalSnafu)?
80            .to_vec();
81        let request = api::v1::region::QueryRequest {
82            header: request.header,
83            region_id: request.region_id.as_u64(),
84            plan,
85        };
86
87        let ticket = Ticket {
88            ticket: request.encode_to_vec().into(),
89        };
90        self.do_get_inner(ticket)
91            .await
92            .map_err(BoxedError::new)
93            .context(meta_error::ExternalSnafu)
94    }
95}
96
97impl RegionRequester {
98    pub fn new(client: Client, send_compression: bool, accept_compression: bool) -> Self {
99        Self {
100            client,
101            send_compression,
102            accept_compression,
103        }
104    }
105
106    pub async fn do_get_inner(&self, ticket: Ticket) -> Result<SendableRecordBatchStream> {
107        let mut flight_client = self
108            .client
109            .make_flight_client(self.send_compression, self.accept_compression)?;
110        let response = flight_client
111            .mut_inner()
112            .do_get(ticket)
113            .await
114            .or_else(|e| {
115                let tonic_code = e.code();
116                let e: error::Error = e.into();
117                error!(
118                    e; "Failed to do Flight get, addr: {}, code: {}",
119                    flight_client.addr(),
120                    tonic_code
121                );
122                Err(BoxedError::new(e)).with_context(|_| FlightGetSnafu {
123                    addr: flight_client.addr().to_string(),
124                    tonic_code,
125                })
126            })?;
127
128        let flight_data_stream = response.into_inner();
129        let mut decoder = FlightDecoder::default();
130
131        let mut flight_message_stream = flight_data_stream.map(move |flight_data| {
132            flight_data
133                .map_err(Error::from)
134                .and_then(|data| decoder.try_decode(&data).context(ConvertFlightDataSnafu))?
135                .context(IllegalFlightMessagesSnafu {
136                    reason: "none message",
137                })
138        });
139
140        let Some(first_flight_message) = flight_message_stream.next().await else {
141            return IllegalFlightMessagesSnafu {
142                reason: "Expect the response not to be empty",
143            }
144            .fail();
145        };
146        let FlightMessage::Schema(schema) = first_flight_message? else {
147            return IllegalFlightMessagesSnafu {
148                reason: "Expect schema to be the first flight message",
149            }
150            .fail();
151        };
152
153        let metrics = Arc::new(ArcSwapOption::from(None));
154        let metrics_ref = metrics.clone();
155
156        let tracing_context = TracingContext::from_current_span();
157
158        let schema = Arc::new(
159            datatypes::schema::Schema::try_from(schema).context(error::ConvertSchemaSnafu)?,
160        );
161        let schema_cloned = schema.clone();
162        let stream = Box::pin(stream!({
163            let _span = tracing_context.attach(common_telemetry::tracing::info_span!(
164                "poll_flight_data_stream"
165            ));
166
167            let mut buffered_message: Option<FlightMessage> = None;
168            let mut stream_ended = false;
169
170            while !stream_ended {
171                // get the next message from the buffered message or read from the flight message stream
172                let flight_message_item = if let Some(msg) = buffered_message.take() {
173                    Some(Ok(msg))
174                } else {
175                    flight_message_stream.next().await
176                };
177
178                let flight_message = match flight_message_item {
179                    Some(Ok(message)) => message,
180                    Some(Err(e)) => {
181                        yield Err(BoxedError::new(e)).context(ExternalSnafu);
182                        break;
183                    }
184                    None => break,
185                };
186
187                match flight_message {
188                    FlightMessage::RecordBatch(record_batch) => {
189                        let result_to_yield =
190                            RecordBatch::from_df_record_batch(schema_cloned.clone(), record_batch);
191
192                        // get the next message from the stream. normally it should be a metrics message.
193                        if let Some(next_flight_message_result) = flight_message_stream.next().await
194                        {
195                            match next_flight_message_result {
196                                Ok(FlightMessage::Metrics(s)) => {
197                                    let m = serde_json::from_str(&s).ok().map(Arc::new);
198                                    metrics_ref.swap(m);
199                                }
200                                Ok(FlightMessage::RecordBatch(rb)) => {
201                                    // for some reason it's not a metrics message, so we need to buffer this record batch
202                                    // and yield it in the next iteration.
203                                    buffered_message = Some(FlightMessage::RecordBatch(rb));
204                                }
205                                Ok(_) => {
206                                    yield IllegalFlightMessagesSnafu {
207                                        reason: "A RecordBatch message can only be succeeded by a Metrics message or another RecordBatch message"
208                                    }
209                                    .fail()
210                                    .map_err(BoxedError::new)
211                                    .context(ExternalSnafu);
212                                    break;
213                                }
214                                Err(e) => {
215                                    yield Err(BoxedError::new(e)).context(ExternalSnafu);
216                                    break;
217                                }
218                            }
219                        } else {
220                            // the stream has ended
221                            stream_ended = true;
222                        }
223
224                        yield Ok(result_to_yield);
225                    }
226                    FlightMessage::Metrics(s) => {
227                        // just a branch in case of some metrics message comes after other things.
228                        let m = serde_json::from_str(&s).ok().map(Arc::new);
229                        metrics_ref.swap(m);
230                        break;
231                    }
232                    _ => {
233                        yield IllegalFlightMessagesSnafu {
234                            reason: "A Schema message must be succeeded exclusively by a set of RecordBatch messages"
235                        }
236                        .fail()
237                        .map_err(BoxedError::new)
238                        .context(ExternalSnafu);
239                        break;
240                    }
241                }
242            }
243        }));
244        let record_batch_stream = RecordBatchStreamWrapper {
245            schema,
246            stream,
247            output_ordering: None,
248            metrics,
249            span: Span::current(),
250        };
251        Ok(Box::pin(record_batch_stream))
252    }
253
254    async fn handle_inner(&self, request: RegionRequest) -> Result<RegionResponse> {
255        let request_type = request
256            .body
257            .as_ref()
258            .with_context(|| MissingFieldSnafu { field: "body" })?
259            .as_ref()
260            .to_string();
261        let _timer = metrics::METRIC_REGION_REQUEST_GRPC
262            .with_label_values(&[request_type.as_str()])
263            .start_timer();
264
265        let (addr, mut client) = self.client.raw_region_client()?;
266
267        let response = client
268            .handle(request)
269            .await
270            .map_err(|e| {
271                let code = e.code();
272                // Uses `Error::RegionServer` instead of `Error::Server`
273                error::Error::RegionServer {
274                    addr,
275                    code,
276                    source: BoxedError::new(error::Error::from(e)),
277                    location: location!(),
278                }
279            })?
280            .into_inner();
281
282        check_response_header(&response.header)?;
283
284        Ok(RegionResponse::from_region_response(response))
285    }
286
287    pub async fn handle(&self, request: RegionRequest) -> Result<RegionResponse> {
288        self.handle_inner(request).await
289    }
290
291    pub async fn handle_remote_dyn_filter_update(
292        &self,
293        query_id: impl Into<String>,
294        update: RemoteDynFilterUpdate,
295    ) -> Result<RegionResponse> {
296        self.handle_inner(build_remote_dyn_filter_request(
297            query_id.into(),
298            remote_dyn_filter_request::Action::Update(update),
299        ))
300        .await
301    }
302
303    pub async fn handle_remote_dyn_filter_unregister(
304        &self,
305        query_id: impl Into<String>,
306        unregister: RemoteDynFilterUnregister,
307    ) -> Result<RegionResponse> {
308        self.handle_inner(build_remote_dyn_filter_request(
309            query_id.into(),
310            remote_dyn_filter_request::Action::Unregister(unregister),
311        ))
312        .await
313    }
314}
315
316fn build_remote_dyn_filter_request(
317    query_id: String,
318    action: remote_dyn_filter_request::Action,
319) -> RegionRequest {
320    RegionRequest {
321        header: Some(RegionRequestHeader {
322            tracing_context: TracingContext::from_current_span().to_w3c(),
323            ..Default::default()
324        }),
325        body: Some(region_request::Body::RemoteDynFilter(
326            RemoteDynFilterRequest {
327                query_id,
328                action: Some(action),
329            },
330        )),
331    }
332}
333
334pub fn check_response_header(header: &Option<ResponseHeader>) -> Result<()> {
335    let status = header
336        .as_ref()
337        .and_then(|header| header.status.as_ref())
338        .context(IllegalDatabaseResponseSnafu {
339            err_msg: "either response header or status is missing",
340        })?;
341
342    if StatusCode::is_success(status.status_code) {
343        Ok(())
344    } else {
345        let code =
346            StatusCode::from_u32(status.status_code).context(IllegalDatabaseResponseSnafu {
347                err_msg: format!("unknown server status: {:?}", status),
348            })?;
349        ServerSnafu {
350            code,
351            msg: status.err_msg.clone(),
352        }
353        .fail()
354    }
355}
356
357#[cfg(test)]
358mod test {
359    use api::v1::Status as PbStatus;
360    use api::v1::region::{RemoteDynFilterUpdate, region_request, remote_dyn_filter_request};
361
362    use super::*;
363    use crate::Error::{IllegalDatabaseResponse, Server};
364
365    #[test]
366    fn test_check_response_header() {
367        let result = check_response_header(&None);
368        assert!(matches!(
369            result.unwrap_err(),
370            IllegalDatabaseResponse { .. }
371        ));
372
373        let result = check_response_header(&Some(ResponseHeader { status: None }));
374        assert!(matches!(
375            result.unwrap_err(),
376            IllegalDatabaseResponse { .. }
377        ));
378
379        let result = check_response_header(&Some(ResponseHeader {
380            status: Some(PbStatus {
381                status_code: StatusCode::Success as u32,
382                err_msg: String::default(),
383            }),
384        }));
385        assert!(result.is_ok());
386
387        let result = check_response_header(&Some(ResponseHeader {
388            status: Some(PbStatus {
389                status_code: u32::MAX,
390                err_msg: String::default(),
391            }),
392        }));
393        assert!(matches!(
394            result.unwrap_err(),
395            IllegalDatabaseResponse { .. }
396        ));
397
398        let result = check_response_header(&Some(ResponseHeader {
399            status: Some(PbStatus {
400                status_code: StatusCode::Internal as u32,
401                err_msg: "blabla".to_string(),
402            }),
403        }));
404        let Server { code, msg, .. } = result.unwrap_err() else {
405            unreachable!()
406        };
407        assert_eq!(code, StatusCode::Internal);
408        assert_eq!(msg, "blabla");
409    }
410
411    #[test]
412    fn test_build_remote_dyn_filter_request_sets_header_and_body() {
413        let request = build_remote_dyn_filter_request(
414            "query-1".to_string(),
415            remote_dyn_filter_request::Action::Update(RemoteDynFilterUpdate {
416                filter_id: "filter-1".to_string(),
417                payload: vec![1, 2, 3],
418                generation: 7,
419                is_complete: false,
420            }),
421        );
422
423        request.header.expect("remote dyn filter header must exist");
424
425        let body = request.body.expect("remote dyn filter body must exist");
426        let region_request::Body::RemoteDynFilter(remote_request) = body else {
427            panic!("expected remote dyn filter request body");
428        };
429
430        assert_eq!(remote_request.query_id, "query-1");
431        assert!(matches!(
432            remote_request.action,
433            Some(remote_dyn_filter_request::Action::Update(_))
434        ));
435    }
436}