feat: refactor for test (#375)

* chore: add set_header macro & remove some unnessary code

* chore: channel_mannager with connector
This commit is contained in:
Jiachun Feng
2022-11-01 17:34:54 +08:00
committed by GitHub
parent ea2ebc0e87
commit 2d4a44414d
12 changed files with 222 additions and 137 deletions

1
Cargo.lock generated
View File

@@ -983,6 +983,7 @@ dependencies = [
"snafu",
"tokio",
"tonic",
"tower",
]
[[package]]

View File

@@ -4,91 +4,51 @@ pub const PROTOCOL_VERSION: u64 = 1;
impl RequestHeader {
#[inline]
pub fn new((cluster_id, member_id): (u64, u64)) -> Option<Self> {
Some(Self {
pub fn new((cluster_id, member_id): (u64, u64)) -> Self {
Self {
protocol_version: PROTOCOL_VERSION,
cluster_id,
member_id,
})
}
}
}
impl ResponseHeader {
#[inline]
pub fn success(cluster_id: u64) -> Option<Self> {
Some(Self {
pub fn success(cluster_id: u64) -> Self {
Self {
protocol_version: PROTOCOL_VERSION,
cluster_id,
..Default::default()
})
}
}
#[inline]
pub fn failed(cluster_id: u64, error: Error) -> Option<Self> {
Some(Self {
pub fn failed(cluster_id: u64, error: Error) -> Self {
Self {
protocol_version: PROTOCOL_VERSION,
cluster_id,
error: Some(error),
})
}
}
impl TableName {
pub fn new(
catalog: impl Into<String>,
schema: impl Into<String>,
table: impl Into<String>,
) -> Self {
Self {
catalog_name: catalog.into(),
schema_name: schema.into(),
table_name: table.into(),
}
}
}
impl RouteRequest {
pub fn add_table(mut self, table_name: TableName) -> Self {
self.table_names.push(table_name);
self
}
}
impl CreateRequest {
pub fn add_partition(mut self, partition: Partition) -> Self {
self.partitions.push(partition);
self
}
}
impl Region {
pub fn new(id: u64, name: impl Into<String>, partition: Partition) -> Self {
Self {
id,
name: name.into(),
partition: Some(partition),
..Default::default()
macro_rules! gen_set_header {
($req: ty) => {
impl $req {
#[inline]
pub fn set_header(&mut self, (cluster_id, member_id): (u64, u64)) {
self.header = Some(RequestHeader::new((cluster_id, member_id)));
}
}
}
pub fn attr(mut self, key: impl Into<String>, val: impl Into<String>) -> Self {
self.attrs.insert(key.into(), val.into());
self
}
};
}
impl Partition {
pub fn new() -> Self {
Default::default()
}
pub fn column_list(mut self, column_list: Vec<Vec<u8>>) -> Self {
self.column_list = column_list;
self
}
pub fn value_list(mut self, value_list: Vec<Vec<u8>>) -> Self {
self.value_list = value_list;
self
}
}
gen_set_header!(HeartbeatRequest);
gen_set_header!(RouteRequest);
gen_set_header!(CreateRequest);
gen_set_header!(RangeRequest);
gen_set_header!(PutRequest);
gen_set_header!(BatchPutRequest);
gen_set_header!(CompareAndPutRequest);
gen_set_header!(DeleteRangeRequest);

View File

@@ -13,6 +13,7 @@ datafusion = { git = "https://github.com/apache/arrow-datafusion.git", branch =
snafu = { version = "0.7", features = ["backtraces"] }
tokio = { version = "1.0", features = ["full"] }
tonic = "0.8"
tower = "0.4"
[dependencies.arrow]
package = "arrow2"

View File

@@ -6,6 +6,8 @@ use std::time::Duration;
use snafu::ResultExt;
use tonic::transport::Channel as InnerChannel;
use tonic::transport::Endpoint;
use tonic::transport::Uri;
use tower::make::MakeConnection;
use crate::error;
use crate::error::Result;
@@ -55,6 +57,53 @@ impl ChannelManager {
return Ok(ch.channel.clone());
}
let endpoint = self.build_endpoint(addr)?;
let inner_channel = endpoint.connect_lazy();
let channel = Channel {
channel: inner_channel.clone(),
access: 1,
use_default_connector: true,
};
pool.put(addr, channel);
Ok(inner_channel)
}
pub fn reset_with_connector<C>(
&self,
addr: impl AsRef<str>,
connector: C,
) -> Result<InnerChannel>
where
C: MakeConnection<Uri> + Send + 'static,
C::Connection: Unpin + Send + 'static,
C::Future: Send + 'static,
Box<dyn std::error::Error + Send + Sync>: From<C::Error> + Send + 'static,
{
let addr = addr.as_ref();
let endpoint = self.build_endpoint(addr)?;
let inner_channel = endpoint.connect_with_connector_lazy(connector);
let channel = Channel {
channel: inner_channel.clone(),
access: 1,
use_default_connector: false,
};
let mut pool = self.pool.lock().unwrap();
pool.put(addr, channel);
Ok(inner_channel)
}
pub fn retain_channel<F>(&self, f: F)
where
F: FnMut(&String, &mut Channel) -> bool,
{
let mut pool = self.pool.lock().unwrap();
pool.retain_channel(f);
}
fn build_endpoint(&self, addr: &str) -> Result<Endpoint> {
let mut endpoint =
Endpoint::new(format!("http://{}", addr)).context(error::CreateChannelSnafu)?;
@@ -92,14 +141,7 @@ impl ChannelManager {
.tcp_keepalive(self.config.tcp_keepalive)
.tcp_nodelay(self.config.tcp_nodelay);
let inner_channel = endpoint.connect_lazy();
let channel = Channel {
channel: inner_channel.clone(),
access: 1,
};
pool.put(addr, channel);
Ok(inner_channel)
Ok(endpoint)
}
}
@@ -252,6 +294,24 @@ impl ChannelConfig {
}
}
#[derive(Debug)]
pub struct Channel {
channel: InnerChannel,
access: usize,
use_default_connector: bool,
}
impl Channel {
#[inline]
pub fn access(&self) -> usize {
self.access
}
#[inline]
pub fn use_default_connector(&self) -> bool {
self.use_default_connector
}
}
#[derive(Debug)]
struct Pool {
channels: HashMap<String, Channel>,
@@ -277,12 +337,6 @@ impl Pool {
}
}
#[derive(Debug)]
struct Channel {
channel: InnerChannel,
access: usize,
}
async fn recycle_channel_in_loop(pool: Arc<Mutex<Pool>>, interval_secs: u64) {
let mut interval = tokio::time::interval(Duration::from_secs(interval_secs));
@@ -302,6 +356,8 @@ async fn recycle_channel_in_loop(pool: Arc<Mutex<Pool>>, interval_secs: u64) {
#[cfg(test)]
mod tests {
use tower::service_fn;
use super::*;
#[should_panic]
@@ -326,19 +382,7 @@ mod tests {
channels: HashMap::default(),
};
let pool = Arc::new(Mutex::new(pool));
let config = ChannelConfig::new()
.timeout(Duration::from_secs(1))
.connect_timeout(Duration::from_secs(1))
.concurrency_limit(1)
.rate_limit(1, Duration::from_secs(1))
.initial_stream_window_size(1)
.initial_connection_window_size(1)
.http2_keep_alive_interval(Duration::from_secs(1))
.http2_keep_alive_timeout(Duration::from_secs(1))
.http2_keep_alive_while_idle(true)
.http2_adaptive_window(true)
.tcp_keepalive(Duration::from_secs(1))
.tcp_nodelay(true);
let config = ChannelConfig::new();
let mgr = ChannelManager { pool, config };
let addr = "test_uri";
@@ -419,4 +463,68 @@ mod tests {
cfg
);
}
#[test]
fn test_build_endpoint() {
let pool = Pool {
channels: HashMap::default(),
};
let pool = Arc::new(Mutex::new(pool));
let config = ChannelConfig::new()
.timeout(Duration::from_secs(3))
.connect_timeout(Duration::from_secs(5))
.concurrency_limit(6)
.rate_limit(5, Duration::from_secs(1))
.initial_stream_window_size(10)
.initial_connection_window_size(20)
.http2_keep_alive_interval(Duration::from_secs(1))
.http2_keep_alive_timeout(Duration::from_secs(3))
.http2_keep_alive_while_idle(true)
.http2_adaptive_window(true)
.tcp_keepalive(Duration::from_secs(2))
.tcp_nodelay(true);
let mgr = ChannelManager { pool, config };
let res = mgr.build_endpoint("test_addr");
assert!(res.is_ok());
}
#[tokio::test]
async fn test_channel_with_connector() {
let pool = Pool {
channels: HashMap::default(),
};
let pool = Arc::new(Mutex::new(pool));
let config = ChannelConfig::new();
let mgr = ChannelManager { pool, config };
let addr = "test_addr";
let res = mgr.get(addr);
assert!(res.is_ok());
mgr.retain_channel(|addr, channel| {
assert_eq!("test_addr", addr);
assert!(channel.use_default_connector());
true
});
let (client, _) = tokio::io::duplex(1024);
let mut client = Some(client);
let res = mgr.reset_with_connector(
addr,
service_fn(move |_| {
let client = client.take().unwrap();
async move { Ok::<_, std::io::Error>(client) }
}),
);
assert!(res.is_ok());
mgr.retain_channel(|addr, channel| {
assert_eq!("test_addr", addr);
assert!(!channel.use_default_connector());
true
});
}
}

View File

@@ -40,7 +40,7 @@ impl HeartbeatSender {
#[inline]
pub async fn send(&self, mut req: HeartbeatRequest) -> Result<()> {
req.header = RequestHeader::new(self.id);
req.set_header(self.id);
self.sender.send(req).await.map_err(|e| {
error::SendHeartbeatSnafu {
err_msg: e.to_string(),
@@ -158,7 +158,7 @@ impl Inner {
let mut leader = None;
for addr in &self.peers {
let req = AskLeaderRequest {
header: header.clone(),
header: Some(header.clone()),
};
let mut client = self.make_client(addr)?;
match client.ask_leader(req).await {
@@ -183,7 +183,7 @@ impl Inner {
let (sender, receiver) = mpsc::channel::<HeartbeatRequest>(128);
let handshake = HeartbeatRequest {
header: RequestHeader::new(self.id),
header: Some(RequestHeader::new(self.id)),
..Default::default()
};
sender.send(handshake).await.map_err(|e| {

View File

@@ -3,7 +3,6 @@ use std::sync::Arc;
use api::v1::meta::router_client::RouterClient;
use api::v1::meta::CreateRequest;
use api::v1::meta::RequestHeader;
use api::v1::meta::RouteRequest;
use api::v1::meta::RouteResponse;
use common_grpc::channel_manager::ChannelManager;
@@ -92,7 +91,7 @@ impl Inner {
async fn route(&self, mut req: RouteRequest) -> Result<RouteResponse> {
let mut client = self.random_client()?;
req.header = RequestHeader::new(self.id);
req.set_header(self.id);
let res = client.route(req).await.context(error::TonicStatusSnafu)?;
Ok(res.into_inner())
@@ -100,7 +99,7 @@ impl Inner {
async fn create(&self, mut req: CreateRequest) -> Result<RouteResponse> {
let mut client = self.random_client()?;
req.header = RequestHeader::new(self.id);
req.set_header(self.id);
let res = client.create(req).await.context(error::TonicStatusSnafu)?;
Ok(res.into_inner())

View File

@@ -12,7 +12,6 @@ use api::v1::meta::PutRequest;
use api::v1::meta::PutResponse;
use api::v1::meta::RangeRequest;
use api::v1::meta::RangeResponse;
use api::v1::meta::RequestHeader;
use common_grpc::channel_manager::ChannelManager;
use snafu::ensure;
use snafu::OptionExt;
@@ -117,7 +116,7 @@ impl Inner {
async fn range(&self, mut req: RangeRequest) -> Result<RangeResponse> {
let mut client = self.random_client()?;
req.header = RequestHeader::new(self.id);
req.set_header(self.id);
let res = client.range(req).await.context(error::TonicStatusSnafu)?;
Ok(res.into_inner())
@@ -125,7 +124,7 @@ impl Inner {
async fn put(&self, mut req: PutRequest) -> Result<PutResponse> {
let mut client = self.random_client()?;
req.header = RequestHeader::new(self.id);
req.set_header(self.id);
let res = client.put(req).await.context(error::TonicStatusSnafu)?;
Ok(res.into_inner())
@@ -133,7 +132,7 @@ impl Inner {
async fn batch_put(&self, mut req: BatchPutRequest) -> Result<BatchPutResponse> {
let mut client = self.random_client()?;
req.header = RequestHeader::new(self.id);
req.set_header(self.id);
let res = client
.batch_put(req)
.await
@@ -147,7 +146,7 @@ impl Inner {
mut req: CompareAndPutRequest,
) -> Result<CompareAndPutResponse> {
let mut client = self.random_client()?;
req.header = RequestHeader::new(self.id);
req.set_header(self.id);
let res = client
.compare_and_put(req)
.await
@@ -158,7 +157,7 @@ impl Inner {
async fn delete_range(&self, mut req: DeleteRangeRequest) -> Result<DeleteRangeResponse> {
let mut client = self.random_client()?;
req.header = RequestHeader::new(self.id);
req.set_header(self.id);
let res = client
.delete_range(req)
.await

View File

@@ -46,7 +46,7 @@ mod tests {
};
let req = HeartbeatRequest {
header: RequestHeader::new((1, 2)),
header: Some(RequestHeader::new((1, 2))),
..Default::default()
};
let mut acc = HeartbeatAccumulator::default();

View File

@@ -35,10 +35,10 @@ impl FromStr for LeaseKey {
let cluster_id = caps[1].to_string();
let node_id = caps[2].to_string();
let cluster_id = cluster_id.parse::<u64>().context(error::ParseNumSnafu {
let cluster_id: u64 = cluster_id.parse().context(error::ParseNumSnafu {
err_msg: format!("invalid cluster_id: {}", cluster_id),
})?;
let node_id = node_id.parse::<u64>().context(error::ParseNumSnafu {
let node_id: u64 = node_id.parse().context(error::ParseNumSnafu {
err_msg: format!("invalid node_id: {}", node_id),
})?;

View File

@@ -150,7 +150,7 @@ mod tests {
let meta_srv = MetaSrv::new(MetaSrvOptions::default(), kv_store).await;
let req = AskLeaderRequest {
header: RequestHeader::new((1, 1)),
header: Some(RequestHeader::new((1, 1))),
};
let res = meta_srv.ask_leader(req.into_request()).await.unwrap();

View File

@@ -127,8 +127,9 @@ async fn handle_create(
region_routes,
};
let header = Some(ResponseHeader::success(cluster_id));
Ok(RouteResponse {
header: ResponseHeader::success(cluster_id),
header,
peers,
table_routes: vec![table_route],
})
@@ -153,13 +154,25 @@ mod tests {
let meta_srv = MetaSrv::new(MetaSrvOptions::default(), kv_store).await;
let req = RouteRequest {
header: RequestHeader::new((1, 1)),
..Default::default()
header: Some(RequestHeader::new((1, 1))),
table_names: vec![
TableName {
catalog_name: "catalog1".to_string(),
schema_name: "schema1".to_string(),
table_name: "table1".to_string(),
},
TableName {
catalog_name: "catalog1".to_string(),
schema_name: "schema1".to_string(),
table_name: "table2".to_string(),
},
TableName {
catalog_name: "catalog1".to_string(),
schema_name: "schema1".to_string(),
table_name: "table3".to_string(),
},
],
};
let req = req
.add_table(TableName::new("catalog1", "schema1", "table1"))
.add_table(TableName::new("catalog1", "schema1", "table2"))
.add_table(TableName::new("catalog1", "schema1", "table3"));
let _res = meta_srv.route(req.into_request()).await.unwrap();
}
@@ -185,19 +198,24 @@ mod tests {
#[tokio::test]
async fn test_handle_create() {
let kv_store = Arc::new(NoopKvStore {});
let table_name = TableName::new("test_catalog", "test_db", "table1");
let req = CreateRequest {
header: RequestHeader::new((1, 1)),
table_name: Some(table_name),
..Default::default()
let table_name = TableName {
catalog_name: "test_catalog".to_string(),
schema_name: "test_db".to_string(),
table_name: "table1".to_string(),
};
let p0 = Partition {
column_list: vec![b"col1".to_vec(), b"col2".to_vec()],
value_list: vec![b"v1".to_vec(), b"v2".to_vec()],
};
let p1 = Partition {
column_list: vec![b"col1".to_vec(), b"col2".to_vec()],
value_list: vec![b"v11".to_vec(), b"v22".to_vec()],
};
let req = CreateRequest {
header: Some(RequestHeader::new((1, 1))),
table_name: Some(table_name),
partitions: vec![p0, p1],
};
let p0 = Partition::new()
.column_list(vec![b"col1".to_vec(), b"col2".to_vec()])
.value_list(vec![b"v1".to_vec(), b"v2".to_vec()]);
let p1 = Partition::new()
.column_list(vec![b"col1".to_vec(), b"col2".to_vec()])
.value_list(vec![b"v11".to_vec(), b"v22".to_vec()]);
let req = req.add_partition(p0).add_partition(p1);
let ctx = Context {
datanode_lease_secs: 10,
kv_store,

View File

@@ -69,8 +69,9 @@ impl KvStore for EtcdStore {
.map(|kv| KvPair::new(kv).into())
.collect::<Vec<_>>();
let header = Some(ResponseHeader::success(cluster_id));
Ok(RangeResponse {
header: ResponseHeader::success(cluster_id),
header,
kvs,
more: res.more(),
})
@@ -93,10 +94,8 @@ impl KvStore for EtcdStore {
let prev_kv = res.prev_key().map(|kv| KvPair::new(kv).into());
Ok(PutResponse {
header: ResponseHeader::success(cluster_id),
prev_kv,
})
let header = Some(ResponseHeader::success(cluster_id));
Ok(PutResponse { header, prev_kv })
}
async fn batch_put(&self, req: BatchPutRequest) -> Result<BatchPutResponse> {
@@ -131,10 +130,8 @@ impl KvStore for EtcdStore {
}
}
Ok(BatchPutResponse {
header: ResponseHeader::success(cluster_id),
prev_kvs,
})
let header = Some(ResponseHeader::success(cluster_id));
Ok(BatchPutResponse { header, prev_kvs })
}
async fn compare_and_put(&self, req: CompareAndPutRequest) -> Result<CompareAndPutResponse> {
@@ -186,8 +183,9 @@ impl KvStore for EtcdStore {
}
}
let header = Some(ResponseHeader::success(cluster_id));
Ok(CompareAndPutResponse {
header: ResponseHeader::success(cluster_id),
header,
success,
prev_kv,
})
@@ -213,8 +211,9 @@ impl KvStore for EtcdStore {
.map(|kv| KvPair::new(kv).into())
.collect::<Vec<_>>();
let header = Some(ResponseHeader::success(cluster_id));
Ok(DeleteRangeResponse {
header: ResponseHeader::success(cluster_id),
header,
deleted: res.deleted(),
prev_kvs,
})