diff --git a/Cargo.lock b/Cargo.lock index bf54d32985..42c00bda63 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3085,7 +3085,7 @@ checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" [[package]] name = "greptime-proto" version = "0.1.0" -source = "git+https://github.com/GreptimeTeam/greptime-proto.git?rev=3a715150563b89d5dfc81a5838eac1f66a5658a1#3a715150563b89d5dfc81a5838eac1f66a5658a1" +source = "git+https://github.com/GreptimeTeam/greptime-proto.git?rev=0a7b790ed41364b5599dff806d1080bd59c5c9f6#0a7b790ed41364b5599dff806d1080bd59c5c9f6" dependencies = [ "prost", "tonic", diff --git a/src/api/Cargo.toml b/src/api/Cargo.toml index dbb7b8ee9d..4e280c72e5 100644 --- a/src/api/Cargo.toml +++ b/src/api/Cargo.toml @@ -10,7 +10,7 @@ common-base = { path = "../common/base" } common-error = { path = "../common/error" } common-time = { path = "../common/time" } datatypes = { path = "../datatypes" } -greptime-proto = { git = "https://github.com/GreptimeTeam/greptime-proto.git", rev = "3a715150563b89d5dfc81a5838eac1f66a5658a1" } +greptime-proto = { git = "https://github.com/GreptimeTeam/greptime-proto.git", rev = "0a7b790ed41364b5599dff806d1080bd59c5c9f6" } prost.workspace = true snafu = { version = "0.7", features = ["backtraces"] } tonic.workspace = true diff --git a/src/client/src/client.rs b/src/client/src/client.rs index ca09fe047d..0341a26f5a 100644 --- a/src/client/src/client.rs +++ b/src/client/src/client.rs @@ -14,6 +14,7 @@ use std::sync::Arc; +use api::v1::greptime_database_client::GreptimeDatabaseClient; use arrow_flight::flight_service_client::FlightServiceClient; use common_grpc::channel_manager::ChannelManager; use parking_lot::RwLock; @@ -23,6 +24,10 @@ use tonic::transport::Channel; use crate::load_balance::{LoadBalance, Loadbalancer}; use crate::{error, Result}; +pub(crate) struct DatabaseClient { + pub(crate) inner: GreptimeDatabaseClient, +} + pub(crate) struct FlightClient { addr: String, client: FlightServiceClient, @@ -118,7 +123,7 @@ impl Client { self.inner.set_peers(urls); } - pub(crate) fn make_client(&self) -> Result { + fn find_channel(&self) -> Result<(String, Channel)> { let addr = self .inner .get_peer() @@ -131,11 +136,23 @@ impl Client { .channel_manager .get(&addr) .context(error::CreateChannelSnafu { addr: &addr })?; + Ok((addr, channel)) + } + + pub(crate) fn make_flight_client(&self) -> Result { + let (addr, channel) = self.find_channel()?; Ok(FlightClient { addr, client: FlightServiceClient::new(channel), }) } + + pub(crate) fn make_database_client(&self) -> Result { + let (_, channel) = self.find_channel()?; + Ok(DatabaseClient { + inner: GreptimeDatabaseClient::new(channel), + }) + } } #[cfg(test)] diff --git a/src/client/src/database.rs b/src/client/src/database.rs index de6295f643..c0be177408 100644 --- a/src/client/src/database.rs +++ b/src/client/src/database.rs @@ -12,15 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::str::FromStr; - use api::v1::auth_header::AuthScheme; use api::v1::ddl_request::Expr as DdlExpr; use api::v1::greptime_request::Request; use api::v1::query_request::Query; use api::v1::{ - AlterExpr, AuthHeader, CreateTableExpr, DdlRequest, DropTableExpr, FlushTableExpr, - GreptimeRequest, InsertRequest, PromRangeQuery, QueryRequest, RequestHeader, + greptime_response, AffectedRows, AlterExpr, AuthHeader, CreateTableExpr, DdlRequest, + DropTableExpr, FlushTableExpr, GreptimeRequest, InsertRequest, PromRangeQuery, QueryRequest, + RequestHeader, }; use arrow_flight::{FlightData, Ticket}; use common_error::prelude::*; @@ -31,7 +30,9 @@ use futures_util::{TryFutureExt, TryStreamExt}; use prost::Message; use snafu::{ensure, ResultExt}; -use crate::error::{ConvertFlightDataSnafu, IllegalFlightMessagesSnafu}; +use crate::error::{ + ConvertFlightDataSnafu, IllegalDatabaseResponseSnafu, IllegalFlightMessagesSnafu, +}; use crate::{error, Client, Result}; #[derive(Clone, Debug)] @@ -78,8 +79,26 @@ impl Database { }); } - pub async fn insert(&self, request: InsertRequest) -> Result { - self.do_get(Request::Insert(request)).await + pub async fn insert(&self, request: InsertRequest) -> Result { + let mut client = self.client.make_database_client()?.inner; + let request = GreptimeRequest { + header: Some(RequestHeader { + catalog: self.catalog.clone(), + schema: self.schema.clone(), + authorization: self.ctx.auth_header.clone(), + }), + request: Some(Request::Insert(request)), + }; + let response = client + .handle(request) + .await? + .into_inner() + .response + .context(IllegalDatabaseResponseSnafu { + err_msg: "GreptimeResponse is empty", + })?; + let greptime_response::Response::AffectedRows(AffectedRows { value }) = response; + Ok(value) } pub async fn sql(&self, sql: &str) -> Result { @@ -155,7 +174,7 @@ impl Database { ticket: request.encode_to_vec().into(), }; - let mut client = self.client.make_client()?; + let mut client = self.client.make_flight_client()?; // TODO(LFC): Streaming get flight data. let flight_data: Vec = client @@ -164,22 +183,22 @@ impl Database { .and_then(|response| response.into_inner().try_collect()) .await .map_err(|e| { - let code = get_metadata_value(&e, INNER_ERROR_CODE) - .and_then(|s| StatusCode::from_str(&s).ok()) - .unwrap_or(StatusCode::Unknown); - let msg = get_metadata_value(&e, INNER_ERROR_MSG).unwrap_or(e.to_string()); - error::ExternalSnafu { code, msg } + let tonic_code = e.code(); + let e: error::Error = e.into(); + let code = e.status_code(); + let msg = e.to_string(); + error::ServerSnafu { code, msg } .fail::<()>() .map_err(BoxedError::new) .context(error::FlightGetSnafu { - tonic_code: e.code(), + tonic_code, addr: client.addr(), }) .map_err(|error| { logging::error!( "Failed to do Flight get, addr: {}, code: {}, source: {}", client.addr(), - e.code(), + tonic_code, error ); error @@ -210,12 +229,6 @@ impl Database { } } -fn get_metadata_value(e: &tonic::Status, key: &str) -> Option { - e.metadata() - .get(key) - .and_then(|v| String::from_utf8(v.as_bytes().to_vec()).ok()) -} - #[derive(Default, Debug, Clone)] pub struct FlightContext { auth_header: Option, diff --git a/src/client/src/error.rs b/src/client/src/error.rs index aae7f0866a..3c28b753f0 100644 --- a/src/client/src/error.rs +++ b/src/client/src/error.rs @@ -13,9 +13,10 @@ // limitations under the License. use std::any::Any; +use std::str::FromStr; use common_error::prelude::*; -use tonic::Code; +use tonic::{Code, Status}; #[derive(Debug, Snafu)] #[snafu(visibility(pub))] @@ -68,6 +69,13 @@ pub enum Error { /// Error deserialized from gRPC metadata #[snafu(display("{}", msg))] ExternalError { code: StatusCode, msg: String }, + + // Server error carried in Tonic Status's metadata. + #[snafu(display("{}", msg))] + Server { code: StatusCode, msg: String }, + + #[snafu(display("Illegal Database response: {err_msg}"))] + IllegalDatabaseResponse { err_msg: String }, } pub type Result = std::result::Result; @@ -77,7 +85,10 @@ impl ErrorExt for Error { match self { Error::IllegalFlightMessages { .. } | Error::ColumnDataType { .. } - | Error::MissingField { .. } => StatusCode::Internal, + | Error::MissingField { .. } + | Error::IllegalDatabaseResponse { .. } => StatusCode::Internal, + + Error::Server { code, .. } => *code, Error::FlightGet { source, .. } => source.status_code(), Error::CreateChannel { source, .. } | Error::ConvertFlightData { source } => { source.status_code() @@ -95,3 +106,21 @@ impl ErrorExt for Error { self } } + +impl From for Error { + fn from(e: Status) -> Self { + fn get_metadata_value(e: &Status, key: &str) -> Option { + e.metadata() + .get(key) + .and_then(|v| String::from_utf8(v.as_bytes().to_vec()).ok()) + } + + let code = get_metadata_value(&e, INNER_ERROR_CODE) + .and_then(|s| StatusCode::from_str(&s).ok()) + .unwrap_or(StatusCode::Unknown); + + let msg = get_metadata_value(&e, INNER_ERROR_MSG).unwrap_or(e.to_string()); + + Self::Server { code, msg } + } +} diff --git a/src/frontend/src/table/insert.rs b/src/frontend/src/table/insert.rs index 8919838cef..b134b16018 100644 --- a/src/frontend/src/table/insert.rs +++ b/src/frontend/src/table/insert.rs @@ -74,8 +74,7 @@ impl DistTable { let mut success = 0; for join in joins { - let object_result = join.await.context(error::JoinTaskSnafu)??; - let Output::AffectedRows(rows) = object_result else { unreachable!() }; + let rows = join.await.context(error::JoinTaskSnafu)?? as usize; success += rows; } Ok(Output::AffectedRows(success)) diff --git a/src/frontend/src/table/scan.rs b/src/frontend/src/table/scan.rs index 5b9da34e71..a43b40a665 100644 --- a/src/frontend/src/table/scan.rs +++ b/src/frontend/src/table/scan.rs @@ -47,7 +47,7 @@ impl DatanodeInstance { Self { table, db } } - pub(crate) async fn grpc_insert(&self, request: InsertRequest) -> client::Result { + pub(crate) async fn grpc_insert(&self, request: InsertRequest) -> client::Result { self.db.insert(request).await } diff --git a/src/frontend/src/tests.rs b/src/frontend/src/tests.rs index ad79cd261a..84358e7570 100644 --- a/src/frontend/src/tests.rs +++ b/src/frontend/src/tests.rs @@ -125,15 +125,15 @@ pub(crate) async fn create_datanode_client( // create a mock datanode grpc service, see example here: // https://github.com/hyperium/tonic/blob/master/examples/src/mock/mock.rs - let datanode_service = GrpcServer::new( + let grpc_server = GrpcServer::new( ServerGrpcQueryHandlerAdaptor::arc(datanode_instance), None, runtime, - ) - .create_service(); + ); tokio::spawn(async move { Server::builder() - .add_service(datanode_service) + .add_service(grpc_server.create_flight_service()) + .add_service(grpc_server.create_database_service()) .serve_with_incoming(futures::stream::iter(vec![Ok::<_, std::io::Error>(server)])) .await }); diff --git a/src/servers/src/grpc.rs b/src/servers/src/grpc.rs index 1e09f302ec..4edfb1817a 100644 --- a/src/servers/src/grpc.rs +++ b/src/servers/src/grpc.rs @@ -12,11 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. +mod database; pub mod flight; +pub mod handler; use std::net::SocketAddr; use std::sync::Arc; +use api::v1::greptime_database_server::{GreptimeDatabase, GreptimeDatabaseServer}; use arrow_flight::flight_service_server::{FlightService, FlightServiceServer}; use async_trait::async_trait; use common_runtime::Runtime; @@ -27,18 +30,21 @@ use tokio::net::TcpListener; use tokio::sync::oneshot::{self, Sender}; use tokio::sync::Mutex; use tokio_stream::wrappers::TcpListenerStream; +use tonic::Status; use crate::auth::UserProviderRef; use crate::error::{AlreadyStartedSnafu, Result, StartGrpcSnafu, TcpBindSnafu}; +use crate::grpc::database::DatabaseService; use crate::grpc::flight::FlightHandler; +use crate::grpc::handler::GreptimeRequestHandler; use crate::query_handler::grpc::ServerGrpcQueryHandlerRef; use crate::server::Server; +type TonicResult = std::result::Result; + pub struct GrpcServer { - query_handler: ServerGrpcQueryHandlerRef, - user_provider: Option, shutdown_tx: Mutex>>, - runtime: Arc, + request_handler: Arc, } impl GrpcServer { @@ -47,21 +53,23 @@ impl GrpcServer { user_provider: Option, runtime: Arc, ) -> Self { - Self { + let request_handler = Arc::new(GreptimeRequestHandler::new( query_handler, user_provider, - shutdown_tx: Mutex::new(None), runtime, + )); + Self { + shutdown_tx: Mutex::new(None), + request_handler, } } - pub fn create_service(&self) -> FlightServiceServer { - let service = FlightHandler::new( - self.query_handler.clone(), - self.user_provider.clone(), - self.runtime.clone(), - ); - FlightServiceServer::new(service) + pub fn create_flight_service(&self) -> FlightServiceServer { + FlightServiceServer::new(FlightHandler::new(self.request_handler.clone())) + } + + pub fn create_database_service(&self) -> GreptimeDatabaseServer { + GreptimeDatabaseServer::new(DatabaseService::new(self.request_handler.clone())) } } @@ -103,7 +111,8 @@ impl Server for GrpcServer { // Would block to serve requests. tonic::transport::Server::builder() - .add_service(self.create_service()) + .add_service(self.create_flight_service()) + .add_service(self.create_database_service()) .serve_with_incoming_shutdown(TcpListenerStream::new(listener), rx.map(drop)) .await .context(StartGrpcSnafu)?; diff --git a/src/servers/src/grpc/database.rs b/src/servers/src/grpc/database.rs new file mode 100644 index 0000000000..350827c023 --- /dev/null +++ b/src/servers/src/grpc/database.rs @@ -0,0 +1,57 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use api::v1::greptime_database_server::GreptimeDatabase; +use api::v1::{greptime_response, AffectedRows, GreptimeRequest, GreptimeResponse}; +use async_trait::async_trait; +use common_query::Output; +use tonic::{Request, Response, Status}; + +use crate::grpc::handler::GreptimeRequestHandler; +use crate::grpc::TonicResult; + +pub(crate) struct DatabaseService { + handler: Arc, +} + +impl DatabaseService { + pub(crate) fn new(handler: Arc) -> Self { + Self { handler } + } +} + +#[async_trait] +impl GreptimeDatabase for DatabaseService { + async fn handle( + &self, + request: Request, + ) -> TonicResult> { + let request = request.into_inner(); + let output = self.handler.handle_request(request).await?; + let response = match output { + Output::AffectedRows(rows) => GreptimeResponse { + header: None, + response: Some(greptime_response::Response::AffectedRows(AffectedRows { + value: rows as _, + })), + }, + Output::Stream(_) | Output::RecordBatches(_) => { + return Err(Status::unimplemented("GreptimeDatabase::handle for query")); + } + }; + Ok(Response::new(response)) + } +} diff --git a/src/servers/src/grpc/flight.rs b/src/servers/src/grpc/flight.rs index 72bcd632df..0b793d9855 100644 --- a/src/servers/src/grpc/flight.rs +++ b/src/servers/src/grpc/flight.rs @@ -17,8 +17,7 @@ mod stream; use std::pin::Pin; use std::sync::Arc; -use api::v1::auth_header::AuthScheme; -use api::v1::{Basic, GreptimeRequest, RequestHeader}; +use api::v1::GreptimeRequest; use arrow_flight::flight_service_server::FlightService; use arrow_flight::{ Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo, @@ -27,40 +26,25 @@ use arrow_flight::{ use async_trait::async_trait; use common_grpc::flight::{FlightEncoder, FlightMessage}; use common_query::Output; -use common_runtime::Runtime; use futures::Stream; use prost::Message; -use session::context::{QueryContext, QueryContextRef}; -use snafu::{OptionExt, ResultExt}; +use snafu::ResultExt; use tonic::{Request, Response, Status, Streaming}; -use crate::auth::{Identity, UserProviderRef}; use crate::error; -use crate::error::Error::Auth; -use crate::error::{NotFoundAuthHeaderSnafu, UnsupportedAuthSchemeSnafu}; use crate::grpc::flight::stream::FlightRecordBatchStream; -use crate::query_handler::grpc::ServerGrpcQueryHandlerRef; +use crate::grpc::handler::GreptimeRequestHandler; +use crate::grpc::TonicResult; -type TonicResult = Result; type TonicStream = Pin> + Send + Sync + 'static>>; pub struct FlightHandler { - handler: ServerGrpcQueryHandlerRef, - user_provider: Option, - runtime: Arc, + handler: Arc, } impl FlightHandler { - pub fn new( - handler: ServerGrpcQueryHandlerRef, - user_provider: Option, - runtime: Arc, - ) -> Self { - Self { - handler, - user_provider, - runtime, - } + pub fn new(handler: Arc) -> Self { + Self { handler } } } @@ -105,40 +89,8 @@ impl FlightService for FlightHandler { let request = GreptimeRequest::decode(ticket.as_ref()).context(error::InvalidFlightTicketSnafu)?; - let query = request.request.context(error::InvalidQuerySnafu { - reason: "Expecting non-empty GreptimeRequest.", - })?; - let query_ctx = create_query_context(request.header.as_ref()); + let output = self.handler.handle_request(request).await?; - auth( - self.user_provider.as_ref(), - request.header.as_ref(), - &query_ctx, - ) - .await?; - - let handler = self.handler.clone(); - - // Executes requests in another runtime to - // 1. prevent the execution from being cancelled unexpected by Tonic runtime; - // - Refer to our blog for the rational behind it: - // https://www.greptime.com/blogs/2023-01-12-hidden-control-flow.html - // - Obtaining a `JoinHandle` to get the panic message (if there's any). - // From its docs, `JoinHandle` is cancel safe. The task keeps running even it's handle been dropped. - // 2. avoid the handler blocks the gRPC runtime incidentally. - let handle = self - .runtime - .spawn(async move { handler.do_query(query, query_ctx).await }); - - let output = handle.await.map_err(|e| { - if e.is_cancelled() { - Status::cancelled(e.to_string()) - } else if e.is_panic() { - Status::internal(format!("{:?}", e.into_panic())) - } else { - Status::unknown(e.to_string()) - } - })??; let stream = to_flight_data_stream(output); Ok(Response::new(stream)) } @@ -195,56 +147,3 @@ fn to_flight_data_stream(output: Output) -> TonicStream { } } } - -fn create_query_context(header: Option<&RequestHeader>) -> QueryContextRef { - let ctx = QueryContext::arc(); - if let Some(header) = header { - if !header.catalog.is_empty() { - ctx.set_current_catalog(&header.catalog); - } - - if !header.schema.is_empty() { - ctx.set_current_schema(&header.schema); - } - }; - ctx -} - -async fn auth( - user_provider: Option<&UserProviderRef>, - request_header: Option<&RequestHeader>, - query_ctx: &QueryContextRef, -) -> TonicResult<()> { - let Some(user_provider) = user_provider else { return Ok(()) }; - - let user_info = match request_header - .context(NotFoundAuthHeaderSnafu)? - .clone() - .authorization - .context(NotFoundAuthHeaderSnafu)? - .auth_scheme - .context(NotFoundAuthHeaderSnafu)? - { - AuthScheme::Basic(Basic { username, password }) => user_provider - .authenticate( - Identity::UserId(&username, None), - crate::auth::Password::PlainText(&password), - ) - .await - .map_err(|e| Auth { source: e }), - AuthScheme::Token(_) => UnsupportedAuthSchemeSnafu { - name: "Token AuthScheme", - } - .fail(), - } - .map_err(|e| Status::unauthenticated(e.to_string()))?; - - user_provider - .authorize( - &query_ctx.current_catalog(), - &query_ctx.current_schema(), - &user_info, - ) - .await - .map_err(|e| Status::permission_denied(e.to_string())) -} diff --git a/src/servers/src/grpc/handler.rs b/src/servers/src/grpc/handler.rs new file mode 100644 index 0000000000..da97b77b2c --- /dev/null +++ b/src/servers/src/grpc/handler.rs @@ -0,0 +1,137 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use api::v1::auth_header::AuthScheme; +use api::v1::{Basic, GreptimeRequest, RequestHeader}; +use common_query::Output; +use common_runtime::Runtime; +use session::context::{QueryContext, QueryContextRef}; +use snafu::OptionExt; +use tonic::Status; + +use crate::auth::{Identity, Password, UserProviderRef}; +use crate::error::Error::{Auth, UnsupportedAuthScheme}; +use crate::error::{InvalidQuerySnafu, NotFoundAuthHeaderSnafu}; +use crate::grpc::TonicResult; +use crate::query_handler::grpc::ServerGrpcQueryHandlerRef; + +pub struct GreptimeRequestHandler { + handler: ServerGrpcQueryHandlerRef, + user_provider: Option, + runtime: Arc, +} + +impl GreptimeRequestHandler { + pub fn new( + handler: ServerGrpcQueryHandlerRef, + user_provider: Option, + runtime: Arc, + ) -> Self { + Self { + handler, + user_provider, + runtime, + } + } + + pub(crate) async fn handle_request(&self, request: GreptimeRequest) -> TonicResult { + let query = request.request.context(InvalidQuerySnafu { + reason: "Expecting non-empty GreptimeRequest.", + })?; + + let header = request.header.as_ref(); + let query_ctx = create_query_context(header); + + self.auth(header, &query_ctx).await?; + + let handler = self.handler.clone(); + + // Executes requests in another runtime to + // 1. prevent the execution from being cancelled unexpected by Tonic runtime; + // - Refer to our blog for the rational behind it: + // https://www.greptime.com/blogs/2023-01-12-hidden-control-flow.html + // - Obtaining a `JoinHandle` to get the panic message (if there's any). + // From its docs, `JoinHandle` is cancel safe. The task keeps running even it's handle been dropped. + // 2. avoid the handler blocks the gRPC runtime incidentally. + let handle = self + .runtime + .spawn(async move { handler.do_query(query, query_ctx).await }); + + let output = handle.await.map_err(|e| { + if e.is_cancelled() { + Status::cancelled(e.to_string()) + } else if e.is_panic() { + Status::internal(format!("{:?}", e.into_panic())) + } else { + Status::unknown(e.to_string()) + } + })??; + Ok(output) + } + + async fn auth( + &self, + header: Option<&RequestHeader>, + query_ctx: &QueryContextRef, + ) -> TonicResult<()> { + let Some(user_provider) = self.user_provider.as_ref() else { return Ok(()) }; + + let auth_scheme = header + .and_then(|header| { + header + .authorization + .as_ref() + .and_then(|x| x.auth_scheme.clone()) + }) + .context(NotFoundAuthHeaderSnafu)?; + + let user_info = match auth_scheme { + AuthScheme::Basic(Basic { username, password }) => user_provider + .authenticate( + Identity::UserId(&username, None), + Password::PlainText(&password), + ) + .await + .map_err(|e| Auth { source: e }), + AuthScheme::Token(_) => Err(UnsupportedAuthScheme { + name: "Token AuthScheme".to_string(), + }), + } + .map_err(|e| Status::unauthenticated(e.to_string()))?; + + user_provider + .authorize( + &query_ctx.current_catalog(), + &query_ctx.current_schema(), + &user_info, + ) + .await + .map_err(|e| Status::permission_denied(e.to_string())) + } +} + +fn create_query_context(header: Option<&RequestHeader>) -> QueryContextRef { + let ctx = QueryContext::arc(); + if let Some(header) = header { + if !header.catalog.is_empty() { + ctx.set_current_catalog(&header.catalog); + } + if !header.schema.is_empty() { + ctx.set_current_schema(&header.schema); + } + }; + ctx +} diff --git a/src/servers/tests/grpc/mod.rs b/src/servers/tests/grpc/mod.rs index 38da7fb48f..76edaee9f2 100644 --- a/src/servers/tests/grpc/mod.rs +++ b/src/servers/tests/grpc/mod.rs @@ -24,6 +24,7 @@ use common_runtime::{Builder as RuntimeBuilder, Runtime}; use servers::auth::UserProviderRef; use servers::error::{Result, StartGrpcSnafu, TcpBindSnafu}; use servers::grpc::flight::FlightHandler; +use servers::grpc::handler::GreptimeRequestHandler; use servers::query_handler::grpc::ServerGrpcQueryHandlerRef; use servers::server::Server; use snafu::ResultExt; @@ -54,11 +55,11 @@ impl MockGrpcServer { } fn create_service(&self) -> FlightServiceServer { - let service = FlightHandler::new( + let service = FlightHandler::new(Arc::new(GreptimeRequestHandler::new( self.query_handler.clone(), self.user_provider.clone(), self.runtime.clone(), - ); + ))); FlightServiceServer::new(service) } } diff --git a/tests-integration/tests/grpc.rs b/tests-integration/tests/grpc.rs index 0dc2fa71d6..3abe9a8d5b 100644 --- a/tests-integration/tests/grpc.rs +++ b/tests-integration/tests/grpc.rs @@ -183,7 +183,7 @@ async fn insert_and_assert(db: &Database) { row_count: 4, }; let result = db.insert(request).await; - result.unwrap(); + assert_eq!(result.unwrap(), 4); let result = db .sql(