1use std::sync::Arc;
16
17use api::region::RegionResponse;
18use api::v1::ResponseHeader;
19use api::v1::region::{
20 RegionRequest, RegionRequestHeader, RemoteDynFilterRequest, RemoteDynFilterUnregister,
21 RemoteDynFilterUpdate, region_request, remote_dyn_filter_request,
22};
23use arc_swap::ArcSwapOption;
24use arrow_flight::Ticket;
25use async_stream::stream;
26use async_trait::async_trait;
27use common_error::ext::BoxedError;
28use common_error::status_code::StatusCode;
29use common_grpc::flight::{FlightDecoder, FlightMessage};
30use common_meta::error::{self as meta_error, Result as MetaResult};
31use common_meta::node_manager::Datanode;
32use common_query::request::QueryRequest;
33use common_recordbatch::error::ExternalSnafu;
34use common_recordbatch::{RecordBatch, RecordBatchStreamWrapper, SendableRecordBatchStream};
35use common_telemetry::error;
36use common_telemetry::tracing::Span;
37use common_telemetry::tracing_context::TracingContext;
38use prost::Message;
39use query::query_engine::DefaultSerializer;
40use snafu::{OptionExt, ResultExt, location};
41use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan};
42use tokio_stream::StreamExt;
43
44use crate::error::{
45 self, ConvertFlightDataSnafu, FlightGetSnafu, IllegalDatabaseResponseSnafu,
46 IllegalFlightMessagesSnafu, MissingFieldSnafu, Result, ServerSnafu,
47};
48use crate::{Client, Error, metrics};
49
50#[derive(Debug)]
51pub struct RegionRequester {
52 client: Client,
53 send_compression: bool,
54 accept_compression: bool,
55}
56
57#[async_trait]
58impl Datanode for RegionRequester {
59 async fn handle(&self, request: RegionRequest) -> MetaResult<RegionResponse> {
60 self.handle_inner(request).await.map_err(|err| {
61 if err.should_retry() {
62 meta_error::Error::RetryLater {
63 source: BoxedError::new(err),
64 clean_poisons: false,
65 }
66 } else {
67 meta_error::Error::External {
68 source: BoxedError::new(err),
69 location: location!(),
70 }
71 }
72 })
73 }
74
75 async fn handle_query(&self, request: QueryRequest) -> MetaResult<SendableRecordBatchStream> {
76 let plan = DFLogicalSubstraitConvertor
77 .encode(&request.plan, DefaultSerializer)
78 .map_err(BoxedError::new)
79 .context(meta_error::ExternalSnafu)?
80 .to_vec();
81 let request = api::v1::region::QueryRequest {
82 header: request.header,
83 region_id: request.region_id.as_u64(),
84 plan,
85 };
86
87 let ticket = Ticket {
88 ticket: request.encode_to_vec().into(),
89 };
90 self.do_get_inner(ticket)
91 .await
92 .map_err(BoxedError::new)
93 .context(meta_error::ExternalSnafu)
94 }
95}
96
97impl RegionRequester {
98 pub fn new(client: Client, send_compression: bool, accept_compression: bool) -> Self {
99 Self {
100 client,
101 send_compression,
102 accept_compression,
103 }
104 }
105
106 pub async fn do_get_inner(&self, ticket: Ticket) -> Result<SendableRecordBatchStream> {
107 let mut flight_client = self
108 .client
109 .make_flight_client(self.send_compression, self.accept_compression)?;
110 let response = flight_client
111 .mut_inner()
112 .do_get(ticket)
113 .await
114 .or_else(|e| {
115 let tonic_code = e.code();
116 let e: error::Error = e.into();
117 error!(
118 e; "Failed to do Flight get, addr: {}, code: {}",
119 flight_client.addr(),
120 tonic_code
121 );
122 Err(BoxedError::new(e)).with_context(|_| FlightGetSnafu {
123 addr: flight_client.addr().to_string(),
124 tonic_code,
125 })
126 })?;
127
128 let flight_data_stream = response.into_inner();
129 let mut decoder = FlightDecoder::default();
130
131 let mut flight_message_stream = flight_data_stream.map(move |flight_data| {
132 flight_data
133 .map_err(Error::from)
134 .and_then(|data| decoder.try_decode(&data).context(ConvertFlightDataSnafu))?
135 .context(IllegalFlightMessagesSnafu {
136 reason: "none message",
137 })
138 });
139
140 let Some(first_flight_message) = flight_message_stream.next().await else {
141 return IllegalFlightMessagesSnafu {
142 reason: "Expect the response not to be empty",
143 }
144 .fail();
145 };
146 let FlightMessage::Schema(schema) = first_flight_message? else {
147 return IllegalFlightMessagesSnafu {
148 reason: "Expect schema to be the first flight message",
149 }
150 .fail();
151 };
152
153 let metrics = Arc::new(ArcSwapOption::from(None));
154 let metrics_ref = metrics.clone();
155
156 let tracing_context = TracingContext::from_current_span();
157
158 let schema = Arc::new(
159 datatypes::schema::Schema::try_from(schema).context(error::ConvertSchemaSnafu)?,
160 );
161 let schema_cloned = schema.clone();
162 let stream = Box::pin(stream!({
163 let _span = tracing_context.attach(common_telemetry::tracing::info_span!(
164 "poll_flight_data_stream"
165 ));
166
167 let mut buffered_message: Option<FlightMessage> = None;
168 let mut stream_ended = false;
169
170 while !stream_ended {
171 let flight_message_item = if let Some(msg) = buffered_message.take() {
173 Some(Ok(msg))
174 } else {
175 flight_message_stream.next().await
176 };
177
178 let flight_message = match flight_message_item {
179 Some(Ok(message)) => message,
180 Some(Err(e)) => {
181 yield Err(BoxedError::new(e)).context(ExternalSnafu);
182 break;
183 }
184 None => break,
185 };
186
187 match flight_message {
188 FlightMessage::RecordBatch(record_batch) => {
189 let result_to_yield =
190 RecordBatch::from_df_record_batch(schema_cloned.clone(), record_batch);
191
192 if let Some(next_flight_message_result) = flight_message_stream.next().await
194 {
195 match next_flight_message_result {
196 Ok(FlightMessage::Metrics(s)) => {
197 let m = serde_json::from_str(&s).ok().map(Arc::new);
198 metrics_ref.swap(m);
199 }
200 Ok(FlightMessage::RecordBatch(rb)) => {
201 buffered_message = Some(FlightMessage::RecordBatch(rb));
204 }
205 Ok(_) => {
206 yield IllegalFlightMessagesSnafu {
207 reason: "A RecordBatch message can only be succeeded by a Metrics message or another RecordBatch message"
208 }
209 .fail()
210 .map_err(BoxedError::new)
211 .context(ExternalSnafu);
212 break;
213 }
214 Err(e) => {
215 yield Err(BoxedError::new(e)).context(ExternalSnafu);
216 break;
217 }
218 }
219 } else {
220 stream_ended = true;
222 }
223
224 yield Ok(result_to_yield);
225 }
226 FlightMessage::Metrics(s) => {
227 let m = serde_json::from_str(&s).ok().map(Arc::new);
229 metrics_ref.swap(m);
230 break;
231 }
232 _ => {
233 yield IllegalFlightMessagesSnafu {
234 reason: "A Schema message must be succeeded exclusively by a set of RecordBatch messages"
235 }
236 .fail()
237 .map_err(BoxedError::new)
238 .context(ExternalSnafu);
239 break;
240 }
241 }
242 }
243 }));
244 let record_batch_stream = RecordBatchStreamWrapper {
245 schema,
246 stream,
247 output_ordering: None,
248 metrics,
249 span: Span::current(),
250 };
251 Ok(Box::pin(record_batch_stream))
252 }
253
254 async fn handle_inner(&self, request: RegionRequest) -> Result<RegionResponse> {
255 let request_type = request
256 .body
257 .as_ref()
258 .with_context(|| MissingFieldSnafu { field: "body" })?
259 .as_ref()
260 .to_string();
261 let _timer = metrics::METRIC_REGION_REQUEST_GRPC
262 .with_label_values(&[request_type.as_str()])
263 .start_timer();
264
265 let (addr, mut client) = self.client.raw_region_client()?;
266
267 let response = client
268 .handle(request)
269 .await
270 .map_err(|e| {
271 let code = e.code();
272 error::Error::RegionServer {
274 addr,
275 code,
276 source: BoxedError::new(error::Error::from(e)),
277 location: location!(),
278 }
279 })?
280 .into_inner();
281
282 check_response_header(&response.header)?;
283
284 Ok(RegionResponse::from_region_response(response))
285 }
286
287 pub async fn handle(&self, request: RegionRequest) -> Result<RegionResponse> {
288 self.handle_inner(request).await
289 }
290
291 pub async fn handle_remote_dyn_filter_update(
292 &self,
293 query_id: impl Into<String>,
294 update: RemoteDynFilterUpdate,
295 ) -> Result<RegionResponse> {
296 self.handle_inner(build_remote_dyn_filter_request(
297 query_id.into(),
298 remote_dyn_filter_request::Action::Update(update),
299 ))
300 .await
301 }
302
303 pub async fn handle_remote_dyn_filter_unregister(
304 &self,
305 query_id: impl Into<String>,
306 unregister: RemoteDynFilterUnregister,
307 ) -> Result<RegionResponse> {
308 self.handle_inner(build_remote_dyn_filter_request(
309 query_id.into(),
310 remote_dyn_filter_request::Action::Unregister(unregister),
311 ))
312 .await
313 }
314}
315
316fn build_remote_dyn_filter_request(
317 query_id: String,
318 action: remote_dyn_filter_request::Action,
319) -> RegionRequest {
320 RegionRequest {
321 header: Some(RegionRequestHeader {
322 tracing_context: TracingContext::from_current_span().to_w3c(),
323 ..Default::default()
324 }),
325 body: Some(region_request::Body::RemoteDynFilter(
326 RemoteDynFilterRequest {
327 query_id,
328 action: Some(action),
329 },
330 )),
331 }
332}
333
334pub fn check_response_header(header: &Option<ResponseHeader>) -> Result<()> {
335 let status = header
336 .as_ref()
337 .and_then(|header| header.status.as_ref())
338 .context(IllegalDatabaseResponseSnafu {
339 err_msg: "either response header or status is missing",
340 })?;
341
342 if StatusCode::is_success(status.status_code) {
343 Ok(())
344 } else {
345 let code =
346 StatusCode::from_u32(status.status_code).context(IllegalDatabaseResponseSnafu {
347 err_msg: format!("unknown server status: {:?}", status),
348 })?;
349 ServerSnafu {
350 code,
351 msg: status.err_msg.clone(),
352 }
353 .fail()
354 }
355}
356
357#[cfg(test)]
358mod test {
359 use api::v1::Status as PbStatus;
360 use api::v1::region::{RemoteDynFilterUpdate, region_request, remote_dyn_filter_request};
361
362 use super::*;
363 use crate::Error::{IllegalDatabaseResponse, Server};
364
365 #[test]
366 fn test_check_response_header() {
367 let result = check_response_header(&None);
368 assert!(matches!(
369 result.unwrap_err(),
370 IllegalDatabaseResponse { .. }
371 ));
372
373 let result = check_response_header(&Some(ResponseHeader { status: None }));
374 assert!(matches!(
375 result.unwrap_err(),
376 IllegalDatabaseResponse { .. }
377 ));
378
379 let result = check_response_header(&Some(ResponseHeader {
380 status: Some(PbStatus {
381 status_code: StatusCode::Success as u32,
382 err_msg: String::default(),
383 }),
384 }));
385 assert!(result.is_ok());
386
387 let result = check_response_header(&Some(ResponseHeader {
388 status: Some(PbStatus {
389 status_code: u32::MAX,
390 err_msg: String::default(),
391 }),
392 }));
393 assert!(matches!(
394 result.unwrap_err(),
395 IllegalDatabaseResponse { .. }
396 ));
397
398 let result = check_response_header(&Some(ResponseHeader {
399 status: Some(PbStatus {
400 status_code: StatusCode::Internal as u32,
401 err_msg: "blabla".to_string(),
402 }),
403 }));
404 let Server { code, msg, .. } = result.unwrap_err() else {
405 unreachable!()
406 };
407 assert_eq!(code, StatusCode::Internal);
408 assert_eq!(msg, "blabla");
409 }
410
411 #[test]
412 fn test_build_remote_dyn_filter_request_sets_header_and_body() {
413 let request = build_remote_dyn_filter_request(
414 "query-1".to_string(),
415 remote_dyn_filter_request::Action::Update(RemoteDynFilterUpdate {
416 filter_id: "filter-1".to_string(),
417 payload: vec![1, 2, 3],
418 generation: 7,
419 is_complete: false,
420 }),
421 );
422
423 request.header.expect("remote dyn filter header must exist");
424
425 let body = request.body.expect("remote dyn filter body must exist");
426 let region_request::Body::RemoteDynFilter(remote_request) = body else {
427 panic!("expected remote dyn filter request body");
428 };
429
430 assert_eq!(remote_request.query_id, "query-1");
431 assert!(matches!(
432 remote_request.action,
433 Some(remote_dyn_filter_request::Action::Update(_))
434 ));
435 }
436}