mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-05-20 23:10:37 +00:00
feat: generating context in http middleware & mysql auth method (#453)
This commit is contained in:
@@ -4,6 +4,7 @@ use serde::{Deserialize, Serialize};
|
||||
pub struct PostgresOptions {
|
||||
pub addr: String,
|
||||
pub runtime_size: usize,
|
||||
pub check_pwd: bool,
|
||||
}
|
||||
|
||||
impl Default for PostgresOptions {
|
||||
@@ -11,6 +12,7 @@ impl Default for PostgresOptions {
|
||||
Self {
|
||||
addr: "0.0.0.0:4003".to_string(),
|
||||
runtime_size: 2,
|
||||
check_pwd: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -73,8 +73,11 @@ impl Services {
|
||||
.context(error::RuntimeResourceSnafu)?,
|
||||
);
|
||||
|
||||
let pg_server =
|
||||
Box::new(PostgresServer::new(instance.clone(), pg_io_runtime)) as Box<dyn Server>;
|
||||
let pg_server = Box::new(PostgresServer::new(
|
||||
instance.clone(),
|
||||
opts.check_pwd,
|
||||
pg_io_runtime,
|
||||
)) as Box<dyn Server>;
|
||||
|
||||
Some((pg_server, pg_addr))
|
||||
} else {
|
||||
|
||||
@@ -32,6 +32,7 @@ opensrv-mysql = "0.1"
|
||||
pgwire = { version = "0.4" }
|
||||
prost = "0.11"
|
||||
regex = "1.6"
|
||||
rand = "0.8"
|
||||
schemars = "0.8"
|
||||
serde = "1.0"
|
||||
serde_json = "1.0"
|
||||
|
||||
@@ -2,13 +2,13 @@ use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use snafu::OptionExt;
|
||||
|
||||
use crate::context::AuthMethod::Token;
|
||||
use crate::context::Channel::HTTP;
|
||||
use crate::error::{BuildingContextSnafu, Result};
|
||||
|
||||
type CtxFnRef = Arc<dyn Fn(&Context) -> bool + Send + Sync>;
|
||||
|
||||
#[derive(Default, Serialize, Deserialize)]
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct Context {
|
||||
pub exec_info: ExecInfo,
|
||||
pub client_info: ClientInfo,
|
||||
@@ -19,16 +19,70 @@ pub struct Context {
|
||||
}
|
||||
|
||||
impl Context {
|
||||
pub fn new() -> Self {
|
||||
Context::default()
|
||||
}
|
||||
|
||||
pub fn add_predicate(&mut self, predicate: CtxFnRef) {
|
||||
self.predicates.push(predicate);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Serialize, Deserialize)]
|
||||
#[derive(Default)]
|
||||
pub struct CtxBuilder {
|
||||
client_addr: Option<String>,
|
||||
|
||||
username: Option<String>,
|
||||
from_channel: Option<Channel>,
|
||||
auth_method: Option<AuthMethod>,
|
||||
}
|
||||
|
||||
impl CtxBuilder {
|
||||
pub fn new() -> CtxBuilder {
|
||||
CtxBuilder::default()
|
||||
}
|
||||
|
||||
pub fn client_addr(mut self, addr: Option<String>) -> CtxBuilder {
|
||||
self.client_addr = addr;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn set_channel(mut self, channel: Option<Channel>) -> CtxBuilder {
|
||||
self.from_channel = channel;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn set_auth_method(mut self, auth_method: Option<AuthMethod>) -> CtxBuilder {
|
||||
self.auth_method = auth_method;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn set_username(mut self, username: Option<String>) -> CtxBuilder {
|
||||
self.username = username;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> Result<Context> {
|
||||
Ok(Context {
|
||||
client_info: ClientInfo {
|
||||
client_host: self.client_addr.context(BuildingContextSnafu {
|
||||
err_msg: "unknown client addr while building ctx",
|
||||
})?,
|
||||
},
|
||||
user_info: UserInfo {
|
||||
username: self.username,
|
||||
from_channel: self.from_channel.context(BuildingContextSnafu {
|
||||
err_msg: "unknown channel while building ctx",
|
||||
})?,
|
||||
auth_method: self.auth_method.context(BuildingContextSnafu {
|
||||
err_msg: "unknown auth method while building ctx",
|
||||
})?,
|
||||
},
|
||||
|
||||
exec_info: ExecInfo::default(),
|
||||
quota: Quota::default(),
|
||||
predicates: vec![],
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct ExecInfo {
|
||||
pub catalog: Option<String>,
|
||||
pub schema: Option<String>,
|
||||
@@ -37,34 +91,29 @@ pub struct ExecInfo {
|
||||
pub trace_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Default, Serialize, Deserialize)]
|
||||
pub struct ClientInfo {
|
||||
pub client_host: Option<String>,
|
||||
}
|
||||
|
||||
impl ClientInfo {
|
||||
pub fn new(host: Option<String>) -> Self {
|
||||
ClientInfo { client_host: host }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Serialize, Deserialize)]
|
||||
pub struct UserInfo {
|
||||
pub username: Option<String>,
|
||||
pub from_channel: Option<Channel>,
|
||||
pub auth_method: Option<AuthMethod>,
|
||||
}
|
||||
|
||||
impl UserInfo {
|
||||
pub fn with_http_token(token: String) -> Self {
|
||||
UserInfo {
|
||||
username: None,
|
||||
from_channel: Some(HTTP),
|
||||
auth_method: Some(Token(token)),
|
||||
impl Default for ExecInfo {
|
||||
fn default() -> Self {
|
||||
ExecInfo {
|
||||
catalog: Some("greptime".to_string()),
|
||||
schema: Some("public".to_string()),
|
||||
extra_opts: HashMap::new(),
|
||||
trace_id: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Serialize, Deserialize)]
|
||||
pub struct ClientInfo {
|
||||
pub client_host: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct UserInfo {
|
||||
pub username: Option<String>,
|
||||
pub from_channel: Channel,
|
||||
pub auth_method: AuthMethod,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum Channel {
|
||||
GRPC,
|
||||
@@ -78,10 +127,17 @@ pub enum AuthMethod {
|
||||
Password {
|
||||
hash_method: AuthHashMethod,
|
||||
hashed_value: Vec<u8>,
|
||||
salt: Vec<u8>,
|
||||
},
|
||||
Token(String),
|
||||
}
|
||||
|
||||
impl Default for AuthMethod {
|
||||
fn default() -> Self {
|
||||
AuthMethod::None
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum AuthHashMethod {
|
||||
DoubleSha1,
|
||||
@@ -97,16 +153,26 @@ pub struct Quota {
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use std::collections::HashMap;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::context::AuthMethod::Token;
|
||||
use crate::context::Channel::HTTP;
|
||||
use crate::context::{ClientInfo, Context, ExecInfo, Quota, UserInfo};
|
||||
use crate::context::{Channel, Context, CtxBuilder, UserInfo};
|
||||
|
||||
#[test]
|
||||
fn test_predicate() {
|
||||
let mut ctx = Context::default();
|
||||
let mut ctx = Context {
|
||||
exec_info: Default::default(),
|
||||
client_info: Default::default(),
|
||||
user_info: UserInfo {
|
||||
username: None,
|
||||
from_channel: Channel::GRPC,
|
||||
auth_method: Default::default(),
|
||||
},
|
||||
quota: Default::default(),
|
||||
predicates: vec![],
|
||||
};
|
||||
ctx.add_predicate(Arc::new(|ctx: &Context| {
|
||||
ctx.quota.total > ctx.quota.consumed
|
||||
}));
|
||||
@@ -123,43 +189,27 @@ mod test {
|
||||
|
||||
#[test]
|
||||
fn test_build() {
|
||||
let ctx = Context {
|
||||
exec_info: ExecInfo {
|
||||
catalog: Some(String::from("greptime")),
|
||||
schema: Some(String::from("public")),
|
||||
extra_opts: HashMap::new(),
|
||||
trace_id: None,
|
||||
},
|
||||
client_info: ClientInfo::new(Some(String::from("127.0.0.1:4001"))),
|
||||
user_info: UserInfo::with_http_token(String::from("HELLO")),
|
||||
quota: Quota {
|
||||
total: 10,
|
||||
consumed: 5,
|
||||
estimated: 2,
|
||||
},
|
||||
predicates: vec![],
|
||||
};
|
||||
let ctx = CtxBuilder::new()
|
||||
.client_addr(Some("127.0.0.1:4001".to_string()))
|
||||
.set_channel(Some(HTTP))
|
||||
.set_auth_method(Some(Token("HELLO".to_string())))
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(ctx.exec_info.catalog.unwrap(), String::from("greptime"));
|
||||
assert_eq!(ctx.exec_info.schema.unwrap(), String::from("public"));
|
||||
assert_eq!(ctx.exec_info.extra_opts.capacity(), 0);
|
||||
assert_eq!(ctx.exec_info.extra_opts.len(), 0);
|
||||
assert_eq!(ctx.exec_info.trace_id, None);
|
||||
|
||||
assert_eq!(
|
||||
ctx.client_info.client_host.unwrap(),
|
||||
String::from("127.0.0.1:4001")
|
||||
);
|
||||
assert_eq!(ctx.client_info.client_host, String::from("127.0.0.1:4001"));
|
||||
|
||||
assert_eq!(ctx.user_info.username, None);
|
||||
assert_eq!(ctx.user_info.from_channel.unwrap(), HTTP);
|
||||
assert_eq!(
|
||||
ctx.user_info.auth_method.unwrap(),
|
||||
Token(String::from("HELLO"))
|
||||
);
|
||||
assert_eq!(ctx.user_info.from_channel, HTTP);
|
||||
assert_eq!(ctx.user_info.auth_method, Token(String::from("HELLO")));
|
||||
|
||||
assert!(ctx.quota.total > 0);
|
||||
assert!(ctx.quota.consumed > 0);
|
||||
assert!(ctx.quota.estimated > 0);
|
||||
assert_eq!(ctx.quota.total, 0);
|
||||
assert_eq!(ctx.quota.consumed, 0);
|
||||
assert_eq!(ctx.quota.estimated, 0);
|
||||
|
||||
assert_eq!(ctx.predicates.capacity(), 0);
|
||||
}
|
||||
|
||||
@@ -174,6 +174,12 @@ pub enum Error {
|
||||
#[snafu(backtrace)]
|
||||
source: BoxedError,
|
||||
},
|
||||
|
||||
#[snafu(display("Failed to build context, msg: {}", err_msg))]
|
||||
BuildingContext {
|
||||
err_msg: String,
|
||||
backtrace: Backtrace,
|
||||
},
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
@@ -192,7 +198,8 @@ impl ErrorExt for Error {
|
||||
| AlreadyStarted { .. }
|
||||
| InvalidPromRemoteReadQueryResult { .. }
|
||||
| TcpBind { .. }
|
||||
| GrpcReflectionService { .. } => StatusCode::Internal,
|
||||
| GrpcReflectionService { .. }
|
||||
| BuildingContext { .. } => StatusCode::Internal,
|
||||
|
||||
InsertScript { source, .. }
|
||||
| ExecuteScript { source, .. }
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
mod context;
|
||||
pub mod handler;
|
||||
pub mod influxdb;
|
||||
pub mod opentsdb;
|
||||
@@ -11,6 +12,7 @@ use aide::axum::routing as apirouting;
|
||||
use aide::axum::{ApiRouter, IntoApiResponse};
|
||||
use aide::openapi::{Info, OpenApi, Server as OpenAPIServer};
|
||||
use async_trait::async_trait;
|
||||
use axum::middleware::{self};
|
||||
use axum::response::Html;
|
||||
use axum::Extension;
|
||||
use axum::{error_handling::HandleErrorLayer, response::Json, routing, BoxError, Router};
|
||||
@@ -313,7 +315,9 @@ impl HttpServer {
|
||||
.layer(HandleErrorLayer::new(handle_error))
|
||||
.layer(TraceLayer::new_for_http())
|
||||
// TODO(LFC): make timeout configurable
|
||||
.layer(TimeoutLayer::new(Duration::from_secs(30))),
|
||||
.layer(TimeoutLayer::new(Duration::from_secs(30)))
|
||||
// custom layer
|
||||
.layer(middleware::from_fn(context::build_ctx)),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
48
src/servers/src/http/context.rs
Normal file
48
src/servers/src/http/context.rs
Normal file
@@ -0,0 +1,48 @@
|
||||
use axum::{
|
||||
http,
|
||||
http::{Request, StatusCode},
|
||||
middleware::Next,
|
||||
response::Response,
|
||||
};
|
||||
use common_telemetry::error;
|
||||
|
||||
use crate::context::{AuthMethod, Channel, CtxBuilder};
|
||||
|
||||
pub async fn build_ctx<B>(mut req: Request<B>, next: Next<B>) -> Result<Response, StatusCode> {
|
||||
let auth_option = req
|
||||
.headers()
|
||||
.get(http::header::AUTHORIZATION)
|
||||
.map(|header| {
|
||||
header
|
||||
.to_str()
|
||||
.map(|header_str| match header_str.split_once(' ') {
|
||||
Some((name, content)) if name == "Bearer" || name == "TOKEN" => {
|
||||
AuthMethod::Token(String::from(content))
|
||||
}
|
||||
_ => AuthMethod::None,
|
||||
})
|
||||
.unwrap_or(AuthMethod::None)
|
||||
})
|
||||
.or(Some(AuthMethod::None));
|
||||
|
||||
match CtxBuilder::new()
|
||||
.client_addr(
|
||||
req.headers()
|
||||
.get(http::header::HOST)
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.map(|h| h.to_string()),
|
||||
)
|
||||
.set_channel(Some(Channel::HTTP))
|
||||
.set_auth_method(auth_option)
|
||||
.build()
|
||||
{
|
||||
Ok(ctx) => {
|
||||
req.extensions_mut().insert(ctx);
|
||||
Ok(next.run(req).await)
|
||||
}
|
||||
Err(e) => {
|
||||
error!(e; "fail to create context");
|
||||
Err(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,12 +1,19 @@
|
||||
use std::io;
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use common_telemetry::error;
|
||||
use opensrv_mysql::AsyncMysqlShim;
|
||||
use opensrv_mysql::ErrorKind;
|
||||
use opensrv_mysql::ParamParser;
|
||||
use opensrv_mysql::QueryResultWriter;
|
||||
use opensrv_mysql::StatementMetaWriter;
|
||||
use rand::RngCore;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use crate::context::AuthHashMethod::DoubleSha1;
|
||||
use crate::context::Channel::MYSQL;
|
||||
use crate::context::{AuthMethod, Context, CtxBuilder};
|
||||
use crate::error::{self, Result};
|
||||
use crate::mysql::writer::MysqlResultWriter;
|
||||
use crate::query_handler::SqlQueryHandlerRef;
|
||||
@@ -14,11 +21,32 @@ use crate::query_handler::SqlQueryHandlerRef;
|
||||
// An intermediate shim for executing MySQL queries.
|
||||
pub struct MysqlInstanceShim {
|
||||
query_handler: SqlQueryHandlerRef,
|
||||
salt: [u8; 20],
|
||||
client_addr: String,
|
||||
ctx: Arc<RwLock<Option<Context>>>,
|
||||
}
|
||||
|
||||
impl MysqlInstanceShim {
|
||||
pub fn create(query_handler: SqlQueryHandlerRef) -> MysqlInstanceShim {
|
||||
MysqlInstanceShim { query_handler }
|
||||
pub fn create(query_handler: SqlQueryHandlerRef, client_addr: String) -> MysqlInstanceShim {
|
||||
// init a random salt
|
||||
let mut bs = vec![0u8; 20];
|
||||
let mut rng = rand::thread_rng();
|
||||
rng.fill_bytes(bs.as_mut());
|
||||
|
||||
let mut scramble: [u8; 20] = [0; 20];
|
||||
for i in 0..20 {
|
||||
scramble[i] = bs[i] & 0x7fu8;
|
||||
if scramble[i] == b'\0' || scramble[i] == b'$' {
|
||||
scramble[i] += 1;
|
||||
}
|
||||
}
|
||||
|
||||
MysqlInstanceShim {
|
||||
query_handler,
|
||||
salt: scramble,
|
||||
client_addr,
|
||||
ctx: Arc::new(RwLock::new(None)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,6 +54,48 @@ impl MysqlInstanceShim {
|
||||
impl<W: io::Write + Send + Sync> AsyncMysqlShim<W> for MysqlInstanceShim {
|
||||
type Error = error::Error;
|
||||
|
||||
fn salt(&self) -> [u8; 20] {
|
||||
self.salt
|
||||
}
|
||||
|
||||
async fn authenticate(
|
||||
&self,
|
||||
_auth_plugin: &str,
|
||||
username: &[u8],
|
||||
salt: &[u8],
|
||||
auth_data: &[u8],
|
||||
) -> bool {
|
||||
// if not specified then **root** will be used
|
||||
let username = String::from_utf8_lossy(username);
|
||||
let client_addr = self.client_addr.clone();
|
||||
let auth_method = match auth_data.len() {
|
||||
0 => AuthMethod::None,
|
||||
_ => AuthMethod::Password {
|
||||
hash_method: DoubleSha1,
|
||||
hashed_value: auth_data.to_vec(),
|
||||
salt: salt.to_vec(),
|
||||
},
|
||||
};
|
||||
|
||||
return match CtxBuilder::new()
|
||||
.client_addr(Some(client_addr))
|
||||
.set_channel(Some(MYSQL))
|
||||
.set_username(Some(username.to_string()))
|
||||
.set_auth_method(Some(auth_method))
|
||||
.build()
|
||||
{
|
||||
Ok(ctx) => {
|
||||
let mut a = self.ctx.write().await;
|
||||
*a = Some(ctx);
|
||||
true
|
||||
}
|
||||
Err(e) => {
|
||||
error!(e; "create ctx failed when authing mysql conn");
|
||||
false
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
async fn on_prepare<'a>(
|
||||
&'a mut self,
|
||||
_: &'a str,
|
||||
|
||||
@@ -59,7 +59,7 @@ impl MysqlServer {
|
||||
query_handler: SqlQueryHandlerRef,
|
||||
) -> Result<()> {
|
||||
info!("MySQL connection coming from: {}", stream.peer_addr()?);
|
||||
let shim = MysqlInstanceShim::create(query_handler);
|
||||
let shim = MysqlInstanceShim::create(query_handler, stream.peer_addr()?.to_string());
|
||||
// TODO(LFC): Relate "handler" with MySQL session; also deal with panics there.
|
||||
let _handler = io_runtime.spawn(AsyncMysqlIntermediary::run_on(shim, stream));
|
||||
Ok(())
|
||||
|
||||
110
src/servers/src/postgres/auth_handler.rs
Normal file
110
src/servers/src/postgres/auth_handler.rs
Normal file
@@ -0,0 +1,110 @@
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Debug;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::{Sink, SinkExt};
|
||||
use pgwire::api::auth::{ServerParameterProvider, StartupHandler};
|
||||
use pgwire::api::{auth, ClientInfo, PgWireConnectionState};
|
||||
use pgwire::error::ErrorInfo;
|
||||
use pgwire::error::{PgWireError, PgWireResult};
|
||||
use pgwire::messages::response::ErrorResponse;
|
||||
use pgwire::messages::startup::Authentication;
|
||||
use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage};
|
||||
|
||||
struct PgPwdVerifier;
|
||||
|
||||
impl PgPwdVerifier {
|
||||
async fn verify_pwd(&self, _pwd: &str, _meta: HashMap<String, String>) -> PgWireResult<bool> {
|
||||
Ok(true)
|
||||
}
|
||||
}
|
||||
|
||||
struct GreptimeDBStartupParameters {
|
||||
version: &'static str,
|
||||
}
|
||||
|
||||
impl GreptimeDBStartupParameters {
|
||||
fn new() -> GreptimeDBStartupParameters {
|
||||
GreptimeDBStartupParameters {
|
||||
version: env!("CARGO_PKG_VERSION"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ServerParameterProvider for GreptimeDBStartupParameters {
|
||||
fn server_parameters<C>(&self, _client: &C) -> Option<HashMap<String, String>>
|
||||
where
|
||||
C: ClientInfo,
|
||||
{
|
||||
let mut params = HashMap::with_capacity(1);
|
||||
params.insert("server_version".to_owned(), self.version.to_owned());
|
||||
|
||||
Some(params)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PgAuthStartupHandler {
|
||||
verifier: PgPwdVerifier,
|
||||
param_provider: GreptimeDBStartupParameters,
|
||||
with_pwd: bool,
|
||||
}
|
||||
|
||||
impl PgAuthStartupHandler {
|
||||
pub fn new(with_pwd: bool) -> Self {
|
||||
PgAuthStartupHandler {
|
||||
verifier: PgPwdVerifier,
|
||||
param_provider: GreptimeDBStartupParameters::new(),
|
||||
with_pwd,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl StartupHandler for PgAuthStartupHandler {
|
||||
async fn on_startup<C>(
|
||||
&self,
|
||||
client: &mut C,
|
||||
message: &PgWireFrontendMessage,
|
||||
) -> PgWireResult<()>
|
||||
where
|
||||
C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send,
|
||||
C::Error: Debug,
|
||||
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
|
||||
{
|
||||
match message {
|
||||
PgWireFrontendMessage::Startup(ref startup) => {
|
||||
auth::save_startup_parameters_to_metadata(client, startup);
|
||||
if self.with_pwd {
|
||||
client.set_state(PgWireConnectionState::AuthenticationInProgress);
|
||||
client
|
||||
.send(PgWireBackendMessage::Authentication(
|
||||
Authentication::CleartextPassword,
|
||||
))
|
||||
.await?;
|
||||
} else {
|
||||
auth::finish_authentication(client, &self.param_provider).await;
|
||||
}
|
||||
}
|
||||
PgWireFrontendMessage::Password(ref pwd) => {
|
||||
let meta = client.metadata().clone();
|
||||
if let Ok(true) = self.verifier.verify_pwd(pwd.password(), meta).await {
|
||||
auth::finish_authentication(client, &self.param_provider).await
|
||||
} else {
|
||||
let error_info = ErrorInfo::new(
|
||||
"FATAL".to_owned(),
|
||||
"28P01".to_owned(),
|
||||
"Password authentication failed".to_owned(),
|
||||
);
|
||||
let error = ErrorResponse::from(error_info);
|
||||
|
||||
client
|
||||
.feed(PgWireBackendMessage::ErrorResponse(error))
|
||||
.await?;
|
||||
client.close().await?;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
mod auth_handler;
|
||||
mod handler;
|
||||
mod server;
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
use std::collections::HashMap;
|
||||
use std::future::Future;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
@@ -6,77 +5,31 @@ use std::sync::Arc;
|
||||
use async_trait::async_trait;
|
||||
use common_runtime::Runtime;
|
||||
use common_telemetry::logging::error;
|
||||
use futures::{Sink, StreamExt};
|
||||
use pgwire::api::auth::{self, ServerParameterProvider, StartupHandler};
|
||||
use pgwire::api::ClientInfo;
|
||||
use pgwire::error::{PgWireError, PgWireResult};
|
||||
use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage};
|
||||
use futures::StreamExt;
|
||||
use pgwire::tokio::process_socket;
|
||||
use tokio;
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::postgres::auth_handler::PgAuthStartupHandler;
|
||||
use crate::postgres::handler::PostgresServerHandler;
|
||||
use crate::query_handler::SqlQueryHandlerRef;
|
||||
use crate::server::{AbortableStream, BaseTcpServer, Server};
|
||||
|
||||
struct SimpleStartupHandler;
|
||||
|
||||
#[async_trait]
|
||||
impl StartupHandler for SimpleStartupHandler {
|
||||
async fn on_startup<C>(
|
||||
&self,
|
||||
client: &mut C,
|
||||
message: &PgWireFrontendMessage,
|
||||
) -> PgWireResult<()>
|
||||
where
|
||||
C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send,
|
||||
C::Error: std::fmt::Debug,
|
||||
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
|
||||
{
|
||||
if let PgWireFrontendMessage::Startup(ref startup) = message {
|
||||
auth::save_startup_parameters_to_metadata(client, startup);
|
||||
auth::finish_authentication(client, &GreptimeDBStartupParameters::new()).await;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
struct GreptimeDBStartupParameters {
|
||||
version: &'static str,
|
||||
}
|
||||
|
||||
impl GreptimeDBStartupParameters {
|
||||
fn new() -> GreptimeDBStartupParameters {
|
||||
GreptimeDBStartupParameters {
|
||||
version: env!("CARGO_PKG_VERSION"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ServerParameterProvider for GreptimeDBStartupParameters {
|
||||
fn server_parameters<C>(&self, _client: &C) -> Option<HashMap<String, String>>
|
||||
where
|
||||
C: ClientInfo,
|
||||
{
|
||||
let mut params = HashMap::with_capacity(1);
|
||||
params.insert("server_version".to_owned(), self.version.to_owned());
|
||||
|
||||
Some(params)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PostgresServer {
|
||||
base_server: BaseTcpServer,
|
||||
auth_handler: Arc<SimpleStartupHandler>,
|
||||
auth_handler: Arc<PgAuthStartupHandler>,
|
||||
query_handler: Arc<PostgresServerHandler>,
|
||||
}
|
||||
|
||||
impl PostgresServer {
|
||||
/// Creates a new Postgres server with provided query_handler and async runtime
|
||||
pub fn new(query_handler: SqlQueryHandlerRef, io_runtime: Arc<Runtime>) -> PostgresServer {
|
||||
pub fn new(
|
||||
query_handler: SqlQueryHandlerRef,
|
||||
check_pwd: bool,
|
||||
io_runtime: Arc<Runtime>,
|
||||
) -> PostgresServer {
|
||||
let postgres_handler = Arc::new(PostgresServerHandler::new(query_handler));
|
||||
let startup_handler = Arc::new(SimpleStartupHandler);
|
||||
let startup_handler = Arc::new(PgAuthStartupHandler::new(check_pwd));
|
||||
PostgresServer {
|
||||
base_server: BaseTcpServer::create_server("Postgres", io_runtime),
|
||||
auth_handler: startup_handler,
|
||||
|
||||
@@ -63,10 +63,10 @@ async fn test_shutdown_mysql_server() -> Result<()> {
|
||||
let server_port = server_addr.port();
|
||||
|
||||
let mut join_handles = vec![];
|
||||
for _ in 0..2 {
|
||||
for index in 0..2 {
|
||||
join_handles.push(tokio::spawn(async move {
|
||||
for _ in 0..1000 {
|
||||
match create_connection(server_port).await {
|
||||
match create_connection(server_port, index == 1).await {
|
||||
Ok(mut connection) => {
|
||||
let result: u32 = connection
|
||||
.query_first("SELECT uint32s FROM numbers LIMIT 1")
|
||||
@@ -114,7 +114,7 @@ async fn test_query_all_datatypes() -> Result<()> {
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let server_addr = mysql_server.start(listening).await.unwrap();
|
||||
|
||||
let mut connection = create_connection(server_addr.port()).await.unwrap();
|
||||
let mut connection = create_connection(server_addr.port(), false).await.unwrap();
|
||||
let mut result = connection
|
||||
.query_iter("SELECT * FROM all_datatypes LIMIT 3")
|
||||
.await
|
||||
@@ -149,11 +149,13 @@ async fn test_query_concurrently() -> Result<()> {
|
||||
let threads = 4;
|
||||
let expect_executed_queries_per_worker = 1000;
|
||||
let mut join_handles = vec![];
|
||||
for _ in 0..threads {
|
||||
for index in 0..threads {
|
||||
join_handles.push(tokio::spawn(async move {
|
||||
let mut rand: StdRng = rand::SeedableRng::from_entropy();
|
||||
|
||||
let mut connection = create_connection(server_port).await.unwrap();
|
||||
let mut connection = create_connection(server_port, index % 2 == 0)
|
||||
.await
|
||||
.unwrap();
|
||||
for _ in 0..expect_executed_queries_per_worker {
|
||||
let expected: u32 = rand.gen_range(0..100);
|
||||
let result: u32 = connection
|
||||
@@ -168,7 +170,9 @@ async fn test_query_concurrently() -> Result<()> {
|
||||
|
||||
let should_recreate_conn = expected == 1;
|
||||
if should_recreate_conn {
|
||||
connection = create_connection(server_port).await.unwrap();
|
||||
connection = create_connection(server_port, index % 2 == 0)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
expect_executed_queries_per_worker
|
||||
@@ -182,11 +186,16 @@ async fn test_query_concurrently() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_connection(port: u16) -> mysql_async::Result<mysql_async::Conn> {
|
||||
let opts = mysql_async::OptsBuilder::default()
|
||||
async fn create_connection(port: u16, with_pwd: bool) -> mysql_async::Result<mysql_async::Conn> {
|
||||
let mut opts = mysql_async::OptsBuilder::default()
|
||||
.ip_or_hostname("127.0.0.1")
|
||||
.tcp_port(port)
|
||||
.prefer_socket(false)
|
||||
.wait_timeout(Some(1000));
|
||||
|
||||
if with_pwd {
|
||||
opts = opts.pass(Some("default_pwd".to_string()));
|
||||
}
|
||||
|
||||
mysql_async::Conn::new(opts).await
|
||||
}
|
||||
|
||||
@@ -13,7 +13,7 @@ use tokio_postgres::{Client, Error as PgError, NoTls, SimpleQueryMessage};
|
||||
|
||||
use crate::create_testing_sql_query_handler;
|
||||
|
||||
fn create_postgres_server(table: MemTable) -> Result<Box<dyn Server>> {
|
||||
fn create_postgres_server(table: MemTable, check_pwd: bool) -> Result<Box<dyn Server>> {
|
||||
let query_handler = create_testing_sql_query_handler(table);
|
||||
let io_runtime = Arc::new(
|
||||
RuntimeBuilder::default()
|
||||
@@ -22,14 +22,18 @@ fn create_postgres_server(table: MemTable) -> Result<Box<dyn Server>> {
|
||||
.build()
|
||||
.unwrap(),
|
||||
);
|
||||
Ok(Box::new(PostgresServer::new(query_handler, io_runtime)))
|
||||
Ok(Box::new(PostgresServer::new(
|
||||
query_handler,
|
||||
check_pwd,
|
||||
io_runtime,
|
||||
)))
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
pub async fn test_start_postgres_server() -> Result<()> {
|
||||
let table = MemTable::default_numbers_table();
|
||||
|
||||
let pg_server = create_postgres_server(table)?;
|
||||
let pg_server = create_postgres_server(table, false)?;
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let result = pg_server.start(listening).await;
|
||||
assert!(result.is_ok());
|
||||
@@ -43,12 +47,19 @@ pub async fn test_start_postgres_server() -> Result<()> {
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_shutdown_pg_server() -> Result<()> {
|
||||
async fn test_shutdown_pg_server_range() -> Result<()> {
|
||||
assert!(test_shutdown_pg_server(false).await.is_ok());
|
||||
assert!(test_shutdown_pg_server(true).await.is_ok());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_shutdown_pg_server(with_pwd: bool) -> Result<()> {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
|
||||
let table = MemTable::default_numbers_table();
|
||||
|
||||
let postgres_server = create_postgres_server(table)?;
|
||||
let postgres_server = create_postgres_server(table, with_pwd)?;
|
||||
let result = postgres_server.shutdown().await;
|
||||
assert!(result
|
||||
.unwrap_err()
|
||||
@@ -63,7 +74,7 @@ async fn test_shutdown_pg_server() -> Result<()> {
|
||||
for _ in 0..2 {
|
||||
join_handles.push(tokio::spawn(async move {
|
||||
for _ in 0..1000 {
|
||||
match create_connection(server_port).await {
|
||||
match create_connection(server_port, with_pwd).await {
|
||||
Ok(connection) => {
|
||||
match connection
|
||||
.simple_query("SELECT uint32s FROM numbers LIMIT 1")
|
||||
@@ -107,7 +118,7 @@ async fn test_query_pg_concurrently() -> Result<()> {
|
||||
|
||||
let table = MemTable::default_numbers_table();
|
||||
|
||||
let pg_server = create_postgres_server(table)?;
|
||||
let pg_server = create_postgres_server(table, false)?;
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let server_addr = pg_server.start(listening).await.unwrap();
|
||||
let server_port = server_addr.port();
|
||||
@@ -119,7 +130,7 @@ async fn test_query_pg_concurrently() -> Result<()> {
|
||||
join_handles.push(tokio::spawn(async move {
|
||||
let mut rand: StdRng = rand::SeedableRng::from_entropy();
|
||||
|
||||
let mut client = create_connection(server_port).await.unwrap();
|
||||
let mut client = create_connection(server_port, false).await.unwrap();
|
||||
|
||||
for _k in 0..expect_executed_queries_per_worker {
|
||||
let expected: u32 = rand.gen_range(0..100);
|
||||
@@ -140,7 +151,7 @@ async fn test_query_pg_concurrently() -> Result<()> {
|
||||
// 1/100 chance to reconnect
|
||||
let should_recreate_conn = expected == 1;
|
||||
if should_recreate_conn {
|
||||
client = create_connection(server_port).await.unwrap();
|
||||
client = create_connection(server_port, false).await.unwrap();
|
||||
}
|
||||
}
|
||||
expect_executed_queries_per_worker
|
||||
@@ -154,8 +165,15 @@ async fn test_query_pg_concurrently() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn create_connection(port: u16) -> std::result::Result<Client, PgError> {
|
||||
let url = format!("host=127.0.0.1 port={} connect_timeout=2", port);
|
||||
async fn create_connection(port: u16, with_pwd: bool) -> std::result::Result<Client, PgError> {
|
||||
let url = if with_pwd {
|
||||
format!(
|
||||
"host=127.0.0.1 port={} user=test_user password=test_pwd connect_timeout=2",
|
||||
port
|
||||
)
|
||||
} else {
|
||||
format!("host=127.0.0.1 port={} connect_timeout=2", port)
|
||||
};
|
||||
let (client, conn) = tokio_postgres::connect(&url, NoTls).await?;
|
||||
tokio::spawn(conn);
|
||||
Ok(client)
|
||||
|
||||
Reference in New Issue
Block a user