feat: implement Arrow Flight "DoPut" in Frontend (#5836)

* feat: implement Arrow Flight "DoPut" in Frontend

* support auth for "do_put"

* set request_id in DoPut requests and responses

* set "db" in request header
This commit is contained in:
LFC
2025-04-17 11:46:19 +08:00
committed by GitHub
parent fdab5d198e
commit d27b9fc3a1
26 changed files with 944 additions and 185 deletions

View File

@@ -16,6 +16,7 @@ arc-swap = "1.6"
arrow-flight.workspace = true
async-stream.workspace = true
async-trait.workspace = true
base64.workspace = true
common-catalog.workspace = true
common-error.workspace = true
common-grpc.workspace = true
@@ -25,6 +26,7 @@ common-query.workspace = true
common-recordbatch.workspace = true
common-telemetry.workspace = true
enum_dispatch = "0.3"
futures.workspace = true
futures-util.workspace = true
lazy_static.workspace = true
moka = { workspace = true, features = ["future"] }

View File

@@ -12,36 +12,49 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::pin::Pin;
use std::str::FromStr;
use api::v1::auth_header::AuthScheme;
use api::v1::ddl_request::Expr as DdlExpr;
use api::v1::greptime_database_client::GreptimeDatabaseClient;
use api::v1::greptime_request::Request;
use api::v1::query_request::Query;
use api::v1::{
AlterTableExpr, AuthHeader, CreateTableExpr, DdlRequest, GreptimeRequest, InsertRequests,
QueryRequest, RequestHeader,
AlterTableExpr, AuthHeader, Basic, CreateTableExpr, DdlRequest, GreptimeRequest,
InsertRequests, QueryRequest, RequestHeader,
};
use arrow_flight::Ticket;
use arrow_flight::{FlightData, Ticket};
use async_stream::stream;
use base64::prelude::BASE64_STANDARD;
use base64::Engine;
use common_catalog::build_db_string;
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use common_error::ext::{BoxedError, ErrorExt};
use common_grpc::flight::do_put::DoPutResponse;
use common_grpc::flight::{FlightDecoder, FlightMessage};
use common_query::Output;
use common_recordbatch::error::ExternalSnafu;
use common_recordbatch::RecordBatchStreamWrapper;
use common_telemetry::error;
use common_telemetry::tracing_context::W3cTrace;
use futures_util::StreamExt;
use futures::future;
use futures_util::{Stream, StreamExt, TryStreamExt};
use prost::Message;
use snafu::{ensure, ResultExt};
use tonic::metadata::AsciiMetadataKey;
use tonic::metadata::{AsciiMetadataKey, MetadataValue};
use tonic::transport::Channel;
use crate::error::{
ConvertFlightDataSnafu, Error, FlightGetSnafu, IllegalFlightMessagesSnafu, InvalidAsciiSnafu,
ServerSnafu,
InvalidTonicMetadataValueSnafu, ServerSnafu,
};
use crate::{from_grpc_response, Client, Result};
type FlightDataStream = Pin<Box<dyn Stream<Item = FlightData> + Send>>;
type DoPutResponseStream = Pin<Box<dyn Stream<Item = Result<DoPutResponse>>>>;
#[derive(Clone, Debug, Default)]
pub struct Database {
// The "catalog" and "schema" to be used in processing the requests at the server side.
@@ -108,16 +121,24 @@ impl Database {
self.catalog = catalog.into();
}
pub fn catalog(&self) -> &String {
&self.catalog
fn catalog_or_default(&self) -> &str {
if self.catalog.is_empty() {
DEFAULT_CATALOG_NAME
} else {
&self.catalog
}
}
pub fn set_schema(&mut self, schema: impl Into<String>) {
self.schema = schema.into();
}
pub fn schema(&self) -> &String {
&self.schema
fn schema_or_default(&self) -> &str {
if self.schema.is_empty() {
DEFAULT_SCHEMA_NAME
} else {
&self.schema
}
}
pub fn set_timezone(&mut self, timezone: impl Into<String>) {
@@ -310,6 +331,41 @@ impl Database {
}
}
}
/// Ingest a stream of [RecordBatch]es that belong to a table, using Arrow Flight's "`DoPut`"
/// method. The return value is also a stream, produces [DoPutResponse]s.
pub async fn do_put(&self, stream: FlightDataStream) -> Result<DoPutResponseStream> {
let mut request = tonic::Request::new(stream);
if let Some(AuthHeader {
auth_scheme: Some(AuthScheme::Basic(Basic { username, password })),
}) = &self.ctx.auth_header
{
let encoded = BASE64_STANDARD.encode(format!("{username}:{password}"));
let value =
MetadataValue::from_str(&encoded).context(InvalidTonicMetadataValueSnafu)?;
request.metadata_mut().insert("x-greptime-auth", value);
}
let db_to_put = if !self.dbname.is_empty() {
&self.dbname
} else {
&build_db_string(self.catalog_or_default(), self.schema_or_default())
};
request.metadata_mut().insert(
"x-greptime-db-name",
MetadataValue::from_str(db_to_put).context(InvalidTonicMetadataValueSnafu)?,
);
let mut client = self.client.make_flight_client()?;
let response = client.mut_inner().do_put(request).await?;
let response = response
.into_inner()
.map_err(Into::into)
.and_then(|x| future::ready(DoPutResponse::try_from(x).context(ConvertFlightDataSnafu)))
.boxed();
Ok(response)
}
}
#[derive(Default, Debug, Clone)]

View File

@@ -19,6 +19,7 @@ use common_error::status_code::{convert_tonic_code_to_status_code, StatusCode};
use common_error::{GREPTIME_DB_HEADER_ERROR_CODE, GREPTIME_DB_HEADER_ERROR_MSG};
use common_macro::stack_trace_debug;
use snafu::{location, Location, Snafu};
use tonic::metadata::errors::InvalidMetadataValue;
use tonic::{Code, Status};
#[derive(Snafu)]
@@ -115,6 +116,14 @@ pub enum Error {
#[snafu(implicit)]
location: Location,
},
#[snafu(display("Invalid Tonic metadata value"))]
InvalidTonicMetadataValue {
#[snafu(source)]
error: InvalidMetadataValue,
#[snafu(implicit)]
location: Location,
},
}
pub type Result<T> = std::result::Result<T, Error>;
@@ -135,7 +144,9 @@ impl ErrorExt for Error {
| Error::CreateTlsChannel { source, .. } => source.status_code(),
Error::IllegalGrpcClientState { .. } => StatusCode::Unexpected,
Error::InvalidAscii { .. } => StatusCode::InvalidArguments,
Error::InvalidAscii { .. } | Error::InvalidTonicMetadataValue { .. } => {
StatusCode::InvalidArguments
}
}
}

View File

@@ -16,7 +16,7 @@
mod client;
pub mod client_manager;
mod database;
pub mod database;
pub mod error;
pub mod flow;
pub mod load_balance;

View File

@@ -23,6 +23,8 @@ flatbuffers = "24"
hyper.workspace = true
lazy_static.workspace = true
prost.workspace = true
serde.workspace = true
serde_json.workspace = true
snafu.workspace = true
tokio.workspace = true
tokio-util.workspace = true

View File

@@ -97,6 +97,14 @@ pub enum Error {
#[snafu(display("Not supported: {}", feat))]
NotSupported { feat: String },
#[snafu(display("Failed to serde Json"))]
SerdeJson {
#[snafu(source)]
error: serde_json::error::Error,
#[snafu(implicit)]
location: Location,
},
}
impl ErrorExt for Error {
@@ -110,7 +118,8 @@ impl ErrorExt for Error {
Error::CreateChannel { .. }
| Error::Conversion { .. }
| Error::DecodeFlightData { .. } => StatusCode::Internal,
| Error::DecodeFlightData { .. }
| Error::SerdeJson { .. } => StatusCode::Internal,
Error::CreateRecordBatch { source, .. } => source.status_code(),
Error::ConvertArrowSchema { source, .. } => source.status_code(),

View File

@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
pub mod do_put;
use std::collections::HashMap;
use std::sync::Arc;

View File

@@ -0,0 +1,93 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use arrow_flight::PutResult;
use common_base::AffectedRows;
use serde::{Deserialize, Serialize};
use snafu::ResultExt;
use crate::error::{Error, SerdeJsonSnafu};
/// The metadata for "DoPut" requests and responses.
///
/// Currently, there's only a "request_id", for coordinating requests and responses in the streams.
/// Client can set a unique request id in this metadata, and the server will return the same id in
/// the corresponding response. In doing so, a client can know how to do with its pending requests.
#[derive(Serialize, Deserialize)]
pub struct DoPutMetadata {
request_id: i64,
}
impl DoPutMetadata {
pub fn new(request_id: i64) -> Self {
Self { request_id }
}
pub fn request_id(&self) -> i64 {
self.request_id
}
}
/// The response in the "DoPut" returned stream.
#[derive(Serialize, Deserialize)]
pub struct DoPutResponse {
/// The same "request_id" in the request; see the [DoPutMetadata].
request_id: i64,
/// The successfully ingested rows number.
affected_rows: AffectedRows,
}
impl DoPutResponse {
pub fn new(request_id: i64, affected_rows: AffectedRows) -> Self {
Self {
request_id,
affected_rows,
}
}
pub fn request_id(&self) -> i64 {
self.request_id
}
pub fn affected_rows(&self) -> AffectedRows {
self.affected_rows
}
}
impl TryFrom<PutResult> for DoPutResponse {
type Error = Error;
fn try_from(value: PutResult) -> Result<Self, Self::Error> {
serde_json::from_slice(&value.app_metadata).context(SerdeJsonSnafu)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serde_do_put_metadata() {
let serialized = r#"{"request_id":42}"#;
let metadata = serde_json::from_str::<DoPutMetadata>(serialized).unwrap();
assert_eq!(metadata.request_id(), 42);
}
#[test]
fn test_serde_do_put_response() {
let x = DoPutResponse::new(42, 88);
let serialized = serde_json::to_string(&x).unwrap();
assert_eq!(serialized, r#"{"request_id":42,"affected_rows":88}"#);
}
}

View File

@@ -18,12 +18,13 @@ use api::v1::query_request::Query;
use api::v1::{DeleteRequests, DropFlowExpr, InsertRequests, RowDeleteRequests, RowInsertRequests};
use async_trait::async_trait;
use auth::{PermissionChecker, PermissionCheckerRef, PermissionReq};
use common_base::AffectedRows;
use common_query::Output;
use common_telemetry::tracing::{self};
use datafusion::execution::SessionStateBuilder;
use query::parser::PromQuery;
use servers::interceptor::{GrpcQueryInterceptor, GrpcQueryInterceptorRef};
use servers::query_handler::grpc::GrpcQueryHandler;
use servers::query_handler::grpc::{GrpcQueryHandler, RawRecordBatch};
use servers::query_handler::sql::SqlQueryHandler;
use session::context::QueryContextRef;
use snafu::{ensure, OptionExt, ResultExt};
@@ -31,8 +32,9 @@ use substrait::{DFLogicalSubstraitConvertor, SubstraitPlan};
use table::table_name::TableName;
use crate::error::{
Error, InFlightWriteBytesExceededSnafu, IncompleteGrpcRequestSnafu, NotSupportedSnafu,
PermissionSnafu, Result, SubstraitDecodeLogicalPlanSnafu, TableOperationSnafu,
CatalogSnafu, Error, InFlightWriteBytesExceededSnafu, IncompleteGrpcRequestSnafu,
NotSupportedSnafu, PermissionSnafu, Result, SubstraitDecodeLogicalPlanSnafu,
TableNotFoundSnafu, TableOperationSnafu,
};
use crate::instance::{attach_timer, Instance};
use crate::metrics::{
@@ -203,6 +205,34 @@ impl GrpcQueryHandler for Instance {
let output = interceptor.post_execute(output, ctx)?;
Ok(output)
}
async fn put_record_batch(
&self,
table: &TableName,
record_batch: RawRecordBatch,
) -> Result<AffectedRows> {
let _table = self
.catalog_manager()
.table(
&table.catalog_name,
&table.schema_name,
&table.table_name,
None,
)
.await
.context(CatalogSnafu)?
.with_context(|| TableNotFoundSnafu {
table_name: table.to_string(),
})?;
// TODO(LFC): Implement it.
common_telemetry::debug!(
"calling put_record_batch with table: {:?} and record_batch size: {}",
table,
record_batch.len()
);
Ok(record_batch.len())
}
}
fn fill_catalog_and_schema_from_context(ddl_expr: &mut DdlExpr, ctx: &QueryContextRef) {

View File

@@ -16,6 +16,7 @@ mod stream;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use api::v1::GreptimeRequest;
use arrow_flight::flight_service_server::FlightService;
@@ -24,21 +25,33 @@ use arrow_flight::{
HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket,
};
use async_trait::async_trait;
use bytes::Bytes;
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use common_catalog::parse_catalog_and_schema_from_db_string;
use common_grpc::flight::do_put::{DoPutMetadata, DoPutResponse};
use common_grpc::flight::{FlightEncoder, FlightMessage};
use common_query::{Output, OutputData};
use common_telemetry::tracing::info_span;
use common_telemetry::tracing_context::{FutureExt, TracingContext};
use futures::Stream;
use futures::{future, ready, Stream};
use futures_util::{StreamExt, TryStreamExt};
use prost::Message;
use snafu::ResultExt;
use snafu::{ensure, ResultExt};
use table::table_name::TableName;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tonic::{Request, Response, Status, Streaming};
use crate::error;
use crate::error::{InvalidParameterSnafu, ParseJsonSnafu, Result, ToJsonSnafu};
pub use crate::grpc::flight::stream::FlightRecordBatchStream;
use crate::grpc::greptime_handler::{get_request_type, GreptimeRequestHandler};
use crate::grpc::TonicResult;
use crate::http::header::constants::GREPTIME_DB_HEADER_NAME;
use crate::http::AUTHORIZATION_HEADER;
use crate::query_handler::grpc::RawRecordBatch;
pub type TonicStream<T> = Pin<Box<dyn Stream<Item = TonicResult<T>> + Send + Sync + 'static>>;
pub type TonicStream<T> = Pin<Box<dyn Stream<Item = TonicResult<T>> + Send + 'static>>;
/// A subset of [FlightService]
#[async_trait]
@@ -47,6 +60,14 @@ pub trait FlightCraft: Send + Sync + 'static {
&self,
request: Request<Ticket>,
) -> TonicResult<Response<TonicStream<FlightData>>>;
async fn do_put(
&self,
request: Request<Streaming<FlightData>>,
) -> TonicResult<Response<TonicStream<PutResult>>> {
let _ = request;
Err(Status::unimplemented("Not yet implemented"))
}
}
pub type FlightCraftRef = Arc<dyn FlightCraft>;
@@ -67,6 +88,13 @@ impl FlightCraft for FlightCraftRef {
) -> TonicResult<Response<TonicStream<FlightData>>> {
(**self).do_get(request).await
}
async fn do_put(
&self,
request: Request<Streaming<FlightData>>,
) -> TonicResult<Response<TonicStream<PutResult>>> {
self.as_ref().do_put(request).await
}
}
#[async_trait]
@@ -120,9 +148,9 @@ impl<T: FlightCraft> FlightService for FlightCraftWrapper<T> {
async fn do_put(
&self,
_: Request<Streaming<FlightData>>,
request: Request<Streaming<FlightData>>,
) -> TonicResult<Response<Self::DoPutStream>> {
Err(Status::unimplemented("Not yet implemented"))
self.0.do_put(request).await
}
type DoExchangeStream = TonicStream<FlightData>;
@@ -168,13 +196,164 @@ impl FlightCraft for GreptimeRequestHandler {
);
async {
let output = self.handle_request(request, Default::default()).await?;
let stream: Pin<Box<dyn Stream<Item = Result<FlightData, Status>> + Send + Sync>> =
to_flight_data_stream(output, TracingContext::from_current_span());
let stream = to_flight_data_stream(output, TracingContext::from_current_span());
Ok(Response::new(stream))
}
.trace(span)
.await
}
async fn do_put(
&self,
request: Request<Streaming<FlightData>>,
) -> TonicResult<Response<TonicStream<PutResult>>> {
let (headers, _, stream) = request.into_parts();
let header = |key: &str| -> TonicResult<Option<&str>> {
let Some(v) = headers.get(key) else {
return Ok(None);
};
let Ok(v) = std::str::from_utf8(v.as_bytes()) else {
return Err(InvalidParameterSnafu {
reason: "expect valid UTF-8 value",
}
.build()
.into());
};
Ok(Some(v))
};
let username_and_password = header(AUTHORIZATION_HEADER)?;
let db = header(GREPTIME_DB_HEADER_NAME)?;
if !self.validate_auth(username_and_password, db).await? {
return Err(Status::unauthenticated("auth failed"));
}
const MAX_PENDING_RESPONSES: usize = 32;
let (tx, rx) = mpsc::channel::<TonicResult<DoPutResponse>>(MAX_PENDING_RESPONSES);
let stream = PutRecordBatchRequestStream {
flight_data_stream: stream,
state: PutRecordBatchRequestStreamState::Init(db.map(ToString::to_string)),
};
self.put_record_batches(stream, tx).await;
let response = ReceiverStream::new(rx)
.and_then(|response| {
future::ready({
serde_json::to_vec(&response)
.context(ToJsonSnafu)
.map(|x| PutResult {
app_metadata: Bytes::from(x),
})
.map_err(Into::into)
})
})
.boxed();
Ok(Response::new(response))
}
}
pub(crate) struct PutRecordBatchRequest {
pub(crate) table_name: TableName,
pub(crate) request_id: i64,
pub(crate) record_batch: RawRecordBatch,
}
impl PutRecordBatchRequest {
fn try_new(table_name: TableName, flight_data: FlightData) -> Result<Self> {
let metadata: DoPutMetadata =
serde_json::from_slice(&flight_data.app_metadata).context(ParseJsonSnafu)?;
Ok(Self {
table_name,
request_id: metadata.request_id(),
record_batch: flight_data.data_body,
})
}
}
pub(crate) struct PutRecordBatchRequestStream {
flight_data_stream: Streaming<FlightData>,
state: PutRecordBatchRequestStreamState,
}
enum PutRecordBatchRequestStreamState {
Init(Option<String>),
Started(TableName),
}
impl Stream for PutRecordBatchRequestStream {
type Item = TonicResult<PutRecordBatchRequest>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
fn extract_table_name(mut descriptor: FlightDescriptor) -> Result<String> {
ensure!(
descriptor.r#type == arrow_flight::flight_descriptor::DescriptorType::Path as i32,
InvalidParameterSnafu {
reason: "expect FlightDescriptor::type == 'Path' only",
}
);
ensure!(
descriptor.path.len() == 1,
InvalidParameterSnafu {
reason: "expect FlightDescriptor::path has only one table name",
}
);
Ok(descriptor.path.remove(0))
}
let poll = ready!(self.flight_data_stream.poll_next_unpin(cx));
let result = match &mut self.state {
PutRecordBatchRequestStreamState::Init(db) => match poll {
Some(Ok(mut flight_data)) => {
let flight_descriptor = flight_data.flight_descriptor.take();
let result = if let Some(descriptor) = flight_descriptor {
let table_name = extract_table_name(descriptor).map(|x| {
let (catalog, schema) = if let Some(db) = db {
parse_catalog_and_schema_from_db_string(db)
} else {
(
DEFAULT_CATALOG_NAME.to_string(),
DEFAULT_SCHEMA_NAME.to_string(),
)
};
TableName::new(catalog, schema, x)
});
let table_name = match table_name {
Ok(table_name) => table_name,
Err(e) => return Poll::Ready(Some(Err(e.into()))),
};
let request =
PutRecordBatchRequest::try_new(table_name.clone(), flight_data);
let request = match request {
Ok(request) => request,
Err(e) => return Poll::Ready(Some(Err(e.into()))),
};
self.state = PutRecordBatchRequestStreamState::Started(table_name);
Ok(request)
} else {
Err(Status::failed_precondition(
"table to put is not found in flight descriptor",
))
};
Some(result)
}
Some(Err(e)) => Some(Err(e)),
None => None,
},
PutRecordBatchRequestStreamState::Started(table_name) => poll.map(|x| {
x.and_then(|flight_data| {
PutRecordBatchRequest::try_new(table_name.clone(), flight_data)
.map_err(Into::into)
})
}),
};
Poll::Ready(result)
}
}
fn to_flight_data_stream(

View File

@@ -18,23 +18,33 @@ use std::time::Instant;
use api::helper::request_type;
use api::v1::auth_header::AuthScheme;
use api::v1::{Basic, GreptimeRequest, RequestHeader};
use api::v1::{AuthHeader, Basic, GreptimeRequest, RequestHeader};
use auth::{Identity, Password, UserInfoRef, UserProviderRef};
use base64::prelude::BASE64_STANDARD;
use base64::Engine;
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use common_catalog::parse_catalog_and_schema_from_db_string;
use common_error::ext::ErrorExt;
use common_error::status_code::StatusCode;
use common_grpc::flight::do_put::DoPutResponse;
use common_query::Output;
use common_runtime::runtime::RuntimeTrait;
use common_runtime::Runtime;
use common_telemetry::tracing_context::{FutureExt, TracingContext};
use common_telemetry::{debug, error, tracing};
use common_telemetry::{debug, error, tracing, warn};
use common_time::timezone::parse_timezone;
use session::context::{QueryContextBuilder, QueryContextRef};
use futures_util::StreamExt;
use session::context::{QueryContext, QueryContextBuilder, QueryContextRef};
use snafu::{OptionExt, ResultExt};
use tokio::sync::mpsc;
use crate::error::Error::UnsupportedAuthScheme;
use crate::error::{AuthSnafu, InvalidQuerySnafu, JoinTaskSnafu, NotFoundAuthHeaderSnafu, Result};
use crate::error::{
AuthSnafu, InvalidAuthHeaderInvalidUtf8ValueSnafu, InvalidBase64ValueSnafu, InvalidQuerySnafu,
JoinTaskSnafu, NotFoundAuthHeaderSnafu, Result,
};
use crate::grpc::flight::{PutRecordBatchRequest, PutRecordBatchRequestStream};
use crate::grpc::TonicResult;
use crate::metrics::{METRIC_AUTH_FAILURE, METRIC_SERVER_GRPC_DB_REQUEST_TIMER};
use crate::query_handler::grpc::ServerGrpcQueryHandlerRef;
@@ -118,6 +128,95 @@ impl GreptimeRequestHandler {
None => result_future.await,
}
}
pub(crate) async fn put_record_batches(
&self,
mut stream: PutRecordBatchRequestStream,
result_sender: mpsc::Sender<TonicResult<DoPutResponse>>,
) {
let handler = self.handler.clone();
let runtime = self
.runtime
.clone()
.unwrap_or_else(common_runtime::global_runtime);
runtime.spawn(async move {
while let Some(request) = stream.next().await {
let request = match request {
Ok(request) => request,
Err(e) => {
let _ = result_sender.try_send(Err(e));
break;
}
};
let PutRecordBatchRequest {
table_name,
request_id,
record_batch,
} = request;
let result = handler.put_record_batch(&table_name, record_batch).await;
let result = result
.map(|x| DoPutResponse::new(request_id, x))
.map_err(Into::into);
if result_sender.try_send(result).is_err() {
warn!(r#""DoPut" client maybe unreachable, abort handling its message"#);
break;
}
}
});
}
pub(crate) async fn validate_auth(
&self,
username_and_password: Option<&str>,
db: Option<&str>,
) -> Result<bool> {
if self.user_provider.is_none() {
return Ok(true);
}
let username_and_password = username_and_password.context(NotFoundAuthHeaderSnafu)?;
let username_and_password = BASE64_STANDARD
.decode(username_and_password)
.context(InvalidBase64ValueSnafu)
.and_then(|x| String::from_utf8(x).context(InvalidAuthHeaderInvalidUtf8ValueSnafu))?;
let mut split = username_and_password.splitn(2, ':');
let (username, password) = match (split.next(), split.next()) {
(Some(username), Some(password)) => (username, password),
(Some(username), None) => (username, ""),
(None, None) => return Ok(false),
_ => unreachable!(), // because this iterator won't yield Some after None
};
let (catalog, schema) = if let Some(db) = db {
parse_catalog_and_schema_from_db_string(db)
} else {
(
DEFAULT_CATALOG_NAME.to_string(),
DEFAULT_SCHEMA_NAME.to_string(),
)
};
let header = RequestHeader {
authorization: Some(AuthHeader {
auth_scheme: Some(AuthScheme::Basic(Basic {
username: username.to_string(),
password: password.to_string(),
})),
}),
catalog,
schema,
..Default::default()
};
Ok(auth(
self.user_provider.clone(),
Some(&header),
&QueryContext::arc(),
)
.await
.is_ok())
}
}
pub fn get_request_type(request: &GreptimeRequest) -> &'static str {

View File

@@ -1169,7 +1169,6 @@ mod test {
use std::io::Cursor;
use std::sync::Arc;
use api::v1::greptime_request::Request;
use arrow_ipc::reader::FileReader;
use arrow_schema::DataType;
use axum::handler::Handler;
@@ -1191,26 +1190,12 @@ mod test {
use super::*;
use crate::error::Error;
use crate::http::test_helpers::TestClient;
use crate::query_handler::grpc::GrpcQueryHandler;
use crate::query_handler::sql::{ServerSqlQueryHandlerAdapter, SqlQueryHandler};
struct DummyInstance {
_tx: mpsc::Sender<(String, Vec<u8>)>,
}
#[async_trait]
impl GrpcQueryHandler for DummyInstance {
type Error = Error;
async fn do_query(
&self,
_query: Request,
_ctx: QueryContextRef,
) -> std::result::Result<Output, Self::Error> {
unimplemented!()
}
}
#[async_trait]
impl SqlQueryHandler for DummyInstance {
type Error = Error;

View File

@@ -16,16 +16,20 @@ use std::sync::Arc;
use api::v1::greptime_request::Request;
use async_trait::async_trait;
use common_base::AffectedRows;
use common_error::ext::{BoxedError, ErrorExt};
use common_query::Output;
use session::context::QueryContextRef;
use snafu::ResultExt;
use table::table_name::TableName;
use crate::error::{self, Result};
pub type GrpcQueryHandlerRef<E> = Arc<dyn GrpcQueryHandler<Error = E> + Send + Sync>;
pub type ServerGrpcQueryHandlerRef = GrpcQueryHandlerRef<error::Error>;
pub type RawRecordBatch = bytes::Bytes;
#[async_trait]
pub trait GrpcQueryHandler {
type Error: ErrorExt;
@@ -35,6 +39,12 @@ pub trait GrpcQueryHandler {
query: Request,
ctx: QueryContextRef,
) -> std::result::Result<Output, Self::Error>;
async fn put_record_batch(
&self,
table: &TableName,
record_batch: RawRecordBatch,
) -> std::result::Result<AffectedRows, Self::Error>;
}
pub struct ServerGrpcQueryHandlerAdapter<E>(GrpcQueryHandlerRef<E>);
@@ -59,4 +69,16 @@ where
.map_err(BoxedError::new)
.context(error::ExecuteGrpcQuerySnafu)
}
async fn put_record_batch(
&self,
table: &TableName,
record_batch: RawRecordBatch,
) -> Result<AffectedRows> {
self.0
.put_record_batch(table, record_batch)
.await
.map_err(BoxedError::new)
.context(error::ExecuteGrpcRequestSnafu)
}
}

View File

@@ -14,7 +14,6 @@
use std::sync::Arc;
use api::v1::greptime_request::Request;
use api::v1::RowInsertRequests;
use async_trait::async_trait;
use auth::tests::{DatabaseAuthInfo, MockUserProvider};
@@ -29,7 +28,6 @@ use servers::http::header::constants::GREPTIME_DB_HEADER_NAME;
use servers::http::test_helpers::TestClient;
use servers::http::{HttpOptions, HttpServerBuilder};
use servers::influxdb::InfluxdbRequest;
use servers::query_handler::grpc::GrpcQueryHandler;
use servers::query_handler::sql::SqlQueryHandler;
use servers::query_handler::InfluxdbLineProtocolHandler;
use session::context::QueryContextRef;
@@ -39,19 +37,6 @@ struct DummyInstance {
tx: Arc<mpsc::Sender<(String, String)>>,
}
#[async_trait]
impl GrpcQueryHandler for DummyInstance {
type Error = Error;
async fn do_query(
&self,
_query: Request,
_ctx: QueryContextRef,
) -> std::result::Result<Output, Self::Error> {
unimplemented!()
}
}
#[async_trait]
impl InfluxdbLineProtocolHandler for DummyInstance {
async fn exec(&self, request: InfluxdbRequest, ctx: QueryContextRef) -> Result<Output> {

View File

@@ -14,7 +14,6 @@
use std::sync::Arc;
use api::v1::greptime_request::Request;
use async_trait::async_trait;
use axum::Router;
use common_query::Output;
@@ -26,7 +25,6 @@ use servers::error::{self, Result};
use servers::http::test_helpers::TestClient;
use servers::http::{HttpOptions, HttpServerBuilder};
use servers::opentsdb::codec::DataPoint;
use servers::query_handler::grpc::GrpcQueryHandler;
use servers::query_handler::sql::SqlQueryHandler;
use servers::query_handler::OpentsdbProtocolHandler;
use session::context::QueryContextRef;
@@ -36,19 +34,6 @@ struct DummyInstance {
tx: mpsc::Sender<String>,
}
#[async_trait]
impl GrpcQueryHandler for DummyInstance {
type Error = crate::Error;
async fn do_query(
&self,
_query: Request,
_ctx: QueryContextRef,
) -> std::result::Result<Output, Self::Error> {
unimplemented!()
}
}
#[async_trait]
impl OpentsdbProtocolHandler for DummyInstance {
async fn exec(&self, data_points: Vec<DataPoint>, _ctx: QueryContextRef) -> Result<usize> {

View File

@@ -17,7 +17,6 @@ use std::sync::Arc;
use api::prom_store::remote::{
LabelMatcher, Query, QueryResult, ReadRequest, ReadResponse, WriteRequest,
};
use api::v1::greptime_request::Request;
use api::v1::RowInsertRequests;
use async_trait::async_trait;
use axum::Router;
@@ -33,7 +32,6 @@ use servers::http::test_helpers::TestClient;
use servers::http::{HttpOptions, HttpServerBuilder};
use servers::prom_store;
use servers::prom_store::{snappy_compress, Metrics};
use servers::query_handler::grpc::GrpcQueryHandler;
use servers::query_handler::sql::SqlQueryHandler;
use servers::query_handler::{PromStoreProtocolHandler, PromStoreResponse};
use session::context::QueryContextRef;
@@ -43,19 +41,6 @@ struct DummyInstance {
tx: mpsc::Sender<(String, Vec<u8>)>,
}
#[async_trait]
impl GrpcQueryHandler for DummyInstance {
type Error = Error;
async fn do_query(
&self,
_query: Request,
_ctx: QueryContextRef,
) -> std::result::Result<Output, Self::Error> {
unimplemented!()
}
}
#[async_trait]
impl PromStoreProtocolHandler for DummyInstance {
async fn write(

View File

@@ -18,6 +18,7 @@ use api::v1::greptime_request::Request;
use api::v1::query_request::Query;
use async_trait::async_trait;
use catalog::memory::MemoryCatalogManager;
use common_base::AffectedRows;
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use common_query::Output;
use datafusion_expr::LogicalPlan;
@@ -26,11 +27,12 @@ use query::parser::{PromQuery, QueryLanguageParser, QueryStatement};
use query::query_engine::DescribeResult;
use query::{QueryEngineFactory, QueryEngineRef};
use servers::error::{Error, NotSupportedSnafu, Result};
use servers::query_handler::grpc::{GrpcQueryHandler, ServerGrpcQueryHandlerRef};
use servers::query_handler::grpc::{GrpcQueryHandler, RawRecordBatch, ServerGrpcQueryHandlerRef};
use servers::query_handler::sql::{ServerSqlQueryHandlerRef, SqlQueryHandler};
use session::context::QueryContextRef;
use snafu::ensure;
use sql::statements::statement::Statement;
use table::table_name::TableName;
use table::TableRef;
mod grpc;
@@ -155,6 +157,16 @@ impl GrpcQueryHandler for DummyInstance {
};
Ok(output)
}
async fn put_record_batch(
&self,
table: &TableName,
record_batch: RawRecordBatch,
) -> std::result::Result<AffectedRows, Self::Error> {
let _ = table;
let _ = record_batch;
unimplemented!()
}
}
fn create_testing_instance(table: TableRef) -> DummyInstance {

View File

@@ -172,6 +172,9 @@ pub enum Error {
#[snafu(implicit)]
location: Location,
},
#[snafu(display("Invalid table name: '{s}'"))]
InvalidTableName { s: String },
}
impl ErrorExt for Error {
@@ -197,7 +200,8 @@ impl ErrorExt for Error {
Error::MissingTimeIndexColumn { .. } => StatusCode::IllegalState,
Error::InvalidTableOptionValue { .. }
| Error::SetSkippingOptions { .. }
| Error::UnsetSkippingOptions { .. } => StatusCode::InvalidArguments,
| Error::UnsetSkippingOptions { .. }
| Error::InvalidTableName { .. } => StatusCode::InvalidArguments,
}
}

View File

@@ -15,8 +15,12 @@
use std::fmt::{Display, Formatter};
use api::v1::TableName as PbTableName;
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use serde::{Deserialize, Serialize};
use snafu::ensure;
use crate::error;
use crate::error::InvalidTableNameSnafu;
use crate::table_reference::TableReference;
#[derive(Debug, Clone, Hash, Eq, PartialEq, Deserialize, Serialize)]
@@ -83,3 +87,37 @@ impl From<TableReference<'_>> for TableName {
Self::new(table_ref.catalog, table_ref.schema, table_ref.table)
}
}
impl TryFrom<Vec<String>> for TableName {
type Error = error::Error;
fn try_from(v: Vec<String>) -> Result<Self, Self::Error> {
ensure!(
!v.is_empty() && v.len() <= 3,
InvalidTableNameSnafu {
s: format!("{v:?}")
}
);
let mut v = v.into_iter();
match (v.next(), v.next(), v.next()) {
(Some(catalog_name), Some(schema_name), Some(table_name)) => Ok(Self {
catalog_name,
schema_name,
table_name,
}),
(Some(schema_name), Some(table_name), None) => Ok(Self {
catalog_name: DEFAULT_CATALOG_NAME.to_string(),
schema_name,
table_name,
}),
(Some(table_name), None, None) => Ok(Self {
catalog_name: DEFAULT_CATALOG_NAME.to_string(),
schema_name: DEFAULT_SCHEMA_NAME.to_string(),
table_name,
}),
// Unreachable because it's ensured that "v" is not empty,
// and its iterator will not yield `Some` after `None`.
_ => unreachable!(),
}
}
}