mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2025-12-22 22:20:02 +00:00
feat: add schema check on postgres startup (#758)
* feat: add schema check on postgres startup * chore: update pgwire to 0.6.3 * test: add test for unspecified db
This commit is contained in:
4
Cargo.lock
generated
4
Cargo.lock
generated
@@ -4561,9 +4561,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "pgwire"
|
||||
version = "0.6.1"
|
||||
version = "0.6.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d90fd7db2eab0a1b9cdde0ef2393f99b83c6198b1c2e62595e8d269d59b8ffca"
|
||||
checksum = "ab6d8c74bed581ab4a5ae0393ae05dc50e6b097d6298bcf97c5c58246b74aee6"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"bytes",
|
||||
|
||||
@@ -44,7 +44,7 @@ use distributed::DistInstance;
|
||||
use meta_client::client::{MetaClient, MetaClientBuilder};
|
||||
use meta_client::MetaClientOpts;
|
||||
use servers::query_handler::{
|
||||
GrpcAdminHandler, GrpcAdminHandlerRef, GrpcQueryHandler, GrpcQueryHandlerRef,
|
||||
CatalogHandler, GrpcAdminHandler, GrpcAdminHandlerRef, GrpcQueryHandler, GrpcQueryHandlerRef,
|
||||
InfluxdbLineProtocolHandler, OpentsdbProtocolHandler, PrometheusProtocolHandler, ScriptHandler,
|
||||
ScriptHandlerRef, SqlQueryHandler, SqlQueryHandlerRef,
|
||||
};
|
||||
@@ -79,6 +79,7 @@ pub trait FrontendInstance:
|
||||
+ InfluxdbLineProtocolHandler
|
||||
+ PrometheusProtocolHandler
|
||||
+ ScriptHandler
|
||||
+ CatalogHandler
|
||||
+ Send
|
||||
+ Sync
|
||||
+ 'static
|
||||
@@ -663,6 +664,15 @@ impl GrpcAdminHandler for Instance {
|
||||
}
|
||||
}
|
||||
|
||||
impl CatalogHandler for Instance {
|
||||
fn is_valid_schema(&self, catalog: &str, schema: &str) -> server_error::Result<bool> {
|
||||
self.catalog_manager
|
||||
.schema(catalog, schema)
|
||||
.map(|s| s.is_some())
|
||||
.context(server_error::CatalogSnafu)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::assert_matches::assert_matches;
|
||||
|
||||
@@ -98,6 +98,7 @@ impl Services {
|
||||
);
|
||||
|
||||
let pg_server = Box::new(PostgresServer::new(
|
||||
instance.clone(),
|
||||
instance.clone(),
|
||||
opts.tls.clone(),
|
||||
pg_io_runtime,
|
||||
|
||||
@@ -12,6 +12,7 @@ axum = "0.6"
|
||||
axum-macros = "0.3"
|
||||
base64 = "0.13"
|
||||
bytes = "1.2"
|
||||
catalog = { path = "../catalog" }
|
||||
common-base = { path = "../common/base" }
|
||||
common-catalog = { path = "../common/catalog" }
|
||||
common-error = { path = "../common/error" }
|
||||
@@ -34,7 +35,7 @@ num_cpus = "1.13"
|
||||
once_cell = "1.16"
|
||||
openmetrics-parser = "0.4"
|
||||
opensrv-mysql = "0.3"
|
||||
pgwire = "0.6.1"
|
||||
pgwire = "0.6.3"
|
||||
prost = "0.11"
|
||||
rand = "0.8"
|
||||
regex = "1.6"
|
||||
@@ -60,7 +61,6 @@ tower-http = { version = "0.3", features = ["full"] }
|
||||
|
||||
[dev-dependencies]
|
||||
axum-test-helper = { git = "https://github.com/sunng87/axum-test-helper.git", branch = "patch-1" }
|
||||
catalog = { path = "../catalog" }
|
||||
common-base = { path = "../common/base" }
|
||||
mysql_async = { version = "0.31", default-features = false, features = [
|
||||
"default-rustls",
|
||||
|
||||
@@ -20,6 +20,7 @@ use axum::http::StatusCode as HttpStatusCode;
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use axum::Json;
|
||||
use base64::DecodeError;
|
||||
use catalog;
|
||||
use common_error::prelude::*;
|
||||
use hyper::header::ToStrError;
|
||||
use serde_json::json;
|
||||
@@ -239,6 +240,9 @@ pub enum Error {
|
||||
source: FromUtf8Error,
|
||||
backtrace: Backtrace,
|
||||
},
|
||||
|
||||
#[snafu(display("Error accessing catalog: {}", source))]
|
||||
CatalogError { source: catalog::error::Error },
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
@@ -258,6 +262,7 @@ impl ErrorExt for Error {
|
||||
| InvalidPromRemoteReadQueryResult { .. }
|
||||
| TcpBind { .. }
|
||||
| GrpcReflectionService { .. }
|
||||
| CatalogError { .. }
|
||||
| BuildingContext { .. } => StatusCode::Internal,
|
||||
|
||||
InsertScript { source, .. }
|
||||
|
||||
@@ -16,6 +16,7 @@ use std::collections::HashMap;
|
||||
use std::fmt::Debug;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use common_catalog::consts::DEFAULT_CATALOG_NAME;
|
||||
use futures::{Sink, SinkExt};
|
||||
use pgwire::api::auth::{ServerParameterProvider, StartupHandler};
|
||||
use pgwire::api::{auth, ClientInfo, PgWireConnectionState};
|
||||
@@ -28,6 +29,7 @@ use snafu::ResultExt;
|
||||
use crate::auth::{Identity, Password, UserProviderRef};
|
||||
use crate::error;
|
||||
use crate::error::Result;
|
||||
use crate::query_handler::CatalogHandlerRef;
|
||||
|
||||
struct PgPwdVerifier {
|
||||
user_provider: Option<UserProviderRef>,
|
||||
@@ -108,14 +110,20 @@ pub struct PgAuthStartupHandler {
|
||||
verifier: PgPwdVerifier,
|
||||
param_provider: GreptimeDBStartupParameters,
|
||||
force_tls: bool,
|
||||
catalog_handler: CatalogHandlerRef,
|
||||
}
|
||||
|
||||
impl PgAuthStartupHandler {
|
||||
pub fn new(user_provider: Option<UserProviderRef>, force_tls: bool) -> Self {
|
||||
pub fn new(
|
||||
user_provider: Option<UserProviderRef>,
|
||||
force_tls: bool,
|
||||
catalog_handler: CatalogHandlerRef,
|
||||
) -> Self {
|
||||
PgAuthStartupHandler {
|
||||
verifier: PgPwdVerifier { user_provider },
|
||||
param_provider: GreptimeDBStartupParameters::new(),
|
||||
force_tls,
|
||||
catalog_handler,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -134,21 +142,42 @@ impl StartupHandler for PgAuthStartupHandler {
|
||||
{
|
||||
match message {
|
||||
PgWireFrontendMessage::Startup(ref startup) => {
|
||||
// check ssl requirement
|
||||
if !client.is_secure() && self.force_tls {
|
||||
let error_info = ErrorInfo::new(
|
||||
"FATAL".to_owned(),
|
||||
"28000".to_owned(),
|
||||
"No encryption".to_owned(),
|
||||
);
|
||||
let error = ErrorResponse::from(error_info);
|
||||
|
||||
client
|
||||
.feed(PgWireBackendMessage::ErrorResponse(error))
|
||||
.await?;
|
||||
client.close().await?;
|
||||
send_error(client, "FATAL", "28000", "No encryption".to_owned()).await?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
auth::save_startup_parameters_to_metadata(client, startup);
|
||||
|
||||
// check if db is valid
|
||||
let db_ref = client.metadata().get(super::METADATA_DATABASE);
|
||||
if let Some(db) = db_ref {
|
||||
if !self
|
||||
.catalog_handler
|
||||
.is_valid_schema(DEFAULT_CATALOG_NAME, db)
|
||||
.map_err(|e| PgWireError::ApiError(Box::new(e)))?
|
||||
{
|
||||
send_error(
|
||||
client,
|
||||
"FATAL",
|
||||
"3D000",
|
||||
format!("Database not found: {}", db),
|
||||
)
|
||||
.await?;
|
||||
return Ok(());
|
||||
}
|
||||
} else {
|
||||
send_error(
|
||||
client,
|
||||
"FATAL",
|
||||
"3D000",
|
||||
"Database not specified".to_owned(),
|
||||
)
|
||||
.await?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if self.verifier.user_provider.is_some() {
|
||||
client.set_state(PgWireConnectionState::AuthenticationInProgress);
|
||||
client
|
||||
@@ -165,17 +194,13 @@ impl StartupHandler for PgAuthStartupHandler {
|
||||
if let Ok(true) = self.verifier.verify_pwd(pwd.password(), login_info).await {
|
||||
auth::finish_authentication(client, &self.param_provider).await
|
||||
} else {
|
||||
let error_info = ErrorInfo::new(
|
||||
"FATAL".to_owned(),
|
||||
"28P01".to_owned(),
|
||||
send_error(
|
||||
client,
|
||||
"FATAL",
|
||||
"28P01",
|
||||
"Password authentication failed".to_owned(),
|
||||
);
|
||||
let error = ErrorResponse::from(error_info);
|
||||
|
||||
client
|
||||
.feed(PgWireBackendMessage::ErrorResponse(error))
|
||||
.await?;
|
||||
client.close().await?;
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
@@ -183,3 +208,17 @@ impl StartupHandler for PgAuthStartupHandler {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_error<C>(client: &mut C, level: &str, code: &str, message: String) -> PgWireResult<()>
|
||||
where
|
||||
C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send,
|
||||
C::Error: Debug,
|
||||
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
|
||||
{
|
||||
let error = ErrorResponse::from(ErrorInfo::new(level.to_owned(), code.to_owned(), message));
|
||||
client
|
||||
.feed(PgWireBackendMessage::ErrorResponse(error))
|
||||
.await?;
|
||||
client.close().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -29,7 +29,7 @@ use crate::auth::UserProviderRef;
|
||||
use crate::error::Result;
|
||||
use crate::postgres::auth_handler::PgAuthStartupHandler;
|
||||
use crate::postgres::handler::PostgresServerHandler;
|
||||
use crate::query_handler::SqlQueryHandlerRef;
|
||||
use crate::query_handler::{CatalogHandlerRef, SqlQueryHandlerRef};
|
||||
use crate::server::{AbortableStream, BaseTcpServer, Server};
|
||||
use crate::tls::TlsOption;
|
||||
|
||||
@@ -44,6 +44,7 @@ impl PostgresServer {
|
||||
/// Creates a new Postgres server with provided query_handler and async runtime
|
||||
pub fn new(
|
||||
query_handler: SqlQueryHandlerRef,
|
||||
catalog_handler: CatalogHandlerRef,
|
||||
tls: TlsOption,
|
||||
io_runtime: Arc<Runtime>,
|
||||
user_provider: Option<UserProviderRef>,
|
||||
@@ -52,6 +53,7 @@ impl PostgresServer {
|
||||
let startup_handler = Arc::new(PgAuthStartupHandler::new(
|
||||
user_provider,
|
||||
tls.should_force_tls(),
|
||||
catalog_handler,
|
||||
));
|
||||
PostgresServer {
|
||||
base_server: BaseTcpServer::create_server("Postgres", io_runtime),
|
||||
|
||||
@@ -43,6 +43,7 @@ pub type OpentsdbProtocolHandlerRef = Arc<dyn OpentsdbProtocolHandler + Send + S
|
||||
pub type InfluxdbLineProtocolHandlerRef = Arc<dyn InfluxdbLineProtocolHandler + Send + Sync>;
|
||||
pub type PrometheusProtocolHandlerRef = Arc<dyn PrometheusProtocolHandler + Send + Sync>;
|
||||
pub type ScriptHandlerRef = Arc<dyn ScriptHandler + Send + Sync>;
|
||||
pub type CatalogHandlerRef = Arc<dyn CatalogHandler + Send + Sync>;
|
||||
|
||||
#[async_trait]
|
||||
pub trait SqlQueryHandler {
|
||||
@@ -100,3 +101,8 @@ pub trait PrometheusProtocolHandler {
|
||||
/// Handling push gateway requests
|
||||
async fn ingest_metrics(&self, metrics: Metrics) -> Result<()>;
|
||||
}
|
||||
|
||||
pub trait CatalogHandler {
|
||||
/// check if schema is valid
|
||||
fn is_valid_schema(&self, catalog: &str, schema: &str) -> Result<bool>;
|
||||
}
|
||||
|
||||
@@ -23,7 +23,7 @@ use common_query::Output;
|
||||
use query::{QueryEngineFactory, QueryEngineRef};
|
||||
use servers::error::Result;
|
||||
use servers::query_handler::{
|
||||
ScriptHandler, ScriptHandlerRef, SqlQueryHandler, SqlQueryHandlerRef,
|
||||
CatalogHandler, ScriptHandler, ScriptHandlerRef, SqlQueryHandler, SqlQueryHandlerRef,
|
||||
};
|
||||
use table::test_util::MemTable;
|
||||
|
||||
@@ -92,6 +92,12 @@ impl ScriptHandler for DummyInstance {
|
||||
}
|
||||
}
|
||||
|
||||
impl CatalogHandler for DummyInstance {
|
||||
fn is_valid_schema(&self, catalog: &str, schema: &str) -> Result<bool> {
|
||||
Ok(catalog == DEFAULT_CATALOG_NAME && schema == DEFAULT_SCHEMA_NAME)
|
||||
}
|
||||
}
|
||||
|
||||
fn create_testing_instance(table: MemTable) -> DummyInstance {
|
||||
let table_name = table.table_name().to_string();
|
||||
let table = Arc::new(table);
|
||||
|
||||
@@ -31,14 +31,14 @@ use servers::tls::TlsOption;
|
||||
use table::test_util::MemTable;
|
||||
use tokio_postgres::{Client, Error as PgError, NoTls, SimpleQueryMessage};
|
||||
|
||||
use crate::create_testing_sql_query_handler;
|
||||
use crate::create_testing_instance;
|
||||
|
||||
fn create_postgres_server(
|
||||
table: MemTable,
|
||||
check_pwd: bool,
|
||||
tls: TlsOption,
|
||||
) -> Result<Box<dyn Server>> {
|
||||
let query_handler = create_testing_sql_query_handler(table);
|
||||
let instance = Arc::new(create_testing_instance(table));
|
||||
let io_runtime = Arc::new(
|
||||
RuntimeBuilder::default()
|
||||
.worker_threads(4)
|
||||
@@ -55,7 +55,8 @@ fn create_postgres_server(
|
||||
};
|
||||
|
||||
Ok(Box::new(PostgresServer::new(
|
||||
query_handler,
|
||||
instance.clone(),
|
||||
instance,
|
||||
tls,
|
||||
io_runtime,
|
||||
user_provider,
|
||||
@@ -239,11 +240,11 @@ async fn test_server_secure_require_client_secure() -> Result<()> {
|
||||
async fn test_using_db() -> Result<()> {
|
||||
let server_port = start_test_server(TlsOption::default()).await?;
|
||||
|
||||
let client = create_connection_with_given_db(server_port, "testdb")
|
||||
.await
|
||||
.unwrap();
|
||||
let result = client.simple_query("SELECT uint32s FROM numbers").await;
|
||||
assert!(result.is_err());
|
||||
let client = create_connection_with_given_db(server_port, "testdb").await;
|
||||
assert!(client.is_err());
|
||||
|
||||
let client = create_connection_without_db(server_port).await;
|
||||
assert!(client.is_err());
|
||||
|
||||
let client = create_connection_with_given_db(server_port, DEFAULT_SCHEMA_NAME)
|
||||
.await
|
||||
@@ -284,11 +285,14 @@ async fn create_secure_connection(
|
||||
) -> std::result::Result<Client, PgError> {
|
||||
let url = if with_pwd {
|
||||
format!(
|
||||
"sslmode=require host=127.0.0.1 port={} user=test_user password=test_pwd connect_timeout=2",
|
||||
port
|
||||
"sslmode=require host=127.0.0.1 port={} user=test_user password=test_pwd connect_timeout=2, dbname={}",
|
||||
port, DEFAULT_SCHEMA_NAME
|
||||
)
|
||||
} else {
|
||||
format!("host=127.0.0.1 port={} connect_timeout=2", port)
|
||||
format!(
|
||||
"host=127.0.0.1 port={} connect_timeout=2 dbname={}",
|
||||
port, DEFAULT_SCHEMA_NAME
|
||||
)
|
||||
};
|
||||
|
||||
let mut config = rustls::ClientConfig::builder()
|
||||
@@ -312,11 +316,14 @@ async fn create_plain_connection(
|
||||
) -> std::result::Result<Client, PgError> {
|
||||
let url = if with_pwd {
|
||||
format!(
|
||||
"host=127.0.0.1 port={} user=test_user password=test_pwd connect_timeout=2",
|
||||
port
|
||||
"host=127.0.0.1 port={} user=test_user password=test_pwd connect_timeout=2 dbname={}",
|
||||
port, DEFAULT_SCHEMA_NAME
|
||||
)
|
||||
} else {
|
||||
format!("host=127.0.0.1 port={} connect_timeout=2", port)
|
||||
format!(
|
||||
"host=127.0.0.1 port={} connect_timeout=2 dbname={}",
|
||||
port, DEFAULT_SCHEMA_NAME
|
||||
)
|
||||
};
|
||||
let (client, conn) = tokio_postgres::connect(&url, NoTls).await?;
|
||||
tokio::spawn(conn);
|
||||
@@ -336,6 +343,13 @@ async fn create_connection_with_given_db(
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
async fn create_connection_without_db(port: u16) -> std::result::Result<Client, PgError> {
|
||||
let url = format!("host=127.0.0.1 port={} connect_timeout=2", port);
|
||||
let (client, conn) = tokio_postgres::connect(&url, NoTls).await?;
|
||||
tokio::spawn(conn);
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
fn resolve_result(resp: &SimpleQueryMessage, col_index: usize) -> Option<&str> {
|
||||
match resp {
|
||||
&SimpleQueryMessage::Row(ref r) => r.get(col_index),
|
||||
|
||||
Reference in New Issue
Block a user