fix tests

This commit is contained in:
Conrad Ludgate
2024-04-18 17:23:33 +01:00
parent f40bf86575
commit 4c573e13d4

View File

@@ -4,7 +4,8 @@ use parking_lot::RwLock;
use pin_project_lite::pin_project;
use rand::Rng;
use smallvec::SmallVec;
use std::{collections::HashMap, sync::Arc, sync::Weak, time::Duration};
use std::sync::Weak;
use std::{collections::HashMap, sync::Arc, time::Duration};
use std::{
fmt,
task::{ready, Poll},
@@ -621,7 +622,7 @@ impl<C: ClientInnerExt, I: Future<Output = ()>> Future for DbConnection<C, I> {
.as_mut()
.reset(Instant::now() + *this.idle);
info!("connection idle");
if let Some(pool) = this.pool.clone().upgrade() {
if let Some(pool) = this.pool.upgrade() {
// remove client from pool - should close the connection if it's idle.
// does nothing if the client is currently checked-out and in-use
if pool
@@ -759,7 +760,7 @@ impl<C: ClientInnerExt> Deref for Client<C> {
}
impl<C: ClientInnerExt> Client<C> {
fn do_drop(&mut self) -> Option<impl FnOnce()> {
fn do_drop(&mut self) {
let conn_info = self.conn_info.clone();
let client = self
.inner
@@ -770,69 +771,59 @@ impl<C: ClientInnerExt> Client<C> {
if client.is_closed() {
info!(%conn_id, "pool: throwing away connection '{conn_info}' because connection is closed");
return None;
return;
}
if let Some(conn_pool) = std::mem::take(&mut self.pool).upgrade() {
let current_span = self.span.clone();
if let Some(conn_pool) = self.pool.upgrade() {
// return connection to the pool
return Some(move || {
let _span = current_span.enter();
EndpointConnPool::put(&conn_pool, &conn_info, client);
});
let _span = self.span.enter();
EndpointConnPool::put(&conn_pool, &conn_info, client);
}
None
}
}
impl<C: ClientInnerExt> Drop for Client<C> {
fn drop(&mut self) {
if let Some(drop) = self.do_drop() {
tokio::task::spawn_blocking(drop);
}
self.do_drop();
}
}
#[cfg(test)]
mod tests {
use std::{mem, sync::atomic::AtomicBool};
use crate::{BranchId, EndpointId, ProjectId};
use super::*;
struct MockClient(Arc<AtomicBool>);
impl MockClient {
fn new(is_closed: bool) -> Self {
MockClient(Arc::new(is_closed.into()))
}
}
struct MockClient(CancellationToken);
impl ClientInnerExt for MockClient {
fn is_closed(&self) -> bool {
self.0.load(atomic::Ordering::Relaxed)
self.0.is_cancelled()
}
fn get_process_id(&self) -> i32 {
0
}
}
fn create_inner() -> ClientInner<MockClient> {
create_inner_with(MockClient::new(false))
}
fn create_inner_with(client: MockClient) -> ClientInner<MockClient> {
ClientInner {
inner: client,
session: tokio::sync::watch::Sender::new(uuid::Uuid::new_v4()),
cancel: CancellationToken::new(),
aux: MetricsAuxInfo {
fn create_inner(
global_pool: Arc<GlobalConnPool<MockClient>>,
conn_info: ConnInfo,
) -> (Client<MockClient>, CancellationToken) {
let cancelled = CancellationToken::new();
let client = poll_client(
global_pool,
&mut RequestMonitoring::test(),
conn_info,
MockClient(cancelled.clone()),
cancelled.clone().cancelled_owned(),
uuid::Uuid::new_v4(),
MetricsAuxInfo {
endpoint_id: (&EndpointId::from("endpoint")).into(),
project_id: (&ProjectId::from("project")).into(),
branch_id: (&BranchId::from("branch")).into(),
cold_start_info: crate::console::messages::ColdStartInfo::Warm,
},
conn_id: uuid::Uuid::new_v4(),
}
);
(client, cancelled)
}
#[tokio::test]
@@ -859,51 +850,36 @@ mod tests {
dbname: "dbname".into(),
password: "password".as_bytes().into(),
};
let ep_pool = Arc::downgrade(
&pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key().unwrap()),
);
{
let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
let (mut client, _) = create_inner(pool.clone(), conn_info.clone());
assert_eq!(0, pool.get_global_connections_count());
client.inner().1.discard();
drop(client);
// Discard should not add the connection from the pool.
assert_eq!(0, pool.get_global_connections_count());
}
{
let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
client.do_drop().unwrap()();
mem::forget(client); // drop the client
let (client, _) = create_inner(pool.clone(), conn_info.clone());
drop(client);
assert_eq!(1, pool.get_global_connections_count());
}
{
let mut closed_client = Client::new(
create_inner_with(MockClient::new(true)),
conn_info.clone(),
ep_pool.clone(),
);
closed_client.do_drop().unwrap()();
mem::forget(closed_client); // drop the client
// The closed client shouldn't be added to the pool.
let (client, cancel) = create_inner(pool.clone(), conn_info.clone());
cancel.cancel();
drop(client);
// The closed client shouldn't be added to the pool.
assert_eq!(1, pool.get_global_connections_count());
}
let is_closed: Arc<AtomicBool> = Arc::new(false.into());
{
let mut client = Client::new(
create_inner_with(MockClient(is_closed.clone())),
conn_info.clone(),
ep_pool.clone(),
);
client.do_drop().unwrap()();
mem::forget(client); // drop the client
let cancel = {
let (client, cancel) = create_inner(pool.clone(), conn_info.clone());
drop(client);
// The client should be added to the pool.
assert_eq!(2, pool.get_global_connections_count());
}
cancel
};
{
let mut client = Client::new(create_inner(), conn_info, ep_pool);
client.do_drop().unwrap()();
mem::forget(client); // drop the client
let client = create_inner(pool.clone(), conn_info.clone());
drop(client);
// The client shouldn't be added to the pool. Because the ep-pool is full.
assert_eq!(2, pool.get_global_connections_count());
}
@@ -917,25 +893,19 @@ mod tests {
dbname: "dbname".into(),
password: "password".as_bytes().into(),
};
let ep_pool = Arc::downgrade(
&pool.get_or_create_endpoint_pool(&conn_info.endpoint_cache_key().unwrap()),
);
{
let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
client.do_drop().unwrap()();
mem::forget(client); // drop the client
let client = create_inner(pool.clone(), conn_info.clone());
drop(client);
assert_eq!(3, pool.get_global_connections_count());
}
{
let mut client = Client::new(create_inner(), conn_info.clone(), ep_pool.clone());
client.do_drop().unwrap()();
mem::forget(client); // drop the client
let client = create_inner(pool.clone(), conn_info.clone());
drop(client);
// The client shouldn't be added to the pool. Because the global pool is full.
assert_eq!(3, pool.get_global_connections_count());
}
is_closed.store(true, atomic::Ordering::Relaxed);
cancel.cancel();
// Do gc for all shards.
pool.gc(0);
pool.gc(1);