diff --git a/Cargo.lock b/Cargo.lock index 792a891564..c2c15a6a99 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2975,7 +2975,7 @@ checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" [[package]] name = "greptime-proto" version = "0.1.0" -source = "git+https://github.com/GreptimeTeam/greptime-proto.git?rev=3e6349be127b65a8b42a38cda9d527ec423ca77d#3e6349be127b65a8b42a38cda9d527ec423ca77d" +source = "git+https://github.com/GreptimeTeam/greptime-proto.git?rev=1599ae2a0d1d8f42ee23ed26e4ad7a7b34134c60#1599ae2a0d1d8f42ee23ed26e4ad7a7b34134c60" dependencies = [ "prost 0.11.6", "tonic", @@ -6685,6 +6685,7 @@ dependencies = [ "bytes", "catalog", "chrono", + "client", "common-base", "common-catalog", "common-error", diff --git a/src/api/Cargo.toml b/src/api/Cargo.toml index 384f1d7ddf..14c14546df 100644 --- a/src/api/Cargo.toml +++ b/src/api/Cargo.toml @@ -10,7 +10,7 @@ common-base = { path = "../common/base" } common-error = { path = "../common/error" } common-time = { path = "../common/time" } datatypes = { path = "../datatypes" } -greptime-proto = { git = "https://github.com/GreptimeTeam/greptime-proto.git", rev = "3e6349be127b65a8b42a38cda9d527ec423ca77d" } +greptime-proto = { git = "https://github.com/GreptimeTeam/greptime-proto.git", rev = "1599ae2a0d1d8f42ee23ed26e4ad7a7b34134c60" } prost.workspace = true snafu = { version = "0.7", features = ["backtraces"] } tonic.workspace = true diff --git a/src/client/src/database.rs b/src/client/src/database.rs index bdf63b748e..efc1eacfa8 100644 --- a/src/client/src/database.rs +++ b/src/client/src/database.rs @@ -14,12 +14,13 @@ use std::str::FromStr; +use api::v1::auth_header::AuthScheme; use api::v1::ddl_request::Expr as DdlExpr; use api::v1::greptime_request::Request; use api::v1::query_request::Query; use api::v1::{ - AlterExpr, CreateTableExpr, DdlRequest, DropTableExpr, GreptimeRequest, InsertRequest, - QueryRequest, RequestHeader, + AlterExpr, AuthHeader, CreateTableExpr, DdlRequest, DropTableExpr, GreptimeRequest, + InsertRequest, QueryRequest, RequestHeader, }; use arrow_flight::{FlightData, Ticket}; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; @@ -42,6 +43,7 @@ pub struct Database { schema: String, client: Client, + ctx: FlightContext, } impl Database { @@ -50,6 +52,7 @@ impl Database { catalog: catalog.into(), schema: schema.into(), client, + ctx: FlightContext::default(), } } @@ -61,6 +64,12 @@ impl Database { self.schema = schema.into(); } + pub fn set_auth(&mut self, auth: AuthScheme) { + self.ctx.auth_header = Some(AuthHeader { + auth_scheme: Some(auth), + }); + } + pub async fn insert(&self, request: InsertRequest) -> Result { self.do_get(Request::Insert(request)).await } @@ -105,6 +114,7 @@ impl Database { header: Some(RequestHeader { catalog: self.catalog.clone(), schema: self.schema.clone(), + authorization: self.ctx.auth_header.clone(), }), request: Some(request), }; @@ -164,12 +174,18 @@ fn get_metadata_value(e: &tonic::Status, key: &str) -> Option { .and_then(|v| String::from_utf8(v.as_bytes().to_vec()).ok()) } +#[derive(Default, Debug, Clone)] +pub struct FlightContext { + auth_header: Option, +} + #[cfg(test)] mod tests { use std::sync::Arc; use api::helper::ColumnDataTypeWrapper; - use api::v1::Column; + use api::v1::auth_header::AuthScheme; + use api::v1::{AuthHeader, Basic, Column}; use common_grpc::select::{null_mask, values}; use common_grpc_expr::column_to_vector; use datatypes::prelude::{Vector, VectorRef}; @@ -179,6 +195,8 @@ mod tests { UInt32Vector, UInt64Vector, UInt8Vector, }; + use crate::database::FlightContext; + #[test] fn test_column_to_vector() { let mut column = create_test_column(Arc::new(BooleanVector::from(vec![true]))); @@ -262,4 +280,26 @@ mod tests { datatype: wrapper.datatype() as i32, } } + + #[test] + fn test_flight_ctx() { + let mut ctx = FlightContext::default(); + assert!(ctx.auth_header.is_none()); + + let basic = AuthScheme::Basic(Basic { + username: "u".to_string(), + password: "p".to_string(), + }); + + ctx.auth_header = Some(AuthHeader { + auth_scheme: Some(basic), + }); + + assert!(matches!( + ctx.auth_header, + Some(AuthHeader { + auth_scheme: Some(AuthScheme::Basic(_)), + }) + )) + } } diff --git a/src/datanode/src/server.rs b/src/datanode/src/server.rs index 3827138fb3..482e4f1a10 100644 --- a/src/datanode/src/server.rs +++ b/src/datanode/src/server.rs @@ -89,6 +89,7 @@ impl Services { Ok(Self { grpc_server: GrpcServer::new( ServerGrpcQueryHandlerAdaptor::arc(instance), + None, grpc_runtime, ), mysql_server, diff --git a/src/frontend/src/server.rs b/src/frontend/src/server.rs index 2b07fc6fd0..a6ef49bddf 100644 --- a/src/frontend/src/server.rs +++ b/src/frontend/src/server.rs @@ -66,6 +66,7 @@ impl Services { let grpc_server = GrpcServer::new( ServerGrpcQueryHandlerAdaptor::arc(instance.clone()), + user_provider.clone(), grpc_runtime, ); diff --git a/src/frontend/src/tests.rs b/src/frontend/src/tests.rs index d4a54b7886..b67e68cee8 100644 --- a/src/frontend/src/tests.rs +++ b/src/frontend/src/tests.rs @@ -114,6 +114,7 @@ pub(crate) async fn create_datanode_client( // https://github.com/hyperium/tonic/blob/master/examples/src/mock/mock.rs let datanode_service = GrpcServer::new( ServerGrpcQueryHandlerAdaptor::arc(datanode_instance), + None, runtime, ) .create_service(); diff --git a/src/meta-srv/src/service/store.rs b/src/meta-srv/src/service/store.rs index a972c1ff59..ae7d1d22a2 100644 --- a/src/meta-srv/src/service/store.rs +++ b/src/meta-srv/src/service/store.rs @@ -18,11 +18,11 @@ pub mod kv; pub mod memory; use api::v1::meta::{ - store_server, BatchPutRequest, BatchPutResponse, CompareAndPutRequest, CompareAndPutResponse, - DeleteRangeRequest, DeleteRangeResponse, MoveValueRequest, MoveValueResponse, PutRequest, - PutResponse, RangeRequest, RangeResponse, + store_server, BatchGetRequest, BatchGetResponse, BatchPutRequest, BatchPutResponse, + CompareAndPutRequest, CompareAndPutResponse, DeleteRangeRequest, DeleteRangeResponse, + MoveValueRequest, MoveValueResponse, PutRequest, PutResponse, RangeRequest, RangeResponse, }; -use tonic::{Request, Response}; +use tonic::{Request, Response, Status}; use crate::metasrv::MetaSrv; use crate::service::GrpcResult; @@ -43,6 +43,14 @@ impl store_server::Store for MetaSrv { Ok(Response::new(res)) } + async fn batch_get( + &self, + _request: Request, + ) -> Result, Status> { + // TODO(fys): please fix this + unimplemented!() + } + async fn batch_put(&self, req: Request) -> GrpcResult { let req = req.into_inner(); let res = self.kv_store().batch_put(req).await?; diff --git a/src/servers/Cargo.toml b/src/servers/Cargo.toml index 611a207731..04b092cd5d 100644 --- a/src/servers/Cargo.toml +++ b/src/servers/Cargo.toml @@ -69,6 +69,7 @@ tower-http = { version = "0.3", features = ["full"] } [dev-dependencies] axum-test-helper = { git = "https://github.com/sunng87/axum-test-helper.git", branch = "patch-1" } +client = { path = "../client" } common-base = { path = "../common/base" } mysql_async = { version = "0.31", default-features = false, features = [ "default-rustls", diff --git a/src/servers/src/error.rs b/src/servers/src/error.rs index a17a8eb5d4..d20c201d43 100644 --- a/src/servers/src/error.rs +++ b/src/servers/src/error.rs @@ -22,7 +22,6 @@ use axum::Json; use base64::DecodeError; use catalog; use common_error::prelude::*; -use hyper::header::ToStrError; use serde_json::json; use tonic::codegen::http::{HeaderMap, HeaderValue}; use tonic::metadata::MetadataMap; @@ -150,7 +149,7 @@ pub enum Error { #[snafu(display("Invalid OpenTSDB line, source: {}", source))] InvalidOpentsdbLine { - source: std::string::FromUtf8Error, + source: FromUtf8Error, backtrace: Backtrace, }, @@ -216,7 +215,7 @@ pub enum Error { source: auth::Error, }, - #[snafu(display("Not found http authorization header"))] + #[snafu(display("Not found http or grpc authorization header"))] NotFoundAuthHeader {}, #[snafu(display("Not found influx http authorization info"))] @@ -224,7 +223,7 @@ pub enum Error { #[snafu(display("Invalid visibility ASCII chars, source: {}", source))] InvisibleASCII { - source: ToStrError, + source: hyper::header::ToStrError, backtrace: Backtrace, }, diff --git a/src/servers/src/grpc.rs b/src/servers/src/grpc.rs index afb111629f..39b76018b9 100644 --- a/src/servers/src/grpc.rs +++ b/src/servers/src/grpc.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -mod flight; +pub mod flight; use std::net::SocketAddr; use std::sync::Arc; @@ -28,6 +28,7 @@ use tokio::sync::oneshot::{self, Sender}; use tokio::sync::Mutex; use tokio_stream::wrappers::TcpListenerStream; +use crate::auth::UserProviderRef; use crate::error::{AlreadyStartedSnafu, Result, StartGrpcSnafu, TcpBindSnafu}; use crate::grpc::flight::FlightHandler; use crate::query_handler::grpc::ServerGrpcQueryHandlerRef; @@ -35,21 +36,31 @@ use crate::server::Server; pub struct GrpcServer { query_handler: ServerGrpcQueryHandlerRef, + user_provider: Option, shutdown_tx: Mutex>>, runtime: Arc, } impl GrpcServer { - pub fn new(query_handler: ServerGrpcQueryHandlerRef, runtime: Arc) -> Self { + pub fn new( + query_handler: ServerGrpcQueryHandlerRef, + user_provider: Option, + runtime: Arc, + ) -> Self { Self { query_handler, + user_provider, shutdown_tx: Mutex::new(None), runtime, } } pub fn create_service(&self) -> FlightServiceServer { - let service = FlightHandler::new(self.query_handler.clone(), self.runtime.clone()); + let service = FlightHandler::new( + self.query_handler.clone(), + self.user_provider.clone(), + self.runtime.clone(), + ); FlightServiceServer::new(service) } } diff --git a/src/servers/src/grpc/flight.rs b/src/servers/src/grpc/flight.rs index 196f07960e..dcbfb6fe6c 100644 --- a/src/servers/src/grpc/flight.rs +++ b/src/servers/src/grpc/flight.rs @@ -17,7 +17,8 @@ mod stream; use std::pin::Pin; use std::sync::Arc; -use api::v1::{GreptimeRequest, RequestHeader}; +use api::v1::auth_header::AuthScheme; +use api::v1::{Basic, GreptimeRequest, RequestHeader}; use arrow_flight::flight_service_server::FlightService; use arrow_flight::{ Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo, @@ -33,21 +34,33 @@ use session::context::{QueryContext, QueryContextRef}; use snafu::{OptionExt, ResultExt}; use tonic::{Request, Response, Status, Streaming}; +use crate::auth::{Identity, UserProviderRef}; use crate::error; +use crate::error::Error::Auth; +use crate::error::{NotFoundAuthHeaderSnafu, UnsupportedAuthSchemeSnafu}; use crate::grpc::flight::stream::FlightRecordBatchStream; use crate::query_handler::grpc::ServerGrpcQueryHandlerRef; type TonicResult = Result; type TonicStream = Pin> + Send + Sync + 'static>>; -pub(crate) struct FlightHandler { +pub struct FlightHandler { handler: ServerGrpcQueryHandlerRef, + user_provider: Option, runtime: Arc, } impl FlightHandler { - pub(crate) fn new(handler: ServerGrpcQueryHandlerRef, runtime: Arc) -> Self { - Self { handler, runtime } + pub fn new( + handler: ServerGrpcQueryHandlerRef, + user_provider: Option, + runtime: Arc, + ) -> Self { + Self { + handler, + user_provider, + runtime, + } } } @@ -97,6 +110,13 @@ impl FlightService for FlightHandler { })?; let query_ctx = create_query_context(request.header.as_ref()); + auth( + self.user_provider.as_ref(), + request.header.as_ref(), + &query_ctx, + ) + .await?; + let handler = self.handler.clone(); // Executes requests in another runtime to @@ -189,3 +209,42 @@ fn create_query_context(header: Option<&RequestHeader>) -> QueryContextRef { }; ctx } + +async fn auth( + user_provider: Option<&UserProviderRef>, + request_header: Option<&RequestHeader>, + query_ctx: &QueryContextRef, +) -> TonicResult<()> { + let Some(user_provider) = user_provider else { return Ok(()) }; + + let user_info = match request_header + .context(NotFoundAuthHeaderSnafu)? + .clone() + .authorization + .context(NotFoundAuthHeaderSnafu)? + .auth_scheme + .context(NotFoundAuthHeaderSnafu)? + { + AuthScheme::Basic(Basic { username, password }) => user_provider + .authenticate( + Identity::UserId(&username, None), + crate::auth::Password::PlainText(&password), + ) + .await + .map_err(|e| Auth { source: e }), + AuthScheme::Token(_) => UnsupportedAuthSchemeSnafu { + name: "Token AuthScheme", + } + .fail(), + } + .map_err(|e| Status::unauthenticated(e.to_string()))?; + + user_provider + .authorize( + &query_ctx.current_catalog(), + &query_ctx.current_schema(), + &user_info, + ) + .await + .map_err(|e| Status::permission_denied(e.to_string())) +} diff --git a/src/servers/src/http/authorize.rs b/src/servers/src/http/authorize.rs index eef589ee4d..d438fe4e1a 100644 --- a/src/servers/src/http/authorize.rs +++ b/src/servers/src/http/authorize.rs @@ -194,9 +194,9 @@ async fn authenticate( get_influxdb_credentials(request)?.context(NotFoundInfluxAuthSnafu)? } else { // normal http auth - let (scheme, credential) = auth_header(request)?; + let scheme = auth_header(request)?; match scheme { - AuthScheme::Basic => decode_basic(credential)?, + AuthScheme::Basic(username, password) => (username, password), } }; @@ -219,15 +219,27 @@ where #[derive(Debug)] pub enum AuthScheme { - Basic, + Basic(Username, Password), } +type Username = String; +type Password = String; + impl TryFrom<&str> for AuthScheme { type Error = error::Error; fn try_from(value: &str) -> Result { - match value.to_lowercase().as_str() { - "basic" => Ok(AuthScheme::Basic), + let (scheme, encoded_credentials) = value + .split_once(' ') + .context(InvalidAuthorizationHeaderSnafu)?; + ensure!( + !encoded_credentials.contains(' '), + InvalidAuthorizationHeaderSnafu + ); + + match scheme.to_lowercase().as_str() { + "basic" => decode_basic(encoded_credentials) + .map(|(username, password)| AuthScheme::Basic(username, password)), other => UnsupportedAuthSchemeSnafu { name: other }.fail(), } } @@ -235,7 +247,7 @@ impl TryFrom<&str> for AuthScheme { type Credential<'a> = &'a str; -fn auth_header(req: &Request) -> Result<(AuthScheme, Credential)> { +fn auth_header(req: &Request) -> Result { let auth_header = req .headers() .get(http::header::AUTHORIZATION) @@ -243,20 +255,9 @@ fn auth_header(req: &Request) -> Result<(AuthScheme, Credential)> { .to_str() .context(InvisibleASCIISnafu)?; - let (auth_scheme, encoded_credentials) = auth_header - .split_once(' ') - .context(InvalidAuthorizationHeaderSnafu)?; - - if encoded_credentials.contains(' ') { - return InvalidAuthorizationHeaderSnafu {}.fail(); - } - - Ok((auth_scheme.try_into()?, encoded_credentials)) + auth_header.try_into() } -type Username = String; -type Password = String; - fn decode_basic(credential: Credential) -> Result<(Username, Password)> { let decoded = base64::decode(credential).context(error::InvalidBase64ValueSnafu)?; let as_utf8 = String::from_utf8(decoded).context(error::InvalidUtf8ValueSnafu)?; @@ -324,8 +325,12 @@ mod tests { #[test] fn test_try_into_auth_scheme() { let auth_scheme_str = "basic"; - let auth_scheme: AuthScheme = auth_scheme_str.try_into().unwrap(); - matches!(auth_scheme, AuthScheme::Basic); + let re: Result = auth_scheme_str.try_into(); + assert!(re.is_err()); + + let auth_scheme_str = "basic dGVzdDp0ZXN0"; + let scheme: AuthScheme = auth_scheme_str.try_into().unwrap(); + matches!(scheme, AuthScheme::Basic(username, pwd) if username == "test" && pwd == "test"); let unsupported = "digest"; let auth_scheme: Result = unsupported.try_into(); @@ -337,9 +342,8 @@ mod tests { // base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ=" let req = mock_http_request(Some("Basic dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap(); - let (auth_scheme, credential) = auth_header(&req).unwrap(); - matches!(auth_scheme, AuthScheme::Basic); - assert_eq!("dXNlcm5hbWU6cGFzc3dvcmQ=", credential); + let auth_scheme = auth_header(&req).unwrap(); + matches!(auth_scheme, AuthScheme::Basic(username, pwd) if username == "username" && pwd == "password"); let wrong_req = mock_http_request(Some("Basic dXNlcm5hbWU6 cGFzc3dvcmQ="), None).unwrap(); let res = auth_header(&wrong_req); diff --git a/src/servers/tests/grpc/mod.rs b/src/servers/tests/grpc/mod.rs new file mode 100644 index 0000000000..c2b943dfba --- /dev/null +++ b/src/servers/tests/grpc/mod.rs @@ -0,0 +1,141 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::net::SocketAddr; +use std::sync::Arc; + +use api::v1::auth_header::AuthScheme; +use api::v1::Basic; +use arrow_flight::flight_service_server::{FlightService, FlightServiceServer}; +use async_trait::async_trait; +use client::{Client, Database}; +use common_runtime::{Builder as RuntimeBuilder, Runtime}; +use servers::auth::UserProviderRef; +use servers::error::{Result, StartGrpcSnafu, TcpBindSnafu}; +use servers::grpc::flight::FlightHandler; +use servers::query_handler::grpc::ServerGrpcQueryHandlerRef; +use servers::server::Server; +use snafu::ResultExt; +use table::test_util::MemTable; +use tokio::net::TcpListener; +use tokio_stream::wrappers::TcpListenerStream; + +use crate::auth::MockUserProvider; +use crate::{create_testing_grpc_query_handler, LOCALHOST_WITH_0}; + +struct MockGrpcServer { + query_handler: ServerGrpcQueryHandlerRef, + user_provider: Option, + runtime: Arc, +} + +impl MockGrpcServer { + fn new( + query_handler: ServerGrpcQueryHandlerRef, + user_provider: Option, + runtime: Arc, + ) -> Self { + Self { + query_handler, + user_provider, + runtime, + } + } + + fn create_service(&self) -> FlightServiceServer { + let service = FlightHandler::new( + self.query_handler.clone(), + self.user_provider.clone(), + self.runtime.clone(), + ); + FlightServiceServer::new(service) + } +} + +#[async_trait] +impl Server for MockGrpcServer { + async fn shutdown(&self) -> Result<()> { + Ok(()) + } + + async fn start(&self, addr: SocketAddr) -> Result { + let (listener, addr) = { + let listener = TcpListener::bind(addr) + .await + .context(TcpBindSnafu { addr })?; + let addr = listener.local_addr().context(TcpBindSnafu { addr })?; + (listener, addr) + }; + + let service = self.create_service(); + // Would block to serve requests. + tokio::spawn(async move { + tonic::transport::Server::builder() + .add_service(service) + .serve_with_incoming(TcpListenerStream::new(listener)) + .await + .context(StartGrpcSnafu) + .unwrap() + }); + + Ok(addr) + } +} + +fn create_grpc_server(table: MemTable) -> Result> { + let query_handler = create_testing_grpc_query_handler(table); + let io_runtime = Arc::new( + RuntimeBuilder::default() + .worker_threads(4) + .thread_name("grpc-io-handlers") + .build() + .unwrap(), + ); + + let provider = MockUserProvider::default(); + + Ok(Arc::new(MockGrpcServer::new( + query_handler, + Some(Arc::new(provider)), + io_runtime, + ))) +} + +#[tokio::test] +async fn test_grpc_server_startup() { + let server = create_grpc_server(MemTable::default_numbers_table()).unwrap(); + let re = server.start(LOCALHOST_WITH_0.parse().unwrap()).await; + assert!(re.is_ok()); +} + +#[tokio::test] +async fn test_grpc_query() { + let server = create_grpc_server(MemTable::default_numbers_table()).unwrap(); + let re = server.start(LOCALHOST_WITH_0.parse().unwrap()).await; + assert!(re.is_ok()); + + let grpc_client = Client::with_urls(vec![re.unwrap().to_string()]); + let mut db = Database::with_client(grpc_client); + + let re = db.sql("select * from numbers").await; + assert!(re.is_err()); + + let greptime = "greptime".to_string(); + db.set_auth(AuthScheme::Basic(Basic { + username: greptime.clone(), + password: greptime.clone(), + })); + let re = db.sql("select * from numbers").await; + assert!(re.is_ok()); +} diff --git a/src/servers/tests/mod.rs b/src/servers/tests/mod.rs index 055f95ff2f..ac181cd256 100644 --- a/src/servers/tests/mod.rs +++ b/src/servers/tests/mod.rs @@ -15,6 +15,8 @@ use std::collections::HashMap; use std::sync::{Arc, RwLock}; +use api::v1::greptime_request::{Request as GreptimeRequest, Request}; +use api::v1::query_request::Query; use async_trait::async_trait; use catalog::local::{MemoryCatalogManager, MemoryCatalogProvider, MemorySchemaProvider}; use catalog::{CatalogList, CatalogProvider, SchemaProvider}; @@ -25,14 +27,17 @@ use query::parser::{PromQuery, QueryLanguageParser, QueryStatement}; use query::{QueryEngineFactory, QueryEngineRef}; use script::engine::{CompileContext, EvalContext, Script, ScriptEngine}; use script::python::{PyEngine, PyScript}; -use servers::error::{Error, Result}; +use servers::error::{Error, NotSupportedSnafu, Result}; +use servers::query_handler::grpc::{GrpcQueryHandler, ServerGrpcQueryHandlerRef}; use servers::query_handler::sql::{ServerSqlQueryHandlerRef, SqlQueryHandler}; use servers::query_handler::{ScriptHandler, ScriptHandlerRef}; use session::context::QueryContextRef; +use snafu::ensure; use sql::statements::statement::Statement; use table::test_util::MemTable; mod auth; +mod grpc; mod http; mod interceptor; mod mysql; @@ -40,6 +45,8 @@ mod opentsdb; mod postgres; mod py_script; +const LOCALHOST_WITH_0: &str = "127.0.0.1:0"; + struct DummyInstance { query_engine: QueryEngineRef, py_engine: Arc, @@ -80,7 +87,7 @@ impl SqlQueryHandler for DummyInstance { async fn do_statement_query( &self, - _stmt: sql::statements::statement::Statement, + _stmt: Statement, _query_ctx: QueryContextRef, ) -> Result { unimplemented!() @@ -137,6 +144,39 @@ impl ScriptHandler for DummyInstance { } } +#[async_trait] +impl GrpcQueryHandler for DummyInstance { + type Error = Error; + + async fn do_query( + &self, + request: GreptimeRequest, + ctx: QueryContextRef, + ) -> std::result::Result { + let output = match request { + Request::Insert(_) => unimplemented!(), + Request::Query(query_request) => { + let query = query_request.query.unwrap(); + match query { + Query::Sql(sql) => { + let mut result = SqlQueryHandler::do_query(self, &sql, ctx).await; + ensure!( + result.len() == 1, + NotSupportedSnafu { + feat: "execute multiple statements in SQL query string through GRPC interface" + } + ); + result.remove(0)? + } + Query::LogicalPlan(_) => unimplemented!(), + } + } + Request::Ddl(_) => unimplemented!(), + }; + Ok(output) + } +} + fn create_testing_instance(table: MemTable) -> DummyInstance { let table_name = table.table_name().to_string(); let table = Arc::new(table); @@ -164,3 +204,7 @@ fn create_testing_script_handler(table: MemTable) -> ScriptHandlerRef { fn create_testing_sql_query_handler(table: MemTable) -> ServerSqlQueryHandlerRef { Arc::new(create_testing_instance(table)) as _ } + +fn create_testing_grpc_query_handler(table: MemTable) -> ServerGrpcQueryHandlerRef { + Arc::new(create_testing_instance(table)) as _ +} diff --git a/tests-integration/src/test_util.rs b/tests-integration/src/test_util.rs index 6b5ee9e1c4..517050e1eb 100644 --- a/tests-integration/src/test_util.rs +++ b/tests-integration/src/test_util.rs @@ -350,6 +350,7 @@ pub async fn setup_grpc_server( let fe_instance_ref = Arc::new(fe_instance); let fe_grpc_server = Arc::new(GrpcServer::new( ServerGrpcQueryHandlerAdaptor::arc(fe_instance_ref), + None, runtime, )); let grpc_server_clone = fe_grpc_server.clone();