mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-05-20 06:50:37 +00:00
refactor: make postgres handler stateful (#914)
* feat: update pgwire to 0.8 and unify postgres handler * fix: correct password message matching
This commit is contained in:
847
Cargo.lock
generated
847
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -25,6 +25,7 @@ common-runtime = { path = "../common/runtime" }
|
||||
common-telemetry = { path = "../common/telemetry" }
|
||||
common-time = { path = "../common/time" }
|
||||
datatypes = { path = "../datatypes" }
|
||||
derive_builder = "0.12"
|
||||
digest = "0.10"
|
||||
futures = "0.3"
|
||||
hex = { version = "0.4" }
|
||||
@@ -37,7 +38,7 @@ num_cpus = "1.13"
|
||||
once_cell = "1.16"
|
||||
openmetrics-parser = "0.4"
|
||||
opensrv-mysql = { git = "https://github.com/datafuselabs/opensrv", rev = "b44c9d1360da297b305abf33aecfa94888e1554c" }
|
||||
pgwire = "0.6.3"
|
||||
pgwire = "0.8"
|
||||
pin-project = "1.0"
|
||||
prost.workspace = true
|
||||
query = { path = "../query" }
|
||||
|
||||
@@ -23,4 +23,83 @@ pub(crate) const METADATA_CATALOG: &str = "catalog";
|
||||
/// key to store our parsed schema
|
||||
pub(crate) const METADATA_SCHEMA: &str = "schema";
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use derive_builder::Builder;
|
||||
use pgwire::api::auth::ServerParameterProvider;
|
||||
use pgwire::api::stmt::NoopQueryParser;
|
||||
use pgwire::api::store::MemPortalStore;
|
||||
use pgwire::api::{ClientInfo, MakeHandler};
|
||||
pub use server::PostgresServer;
|
||||
use session::context::{QueryContext, QueryContextRef};
|
||||
|
||||
use self::auth_handler::PgLoginVerifier;
|
||||
use crate::auth::UserProviderRef;
|
||||
use crate::query_handler::sql::ServerSqlQueryHandlerRef;
|
||||
|
||||
pub(crate) struct GreptimeDBStartupParameters {
|
||||
version: &'static str,
|
||||
}
|
||||
|
||||
impl GreptimeDBStartupParameters {
|
||||
fn new() -> GreptimeDBStartupParameters {
|
||||
GreptimeDBStartupParameters {
|
||||
version: env!("CARGO_PKG_VERSION"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ServerParameterProvider for GreptimeDBStartupParameters {
|
||||
fn server_parameters<C>(&self, _client: &C) -> Option<HashMap<String, String>>
|
||||
where
|
||||
C: ClientInfo,
|
||||
{
|
||||
let mut params = HashMap::with_capacity(4);
|
||||
params.insert("server_version".to_owned(), self.version.to_owned());
|
||||
params.insert("server_encoding".to_owned(), "UTF8".to_owned());
|
||||
params.insert("client_encoding".to_owned(), "UTF8".to_owned());
|
||||
params.insert("DateStyle".to_owned(), "ISO YMD".to_owned());
|
||||
|
||||
Some(params)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PostgresServerHandler {
|
||||
query_handler: ServerSqlQueryHandlerRef,
|
||||
login_verifier: PgLoginVerifier,
|
||||
force_tls: bool,
|
||||
param_provider: Arc<GreptimeDBStartupParameters>,
|
||||
|
||||
query_ctx: QueryContextRef,
|
||||
portal_store: Arc<MemPortalStore<String>>,
|
||||
query_parser: Arc<NoopQueryParser>,
|
||||
}
|
||||
|
||||
#[derive(Builder)]
|
||||
pub(crate) struct MakePostgresServerHandler {
|
||||
query_handler: ServerSqlQueryHandlerRef,
|
||||
user_provider: Option<UserProviderRef>,
|
||||
#[builder(default = "Arc::new(GreptimeDBStartupParameters::new())")]
|
||||
param_provider: Arc<GreptimeDBStartupParameters>,
|
||||
#[builder(default = "Arc::new(NoopQueryParser::new())")]
|
||||
query_parser: Arc<NoopQueryParser>,
|
||||
force_tls: bool,
|
||||
}
|
||||
|
||||
impl MakeHandler for MakePostgresServerHandler {
|
||||
type Handler = Arc<PostgresServerHandler>;
|
||||
|
||||
fn make(&self) -> Self::Handler {
|
||||
Arc::new(PostgresServerHandler {
|
||||
query_handler: self.query_handler.clone(),
|
||||
login_verifier: PgLoginVerifier::new(self.user_provider.clone()),
|
||||
force_tls: self.force_tls,
|
||||
param_provider: self.param_provider.clone(),
|
||||
|
||||
query_ctx: QueryContext::arc(),
|
||||
portal_store: Arc::new(MemPortalStore::new()),
|
||||
query_parser: self.query_parser.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,29 +12,35 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Debug;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::{Sink, SinkExt};
|
||||
use pgwire::api::auth::{ServerParameterProvider, StartupHandler};
|
||||
use pgwire::api::auth::StartupHandler;
|
||||
use pgwire::api::{auth, ClientInfo, PgWireConnectionState};
|
||||
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
|
||||
use pgwire::messages::response::ErrorResponse;
|
||||
use pgwire::messages::startup::Authentication;
|
||||
use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage};
|
||||
use session::context::UserInfo;
|
||||
use session::context::{QueryContextRef, UserInfo};
|
||||
use snafu::ResultExt;
|
||||
|
||||
use super::PostgresServerHandler;
|
||||
use crate::auth::{Identity, Password, UserProviderRef};
|
||||
use crate::error;
|
||||
use crate::error::Result;
|
||||
use crate::query_handler::sql::ServerSqlQueryHandlerRef;
|
||||
|
||||
struct PgLoginVerifier {
|
||||
pub(crate) struct PgLoginVerifier {
|
||||
user_provider: Option<UserProviderRef>,
|
||||
}
|
||||
|
||||
impl PgLoginVerifier {
|
||||
pub(crate) fn new(user_provider: Option<UserProviderRef>) -> Self {
|
||||
Self { user_provider }
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
struct LoginInfo {
|
||||
user: Option<String>,
|
||||
@@ -107,61 +113,24 @@ impl PgLoginVerifier {
|
||||
}
|
||||
}
|
||||
|
||||
struct GreptimeDBStartupParameters {
|
||||
version: &'static str,
|
||||
}
|
||||
|
||||
impl GreptimeDBStartupParameters {
|
||||
fn new() -> GreptimeDBStartupParameters {
|
||||
GreptimeDBStartupParameters {
|
||||
version: env!("CARGO_PKG_VERSION"),
|
||||
}
|
||||
fn set_query_context_from_client_info<C>(client: &C, query_context: QueryContextRef)
|
||||
where
|
||||
C: ClientInfo,
|
||||
{
|
||||
if let Some(current_catalog) = client.metadata().get(super::METADATA_CATALOG) {
|
||||
query_context.set_current_catalog(current_catalog);
|
||||
}
|
||||
}
|
||||
|
||||
impl ServerParameterProvider for GreptimeDBStartupParameters {
|
||||
fn server_parameters<C>(&self, _client: &C) -> Option<HashMap<String, String>>
|
||||
where
|
||||
C: ClientInfo,
|
||||
{
|
||||
let mut params = HashMap::with_capacity(4);
|
||||
params.insert("server_version".to_owned(), self.version.to_owned());
|
||||
params.insert("server_encoding".to_owned(), "UTF8".to_owned());
|
||||
params.insert("client_encoding".to_owned(), "UTF8".to_owned());
|
||||
params.insert("DateStyle".to_owned(), "ISO YMD".to_owned());
|
||||
|
||||
Some(params)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PgAuthStartupHandler {
|
||||
verifier: PgLoginVerifier,
|
||||
param_provider: GreptimeDBStartupParameters,
|
||||
force_tls: bool,
|
||||
query_handler: ServerSqlQueryHandlerRef,
|
||||
}
|
||||
|
||||
impl PgAuthStartupHandler {
|
||||
pub fn new(
|
||||
user_provider: Option<UserProviderRef>,
|
||||
force_tls: bool,
|
||||
query_handler: ServerSqlQueryHandlerRef,
|
||||
) -> Self {
|
||||
PgAuthStartupHandler {
|
||||
verifier: PgLoginVerifier { user_provider },
|
||||
param_provider: GreptimeDBStartupParameters::new(),
|
||||
force_tls,
|
||||
query_handler,
|
||||
}
|
||||
if let Some(current_schema) = client.metadata().get(super::METADATA_SCHEMA) {
|
||||
query_context.set_current_schema(current_schema);
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl StartupHandler for PgAuthStartupHandler {
|
||||
impl StartupHandler for PostgresServerHandler {
|
||||
async fn on_startup<C>(
|
||||
&self,
|
||||
client: &mut C,
|
||||
message: &PgWireFrontendMessage,
|
||||
message: PgWireFrontendMessage,
|
||||
) -> PgWireResult<()>
|
||||
where
|
||||
C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send,
|
||||
@@ -194,7 +163,7 @@ impl StartupHandler for PgAuthStartupHandler {
|
||||
}
|
||||
}
|
||||
|
||||
if self.verifier.user_provider.is_some() {
|
||||
if self.login_verifier.user_provider.is_some() {
|
||||
client.set_state(PgWireConnectionState::AuthenticationInProgress);
|
||||
client
|
||||
.send(PgWireBackendMessage::Authentication(
|
||||
@@ -202,14 +171,22 @@ impl StartupHandler for PgAuthStartupHandler {
|
||||
))
|
||||
.await?;
|
||||
} else {
|
||||
auth::finish_authentication(client, &self.param_provider).await;
|
||||
set_query_context_from_client_info(client, self.query_ctx.clone());
|
||||
auth::finish_authentication(client, self.param_provider.as_ref()).await;
|
||||
}
|
||||
}
|
||||
PgWireFrontendMessage::Password(ref pwd) => {
|
||||
PgWireFrontendMessage::PasswordMessageFamily(pwd) => {
|
||||
// the newer version of pgwire has a few variant password
|
||||
// message like cleartext/md5 password, saslresponse, etc. Here
|
||||
// we must manually coerce it into password
|
||||
let pwd = pwd.into_password()?;
|
||||
|
||||
let login_info = LoginInfo::from_client_info(client);
|
||||
// do authenticate
|
||||
let authenticate_result =
|
||||
self.verifier.verify_pwd(pwd.password(), &login_info).await;
|
||||
let authenticate_result = self
|
||||
.login_verifier
|
||||
.verify_pwd(pwd.password(), &login_info)
|
||||
.await;
|
||||
if !matches!(authenticate_result, Ok(true)) {
|
||||
return send_error(
|
||||
client,
|
||||
@@ -220,7 +197,7 @@ impl StartupHandler for PgAuthStartupHandler {
|
||||
.await;
|
||||
}
|
||||
// do authorize
|
||||
let authorize_result = self.verifier.authorize(&login_info).await;
|
||||
let authorize_result = self.login_verifier.authorize(&login_info).await;
|
||||
if !matches!(authorize_result, Ok(true)) {
|
||||
return send_error(
|
||||
client,
|
||||
@@ -230,7 +207,8 @@ impl StartupHandler for PgAuthStartupHandler {
|
||||
)
|
||||
.await;
|
||||
}
|
||||
auth::finish_authentication(client, &self.param_provider).await;
|
||||
set_query_context_from_client_info(client, self.query_ctx.clone());
|
||||
auth::finish_authentication(client, self.param_provider.as_ref()).await;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
@@ -25,46 +25,24 @@ use futures::{future, stream, Stream, StreamExt};
|
||||
use pgwire::api::portal::Portal;
|
||||
use pgwire::api::query::{ExtendedQueryHandler, SimpleQueryHandler};
|
||||
use pgwire::api::results::{text_query_response, FieldInfo, Response, Tag, TextDataRowEncoder};
|
||||
use pgwire::api::stmt::NoopQueryParser;
|
||||
use pgwire::api::store::MemPortalStore;
|
||||
use pgwire::api::{ClientInfo, Type};
|
||||
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
|
||||
use session::context::QueryContext;
|
||||
|
||||
use super::PostgresServerHandler;
|
||||
use crate::error::{self, Error, Result};
|
||||
use crate::query_handler::sql::ServerSqlQueryHandlerRef;
|
||||
|
||||
pub struct PostgresServerHandler {
|
||||
query_handler: ServerSqlQueryHandlerRef,
|
||||
}
|
||||
|
||||
impl PostgresServerHandler {
|
||||
pub fn new(query_handler: ServerSqlQueryHandlerRef) -> Self {
|
||||
PostgresServerHandler { query_handler }
|
||||
}
|
||||
}
|
||||
|
||||
fn query_context_from_client_info<C>(client: &C) -> Arc<QueryContext>
|
||||
where
|
||||
C: ClientInfo,
|
||||
{
|
||||
let query_context = QueryContext::new();
|
||||
if let Some(current_catalog) = client.metadata().get(super::METADATA_CATALOG) {
|
||||
query_context.set_current_catalog(current_catalog);
|
||||
}
|
||||
if let Some(current_schema) = client.metadata().get(super::METADATA_SCHEMA) {
|
||||
query_context.set_current_schema(current_schema);
|
||||
}
|
||||
|
||||
Arc::new(query_context)
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl SimpleQueryHandler for PostgresServerHandler {
|
||||
async fn do_query<C>(&self, client: &C, query: &str) -> PgWireResult<Vec<Response>>
|
||||
async fn do_query<C>(&self, _client: &C, query: &str) -> PgWireResult<Vec<Response>>
|
||||
where
|
||||
C: ClientInfo + Unpin + Send + Sync,
|
||||
{
|
||||
let query_ctx = query_context_from_client_info(client);
|
||||
let outputs = self.query_handler.do_query(query, query_ctx).await;
|
||||
let outputs = self
|
||||
.query_handler
|
||||
.do_query(query, self.query_ctx.clone())
|
||||
.await;
|
||||
|
||||
let mut results = Vec::with_capacity(outputs.len());
|
||||
|
||||
@@ -201,16 +179,43 @@ fn type_translate(origin: &ConcreteDataType) -> Result<Type> {
|
||||
|
||||
#[async_trait]
|
||||
impl ExtendedQueryHandler for PostgresServerHandler {
|
||||
type Statement = String;
|
||||
type QueryParser = NoopQueryParser;
|
||||
type PortalStore = MemPortalStore<Self::Statement>;
|
||||
|
||||
fn portal_store(&self) -> Arc<Self::PortalStore> {
|
||||
self.portal_store.clone()
|
||||
}
|
||||
|
||||
fn query_parser(&self) -> Arc<Self::QueryParser> {
|
||||
self.query_parser.clone()
|
||||
}
|
||||
|
||||
async fn do_query<C>(
|
||||
&self,
|
||||
_client: &mut C,
|
||||
_portal: &Portal,
|
||||
_portal: &Portal<Self::Statement>,
|
||||
_max_rows: usize,
|
||||
) -> PgWireResult<Response>
|
||||
where
|
||||
C: ClientInfo + Unpin + Send + Sync,
|
||||
{
|
||||
unimplemented!()
|
||||
Ok(Response::Error(Box::new(ErrorInfo::new(
|
||||
"ERROR".to_owned(),
|
||||
"XX000".to_owned(),
|
||||
"Extended query is not implemented on this server yet".to_owned(),
|
||||
))))
|
||||
}
|
||||
|
||||
async fn do_describe<C>(
|
||||
&self,
|
||||
_client: &mut C,
|
||||
_statement: &Self::Statement,
|
||||
) -> PgWireResult<Vec<FieldInfo>>
|
||||
where
|
||||
C: ClientInfo + Unpin + Send + Sync,
|
||||
{
|
||||
Ok(vec![])
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -25,18 +25,16 @@ use pgwire::tokio::process_socket;
|
||||
use tokio;
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
|
||||
use super::{MakePostgresServerHandler, MakePostgresServerHandlerBuilder};
|
||||
use crate::auth::UserProviderRef;
|
||||
use crate::error::Result;
|
||||
use crate::postgres::auth_handler::PgAuthStartupHandler;
|
||||
use crate::postgres::handler::PostgresServerHandler;
|
||||
use crate::query_handler::sql::ServerSqlQueryHandlerRef;
|
||||
use crate::server::{AbortableStream, BaseTcpServer, Server};
|
||||
use crate::tls::TlsOption;
|
||||
|
||||
pub struct PostgresServer {
|
||||
base_server: BaseTcpServer,
|
||||
auth_handler: Arc<PgAuthStartupHandler>,
|
||||
query_handler: Arc<PostgresServerHandler>,
|
||||
make_handler: Arc<MakePostgresServerHandler>,
|
||||
tls: TlsOption,
|
||||
}
|
||||
|
||||
@@ -48,16 +46,17 @@ impl PostgresServer {
|
||||
io_runtime: Arc<Runtime>,
|
||||
user_provider: Option<UserProviderRef>,
|
||||
) -> PostgresServer {
|
||||
let postgres_handler = Arc::new(PostgresServerHandler::new(query_handler.clone()));
|
||||
let startup_handler = Arc::new(PgAuthStartupHandler::new(
|
||||
user_provider,
|
||||
tls.should_force_tls(),
|
||||
query_handler,
|
||||
));
|
||||
let make_handler = Arc::new(
|
||||
MakePostgresServerHandlerBuilder::default()
|
||||
.query_handler(query_handler.clone())
|
||||
.user_provider(user_provider.clone())
|
||||
.force_tls(tls.should_force_tls())
|
||||
.build()
|
||||
.unwrap(),
|
||||
);
|
||||
PostgresServer {
|
||||
base_server: BaseTcpServer::create_server("Postgres", io_runtime),
|
||||
auth_handler: startup_handler,
|
||||
query_handler: postgres_handler,
|
||||
make_handler,
|
||||
tls,
|
||||
}
|
||||
}
|
||||
@@ -68,14 +67,11 @@ impl PostgresServer {
|
||||
accepting_stream: AbortableStream,
|
||||
tls_acceptor: Option<Arc<TlsAcceptor>>,
|
||||
) -> impl Future<Output = ()> {
|
||||
let auth_handler = self.auth_handler.clone();
|
||||
let query_handler = self.query_handler.clone();
|
||||
|
||||
let handler = self.make_handler.clone();
|
||||
accepting_stream.for_each(move |tcp_stream| {
|
||||
let io_runtime = io_runtime.clone();
|
||||
let auth_handler = auth_handler.clone();
|
||||
let query_handler = query_handler.clone();
|
||||
let tls_acceptor = tls_acceptor.clone();
|
||||
let handler = handler.clone();
|
||||
|
||||
async move {
|
||||
match tcp_stream {
|
||||
@@ -89,9 +85,9 @@ impl PostgresServer {
|
||||
io_runtime.spawn(process_socket(
|
||||
io_stream,
|
||||
tls_acceptor.clone(),
|
||||
auth_handler.clone(),
|
||||
query_handler.clone(),
|
||||
query_handler.clone(),
|
||||
handler.clone(),
|
||||
handler.clone(),
|
||||
handler,
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -89,13 +89,13 @@ async fn test_shutdown_pg_server_range() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn test_schema_validating() -> Result<()> {
|
||||
async fn generate_server(auth_info: DatabaseAuthInfo<'_>) -> Result<(Box<dyn Server>, u16)> {
|
||||
let table = MemTable::default_numbers_table();
|
||||
let postgres_server =
|
||||
create_postgres_server(table, true, Default::default(), Some(auth_info))?;
|
||||
let listening = "127.0.0.1:5432".parse::<SocketAddr>().unwrap();
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let server_addr = postgres_server.start(listening).await.unwrap();
|
||||
let server_port = server_addr.port();
|
||||
Ok((postgres_server, server_port))
|
||||
@@ -103,8 +103,8 @@ async fn test_schema_validating() -> Result<()> {
|
||||
|
||||
common_telemetry::init_default_ut_logging();
|
||||
let (pg_server, server_port) = generate_server(DatabaseAuthInfo {
|
||||
catalog: "greptime",
|
||||
schema: "public",
|
||||
catalog: DEFAULT_CATALOG_NAME,
|
||||
schema: DEFAULT_SCHEMA_NAME,
|
||||
username: "greptime",
|
||||
})
|
||||
.await?;
|
||||
@@ -115,8 +115,8 @@ async fn test_schema_validating() -> Result<()> {
|
||||
assert!(result.is_ok());
|
||||
|
||||
let (pg_server, server_port) = generate_server(DatabaseAuthInfo {
|
||||
catalog: "greptime",
|
||||
schema: "public",
|
||||
catalog: DEFAULT_CATALOG_NAME,
|
||||
schema: DEFAULT_SCHEMA_NAME,
|
||||
username: "no_right_user",
|
||||
})
|
||||
.await?;
|
||||
@@ -141,7 +141,7 @@ async fn test_shutdown_pg_server(with_pwd: bool) -> Result<()> {
|
||||
.to_string()
|
||||
.contains("Postgres server is not started."));
|
||||
|
||||
let listening = "127.0.0.1:5432".parse::<SocketAddr>().unwrap();
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let server_addr = postgres_server.start(listening).await.unwrap();
|
||||
let server_port = server_addr.port();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user