feat: insert with stream (#1703)

* feat: insert with stream

* chore: by CR
This commit is contained in:
JeremyHi
2023-06-03 11:58:00 +08:00
committed by GitHub
parent 5004cf6d9a
commit 3d7185749d
4 changed files with 53 additions and 20 deletions

1
Cargo.lock generated
View File

@@ -1537,6 +1537,7 @@ dependencies = [
"substrait 0.2.0",
"substrait 0.7.5",
"tokio",
"tokio-stream",
"tonic 0.9.2",
"tracing",
"tracing-subscriber",

View File

@@ -30,12 +30,13 @@ parking_lot = "0.12"
prost.workspace = true
rand.workspace = true
snafu.workspace = true
tokio-stream = { version = "0.1", features = ["net"] }
tokio.workspace = true
tonic.workspace = true
[dev-dependencies]
datanode = { path = "../datanode" }
substrait = { path = "../common/substrait" }
tokio.workspace = true
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
prost.workspace = true

View File

@@ -29,6 +29,9 @@ use common_telemetry::{logging, timer};
use futures_util::{TryFutureExt, TryStreamExt};
use prost::Message;
use snafu::{ensure, ResultExt};
use tokio::sync::mpsc::Sender;
use tokio::sync::{mpsc, OnceCell};
use tokio_stream::wrappers::ReceiverStream;
use crate::error::{
ConvertFlightDataSnafu, IllegalDatabaseResponseSnafu, IllegalFlightMessagesSnafu,
@@ -47,6 +50,7 @@ pub struct Database {
dbname: String,
client: Client,
streaming_client: OnceCell<Sender<GreptimeRequest>>,
ctx: FlightContext,
}
@@ -58,6 +62,7 @@ impl Database {
schema: schema.into(),
dbname: "".to_string(),
client,
streaming_client: OnceCell::new(),
ctx: FlightContext::default(),
}
}
@@ -75,6 +80,7 @@ impl Database {
schema: "".to_string(),
dbname: dbname.into(),
client,
streaming_client: OnceCell::new(),
ctx: FlightContext::default(),
}
}
@@ -114,6 +120,22 @@ impl Database {
self.handle(Request::Inserts(requests)).await
}
pub async fn insert_to_stream(&self, requests: InsertRequests) -> Result<()> {
let streaming_client = self
.streaming_client
.get_or_try_init(|| self.client_stream())
.await?;
let request = self.to_rpc_request(Request::Inserts(requests));
streaming_client.send(request).await.map_err(|e| {
error::ClientStreamingSnafu {
err_msg: e.to_string(),
}
.build()
})
}
pub async fn delete(&self, request: DeleteRequest) -> Result<u32> {
let _timer = timer!(metrics::METRIC_GRPC_DELETE);
self.handle(Request::Delete(request)).await
@@ -121,15 +143,7 @@ impl Database {
async fn handle(&self, request: Request) -> Result<u32> {
let mut client = self.client.make_database_client()?.inner;
let request = GreptimeRequest {
header: Some(RequestHeader {
catalog: self.catalog.clone(),
schema: self.schema.clone(),
authorization: self.ctx.auth_header.clone(),
dbname: self.dbname.clone(),
}),
request: Some(request),
};
let request = self.to_rpc_request(request);
let response = client
.handle(request)
.await?
@@ -142,6 +156,27 @@ impl Database {
Ok(value)
}
#[inline]
fn to_rpc_request(&self, request: Request) -> GreptimeRequest {
GreptimeRequest {
header: Some(RequestHeader {
catalog: self.catalog.clone(),
schema: self.schema.clone(),
authorization: self.ctx.auth_header.clone(),
dbname: self.dbname.clone(),
}),
request: Some(request),
}
}
async fn client_stream(&self) -> Result<Sender<GreptimeRequest>> {
let mut client = self.client.make_database_client()?.inner;
let (sender, receiver) = mpsc::channel::<GreptimeRequest>(65536);
let receiver = ReceiverStream::new(receiver);
client.handle_requests(receiver).await?;
Ok(sender)
}
pub async fn sql(&self, sql: &str) -> Result<Output> {
let _timer = timer!(metrics::METRIC_GRPC_SQL);
self.do_get(Request::Query(QueryRequest {
@@ -212,15 +247,7 @@ impl Database {
async fn do_get(&self, request: Request) -> Result<Output> {
// FIXME(paomian): should be added some labels for metrics
let _timer = timer!(metrics::METRIC_GRPC_DO_GET);
let request = GreptimeRequest {
header: Some(RequestHeader {
catalog: self.catalog.clone(),
schema: self.schema.clone(),
authorization: self.ctx.auth_header.clone(),
dbname: self.dbname.clone(),
}),
request: Some(request),
};
let request = self.to_rpc_request(request);
let request = Ticket {
ticket: request.encode_to_vec().into(),
};

View File

@@ -67,6 +67,9 @@ pub enum Error {
#[snafu(display("Illegal Database response: {err_msg}"))]
IllegalDatabaseResponse { err_msg: String },
#[snafu(display("Failed to send request with streaming: {}", err_msg))]
ClientStreaming { err_msg: String, location: Location },
}
pub type Result<T> = std::result::Result<T, Error>;
@@ -77,7 +80,8 @@ impl ErrorExt for Error {
Error::IllegalFlightMessages { .. }
| Error::ColumnDataType { .. }
| Error::MissingField { .. }
| Error::IllegalDatabaseResponse { .. } => StatusCode::Internal,
| Error::IllegalDatabaseResponse { .. }
| Error::ClientStreaming { .. } => StatusCode::Internal,
Error::Server { code, .. } => *code,
Error::FlightGet { source, .. } => source.status_code(),