mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-01-03 20:02:54 +00:00
feat: GRPC unary insert method
This commit is contained in:
2
Cargo.lock
generated
2
Cargo.lock
generated
@@ -3085,7 +3085,7 @@ checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b"
|
||||
[[package]]
|
||||
name = "greptime-proto"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/GreptimeTeam/greptime-proto.git?rev=3a715150563b89d5dfc81a5838eac1f66a5658a1#3a715150563b89d5dfc81a5838eac1f66a5658a1"
|
||||
source = "git+https://github.com/GreptimeTeam/greptime-proto.git?rev=0a7b790ed41364b5599dff806d1080bd59c5c9f6#0a7b790ed41364b5599dff806d1080bd59c5c9f6"
|
||||
dependencies = [
|
||||
"prost",
|
||||
"tonic",
|
||||
|
||||
@@ -10,7 +10,7 @@ common-base = { path = "../common/base" }
|
||||
common-error = { path = "../common/error" }
|
||||
common-time = { path = "../common/time" }
|
||||
datatypes = { path = "../datatypes" }
|
||||
greptime-proto = { git = "https://github.com/GreptimeTeam/greptime-proto.git", rev = "3a715150563b89d5dfc81a5838eac1f66a5658a1" }
|
||||
greptime-proto = { git = "https://github.com/GreptimeTeam/greptime-proto.git", rev = "0a7b790ed41364b5599dff806d1080bd59c5c9f6" }
|
||||
prost.workspace = true
|
||||
snafu = { version = "0.7", features = ["backtraces"] }
|
||||
tonic.workspace = true
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use api::v1::greptime_database_client::GreptimeDatabaseClient;
|
||||
use arrow_flight::flight_service_client::FlightServiceClient;
|
||||
use common_grpc::channel_manager::ChannelManager;
|
||||
use parking_lot::RwLock;
|
||||
@@ -23,6 +24,10 @@ use tonic::transport::Channel;
|
||||
use crate::load_balance::{LoadBalance, Loadbalancer};
|
||||
use crate::{error, Result};
|
||||
|
||||
pub(crate) struct DatabaseClient {
|
||||
pub(crate) inner: GreptimeDatabaseClient<Channel>,
|
||||
}
|
||||
|
||||
pub(crate) struct FlightClient {
|
||||
addr: String,
|
||||
client: FlightServiceClient<Channel>,
|
||||
@@ -118,7 +123,7 @@ impl Client {
|
||||
self.inner.set_peers(urls);
|
||||
}
|
||||
|
||||
pub(crate) fn make_client(&self) -> Result<FlightClient> {
|
||||
fn find_channel(&self) -> Result<(String, Channel)> {
|
||||
let addr = self
|
||||
.inner
|
||||
.get_peer()
|
||||
@@ -131,11 +136,23 @@ impl Client {
|
||||
.channel_manager
|
||||
.get(&addr)
|
||||
.context(error::CreateChannelSnafu { addr: &addr })?;
|
||||
Ok((addr, channel))
|
||||
}
|
||||
|
||||
pub(crate) fn make_flight_client(&self) -> Result<FlightClient> {
|
||||
let (addr, channel) = self.find_channel()?;
|
||||
Ok(FlightClient {
|
||||
addr,
|
||||
client: FlightServiceClient::new(channel),
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn make_database_client(&self) -> Result<DatabaseClient> {
|
||||
let (_, channel) = self.find_channel()?;
|
||||
Ok(DatabaseClient {
|
||||
inner: GreptimeDatabaseClient::new(channel),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -12,15 +12,14 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::str::FromStr;
|
||||
|
||||
use api::v1::auth_header::AuthScheme;
|
||||
use api::v1::ddl_request::Expr as DdlExpr;
|
||||
use api::v1::greptime_request::Request;
|
||||
use api::v1::query_request::Query;
|
||||
use api::v1::{
|
||||
AlterExpr, AuthHeader, CreateTableExpr, DdlRequest, DropTableExpr, FlushTableExpr,
|
||||
GreptimeRequest, InsertRequest, PromRangeQuery, QueryRequest, RequestHeader,
|
||||
greptime_response, AffectedRows, AlterExpr, AuthHeader, CreateTableExpr, DdlRequest,
|
||||
DropTableExpr, FlushTableExpr, GreptimeRequest, InsertRequest, PromRangeQuery, QueryRequest,
|
||||
RequestHeader,
|
||||
};
|
||||
use arrow_flight::{FlightData, Ticket};
|
||||
use common_error::prelude::*;
|
||||
@@ -31,7 +30,9 @@ use futures_util::{TryFutureExt, TryStreamExt};
|
||||
use prost::Message;
|
||||
use snafu::{ensure, ResultExt};
|
||||
|
||||
use crate::error::{ConvertFlightDataSnafu, IllegalFlightMessagesSnafu};
|
||||
use crate::error::{
|
||||
ConvertFlightDataSnafu, IllegalDatabaseResponseSnafu, IllegalFlightMessagesSnafu,
|
||||
};
|
||||
use crate::{error, Client, Result};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
@@ -78,8 +79,26 @@ impl Database {
|
||||
});
|
||||
}
|
||||
|
||||
pub async fn insert(&self, request: InsertRequest) -> Result<Output> {
|
||||
self.do_get(Request::Insert(request)).await
|
||||
pub async fn insert(&self, request: InsertRequest) -> Result<u32> {
|
||||
let mut client = self.client.make_database_client()?.inner;
|
||||
let request = GreptimeRequest {
|
||||
header: Some(RequestHeader {
|
||||
catalog: self.catalog.clone(),
|
||||
schema: self.schema.clone(),
|
||||
authorization: self.ctx.auth_header.clone(),
|
||||
}),
|
||||
request: Some(Request::Insert(request)),
|
||||
};
|
||||
let response = client
|
||||
.handle(request)
|
||||
.await?
|
||||
.into_inner()
|
||||
.response
|
||||
.context(IllegalDatabaseResponseSnafu {
|
||||
err_msg: "GreptimeResponse is empty",
|
||||
})?;
|
||||
let greptime_response::Response::AffectedRows(AffectedRows { value }) = response;
|
||||
Ok(value)
|
||||
}
|
||||
|
||||
pub async fn sql(&self, sql: &str) -> Result<Output> {
|
||||
@@ -155,7 +174,7 @@ impl Database {
|
||||
ticket: request.encode_to_vec().into(),
|
||||
};
|
||||
|
||||
let mut client = self.client.make_client()?;
|
||||
let mut client = self.client.make_flight_client()?;
|
||||
|
||||
// TODO(LFC): Streaming get flight data.
|
||||
let flight_data: Vec<FlightData> = client
|
||||
@@ -164,22 +183,22 @@ impl Database {
|
||||
.and_then(|response| response.into_inner().try_collect())
|
||||
.await
|
||||
.map_err(|e| {
|
||||
let code = get_metadata_value(&e, INNER_ERROR_CODE)
|
||||
.and_then(|s| StatusCode::from_str(&s).ok())
|
||||
.unwrap_or(StatusCode::Unknown);
|
||||
let msg = get_metadata_value(&e, INNER_ERROR_MSG).unwrap_or(e.to_string());
|
||||
error::ExternalSnafu { code, msg }
|
||||
let tonic_code = e.code();
|
||||
let e: error::Error = e.into();
|
||||
let code = e.status_code();
|
||||
let msg = e.to_string();
|
||||
error::ServerSnafu { code, msg }
|
||||
.fail::<()>()
|
||||
.map_err(BoxedError::new)
|
||||
.context(error::FlightGetSnafu {
|
||||
tonic_code: e.code(),
|
||||
tonic_code,
|
||||
addr: client.addr(),
|
||||
})
|
||||
.map_err(|error| {
|
||||
logging::error!(
|
||||
"Failed to do Flight get, addr: {}, code: {}, source: {}",
|
||||
client.addr(),
|
||||
e.code(),
|
||||
tonic_code,
|
||||
error
|
||||
);
|
||||
error
|
||||
@@ -210,12 +229,6 @@ impl Database {
|
||||
}
|
||||
}
|
||||
|
||||
fn get_metadata_value(e: &tonic::Status, key: &str) -> Option<String> {
|
||||
e.metadata()
|
||||
.get(key)
|
||||
.and_then(|v| String::from_utf8(v.as_bytes().to_vec()).ok())
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct FlightContext {
|
||||
auth_header: Option<AuthHeader>,
|
||||
|
||||
@@ -13,9 +13,10 @@
|
||||
// limitations under the License.
|
||||
|
||||
use std::any::Any;
|
||||
use std::str::FromStr;
|
||||
|
||||
use common_error::prelude::*;
|
||||
use tonic::Code;
|
||||
use tonic::{Code, Status};
|
||||
|
||||
#[derive(Debug, Snafu)]
|
||||
#[snafu(visibility(pub))]
|
||||
@@ -68,6 +69,13 @@ pub enum Error {
|
||||
/// Error deserialized from gRPC metadata
|
||||
#[snafu(display("{}", msg))]
|
||||
ExternalError { code: StatusCode, msg: String },
|
||||
|
||||
// Server error carried in Tonic Status's metadata.
|
||||
#[snafu(display("{}", msg))]
|
||||
Server { code: StatusCode, msg: String },
|
||||
|
||||
#[snafu(display("Illegal Database response: {err_msg}"))]
|
||||
IllegalDatabaseResponse { err_msg: String },
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
@@ -77,7 +85,10 @@ impl ErrorExt for Error {
|
||||
match self {
|
||||
Error::IllegalFlightMessages { .. }
|
||||
| Error::ColumnDataType { .. }
|
||||
| Error::MissingField { .. } => StatusCode::Internal,
|
||||
| Error::MissingField { .. }
|
||||
| Error::IllegalDatabaseResponse { .. } => StatusCode::Internal,
|
||||
|
||||
Error::Server { code, .. } => *code,
|
||||
Error::FlightGet { source, .. } => source.status_code(),
|
||||
Error::CreateChannel { source, .. } | Error::ConvertFlightData { source } => {
|
||||
source.status_code()
|
||||
@@ -95,3 +106,21 @@ impl ErrorExt for Error {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Status> for Error {
|
||||
fn from(e: Status) -> Self {
|
||||
fn get_metadata_value(e: &Status, key: &str) -> Option<String> {
|
||||
e.metadata()
|
||||
.get(key)
|
||||
.and_then(|v| String::from_utf8(v.as_bytes().to_vec()).ok())
|
||||
}
|
||||
|
||||
let code = get_metadata_value(&e, INNER_ERROR_CODE)
|
||||
.and_then(|s| StatusCode::from_str(&s).ok())
|
||||
.unwrap_or(StatusCode::Unknown);
|
||||
|
||||
let msg = get_metadata_value(&e, INNER_ERROR_MSG).unwrap_or(e.to_string());
|
||||
|
||||
Self::Server { code, msg }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,8 +74,7 @@ impl DistTable {
|
||||
|
||||
let mut success = 0;
|
||||
for join in joins {
|
||||
let object_result = join.await.context(error::JoinTaskSnafu)??;
|
||||
let Output::AffectedRows(rows) = object_result else { unreachable!() };
|
||||
let rows = join.await.context(error::JoinTaskSnafu)?? as usize;
|
||||
success += rows;
|
||||
}
|
||||
Ok(Output::AffectedRows(success))
|
||||
|
||||
@@ -47,7 +47,7 @@ impl DatanodeInstance {
|
||||
Self { table, db }
|
||||
}
|
||||
|
||||
pub(crate) async fn grpc_insert(&self, request: InsertRequest) -> client::Result<Output> {
|
||||
pub(crate) async fn grpc_insert(&self, request: InsertRequest) -> client::Result<u32> {
|
||||
self.db.insert(request).await
|
||||
}
|
||||
|
||||
|
||||
@@ -125,15 +125,15 @@ pub(crate) async fn create_datanode_client(
|
||||
|
||||
// create a mock datanode grpc service, see example here:
|
||||
// https://github.com/hyperium/tonic/blob/master/examples/src/mock/mock.rs
|
||||
let datanode_service = GrpcServer::new(
|
||||
let grpc_server = GrpcServer::new(
|
||||
ServerGrpcQueryHandlerAdaptor::arc(datanode_instance),
|
||||
None,
|
||||
runtime,
|
||||
)
|
||||
.create_service();
|
||||
);
|
||||
tokio::spawn(async move {
|
||||
Server::builder()
|
||||
.add_service(datanode_service)
|
||||
.add_service(grpc_server.create_flight_service())
|
||||
.add_service(grpc_server.create_database_service())
|
||||
.serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>(server)]))
|
||||
.await
|
||||
});
|
||||
|
||||
@@ -12,11 +12,14 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
mod database;
|
||||
pub mod flight;
|
||||
pub mod handler;
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use api::v1::greptime_database_server::{GreptimeDatabase, GreptimeDatabaseServer};
|
||||
use arrow_flight::flight_service_server::{FlightService, FlightServiceServer};
|
||||
use async_trait::async_trait;
|
||||
use common_runtime::Runtime;
|
||||
@@ -27,18 +30,21 @@ use tokio::net::TcpListener;
|
||||
use tokio::sync::oneshot::{self, Sender};
|
||||
use tokio::sync::Mutex;
|
||||
use tokio_stream::wrappers::TcpListenerStream;
|
||||
use tonic::Status;
|
||||
|
||||
use crate::auth::UserProviderRef;
|
||||
use crate::error::{AlreadyStartedSnafu, Result, StartGrpcSnafu, TcpBindSnafu};
|
||||
use crate::grpc::database::DatabaseService;
|
||||
use crate::grpc::flight::FlightHandler;
|
||||
use crate::grpc::handler::GreptimeRequestHandler;
|
||||
use crate::query_handler::grpc::ServerGrpcQueryHandlerRef;
|
||||
use crate::server::Server;
|
||||
|
||||
type TonicResult<T> = std::result::Result<T, Status>;
|
||||
|
||||
pub struct GrpcServer {
|
||||
query_handler: ServerGrpcQueryHandlerRef,
|
||||
user_provider: Option<UserProviderRef>,
|
||||
shutdown_tx: Mutex<Option<Sender<()>>>,
|
||||
runtime: Arc<Runtime>,
|
||||
request_handler: Arc<GreptimeRequestHandler>,
|
||||
}
|
||||
|
||||
impl GrpcServer {
|
||||
@@ -47,21 +53,23 @@ impl GrpcServer {
|
||||
user_provider: Option<UserProviderRef>,
|
||||
runtime: Arc<Runtime>,
|
||||
) -> Self {
|
||||
Self {
|
||||
let request_handler = Arc::new(GreptimeRequestHandler::new(
|
||||
query_handler,
|
||||
user_provider,
|
||||
shutdown_tx: Mutex::new(None),
|
||||
runtime,
|
||||
));
|
||||
Self {
|
||||
shutdown_tx: Mutex::new(None),
|
||||
request_handler,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_service(&self) -> FlightServiceServer<impl FlightService> {
|
||||
let service = FlightHandler::new(
|
||||
self.query_handler.clone(),
|
||||
self.user_provider.clone(),
|
||||
self.runtime.clone(),
|
||||
);
|
||||
FlightServiceServer::new(service)
|
||||
pub fn create_flight_service(&self) -> FlightServiceServer<impl FlightService> {
|
||||
FlightServiceServer::new(FlightHandler::new(self.request_handler.clone()))
|
||||
}
|
||||
|
||||
pub fn create_database_service(&self) -> GreptimeDatabaseServer<impl GreptimeDatabase> {
|
||||
GreptimeDatabaseServer::new(DatabaseService::new(self.request_handler.clone()))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -103,7 +111,8 @@ impl Server for GrpcServer {
|
||||
|
||||
// Would block to serve requests.
|
||||
tonic::transport::Server::builder()
|
||||
.add_service(self.create_service())
|
||||
.add_service(self.create_flight_service())
|
||||
.add_service(self.create_database_service())
|
||||
.serve_with_incoming_shutdown(TcpListenerStream::new(listener), rx.map(drop))
|
||||
.await
|
||||
.context(StartGrpcSnafu)?;
|
||||
|
||||
57
src/servers/src/grpc/database.rs
Normal file
57
src/servers/src/grpc/database.rs
Normal file
@@ -0,0 +1,57 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use api::v1::greptime_database_server::GreptimeDatabase;
|
||||
use api::v1::{greptime_response, AffectedRows, GreptimeRequest, GreptimeResponse};
|
||||
use async_trait::async_trait;
|
||||
use common_query::Output;
|
||||
use tonic::{Request, Response, Status};
|
||||
|
||||
use crate::grpc::handler::GreptimeRequestHandler;
|
||||
use crate::grpc::TonicResult;
|
||||
|
||||
pub(crate) struct DatabaseService {
|
||||
handler: Arc<GreptimeRequestHandler>,
|
||||
}
|
||||
|
||||
impl DatabaseService {
|
||||
pub(crate) fn new(handler: Arc<GreptimeRequestHandler>) -> Self {
|
||||
Self { handler }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl GreptimeDatabase for DatabaseService {
|
||||
async fn handle(
|
||||
&self,
|
||||
request: Request<GreptimeRequest>,
|
||||
) -> TonicResult<Response<GreptimeResponse>> {
|
||||
let request = request.into_inner();
|
||||
let output = self.handler.handle_request(request).await?;
|
||||
let response = match output {
|
||||
Output::AffectedRows(rows) => GreptimeResponse {
|
||||
header: None,
|
||||
response: Some(greptime_response::Response::AffectedRows(AffectedRows {
|
||||
value: rows as _,
|
||||
})),
|
||||
},
|
||||
Output::Stream(_) | Output::RecordBatches(_) => {
|
||||
return Err(Status::unimplemented("GreptimeDatabase::handle for query"));
|
||||
}
|
||||
};
|
||||
Ok(Response::new(response))
|
||||
}
|
||||
}
|
||||
@@ -17,8 +17,7 @@ mod stream;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
|
||||
use api::v1::auth_header::AuthScheme;
|
||||
use api::v1::{Basic, GreptimeRequest, RequestHeader};
|
||||
use api::v1::GreptimeRequest;
|
||||
use arrow_flight::flight_service_server::FlightService;
|
||||
use arrow_flight::{
|
||||
Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
|
||||
@@ -27,40 +26,25 @@ use arrow_flight::{
|
||||
use async_trait::async_trait;
|
||||
use common_grpc::flight::{FlightEncoder, FlightMessage};
|
||||
use common_query::Output;
|
||||
use common_runtime::Runtime;
|
||||
use futures::Stream;
|
||||
use prost::Message;
|
||||
use session::context::{QueryContext, QueryContextRef};
|
||||
use snafu::{OptionExt, ResultExt};
|
||||
use snafu::ResultExt;
|
||||
use tonic::{Request, Response, Status, Streaming};
|
||||
|
||||
use crate::auth::{Identity, UserProviderRef};
|
||||
use crate::error;
|
||||
use crate::error::Error::Auth;
|
||||
use crate::error::{NotFoundAuthHeaderSnafu, UnsupportedAuthSchemeSnafu};
|
||||
use crate::grpc::flight::stream::FlightRecordBatchStream;
|
||||
use crate::query_handler::grpc::ServerGrpcQueryHandlerRef;
|
||||
use crate::grpc::handler::GreptimeRequestHandler;
|
||||
use crate::grpc::TonicResult;
|
||||
|
||||
type TonicResult<T> = Result<T, Status>;
|
||||
type TonicStream<T> = Pin<Box<dyn Stream<Item = TonicResult<T>> + Send + Sync + 'static>>;
|
||||
|
||||
pub struct FlightHandler {
|
||||
handler: ServerGrpcQueryHandlerRef,
|
||||
user_provider: Option<UserProviderRef>,
|
||||
runtime: Arc<Runtime>,
|
||||
handler: Arc<GreptimeRequestHandler>,
|
||||
}
|
||||
|
||||
impl FlightHandler {
|
||||
pub fn new(
|
||||
handler: ServerGrpcQueryHandlerRef,
|
||||
user_provider: Option<UserProviderRef>,
|
||||
runtime: Arc<Runtime>,
|
||||
) -> Self {
|
||||
Self {
|
||||
handler,
|
||||
user_provider,
|
||||
runtime,
|
||||
}
|
||||
pub fn new(handler: Arc<GreptimeRequestHandler>) -> Self {
|
||||
Self { handler }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -105,40 +89,8 @@ impl FlightService for FlightHandler {
|
||||
let request =
|
||||
GreptimeRequest::decode(ticket.as_ref()).context(error::InvalidFlightTicketSnafu)?;
|
||||
|
||||
let query = request.request.context(error::InvalidQuerySnafu {
|
||||
reason: "Expecting non-empty GreptimeRequest.",
|
||||
})?;
|
||||
let query_ctx = create_query_context(request.header.as_ref());
|
||||
let output = self.handler.handle_request(request).await?;
|
||||
|
||||
auth(
|
||||
self.user_provider.as_ref(),
|
||||
request.header.as_ref(),
|
||||
&query_ctx,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let handler = self.handler.clone();
|
||||
|
||||
// Executes requests in another runtime to
|
||||
// 1. prevent the execution from being cancelled unexpected by Tonic runtime;
|
||||
// - Refer to our blog for the rational behind it:
|
||||
// https://www.greptime.com/blogs/2023-01-12-hidden-control-flow.html
|
||||
// - Obtaining a `JoinHandle` to get the panic message (if there's any).
|
||||
// From its docs, `JoinHandle` is cancel safe. The task keeps running even it's handle been dropped.
|
||||
// 2. avoid the handler blocks the gRPC runtime incidentally.
|
||||
let handle = self
|
||||
.runtime
|
||||
.spawn(async move { handler.do_query(query, query_ctx).await });
|
||||
|
||||
let output = handle.await.map_err(|e| {
|
||||
if e.is_cancelled() {
|
||||
Status::cancelled(e.to_string())
|
||||
} else if e.is_panic() {
|
||||
Status::internal(format!("{:?}", e.into_panic()))
|
||||
} else {
|
||||
Status::unknown(e.to_string())
|
||||
}
|
||||
})??;
|
||||
let stream = to_flight_data_stream(output);
|
||||
Ok(Response::new(stream))
|
||||
}
|
||||
@@ -195,56 +147,3 @@ fn to_flight_data_stream(output: Output) -> TonicStream<FlightData> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn create_query_context(header: Option<&RequestHeader>) -> QueryContextRef {
|
||||
let ctx = QueryContext::arc();
|
||||
if let Some(header) = header {
|
||||
if !header.catalog.is_empty() {
|
||||
ctx.set_current_catalog(&header.catalog);
|
||||
}
|
||||
|
||||
if !header.schema.is_empty() {
|
||||
ctx.set_current_schema(&header.schema);
|
||||
}
|
||||
};
|
||||
ctx
|
||||
}
|
||||
|
||||
async fn auth(
|
||||
user_provider: Option<&UserProviderRef>,
|
||||
request_header: Option<&RequestHeader>,
|
||||
query_ctx: &QueryContextRef,
|
||||
) -> TonicResult<()> {
|
||||
let Some(user_provider) = user_provider else { return Ok(()) };
|
||||
|
||||
let user_info = match request_header
|
||||
.context(NotFoundAuthHeaderSnafu)?
|
||||
.clone()
|
||||
.authorization
|
||||
.context(NotFoundAuthHeaderSnafu)?
|
||||
.auth_scheme
|
||||
.context(NotFoundAuthHeaderSnafu)?
|
||||
{
|
||||
AuthScheme::Basic(Basic { username, password }) => user_provider
|
||||
.authenticate(
|
||||
Identity::UserId(&username, None),
|
||||
crate::auth::Password::PlainText(&password),
|
||||
)
|
||||
.await
|
||||
.map_err(|e| Auth { source: e }),
|
||||
AuthScheme::Token(_) => UnsupportedAuthSchemeSnafu {
|
||||
name: "Token AuthScheme",
|
||||
}
|
||||
.fail(),
|
||||
}
|
||||
.map_err(|e| Status::unauthenticated(e.to_string()))?;
|
||||
|
||||
user_provider
|
||||
.authorize(
|
||||
&query_ctx.current_catalog(),
|
||||
&query_ctx.current_schema(),
|
||||
&user_info,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| Status::permission_denied(e.to_string()))
|
||||
}
|
||||
|
||||
137
src/servers/src/grpc/handler.rs
Normal file
137
src/servers/src/grpc/handler.rs
Normal file
@@ -0,0 +1,137 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use api::v1::auth_header::AuthScheme;
|
||||
use api::v1::{Basic, GreptimeRequest, RequestHeader};
|
||||
use common_query::Output;
|
||||
use common_runtime::Runtime;
|
||||
use session::context::{QueryContext, QueryContextRef};
|
||||
use snafu::OptionExt;
|
||||
use tonic::Status;
|
||||
|
||||
use crate::auth::{Identity, Password, UserProviderRef};
|
||||
use crate::error::Error::{Auth, UnsupportedAuthScheme};
|
||||
use crate::error::{InvalidQuerySnafu, NotFoundAuthHeaderSnafu};
|
||||
use crate::grpc::TonicResult;
|
||||
use crate::query_handler::grpc::ServerGrpcQueryHandlerRef;
|
||||
|
||||
pub struct GreptimeRequestHandler {
|
||||
handler: ServerGrpcQueryHandlerRef,
|
||||
user_provider: Option<UserProviderRef>,
|
||||
runtime: Arc<Runtime>,
|
||||
}
|
||||
|
||||
impl GreptimeRequestHandler {
|
||||
pub fn new(
|
||||
handler: ServerGrpcQueryHandlerRef,
|
||||
user_provider: Option<UserProviderRef>,
|
||||
runtime: Arc<Runtime>,
|
||||
) -> Self {
|
||||
Self {
|
||||
handler,
|
||||
user_provider,
|
||||
runtime,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn handle_request(&self, request: GreptimeRequest) -> TonicResult<Output> {
|
||||
let query = request.request.context(InvalidQuerySnafu {
|
||||
reason: "Expecting non-empty GreptimeRequest.",
|
||||
})?;
|
||||
|
||||
let header = request.header.as_ref();
|
||||
let query_ctx = create_query_context(header);
|
||||
|
||||
self.auth(header, &query_ctx).await?;
|
||||
|
||||
let handler = self.handler.clone();
|
||||
|
||||
// Executes requests in another runtime to
|
||||
// 1. prevent the execution from being cancelled unexpected by Tonic runtime;
|
||||
// - Refer to our blog for the rational behind it:
|
||||
// https://www.greptime.com/blogs/2023-01-12-hidden-control-flow.html
|
||||
// - Obtaining a `JoinHandle` to get the panic message (if there's any).
|
||||
// From its docs, `JoinHandle` is cancel safe. The task keeps running even it's handle been dropped.
|
||||
// 2. avoid the handler blocks the gRPC runtime incidentally.
|
||||
let handle = self
|
||||
.runtime
|
||||
.spawn(async move { handler.do_query(query, query_ctx).await });
|
||||
|
||||
let output = handle.await.map_err(|e| {
|
||||
if e.is_cancelled() {
|
||||
Status::cancelled(e.to_string())
|
||||
} else if e.is_panic() {
|
||||
Status::internal(format!("{:?}", e.into_panic()))
|
||||
} else {
|
||||
Status::unknown(e.to_string())
|
||||
}
|
||||
})??;
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
async fn auth(
|
||||
&self,
|
||||
header: Option<&RequestHeader>,
|
||||
query_ctx: &QueryContextRef,
|
||||
) -> TonicResult<()> {
|
||||
let Some(user_provider) = self.user_provider.as_ref() else { return Ok(()) };
|
||||
|
||||
let auth_scheme = header
|
||||
.and_then(|header| {
|
||||
header
|
||||
.authorization
|
||||
.as_ref()
|
||||
.and_then(|x| x.auth_scheme.clone())
|
||||
})
|
||||
.context(NotFoundAuthHeaderSnafu)?;
|
||||
|
||||
let user_info = match auth_scheme {
|
||||
AuthScheme::Basic(Basic { username, password }) => user_provider
|
||||
.authenticate(
|
||||
Identity::UserId(&username, None),
|
||||
Password::PlainText(&password),
|
||||
)
|
||||
.await
|
||||
.map_err(|e| Auth { source: e }),
|
||||
AuthScheme::Token(_) => Err(UnsupportedAuthScheme {
|
||||
name: "Token AuthScheme".to_string(),
|
||||
}),
|
||||
}
|
||||
.map_err(|e| Status::unauthenticated(e.to_string()))?;
|
||||
|
||||
user_provider
|
||||
.authorize(
|
||||
&query_ctx.current_catalog(),
|
||||
&query_ctx.current_schema(),
|
||||
&user_info,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| Status::permission_denied(e.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
fn create_query_context(header: Option<&RequestHeader>) -> QueryContextRef {
|
||||
let ctx = QueryContext::arc();
|
||||
if let Some(header) = header {
|
||||
if !header.catalog.is_empty() {
|
||||
ctx.set_current_catalog(&header.catalog);
|
||||
}
|
||||
if !header.schema.is_empty() {
|
||||
ctx.set_current_schema(&header.schema);
|
||||
}
|
||||
};
|
||||
ctx
|
||||
}
|
||||
@@ -24,6 +24,7 @@ use common_runtime::{Builder as RuntimeBuilder, Runtime};
|
||||
use servers::auth::UserProviderRef;
|
||||
use servers::error::{Result, StartGrpcSnafu, TcpBindSnafu};
|
||||
use servers::grpc::flight::FlightHandler;
|
||||
use servers::grpc::handler::GreptimeRequestHandler;
|
||||
use servers::query_handler::grpc::ServerGrpcQueryHandlerRef;
|
||||
use servers::server::Server;
|
||||
use snafu::ResultExt;
|
||||
@@ -54,11 +55,11 @@ impl MockGrpcServer {
|
||||
}
|
||||
|
||||
fn create_service(&self) -> FlightServiceServer<impl FlightService> {
|
||||
let service = FlightHandler::new(
|
||||
let service = FlightHandler::new(Arc::new(GreptimeRequestHandler::new(
|
||||
self.query_handler.clone(),
|
||||
self.user_provider.clone(),
|
||||
self.runtime.clone(),
|
||||
);
|
||||
)));
|
||||
FlightServiceServer::new(service)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -183,7 +183,7 @@ async fn insert_and_assert(db: &Database) {
|
||||
row_count: 4,
|
||||
};
|
||||
let result = db.insert(request).await;
|
||||
result.unwrap();
|
||||
assert_eq!(result.unwrap(), 4);
|
||||
|
||||
let result = db
|
||||
.sql(
|
||||
|
||||
Reference in New Issue
Block a user