mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-01-08 14:22:58 +00:00
feat: add auth to grpc handler (#1051)
* chore: get header in grpc & temp save * chore: change authscheme to include data str * chore: add auth to grpc flight handler * chore: add unit test & hold for now since grpc api doesnt accept req input * chore: minor change * chore: minor change * chore: add flight context to database interface * chore: add test * chore: update proto version & fix cr issue * chore: add test * chore: minor update
This commit is contained in:
3
Cargo.lock
generated
3
Cargo.lock
generated
@@ -2975,7 +2975,7 @@ checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b"
|
||||
[[package]]
|
||||
name = "greptime-proto"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/GreptimeTeam/greptime-proto.git?rev=3e6349be127b65a8b42a38cda9d527ec423ca77d#3e6349be127b65a8b42a38cda9d527ec423ca77d"
|
||||
source = "git+https://github.com/GreptimeTeam/greptime-proto.git?rev=1599ae2a0d1d8f42ee23ed26e4ad7a7b34134c60#1599ae2a0d1d8f42ee23ed26e4ad7a7b34134c60"
|
||||
dependencies = [
|
||||
"prost 0.11.6",
|
||||
"tonic",
|
||||
@@ -6685,6 +6685,7 @@ dependencies = [
|
||||
"bytes",
|
||||
"catalog",
|
||||
"chrono",
|
||||
"client",
|
||||
"common-base",
|
||||
"common-catalog",
|
||||
"common-error",
|
||||
|
||||
@@ -10,7 +10,7 @@ common-base = { path = "../common/base" }
|
||||
common-error = { path = "../common/error" }
|
||||
common-time = { path = "../common/time" }
|
||||
datatypes = { path = "../datatypes" }
|
||||
greptime-proto = { git = "https://github.com/GreptimeTeam/greptime-proto.git", rev = "3e6349be127b65a8b42a38cda9d527ec423ca77d" }
|
||||
greptime-proto = { git = "https://github.com/GreptimeTeam/greptime-proto.git", rev = "1599ae2a0d1d8f42ee23ed26e4ad7a7b34134c60" }
|
||||
prost.workspace = true
|
||||
snafu = { version = "0.7", features = ["backtraces"] }
|
||||
tonic.workspace = true
|
||||
|
||||
@@ -14,12 +14,13 @@
|
||||
|
||||
use std::str::FromStr;
|
||||
|
||||
use api::v1::auth_header::AuthScheme;
|
||||
use api::v1::ddl_request::Expr as DdlExpr;
|
||||
use api::v1::greptime_request::Request;
|
||||
use api::v1::query_request::Query;
|
||||
use api::v1::{
|
||||
AlterExpr, CreateTableExpr, DdlRequest, DropTableExpr, GreptimeRequest, InsertRequest,
|
||||
QueryRequest, RequestHeader,
|
||||
AlterExpr, AuthHeader, CreateTableExpr, DdlRequest, DropTableExpr, GreptimeRequest,
|
||||
InsertRequest, QueryRequest, RequestHeader,
|
||||
};
|
||||
use arrow_flight::{FlightData, Ticket};
|
||||
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
|
||||
@@ -42,6 +43,7 @@ pub struct Database {
|
||||
schema: String,
|
||||
|
||||
client: Client,
|
||||
ctx: FlightContext,
|
||||
}
|
||||
|
||||
impl Database {
|
||||
@@ -50,6 +52,7 @@ impl Database {
|
||||
catalog: catalog.into(),
|
||||
schema: schema.into(),
|
||||
client,
|
||||
ctx: FlightContext::default(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -61,6 +64,12 @@ impl Database {
|
||||
self.schema = schema.into();
|
||||
}
|
||||
|
||||
pub fn set_auth(&mut self, auth: AuthScheme) {
|
||||
self.ctx.auth_header = Some(AuthHeader {
|
||||
auth_scheme: Some(auth),
|
||||
});
|
||||
}
|
||||
|
||||
pub async fn insert(&self, request: InsertRequest) -> Result<Output> {
|
||||
self.do_get(Request::Insert(request)).await
|
||||
}
|
||||
@@ -105,6 +114,7 @@ impl Database {
|
||||
header: Some(RequestHeader {
|
||||
catalog: self.catalog.clone(),
|
||||
schema: self.schema.clone(),
|
||||
authorization: self.ctx.auth_header.clone(),
|
||||
}),
|
||||
request: Some(request),
|
||||
};
|
||||
@@ -164,12 +174,18 @@ fn get_metadata_value(e: &tonic::Status, key: &str) -> Option<String> {
|
||||
.and_then(|v| String::from_utf8(v.as_bytes().to_vec()).ok())
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct FlightContext {
|
||||
auth_header: Option<AuthHeader>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use api::helper::ColumnDataTypeWrapper;
|
||||
use api::v1::Column;
|
||||
use api::v1::auth_header::AuthScheme;
|
||||
use api::v1::{AuthHeader, Basic, Column};
|
||||
use common_grpc::select::{null_mask, values};
|
||||
use common_grpc_expr::column_to_vector;
|
||||
use datatypes::prelude::{Vector, VectorRef};
|
||||
@@ -179,6 +195,8 @@ mod tests {
|
||||
UInt32Vector, UInt64Vector, UInt8Vector,
|
||||
};
|
||||
|
||||
use crate::database::FlightContext;
|
||||
|
||||
#[test]
|
||||
fn test_column_to_vector() {
|
||||
let mut column = create_test_column(Arc::new(BooleanVector::from(vec![true])));
|
||||
@@ -262,4 +280,26 @@ mod tests {
|
||||
datatype: wrapper.datatype() as i32,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_flight_ctx() {
|
||||
let mut ctx = FlightContext::default();
|
||||
assert!(ctx.auth_header.is_none());
|
||||
|
||||
let basic = AuthScheme::Basic(Basic {
|
||||
username: "u".to_string(),
|
||||
password: "p".to_string(),
|
||||
});
|
||||
|
||||
ctx.auth_header = Some(AuthHeader {
|
||||
auth_scheme: Some(basic),
|
||||
});
|
||||
|
||||
assert!(matches!(
|
||||
ctx.auth_header,
|
||||
Some(AuthHeader {
|
||||
auth_scheme: Some(AuthScheme::Basic(_)),
|
||||
})
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -89,6 +89,7 @@ impl Services {
|
||||
Ok(Self {
|
||||
grpc_server: GrpcServer::new(
|
||||
ServerGrpcQueryHandlerAdaptor::arc(instance),
|
||||
None,
|
||||
grpc_runtime,
|
||||
),
|
||||
mysql_server,
|
||||
|
||||
@@ -66,6 +66,7 @@ impl Services {
|
||||
|
||||
let grpc_server = GrpcServer::new(
|
||||
ServerGrpcQueryHandlerAdaptor::arc(instance.clone()),
|
||||
user_provider.clone(),
|
||||
grpc_runtime,
|
||||
);
|
||||
|
||||
|
||||
@@ -114,6 +114,7 @@ pub(crate) async fn create_datanode_client(
|
||||
// https://github.com/hyperium/tonic/blob/master/examples/src/mock/mock.rs
|
||||
let datanode_service = GrpcServer::new(
|
||||
ServerGrpcQueryHandlerAdaptor::arc(datanode_instance),
|
||||
None,
|
||||
runtime,
|
||||
)
|
||||
.create_service();
|
||||
|
||||
@@ -18,11 +18,11 @@ pub mod kv;
|
||||
pub mod memory;
|
||||
|
||||
use api::v1::meta::{
|
||||
store_server, BatchPutRequest, BatchPutResponse, CompareAndPutRequest, CompareAndPutResponse,
|
||||
DeleteRangeRequest, DeleteRangeResponse, MoveValueRequest, MoveValueResponse, PutRequest,
|
||||
PutResponse, RangeRequest, RangeResponse,
|
||||
store_server, BatchGetRequest, BatchGetResponse, BatchPutRequest, BatchPutResponse,
|
||||
CompareAndPutRequest, CompareAndPutResponse, DeleteRangeRequest, DeleteRangeResponse,
|
||||
MoveValueRequest, MoveValueResponse, PutRequest, PutResponse, RangeRequest, RangeResponse,
|
||||
};
|
||||
use tonic::{Request, Response};
|
||||
use tonic::{Request, Response, Status};
|
||||
|
||||
use crate::metasrv::MetaSrv;
|
||||
use crate::service::GrpcResult;
|
||||
@@ -43,6 +43,14 @@ impl store_server::Store for MetaSrv {
|
||||
Ok(Response::new(res))
|
||||
}
|
||||
|
||||
async fn batch_get(
|
||||
&self,
|
||||
_request: Request<BatchGetRequest>,
|
||||
) -> Result<Response<BatchGetResponse>, Status> {
|
||||
// TODO(fys): please fix this
|
||||
unimplemented!()
|
||||
}
|
||||
|
||||
async fn batch_put(&self, req: Request<BatchPutRequest>) -> GrpcResult<BatchPutResponse> {
|
||||
let req = req.into_inner();
|
||||
let res = self.kv_store().batch_put(req).await?;
|
||||
|
||||
@@ -69,6 +69,7 @@ tower-http = { version = "0.3", features = ["full"] }
|
||||
|
||||
[dev-dependencies]
|
||||
axum-test-helper = { git = "https://github.com/sunng87/axum-test-helper.git", branch = "patch-1" }
|
||||
client = { path = "../client" }
|
||||
common-base = { path = "../common/base" }
|
||||
mysql_async = { version = "0.31", default-features = false, features = [
|
||||
"default-rustls",
|
||||
|
||||
@@ -22,7 +22,6 @@ use axum::Json;
|
||||
use base64::DecodeError;
|
||||
use catalog;
|
||||
use common_error::prelude::*;
|
||||
use hyper::header::ToStrError;
|
||||
use serde_json::json;
|
||||
use tonic::codegen::http::{HeaderMap, HeaderValue};
|
||||
use tonic::metadata::MetadataMap;
|
||||
@@ -150,7 +149,7 @@ pub enum Error {
|
||||
|
||||
#[snafu(display("Invalid OpenTSDB line, source: {}", source))]
|
||||
InvalidOpentsdbLine {
|
||||
source: std::string::FromUtf8Error,
|
||||
source: FromUtf8Error,
|
||||
backtrace: Backtrace,
|
||||
},
|
||||
|
||||
@@ -216,7 +215,7 @@ pub enum Error {
|
||||
source: auth::Error,
|
||||
},
|
||||
|
||||
#[snafu(display("Not found http authorization header"))]
|
||||
#[snafu(display("Not found http or grpc authorization header"))]
|
||||
NotFoundAuthHeader {},
|
||||
|
||||
#[snafu(display("Not found influx http authorization info"))]
|
||||
@@ -224,7 +223,7 @@ pub enum Error {
|
||||
|
||||
#[snafu(display("Invalid visibility ASCII chars, source: {}", source))]
|
||||
InvisibleASCII {
|
||||
source: ToStrError,
|
||||
source: hyper::header::ToStrError,
|
||||
backtrace: Backtrace,
|
||||
},
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
mod flight;
|
||||
pub mod flight;
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
@@ -28,6 +28,7 @@ use tokio::sync::oneshot::{self, Sender};
|
||||
use tokio::sync::Mutex;
|
||||
use tokio_stream::wrappers::TcpListenerStream;
|
||||
|
||||
use crate::auth::UserProviderRef;
|
||||
use crate::error::{AlreadyStartedSnafu, Result, StartGrpcSnafu, TcpBindSnafu};
|
||||
use crate::grpc::flight::FlightHandler;
|
||||
use crate::query_handler::grpc::ServerGrpcQueryHandlerRef;
|
||||
@@ -35,21 +36,31 @@ use crate::server::Server;
|
||||
|
||||
pub struct GrpcServer {
|
||||
query_handler: ServerGrpcQueryHandlerRef,
|
||||
user_provider: Option<UserProviderRef>,
|
||||
shutdown_tx: Mutex<Option<Sender<()>>>,
|
||||
runtime: Arc<Runtime>,
|
||||
}
|
||||
|
||||
impl GrpcServer {
|
||||
pub fn new(query_handler: ServerGrpcQueryHandlerRef, runtime: Arc<Runtime>) -> Self {
|
||||
pub fn new(
|
||||
query_handler: ServerGrpcQueryHandlerRef,
|
||||
user_provider: Option<UserProviderRef>,
|
||||
runtime: Arc<Runtime>,
|
||||
) -> Self {
|
||||
Self {
|
||||
query_handler,
|
||||
user_provider,
|
||||
shutdown_tx: Mutex::new(None),
|
||||
runtime,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_service(&self) -> FlightServiceServer<impl FlightService> {
|
||||
let service = FlightHandler::new(self.query_handler.clone(), self.runtime.clone());
|
||||
let service = FlightHandler::new(
|
||||
self.query_handler.clone(),
|
||||
self.user_provider.clone(),
|
||||
self.runtime.clone(),
|
||||
);
|
||||
FlightServiceServer::new(service)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,7 +17,8 @@ mod stream;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
|
||||
use api::v1::{GreptimeRequest, RequestHeader};
|
||||
use api::v1::auth_header::AuthScheme;
|
||||
use api::v1::{Basic, GreptimeRequest, RequestHeader};
|
||||
use arrow_flight::flight_service_server::FlightService;
|
||||
use arrow_flight::{
|
||||
Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
|
||||
@@ -33,21 +34,33 @@ use session::context::{QueryContext, QueryContextRef};
|
||||
use snafu::{OptionExt, ResultExt};
|
||||
use tonic::{Request, Response, Status, Streaming};
|
||||
|
||||
use crate::auth::{Identity, UserProviderRef};
|
||||
use crate::error;
|
||||
use crate::error::Error::Auth;
|
||||
use crate::error::{NotFoundAuthHeaderSnafu, UnsupportedAuthSchemeSnafu};
|
||||
use crate::grpc::flight::stream::FlightRecordBatchStream;
|
||||
use crate::query_handler::grpc::ServerGrpcQueryHandlerRef;
|
||||
|
||||
type TonicResult<T> = Result<T, Status>;
|
||||
type TonicStream<T> = Pin<Box<dyn Stream<Item = TonicResult<T>> + Send + Sync + 'static>>;
|
||||
|
||||
pub(crate) struct FlightHandler {
|
||||
pub struct FlightHandler {
|
||||
handler: ServerGrpcQueryHandlerRef,
|
||||
user_provider: Option<UserProviderRef>,
|
||||
runtime: Arc<Runtime>,
|
||||
}
|
||||
|
||||
impl FlightHandler {
|
||||
pub(crate) fn new(handler: ServerGrpcQueryHandlerRef, runtime: Arc<Runtime>) -> Self {
|
||||
Self { handler, runtime }
|
||||
pub fn new(
|
||||
handler: ServerGrpcQueryHandlerRef,
|
||||
user_provider: Option<UserProviderRef>,
|
||||
runtime: Arc<Runtime>,
|
||||
) -> Self {
|
||||
Self {
|
||||
handler,
|
||||
user_provider,
|
||||
runtime,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -97,6 +110,13 @@ impl FlightService for FlightHandler {
|
||||
})?;
|
||||
let query_ctx = create_query_context(request.header.as_ref());
|
||||
|
||||
auth(
|
||||
self.user_provider.as_ref(),
|
||||
request.header.as_ref(),
|
||||
&query_ctx,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let handler = self.handler.clone();
|
||||
|
||||
// Executes requests in another runtime to
|
||||
@@ -189,3 +209,42 @@ fn create_query_context(header: Option<&RequestHeader>) -> QueryContextRef {
|
||||
};
|
||||
ctx
|
||||
}
|
||||
|
||||
async fn auth(
|
||||
user_provider: Option<&UserProviderRef>,
|
||||
request_header: Option<&RequestHeader>,
|
||||
query_ctx: &QueryContextRef,
|
||||
) -> TonicResult<()> {
|
||||
let Some(user_provider) = user_provider else { return Ok(()) };
|
||||
|
||||
let user_info = match request_header
|
||||
.context(NotFoundAuthHeaderSnafu)?
|
||||
.clone()
|
||||
.authorization
|
||||
.context(NotFoundAuthHeaderSnafu)?
|
||||
.auth_scheme
|
||||
.context(NotFoundAuthHeaderSnafu)?
|
||||
{
|
||||
AuthScheme::Basic(Basic { username, password }) => user_provider
|
||||
.authenticate(
|
||||
Identity::UserId(&username, None),
|
||||
crate::auth::Password::PlainText(&password),
|
||||
)
|
||||
.await
|
||||
.map_err(|e| Auth { source: e }),
|
||||
AuthScheme::Token(_) => UnsupportedAuthSchemeSnafu {
|
||||
name: "Token AuthScheme",
|
||||
}
|
||||
.fail(),
|
||||
}
|
||||
.map_err(|e| Status::unauthenticated(e.to_string()))?;
|
||||
|
||||
user_provider
|
||||
.authorize(
|
||||
&query_ctx.current_catalog(),
|
||||
&query_ctx.current_schema(),
|
||||
&user_info,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| Status::permission_denied(e.to_string()))
|
||||
}
|
||||
|
||||
@@ -194,9 +194,9 @@ async fn authenticate<B: Send + Sync + 'static>(
|
||||
get_influxdb_credentials(request)?.context(NotFoundInfluxAuthSnafu)?
|
||||
} else {
|
||||
// normal http auth
|
||||
let (scheme, credential) = auth_header(request)?;
|
||||
let scheme = auth_header(request)?;
|
||||
match scheme {
|
||||
AuthScheme::Basic => decode_basic(credential)?,
|
||||
AuthScheme::Basic(username, password) => (username, password),
|
||||
}
|
||||
};
|
||||
|
||||
@@ -219,15 +219,27 @@ where
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum AuthScheme {
|
||||
Basic,
|
||||
Basic(Username, Password),
|
||||
}
|
||||
|
||||
type Username = String;
|
||||
type Password = String;
|
||||
|
||||
impl TryFrom<&str> for AuthScheme {
|
||||
type Error = error::Error;
|
||||
|
||||
fn try_from(value: &str) -> Result<Self> {
|
||||
match value.to_lowercase().as_str() {
|
||||
"basic" => Ok(AuthScheme::Basic),
|
||||
let (scheme, encoded_credentials) = value
|
||||
.split_once(' ')
|
||||
.context(InvalidAuthorizationHeaderSnafu)?;
|
||||
ensure!(
|
||||
!encoded_credentials.contains(' '),
|
||||
InvalidAuthorizationHeaderSnafu
|
||||
);
|
||||
|
||||
match scheme.to_lowercase().as_str() {
|
||||
"basic" => decode_basic(encoded_credentials)
|
||||
.map(|(username, password)| AuthScheme::Basic(username, password)),
|
||||
other => UnsupportedAuthSchemeSnafu { name: other }.fail(),
|
||||
}
|
||||
}
|
||||
@@ -235,7 +247,7 @@ impl TryFrom<&str> for AuthScheme {
|
||||
|
||||
type Credential<'a> = &'a str;
|
||||
|
||||
fn auth_header<B>(req: &Request<B>) -> Result<(AuthScheme, Credential)> {
|
||||
fn auth_header<B>(req: &Request<B>) -> Result<AuthScheme> {
|
||||
let auth_header = req
|
||||
.headers()
|
||||
.get(http::header::AUTHORIZATION)
|
||||
@@ -243,20 +255,9 @@ fn auth_header<B>(req: &Request<B>) -> Result<(AuthScheme, Credential)> {
|
||||
.to_str()
|
||||
.context(InvisibleASCIISnafu)?;
|
||||
|
||||
let (auth_scheme, encoded_credentials) = auth_header
|
||||
.split_once(' ')
|
||||
.context(InvalidAuthorizationHeaderSnafu)?;
|
||||
|
||||
if encoded_credentials.contains(' ') {
|
||||
return InvalidAuthorizationHeaderSnafu {}.fail();
|
||||
}
|
||||
|
||||
Ok((auth_scheme.try_into()?, encoded_credentials))
|
||||
auth_header.try_into()
|
||||
}
|
||||
|
||||
type Username = String;
|
||||
type Password = String;
|
||||
|
||||
fn decode_basic(credential: Credential) -> Result<(Username, Password)> {
|
||||
let decoded = base64::decode(credential).context(error::InvalidBase64ValueSnafu)?;
|
||||
let as_utf8 = String::from_utf8(decoded).context(error::InvalidUtf8ValueSnafu)?;
|
||||
@@ -324,8 +325,12 @@ mod tests {
|
||||
#[test]
|
||||
fn test_try_into_auth_scheme() {
|
||||
let auth_scheme_str = "basic";
|
||||
let auth_scheme: AuthScheme = auth_scheme_str.try_into().unwrap();
|
||||
matches!(auth_scheme, AuthScheme::Basic);
|
||||
let re: Result<AuthScheme> = auth_scheme_str.try_into();
|
||||
assert!(re.is_err());
|
||||
|
||||
let auth_scheme_str = "basic dGVzdDp0ZXN0";
|
||||
let scheme: AuthScheme = auth_scheme_str.try_into().unwrap();
|
||||
matches!(scheme, AuthScheme::Basic(username, pwd) if username == "test" && pwd == "test");
|
||||
|
||||
let unsupported = "digest";
|
||||
let auth_scheme: Result<AuthScheme> = unsupported.try_into();
|
||||
@@ -337,9 +342,8 @@ mod tests {
|
||||
// base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ="
|
||||
let req = mock_http_request(Some("Basic dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap();
|
||||
|
||||
let (auth_scheme, credential) = auth_header(&req).unwrap();
|
||||
matches!(auth_scheme, AuthScheme::Basic);
|
||||
assert_eq!("dXNlcm5hbWU6cGFzc3dvcmQ=", credential);
|
||||
let auth_scheme = auth_header(&req).unwrap();
|
||||
matches!(auth_scheme, AuthScheme::Basic(username, pwd) if username == "username" && pwd == "password");
|
||||
|
||||
let wrong_req = mock_http_request(Some("Basic dXNlcm5hbWU6 cGFzc3dvcmQ="), None).unwrap();
|
||||
let res = auth_header(&wrong_req);
|
||||
|
||||
141
src/servers/tests/grpc/mod.rs
Normal file
141
src/servers/tests/grpc/mod.rs
Normal file
@@ -0,0 +1,141 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use api::v1::auth_header::AuthScheme;
|
||||
use api::v1::Basic;
|
||||
use arrow_flight::flight_service_server::{FlightService, FlightServiceServer};
|
||||
use async_trait::async_trait;
|
||||
use client::{Client, Database};
|
||||
use common_runtime::{Builder as RuntimeBuilder, Runtime};
|
||||
use servers::auth::UserProviderRef;
|
||||
use servers::error::{Result, StartGrpcSnafu, TcpBindSnafu};
|
||||
use servers::grpc::flight::FlightHandler;
|
||||
use servers::query_handler::grpc::ServerGrpcQueryHandlerRef;
|
||||
use servers::server::Server;
|
||||
use snafu::ResultExt;
|
||||
use table::test_util::MemTable;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio_stream::wrappers::TcpListenerStream;
|
||||
|
||||
use crate::auth::MockUserProvider;
|
||||
use crate::{create_testing_grpc_query_handler, LOCALHOST_WITH_0};
|
||||
|
||||
struct MockGrpcServer {
|
||||
query_handler: ServerGrpcQueryHandlerRef,
|
||||
user_provider: Option<UserProviderRef>,
|
||||
runtime: Arc<Runtime>,
|
||||
}
|
||||
|
||||
impl MockGrpcServer {
|
||||
fn new(
|
||||
query_handler: ServerGrpcQueryHandlerRef,
|
||||
user_provider: Option<UserProviderRef>,
|
||||
runtime: Arc<Runtime>,
|
||||
) -> Self {
|
||||
Self {
|
||||
query_handler,
|
||||
user_provider,
|
||||
runtime,
|
||||
}
|
||||
}
|
||||
|
||||
fn create_service(&self) -> FlightServiceServer<impl FlightService> {
|
||||
let service = FlightHandler::new(
|
||||
self.query_handler.clone(),
|
||||
self.user_provider.clone(),
|
||||
self.runtime.clone(),
|
||||
);
|
||||
FlightServiceServer::new(service)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Server for MockGrpcServer {
|
||||
async fn shutdown(&self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn start(&self, addr: SocketAddr) -> Result<SocketAddr> {
|
||||
let (listener, addr) = {
|
||||
let listener = TcpListener::bind(addr)
|
||||
.await
|
||||
.context(TcpBindSnafu { addr })?;
|
||||
let addr = listener.local_addr().context(TcpBindSnafu { addr })?;
|
||||
(listener, addr)
|
||||
};
|
||||
|
||||
let service = self.create_service();
|
||||
// Would block to serve requests.
|
||||
tokio::spawn(async move {
|
||||
tonic::transport::Server::builder()
|
||||
.add_service(service)
|
||||
.serve_with_incoming(TcpListenerStream::new(listener))
|
||||
.await
|
||||
.context(StartGrpcSnafu)
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
Ok(addr)
|
||||
}
|
||||
}
|
||||
|
||||
fn create_grpc_server(table: MemTable) -> Result<Arc<dyn Server>> {
|
||||
let query_handler = create_testing_grpc_query_handler(table);
|
||||
let io_runtime = Arc::new(
|
||||
RuntimeBuilder::default()
|
||||
.worker_threads(4)
|
||||
.thread_name("grpc-io-handlers")
|
||||
.build()
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
let provider = MockUserProvider::default();
|
||||
|
||||
Ok(Arc::new(MockGrpcServer::new(
|
||||
query_handler,
|
||||
Some(Arc::new(provider)),
|
||||
io_runtime,
|
||||
)))
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_grpc_server_startup() {
|
||||
let server = create_grpc_server(MemTable::default_numbers_table()).unwrap();
|
||||
let re = server.start(LOCALHOST_WITH_0.parse().unwrap()).await;
|
||||
assert!(re.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_grpc_query() {
|
||||
let server = create_grpc_server(MemTable::default_numbers_table()).unwrap();
|
||||
let re = server.start(LOCALHOST_WITH_0.parse().unwrap()).await;
|
||||
assert!(re.is_ok());
|
||||
|
||||
let grpc_client = Client::with_urls(vec![re.unwrap().to_string()]);
|
||||
let mut db = Database::with_client(grpc_client);
|
||||
|
||||
let re = db.sql("select * from numbers").await;
|
||||
assert!(re.is_err());
|
||||
|
||||
let greptime = "greptime".to_string();
|
||||
db.set_auth(AuthScheme::Basic(Basic {
|
||||
username: greptime.clone(),
|
||||
password: greptime.clone(),
|
||||
}));
|
||||
let re = db.sql("select * from numbers").await;
|
||||
assert!(re.is_ok());
|
||||
}
|
||||
@@ -15,6 +15,8 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
use api::v1::greptime_request::{Request as GreptimeRequest, Request};
|
||||
use api::v1::query_request::Query;
|
||||
use async_trait::async_trait;
|
||||
use catalog::local::{MemoryCatalogManager, MemoryCatalogProvider, MemorySchemaProvider};
|
||||
use catalog::{CatalogList, CatalogProvider, SchemaProvider};
|
||||
@@ -25,14 +27,17 @@ use query::parser::{PromQuery, QueryLanguageParser, QueryStatement};
|
||||
use query::{QueryEngineFactory, QueryEngineRef};
|
||||
use script::engine::{CompileContext, EvalContext, Script, ScriptEngine};
|
||||
use script::python::{PyEngine, PyScript};
|
||||
use servers::error::{Error, Result};
|
||||
use servers::error::{Error, NotSupportedSnafu, Result};
|
||||
use servers::query_handler::grpc::{GrpcQueryHandler, ServerGrpcQueryHandlerRef};
|
||||
use servers::query_handler::sql::{ServerSqlQueryHandlerRef, SqlQueryHandler};
|
||||
use servers::query_handler::{ScriptHandler, ScriptHandlerRef};
|
||||
use session::context::QueryContextRef;
|
||||
use snafu::ensure;
|
||||
use sql::statements::statement::Statement;
|
||||
use table::test_util::MemTable;
|
||||
|
||||
mod auth;
|
||||
mod grpc;
|
||||
mod http;
|
||||
mod interceptor;
|
||||
mod mysql;
|
||||
@@ -40,6 +45,8 @@ mod opentsdb;
|
||||
mod postgres;
|
||||
mod py_script;
|
||||
|
||||
const LOCALHOST_WITH_0: &str = "127.0.0.1:0";
|
||||
|
||||
struct DummyInstance {
|
||||
query_engine: QueryEngineRef,
|
||||
py_engine: Arc<PyEngine>,
|
||||
@@ -80,7 +87,7 @@ impl SqlQueryHandler for DummyInstance {
|
||||
|
||||
async fn do_statement_query(
|
||||
&self,
|
||||
_stmt: sql::statements::statement::Statement,
|
||||
_stmt: Statement,
|
||||
_query_ctx: QueryContextRef,
|
||||
) -> Result<Output> {
|
||||
unimplemented!()
|
||||
@@ -137,6 +144,39 @@ impl ScriptHandler for DummyInstance {
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl GrpcQueryHandler for DummyInstance {
|
||||
type Error = Error;
|
||||
|
||||
async fn do_query(
|
||||
&self,
|
||||
request: GreptimeRequest,
|
||||
ctx: QueryContextRef,
|
||||
) -> std::result::Result<Output, Self::Error> {
|
||||
let output = match request {
|
||||
Request::Insert(_) => unimplemented!(),
|
||||
Request::Query(query_request) => {
|
||||
let query = query_request.query.unwrap();
|
||||
match query {
|
||||
Query::Sql(sql) => {
|
||||
let mut result = SqlQueryHandler::do_query(self, &sql, ctx).await;
|
||||
ensure!(
|
||||
result.len() == 1,
|
||||
NotSupportedSnafu {
|
||||
feat: "execute multiple statements in SQL query string through GRPC interface"
|
||||
}
|
||||
);
|
||||
result.remove(0)?
|
||||
}
|
||||
Query::LogicalPlan(_) => unimplemented!(),
|
||||
}
|
||||
}
|
||||
Request::Ddl(_) => unimplemented!(),
|
||||
};
|
||||
Ok(output)
|
||||
}
|
||||
}
|
||||
|
||||
fn create_testing_instance(table: MemTable) -> DummyInstance {
|
||||
let table_name = table.table_name().to_string();
|
||||
let table = Arc::new(table);
|
||||
@@ -164,3 +204,7 @@ fn create_testing_script_handler(table: MemTable) -> ScriptHandlerRef {
|
||||
fn create_testing_sql_query_handler(table: MemTable) -> ServerSqlQueryHandlerRef {
|
||||
Arc::new(create_testing_instance(table)) as _
|
||||
}
|
||||
|
||||
fn create_testing_grpc_query_handler(table: MemTable) -> ServerGrpcQueryHandlerRef {
|
||||
Arc::new(create_testing_instance(table)) as _
|
||||
}
|
||||
|
||||
@@ -350,6 +350,7 @@ pub async fn setup_grpc_server(
|
||||
let fe_instance_ref = Arc::new(fe_instance);
|
||||
let fe_grpc_server = Arc::new(GrpcServer::new(
|
||||
ServerGrpcQueryHandlerAdaptor::arc(fe_instance_ref),
|
||||
None,
|
||||
runtime,
|
||||
));
|
||||
let grpc_server_clone = fe_grpc_server.clone();
|
||||
|
||||
Reference in New Issue
Block a user