feat: add schema check on postgres startup (#758)

* feat: add schema check on postgres startup

* chore: update pgwire to 0.6.3

* test: add test for unspecified db
This commit is contained in:
Ning Sun
2022-12-19 10:53:44 +08:00
committed by GitHub
parent ea1896493b
commit efd85df6be
10 changed files with 126 additions and 43 deletions

4
Cargo.lock generated
View File

@@ -4561,9 +4561,9 @@ dependencies = [
[[package]]
name = "pgwire"
version = "0.6.1"
version = "0.6.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d90fd7db2eab0a1b9cdde0ef2393f99b83c6198b1c2e62595e8d269d59b8ffca"
checksum = "ab6d8c74bed581ab4a5ae0393ae05dc50e6b097d6298bcf97c5c58246b74aee6"
dependencies = [
"async-trait",
"bytes",

View File

@@ -44,7 +44,7 @@ use distributed::DistInstance;
use meta_client::client::{MetaClient, MetaClientBuilder};
use meta_client::MetaClientOpts;
use servers::query_handler::{
GrpcAdminHandler, GrpcAdminHandlerRef, GrpcQueryHandler, GrpcQueryHandlerRef,
CatalogHandler, GrpcAdminHandler, GrpcAdminHandlerRef, GrpcQueryHandler, GrpcQueryHandlerRef,
InfluxdbLineProtocolHandler, OpentsdbProtocolHandler, PrometheusProtocolHandler, ScriptHandler,
ScriptHandlerRef, SqlQueryHandler, SqlQueryHandlerRef,
};
@@ -79,6 +79,7 @@ pub trait FrontendInstance:
+ InfluxdbLineProtocolHandler
+ PrometheusProtocolHandler
+ ScriptHandler
+ CatalogHandler
+ Send
+ Sync
+ 'static
@@ -663,6 +664,15 @@ impl GrpcAdminHandler for Instance {
}
}
impl CatalogHandler for Instance {
fn is_valid_schema(&self, catalog: &str, schema: &str) -> server_error::Result<bool> {
self.catalog_manager
.schema(catalog, schema)
.map(|s| s.is_some())
.context(server_error::CatalogSnafu)
}
}
#[cfg(test)]
mod tests {
use std::assert_matches::assert_matches;

View File

@@ -98,6 +98,7 @@ impl Services {
);
let pg_server = Box::new(PostgresServer::new(
instance.clone(),
instance.clone(),
opts.tls.clone(),
pg_io_runtime,

View File

@@ -12,6 +12,7 @@ axum = "0.6"
axum-macros = "0.3"
base64 = "0.13"
bytes = "1.2"
catalog = { path = "../catalog" }
common-base = { path = "../common/base" }
common-catalog = { path = "../common/catalog" }
common-error = { path = "../common/error" }
@@ -34,7 +35,7 @@ num_cpus = "1.13"
once_cell = "1.16"
openmetrics-parser = "0.4"
opensrv-mysql = "0.3"
pgwire = "0.6.1"
pgwire = "0.6.3"
prost = "0.11"
rand = "0.8"
regex = "1.6"
@@ -60,7 +61,6 @@ tower-http = { version = "0.3", features = ["full"] }
[dev-dependencies]
axum-test-helper = { git = "https://github.com/sunng87/axum-test-helper.git", branch = "patch-1" }
catalog = { path = "../catalog" }
common-base = { path = "../common/base" }
mysql_async = { version = "0.31", default-features = false, features = [
"default-rustls",

View File

@@ -20,6 +20,7 @@ use axum::http::StatusCode as HttpStatusCode;
use axum::response::{IntoResponse, Response};
use axum::Json;
use base64::DecodeError;
use catalog;
use common_error::prelude::*;
use hyper::header::ToStrError;
use serde_json::json;
@@ -239,6 +240,9 @@ pub enum Error {
source: FromUtf8Error,
backtrace: Backtrace,
},
#[snafu(display("Error accessing catalog: {}", source))]
CatalogError { source: catalog::error::Error },
}
pub type Result<T> = std::result::Result<T, Error>;
@@ -258,6 +262,7 @@ impl ErrorExt for Error {
| InvalidPromRemoteReadQueryResult { .. }
| TcpBind { .. }
| GrpcReflectionService { .. }
| CatalogError { .. }
| BuildingContext { .. } => StatusCode::Internal,
InsertScript { source, .. }

View File

@@ -16,6 +16,7 @@ use std::collections::HashMap;
use std::fmt::Debug;
use async_trait::async_trait;
use common_catalog::consts::DEFAULT_CATALOG_NAME;
use futures::{Sink, SinkExt};
use pgwire::api::auth::{ServerParameterProvider, StartupHandler};
use pgwire::api::{auth, ClientInfo, PgWireConnectionState};
@@ -28,6 +29,7 @@ use snafu::ResultExt;
use crate::auth::{Identity, Password, UserProviderRef};
use crate::error;
use crate::error::Result;
use crate::query_handler::CatalogHandlerRef;
struct PgPwdVerifier {
user_provider: Option<UserProviderRef>,
@@ -108,14 +110,20 @@ pub struct PgAuthStartupHandler {
verifier: PgPwdVerifier,
param_provider: GreptimeDBStartupParameters,
force_tls: bool,
catalog_handler: CatalogHandlerRef,
}
impl PgAuthStartupHandler {
pub fn new(user_provider: Option<UserProviderRef>, force_tls: bool) -> Self {
pub fn new(
user_provider: Option<UserProviderRef>,
force_tls: bool,
catalog_handler: CatalogHandlerRef,
) -> Self {
PgAuthStartupHandler {
verifier: PgPwdVerifier { user_provider },
param_provider: GreptimeDBStartupParameters::new(),
force_tls,
catalog_handler,
}
}
}
@@ -134,21 +142,42 @@ impl StartupHandler for PgAuthStartupHandler {
{
match message {
PgWireFrontendMessage::Startup(ref startup) => {
// check ssl requirement
if !client.is_secure() && self.force_tls {
let error_info = ErrorInfo::new(
"FATAL".to_owned(),
"28000".to_owned(),
"No encryption".to_owned(),
);
let error = ErrorResponse::from(error_info);
client
.feed(PgWireBackendMessage::ErrorResponse(error))
.await?;
client.close().await?;
send_error(client, "FATAL", "28000", "No encryption".to_owned()).await?;
return Ok(());
}
auth::save_startup_parameters_to_metadata(client, startup);
// check if db is valid
let db_ref = client.metadata().get(super::METADATA_DATABASE);
if let Some(db) = db_ref {
if !self
.catalog_handler
.is_valid_schema(DEFAULT_CATALOG_NAME, db)
.map_err(|e| PgWireError::ApiError(Box::new(e)))?
{
send_error(
client,
"FATAL",
"3D000",
format!("Database not found: {}", db),
)
.await?;
return Ok(());
}
} else {
send_error(
client,
"FATAL",
"3D000",
"Database not specified".to_owned(),
)
.await?;
return Ok(());
}
if self.verifier.user_provider.is_some() {
client.set_state(PgWireConnectionState::AuthenticationInProgress);
client
@@ -165,17 +194,13 @@ impl StartupHandler for PgAuthStartupHandler {
if let Ok(true) = self.verifier.verify_pwd(pwd.password(), login_info).await {
auth::finish_authentication(client, &self.param_provider).await
} else {
let error_info = ErrorInfo::new(
"FATAL".to_owned(),
"28P01".to_owned(),
send_error(
client,
"FATAL",
"28P01",
"Password authentication failed".to_owned(),
);
let error = ErrorResponse::from(error_info);
client
.feed(PgWireBackendMessage::ErrorResponse(error))
.await?;
client.close().await?;
)
.await?;
}
}
_ => {}
@@ -183,3 +208,17 @@ impl StartupHandler for PgAuthStartupHandler {
Ok(())
}
}
async fn send_error<C>(client: &mut C, level: &str, code: &str, message: String) -> PgWireResult<()>
where
C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send,
C::Error: Debug,
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
{
let error = ErrorResponse::from(ErrorInfo::new(level.to_owned(), code.to_owned(), message));
client
.feed(PgWireBackendMessage::ErrorResponse(error))
.await?;
client.close().await?;
Ok(())
}

View File

@@ -29,7 +29,7 @@ use crate::auth::UserProviderRef;
use crate::error::Result;
use crate::postgres::auth_handler::PgAuthStartupHandler;
use crate::postgres::handler::PostgresServerHandler;
use crate::query_handler::SqlQueryHandlerRef;
use crate::query_handler::{CatalogHandlerRef, SqlQueryHandlerRef};
use crate::server::{AbortableStream, BaseTcpServer, Server};
use crate::tls::TlsOption;
@@ -44,6 +44,7 @@ impl PostgresServer {
/// Creates a new Postgres server with provided query_handler and async runtime
pub fn new(
query_handler: SqlQueryHandlerRef,
catalog_handler: CatalogHandlerRef,
tls: TlsOption,
io_runtime: Arc<Runtime>,
user_provider: Option<UserProviderRef>,
@@ -52,6 +53,7 @@ impl PostgresServer {
let startup_handler = Arc::new(PgAuthStartupHandler::new(
user_provider,
tls.should_force_tls(),
catalog_handler,
));
PostgresServer {
base_server: BaseTcpServer::create_server("Postgres", io_runtime),

View File

@@ -43,6 +43,7 @@ pub type OpentsdbProtocolHandlerRef = Arc<dyn OpentsdbProtocolHandler + Send + S
pub type InfluxdbLineProtocolHandlerRef = Arc<dyn InfluxdbLineProtocolHandler + Send + Sync>;
pub type PrometheusProtocolHandlerRef = Arc<dyn PrometheusProtocolHandler + Send + Sync>;
pub type ScriptHandlerRef = Arc<dyn ScriptHandler + Send + Sync>;
pub type CatalogHandlerRef = Arc<dyn CatalogHandler + Send + Sync>;
#[async_trait]
pub trait SqlQueryHandler {
@@ -100,3 +101,8 @@ pub trait PrometheusProtocolHandler {
/// Handling push gateway requests
async fn ingest_metrics(&self, metrics: Metrics) -> Result<()>;
}
pub trait CatalogHandler {
/// check if schema is valid
fn is_valid_schema(&self, catalog: &str, schema: &str) -> Result<bool>;
}

View File

@@ -23,7 +23,7 @@ use common_query::Output;
use query::{QueryEngineFactory, QueryEngineRef};
use servers::error::Result;
use servers::query_handler::{
ScriptHandler, ScriptHandlerRef, SqlQueryHandler, SqlQueryHandlerRef,
CatalogHandler, ScriptHandler, ScriptHandlerRef, SqlQueryHandler, SqlQueryHandlerRef,
};
use table::test_util::MemTable;
@@ -92,6 +92,12 @@ impl ScriptHandler for DummyInstance {
}
}
impl CatalogHandler for DummyInstance {
fn is_valid_schema(&self, catalog: &str, schema: &str) -> Result<bool> {
Ok(catalog == DEFAULT_CATALOG_NAME && schema == DEFAULT_SCHEMA_NAME)
}
}
fn create_testing_instance(table: MemTable) -> DummyInstance {
let table_name = table.table_name().to_string();
let table = Arc::new(table);

View File

@@ -31,14 +31,14 @@ use servers::tls::TlsOption;
use table::test_util::MemTable;
use tokio_postgres::{Client, Error as PgError, NoTls, SimpleQueryMessage};
use crate::create_testing_sql_query_handler;
use crate::create_testing_instance;
fn create_postgres_server(
table: MemTable,
check_pwd: bool,
tls: TlsOption,
) -> Result<Box<dyn Server>> {
let query_handler = create_testing_sql_query_handler(table);
let instance = Arc::new(create_testing_instance(table));
let io_runtime = Arc::new(
RuntimeBuilder::default()
.worker_threads(4)
@@ -55,7 +55,8 @@ fn create_postgres_server(
};
Ok(Box::new(PostgresServer::new(
query_handler,
instance.clone(),
instance,
tls,
io_runtime,
user_provider,
@@ -239,11 +240,11 @@ async fn test_server_secure_require_client_secure() -> Result<()> {
async fn test_using_db() -> Result<()> {
let server_port = start_test_server(TlsOption::default()).await?;
let client = create_connection_with_given_db(server_port, "testdb")
.await
.unwrap();
let result = client.simple_query("SELECT uint32s FROM numbers").await;
assert!(result.is_err());
let client = create_connection_with_given_db(server_port, "testdb").await;
assert!(client.is_err());
let client = create_connection_without_db(server_port).await;
assert!(client.is_err());
let client = create_connection_with_given_db(server_port, DEFAULT_SCHEMA_NAME)
.await
@@ -284,11 +285,14 @@ async fn create_secure_connection(
) -> std::result::Result<Client, PgError> {
let url = if with_pwd {
format!(
"sslmode=require host=127.0.0.1 port={} user=test_user password=test_pwd connect_timeout=2",
port
"sslmode=require host=127.0.0.1 port={} user=test_user password=test_pwd connect_timeout=2, dbname={}",
port, DEFAULT_SCHEMA_NAME
)
} else {
format!("host=127.0.0.1 port={} connect_timeout=2", port)
format!(
"host=127.0.0.1 port={} connect_timeout=2 dbname={}",
port, DEFAULT_SCHEMA_NAME
)
};
let mut config = rustls::ClientConfig::builder()
@@ -312,11 +316,14 @@ async fn create_plain_connection(
) -> std::result::Result<Client, PgError> {
let url = if with_pwd {
format!(
"host=127.0.0.1 port={} user=test_user password=test_pwd connect_timeout=2",
port
"host=127.0.0.1 port={} user=test_user password=test_pwd connect_timeout=2 dbname={}",
port, DEFAULT_SCHEMA_NAME
)
} else {
format!("host=127.0.0.1 port={} connect_timeout=2", port)
format!(
"host=127.0.0.1 port={} connect_timeout=2 dbname={}",
port, DEFAULT_SCHEMA_NAME
)
};
let (client, conn) = tokio_postgres::connect(&url, NoTls).await?;
tokio::spawn(conn);
@@ -336,6 +343,13 @@ async fn create_connection_with_given_db(
Ok(client)
}
async fn create_connection_without_db(port: u16) -> std::result::Result<Client, PgError> {
let url = format!("host=127.0.0.1 port={} connect_timeout=2", port);
let (client, conn) = tokio_postgres::connect(&url, NoTls).await?;
tokio::spawn(conn);
Ok(client)
}
fn resolve_result(resp: &SimpleQueryMessage, col_index: usize) -> Option<&str> {
match resp {
&SimpleQueryMessage::Row(ref r) => r.get(col_index),