1#[cfg(test)]
16mod test {
17 use std::net::SocketAddr;
18 use std::sync::Arc;
19
20 use api::v1::auth_header::AuthScheme;
21 use api::v1::{Basic, ColumnDataType, ColumnDef, CreateTableExpr, SemanticType};
22 use arrow_flight::FlightDescriptor;
23 use auth::user_provider_from_option;
24 use client::{Client, Database};
25 use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
26 use common_grpc::flight::do_put::DoPutMetadata;
27 use common_grpc::flight::{FlightEncoder, FlightMessage};
28 use common_query::OutputData;
29 use common_recordbatch::RecordBatch;
30 use common_recordbatch::adapter::RegionWatermarkEntry;
31 use datatypes::prelude::{ConcreteDataType, ScalarVector, VectorRef};
32 use datatypes::schema::{ColumnSchema, Schema};
33 use datatypes::vectors::{Int32Vector, StringVector, TimestampMillisecondVector};
34 use futures_util::StreamExt;
35 use itertools::Itertools;
36 use servers::grpc::builder::GrpcServerBuilder;
37 use servers::grpc::greptime_handler::GreptimeRequestHandler;
38 use servers::grpc::{FlightCompression, GrpcServerConfig};
39 use servers::server::Server;
40
41 use crate::cluster::GreptimeDbClusterBuilder;
42 use crate::grpc::query_and_expect;
43 use crate::test_util::{StorageType, setup_grpc_server};
44 use crate::tests::test_util::MockInstance;
45
46 #[tokio::test(flavor = "multi_thread")]
47 async fn test_standalone_flight_do_put() {
48 common_telemetry::init_default_ut_logging();
49
50 let (db, server) =
51 setup_grpc_server(StorageType::File, "test_standalone_flight_do_put").await;
52 let addr = server.bind_addr().unwrap().to_string();
53
54 let client = Client::with_urls(vec![addr]);
55 let client = Database::new_with_dbname("greptime-public", client);
56
57 create_table(&client).await;
58
59 let record_batches = create_record_batches(1);
60 test_put_record_batches(&client, record_batches).await;
61
62 let sql = "select ts, a, `B` from foo order by ts";
63 let expected = "\
64+-------------------------+----+----+
65| ts | a | B |
66+-------------------------+----+----+
67| 1970-01-01T00:00:00.001 | -1 | s1 |
68| 1970-01-01T00:00:00.002 | -2 | s2 |
69| 1970-01-01T00:00:00.003 | -3 | s3 |
70| 1970-01-01T00:00:00.004 | -4 | s4 |
71| 1970-01-01T00:00:00.005 | -5 | s5 |
72| 1970-01-01T00:00:00.006 | -6 | s6 |
73| 1970-01-01T00:00:00.007 | -7 | s7 |
74| 1970-01-01T00:00:00.008 | -8 | s8 |
75| 1970-01-01T00:00:00.009 | -9 | s9 |
76+-------------------------+----+----+";
77 query_and_expect(db.frontend().as_ref(), sql, expected).await;
78 }
79
80 #[tokio::test(flavor = "multi_thread")]
81 async fn test_distributed_flight_do_put() {
82 common_telemetry::init_default_ut_logging();
83
84 let db = GreptimeDbClusterBuilder::new("test_distributed_flight_do_put")
85 .await
86 .build(false)
87 .await;
88
89 let runtime = common_runtime::global_runtime().clone();
90 let greptime_request_handler = GreptimeRequestHandler::new(
91 db.frontend.instance.clone(),
92 user_provider_from_option("static_user_provider:cmd:greptime_user=greptime_pwd").ok(),
93 Some(runtime.clone()),
94 FlightCompression::default(),
95 );
96 let mut grpc_server = GrpcServerBuilder::new(GrpcServerConfig::default(), runtime)
97 .flight_handler(Arc::new(greptime_request_handler))
98 .build();
99 grpc_server
100 .start("127.0.0.1:0".parse::<SocketAddr>().unwrap())
101 .await
102 .unwrap();
103 let addr = grpc_server.bind_addr().unwrap().to_string();
104
105 let client = Client::with_urls(vec![addr]);
106 let mut client = Database::new(DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, client);
107 client.set_auth(AuthScheme::Basic(Basic {
108 username: "greptime_user".to_string(),
109 password: "greptime_pwd".to_string(),
110 }));
111
112 create_table(&client).await;
113
114 let record_batches = create_record_batches(1);
115 test_put_record_batches(&client, record_batches).await;
116
117 let sql = "select ts, a, `B` from foo order by ts";
118 let expected = "\
119+-------------------------+----+----+
120| ts | a | B |
121+-------------------------+----+----+
122| 1970-01-01T00:00:00.001 | -1 | s1 |
123| 1970-01-01T00:00:00.002 | -2 | s2 |
124| 1970-01-01T00:00:00.003 | -3 | s3 |
125| 1970-01-01T00:00:00.004 | -4 | s4 |
126| 1970-01-01T00:00:00.005 | -5 | s5 |
127| 1970-01-01T00:00:00.006 | -6 | s6 |
128| 1970-01-01T00:00:00.007 | -7 | s7 |
129| 1970-01-01T00:00:00.008 | -8 | s8 |
130| 1970-01-01T00:00:00.009 | -9 | s9 |
131+-------------------------+----+----+";
132 query_and_expect(db.fe_instance().as_ref(), sql, expected).await;
133
134 let output = client.sql(sql).await.unwrap();
135 let OutputData::Stream(mut stream) = output.data else {
136 panic!("expected stream output");
137 };
138 while let Some(batch) = stream.next().await {
139 batch.unwrap();
140 }
141 let metrics = stream.metrics().expect("expected terminal metrics");
142 assert!(metrics.region_watermarks.is_empty());
143
144 let result = client
145 .sql_with_terminal_metrics(sql, &[("flow.return_region_seq", "true")])
146 .await
147 .unwrap();
148 let terminal_metrics = result.metrics.clone();
149 let OutputData::Stream(mut stream) = result.output.data else {
150 panic!("expected stream output");
151 };
152 while let Some(batch) = stream.next().await {
153 batch.unwrap();
154 }
155 assert!(terminal_metrics.is_ready());
156 let regions = db.list_all_regions().await;
157 assert_eq!(regions.len(), 1);
158 let (region_id, region) = regions.into_iter().next().unwrap();
159 let expected_watermark = (region_id.as_u64(), region.find_committed_sequence());
160 assert_eq!(
161 terminal_metrics.region_watermark_map(),
162 Some(std::collections::HashMap::from([expected_watermark]))
163 );
164
165 let output = client
166 .sql_with_hint(sql, &[("flow.return_region_seq", "true")])
167 .await
168 .unwrap();
169 let OutputData::Stream(mut stream) = output.data else {
170 panic!("expected stream output");
171 };
172
173 let mut row_count = 0;
174 while let Some(batch) = stream.next().await {
175 let batch = batch.unwrap();
176 row_count += batch.num_rows();
177 }
178 assert_eq!(row_count, 9);
179
180 let metrics = stream.metrics().expect("expected terminal metrics");
181 let region_watermarks = metrics.region_watermarks;
182 assert_eq!(
183 region_watermarks,
184 vec![RegionWatermarkEntry {
185 region_id: expected_watermark.0,
186 watermark: Some(expected_watermark.1),
187 }]
188 );
189
190 let previous_watermark = expected_watermark;
191
192 create_table_named(&client, "bar").await;
193 let result = client
194 .sql_with_terminal_metrics("insert into bar select ts, a, `B` from foo", &[])
195 .await
196 .unwrap();
197 let OutputData::AffectedRows(affected_rows) = result.output.data else {
198 panic!("expected affected rows output");
199 };
200 assert_eq!(affected_rows, 9);
201 assert!(result.metrics.is_ready());
202 assert!(result.region_watermark_map().is_none());
203
204 let err = client
205 .sql_with_terminal_metrics(
206 "insert into bar select ts, a, `B` from foo",
207 &[("flow.return_region_seq", "not-a-bool")],
208 )
209 .await
210 .unwrap_err();
211 let err_msg = format!("{err:?}");
212 assert!(err_msg.contains("Invalid value for flow.return_region_seq"));
213
214 client.sql("truncate table bar").await.unwrap();
215
216 let result = client
217 .sql_with_terminal_metrics(
218 "insert into bar select ts, a, `B` from foo",
219 &[("flow.return_region_seq", "true")],
220 )
221 .await
222 .unwrap();
223 let OutputData::AffectedRows(affected_rows) = result.output.data else {
224 panic!("expected affected rows output");
225 };
226 assert_eq!(affected_rows, 9);
227 assert_eq!(
228 result.region_watermark_map(),
229 Some(std::collections::HashMap::from([previous_watermark]))
230 );
231 }
232
233 async fn test_put_record_batches(client: &Database, record_batches: Vec<RecordBatch>) {
234 let requests_count = record_batches.len();
235 let schema = record_batches[0].schema.arrow_schema().clone();
236
237 let stream = futures::stream::once(async move {
238 let mut schema_data = FlightEncoder::default().encode_schema(schema.as_ref());
239 let metadata = DoPutMetadata::new(0);
240 schema_data.app_metadata = serde_json::to_vec(&metadata).unwrap().into();
241 schema_data.flight_descriptor = Some(FlightDescriptor {
243 r#type: arrow_flight::flight_descriptor::DescriptorType::Path as i32,
244 path: vec!["foo".to_string()],
245 ..Default::default()
246 });
247 schema_data
248 })
249 .chain(
250 tokio_stream::iter(record_batches)
251 .enumerate()
252 .flat_map(|(i, x)| {
253 let mut encoder = FlightEncoder::default();
254 let message = FlightMessage::RecordBatch(x.into_df_record_batch());
255 let mut data = encoder.encode(message);
256 let metadata = DoPutMetadata::new((i + 1) as i64);
257 data.iter_mut().for_each(|x| {
258 x.app_metadata = serde_json::to_vec(&metadata).unwrap().into()
259 });
260 tokio_stream::iter(data)
261 })
262 .boxed(),
263 )
264 .boxed();
265
266 let response_stream = client.do_put(stream).await.unwrap();
267
268 let responses = response_stream.collect::<Vec<_>>().await;
269 let responses_count = responses.len();
270 for (i, response) in responses.into_iter().enumerate() {
271 assert!(response.is_ok(), "{}", response.err().unwrap());
272 let response = response.unwrap();
273 assert_eq!(response.request_id(), i as i64);
274 if i == 0 {
275 assert_eq!(response.affected_rows(), 0);
277 } else {
278 assert_eq!(response.affected_rows(), 3);
279 }
280 }
281 assert_eq!(requests_count + 1, responses_count);
282 }
283
284 fn create_record_batches(start: i64) -> Vec<RecordBatch> {
285 let schema = Arc::new(Schema::new(vec![
286 ColumnSchema::new(
287 "ts",
288 ConcreteDataType::timestamp_millisecond_datatype(),
289 false,
290 )
291 .with_time_index(true),
292 ColumnSchema::new("a", ConcreteDataType::int32_datatype(), false),
293 ColumnSchema::new("B", ConcreteDataType::string_datatype(), true),
294 ]));
295
296 let mut record_batches = Vec::with_capacity(3);
297 for chunk in &(start..start + 9).chunks(3) {
298 let vs = chunk.collect_vec();
299 let x1 = vs[0];
300 let x2 = vs[1];
301 let x3 = vs[2];
302
303 record_batches.push(
304 RecordBatch::new(
305 schema.clone(),
306 vec![
307 Arc::new(TimestampMillisecondVector::from_vec(vec![x1, x2, x3]))
308 as VectorRef,
309 Arc::new(Int32Vector::from_vec(vec![
310 -x1 as i32, -x2 as i32, -x3 as i32,
311 ])),
312 Arc::new(StringVector::from_vec(vec![
313 format!("s{x1}"),
314 format!("s{x2}"),
315 format!("s{x3}"),
316 ])),
317 ],
318 )
319 .unwrap(),
320 );
321 }
322 record_batches
323 }
324
325 async fn create_table(client: &Database) {
326 create_table_named(client, "foo").await;
327 }
328
329 async fn create_table_named(client: &Database, table_name: &str) {
330 let output = client
336 .create(CreateTableExpr {
337 schema_name: "public".to_string(),
338 table_name: table_name.to_string(),
339 column_defs: vec![
340 ColumnDef {
341 name: "ts".to_string(),
342 data_type: ColumnDataType::TimestampMillisecond as i32,
343 semantic_type: SemanticType::Timestamp as i32,
344 is_nullable: false,
345 ..Default::default()
346 },
347 ColumnDef {
348 name: "a".to_string(),
349 data_type: ColumnDataType::Int32 as i32,
350 semantic_type: SemanticType::Tag as i32,
351 is_nullable: false,
352 ..Default::default()
353 },
354 ColumnDef {
355 name: "B".to_string(),
356 data_type: ColumnDataType::String as i32,
357 semantic_type: SemanticType::Field as i32,
358 is_nullable: true,
359 ..Default::default()
360 },
361 ],
362 time_index: "ts".to_string(),
363 primary_keys: vec!["a".to_string()],
364 engine: "mito".to_string(),
365 ..Default::default()
366 })
367 .await
368 .unwrap();
369 let OutputData::AffectedRows(affected_rows) = output.data else {
370 unreachable!()
371 };
372 assert_eq!(affected_rows, 0);
373 }
374}