refactor: make postgres handler stateful (#914)

* feat: update pgwire to 0.8 and unify postgres handler

* fix: correct password message matching
This commit is contained in:
Ning Sun
2023-01-31 14:19:18 +08:00
committed by GitHub
parent b2ad0e972b
commit 39df25a8f6
7 changed files with 598 additions and 544 deletions

847
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -25,6 +25,7 @@ common-runtime = { path = "../common/runtime" }
common-telemetry = { path = "../common/telemetry" }
common-time = { path = "../common/time" }
datatypes = { path = "../datatypes" }
derive_builder = "0.12"
digest = "0.10"
futures = "0.3"
hex = { version = "0.4" }
@@ -37,7 +38,7 @@ num_cpus = "1.13"
once_cell = "1.16"
openmetrics-parser = "0.4"
opensrv-mysql = { git = "https://github.com/datafuselabs/opensrv", rev = "b44c9d1360da297b305abf33aecfa94888e1554c" }
pgwire = "0.6.3"
pgwire = "0.8"
pin-project = "1.0"
prost.workspace = true
query = { path = "../query" }

View File

@@ -23,4 +23,83 @@ pub(crate) const METADATA_CATALOG: &str = "catalog";
/// key to store our parsed schema
pub(crate) const METADATA_SCHEMA: &str = "schema";
use std::collections::HashMap;
use std::sync::Arc;
use derive_builder::Builder;
use pgwire::api::auth::ServerParameterProvider;
use pgwire::api::stmt::NoopQueryParser;
use pgwire::api::store::MemPortalStore;
use pgwire::api::{ClientInfo, MakeHandler};
pub use server::PostgresServer;
use session::context::{QueryContext, QueryContextRef};
use self::auth_handler::PgLoginVerifier;
use crate::auth::UserProviderRef;
use crate::query_handler::sql::ServerSqlQueryHandlerRef;
pub(crate) struct GreptimeDBStartupParameters {
version: &'static str,
}
impl GreptimeDBStartupParameters {
fn new() -> GreptimeDBStartupParameters {
GreptimeDBStartupParameters {
version: env!("CARGO_PKG_VERSION"),
}
}
}
impl ServerParameterProvider for GreptimeDBStartupParameters {
fn server_parameters<C>(&self, _client: &C) -> Option<HashMap<String, String>>
where
C: ClientInfo,
{
let mut params = HashMap::with_capacity(4);
params.insert("server_version".to_owned(), self.version.to_owned());
params.insert("server_encoding".to_owned(), "UTF8".to_owned());
params.insert("client_encoding".to_owned(), "UTF8".to_owned());
params.insert("DateStyle".to_owned(), "ISO YMD".to_owned());
Some(params)
}
}
pub struct PostgresServerHandler {
query_handler: ServerSqlQueryHandlerRef,
login_verifier: PgLoginVerifier,
force_tls: bool,
param_provider: Arc<GreptimeDBStartupParameters>,
query_ctx: QueryContextRef,
portal_store: Arc<MemPortalStore<String>>,
query_parser: Arc<NoopQueryParser>,
}
#[derive(Builder)]
pub(crate) struct MakePostgresServerHandler {
query_handler: ServerSqlQueryHandlerRef,
user_provider: Option<UserProviderRef>,
#[builder(default = "Arc::new(GreptimeDBStartupParameters::new())")]
param_provider: Arc<GreptimeDBStartupParameters>,
#[builder(default = "Arc::new(NoopQueryParser::new())")]
query_parser: Arc<NoopQueryParser>,
force_tls: bool,
}
impl MakeHandler for MakePostgresServerHandler {
type Handler = Arc<PostgresServerHandler>;
fn make(&self) -> Self::Handler {
Arc::new(PostgresServerHandler {
query_handler: self.query_handler.clone(),
login_verifier: PgLoginVerifier::new(self.user_provider.clone()),
force_tls: self.force_tls,
param_provider: self.param_provider.clone(),
query_ctx: QueryContext::arc(),
portal_store: Arc::new(MemPortalStore::new()),
query_parser: self.query_parser.clone(),
})
}
}

View File

@@ -12,29 +12,35 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use std::fmt::Debug;
use async_trait::async_trait;
use futures::{Sink, SinkExt};
use pgwire::api::auth::{ServerParameterProvider, StartupHandler};
use pgwire::api::auth::StartupHandler;
use pgwire::api::{auth, ClientInfo, PgWireConnectionState};
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
use pgwire::messages::response::ErrorResponse;
use pgwire::messages::startup::Authentication;
use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage};
use session::context::UserInfo;
use session::context::{QueryContextRef, UserInfo};
use snafu::ResultExt;
use super::PostgresServerHandler;
use crate::auth::{Identity, Password, UserProviderRef};
use crate::error;
use crate::error::Result;
use crate::query_handler::sql::ServerSqlQueryHandlerRef;
struct PgLoginVerifier {
pub(crate) struct PgLoginVerifier {
user_provider: Option<UserProviderRef>,
}
impl PgLoginVerifier {
pub(crate) fn new(user_provider: Option<UserProviderRef>) -> Self {
Self { user_provider }
}
}
#[allow(dead_code)]
struct LoginInfo {
user: Option<String>,
@@ -107,61 +113,24 @@ impl PgLoginVerifier {
}
}
struct GreptimeDBStartupParameters {
version: &'static str,
}
impl GreptimeDBStartupParameters {
fn new() -> GreptimeDBStartupParameters {
GreptimeDBStartupParameters {
version: env!("CARGO_PKG_VERSION"),
}
fn set_query_context_from_client_info<C>(client: &C, query_context: QueryContextRef)
where
C: ClientInfo,
{
if let Some(current_catalog) = client.metadata().get(super::METADATA_CATALOG) {
query_context.set_current_catalog(current_catalog);
}
}
impl ServerParameterProvider for GreptimeDBStartupParameters {
fn server_parameters<C>(&self, _client: &C) -> Option<HashMap<String, String>>
where
C: ClientInfo,
{
let mut params = HashMap::with_capacity(4);
params.insert("server_version".to_owned(), self.version.to_owned());
params.insert("server_encoding".to_owned(), "UTF8".to_owned());
params.insert("client_encoding".to_owned(), "UTF8".to_owned());
params.insert("DateStyle".to_owned(), "ISO YMD".to_owned());
Some(params)
}
}
pub struct PgAuthStartupHandler {
verifier: PgLoginVerifier,
param_provider: GreptimeDBStartupParameters,
force_tls: bool,
query_handler: ServerSqlQueryHandlerRef,
}
impl PgAuthStartupHandler {
pub fn new(
user_provider: Option<UserProviderRef>,
force_tls: bool,
query_handler: ServerSqlQueryHandlerRef,
) -> Self {
PgAuthStartupHandler {
verifier: PgLoginVerifier { user_provider },
param_provider: GreptimeDBStartupParameters::new(),
force_tls,
query_handler,
}
if let Some(current_schema) = client.metadata().get(super::METADATA_SCHEMA) {
query_context.set_current_schema(current_schema);
}
}
#[async_trait]
impl StartupHandler for PgAuthStartupHandler {
impl StartupHandler for PostgresServerHandler {
async fn on_startup<C>(
&self,
client: &mut C,
message: &PgWireFrontendMessage,
message: PgWireFrontendMessage,
) -> PgWireResult<()>
where
C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send,
@@ -194,7 +163,7 @@ impl StartupHandler for PgAuthStartupHandler {
}
}
if self.verifier.user_provider.is_some() {
if self.login_verifier.user_provider.is_some() {
client.set_state(PgWireConnectionState::AuthenticationInProgress);
client
.send(PgWireBackendMessage::Authentication(
@@ -202,14 +171,22 @@ impl StartupHandler for PgAuthStartupHandler {
))
.await?;
} else {
auth::finish_authentication(client, &self.param_provider).await;
set_query_context_from_client_info(client, self.query_ctx.clone());
auth::finish_authentication(client, self.param_provider.as_ref()).await;
}
}
PgWireFrontendMessage::Password(ref pwd) => {
PgWireFrontendMessage::PasswordMessageFamily(pwd) => {
// the newer version of pgwire has a few variant password
// message like cleartext/md5 password, saslresponse, etc. Here
// we must manually coerce it into password
let pwd = pwd.into_password()?;
let login_info = LoginInfo::from_client_info(client);
// do authenticate
let authenticate_result =
self.verifier.verify_pwd(pwd.password(), &login_info).await;
let authenticate_result = self
.login_verifier
.verify_pwd(pwd.password(), &login_info)
.await;
if !matches!(authenticate_result, Ok(true)) {
return send_error(
client,
@@ -220,7 +197,7 @@ impl StartupHandler for PgAuthStartupHandler {
.await;
}
// do authorize
let authorize_result = self.verifier.authorize(&login_info).await;
let authorize_result = self.login_verifier.authorize(&login_info).await;
if !matches!(authorize_result, Ok(true)) {
return send_error(
client,
@@ -230,7 +207,8 @@ impl StartupHandler for PgAuthStartupHandler {
)
.await;
}
auth::finish_authentication(client, &self.param_provider).await;
set_query_context_from_client_info(client, self.query_ctx.clone());
auth::finish_authentication(client, self.param_provider.as_ref()).await;
}
_ => {}
}

View File

@@ -25,46 +25,24 @@ use futures::{future, stream, Stream, StreamExt};
use pgwire::api::portal::Portal;
use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
use pgwire::api::results::{text_query_response, FieldInfo, Response, Tag, TextDataRowEncoder};
use pgwire::api::stmt::NoopQueryParser;
use pgwire::api::store::MemPortalStore;
use pgwire::api::{ClientInfo, Type};
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
use session::context::QueryContext;
use super::PostgresServerHandler;
use crate::error::{self, Error, Result};
use crate::query_handler::sql::ServerSqlQueryHandlerRef;
pub struct PostgresServerHandler {
query_handler: ServerSqlQueryHandlerRef,
}
impl PostgresServerHandler {
pub fn new(query_handler: ServerSqlQueryHandlerRef) -> Self {
PostgresServerHandler { query_handler }
}
}
fn query_context_from_client_info<C>(client: &C) -> Arc<QueryContext>
where
C: ClientInfo,
{
let query_context = QueryContext::new();
if let Some(current_catalog) = client.metadata().get(super::METADATA_CATALOG) {
query_context.set_current_catalog(current_catalog);
}
if let Some(current_schema) = client.metadata().get(super::METADATA_SCHEMA) {
query_context.set_current_schema(current_schema);
}
Arc::new(query_context)
}
#[async_trait]
impl SimpleQueryHandler for PostgresServerHandler {
async fn do_query<C>(&self, client: &C, query: &str) -> PgWireResult<Vec<Response>>
async fn do_query<C>(&self, _client: &C, query: &str) -> PgWireResult<Vec<Response>>
where
C: ClientInfo + Unpin + Send + Sync,
{
let query_ctx = query_context_from_client_info(client);
let outputs = self.query_handler.do_query(query, query_ctx).await;
let outputs = self
.query_handler
.do_query(query, self.query_ctx.clone())
.await;
let mut results = Vec::with_capacity(outputs.len());
@@ -201,16 +179,43 @@ fn type_translate(origin: &ConcreteDataType) -> Result<Type> {
#[async_trait]
impl ExtendedQueryHandler for PostgresServerHandler {
type Statement = String;
type QueryParser = NoopQueryParser;
type PortalStore = MemPortalStore<Self::Statement>;
fn portal_store(&self) -> Arc<Self::PortalStore> {
self.portal_store.clone()
}
fn query_parser(&self) -> Arc<Self::QueryParser> {
self.query_parser.clone()
}
async fn do_query<C>(
&self,
_client: &mut C,
_portal: &Portal,
_portal: &Portal<Self::Statement>,
_max_rows: usize,
) -> PgWireResult<Response>
where
C: ClientInfo + Unpin + Send + Sync,
{
unimplemented!()
Ok(Response::Error(Box::new(ErrorInfo::new(
"ERROR".to_owned(),
"XX000".to_owned(),
"Extended query is not implemented on this server yet".to_owned(),
))))
}
async fn do_describe<C>(
&self,
_client: &mut C,
_statement: &Self::Statement,
) -> PgWireResult<Vec<FieldInfo>>
where
C: ClientInfo + Unpin + Send + Sync,
{
Ok(vec![])
}
}

View File

@@ -25,18 +25,16 @@ use pgwire::tokio::process_socket;
use tokio;
use tokio_rustls::TlsAcceptor;
use super::{MakePostgresServerHandler, MakePostgresServerHandlerBuilder};
use crate::auth::UserProviderRef;
use crate::error::Result;
use crate::postgres::auth_handler::PgAuthStartupHandler;
use crate::postgres::handler::PostgresServerHandler;
use crate::query_handler::sql::ServerSqlQueryHandlerRef;
use crate::server::{AbortableStream, BaseTcpServer, Server};
use crate::tls::TlsOption;
pub struct PostgresServer {
base_server: BaseTcpServer,
auth_handler: Arc<PgAuthStartupHandler>,
query_handler: Arc<PostgresServerHandler>,
make_handler: Arc<MakePostgresServerHandler>,
tls: TlsOption,
}
@@ -48,16 +46,17 @@ impl PostgresServer {
io_runtime: Arc<Runtime>,
user_provider: Option<UserProviderRef>,
) -> PostgresServer {
let postgres_handler = Arc::new(PostgresServerHandler::new(query_handler.clone()));
let startup_handler = Arc::new(PgAuthStartupHandler::new(
user_provider,
tls.should_force_tls(),
query_handler,
));
let make_handler = Arc::new(
MakePostgresServerHandlerBuilder::default()
.query_handler(query_handler.clone())
.user_provider(user_provider.clone())
.force_tls(tls.should_force_tls())
.build()
.unwrap(),
);
PostgresServer {
base_server: BaseTcpServer::create_server("Postgres", io_runtime),
auth_handler: startup_handler,
query_handler: postgres_handler,
make_handler,
tls,
}
}
@@ -68,14 +67,11 @@ impl PostgresServer {
accepting_stream: AbortableStream,
tls_acceptor: Option<Arc<TlsAcceptor>>,
) -> impl Future<Output = ()> {
let auth_handler = self.auth_handler.clone();
let query_handler = self.query_handler.clone();
let handler = self.make_handler.clone();
accepting_stream.for_each(move |tcp_stream| {
let io_runtime = io_runtime.clone();
let auth_handler = auth_handler.clone();
let query_handler = query_handler.clone();
let tls_acceptor = tls_acceptor.clone();
let handler = handler.clone();
async move {
match tcp_stream {
@@ -89,9 +85,9 @@ impl PostgresServer {
io_runtime.spawn(process_socket(
io_stream,
tls_acceptor.clone(),
auth_handler.clone(),
query_handler.clone(),
query_handler.clone(),
handler.clone(),
handler.clone(),
handler,
));
}
};

View File

@@ -89,13 +89,13 @@ async fn test_shutdown_pg_server_range() -> Result<()> {
Ok(())
}
#[tokio::test]
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_schema_validating() -> Result<()> {
async fn generate_server(auth_info: DatabaseAuthInfo<'_>) -> Result<(Box<dyn Server>, u16)> {
let table = MemTable::default_numbers_table();
let postgres_server =
create_postgres_server(table, true, Default::default(), Some(auth_info))?;
let listening = "127.0.0.1:5432".parse::<SocketAddr>().unwrap();
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = postgres_server.start(listening).await.unwrap();
let server_port = server_addr.port();
Ok((postgres_server, server_port))
@@ -103,8 +103,8 @@ async fn test_schema_validating() -> Result<()> {
common_telemetry::init_default_ut_logging();
let (pg_server, server_port) = generate_server(DatabaseAuthInfo {
catalog: "greptime",
schema: "public",
catalog: DEFAULT_CATALOG_NAME,
schema: DEFAULT_SCHEMA_NAME,
username: "greptime",
})
.await?;
@@ -115,8 +115,8 @@ async fn test_schema_validating() -> Result<()> {
assert!(result.is_ok());
let (pg_server, server_port) = generate_server(DatabaseAuthInfo {
catalog: "greptime",
schema: "public",
catalog: DEFAULT_CATALOG_NAME,
schema: DEFAULT_SCHEMA_NAME,
username: "no_right_user",
})
.await?;
@@ -141,7 +141,7 @@ async fn test_shutdown_pg_server(with_pwd: bool) -> Result<()> {
.to_string()
.contains("Postgres server is not started."));
let listening = "127.0.0.1:5432".parse::<SocketAddr>().unwrap();
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = postgres_server.start(listening).await.unwrap();
let server_port = server_addr.port();