mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-07 13:32:57 +00:00
unfreeze client session upon callback
This commit is contained in:
2
Cargo.lock
generated
2
Cargo.lock
generated
@@ -1355,6 +1355,8 @@ dependencies = [
|
||||
"hex",
|
||||
"md5",
|
||||
"rand",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tokio",
|
||||
"tokio-postgres 0.7.2",
|
||||
"zenith_utils",
|
||||
|
||||
@@ -12,7 +12,8 @@ bytes = { version = "1.0.1", features = ['serde'] }
|
||||
md5 = "0.7.0"
|
||||
rand = "0.8.3"
|
||||
hex = "0.4.3"
|
||||
|
||||
serde = "1"
|
||||
serde_json = "1"
|
||||
tokio = "1.7.1"
|
||||
tokio-postgres = "0.7.2"
|
||||
|
||||
|
||||
@@ -1,10 +1,17 @@
|
||||
use anyhow::{bail, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{collections::HashMap, net::SocketAddr};
|
||||
|
||||
pub struct CPlaneApi {
|
||||
address: SocketAddr,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct DatabaseInfo {
|
||||
pub addr: SocketAddr,
|
||||
pub connstr: String,
|
||||
}
|
||||
|
||||
// mock cplane api
|
||||
impl CPlaneApi {
|
||||
pub fn new(address: &SocketAddr) -> CPlaneApi {
|
||||
@@ -44,11 +51,17 @@ impl CPlaneApi {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_database_uri(&self, _user: &String, _database: &String) -> Result<String> {
|
||||
Ok("user=stas dbname=stas".to_string())
|
||||
pub fn get_database_uri(&self, _user: &String, _database: &String) -> Result<DatabaseInfo> {
|
||||
Ok(DatabaseInfo {
|
||||
addr: "127.0.0.1:5432".parse()?,
|
||||
connstr: "user=stas dbname=stas".into(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn create_database(&self, _user: &String, _database: &String) -> Result<String> {
|
||||
Ok("user=stas dbname=stas".to_string())
|
||||
pub fn create_database(&self, _user: &String, _database: &String) -> Result<DatabaseInfo> {
|
||||
Ok(DatabaseInfo {
|
||||
addr: "127.0.0.1:5432".parse()?,
|
||||
connstr: "user=stas dbname=stas".into(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,10 +6,14 @@
|
||||
/// in somewhat transparent manner (again via communication with control plane API).
|
||||
///
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
net::{SocketAddr, TcpListener},
|
||||
sync::{mpsc, Mutex},
|
||||
thread,
|
||||
};
|
||||
|
||||
use cplane_api::DatabaseInfo;
|
||||
|
||||
mod cplane_api;
|
||||
mod mgmt;
|
||||
mod proxy;
|
||||
@@ -26,20 +30,29 @@ pub struct ProxyConf {
|
||||
pub cplane_address: SocketAddr,
|
||||
}
|
||||
|
||||
pub struct ProxyState {
|
||||
pub conf: ProxyConf,
|
||||
pub waiters: Mutex<HashMap<String, mpsc::Sender<anyhow::Result<DatabaseInfo>>>>,
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
let conf = ProxyConf {
|
||||
proxy_address: "0.0.0.0:4000".parse()?,
|
||||
mgmt_address: "0.0.0.0:8080".parse()?,
|
||||
cplane_address: "127.0.0.1:3000".parse()?,
|
||||
};
|
||||
let conf: &'static ProxyConf = Box::leak(Box::new(conf));
|
||||
let state = ProxyState {
|
||||
conf,
|
||||
waiters: Mutex::new(HashMap::new()),
|
||||
};
|
||||
let state: &'static ProxyState = Box::leak(Box::new(state));
|
||||
|
||||
// Check that we can bind to address before further initialization
|
||||
println!("Starting proxy on {}", conf.proxy_address);
|
||||
let pageserver_listener = TcpListener::bind(conf.proxy_address)?;
|
||||
println!("Starting proxy on {}", state.conf.proxy_address);
|
||||
let pageserver_listener = TcpListener::bind(state.conf.proxy_address)?;
|
||||
|
||||
println!("Starting mgmt on {}", conf.mgmt_address);
|
||||
let mgmt_listener = TcpListener::bind(conf.mgmt_address)?;
|
||||
println!("Starting mgmt on {}", state.conf.mgmt_address);
|
||||
let mgmt_listener = TcpListener::bind(state.conf.mgmt_address)?;
|
||||
|
||||
let mut threads = Vec::new();
|
||||
|
||||
@@ -48,13 +61,13 @@ fn main() -> anyhow::Result<()> {
|
||||
threads.push(
|
||||
thread::Builder::new()
|
||||
.name("Proxy thread".into())
|
||||
.spawn(move || proxy::thread_main(&conf, pageserver_listener))?,
|
||||
.spawn(move || proxy::thread_main(&state, pageserver_listener))?,
|
||||
);
|
||||
|
||||
threads.push(
|
||||
thread::Builder::new()
|
||||
.name("Mgmt thread".into())
|
||||
.spawn(move || mgmt::thread_main(&conf, mgmt_listener))?,
|
||||
.spawn(move || mgmt::thread_main(&state, mgmt_listener))?,
|
||||
);
|
||||
|
||||
for t in threads {
|
||||
|
||||
@@ -3,41 +3,70 @@ use std::{
|
||||
thread,
|
||||
};
|
||||
|
||||
use anyhow::bail;
|
||||
use bytes::Bytes;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use zenith_utils::{
|
||||
postgres_backend::{self, PostgresBackend},
|
||||
pq_proto::{BeMessage, SINGLE_COL_ROWDESC},
|
||||
};
|
||||
|
||||
use crate::ProxyConf;
|
||||
use crate::{cplane_api::DatabaseInfo, ProxyState};
|
||||
|
||||
///
|
||||
/// Main proxy listener loop.
|
||||
///
|
||||
/// Listens for connections, and launches a new handler thread for each.
|
||||
///
|
||||
pub fn thread_main(conf: &'static ProxyConf, listener: TcpListener) -> anyhow::Result<()> {
|
||||
pub fn thread_main(state: &'static ProxyState, listener: TcpListener) -> anyhow::Result<()> {
|
||||
loop {
|
||||
let (socket, peer_addr) = listener.accept()?;
|
||||
println!("accepted connection from {}", peer_addr);
|
||||
socket.set_nodelay(true).unwrap();
|
||||
|
||||
thread::spawn(move || {
|
||||
if let Err(err) = mgmt_conn_main(conf, socket) {
|
||||
if let Err(err) = mgmt_conn_main(state, socket) {
|
||||
println!("error: {}", err);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub fn mgmt_conn_main(conf: &'static ProxyConf, socket: TcpStream) -> anyhow::Result<()> {
|
||||
let mut conn_handler = MgmtHandler { conf };
|
||||
pub fn mgmt_conn_main(state: &'static ProxyState, socket: TcpStream) -> anyhow::Result<()> {
|
||||
let mut conn_handler = MgmtHandler { state };
|
||||
let mut pgbackend = PostgresBackend::new(socket, postgres_backend::AuthType::Trust)?;
|
||||
pgbackend.run(&mut conn_handler)
|
||||
}
|
||||
|
||||
struct MgmtHandler {
|
||||
conf: &'static ProxyConf,
|
||||
state: &'static ProxyState,
|
||||
}
|
||||
/// Serialized examples:
|
||||
// {
|
||||
// "session_id": "71d6d03e6d93d99a",
|
||||
// "result": {
|
||||
// "Success": {
|
||||
// "addr": "127.0.0.1:5432",
|
||||
// "connstr": "user=stas dbname=stas"
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// {
|
||||
// "session_id": "71d6d03e6d93d99a",
|
||||
// "result": {
|
||||
// "Failure": "oops"
|
||||
// }
|
||||
// }
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct PsqlSessionResponse {
|
||||
session_id: String,
|
||||
result: PsqlSessionResult,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub enum PsqlSessionResult {
|
||||
Success(DatabaseInfo),
|
||||
Failure(String),
|
||||
}
|
||||
|
||||
impl postgres_backend::Handler for MgmtHandler {
|
||||
@@ -46,13 +75,40 @@ impl postgres_backend::Handler for MgmtHandler {
|
||||
pgb: &mut PostgresBackend,
|
||||
query_string: Bytes,
|
||||
) -> anyhow::Result<()> {
|
||||
let (_, query_string) = query_string
|
||||
.split_last()
|
||||
.ok_or_else(|| anyhow::Error::msg("protocol violation"))?;
|
||||
|
||||
let (_, query_string) = query_string
|
||||
.split_last()
|
||||
.ok_or_else(|| anyhow::Error::msg("protocol violation"))?;
|
||||
|
||||
println!("Got mgmt query: {:?}", query_string);
|
||||
|
||||
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
|
||||
.write_message_noflush(&BeMessage::DataRow(&[Some(b"mgmt_ok")]))?
|
||||
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
let resp: PsqlSessionResponse = serde_json::from_slice(&query_string)?;
|
||||
|
||||
pgb.flush()?;
|
||||
Ok(())
|
||||
let waiters = self.state.waiters.lock().unwrap();
|
||||
|
||||
let sender = waiters
|
||||
.get(&resp.session_id)
|
||||
.ok_or_else(|| anyhow::Error::msg("psql session_id is not found"))?;
|
||||
|
||||
match resp.result {
|
||||
PsqlSessionResult::Success(db_info) => {
|
||||
sender.send(Ok(db_info))?;
|
||||
|
||||
pgb.write_message_noflush(&SINGLE_COL_ROWDESC)?
|
||||
.write_message_noflush(&BeMessage::DataRow(&[Some(b"ok")]))?
|
||||
.write_message_noflush(&BeMessage::CommandComplete(b"SELECT 1"))?;
|
||||
pgb.flush()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
PsqlSessionResult::Failure(message) => {
|
||||
sender.send(Err(anyhow::Error::msg(message.clone())))?;
|
||||
|
||||
bail!("psql session request failed: {}", message)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
use crate::{cplane_api::CPlaneApi, ProxyConf};
|
||||
use crate::cplane_api::CPlaneApi;
|
||||
use crate::cplane_api::DatabaseInfo;
|
||||
use crate::ProxyState;
|
||||
|
||||
use anyhow::bail;
|
||||
use tokio_postgres::NoTls;
|
||||
|
||||
use rand::Rng;
|
||||
use std::sync::mpsc::channel;
|
||||
use std::thread;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use zenith_utils::postgres_backend::{PostgresBackend, ProtoState};
|
||||
@@ -16,7 +19,7 @@ use zenith_utils::{postgres_backend, pq_proto::BeMessage};
|
||||
/// Listens for connections, and launches a new handler thread for each.
|
||||
///
|
||||
pub fn thread_main(
|
||||
conf: &'static ProxyConf,
|
||||
state: &'static ProxyState,
|
||||
listener: std::net::TcpListener,
|
||||
) -> anyhow::Result<()> {
|
||||
loop {
|
||||
@@ -25,15 +28,17 @@ pub fn thread_main(
|
||||
socket.set_nodelay(true).unwrap();
|
||||
|
||||
thread::spawn(move || {
|
||||
if let Err(err) = proxy_conn_main(conf, socket) {
|
||||
if let Err(err) = proxy_conn_main(state, socket) {
|
||||
println!("error: {}", err);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// XXX: clean up fields
|
||||
struct ProxyConnection {
|
||||
conf: &'static ProxyConf,
|
||||
state: &'static ProxyState,
|
||||
|
||||
existing_user: bool,
|
||||
cplane: CPlaneApi,
|
||||
|
||||
@@ -42,20 +47,23 @@ struct ProxyConnection {
|
||||
|
||||
pgb: PostgresBackend,
|
||||
md5_salt: [u8; 4],
|
||||
|
||||
psql_session_id: String,
|
||||
}
|
||||
|
||||
pub fn proxy_conn_main(
|
||||
conf: &'static ProxyConf,
|
||||
state: &'static ProxyState,
|
||||
socket: std::net::TcpStream,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut conn = ProxyConnection {
|
||||
conf,
|
||||
state,
|
||||
existing_user: false,
|
||||
cplane: CPlaneApi::new(&conf.cplane_address),
|
||||
cplane: CPlaneApi::new(&state.conf.cplane_address),
|
||||
user: "".into(),
|
||||
database: "".into(),
|
||||
pgb: PostgresBackend::new(socket, postgres_backend::AuthType::MD5)?,
|
||||
md5_salt: [0u8; 4],
|
||||
psql_session_id: "".into(),
|
||||
};
|
||||
|
||||
// Check StartupMessage
|
||||
@@ -63,7 +71,7 @@ pub fn proxy_conn_main(
|
||||
conn.handle_startup()?;
|
||||
|
||||
// both scenarious here should end up producing database connection string
|
||||
let database_uri = if conn.existing_user {
|
||||
let db_info = if conn.existing_user {
|
||||
conn.handle_existing_user()?
|
||||
} else {
|
||||
conn.handle_new_user()?
|
||||
@@ -75,11 +83,7 @@ pub fn proxy_conn_main(
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let _ = runtime.block_on(proxy_pass(
|
||||
conn.pgb,
|
||||
"127.0.0.1:5432".to_string(),
|
||||
database_uri,
|
||||
))?;
|
||||
let _ = runtime.block_on(proxy_pass(conn.pgb, db_info))?;
|
||||
|
||||
println!("proxy_conn_main done;");
|
||||
|
||||
@@ -134,12 +138,12 @@ impl ProxyConnection {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn handle_existing_user(&mut self) -> anyhow::Result<String> {
|
||||
fn handle_existing_user(&mut self) -> anyhow::Result<DatabaseInfo> {
|
||||
// ask password
|
||||
rand::thread_rng().fill(&mut self.md5_salt);
|
||||
self.pgb
|
||||
.write_message(&BeMessage::AuthenticationMD5Password(&self.md5_salt))?;
|
||||
self.pgb.state = ProtoState::Authentication;
|
||||
self.pgb.state = ProtoState::Authentication; // XXX
|
||||
|
||||
// check password
|
||||
println!("handle_existing_user");
|
||||
@@ -171,21 +175,21 @@ impl ProxyConnection {
|
||||
self.cplane.get_database_uri(&self.user, &self.database)
|
||||
}
|
||||
|
||||
fn handle_new_user(&mut self) -> anyhow::Result<String> {
|
||||
let mut reg_id_buf = [0u8; 8];
|
||||
rand::thread_rng().fill(&mut reg_id_buf);
|
||||
let reg_id = hex::encode(reg_id_buf);
|
||||
fn handle_new_user(&mut self) -> anyhow::Result<DatabaseInfo> {
|
||||
let mut psql_session_id_buf = [0u8; 8];
|
||||
rand::thread_rng().fill(&mut psql_session_id_buf);
|
||||
self.psql_session_id = hex::encode(psql_session_id_buf);
|
||||
|
||||
let hello_message = format!("☀️ Welcome to Zenith!
|
||||
|
||||
To proceed with database creation open following link:
|
||||
|
||||
https://console.zenith.tech/claim_db/{}
|
||||
https://console.zenith.tech/psql_session/{}
|
||||
|
||||
It needed to be done once and we will send you '.pgpass' file which will allow you to access or create
|
||||
databases without opening the browser.
|
||||
|
||||
", reg_id);
|
||||
", self.psql_session_id);
|
||||
|
||||
self.pgb
|
||||
.write_message_noflush(&BeMessage::AuthenticationOk)?;
|
||||
@@ -195,10 +199,24 @@ databases without opening the browser.
|
||||
.write_message(&BeMessage::NoticeResponse(hello_message.to_string()))?;
|
||||
|
||||
// await for database creation
|
||||
let connstring = self.cplane.get_database_uri(&self.user, &self.database)?;
|
||||
let (tx, rx) = channel::<anyhow::Result<DatabaseInfo>>();
|
||||
let _ = self
|
||||
.state
|
||||
.waiters
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert(self.psql_session_id.clone(), tx);
|
||||
|
||||
// Wait for web console response
|
||||
// XXX: respond with error to client
|
||||
let dbinfo = rx.recv()??;
|
||||
|
||||
self.pgb.write_message_noflush(&BeMessage::NoticeResponse(
|
||||
"Connecting to database.".to_string(),
|
||||
))?;
|
||||
self.pgb.write_message(&BeMessage::ReadyForQuery)?;
|
||||
|
||||
Ok(connstring)
|
||||
Ok(dbinfo)
|
||||
}
|
||||
|
||||
fn check_auth_md5(&self, md5_response: &[u8]) -> anyhow::Result<()> {
|
||||
@@ -208,13 +226,9 @@ databases without opening the browser.
|
||||
}
|
||||
}
|
||||
|
||||
async fn proxy_pass(
|
||||
pgb: PostgresBackend,
|
||||
proxy_addr: String,
|
||||
connstr: String,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut socket = tokio::net::TcpStream::connect(proxy_addr).await?;
|
||||
let config = connstr.parse::<tokio_postgres::Config>()?;
|
||||
async fn proxy_pass(pgb: PostgresBackend, db_info: DatabaseInfo) -> anyhow::Result<()> {
|
||||
let mut socket = tokio::net::TcpStream::connect(db_info.addr).await?;
|
||||
let config = db_info.connstr.parse::<tokio_postgres::Config>()?;
|
||||
let _ = config.connect_raw(&mut socket, NoTls).await?;
|
||||
|
||||
println!("Connected to pg, proxying");
|
||||
|
||||
@@ -207,13 +207,9 @@ impl PostgresBackend {
|
||||
|
||||
assert!(self.state == ProtoState::Authentication);
|
||||
|
||||
let (trailing_null, md5_response) = m.split_last().unwrap();
|
||||
|
||||
if *trailing_null != 0 {
|
||||
let errmsg = "protocol violation";
|
||||
self.write_message(&BeMessage::ErrorResponse(format!("{}", errmsg)))?;
|
||||
bail!("auth failed: {}", errmsg);
|
||||
}
|
||||
let (_, md5_response) = m
|
||||
.split_last()
|
||||
.ok_or_else(|| anyhow::Error::msg("protocol violation"))?;
|
||||
|
||||
if let Err(e) = handler.check_auth_md5(self, md5_response) {
|
||||
self.write_message(&BeMessage::ErrorResponse(format!("{}", e)))?;
|
||||
|
||||
Reference in New Issue
Block a user