feat: add auth to grpc handler (#1051)

* chore: get header in grpc & temp save

* chore: change authscheme to include data str

* chore: add auth to grpc flight handler

* chore: add unit test & hold for now since grpc api doesnt accept req input

* chore: minor change

* chore: minor change

* chore: add flight context to database interface

* chore: add test

* chore: update proto version & fix cr issue

* chore: add test

* chore: minor update
This commit is contained in:
shuiyisong
2023-02-22 15:20:10 +08:00
committed by GitHub
parent 390e9095f6
commit fb2e0c7cf3
15 changed files with 357 additions and 45 deletions

3
Cargo.lock generated
View File

@@ -2975,7 +2975,7 @@ checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b"
[[package]]
name = "greptime-proto"
version = "0.1.0"
source = "git+https://github.com/GreptimeTeam/greptime-proto.git?rev=3e6349be127b65a8b42a38cda9d527ec423ca77d#3e6349be127b65a8b42a38cda9d527ec423ca77d"
source = "git+https://github.com/GreptimeTeam/greptime-proto.git?rev=1599ae2a0d1d8f42ee23ed26e4ad7a7b34134c60#1599ae2a0d1d8f42ee23ed26e4ad7a7b34134c60"
dependencies = [
"prost 0.11.6",
"tonic",
@@ -6685,6 +6685,7 @@ dependencies = [
"bytes",
"catalog",
"chrono",
"client",
"common-base",
"common-catalog",
"common-error",

View File

@@ -10,7 +10,7 @@ common-base = { path = "../common/base" }
common-error = { path = "../common/error" }
common-time = { path = "../common/time" }
datatypes = { path = "../datatypes" }
greptime-proto = { git = "https://github.com/GreptimeTeam/greptime-proto.git", rev = "3e6349be127b65a8b42a38cda9d527ec423ca77d" }
greptime-proto = { git = "https://github.com/GreptimeTeam/greptime-proto.git", rev = "1599ae2a0d1d8f42ee23ed26e4ad7a7b34134c60" }
prost.workspace = true
snafu = { version = "0.7", features = ["backtraces"] }
tonic.workspace = true

View File

@@ -14,12 +14,13 @@
use std::str::FromStr;
use api::v1::auth_header::AuthScheme;
use api::v1::ddl_request::Expr as DdlExpr;
use api::v1::greptime_request::Request;
use api::v1::query_request::Query;
use api::v1::{
AlterExpr, CreateTableExpr, DdlRequest, DropTableExpr, GreptimeRequest, InsertRequest,
QueryRequest, RequestHeader,
AlterExpr, AuthHeader, CreateTableExpr, DdlRequest, DropTableExpr, GreptimeRequest,
InsertRequest, QueryRequest, RequestHeader,
};
use arrow_flight::{FlightData, Ticket};
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
@@ -42,6 +43,7 @@ pub struct Database {
schema: String,
client: Client,
ctx: FlightContext,
}
impl Database {
@@ -50,6 +52,7 @@ impl Database {
catalog: catalog.into(),
schema: schema.into(),
client,
ctx: FlightContext::default(),
}
}
@@ -61,6 +64,12 @@ impl Database {
self.schema = schema.into();
}
pub fn set_auth(&mut self, auth: AuthScheme) {
self.ctx.auth_header = Some(AuthHeader {
auth_scheme: Some(auth),
});
}
pub async fn insert(&self, request: InsertRequest) -> Result<Output> {
self.do_get(Request::Insert(request)).await
}
@@ -105,6 +114,7 @@ impl Database {
header: Some(RequestHeader {
catalog: self.catalog.clone(),
schema: self.schema.clone(),
authorization: self.ctx.auth_header.clone(),
}),
request: Some(request),
};
@@ -164,12 +174,18 @@ fn get_metadata_value(e: &tonic::Status, key: &str) -> Option<String> {
.and_then(|v| String::from_utf8(v.as_bytes().to_vec()).ok())
}
#[derive(Default, Debug, Clone)]
pub struct FlightContext {
auth_header: Option<AuthHeader>,
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use api::helper::ColumnDataTypeWrapper;
use api::v1::Column;
use api::v1::auth_header::AuthScheme;
use api::v1::{AuthHeader, Basic, Column};
use common_grpc::select::{null_mask, values};
use common_grpc_expr::column_to_vector;
use datatypes::prelude::{Vector, VectorRef};
@@ -179,6 +195,8 @@ mod tests {
UInt32Vector, UInt64Vector, UInt8Vector,
};
use crate::database::FlightContext;
#[test]
fn test_column_to_vector() {
let mut column = create_test_column(Arc::new(BooleanVector::from(vec![true])));
@@ -262,4 +280,26 @@ mod tests {
datatype: wrapper.datatype() as i32,
}
}
#[test]
fn test_flight_ctx() {
let mut ctx = FlightContext::default();
assert!(ctx.auth_header.is_none());
let basic = AuthScheme::Basic(Basic {
username: "u".to_string(),
password: "p".to_string(),
});
ctx.auth_header = Some(AuthHeader {
auth_scheme: Some(basic),
});
assert!(matches!(
ctx.auth_header,
Some(AuthHeader {
auth_scheme: Some(AuthScheme::Basic(_)),
})
))
}
}

View File

@@ -89,6 +89,7 @@ impl Services {
Ok(Self {
grpc_server: GrpcServer::new(
ServerGrpcQueryHandlerAdaptor::arc(instance),
None,
grpc_runtime,
),
mysql_server,

View File

@@ -66,6 +66,7 @@ impl Services {
let grpc_server = GrpcServer::new(
ServerGrpcQueryHandlerAdaptor::arc(instance.clone()),
user_provider.clone(),
grpc_runtime,
);

View File

@@ -114,6 +114,7 @@ pub(crate) async fn create_datanode_client(
// https://github.com/hyperium/tonic/blob/master/examples/src/mock/mock.rs
let datanode_service = GrpcServer::new(
ServerGrpcQueryHandlerAdaptor::arc(datanode_instance),
None,
runtime,
)
.create_service();

View File

@@ -18,11 +18,11 @@ pub mod kv;
pub mod memory;
use api::v1::meta::{
store_server, BatchPutRequest, BatchPutResponse, CompareAndPutRequest, CompareAndPutResponse,
DeleteRangeRequest, DeleteRangeResponse, MoveValueRequest, MoveValueResponse, PutRequest,
PutResponse, RangeRequest, RangeResponse,
store_server, BatchGetRequest, BatchGetResponse, BatchPutRequest, BatchPutResponse,
CompareAndPutRequest, CompareAndPutResponse, DeleteRangeRequest, DeleteRangeResponse,
MoveValueRequest, MoveValueResponse, PutRequest, PutResponse, RangeRequest, RangeResponse,
};
use tonic::{Request, Response};
use tonic::{Request, Response, Status};
use crate::metasrv::MetaSrv;
use crate::service::GrpcResult;
@@ -43,6 +43,14 @@ impl store_server::Store for MetaSrv {
Ok(Response::new(res))
}
async fn batch_get(
&self,
_request: Request<BatchGetRequest>,
) -> Result<Response<BatchGetResponse>, Status> {
// TODO(fys): please fix this
unimplemented!()
}
async fn batch_put(&self, req: Request<BatchPutRequest>) -> GrpcResult<BatchPutResponse> {
let req = req.into_inner();
let res = self.kv_store().batch_put(req).await?;

View File

@@ -69,6 +69,7 @@ tower-http = { version = "0.3", features = ["full"] }
[dev-dependencies]
axum-test-helper = { git = "https://github.com/sunng87/axum-test-helper.git", branch = "patch-1" }
client = { path = "../client" }
common-base = { path = "../common/base" }
mysql_async = { version = "0.31", default-features = false, features = [
"default-rustls",

View File

@@ -22,7 +22,6 @@ use axum::Json;
use base64::DecodeError;
use catalog;
use common_error::prelude::*;
use hyper::header::ToStrError;
use serde_json::json;
use tonic::codegen::http::{HeaderMap, HeaderValue};
use tonic::metadata::MetadataMap;
@@ -150,7 +149,7 @@ pub enum Error {
#[snafu(display("Invalid OpenTSDB line, source: {}", source))]
InvalidOpentsdbLine {
source: std::string::FromUtf8Error,
source: FromUtf8Error,
backtrace: Backtrace,
},
@@ -216,7 +215,7 @@ pub enum Error {
source: auth::Error,
},
#[snafu(display("Not found http authorization header"))]
#[snafu(display("Not found http or grpc authorization header"))]
NotFoundAuthHeader {},
#[snafu(display("Not found influx http authorization info"))]
@@ -224,7 +223,7 @@ pub enum Error {
#[snafu(display("Invalid visibility ASCII chars, source: {}", source))]
InvisibleASCII {
source: ToStrError,
source: hyper::header::ToStrError,
backtrace: Backtrace,
},

View File

@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
mod flight;
pub mod flight;
use std::net::SocketAddr;
use std::sync::Arc;
@@ -28,6 +28,7 @@ use tokio::sync::oneshot::{self, Sender};
use tokio::sync::Mutex;
use tokio_stream::wrappers::TcpListenerStream;
use crate::auth::UserProviderRef;
use crate::error::{AlreadyStartedSnafu, Result, StartGrpcSnafu, TcpBindSnafu};
use crate::grpc::flight::FlightHandler;
use crate::query_handler::grpc::ServerGrpcQueryHandlerRef;
@@ -35,21 +36,31 @@ use crate::server::Server;
pub struct GrpcServer {
query_handler: ServerGrpcQueryHandlerRef,
user_provider: Option<UserProviderRef>,
shutdown_tx: Mutex<Option<Sender<()>>>,
runtime: Arc<Runtime>,
}
impl GrpcServer {
pub fn new(query_handler: ServerGrpcQueryHandlerRef, runtime: Arc<Runtime>) -> Self {
pub fn new(
query_handler: ServerGrpcQueryHandlerRef,
user_provider: Option<UserProviderRef>,
runtime: Arc<Runtime>,
) -> Self {
Self {
query_handler,
user_provider,
shutdown_tx: Mutex::new(None),
runtime,
}
}
pub fn create_service(&self) -> FlightServiceServer<impl FlightService> {
let service = FlightHandler::new(self.query_handler.clone(), self.runtime.clone());
let service = FlightHandler::new(
self.query_handler.clone(),
self.user_provider.clone(),
self.runtime.clone(),
);
FlightServiceServer::new(service)
}
}

View File

@@ -17,7 +17,8 @@ mod stream;
use std::pin::Pin;
use std::sync::Arc;
use api::v1::{GreptimeRequest, RequestHeader};
use api::v1::auth_header::AuthScheme;
use api::v1::{Basic, GreptimeRequest, RequestHeader};
use arrow_flight::flight_service_server::FlightService;
use arrow_flight::{
Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
@@ -33,21 +34,33 @@ use session::context::{QueryContext, QueryContextRef};
use snafu::{OptionExt, ResultExt};
use tonic::{Request, Response, Status, Streaming};
use crate::auth::{Identity, UserProviderRef};
use crate::error;
use crate::error::Error::Auth;
use crate::error::{NotFoundAuthHeaderSnafu, UnsupportedAuthSchemeSnafu};
use crate::grpc::flight::stream::FlightRecordBatchStream;
use crate::query_handler::grpc::ServerGrpcQueryHandlerRef;
type TonicResult<T> = Result<T, Status>;
type TonicStream<T> = Pin<Box<dyn Stream<Item = TonicResult<T>> + Send + Sync + 'static>>;
pub(crate) struct FlightHandler {
pub struct FlightHandler {
handler: ServerGrpcQueryHandlerRef,
user_provider: Option<UserProviderRef>,
runtime: Arc<Runtime>,
}
impl FlightHandler {
pub(crate) fn new(handler: ServerGrpcQueryHandlerRef, runtime: Arc<Runtime>) -> Self {
Self { handler, runtime }
pub fn new(
handler: ServerGrpcQueryHandlerRef,
user_provider: Option<UserProviderRef>,
runtime: Arc<Runtime>,
) -> Self {
Self {
handler,
user_provider,
runtime,
}
}
}
@@ -97,6 +110,13 @@ impl FlightService for FlightHandler {
})?;
let query_ctx = create_query_context(request.header.as_ref());
auth(
self.user_provider.as_ref(),
request.header.as_ref(),
&query_ctx,
)
.await?;
let handler = self.handler.clone();
// Executes requests in another runtime to
@@ -189,3 +209,42 @@ fn create_query_context(header: Option<&RequestHeader>) -> QueryContextRef {
};
ctx
}
async fn auth(
user_provider: Option<&UserProviderRef>,
request_header: Option<&RequestHeader>,
query_ctx: &QueryContextRef,
) -> TonicResult<()> {
let Some(user_provider) = user_provider else { return Ok(()) };
let user_info = match request_header
.context(NotFoundAuthHeaderSnafu)?
.clone()
.authorization
.context(NotFoundAuthHeaderSnafu)?
.auth_scheme
.context(NotFoundAuthHeaderSnafu)?
{
AuthScheme::Basic(Basic { username, password }) => user_provider
.authenticate(
Identity::UserId(&username, None),
crate::auth::Password::PlainText(&password),
)
.await
.map_err(|e| Auth { source: e }),
AuthScheme::Token(_) => UnsupportedAuthSchemeSnafu {
name: "Token AuthScheme",
}
.fail(),
}
.map_err(|e| Status::unauthenticated(e.to_string()))?;
user_provider
.authorize(
&query_ctx.current_catalog(),
&query_ctx.current_schema(),
&user_info,
)
.await
.map_err(|e| Status::permission_denied(e.to_string()))
}

View File

@@ -194,9 +194,9 @@ async fn authenticate<B: Send + Sync + 'static>(
get_influxdb_credentials(request)?.context(NotFoundInfluxAuthSnafu)?
} else {
// normal http auth
let (scheme, credential) = auth_header(request)?;
let scheme = auth_header(request)?;
match scheme {
AuthScheme::Basic => decode_basic(credential)?,
AuthScheme::Basic(username, password) => (username, password),
}
};
@@ -219,15 +219,27 @@ where
#[derive(Debug)]
pub enum AuthScheme {
Basic,
Basic(Username, Password),
}
type Username = String;
type Password = String;
impl TryFrom<&str> for AuthScheme {
type Error = error::Error;
fn try_from(value: &str) -> Result<Self> {
match value.to_lowercase().as_str() {
"basic" => Ok(AuthScheme::Basic),
let (scheme, encoded_credentials) = value
.split_once(' ')
.context(InvalidAuthorizationHeaderSnafu)?;
ensure!(
!encoded_credentials.contains(' '),
InvalidAuthorizationHeaderSnafu
);
match scheme.to_lowercase().as_str() {
"basic" => decode_basic(encoded_credentials)
.map(|(username, password)| AuthScheme::Basic(username, password)),
other => UnsupportedAuthSchemeSnafu { name: other }.fail(),
}
}
@@ -235,7 +247,7 @@ impl TryFrom<&str> for AuthScheme {
type Credential<'a> = &'a str;
fn auth_header<B>(req: &Request<B>) -> Result<(AuthScheme, Credential)> {
fn auth_header<B>(req: &Request<B>) -> Result<AuthScheme> {
let auth_header = req
.headers()
.get(http::header::AUTHORIZATION)
@@ -243,20 +255,9 @@ fn auth_header<B>(req: &Request<B>) -> Result<(AuthScheme, Credential)> {
.to_str()
.context(InvisibleASCIISnafu)?;
let (auth_scheme, encoded_credentials) = auth_header
.split_once(' ')
.context(InvalidAuthorizationHeaderSnafu)?;
if encoded_credentials.contains(' ') {
return InvalidAuthorizationHeaderSnafu {}.fail();
}
Ok((auth_scheme.try_into()?, encoded_credentials))
auth_header.try_into()
}
type Username = String;
type Password = String;
fn decode_basic(credential: Credential) -> Result<(Username, Password)> {
let decoded = base64::decode(credential).context(error::InvalidBase64ValueSnafu)?;
let as_utf8 = String::from_utf8(decoded).context(error::InvalidUtf8ValueSnafu)?;
@@ -324,8 +325,12 @@ mod tests {
#[test]
fn test_try_into_auth_scheme() {
let auth_scheme_str = "basic";
let auth_scheme: AuthScheme = auth_scheme_str.try_into().unwrap();
matches!(auth_scheme, AuthScheme::Basic);
let re: Result<AuthScheme> = auth_scheme_str.try_into();
assert!(re.is_err());
let auth_scheme_str = "basic dGVzdDp0ZXN0";
let scheme: AuthScheme = auth_scheme_str.try_into().unwrap();
matches!(scheme, AuthScheme::Basic(username, pwd) if username == "test" && pwd == "test");
let unsupported = "digest";
let auth_scheme: Result<AuthScheme> = unsupported.try_into();
@@ -337,9 +342,8 @@ mod tests {
// base64encode("username:password") == "dXNlcm5hbWU6cGFzc3dvcmQ="
let req = mock_http_request(Some("Basic dXNlcm5hbWU6cGFzc3dvcmQ="), None).unwrap();
let (auth_scheme, credential) = auth_header(&req).unwrap();
matches!(auth_scheme, AuthScheme::Basic);
assert_eq!("dXNlcm5hbWU6cGFzc3dvcmQ=", credential);
let auth_scheme = auth_header(&req).unwrap();
matches!(auth_scheme, AuthScheme::Basic(username, pwd) if username == "username" && pwd == "password");
let wrong_req = mock_http_request(Some("Basic dXNlcm5hbWU6 cGFzc3dvcmQ="), None).unwrap();
let res = auth_header(&wrong_req);

View File

@@ -0,0 +1,141 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::net::SocketAddr;
use std::sync::Arc;
use api::v1::auth_header::AuthScheme;
use api::v1::Basic;
use arrow_flight::flight_service_server::{FlightService, FlightServiceServer};
use async_trait::async_trait;
use client::{Client, Database};
use common_runtime::{Builder as RuntimeBuilder, Runtime};
use servers::auth::UserProviderRef;
use servers::error::{Result, StartGrpcSnafu, TcpBindSnafu};
use servers::grpc::flight::FlightHandler;
use servers::query_handler::grpc::ServerGrpcQueryHandlerRef;
use servers::server::Server;
use snafu::ResultExt;
use table::test_util::MemTable;
use tokio::net::TcpListener;
use tokio_stream::wrappers::TcpListenerStream;
use crate::auth::MockUserProvider;
use crate::{create_testing_grpc_query_handler, LOCALHOST_WITH_0};
struct MockGrpcServer {
query_handler: ServerGrpcQueryHandlerRef,
user_provider: Option<UserProviderRef>,
runtime: Arc<Runtime>,
}
impl MockGrpcServer {
fn new(
query_handler: ServerGrpcQueryHandlerRef,
user_provider: Option<UserProviderRef>,
runtime: Arc<Runtime>,
) -> Self {
Self {
query_handler,
user_provider,
runtime,
}
}
fn create_service(&self) -> FlightServiceServer<impl FlightService> {
let service = FlightHandler::new(
self.query_handler.clone(),
self.user_provider.clone(),
self.runtime.clone(),
);
FlightServiceServer::new(service)
}
}
#[async_trait]
impl Server for MockGrpcServer {
async fn shutdown(&self) -> Result<()> {
Ok(())
}
async fn start(&self, addr: SocketAddr) -> Result<SocketAddr> {
let (listener, addr) = {
let listener = TcpListener::bind(addr)
.await
.context(TcpBindSnafu { addr })?;
let addr = listener.local_addr().context(TcpBindSnafu { addr })?;
(listener, addr)
};
let service = self.create_service();
// Would block to serve requests.
tokio::spawn(async move {
tonic::transport::Server::builder()
.add_service(service)
.serve_with_incoming(TcpListenerStream::new(listener))
.await
.context(StartGrpcSnafu)
.unwrap()
});
Ok(addr)
}
}
fn create_grpc_server(table: MemTable) -> Result<Arc<dyn Server>> {
let query_handler = create_testing_grpc_query_handler(table);
let io_runtime = Arc::new(
RuntimeBuilder::default()
.worker_threads(4)
.thread_name("grpc-io-handlers")
.build()
.unwrap(),
);
let provider = MockUserProvider::default();
Ok(Arc::new(MockGrpcServer::new(
query_handler,
Some(Arc::new(provider)),
io_runtime,
)))
}
#[tokio::test]
async fn test_grpc_server_startup() {
let server = create_grpc_server(MemTable::default_numbers_table()).unwrap();
let re = server.start(LOCALHOST_WITH_0.parse().unwrap()).await;
assert!(re.is_ok());
}
#[tokio::test]
async fn test_grpc_query() {
let server = create_grpc_server(MemTable::default_numbers_table()).unwrap();
let re = server.start(LOCALHOST_WITH_0.parse().unwrap()).await;
assert!(re.is_ok());
let grpc_client = Client::with_urls(vec![re.unwrap().to_string()]);
let mut db = Database::with_client(grpc_client);
let re = db.sql("select * from numbers").await;
assert!(re.is_err());
let greptime = "greptime".to_string();
db.set_auth(AuthScheme::Basic(Basic {
username: greptime.clone(),
password: greptime.clone(),
}));
let re = db.sql("select * from numbers").await;
assert!(re.is_ok());
}

View File

@@ -15,6 +15,8 @@
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use api::v1::greptime_request::{Request as GreptimeRequest, Request};
use api::v1::query_request::Query;
use async_trait::async_trait;
use catalog::local::{MemoryCatalogManager, MemoryCatalogProvider, MemorySchemaProvider};
use catalog::{CatalogList, CatalogProvider, SchemaProvider};
@@ -25,14 +27,17 @@ use query::parser::{PromQuery, QueryLanguageParser, QueryStatement};
use query::{QueryEngineFactory, QueryEngineRef};
use script::engine::{CompileContext, EvalContext, Script, ScriptEngine};
use script::python::{PyEngine, PyScript};
use servers::error::{Error, Result};
use servers::error::{Error, NotSupportedSnafu, Result};
use servers::query_handler::grpc::{GrpcQueryHandler, ServerGrpcQueryHandlerRef};
use servers::query_handler::sql::{ServerSqlQueryHandlerRef, SqlQueryHandler};
use servers::query_handler::{ScriptHandler, ScriptHandlerRef};
use session::context::QueryContextRef;
use snafu::ensure;
use sql::statements::statement::Statement;
use table::test_util::MemTable;
mod auth;
mod grpc;
mod http;
mod interceptor;
mod mysql;
@@ -40,6 +45,8 @@ mod opentsdb;
mod postgres;
mod py_script;
const LOCALHOST_WITH_0: &str = "127.0.0.1:0";
struct DummyInstance {
query_engine: QueryEngineRef,
py_engine: Arc<PyEngine>,
@@ -80,7 +87,7 @@ impl SqlQueryHandler for DummyInstance {
async fn do_statement_query(
&self,
_stmt: sql::statements::statement::Statement,
_stmt: Statement,
_query_ctx: QueryContextRef,
) -> Result<Output> {
unimplemented!()
@@ -137,6 +144,39 @@ impl ScriptHandler for DummyInstance {
}
}
#[async_trait]
impl GrpcQueryHandler for DummyInstance {
type Error = Error;
async fn do_query(
&self,
request: GreptimeRequest,
ctx: QueryContextRef,
) -> std::result::Result<Output, Self::Error> {
let output = match request {
Request::Insert(_) => unimplemented!(),
Request::Query(query_request) => {
let query = query_request.query.unwrap();
match query {
Query::Sql(sql) => {
let mut result = SqlQueryHandler::do_query(self, &sql, ctx).await;
ensure!(
result.len() == 1,
NotSupportedSnafu {
feat: "execute multiple statements in SQL query string through GRPC interface"
}
);
result.remove(0)?
}
Query::LogicalPlan(_) => unimplemented!(),
}
}
Request::Ddl(_) => unimplemented!(),
};
Ok(output)
}
}
fn create_testing_instance(table: MemTable) -> DummyInstance {
let table_name = table.table_name().to_string();
let table = Arc::new(table);
@@ -164,3 +204,7 @@ fn create_testing_script_handler(table: MemTable) -> ScriptHandlerRef {
fn create_testing_sql_query_handler(table: MemTable) -> ServerSqlQueryHandlerRef {
Arc::new(create_testing_instance(table)) as _
}
fn create_testing_grpc_query_handler(table: MemTable) -> ServerGrpcQueryHandlerRef {
Arc::new(create_testing_instance(table)) as _
}

View File

@@ -350,6 +350,7 @@ pub async fn setup_grpc_server(
let fe_instance_ref = Arc::new(fe_instance);
let fe_grpc_server = Arc::new(GrpcServer::new(
ServerGrpcQueryHandlerAdaptor::arc(fe_instance_ref),
None,
runtime,
));
let grpc_server_clone = fe_grpc_server.clone();