refactor: bundle the lightweight axum test client (#3669)

* refactor: bundle the lightweight axum test client

Signed-off-by: tison <wander4096@gmail.com>

* address comments

Signed-off-by: tison <wander4096@gmail.com>

---------

Signed-off-by: tison <wander4096@gmail.com>
This commit is contained in:
tison
2024-04-09 10:33:26 +08:00
committed by GitHub
parent ea9367f371
commit 883b7fce96
11 changed files with 269 additions and 31 deletions

25
Cargo.lock generated
View File

@@ -793,24 +793,6 @@ dependencies = [
"syn 2.0.43",
]
[[package]]
name = "axum-test-helper"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "298f62fa902c2515c169ab0bfb56c593229f33faa01131215d58e3d4898e3aa9"
dependencies = [
"axum",
"bytes",
"http",
"http-body",
"hyper",
"reqwest",
"serde",
"tokio",
"tower",
"tower-service",
]
[[package]]
name = "backon"
version = "0.4.1"
@@ -3988,9 +3970,9 @@ dependencies = [
[[package]]
name = "http"
version = "0.2.11"
version = "0.2.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8947b1a6fad4393052c7ba1f4cd97bed3e953a95c79c92ad9b051a04611d9fbb"
checksum = "601cbb57e577e2f5ef5be8e7b83f0f63994f25aa94d673e54a92d5c516d101f1"
dependencies = [
"bytes",
"fnv",
@@ -9030,7 +9012,6 @@ dependencies = [
"auth",
"axum",
"axum-macros",
"axum-test-helper",
"base64 0.21.5",
"bytes",
"catalog",
@@ -9061,6 +9042,7 @@ dependencies = [
"hashbrown 0.14.3",
"headers",
"hostname",
"http",
"http-body",
"humantime-serde",
"hyper",
@@ -10105,7 +10087,6 @@ dependencies = [
"async-trait",
"auth",
"axum",
"axum-test-helper",
"catalog",
"chrono",
"client",

View File

@@ -133,6 +133,7 @@ reqwest = { version = "0.11", default-features = false, features = [
"json",
"rustls-tls-native-roots",
"stream",
"multipart",
] }
rskafka = "0.5"
rust_decimal = "1.33"

View File

@@ -50,7 +50,8 @@ derive_builder.workspace = true
futures = "0.3"
hashbrown = "0.14"
headers = "0.3"
hostname = "0.3.1"
hostname = "0.3"
http = "0.2"
http-body = "0.4"
humantime-serde.workspace = true
hyper = { version = "0.14", features = ["full"] }
@@ -109,7 +110,6 @@ tikv-jemalloc-ctl = { version = "0.5", features = ["use_std"] }
[dev-dependencies]
auth = { workspace = true, features = ["testing"] }
axum-test-helper = "0.3"
catalog = { workspace = true, features = ["testing"] }
client.workspace = true
common-base.workspace = true

View File

@@ -94,6 +94,9 @@ pub mod greptime_result_v1;
pub mod influxdb_result_v1;
pub mod table_result;
#[cfg(any(test, feature = "testing"))]
pub mod test_helpers;
pub const HTTP_API_VERSION: &str = "v1";
pub const HTTP_API_PREFIX: &str = "/v1/";
/// Default http body limit (64M).
@@ -824,7 +827,6 @@ mod test {
use axum::handler::Handler;
use axum::http::StatusCode;
use axum::routing::get;
use axum_test_helper::TestClient;
use common_query::Output;
use common_recordbatch::RecordBatches;
use datatypes::prelude::*;
@@ -838,6 +840,7 @@ mod test {
use super::*;
use crate::error::Error;
use crate::http::test_helpers::TestClient;
use crate::query_handler::grpc::GrpcQueryHandler;
use crate::query_handler::sql::{ServerSqlQueryHandlerAdapter, SqlQueryHandler};

View File

@@ -0,0 +1,254 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Axum Test Client
//!
//! ```rust
//! use axum::Router;
//! use axum::http::StatusCode;
//! use axum::routing::get;
//! use crate::servers::http::test_helpers::TestClient;
//!
//! let async_block = async {
//! // you can replace this Router with your own app
//! let app = Router::new().route("/", get(|| async {}));
//!
//! // initiate the TestClient with the previous declared Router
//! let client = TestClient::new(app);
//!
//! let res = client.get("/").await;
//! assert_eq!(res.status(), StatusCode::OK);
//! };
//!
//! // Create a runtime for executing the async block. This runtime is local
//! // to the main function and does not require any global setup.
//! let runtime = tokio::runtime::Builder::new_current_thread()
//! .enable_all()
//! .build()
//! .unwrap();
//!
//! // Use the local runtime to block on the async block.
//! runtime.block_on(async_block);
//! ```
use std::convert::TryFrom;
use std::net::{SocketAddr, TcpListener};
use axum::body::HttpBody;
use axum::BoxError;
use bytes::Bytes;
use common_telemetry::info;
use http::header::{HeaderName, HeaderValue};
use http::{Request, StatusCode};
use hyper::service::Service;
use hyper::{Body, Server};
use tower::make::Shared;
/// Test client to Axum servers.
pub struct TestClient {
client: reqwest::Client,
addr: SocketAddr,
}
impl TestClient {
/// Create a new test client.
pub fn new<S, ResBody>(svc: S) -> Self
where
S: Service<Request<Body>, Response = http::Response<ResBody>> + Clone + Send + 'static,
ResBody: HttpBody + Send + 'static,
ResBody::Data: Send,
ResBody::Error: Into<BoxError>,
S::Future: Send,
S::Error: Into<BoxError>,
{
let listener = TcpListener::bind("127.0.0.1:0").expect("Could not bind ephemeral socket");
let addr = listener.local_addr().unwrap();
info!("Listening on {}", addr);
tokio::spawn(async move {
let server = Server::from_tcp(listener).unwrap().serve(Shared::new(svc));
server.await.expect("server error");
});
let client = reqwest::Client::builder()
.redirect(reqwest::redirect::Policy::none())
.build()
.unwrap();
TestClient { client, addr }
}
/// Returns the base URL (http://ip:port) for this TestClient
///
/// this is useful when trying to check if Location headers in responses
/// are generated correctly as Location contains an absolute URL
pub fn base_url(&self) -> String {
format!("http://{}", self.addr)
}
/// Create a GET request.
pub fn get(&self, url: &str) -> RequestBuilder {
RequestBuilder {
builder: self.client.get(format!("http://{}{}", self.addr, url)),
}
}
/// Create a HEAD request.
pub fn head(&self, url: &str) -> RequestBuilder {
RequestBuilder {
builder: self.client.head(format!("http://{}{}", self.addr, url)),
}
}
/// Create a POST request.
pub fn post(&self, url: &str) -> RequestBuilder {
RequestBuilder {
builder: self.client.post(format!("http://{}{}", self.addr, url)),
}
}
/// Create a PUT request.
pub fn put(&self, url: &str) -> RequestBuilder {
RequestBuilder {
builder: self.client.put(format!("http://{}{}", self.addr, url)),
}
}
/// Create a PATCH request.
pub fn patch(&self, url: &str) -> RequestBuilder {
RequestBuilder {
builder: self.client.patch(format!("http://{}{}", self.addr, url)),
}
}
/// Create a DELETE request.
pub fn delete(&self, url: &str) -> RequestBuilder {
RequestBuilder {
builder: self.client.delete(format!("http://{}{}", self.addr, url)),
}
}
}
/// Builder for test requests.
pub struct RequestBuilder {
builder: reqwest::RequestBuilder,
}
impl RequestBuilder {
pub async fn send(self) -> TestResponse {
TestResponse {
response: self.builder.send().await.unwrap(),
}
}
/// Set the request body.
pub fn body(mut self, body: impl Into<reqwest::Body>) -> Self {
self.builder = self.builder.body(body);
self
}
/// Set the request forms.
pub fn form<T: serde::Serialize + ?Sized>(mut self, form: &T) -> Self {
self.builder = self.builder.form(&form);
self
}
/// Set the request JSON body.
pub fn json<T>(mut self, json: &T) -> Self
where
T: serde::Serialize,
{
self.builder = self.builder.json(json);
self
}
/// Set a request header.
pub fn header<K, V>(mut self, key: K, value: V) -> Self
where
HeaderName: TryFrom<K>,
<HeaderName as TryFrom<K>>::Error: Into<http::Error>,
HeaderValue: TryFrom<V>,
<HeaderValue as TryFrom<V>>::Error: Into<http::Error>,
{
self.builder = self.builder.header(key, value);
self
}
/// Set a request multipart form.
pub fn multipart(mut self, form: reqwest::multipart::Form) -> Self {
self.builder = self.builder.multipart(form);
self
}
}
/// A wrapper around [`reqwest::Response`] that provides common methods with internal `unwrap()`s.
///
/// This is convenient for tests where panics are what you want. For access to
/// non-panicking versions or the complete `Response` API use `into_inner()` or
/// `as_ref()`.
pub struct TestResponse {
response: reqwest::Response,
}
impl TestResponse {
/// Get the response body as text.
pub async fn text(self) -> String {
self.response.text().await.unwrap()
}
/// Get the response body as bytes.
pub async fn bytes(self) -> Bytes {
self.response.bytes().await.unwrap()
}
/// Get the response body as JSON.
pub async fn json<T>(self) -> T
where
T: serde::de::DeserializeOwned,
{
self.response.json().await.unwrap()
}
/// Get the response status.
pub fn status(&self) -> StatusCode {
self.response.status()
}
/// Get the response headers.
pub fn headers(&self) -> &http::HeaderMap {
self.response.headers()
}
/// Get the response in chunks.
pub async fn chunk(&mut self) -> Option<Bytes> {
self.response.chunk().await.unwrap()
}
/// Get the response in chunks as text.
pub async fn chunk_text(&mut self) -> Option<String> {
let chunk = self.chunk().await?;
Some(String::from_utf8(chunk.to_vec()).unwrap())
}
/// Get the inner [`reqwest::Response`] for less convenient but more complete access.
pub fn into_inner(self) -> reqwest::Response {
self.response
}
}
impl AsRef<reqwest::Response> for TestResponse {
fn as_ref(&self) -> &reqwest::Response {
&self.response
}
}

View File

@@ -13,8 +13,8 @@
// limitations under the License.
use axum::Router;
use axum_test_helper::TestClient;
use common_test_util::ports;
use servers::http::test_helpers::TestClient;
use servers::http::{HttpOptions, HttpServerBuilder};
use table::test_util::MemTable;

View File

@@ -19,7 +19,6 @@ use api::v1::RowInsertRequests;
use async_trait::async_trait;
use auth::tests::{DatabaseAuthInfo, MockUserProvider};
use axum::{http, Router};
use axum_test_helper::TestClient;
use common_query::Output;
use common_test_util::ports;
use query::parser::PromQuery;
@@ -27,6 +26,7 @@ use query::plan::LogicalPlan;
use query::query_engine::DescribeResult;
use servers::error::{Error, Result};
use servers::http::header::constants::GREPTIME_DB_HEADER_NAME;
use servers::http::test_helpers::TestClient;
use servers::http::{HttpOptions, HttpServerBuilder};
use servers::influxdb::InfluxdbRequest;
use servers::query_handler::grpc::GrpcQueryHandler;

View File

@@ -17,13 +17,13 @@ use std::sync::Arc;
use api::v1::greptime_request::Request;
use async_trait::async_trait;
use axum::Router;
use axum_test_helper::TestClient;
use common_query::Output;
use common_test_util::ports;
use query::parser::PromQuery;
use query::plan::LogicalPlan;
use query::query_engine::DescribeResult;
use servers::error::{self, Result};
use servers::http::test_helpers::TestClient;
use servers::http::{HttpOptions, HttpServerBuilder};
use servers::opentsdb::codec::DataPoint;
use servers::query_handler::grpc::GrpcQueryHandler;

View File

@@ -21,7 +21,6 @@ use api::v1::greptime_request::Request;
use api::v1::RowInsertRequests;
use async_trait::async_trait;
use axum::Router;
use axum_test_helper::TestClient;
use common_query::Output;
use common_test_util::ports;
use prost::Message;
@@ -30,6 +29,7 @@ use query::plan::LogicalPlan;
use query::query_engine::DescribeResult;
use servers::error::{Error, Result};
use servers::http::header::{CONTENT_ENCODING_SNAPPY, CONTENT_TYPE_PROTOBUF};
use servers::http::test_helpers::TestClient;
use servers::http::{HttpOptions, HttpServerBuilder};
use servers::prom_store;
use servers::prom_store::{snappy_compress, Metrics};

View File

@@ -16,7 +16,6 @@ arrow-flight.workspace = true
async-trait = "0.1"
auth.workspace = true
axum.workspace = true
axum-test-helper = "0.3.0"
catalog.workspace = true
chrono.workspace = true
client = { workspace = true, features = ["testing"] }

View File

@@ -17,7 +17,6 @@ use std::collections::BTreeMap;
use api::prom_store::remote::WriteRequest;
use auth::user_provider_from_option;
use axum::http::{HeaderName, StatusCode};
use axum_test_helper::TestClient;
use common_error::status_code::StatusCode as ErrorCode;
use prost::Message;
use serde_json::json;
@@ -27,6 +26,7 @@ use servers::http::handler::HealthResponse;
use servers::http::header::GREPTIME_TIMEZONE_HEADER_NAME;
use servers::http::influxdb_result_v1::{InfluxdbOutput, InfluxdbV1Response};
use servers::http::prometheus::{PrometheusJsonResponse, PrometheusResponse};
use servers::http::test_helpers::TestClient;
use servers::http::GreptimeQueryOutput;
use servers::prom_store;
use tests_integration::test_util::{