diff --git a/Cargo.lock b/Cargo.lock index f6d2fb4256..c733e5c41f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1537,6 +1537,7 @@ dependencies = [ "substrait 0.2.0", "substrait 0.7.5", "tokio", + "tokio-stream", "tonic 0.9.2", "tracing", "tracing-subscriber", diff --git a/src/client/Cargo.toml b/src/client/Cargo.toml index 9dab5b072b..1a937a947d 100644 --- a/src/client/Cargo.toml +++ b/src/client/Cargo.toml @@ -30,12 +30,13 @@ parking_lot = "0.12" prost.workspace = true rand.workspace = true snafu.workspace = true +tokio-stream = { version = "0.1", features = ["net"] } +tokio.workspace = true tonic.workspace = true [dev-dependencies] datanode = { path = "../datanode" } substrait = { path = "../common/substrait" } -tokio.workspace = true tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } prost.workspace = true diff --git a/src/client/src/database.rs b/src/client/src/database.rs index 6cb6e3aeef..5c28efaf86 100644 --- a/src/client/src/database.rs +++ b/src/client/src/database.rs @@ -29,6 +29,9 @@ use common_telemetry::{logging, timer}; use futures_util::{TryFutureExt, TryStreamExt}; use prost::Message; use snafu::{ensure, ResultExt}; +use tokio::sync::mpsc::Sender; +use tokio::sync::{mpsc, OnceCell}; +use tokio_stream::wrappers::ReceiverStream; use crate::error::{ ConvertFlightDataSnafu, IllegalDatabaseResponseSnafu, IllegalFlightMessagesSnafu, @@ -47,6 +50,7 @@ pub struct Database { dbname: String, client: Client, + streaming_client: OnceCell>, ctx: FlightContext, } @@ -58,6 +62,7 @@ impl Database { schema: schema.into(), dbname: "".to_string(), client, + streaming_client: OnceCell::new(), ctx: FlightContext::default(), } } @@ -75,6 +80,7 @@ impl Database { schema: "".to_string(), dbname: dbname.into(), client, + streaming_client: OnceCell::new(), ctx: FlightContext::default(), } } @@ -114,6 +120,22 @@ impl Database { self.handle(Request::Inserts(requests)).await } + pub async fn insert_to_stream(&self, requests: InsertRequests) -> Result<()> { + let streaming_client = self + .streaming_client + .get_or_try_init(|| self.client_stream()) + .await?; + + let request = self.to_rpc_request(Request::Inserts(requests)); + + streaming_client.send(request).await.map_err(|e| { + error::ClientStreamingSnafu { + err_msg: e.to_string(), + } + .build() + }) + } + pub async fn delete(&self, request: DeleteRequest) -> Result { let _timer = timer!(metrics::METRIC_GRPC_DELETE); self.handle(Request::Delete(request)).await @@ -121,15 +143,7 @@ impl Database { async fn handle(&self, request: Request) -> Result { let mut client = self.client.make_database_client()?.inner; - let request = GreptimeRequest { - header: Some(RequestHeader { - catalog: self.catalog.clone(), - schema: self.schema.clone(), - authorization: self.ctx.auth_header.clone(), - dbname: self.dbname.clone(), - }), - request: Some(request), - }; + let request = self.to_rpc_request(request); let response = client .handle(request) .await? @@ -142,6 +156,27 @@ impl Database { Ok(value) } + #[inline] + fn to_rpc_request(&self, request: Request) -> GreptimeRequest { + GreptimeRequest { + header: Some(RequestHeader { + catalog: self.catalog.clone(), + schema: self.schema.clone(), + authorization: self.ctx.auth_header.clone(), + dbname: self.dbname.clone(), + }), + request: Some(request), + } + } + + async fn client_stream(&self) -> Result> { + let mut client = self.client.make_database_client()?.inner; + let (sender, receiver) = mpsc::channel::(65536); + let receiver = ReceiverStream::new(receiver); + client.handle_requests(receiver).await?; + Ok(sender) + } + pub async fn sql(&self, sql: &str) -> Result { let _timer = timer!(metrics::METRIC_GRPC_SQL); self.do_get(Request::Query(QueryRequest { @@ -212,15 +247,7 @@ impl Database { async fn do_get(&self, request: Request) -> Result { // FIXME(paomian): should be added some labels for metrics let _timer = timer!(metrics::METRIC_GRPC_DO_GET); - let request = GreptimeRequest { - header: Some(RequestHeader { - catalog: self.catalog.clone(), - schema: self.schema.clone(), - authorization: self.ctx.auth_header.clone(), - dbname: self.dbname.clone(), - }), - request: Some(request), - }; + let request = self.to_rpc_request(request); let request = Ticket { ticket: request.encode_to_vec().into(), }; diff --git a/src/client/src/error.rs b/src/client/src/error.rs index 51608d7115..0bfb67ec0d 100644 --- a/src/client/src/error.rs +++ b/src/client/src/error.rs @@ -67,6 +67,9 @@ pub enum Error { #[snafu(display("Illegal Database response: {err_msg}"))] IllegalDatabaseResponse { err_msg: String }, + + #[snafu(display("Failed to send request with streaming: {}", err_msg))] + ClientStreaming { err_msg: String, location: Location }, } pub type Result = std::result::Result; @@ -77,7 +80,8 @@ impl ErrorExt for Error { Error::IllegalFlightMessages { .. } | Error::ColumnDataType { .. } | Error::MissingField { .. } - | Error::IllegalDatabaseResponse { .. } => StatusCode::Internal, + | Error::IllegalDatabaseResponse { .. } + | Error::ClientStreaming { .. } => StatusCode::Internal, Error::Server { code, .. } => *code, Error::FlightGet { source, .. } => source.status_code(),