Skip to main content

tests_integration/grpc/
network.rs

1// Copyright 2023 Greptime Team
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#[cfg(test)]
16mod tests {
17    use std::io;
18    use std::pin::Pin;
19    use std::task::{Context, Poll};
20    use std::time::Duration;
21
22    use client::Client;
23    use common_grpc::channel_manager::ChannelManager;
24    use futures_util::future::BoxFuture;
25    use http::Uri;
26    use hyper_util::rt::TokioIo;
27    use servers::grpc::GrpcServerConfig;
28    use servers::server::Server;
29    use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
30    use tokio::net::TcpStream;
31    use tokio::sync::mpsc;
32    use tower::Service;
33
34    use crate::test_util::{StorageType, setup_grpc_server_with};
35
36    struct NetworkTrafficMonitorableConnector {
37        interested_tx: mpsc::Sender<String>,
38    }
39
40    impl Service<Uri> for NetworkTrafficMonitorableConnector {
41        type Response = TokioIo<CollectGrpcResponseFrameTypeStream>;
42        type Error = String;
43        type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
44
45        fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
46            Poll::Ready(Ok(()))
47        }
48
49        fn call(&mut self, uri: Uri) -> Self::Future {
50            let frame_types = self.interested_tx.clone();
51
52            Box::pin(async move {
53                let addr = format!(
54                    "{}:{}",
55                    uri.host().unwrap_or("localhost"),
56                    uri.port_u16().unwrap_or(4001),
57                );
58                let inner = TcpStream::connect(addr).await.map_err(|e| e.to_string())?;
59                Ok(TokioIo::new(CollectGrpcResponseFrameTypeStream {
60                    inner,
61                    frame_types,
62                }))
63            })
64        }
65    }
66
67    struct CollectGrpcResponseFrameTypeStream {
68        inner: TcpStream,
69        frame_types: mpsc::Sender<String>,
70    }
71
72    impl AsyncRead for CollectGrpcResponseFrameTypeStream {
73        fn poll_read(
74            mut self: Pin<&mut Self>,
75            cx: &mut Context<'_>,
76            buf: &mut ReadBuf<'_>,
77        ) -> Poll<io::Result<()>> {
78            let before_len = buf.filled().len();
79
80            let result = Pin::new(&mut self.inner).poll_read(cx, buf);
81            if let Poll::Ready(Ok(())) = &result {
82                let after_len = buf.filled().len();
83
84                let new_data = &buf.filled()[before_len..after_len];
85                if let Some(frame_type) = maybe_decode_frame_type(new_data)
86                    && let Err(_) = self.frame_types.try_send(frame_type.to_string())
87                {
88                    return Poll::Ready(Err(io::Error::other("interested party has gone")));
89                }
90            }
91            result
92        }
93    }
94
95    fn maybe_decode_frame_type(data: &[u8]) -> Option<&str> {
96        (data.len() >= 9).then(|| match data[3] {
97            0x0 => "DATA",
98            0x1 => "HEADERS",
99            0x2 => "PRIORITY",
100            0x3 => "RST_STREAM",
101            0x4 => "SETTINGS",
102            0x5 => "PUSH_PROMISE",
103            0x6 => "PING",
104            0x7 => "GOAWAY",
105            0x8 => "WINDOW_UPDATE",
106            0x9 => "CONTINUATION",
107            _ => "UNKNOWN",
108        })
109    }
110
111    impl AsyncWrite for CollectGrpcResponseFrameTypeStream {
112        fn poll_write(
113            mut self: Pin<&mut Self>,
114            cx: &mut Context<'_>,
115            buf: &[u8],
116        ) -> Poll<Result<usize, io::Error>> {
117            Pin::new(&mut self.inner).poll_write(cx, buf)
118        }
119
120        fn poll_flush(
121            mut self: Pin<&mut Self>,
122            cx: &mut Context<'_>,
123        ) -> Poll<Result<(), io::Error>> {
124            Pin::new(&mut self.inner).poll_flush(cx)
125        }
126
127        fn poll_shutdown(
128            mut self: Pin<&mut Self>,
129            cx: &mut Context<'_>,
130        ) -> Poll<Result<(), io::Error>> {
131            Pin::new(&mut self.inner).poll_shutdown(cx)
132        }
133    }
134
135    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
136    async fn test_grpc_max_connection_age() {
137        let config = GrpcServerConfig {
138            max_connection_age: Some(Duration::from_secs(1)),
139            ..Default::default()
140        };
141        let (_db, server) = setup_grpc_server_with(
142            StorageType::File,
143            "test_grpc_max_connection_age",
144            None,
145            Some(config),
146            None,
147        )
148        .await;
149        let addr = server.bind_addr().unwrap().to_string();
150
151        let channel_manager = ChannelManager::new();
152        let client = Client::with_manager_and_urls(channel_manager.clone(), vec![&addr]);
153
154        let (tx, mut rx) = mpsc::channel(1024);
155        channel_manager
156            .reset_with_connector(
157                &addr,
158                NetworkTrafficMonitorableConnector { interested_tx: tx },
159            )
160            .unwrap();
161
162        let recv = tokio::spawn(async move {
163            let sleep = tokio::time::sleep(Duration::from_secs(3));
164            tokio::pin!(sleep);
165
166            let mut frame_types = vec![];
167            loop {
168                tokio::select! {
169                    x = rx.recv() => {
170                        if let Some(x) = x {
171                            frame_types.push(x);
172                        } else {
173                            break;
174                        }
175                    }
176                    _ = &mut sleep => {
177                        break;
178                    }
179                }
180            }
181            frame_types
182        });
183
184        // Drive the gRPC connection, has no special meaning for this keep-alive test.
185        let _ = client.health_check().await;
186
187        let frame_types = recv.await.unwrap();
188        // If "max_connection_age" has taken effects, server will return a "GOAWAY" message.
189        assert!(
190            frame_types.iter().any(|x| x == "GOAWAY"),
191            "{:?}",
192            frame_types
193        );
194
195        server.shutdown().await.unwrap();
196    }
197}