Skip to main content

tests_integration/grpc/
flight.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#[cfg(test)]
16mod test {
17    use std::net::SocketAddr;
18    use std::sync::Arc;
19
20    use api::v1::auth_header::AuthScheme;
21    use api::v1::{Basic, ColumnDataType, ColumnDef, CreateTableExpr, SemanticType};
22    use arrow_flight::FlightDescriptor;
23    use auth::user_provider_from_option;
24    use client::{Client, Database};
25    use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
26    use common_grpc::flight::do_put::DoPutMetadata;
27    use common_grpc::flight::{FlightEncoder, FlightMessage};
28    use common_query::OutputData;
29    use common_recordbatch::RecordBatch;
30    use common_recordbatch::adapter::RegionWatermarkEntry;
31    use datatypes::prelude::{ConcreteDataType, ScalarVector, VectorRef};
32    use datatypes::schema::{ColumnSchema, Schema};
33    use datatypes::vectors::{Int32Vector, StringVector, TimestampMillisecondVector};
34    use futures_util::StreamExt;
35    use itertools::Itertools;
36    use servers::grpc::builder::GrpcServerBuilder;
37    use servers::grpc::greptime_handler::GreptimeRequestHandler;
38    use servers::grpc::{FlightCompression, GrpcServerConfig};
39    use servers::server::Server;
40
41    use crate::cluster::GreptimeDbClusterBuilder;
42    use crate::grpc::query_and_expect;
43    use crate::test_util::{StorageType, setup_grpc_server};
44    use crate::tests::test_util::MockInstance;
45
46    #[tokio::test(flavor = "multi_thread")]
47    async fn test_standalone_flight_do_put() {
48        common_telemetry::init_default_ut_logging();
49
50        let (db, server) =
51            setup_grpc_server(StorageType::File, "test_standalone_flight_do_put").await;
52        let addr = server.bind_addr().unwrap().to_string();
53
54        let client = Client::with_urls(vec![addr]);
55        let client = Database::new_with_dbname("greptime-public", client);
56
57        create_table(&client).await;
58
59        let record_batches = create_record_batches(1);
60        test_put_record_batches(&client, record_batches).await;
61
62        let sql = "select ts, a, `B` from foo order by ts";
63        let expected = "\
64+-------------------------+----+----+
65| ts                      | a  | B  |
66+-------------------------+----+----+
67| 1970-01-01T00:00:00.001 | -1 | s1 |
68| 1970-01-01T00:00:00.002 | -2 | s2 |
69| 1970-01-01T00:00:00.003 | -3 | s3 |
70| 1970-01-01T00:00:00.004 | -4 | s4 |
71| 1970-01-01T00:00:00.005 | -5 | s5 |
72| 1970-01-01T00:00:00.006 | -6 | s6 |
73| 1970-01-01T00:00:00.007 | -7 | s7 |
74| 1970-01-01T00:00:00.008 | -8 | s8 |
75| 1970-01-01T00:00:00.009 | -9 | s9 |
76+-------------------------+----+----+";
77        query_and_expect(db.frontend().as_ref(), sql, expected).await;
78    }
79
80    #[tokio::test(flavor = "multi_thread")]
81    async fn test_distributed_flight_do_put() {
82        common_telemetry::init_default_ut_logging();
83
84        let db = GreptimeDbClusterBuilder::new("test_distributed_flight_do_put")
85            .await
86            .build(false)
87            .await;
88
89        let runtime = common_runtime::global_runtime().clone();
90        let greptime_request_handler = GreptimeRequestHandler::new(
91            db.frontend.instance.clone(),
92            user_provider_from_option("static_user_provider:cmd:greptime_user=greptime_pwd").ok(),
93            Some(runtime.clone()),
94            FlightCompression::default(),
95        );
96        let mut grpc_server = GrpcServerBuilder::new(GrpcServerConfig::default(), runtime)
97            .flight_handler(Arc::new(greptime_request_handler))
98            .build();
99        grpc_server
100            .start("127.0.0.1:0".parse::<SocketAddr>().unwrap())
101            .await
102            .unwrap();
103        let addr = grpc_server.bind_addr().unwrap().to_string();
104
105        let client = Client::with_urls(vec![addr]);
106        let mut client = Database::new(DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, client);
107        client.set_auth(AuthScheme::Basic(Basic {
108            username: "greptime_user".to_string(),
109            password: "greptime_pwd".to_string(),
110        }));
111
112        create_table(&client).await;
113
114        let record_batches = create_record_batches(1);
115        test_put_record_batches(&client, record_batches).await;
116
117        let sql = "select ts, a, `B` from foo order by ts";
118        let expected = "\
119+-------------------------+----+----+
120| ts                      | a  | B  |
121+-------------------------+----+----+
122| 1970-01-01T00:00:00.001 | -1 | s1 |
123| 1970-01-01T00:00:00.002 | -2 | s2 |
124| 1970-01-01T00:00:00.003 | -3 | s3 |
125| 1970-01-01T00:00:00.004 | -4 | s4 |
126| 1970-01-01T00:00:00.005 | -5 | s5 |
127| 1970-01-01T00:00:00.006 | -6 | s6 |
128| 1970-01-01T00:00:00.007 | -7 | s7 |
129| 1970-01-01T00:00:00.008 | -8 | s8 |
130| 1970-01-01T00:00:00.009 | -9 | s9 |
131+-------------------------+----+----+";
132        query_and_expect(db.fe_instance().as_ref(), sql, expected).await;
133
134        let output = client.sql(sql).await.unwrap();
135        let OutputData::Stream(mut stream) = output.data else {
136            panic!("expected stream output");
137        };
138        while let Some(batch) = stream.next().await {
139            batch.unwrap();
140        }
141        let metrics = stream.metrics().expect("expected terminal metrics");
142        assert!(metrics.region_watermarks.is_empty());
143
144        let result = client
145            .sql_with_terminal_metrics(sql, &[("flow.return_region_seq", "true")])
146            .await
147            .unwrap();
148        let terminal_metrics = result.metrics.clone();
149        let OutputData::Stream(mut stream) = result.output.data else {
150            panic!("expected stream output");
151        };
152        while let Some(batch) = stream.next().await {
153            batch.unwrap();
154        }
155        assert!(terminal_metrics.is_ready());
156        let regions = db.list_all_regions().await;
157        assert_eq!(regions.len(), 1);
158        let (region_id, region) = regions.into_iter().next().unwrap();
159        let expected_watermark = (region_id.as_u64(), region.find_committed_sequence());
160        assert_eq!(
161            terminal_metrics.region_watermark_map(),
162            Some(std::collections::HashMap::from([expected_watermark]))
163        );
164
165        let output = client
166            .sql_with_hint(sql, &[("flow.return_region_seq", "true")])
167            .await
168            .unwrap();
169        let OutputData::Stream(mut stream) = output.data else {
170            panic!("expected stream output");
171        };
172
173        let mut row_count = 0;
174        while let Some(batch) = stream.next().await {
175            let batch = batch.unwrap();
176            row_count += batch.num_rows();
177        }
178        assert_eq!(row_count, 9);
179
180        let metrics = stream.metrics().expect("expected terminal metrics");
181        let region_watermarks = metrics.region_watermarks;
182        assert_eq!(
183            region_watermarks,
184            vec![RegionWatermarkEntry {
185                region_id: expected_watermark.0,
186                watermark: Some(expected_watermark.1),
187            }]
188        );
189
190        let previous_watermark = expected_watermark;
191
192        create_table_named(&client, "bar").await;
193        let result = client
194            .sql_with_terminal_metrics("insert into bar select ts, a, `B` from foo", &[])
195            .await
196            .unwrap();
197        let OutputData::AffectedRows(affected_rows) = result.output.data else {
198            panic!("expected affected rows output");
199        };
200        assert_eq!(affected_rows, 9);
201        assert!(result.metrics.is_ready());
202        assert!(result.region_watermark_map().is_none());
203
204        let err = client
205            .sql_with_terminal_metrics(
206                "insert into bar select ts, a, `B` from foo",
207                &[("flow.return_region_seq", "not-a-bool")],
208            )
209            .await
210            .unwrap_err();
211        let err_msg = format!("{err:?}");
212        assert!(err_msg.contains("Invalid value for flow.return_region_seq"));
213
214        client.sql("truncate table bar").await.unwrap();
215
216        let result = client
217            .sql_with_terminal_metrics(
218                "insert into bar select ts, a, `B` from foo",
219                &[("flow.return_region_seq", "true")],
220            )
221            .await
222            .unwrap();
223        let OutputData::AffectedRows(affected_rows) = result.output.data else {
224            panic!("expected affected rows output");
225        };
226        assert_eq!(affected_rows, 9);
227        assert_eq!(
228            result.region_watermark_map(),
229            Some(std::collections::HashMap::from([previous_watermark]))
230        );
231    }
232
233    async fn test_put_record_batches(client: &Database, record_batches: Vec<RecordBatch>) {
234        let requests_count = record_batches.len();
235        let schema = record_batches[0].schema.arrow_schema().clone();
236
237        let stream = futures::stream::once(async move {
238            let mut schema_data = FlightEncoder::default().encode_schema(schema.as_ref());
239            let metadata = DoPutMetadata::new(0);
240            schema_data.app_metadata = serde_json::to_vec(&metadata).unwrap().into();
241            // first message in "DoPut" stream should carry table name in flight descriptor
242            schema_data.flight_descriptor = Some(FlightDescriptor {
243                r#type: arrow_flight::flight_descriptor::DescriptorType::Path as i32,
244                path: vec!["foo".to_string()],
245                ..Default::default()
246            });
247            schema_data
248        })
249        .chain(
250            tokio_stream::iter(record_batches)
251                .enumerate()
252                .flat_map(|(i, x)| {
253                    let mut encoder = FlightEncoder::default();
254                    let message = FlightMessage::RecordBatch(x.into_df_record_batch());
255                    let mut data = encoder.encode(message);
256                    let metadata = DoPutMetadata::new((i + 1) as i64);
257                    data.iter_mut().for_each(|x| {
258                        x.app_metadata = serde_json::to_vec(&metadata).unwrap().into()
259                    });
260                    tokio_stream::iter(data)
261                })
262                .boxed(),
263        )
264        .boxed();
265
266        let response_stream = client.do_put(stream).await.unwrap();
267
268        let responses = response_stream.collect::<Vec<_>>().await;
269        let responses_count = responses.len();
270        for (i, response) in responses.into_iter().enumerate() {
271            assert!(response.is_ok(), "{}", response.err().unwrap());
272            let response = response.unwrap();
273            assert_eq!(response.request_id(), i as i64);
274            if i == 0 {
275                // the first is schema
276                assert_eq!(response.affected_rows(), 0);
277            } else {
278                assert_eq!(response.affected_rows(), 3);
279            }
280        }
281        assert_eq!(requests_count + 1, responses_count);
282    }
283
284    fn create_record_batches(start: i64) -> Vec<RecordBatch> {
285        let schema = Arc::new(Schema::new(vec![
286            ColumnSchema::new(
287                "ts",
288                ConcreteDataType::timestamp_millisecond_datatype(),
289                false,
290            )
291            .with_time_index(true),
292            ColumnSchema::new("a", ConcreteDataType::int32_datatype(), false),
293            ColumnSchema::new("B", ConcreteDataType::string_datatype(), true),
294        ]));
295
296        let mut record_batches = Vec::with_capacity(3);
297        for chunk in &(start..start + 9).chunks(3) {
298            let vs = chunk.collect_vec();
299            let x1 = vs[0];
300            let x2 = vs[1];
301            let x3 = vs[2];
302
303            record_batches.push(
304                RecordBatch::new(
305                    schema.clone(),
306                    vec![
307                        Arc::new(TimestampMillisecondVector::from_vec(vec![x1, x2, x3]))
308                            as VectorRef,
309                        Arc::new(Int32Vector::from_vec(vec![
310                            -x1 as i32, -x2 as i32, -x3 as i32,
311                        ])),
312                        Arc::new(StringVector::from_vec(vec![
313                            format!("s{x1}"),
314                            format!("s{x2}"),
315                            format!("s{x3}"),
316                        ])),
317                    ],
318                )
319                .unwrap(),
320            );
321        }
322        record_batches
323    }
324
325    async fn create_table(client: &Database) {
326        create_table_named(client, "foo").await;
327    }
328
329    async fn create_table_named(client: &Database, table_name: &str) {
330        // create table foo (
331        //   ts timestamp time index,
332        //   a int primary key,
333        //   b string,
334        // )
335        let output = client
336            .create(CreateTableExpr {
337                schema_name: "public".to_string(),
338                table_name: table_name.to_string(),
339                column_defs: vec![
340                    ColumnDef {
341                        name: "ts".to_string(),
342                        data_type: ColumnDataType::TimestampMillisecond as i32,
343                        semantic_type: SemanticType::Timestamp as i32,
344                        is_nullable: false,
345                        ..Default::default()
346                    },
347                    ColumnDef {
348                        name: "a".to_string(),
349                        data_type: ColumnDataType::Int32 as i32,
350                        semantic_type: SemanticType::Tag as i32,
351                        is_nullable: false,
352                        ..Default::default()
353                    },
354                    ColumnDef {
355                        name: "B".to_string(),
356                        data_type: ColumnDataType::String as i32,
357                        semantic_type: SemanticType::Field as i32,
358                        is_nullable: true,
359                        ..Default::default()
360                    },
361                ],
362                time_index: "ts".to_string(),
363                primary_keys: vec!["a".to_string()],
364                engine: "mito".to_string(),
365                ..Default::default()
366            })
367            .await
368            .unwrap();
369        let OutputData::AffectedRows(affected_rows) = output.data else {
370            unreachable!()
371        };
372        assert_eq!(affected_rows, 0);
373    }
374}