From efd85df6be4bd1b6244c6c78f6f9f711d6f82297 Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Mon, 19 Dec 2022 10:53:44 +0800 Subject: [PATCH] feat: add schema check on postgres startup (#758) * feat: add schema check on postgres startup * chore: update pgwire to 0.6.3 * test: add test for unspecified db --- Cargo.lock | 4 +- src/frontend/src/instance.rs | 12 +++- src/frontend/src/server.rs | 1 + src/servers/Cargo.toml | 4 +- src/servers/src/error.rs | 5 ++ src/servers/src/postgres/auth_handler.rs | 83 +++++++++++++++++------- src/servers/src/postgres/server.rs | 4 +- src/servers/src/query_handler.rs | 6 ++ src/servers/tests/mod.rs | 8 ++- src/servers/tests/postgres/mod.rs | 42 ++++++++---- 10 files changed, 126 insertions(+), 43 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b76ae023ee..4c77450f2f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4561,9 +4561,9 @@ dependencies = [ [[package]] name = "pgwire" -version = "0.6.1" +version = "0.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d90fd7db2eab0a1b9cdde0ef2393f99b83c6198b1c2e62595e8d269d59b8ffca" +checksum = "ab6d8c74bed581ab4a5ae0393ae05dc50e6b097d6298bcf97c5c58246b74aee6" dependencies = [ "async-trait", "bytes", diff --git a/src/frontend/src/instance.rs b/src/frontend/src/instance.rs index 721fc5008a..babc06a76f 100644 --- a/src/frontend/src/instance.rs +++ b/src/frontend/src/instance.rs @@ -44,7 +44,7 @@ use distributed::DistInstance; use meta_client::client::{MetaClient, MetaClientBuilder}; use meta_client::MetaClientOpts; use servers::query_handler::{ - GrpcAdminHandler, GrpcAdminHandlerRef, GrpcQueryHandler, GrpcQueryHandlerRef, + CatalogHandler, GrpcAdminHandler, GrpcAdminHandlerRef, GrpcQueryHandler, GrpcQueryHandlerRef, InfluxdbLineProtocolHandler, OpentsdbProtocolHandler, PrometheusProtocolHandler, ScriptHandler, ScriptHandlerRef, SqlQueryHandler, SqlQueryHandlerRef, }; @@ -79,6 +79,7 @@ pub trait FrontendInstance: + InfluxdbLineProtocolHandler + PrometheusProtocolHandler + ScriptHandler + + CatalogHandler + Send + Sync + 'static @@ -663,6 +664,15 @@ impl GrpcAdminHandler for Instance { } } +impl CatalogHandler for Instance { + fn is_valid_schema(&self, catalog: &str, schema: &str) -> server_error::Result { + self.catalog_manager + .schema(catalog, schema) + .map(|s| s.is_some()) + .context(server_error::CatalogSnafu) + } +} + #[cfg(test)] mod tests { use std::assert_matches::assert_matches; diff --git a/src/frontend/src/server.rs b/src/frontend/src/server.rs index d3c55b8c97..6c04860192 100644 --- a/src/frontend/src/server.rs +++ b/src/frontend/src/server.rs @@ -98,6 +98,7 @@ impl Services { ); let pg_server = Box::new(PostgresServer::new( + instance.clone(), instance.clone(), opts.tls.clone(), pg_io_runtime, diff --git a/src/servers/Cargo.toml b/src/servers/Cargo.toml index d708cc5551..b0fb3f13dd 100644 --- a/src/servers/Cargo.toml +++ b/src/servers/Cargo.toml @@ -12,6 +12,7 @@ axum = "0.6" axum-macros = "0.3" base64 = "0.13" bytes = "1.2" +catalog = { path = "../catalog" } common-base = { path = "../common/base" } common-catalog = { path = "../common/catalog" } common-error = { path = "../common/error" } @@ -34,7 +35,7 @@ num_cpus = "1.13" once_cell = "1.16" openmetrics-parser = "0.4" opensrv-mysql = "0.3" -pgwire = "0.6.1" +pgwire = "0.6.3" prost = "0.11" rand = "0.8" regex = "1.6" @@ -60,7 +61,6 @@ tower-http = { version = "0.3", features = ["full"] } [dev-dependencies] axum-test-helper = { git = "https://github.com/sunng87/axum-test-helper.git", branch = "patch-1" } -catalog = { path = "../catalog" } common-base = { path = "../common/base" } mysql_async = { version = "0.31", default-features = false, features = [ "default-rustls", diff --git a/src/servers/src/error.rs b/src/servers/src/error.rs index 5c738cfc12..b039a09c4a 100644 --- a/src/servers/src/error.rs +++ b/src/servers/src/error.rs @@ -20,6 +20,7 @@ use axum::http::StatusCode as HttpStatusCode; use axum::response::{IntoResponse, Response}; use axum::Json; use base64::DecodeError; +use catalog; use common_error::prelude::*; use hyper::header::ToStrError; use serde_json::json; @@ -239,6 +240,9 @@ pub enum Error { source: FromUtf8Error, backtrace: Backtrace, }, + + #[snafu(display("Error accessing catalog: {}", source))] + CatalogError { source: catalog::error::Error }, } pub type Result = std::result::Result; @@ -258,6 +262,7 @@ impl ErrorExt for Error { | InvalidPromRemoteReadQueryResult { .. } | TcpBind { .. } | GrpcReflectionService { .. } + | CatalogError { .. } | BuildingContext { .. } => StatusCode::Internal, InsertScript { source, .. } diff --git a/src/servers/src/postgres/auth_handler.rs b/src/servers/src/postgres/auth_handler.rs index 70e1b06dc0..70a30974f1 100644 --- a/src/servers/src/postgres/auth_handler.rs +++ b/src/servers/src/postgres/auth_handler.rs @@ -16,6 +16,7 @@ use std::collections::HashMap; use std::fmt::Debug; use async_trait::async_trait; +use common_catalog::consts::DEFAULT_CATALOG_NAME; use futures::{Sink, SinkExt}; use pgwire::api::auth::{ServerParameterProvider, StartupHandler}; use pgwire::api::{auth, ClientInfo, PgWireConnectionState}; @@ -28,6 +29,7 @@ use snafu::ResultExt; use crate::auth::{Identity, Password, UserProviderRef}; use crate::error; use crate::error::Result; +use crate::query_handler::CatalogHandlerRef; struct PgPwdVerifier { user_provider: Option, @@ -108,14 +110,20 @@ pub struct PgAuthStartupHandler { verifier: PgPwdVerifier, param_provider: GreptimeDBStartupParameters, force_tls: bool, + catalog_handler: CatalogHandlerRef, } impl PgAuthStartupHandler { - pub fn new(user_provider: Option, force_tls: bool) -> Self { + pub fn new( + user_provider: Option, + force_tls: bool, + catalog_handler: CatalogHandlerRef, + ) -> Self { PgAuthStartupHandler { verifier: PgPwdVerifier { user_provider }, param_provider: GreptimeDBStartupParameters::new(), force_tls, + catalog_handler, } } } @@ -134,21 +142,42 @@ impl StartupHandler for PgAuthStartupHandler { { match message { PgWireFrontendMessage::Startup(ref startup) => { + // check ssl requirement if !client.is_secure() && self.force_tls { - let error_info = ErrorInfo::new( - "FATAL".to_owned(), - "28000".to_owned(), - "No encryption".to_owned(), - ); - let error = ErrorResponse::from(error_info); - - client - .feed(PgWireBackendMessage::ErrorResponse(error)) - .await?; - client.close().await?; + send_error(client, "FATAL", "28000", "No encryption".to_owned()).await?; return Ok(()); } + auth::save_startup_parameters_to_metadata(client, startup); + + // check if db is valid + let db_ref = client.metadata().get(super::METADATA_DATABASE); + if let Some(db) = db_ref { + if !self + .catalog_handler + .is_valid_schema(DEFAULT_CATALOG_NAME, db) + .map_err(|e| PgWireError::ApiError(Box::new(e)))? + { + send_error( + client, + "FATAL", + "3D000", + format!("Database not found: {}", db), + ) + .await?; + return Ok(()); + } + } else { + send_error( + client, + "FATAL", + "3D000", + "Database not specified".to_owned(), + ) + .await?; + return Ok(()); + } + if self.verifier.user_provider.is_some() { client.set_state(PgWireConnectionState::AuthenticationInProgress); client @@ -165,17 +194,13 @@ impl StartupHandler for PgAuthStartupHandler { if let Ok(true) = self.verifier.verify_pwd(pwd.password(), login_info).await { auth::finish_authentication(client, &self.param_provider).await } else { - let error_info = ErrorInfo::new( - "FATAL".to_owned(), - "28P01".to_owned(), + send_error( + client, + "FATAL", + "28P01", "Password authentication failed".to_owned(), - ); - let error = ErrorResponse::from(error_info); - - client - .feed(PgWireBackendMessage::ErrorResponse(error)) - .await?; - client.close().await?; + ) + .await?; } } _ => {} @@ -183,3 +208,17 @@ impl StartupHandler for PgAuthStartupHandler { Ok(()) } } + +async fn send_error(client: &mut C, level: &str, code: &str, message: String) -> PgWireResult<()> +where + C: ClientInfo + Sink + Unpin + Send, + C::Error: Debug, + PgWireError: From<>::Error>, +{ + let error = ErrorResponse::from(ErrorInfo::new(level.to_owned(), code.to_owned(), message)); + client + .feed(PgWireBackendMessage::ErrorResponse(error)) + .await?; + client.close().await?; + Ok(()) +} diff --git a/src/servers/src/postgres/server.rs b/src/servers/src/postgres/server.rs index 5003af92c5..09b2854c4c 100644 --- a/src/servers/src/postgres/server.rs +++ b/src/servers/src/postgres/server.rs @@ -29,7 +29,7 @@ use crate::auth::UserProviderRef; use crate::error::Result; use crate::postgres::auth_handler::PgAuthStartupHandler; use crate::postgres::handler::PostgresServerHandler; -use crate::query_handler::SqlQueryHandlerRef; +use crate::query_handler::{CatalogHandlerRef, SqlQueryHandlerRef}; use crate::server::{AbortableStream, BaseTcpServer, Server}; use crate::tls::TlsOption; @@ -44,6 +44,7 @@ impl PostgresServer { /// Creates a new Postgres server with provided query_handler and async runtime pub fn new( query_handler: SqlQueryHandlerRef, + catalog_handler: CatalogHandlerRef, tls: TlsOption, io_runtime: Arc, user_provider: Option, @@ -52,6 +53,7 @@ impl PostgresServer { let startup_handler = Arc::new(PgAuthStartupHandler::new( user_provider, tls.should_force_tls(), + catalog_handler, )); PostgresServer { base_server: BaseTcpServer::create_server("Postgres", io_runtime), diff --git a/src/servers/src/query_handler.rs b/src/servers/src/query_handler.rs index 3abb848737..1ef549c763 100644 --- a/src/servers/src/query_handler.rs +++ b/src/servers/src/query_handler.rs @@ -43,6 +43,7 @@ pub type OpentsdbProtocolHandlerRef = Arc; pub type PrometheusProtocolHandlerRef = Arc; pub type ScriptHandlerRef = Arc; +pub type CatalogHandlerRef = Arc; #[async_trait] pub trait SqlQueryHandler { @@ -100,3 +101,8 @@ pub trait PrometheusProtocolHandler { /// Handling push gateway requests async fn ingest_metrics(&self, metrics: Metrics) -> Result<()>; } + +pub trait CatalogHandler { + /// check if schema is valid + fn is_valid_schema(&self, catalog: &str, schema: &str) -> Result; +} diff --git a/src/servers/tests/mod.rs b/src/servers/tests/mod.rs index 5f2692bcab..d8d9bc974d 100644 --- a/src/servers/tests/mod.rs +++ b/src/servers/tests/mod.rs @@ -23,7 +23,7 @@ use common_query::Output; use query::{QueryEngineFactory, QueryEngineRef}; use servers::error::Result; use servers::query_handler::{ - ScriptHandler, ScriptHandlerRef, SqlQueryHandler, SqlQueryHandlerRef, + CatalogHandler, ScriptHandler, ScriptHandlerRef, SqlQueryHandler, SqlQueryHandlerRef, }; use table::test_util::MemTable; @@ -92,6 +92,12 @@ impl ScriptHandler for DummyInstance { } } +impl CatalogHandler for DummyInstance { + fn is_valid_schema(&self, catalog: &str, schema: &str) -> Result { + Ok(catalog == DEFAULT_CATALOG_NAME && schema == DEFAULT_SCHEMA_NAME) + } +} + fn create_testing_instance(table: MemTable) -> DummyInstance { let table_name = table.table_name().to_string(); let table = Arc::new(table); diff --git a/src/servers/tests/postgres/mod.rs b/src/servers/tests/postgres/mod.rs index 5653251c0d..8c7db2f056 100644 --- a/src/servers/tests/postgres/mod.rs +++ b/src/servers/tests/postgres/mod.rs @@ -31,14 +31,14 @@ use servers::tls::TlsOption; use table::test_util::MemTable; use tokio_postgres::{Client, Error as PgError, NoTls, SimpleQueryMessage}; -use crate::create_testing_sql_query_handler; +use crate::create_testing_instance; fn create_postgres_server( table: MemTable, check_pwd: bool, tls: TlsOption, ) -> Result> { - let query_handler = create_testing_sql_query_handler(table); + let instance = Arc::new(create_testing_instance(table)); let io_runtime = Arc::new( RuntimeBuilder::default() .worker_threads(4) @@ -55,7 +55,8 @@ fn create_postgres_server( }; Ok(Box::new(PostgresServer::new( - query_handler, + instance.clone(), + instance, tls, io_runtime, user_provider, @@ -239,11 +240,11 @@ async fn test_server_secure_require_client_secure() -> Result<()> { async fn test_using_db() -> Result<()> { let server_port = start_test_server(TlsOption::default()).await?; - let client = create_connection_with_given_db(server_port, "testdb") - .await - .unwrap(); - let result = client.simple_query("SELECT uint32s FROM numbers").await; - assert!(result.is_err()); + let client = create_connection_with_given_db(server_port, "testdb").await; + assert!(client.is_err()); + + let client = create_connection_without_db(server_port).await; + assert!(client.is_err()); let client = create_connection_with_given_db(server_port, DEFAULT_SCHEMA_NAME) .await @@ -284,11 +285,14 @@ async fn create_secure_connection( ) -> std::result::Result { let url = if with_pwd { format!( - "sslmode=require host=127.0.0.1 port={} user=test_user password=test_pwd connect_timeout=2", - port + "sslmode=require host=127.0.0.1 port={} user=test_user password=test_pwd connect_timeout=2, dbname={}", + port, DEFAULT_SCHEMA_NAME ) } else { - format!("host=127.0.0.1 port={} connect_timeout=2", port) + format!( + "host=127.0.0.1 port={} connect_timeout=2 dbname={}", + port, DEFAULT_SCHEMA_NAME + ) }; let mut config = rustls::ClientConfig::builder() @@ -312,11 +316,14 @@ async fn create_plain_connection( ) -> std::result::Result { let url = if with_pwd { format!( - "host=127.0.0.1 port={} user=test_user password=test_pwd connect_timeout=2", - port + "host=127.0.0.1 port={} user=test_user password=test_pwd connect_timeout=2 dbname={}", + port, DEFAULT_SCHEMA_NAME ) } else { - format!("host=127.0.0.1 port={} connect_timeout=2", port) + format!( + "host=127.0.0.1 port={} connect_timeout=2 dbname={}", + port, DEFAULT_SCHEMA_NAME + ) }; let (client, conn) = tokio_postgres::connect(&url, NoTls).await?; tokio::spawn(conn); @@ -336,6 +343,13 @@ async fn create_connection_with_given_db( Ok(client) } +async fn create_connection_without_db(port: u16) -> std::result::Result { + let url = format!("host=127.0.0.1 port={} connect_timeout=2", port); + let (client, conn) = tokio_postgres::connect(&url, NoTls).await?; + tokio::spawn(conn); + Ok(client) +} + fn resolve_result(resp: &SimpleQueryMessage, col_index: usize) -> Option<&str> { match resp { &SimpleQueryMessage::Row(ref r) => r.get(col_index),