feat: support http basic authentication (#733)

* feat: support http auth

* add some unit test and log

* fix

* cr

* remove unused #[derive(Clone)]
This commit is contained in:
fys
2022-12-13 10:44:33 +08:00
committed by GitHub
parent 9b093463cc
commit c5661ee362
13 changed files with 409 additions and 73 deletions

View File

@@ -67,9 +67,13 @@ pub enum StatusCode {
/// User not exist
UserNotFound = 7000,
/// Unsupported password type
UnsupportedPwdType = 7001,
UnsupportedPasswordType = 7001,
/// Username and password does not match
UserPwdMismatch = 7002,
UserPasswordMismatch = 7002,
/// Not found http authorization header
AuthHeaderNotFound = 7003,
/// Invalid http authorization header
InvalidAuthHeader = 7004,
// ====== End of auth related status code =====
}

View File

@@ -10,6 +10,7 @@ api = { path = "../api" }
async-trait = "0.1"
axum = "0.6"
axum-macros = "0.3"
base64 = "0.13"
bytes = "1.2"
common-base = { path = "../common/base" }
common-catalog = { path = "../common/catalog" }
@@ -23,6 +24,7 @@ common-time = { path = "../common/time" }
datatypes = { path = "../datatypes" }
futures = "0.3"
hex = { version = "0.4" }
http-body = "0.4"
humantime-serde = "1.1"
hyper = { version = "0.14", features = ["full"] }
influxdb_line_protocol = { git = "https://github.com/evenyag/influxdb_iox", branch = "feat/line-protocol" }

View File

@@ -24,7 +24,7 @@ use snafu::{Backtrace, ErrorCompat, Snafu};
pub trait UserProvider: Send + Sync {
fn name(&self) -> &str;
async fn auth(&self, id: Identity<'_>, pwd: Password<'_>) -> Result<UserInfo, Error>;
async fn auth(&self, id: Identity<'_>, password: Password<'_>) -> Result<UserInfo, Error>;
}
pub type UserProviderRef = Arc<dyn UserProvider>;
@@ -37,17 +37,17 @@ pub enum Identity<'a> {
UserId(Username<'a>, Option<HostOrIp<'a>>),
}
pub type HashedPwd<'a> = &'a [u8];
pub type HashedPassword<'a> = &'a [u8];
pub type Salt<'a> = &'a [u8];
pub type Pwd<'a> = &'a [u8];
/// Authentication information sent by the client.
pub enum Password<'a> {
PlainText(Pwd<'a>),
MysqlNativePwd(HashedPwd<'a>, Salt<'a>),
PgMD5(HashedPwd<'a>, Salt<'a>),
PlainText(&'a str),
MysqlNativePassword(HashedPassword<'a>, Salt<'a>),
PgMD5(HashedPassword<'a>, Salt<'a>),
}
#[derive(Clone, Debug)]
pub struct UserInfo {
username: String,
}
@@ -76,25 +76,25 @@ impl UserInfo {
#[derive(Debug, Snafu)]
#[snafu(visibility(pub))]
pub enum Error {
#[snafu(display("User not exist"))]
UserNotExist { backtrace: Backtrace },
#[snafu(display("User not found"))]
UserNotFound { backtrace: Backtrace },
#[snafu(display("Unsupported Password Type: {}", pwd_type))]
UnsupportedPwdType {
pwd_type: String,
#[snafu(display("Unsupported password type: {}", password_type))]
UnsupportedPasswordType {
password_type: String,
backtrace: Backtrace,
},
#[snafu(display("Username and password does not match"))]
WrongPwd { backtrace: Backtrace },
UserPasswordMismatch { backtrace: Backtrace },
}
impl ErrorExt for Error {
fn status_code(&self) -> StatusCode {
match self {
Error::UserNotExist { .. } => StatusCode::UserNotFound,
Error::UnsupportedPwdType { .. } => StatusCode::UnsupportedPwdType,
Error::WrongPwd { .. } => StatusCode::UserPwdMismatch,
Error::UserNotFound { .. } => StatusCode::UserNotFound,
Error::UnsupportedPasswordType { .. } => StatusCode::UnsupportedPasswordType,
Error::UserPasswordMismatch { .. } => StatusCode::UserPasswordMismatch,
}
}
@@ -108,11 +108,10 @@ impl ErrorExt for Error {
}
#[cfg(test)]
mod tests {
pub mod test {
use super::{Identity, Password, UserInfo, UserProvider};
use crate::auth;
struct MockUserProvider {}
pub struct MockUserProvider {}
#[async_trait::async_trait]
impl UserProvider for MockUserProvider {
@@ -127,27 +126,34 @@ mod tests {
) -> Result<UserInfo, super::Error> {
match id {
Identity::UserId(username, _host) => match password {
Password::PlainText(pwd) => {
Password::PlainText(password) => {
if username == "greptime" {
if pwd == b"greptime" {
if password == "greptime" {
return Ok(UserInfo {
username: "greptime".to_string(),
});
} else {
return super::WrongPwdSnafu {}.fail();
return super::UserPasswordMismatchSnafu {}.fail();
}
} else {
return super::UserNotExistSnafu {}.fail();
return super::UserNotFoundSnafu {}.fail();
}
}
_ => super::UnsupportedPwdTypeSnafu {
pwd_type: "mysql_native_pwd",
_ => super::UnsupportedPasswordTypeSnafu {
password_type: "mysql_native_password",
}
.fail(),
},
}
}
}
}
#[cfg(test)]
mod tests {
use super::test::MockUserProvider;
use super::{Identity, Password, UserProvider};
use crate::auth;
#[tokio::test]
async fn test_auth_by_plain_text() {
@@ -158,43 +164,46 @@ mod tests {
let auth_result = user_provider
.auth(
Identity::UserId("greptime", None),
Password::PlainText(b"greptime"),
Password::PlainText("greptime"),
)
.await;
assert!(auth_result.is_ok());
assert_eq!("greptime", auth_result.unwrap().user_name());
// auth failed, unsupported pwd type
// auth failed, unsupported password type
let auth_result = user_provider
.auth(
Identity::UserId("greptime", None),
Password::MysqlNativePwd(b"hashed_value", b"salt"),
Password::MysqlNativePassword(b"hashed_value", b"salt"),
)
.await;
assert!(auth_result.is_err());
matches!(
auth_result.err().unwrap(),
auth::Error::UnsupportedPwdType { .. }
auth::Error::UnsupportedPasswordType { .. }
);
// auth failed, err: user not exist.
let auth_result = user_provider
.auth(
Identity::UserId("not_exist_username", None),
Password::PlainText(b"greptime"),
Password::PlainText("greptime"),
)
.await;
assert!(auth_result.is_err());
matches!(auth_result.err().unwrap(), auth::Error::UserNotExist { .. });
matches!(auth_result.err().unwrap(), auth::Error::UserNotFound { .. });
// auth failed, err: wrong password
let auth_result = user_provider
.auth(
Identity::UserId("greptime", None),
Password::PlainText(b"wrong_pwd"),
Password::PlainText("wrong_password"),
)
.await;
assert!(auth_result.is_err());
matches!(auth_result.err().unwrap(), auth::Error::WrongPwd { .. });
matches!(
auth_result.err().unwrap(),
auth::Error::UserPasswordMismatch { .. }
);
}
}

View File

@@ -85,11 +85,11 @@ pub struct ClientInfo {
pub channel: Channel,
}
#[derive(Debug, PartialEq, Eq)]
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum Channel {
GRPC,
HTTP,
MYSQL,
Grpc,
Http,
Mysql,
}
#[derive(Default)]
@@ -105,7 +105,7 @@ mod test {
use std::sync::Arc;
use crate::auth::UserInfo;
use crate::context::Channel::{self, HTTP};
use crate::context::Channel::{self, Http};
use crate::context::{ClientInfo, Context, CtxBuilder};
#[test]
@@ -113,7 +113,7 @@ mod test {
let mut ctx = Context {
client_info: ClientInfo {
client_host: Default::default(),
channel: Channel::GRPC,
channel: Channel::Grpc,
},
user_info: UserInfo::new("greptime"),
quota: Default::default(),
@@ -137,7 +137,7 @@ mod test {
fn test_build() {
let ctx = CtxBuilder::new()
.client_addr("127.0.0.1:4001".to_string())
.set_channel(HTTP)
.set_channel(Http)
.set_user_info(UserInfo::new("greptime"))
.build()
.unwrap();

View File

@@ -14,11 +14,14 @@
use std::any::Any;
use std::net::SocketAddr;
use std::string::FromUtf8Error;
use axum::http::StatusCode as HttpStatusCode;
use axum::response::{IntoResponse, Response};
use axum::Json;
use base64::DecodeError;
use common_error::prelude::*;
use hyper::header::ToStrError;
use serde_json::json;
use crate::auth;
@@ -203,6 +206,33 @@ pub enum Error {
#[snafu(backtrace)]
source: auth::Error,
},
#[snafu(display("Not found http authorization header"))]
NotFoundAuthHeader {},
#[snafu(display("Invalid visibility ASCII chars, source: {}", source))]
InvisibleASCII {
source: ToStrError,
backtrace: Backtrace,
},
#[snafu(display("Unsupported http auth scheme, name: {}", name))]
UnsupportedAuthScheme { name: String },
#[snafu(display("Invalid http authorization header"))]
InvalidAuthorizationHeader { backtrace: Backtrace },
#[snafu(display("Invalid base64 value, source: {:?}", source))]
InvalidBase64Value {
source: DecodeError,
backtrace: Backtrace,
},
#[snafu(display("Invalid utf-8 value, source: {:?}", source))]
InvalidUtf8Value {
source: FromUtf8Error,
backtrace: Backtrace,
},
}
pub type Result<T> = std::result::Result<T, Error>;
@@ -248,6 +278,13 @@ impl ErrorExt for Error {
TlsRequired { .. } => StatusCode::Unknown,
StartFrontend { source, .. } => source.status_code(),
Auth { source, .. } => source.status_code(),
NotFoundAuthHeader { .. } => StatusCode::AuthHeaderNotFound,
InvisibleASCII { .. }
| UnsupportedAuthScheme { .. }
| InvalidAuthorizationHeader { .. }
| InvalidBase64Value { .. }
| InvalidUtf8Value { .. } => StatusCode::InvalidAuthHeader,
}
}

View File

@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
mod context;
mod authorize;
pub mod handler;
pub mod influxdb;
pub mod opentsdb;
@@ -26,8 +26,8 @@ use std::time::Duration;
use aide::axum::{routing as apirouting, ApiRouter, IntoApiResponse};
use aide::openapi::{Info, OpenApi, Server as OpenAPIServer};
use async_trait::async_trait;
use axum::body::BoxBody;
use axum::error_handling::HandleErrorLayer;
use axum::middleware::{self};
use axum::response::{Html, Json};
use axum::{routing, BoxError, Extension, Router};
use common_error::prelude::ErrorExt;
@@ -45,9 +45,12 @@ use tokio::sync::oneshot::{self, Sender};
use tokio::sync::Mutex;
use tower::timeout::TimeoutLayer;
use tower::ServiceBuilder;
use tower_http::auth::AsyncRequireAuthorizationLayer;
use tower_http::trace::TraceLayer;
use self::authorize::HttpAuth;
use self::influxdb::influxdb_write;
use crate::auth::UserProviderRef;
use crate::error::{AlreadyStartedSnafu, Result, StartHttpSnafu};
use crate::query_handler::{
InfluxdbLineProtocolHandlerRef, OpentsdbProtocolHandlerRef, PrometheusProtocolHandlerRef,
@@ -65,6 +68,7 @@ pub struct HttpServer {
prom_handler: Option<PrometheusProtocolHandlerRef>,
script_handler: Option<ScriptHandlerRef>,
shutdown_tx: Mutex<Option<Sender<()>>>,
user_provider: Option<UserProviderRef>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
@@ -295,6 +299,7 @@ impl HttpServer {
opentsdb_handler: None,
influxdb_handler: None,
prom_handler: None,
user_provider: None,
script_handler: None,
shutdown_tx: Mutex::new(None),
}
@@ -332,6 +337,14 @@ impl HttpServer {
self.prom_handler.get_or_insert(handler);
}
pub fn set_user_provider(&mut self, user_provider: UserProviderRef) {
debug_assert!(
self.user_provider.is_none(),
"User provider can be set only once!"
);
self.user_provider.get_or_insert(user_provider);
}
pub fn make_app(&self) -> Router {
let mut api = OpenApi {
info: Info {
@@ -393,7 +406,9 @@ impl HttpServer {
.layer(TraceLayer::new_for_http())
.layer(TimeoutLayer::new(self.options.timeout))
// custom layer
.layer(middleware::from_fn(context::build_ctx)),
.layer(AsyncRequireAuthorizationLayer::new(
HttpAuth::<BoxBody>::new(self.user_provider.clone()),
)),
)
}

View File

@@ -0,0 +1,282 @@
// Copyright 2022 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::marker::PhantomData;
use axum::http::{self, Request, StatusCode};
use axum::response::Response;
use common_telemetry::error;
use futures::future::BoxFuture;
use http_body::Body;
use snafu::{OptionExt, ResultExt};
use tower_http::auth::AsyncAuthorizeRequest;
use crate::auth::{Identity, UserInfo, UserProviderRef};
use crate::error::{self, Result};
pub struct HttpAuth<RespBody> {
user_provider: Option<UserProviderRef>,
_ty: PhantomData<RespBody>,
}
impl<RespBody> HttpAuth<RespBody> {
pub fn new(user_provider: Option<UserProviderRef>) -> Self {
Self {
user_provider,
_ty: PhantomData,
}
}
}
impl<RespBody> Clone for HttpAuth<RespBody> {
fn clone(&self) -> Self {
Self {
user_provider: self.user_provider.clone(),
_ty: PhantomData,
}
}
}
impl<B, RespBody> AsyncAuthorizeRequest<B> for HttpAuth<RespBody>
where
B: Send + Sync + 'static,
RespBody: Body + Default,
{
type RequestBody = B;
type ResponseBody = RespBody;
type Future = BoxFuture<'static, std::result::Result<Request<B>, Response<Self::ResponseBody>>>;
fn authorize(&mut self, mut request: Request<B>) -> Self::Future {
let user_provider = self.user_provider.clone();
Box::pin(async move {
let user_provider = if let Some(user_provider) = &user_provider {
user_provider
} else {
request.extensions_mut().insert(UserInfo::default());
return Ok(request);
};
let (scheme, credential) = match auth_header(&request) {
Ok(auth_header) => auth_header,
Err(e) => {
error!("failed to get http authorize header, err: {:?}", e);
return Err(unauthorized_resp());
}
};
match scheme {
AuthScheme::Basic => {
let (username, password) = match decode_basic(credential) {
Ok(basic_auth) => basic_auth,
Err(e) => {
error!("failed to decode basic authorize, err: {:?}", e);
return Err(unauthorized_resp());
}
};
match user_provider
.auth(
Identity::UserId(&username, None),
crate::auth::Password::PlainText(&password),
)
.await
{
Ok(user_info) => {
request.extensions_mut().insert(user_info);
Ok(request)
}
Err(e) => {
error!("failed to auth, err: {:?}", e);
Err(unauthorized_resp())
}
}
}
}
})
}
}
fn unauthorized_resp<RespBody>() -> Response<RespBody>
where
RespBody: Body + Default,
{
let mut res = Response::new(RespBody::default());
*res.status_mut() = StatusCode::UNAUTHORIZED;
res
}
#[derive(Debug)]
pub enum AuthScheme {
Basic,
}
impl TryFrom<&str> for AuthScheme {
type Error = error::Error;
fn try_from(value: &str) -> Result<Self> {
match value.to_lowercase().as_str() {
"basic" => Ok(AuthScheme::Basic),
other => error::UnsupportedAuthSchemeSnafu { name: other }.fail(),
}
}
}
type Credential<'a> = &'a str;
fn auth_header<B>(req: &Request<B>) -> Result<(AuthScheme, Credential)> {
let auth_header = req
.headers()
.get(http::header::AUTHORIZATION)
.context(error::NotFoundAuthHeaderSnafu)?
.to_str()
.context(error::InvisibleASCIISnafu)?;
let (auth_scheme, encoded_credentials) = auth_header
.split_once(' ')
.context(error::InvalidAuthorizationHeaderSnafu)?;
if encoded_credentials.contains(' ') {
return error::InvalidAuthorizationHeaderSnafu {}.fail();
}
Ok((auth_scheme.try_into()?, encoded_credentials))
}
type Username = String;
type Password = String;
fn decode_basic(credential: Credential) -> Result<(Username, Password)> {
let decoded = base64::decode(credential).context(error::InvalidBase64ValueSnafu)?;
let as_utf8 = String::from_utf8(decoded).context(error::InvalidUtf8ValueSnafu)?;
if let Some((user_id, password)) = as_utf8.split_once(':') {
return Ok((user_id.to_string(), password.to_string()));
}
error::InvalidAuthorizationHeaderSnafu {}.fail()
}
#[cfg(test)]
mod tests {
use std::marker::PhantomData;
use std::sync::Arc;
use axum::body::BoxBody;
use axum::http;
use hyper::Request;
use tower_http::auth::AsyncAuthorizeRequest;
use super::{auth_header, decode_basic, AuthScheme, HttpAuth};
use crate::auth::test::MockUserProvider;
use crate::auth::{UserInfo, UserProvider};
use crate::error;
use crate::error::Result;
#[tokio::test]
async fn test_http_auth() {
let mut http_auth: HttpAuth<BoxBody> = HttpAuth {
user_provider: None,
_ty: PhantomData,
};
// base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ="
let req = mock_http_request("Basic dXNlcm5hbWU6cGFzc3dvcmQ=").unwrap();
let auth_res = http_auth.authorize(req).await.unwrap();
let user_info: &UserInfo = auth_res.extensions().get().unwrap();
let default = UserInfo::default();
assert_eq!(default.user_name(), user_info.user_name());
// In mock user provider, right username:password == "greptime:greptime"
let mock_user_provider = Some(Arc::new(MockUserProvider {}) as Arc<dyn UserProvider>);
let mut http_auth: HttpAuth<BoxBody> = HttpAuth {
user_provider: mock_user_provider,
_ty: PhantomData,
};
// base64encode("greptime:greptime") == "Z3JlcHRpbWU6Z3JlcHRpbWU="
let req = mock_http_request("Basic Z3JlcHRpbWU6Z3JlcHRpbWU=").unwrap();
let req = http_auth.authorize(req).await.unwrap();
let user_info: &UserInfo = req.extensions().get().unwrap();
let default = UserInfo::default();
assert_eq!(default.user_name(), user_info.user_name());
let req = mock_http_request_no_auth().unwrap();
let auth_res = http_auth.authorize(req).await;
assert!(auth_res.is_err());
// base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ="
let wrong_req = mock_http_request("Basic dXNlcm5hbWU6cGFzc3dvcmQ=").unwrap();
let auth_res = http_auth.authorize(wrong_req).await;
assert!(auth_res.is_err());
}
#[test]
fn test_decode_basic() {
// base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ="
let credential = "dXNlcm5hbWU6cGFzc3dvcmQ=";
let (username, pwd) = decode_basic(credential).unwrap();
assert_eq!("username", username);
assert_eq!("password", pwd);
let wrong_credential = "dXNlcm5hbWU6cG Fzc3dvcmQ=";
let result = decode_basic(wrong_credential);
matches!(result.err(), Some(error::Error::InvalidBase64Value { .. }));
}
#[test]
fn test_try_into_auth_scheme() {
let auth_scheme_str = "basic";
let auth_scheme: AuthScheme = auth_scheme_str.try_into().unwrap();
matches!(auth_scheme, AuthScheme::Basic);
let unsupported = "digest";
let auth_scheme: Result<AuthScheme> = unsupported.try_into();
assert!(auth_scheme.is_err());
}
#[test]
fn test_auth_header() {
// base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ="
let req = mock_http_request("Basic dXNlcm5hbWU6cGFzc3dvcmQ=").unwrap();
let (auth_scheme, credential) = auth_header(&req).unwrap();
matches!(auth_scheme, AuthScheme::Basic);
assert_eq!("dXNlcm5hbWU6cGFzc3dvcmQ=", credential);
let wrong_req = mock_http_request("Basic dXNlcm5hbWU6 cGFzc3dvcmQ=").unwrap();
let res = auth_header(&wrong_req);
matches!(
res.err(),
Some(error::Error::InvalidAuthorizationHeader { .. })
);
let wrong_req = mock_http_request("Digest dXNlcm5hbWU6cGFzc3dvcmQ=").unwrap();
let res = auth_header(&wrong_req);
matches!(res.err(), Some(error::Error::UnsupportedAuthScheme { .. }));
}
fn mock_http_request(auth_header: &str) -> Result<Request<()>> {
Ok(Request::builder()
.uri("https://www.rust-lang.org/")
.header(http::header::AUTHORIZATION, auth_header)
.body(())
.unwrap())
}
fn mock_http_request_no_auth() -> Result<Request<()>> {
Ok(Request::builder()
.uri("https://www.rust-lang.org/")
.body(())
.unwrap())
}
}

View File

@@ -1,22 +0,0 @@
// Copyright 2022 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use axum::http::{Request, StatusCode};
use axum::middleware::Next;
use axum::response::Response;
pub async fn build_ctx<B>(req: Request<B>, next: Next<B>) -> Result<Response, StatusCode> {
// TODO(fys): auth and set context
Ok(next.run(req).await)
}

View File

@@ -18,12 +18,14 @@ use std::time::Instant;
use aide::transform::TransformOperation;
use axum::extract::{Json, Query, State};
use axum::Extension;
use common_error::status_code::StatusCode;
use common_telemetry::metric;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use session::context::QueryContext;
use crate::auth::UserInfo;
use crate::http::{ApiState, JsonResponse};
#[derive(Debug, Default, Serialize, Deserialize, JsonSchema)]
@@ -37,6 +39,8 @@ pub struct SqlQuery {
pub async fn sql(
State(state): State<ApiState>,
Query(params): Query<SqlQuery>,
// TODO(fys): pass _user_info into query context
_user_info: Extension<UserInfo>,
) -> Json<JsonResponse> {
let sql_handler = &state.sql_handler;
let start = Instant::now();

View File

@@ -27,7 +27,7 @@ use tokio::io::AsyncWrite;
use tokio::sync::RwLock;
use crate::auth::{Identity, Password, UserProviderRef};
use crate::context::Channel::MYSQL;
use crate::context::Channel::Mysql;
use crate::context::{Context, CtxBuilder};
use crate::error::{self, Result};
use crate::mysql::writer::MysqlResultWriter;
@@ -121,14 +121,14 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
if let Some(user_provider) = &self.user_provider {
let user_id = Identity::UserId(&username, Some(&client_addr));
let pwd = match auth_plugin {
"mysql_native_password" => Password::MysqlNativePwd(auth_data, salt),
let password = match auth_plugin {
"mysql_native_password" => Password::MysqlNativePassword(auth_data, salt),
other => {
error!("Unsupported mysql auth plugin: {}", other);
return false;
}
};
match user_provider.auth(user_id, pwd).await {
match user_provider.auth(user_id, password).await {
Ok(userinfo) => {
user_info = Some(userinfo);
}
@@ -142,7 +142,7 @@ impl<W: AsyncWrite + Send + Sync + Unpin> AsyncMysqlShim<W> for MysqlInstanceShi
return match CtxBuilder::new()
.client_addr(client_addr)
.set_channel(MYSQL)
.set_channel(Mysql)
.set_user_info(user_info)
.build()
{

View File

@@ -57,7 +57,7 @@ impl LoginInfo {
}
impl PgPwdVerifier {
async fn verify_pwd(&self, pwd: &str, login: LoginInfo) -> Result<bool> {
async fn verify_pwd(&self, password: &str, login: LoginInfo) -> Result<bool> {
if let Some(user_provider) = &self.user_provider {
let user_name = match login.user {
Some(name) => name,
@@ -68,7 +68,7 @@ impl PgPwdVerifier {
let _user_info = user_provider
.auth(
Identity::UserId(&user_name, None),
Password::PlainText(pwd.as_bytes()),
Password::PlainText(password),
)
.await
.context(error::AuthSnafu)?;

View File

@@ -18,6 +18,7 @@ use axum::body::Body;
use axum::extract::{Json, Query, RawBody, State};
use common_telemetry::metric;
use metrics::counter;
use servers::auth::UserInfo;
use servers::http::{handler as http_handler, script as script_handler, ApiState, JsonOutput};
use table::test_util::MemTable;
@@ -32,6 +33,7 @@ async fn test_sql_not_provided() {
script_handler: None,
}),
Query(http_handler::SqlQuery::default()),
axum::Extension(UserInfo::default()),
)
.await;
assert!(!json.success());
@@ -55,6 +57,7 @@ async fn test_sql_output_rows() {
script_handler: None,
}),
query,
axum::Extension(UserInfo::default()),
)
.await;
assert!(json.success(), "{:?}", json);