unfreeze client session upon callback

This commit is contained in:
Stas Kelvich
2021-06-25 13:58:37 +03:00
parent 605b90c6c7
commit 1b6d99db7c
7 changed files with 155 additions and 60 deletions

2
Cargo.lock generated
View File

@@ -1355,6 +1355,8 @@ dependencies = [
"hex",
"md5",
"rand",
"serde",
"serde_json",
"tokio",
"tokio-postgres 0.7.2",
"zenith_utils",

View File

@@ -12,7 +12,8 @@ bytes = { version = "1.0.1", features = ['serde'] }
md5 = "0.7.0"
rand = "0.8.3"
hex = "0.4.3"
serde = "1"
serde_json = "1"
tokio = "1.7.1"
tokio-postgres = "0.7.2"

View File

@@ -1,10 +1,17 @@
use anyhow::{bail, Result};
use serde::{Deserialize, Serialize};
use std::{collections::HashMap, net::SocketAddr};
pub struct CPlaneApi {
address: SocketAddr,
}
#[derive(Serialize, Deserialize)]
pub struct DatabaseInfo {
pub addr: SocketAddr,
pub connstr: String,
}
// mock cplane api
impl CPlaneApi {
pub fn new(address: &SocketAddr) -> CPlaneApi {
@@ -44,11 +51,17 @@ impl CPlaneApi {
}
}
pub fn get_database_uri(&self, _user: &String, _database: &String) -> Result<String> {
Ok("user=stas dbname=stas".to_string())
pub fn get_database_uri(&self, _user: &String, _database: &String) -> Result<DatabaseInfo> {
Ok(DatabaseInfo {
addr: "127.0.0.1:5432".parse()?,
connstr: "user=stas dbname=stas".into(),
})
}
pub fn create_database(&self, _user: &String, _database: &String) -> Result<String> {
Ok("user=stas dbname=stas".to_string())
pub fn create_database(&self, _user: &String, _database: &String) -> Result<DatabaseInfo> {
Ok(DatabaseInfo {
addr: "127.0.0.1:5432".parse()?,
connstr: "user=stas dbname=stas".into(),
})
}
}

View File

@@ -6,10 +6,14 @@
/// in somewhat transparent manner (again via communication with control plane API).
///
use std::{
collections::HashMap,
net::{SocketAddr, TcpListener},
sync::{mpsc, Mutex},
thread,
};
use cplane_api::DatabaseInfo;
mod cplane_api;
mod mgmt;
mod proxy;
@@ -26,20 +30,29 @@ pub struct ProxyConf {
pub cplane_address: SocketAddr,
}
pub struct ProxyState {
pub conf: ProxyConf,
pub waiters: Mutex<HashMap<String, mpsc::Sender<anyhow::Result<DatabaseInfo>>>>,
}
fn main() -> anyhow::Result<()> {
let conf = ProxyConf {
proxy_address: "0.0.0.0:4000".parse()?,
mgmt_address: "0.0.0.0:8080".parse()?,
cplane_address: "127.0.0.1:3000".parse()?,
};
let conf: &'static ProxyConf = Box::leak(Box::new(conf));
let state = ProxyState {
conf,
waiters: Mutex::new(HashMap::new()),
};
let state: &'static ProxyState = Box::leak(Box::new(state));
// Check that we can bind to address before further initialization
println!("Starting proxy on {}", conf.proxy_address);
let pageserver_listener = TcpListener::bind(conf.proxy_address)?;
println!("Starting proxy on {}", state.conf.proxy_address);
let pageserver_listener = TcpListener::bind(state.conf.proxy_address)?;
println!("Starting mgmt on {}", conf.mgmt_address);
let mgmt_listener = TcpListener::bind(conf.mgmt_address)?;
println!("Starting mgmt on {}", state.conf.mgmt_address);
let mgmt_listener = TcpListener::bind(state.conf.mgmt_address)?;
let mut threads = Vec::new();
@@ -48,13 +61,13 @@ fn main() -> anyhow::Result<()> {
threads.push(
thread::Builder::new()
.name("Proxy thread".into())
.spawn(move || proxy::thread_main(&conf, pageserver_listener))?,
.spawn(move || proxy::thread_main(&state, pageserver_listener))?,
);
threads.push(
thread::Builder::new()
.name("Mgmt thread".into())
.spawn(move || mgmt::thread_main(&conf, mgmt_listener))?,
.spawn(move || mgmt::thread_main(&state, mgmt_listener))?,
);
for t in threads {

View File

@@ -3,41 +3,70 @@ use std::{
thread,
};
use anyhow::bail;
use bytes::Bytes;
use serde::{Deserialize, Serialize};
use zenith_utils::{
postgres_backend::{self, PostgresBackend},
pq_proto::{BeMessage, SINGLE_COL_ROWDESC},
};
use crate::ProxyConf;
use crate::{cplane_api::DatabaseInfo, ProxyState};
///
/// Main proxy listener loop.
///
/// Listens for connections, and launches a new handler thread for each.
///
pub fn thread_main(conf: &'static ProxyConf, listener: TcpListener) -> anyhow::Result<()> {
pub fn thread_main(state: &'static ProxyState, listener: TcpListener) -> anyhow::Result<()> {
loop {
let (socket, peer_addr) = listener.accept()?;
println!("accepted connection from {}", peer_addr);
socket.set_nodelay(true).unwrap();
thread::spawn(move || {
if let Err(err) = mgmt_conn_main(conf, socket) {
if let Err(err) = mgmt_conn_main(state, socket) {
println!("error: {}", err);
}
});
}
}
pub fn mgmt_conn_main(conf: &'static ProxyConf, socket: TcpStream) -> anyhow::Result<()> {
let mut conn_handler = MgmtHandler { conf };
pub fn mgmt_conn_main(state: &'static ProxyState, socket: TcpStream) -> anyhow::Result<()> {
let mut conn_handler = MgmtHandler { state };
let mut pgbackend = PostgresBackend::new(socket, postgres_backend::AuthType::Trust)?;
pgbackend.run(&mut conn_handler)
}
struct MgmtHandler {
conf: &'static ProxyConf,
state: &'static ProxyState,
}
/// Serialized examples:
// {
// "session_id": "71d6d03e6d93d99a",
// "result": {
// "Success": {
// "addr": "127.0.0.1:5432",
// "connstr": "user=stas dbname=stas"
// }
// }
// }
// {
// "session_id": "71d6d03e6d93d99a",
// "result": {
// "Failure": "oops"
// }
// }
#[derive(Serialize, Deserialize)]
pub struct PsqlSessionResponse {
session_id: String,
result: PsqlSessionResult,
}
#[derive(Serialize, Deserialize)]
pub enum PsqlSessionResult {
Success(DatabaseInfo),
Failure(String),
}
impl postgres_backend::Handler for MgmtHandler {
@@ -46,13 +75,40 @@ impl postgres_backend::Handler for MgmtHandler {
pgb: &mut PostgresBackend,
query_string: Bytes,
) -> anyhow::Result<()> {
let (_, query_string) = query_string
.split_last()
.ok_or_else(|| anyhow::Error::msg("protocol violation"))?;
let (_, query_string) = query_string
.split_last()
.ok_or_else(|| anyhow::Error::msg("protocol violation"))?;
println!("Got mgmt query: {:?}", query_string);
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
.write_message_noflush(&BeMessage::DataRow(&[Some(b"mgmt_ok")]))?
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
let resp: PsqlSessionResponse = serde_json::from_slice(&query_string)?;
pgb.flush()?;
Ok(())
let waiters = self.state.waiters.lock().unwrap();
let sender = waiters
.get(&resp.session_id)
.ok_or_else(|| anyhow::Error::msg("psql session_id is not found"))?;
match resp.result {
PsqlSessionResult::Success(db_info) => {
sender.send(Ok(db_info))?;
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
.write_message_noflush(&BeMessage::DataRow(&[Some(b"ok")]))?
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
pgb.flush()?;
Ok(())
}
PsqlSessionResult::Failure(message) => {
sender.send(Err(anyhow::Error::msg(message.clone())))?;
bail!("psql session request failed: {}", message)
}
}
}
}

View File

@@ -1,9 +1,12 @@
use crate::{cplane_api::CPlaneApi, ProxyConf};
use crate::cplane_api::CPlaneApi;
use crate::cplane_api::DatabaseInfo;
use crate::ProxyState;
use anyhow::bail;
use tokio_postgres::NoTls;
use rand::Rng;
use std::sync::mpsc::channel;
use std::thread;
use tokio::io::AsyncWriteExt;
use zenith_utils::postgres_backend::{PostgresBackend, ProtoState};
@@ -16,7 +19,7 @@ use zenith_utils::{postgres_backend, pq_proto::BeMessage};
/// Listens for connections, and launches a new handler thread for each.
///
pub fn thread_main(
conf: &'static ProxyConf,
state: &'static ProxyState,
listener: std::net::TcpListener,
) -> anyhow::Result<()> {
loop {
@@ -25,15 +28,17 @@ pub fn thread_main(
socket.set_nodelay(true).unwrap();
thread::spawn(move || {
if let Err(err) = proxy_conn_main(conf, socket) {
if let Err(err) = proxy_conn_main(state, socket) {
println!("error: {}", err);
}
});
}
}
// XXX: clean up fields
struct ProxyConnection {
conf: &'static ProxyConf,
state: &'static ProxyState,
existing_user: bool,
cplane: CPlaneApi,
@@ -42,20 +47,23 @@ struct ProxyConnection {
pgb: PostgresBackend,
md5_salt: [u8; 4],
psql_session_id: String,
}
pub fn proxy_conn_main(
conf: &'static ProxyConf,
state: &'static ProxyState,
socket: std::net::TcpStream,
) -> anyhow::Result<()> {
let mut conn = ProxyConnection {
conf,
state,
existing_user: false,
cplane: CPlaneApi::new(&conf.cplane_address),
cplane: CPlaneApi::new(&state.conf.cplane_address),
user: "".into(),
database: "".into(),
pgb: PostgresBackend::new(socket, postgres_backend::AuthType::MD5)?,
md5_salt: [0u8; 4],
psql_session_id: "".into(),
};
// Check StartupMessage
@@ -63,7 +71,7 @@ pub fn proxy_conn_main(
conn.handle_startup()?;
// both scenarious here should end up producing database connection string
let database_uri = if conn.existing_user {
let db_info = if conn.existing_user {
conn.handle_existing_user()?
} else {
conn.handle_new_user()?
@@ -75,11 +83,7 @@ pub fn proxy_conn_main(
.build()
.unwrap();
let _ = runtime.block_on(proxy_pass(
conn.pgb,
"127.0.0.1:5432".to_string(),
database_uri,
))?;
let _ = runtime.block_on(proxy_pass(conn.pgb, db_info))?;
println!("proxy_conn_main done;");
@@ -134,12 +138,12 @@ impl ProxyConnection {
Ok(())
}
fn handle_existing_user(&mut self) -> anyhow::Result<String> {
fn handle_existing_user(&mut self) -> anyhow::Result<DatabaseInfo> {
// ask password
rand::thread_rng().fill(&mut self.md5_salt);
self.pgb
.write_message(&BeMessage::AuthenticationMD5Password(&self.md5_salt))?;
self.pgb.state = ProtoState::Authentication;
self.pgb.state = ProtoState::Authentication; // XXX
// check password
println!("handle_existing_user");
@@ -171,21 +175,21 @@ impl ProxyConnection {
self.cplane.get_database_uri(&self.user, &self.database)
}
fn handle_new_user(&mut self) -> anyhow::Result<String> {
let mut reg_id_buf = [0u8; 8];
rand::thread_rng().fill(&mut reg_id_buf);
let reg_id = hex::encode(reg_id_buf);
fn handle_new_user(&mut self) -> anyhow::Result<DatabaseInfo> {
let mut psql_session_id_buf = [0u8; 8];
rand::thread_rng().fill(&mut psql_session_id_buf);
self.psql_session_id = hex::encode(psql_session_id_buf);
let hello_message = format!("☀️ Welcome to Zenith!
To proceed with database creation open following link:
https://console.zenith.tech/claim_db/{}
https://console.zenith.tech/psql_session/{}
It needed to be done once and we will send you '.pgpass' file which will allow you to access or create
databases without opening the browser.
", reg_id);
", self.psql_session_id);
self.pgb
.write_message_noflush(&BeMessage::AuthenticationOk)?;
@@ -195,10 +199,24 @@ databases without opening the browser.
.write_message(&BeMessage::NoticeResponse(hello_message.to_string()))?;
// await for database creation
let connstring = self.cplane.get_database_uri(&self.user, &self.database)?;
let (tx, rx) = channel::<anyhow::Result<DatabaseInfo>>();
let _ = self
.state
.waiters
.lock()
.unwrap()
.insert(self.psql_session_id.clone(), tx);
// Wait for web console response
// XXX: respond with error to client
let dbinfo = rx.recv()??;
self.pgb.write_message_noflush(&BeMessage::NoticeResponse(
"Connecting to database.".to_string(),
))?;
self.pgb.write_message(&BeMessage::ReadyForQuery)?;
Ok(connstring)
Ok(dbinfo)
}
fn check_auth_md5(&self, md5_response: &[u8]) -> anyhow::Result<()> {
@@ -208,13 +226,9 @@ databases without opening the browser.
}
}
async fn proxy_pass(
pgb: PostgresBackend,
proxy_addr: String,
connstr: String,
) -> anyhow::Result<()> {
let mut socket = tokio::net::TcpStream::connect(proxy_addr).await?;
let config = connstr.parse::<tokio_postgres::Config>()?;
async fn proxy_pass(pgb: PostgresBackend, db_info: DatabaseInfo) -> anyhow::Result<()> {
let mut socket = tokio::net::TcpStream::connect(db_info.addr).await?;
let config = db_info.connstr.parse::<tokio_postgres::Config>()?;
let _ = config.connect_raw(&mut socket, NoTls).await?;
println!("Connected to pg, proxying");

View File

@@ -207,13 +207,9 @@ impl PostgresBackend {
assert!(self.state == ProtoState::Authentication);
let (trailing_null, md5_response) = m.split_last().unwrap();
if *trailing_null != 0 {
let errmsg = "protocol violation";
self.write_message(&BeMessage::ErrorResponse(format!("{}", errmsg)))?;
bail!("auth failed: {}", errmsg);
}
let (_, md5_response) = m
.split_last()
.ok_or_else(|| anyhow::Error::msg("protocol violation"))?;
if let Err(e) = handler.check_auth_md5(self, md5_response) {
self.write_message(&BeMessage::ErrorResponse(format!("{}", e)))?;