mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-07 05:22:56 +00:00
more compact code and more compact futures
This commit is contained in:
@@ -600,6 +600,7 @@ impl ParameterStatusBody {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct ReadyForQueryBody {
|
||||
status: u8,
|
||||
}
|
||||
|
||||
@@ -29,6 +29,9 @@ pub struct Responses {
|
||||
waiting: usize,
|
||||
/// number of ReadyForQuery messages received.
|
||||
received: usize,
|
||||
|
||||
/// The last query status we received.
|
||||
last_status: ReadyForQueryStatus,
|
||||
}
|
||||
|
||||
impl Responses {
|
||||
@@ -39,7 +42,8 @@ impl Responses {
|
||||
let received = self.received;
|
||||
|
||||
// increase the query head if this is the last message.
|
||||
if let Message::ReadyForQuery(_) = message {
|
||||
if let Message::ReadyForQuery(ref status) = message {
|
||||
self.last_status = (*status).into();
|
||||
self.received += 1;
|
||||
}
|
||||
|
||||
@@ -68,6 +72,15 @@ impl Responses {
|
||||
pub async fn next(&mut self) -> Result<Message, Error> {
|
||||
future::poll_fn(|cx| self.poll_next(cx)).await
|
||||
}
|
||||
|
||||
pub async fn wait_until_ready(&mut self) -> Result<ReadyForQueryStatus, Error> {
|
||||
while self.received < self.waiting {
|
||||
if let Message::ReadyForQuery(status) = self.next().await? {
|
||||
return Ok(status.into());
|
||||
}
|
||||
}
|
||||
Ok(self.last_status)
|
||||
}
|
||||
}
|
||||
|
||||
/// A cache of type info and prepared statements for fetching type info
|
||||
@@ -92,13 +105,6 @@ impl InnerClient {
|
||||
Ok(PartialQuery(Some(self)))
|
||||
}
|
||||
|
||||
// pub fn send_with_sync<F>(&mut self, f: F) -> Result<&mut Responses, Error>
|
||||
// where
|
||||
// F: FnOnce(&mut BytesMut) -> Result<(), Error>,
|
||||
// {
|
||||
// self.start()?.send_with_sync(f)
|
||||
// }
|
||||
|
||||
pub fn send_simple_query(&mut self, query: &str) -> Result<&mut Responses, Error> {
|
||||
self.responses.waiting += 1;
|
||||
|
||||
@@ -197,6 +203,8 @@ impl Client {
|
||||
cur: BackendMessages::empty(),
|
||||
waiting: 0,
|
||||
received: 0,
|
||||
// new connections are always idle.
|
||||
last_status: ReadyForQueryStatus::Idle,
|
||||
},
|
||||
buffer: Default::default(),
|
||||
},
|
||||
@@ -230,6 +238,10 @@ impl Client {
|
||||
rx
|
||||
}
|
||||
|
||||
pub async fn wait_until_ready(&mut self) -> Result<ReadyForQueryStatus, Error> {
|
||||
self.inner_mut().responses.wait_until_ready().await
|
||||
}
|
||||
|
||||
/// Pass text directly to the Postgres backend to allow it to sort out typing itself and
|
||||
/// to save a roundtrip
|
||||
pub async fn query_raw_txt<S, I>(
|
||||
|
||||
@@ -2,8 +2,8 @@ use std::pin::pin;
|
||||
use std::sync::Arc;
|
||||
|
||||
use bytes::Bytes;
|
||||
use futures::future::{Either, select, try_join};
|
||||
use futures::{StreamExt, TryFutureExt};
|
||||
use futures::future::try_join;
|
||||
use futures::{FutureExt, TryFutureExt, TryStreamExt};
|
||||
use http::Method;
|
||||
use http::header::AUTHORIZATION;
|
||||
use http_body_util::combinators::BoxBody;
|
||||
@@ -495,7 +495,7 @@ async fn handle_db_inner(
|
||||
.http_conn_content_length_bytes
|
||||
.observe(HttpDirection::Request, body.len() as f64);
|
||||
|
||||
debug!(length = body.len(), "request payload read");
|
||||
debug!(length = body.len(), "request payload read ");
|
||||
let payload: Payload = serde_json::from_slice(&body)?;
|
||||
Ok::<Payload, ReadPayloadError>(payload) // Adjust error type accordingly
|
||||
}
|
||||
@@ -566,29 +566,32 @@ async fn handle_db_inner(
|
||||
.status(StatusCode::OK)
|
||||
.header(header::CONTENT_TYPE, "application/json");
|
||||
|
||||
// Now execute the query and return the result.
|
||||
let json_output = match payload {
|
||||
Payload::Single(stmt) => {
|
||||
stmt.process(&config.http_config, cancel, &mut client, parsed_headers)
|
||||
.await?
|
||||
if let Payload::Batch(_) = payload {
|
||||
if parsed_headers.txn_read_only {
|
||||
response = response.header(TXN_READ_ONLY.clone(), &HEADER_VALUE_TRUE);
|
||||
}
|
||||
Payload::Batch(statements) => {
|
||||
if parsed_headers.txn_read_only {
|
||||
response = response.header(TXN_READ_ONLY.clone(), &HEADER_VALUE_TRUE);
|
||||
}
|
||||
if parsed_headers.txn_deferrable {
|
||||
response = response.header(TXN_DEFERRABLE.clone(), &HEADER_VALUE_TRUE);
|
||||
}
|
||||
if let Some(txn_isolation_level) = parsed_headers
|
||||
.txn_isolation_level
|
||||
.and_then(map_isolation_level_to_headers)
|
||||
{
|
||||
response = response.header(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level);
|
||||
}
|
||||
if parsed_headers.txn_deferrable {
|
||||
response = response.header(TXN_DEFERRABLE.clone(), &HEADER_VALUE_TRUE);
|
||||
}
|
||||
if let Some(txn_isolation_level) = parsed_headers
|
||||
.txn_isolation_level
|
||||
.and_then(map_isolation_level_to_headers)
|
||||
{
|
||||
response = response.header(TXN_ISOLATION_LEVEL.clone(), txn_isolation_level);
|
||||
}
|
||||
}
|
||||
|
||||
statements
|
||||
.process(&config.http_config, cancel, &mut client, parsed_headers)
|
||||
.await?
|
||||
// Now execute the query and return the result.
|
||||
let json_output = match payload
|
||||
.process(&config.http_config, cancel, &mut client, parsed_headers)
|
||||
.await
|
||||
{
|
||||
Ok(json_output) => json_output,
|
||||
Err(error) => {
|
||||
if let SqlOverHttpError::Cancelled(_) = error {
|
||||
cancel_query(&mut client).await;
|
||||
}
|
||||
return Err(error);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -673,7 +676,7 @@ async fn handle_auth_broker_inner(
|
||||
.map(|b| b.boxed()))
|
||||
}
|
||||
|
||||
impl QueryData {
|
||||
impl Payload {
|
||||
async fn process(
|
||||
self,
|
||||
config: &'static HttpConfig,
|
||||
@@ -682,85 +685,11 @@ impl QueryData {
|
||||
parsed_headers: HttpHeaders,
|
||||
) -> Result<String, SqlOverHttpError> {
|
||||
let (inner, mut discard) = client.inner();
|
||||
let cancel_token = inner.cancel_token();
|
||||
|
||||
let mut json_buf = vec![];
|
||||
let needs_tx = matches!(self, Payload::Batch(_));
|
||||
|
||||
let batch_result = match select(
|
||||
pin!(query_to_json(
|
||||
config,
|
||||
&mut *inner,
|
||||
self,
|
||||
json::ValueSer::new(&mut json_buf),
|
||||
parsed_headers
|
||||
)),
|
||||
pin!(cancel.cancelled()),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Either::Left((res, __not_yet_cancelled)) => res,
|
||||
Either::Right((_cancelled, query)) => {
|
||||
tracing::info!("cancelling query");
|
||||
if let Err(err) = cancel_token.cancel_query(NoTls).await {
|
||||
tracing::warn!(?err, "could not cancel query");
|
||||
}
|
||||
// wait for the query cancellation
|
||||
match time::timeout(time::Duration::from_millis(100), query).await {
|
||||
// query successed before it was cancelled.
|
||||
Ok(Ok(status)) => Ok(status),
|
||||
// query failed or was cancelled.
|
||||
Ok(Err(error)) => {
|
||||
let db_error = match &error {
|
||||
SqlOverHttpError::ConnectCompute(
|
||||
HttpConnError::PostgresConnectionError(e),
|
||||
)
|
||||
| SqlOverHttpError::Postgres(e) => e.as_db_error(),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
// if errored for some other reason, it might not be safe to return
|
||||
if !db_error.is_some_and(|e| *e.code() == SqlState::QUERY_CANCELED) {
|
||||
discard.discard();
|
||||
}
|
||||
|
||||
return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres));
|
||||
}
|
||||
Err(_timeout) => {
|
||||
discard.discard();
|
||||
return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
match batch_result {
|
||||
// The query successfully completed.
|
||||
Ok(_) => {
|
||||
let json_output = String::from_utf8(json_buf).expect("json should be valid utf8");
|
||||
Ok(json_output)
|
||||
}
|
||||
// The query failed with an error
|
||||
Err(e) => {
|
||||
discard.discard();
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl BatchQueryData {
|
||||
async fn process(
|
||||
self,
|
||||
config: &'static HttpConfig,
|
||||
cancel: CancellationToken,
|
||||
client: &mut Client,
|
||||
parsed_headers: HttpHeaders,
|
||||
) -> Result<String, SqlOverHttpError> {
|
||||
info!("starting transaction");
|
||||
let (inner, mut discard) = client.inner();
|
||||
let cancel_token = inner.cancel_token();
|
||||
|
||||
{
|
||||
if needs_tx {
|
||||
info!("starting transaction");
|
||||
let query = TransactionBuilder {
|
||||
isolation_level: parsed_headers.txn_isolation_level,
|
||||
read_only: parsed_headers.txn_read_only.then_some(true),
|
||||
@@ -779,93 +708,74 @@ impl BatchQueryData {
|
||||
.map_err(SqlOverHttpError::Postgres)?;
|
||||
}
|
||||
|
||||
let res =
|
||||
query_batch_to_json(config, cancel.child_token(), inner, self, parsed_headers).await;
|
||||
|
||||
let json_output = match res {
|
||||
Ok(json_output) => {
|
||||
info!("commit");
|
||||
inner
|
||||
.commit()
|
||||
.await
|
||||
.inspect_err(|_| {
|
||||
// if we cannot commit - for now don't return connection to pool
|
||||
// TODO: get a query status from the error
|
||||
discard.discard();
|
||||
})
|
||||
.map_err(SqlOverHttpError::Postgres)?;
|
||||
json_output
|
||||
let json_output = json::value_to_string!(|value| match self {
|
||||
Payload::Single(query) => {
|
||||
query_to_json(config, &cancel, inner, query, value, parsed_headers).await?;
|
||||
}
|
||||
Err(SqlOverHttpError::Cancelled(_)) => {
|
||||
if let Err(err) = cancel_token.cancel_query(NoTls).await {
|
||||
tracing::warn!(?err, "could not cancel query");
|
||||
Payload::Batch(batch) => {
|
||||
let mut obj = value.object();
|
||||
let mut results = obj.key("results").list();
|
||||
|
||||
for query in batch.queries {
|
||||
let value = results.entry();
|
||||
query_to_json(config, &cancel, inner, query, value, parsed_headers).await?;
|
||||
}
|
||||
// TODO: after cancelling, wait to see if we can get a status. maybe the connection is still safe.
|
||||
discard.discard();
|
||||
|
||||
return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres));
|
||||
results.finish();
|
||||
obj.finish();
|
||||
}
|
||||
Err(err) => {
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
});
|
||||
|
||||
if needs_tx {
|
||||
inner
|
||||
.commit()
|
||||
.await
|
||||
.inspect_err(|_| {
|
||||
// if we cannot commit - for now don't return connection to pool
|
||||
// TODO: get a query status from the error
|
||||
discard.discard();
|
||||
})
|
||||
.map_err(SqlOverHttpError::Postgres)?;
|
||||
}
|
||||
|
||||
Ok(json_output)
|
||||
}
|
||||
}
|
||||
|
||||
async fn query_batch(
|
||||
config: &'static HttpConfig,
|
||||
cancel: CancellationToken,
|
||||
client: &mut postgres_client::Client,
|
||||
queries: BatchQueryData,
|
||||
parsed_headers: HttpHeaders,
|
||||
results: &mut json::ListSer<'_>,
|
||||
) -> Result<(), SqlOverHttpError> {
|
||||
for stmt in queries.queries {
|
||||
let query = pin!(query_to_json(
|
||||
config,
|
||||
client,
|
||||
stmt,
|
||||
results.entry(),
|
||||
parsed_headers,
|
||||
));
|
||||
let cancelled = pin!(cancel.cancelled());
|
||||
let res = select(query, cancelled).await;
|
||||
match res {
|
||||
// TODO: maybe we should check that the transaction bit is set here
|
||||
Either::Left((Ok(_), _cancelled)) => {}
|
||||
Either::Left((Err(e), _cancelled)) => {
|
||||
return Err(e);
|
||||
}
|
||||
Either::Right((_cancelled, _)) => {
|
||||
return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres));
|
||||
}
|
||||
}
|
||||
async fn cancel_query(client: &mut Client) {
|
||||
let (inner, mut discard) = client.inner();
|
||||
let cancel_token = inner.cancel_token();
|
||||
|
||||
if let Err(err) = cancel_token.cancel_query(NoTls).await {
|
||||
tracing::warn!(?err, "could not cancel query");
|
||||
|
||||
// couldn't reach the server. let's just throw away this conn
|
||||
discard.discard();
|
||||
return;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
// wait for the query cancellation
|
||||
match time::timeout(time::Duration::from_millis(100), inner.wait_until_ready()).await {
|
||||
// we managed to cancel the query.
|
||||
Ok(Ok(_)) => {}
|
||||
// query failed or was cancelled.
|
||||
Ok(Err(error)) => {
|
||||
let db_error = error.as_db_error();
|
||||
|
||||
async fn query_batch_to_json(
|
||||
config: &'static HttpConfig,
|
||||
cancel: CancellationToken,
|
||||
client: &mut postgres_client::Client,
|
||||
queries: BatchQueryData,
|
||||
headers: HttpHeaders,
|
||||
) -> Result<String, SqlOverHttpError> {
|
||||
let json_output = json::value_to_string!(|obj| json::value_as_object!(|obj| {
|
||||
let results = obj.key("results");
|
||||
json::value_as_list!(|results| {
|
||||
query_batch(config, cancel, client, queries, headers, results).await?;
|
||||
});
|
||||
}));
|
||||
|
||||
Ok(json_output)
|
||||
// if errored for some other reason, it might not be safe to reuse the connection.
|
||||
if !db_error.is_some_and(|e| *e.code() == SqlState::QUERY_CANCELED) {
|
||||
discard.discard();
|
||||
}
|
||||
}
|
||||
Err(_timeout) => {
|
||||
discard.discard();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn query_to_json(
|
||||
config: &'static HttpConfig,
|
||||
cancel: &CancellationToken,
|
||||
client: &mut postgres_client::Client,
|
||||
data: QueryData,
|
||||
output: json::ValueSer<'_>,
|
||||
@@ -874,10 +784,13 @@ async fn query_to_json(
|
||||
let query_start = Instant::now();
|
||||
|
||||
let mut output = json::ObjectSer::new(output);
|
||||
let mut row_stream = client
|
||||
.query_raw_txt(&data.query, data.params)
|
||||
.await
|
||||
.map_err(SqlOverHttpError::Postgres)?;
|
||||
|
||||
let mut row_stream =
|
||||
run_until_cancelled(client.query_raw_txt(&data.query, data.params), cancel)
|
||||
.await
|
||||
.ok_or(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres))?
|
||||
.map_err(SqlOverHttpError::Postgres)?;
|
||||
|
||||
let query_acknowledged = Instant::now();
|
||||
|
||||
let mut json_fields = output.key("fields").list();
|
||||
@@ -903,8 +816,13 @@ async fn query_to_json(
|
||||
// big.
|
||||
let mut rows = 0;
|
||||
let mut json_rows = output.key("rows").list();
|
||||
while let Some(row) = row_stream.next().await {
|
||||
let row = row.map_err(SqlOverHttpError::Postgres)?;
|
||||
loop {
|
||||
let row = run_until_cancelled(row_stream.try_next(), cancel)
|
||||
.await
|
||||
.ok_or(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres))?
|
||||
.map_err(SqlOverHttpError::Postgres)?;
|
||||
|
||||
let Some(row) = row else { break };
|
||||
|
||||
// we don't have a streaming response support yet so this is to prevent OOM
|
||||
// from a malicious query (eg a cross join)
|
||||
|
||||
@@ -1,23 +1,50 @@
|
||||
use std::pin::pin;
|
||||
use std::{
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use futures::future::{Either, select};
|
||||
use futures::FutureExt;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
pub async fn run_until_cancelled<F: Future>(
|
||||
pub fn run_until_cancelled<F: Future>(
|
||||
f: F,
|
||||
cancellation_token: &CancellationToken,
|
||||
) -> Option<F::Output> {
|
||||
run_until(f, cancellation_token.cancelled()).await.ok()
|
||||
) -> impl Future<Output = Option<F::Output>> {
|
||||
run_until(f, cancellation_token.cancelled()).map(|r| r.ok())
|
||||
}
|
||||
|
||||
/// Runs the future `f` unless interrupted by future `condition`.
|
||||
pub async fn run_until<F1: Future, F2: Future>(
|
||||
pub fn run_until<F1: Future, F2: Future>(
|
||||
f: F1,
|
||||
condition: F2,
|
||||
) -> Result<F1::Output, F2::Output> {
|
||||
match select(pin!(f), pin!(condition)).await {
|
||||
Either::Left((f1, _)) => Ok(f1),
|
||||
Either::Right((f2, _)) => Err(f2),
|
||||
) -> impl Future<Output = Result<F1::Output, F2::Output>> {
|
||||
RunUntil { a: f, b: condition }
|
||||
}
|
||||
|
||||
pin_project_lite::pin_project! {
|
||||
struct RunUntil<A, B> {
|
||||
#[pin] a: A,
|
||||
#[pin] b: B,
|
||||
}
|
||||
}
|
||||
|
||||
impl<A, B> Future for RunUntil<A, B>
|
||||
where
|
||||
A: Future,
|
||||
B: Future,
|
||||
{
|
||||
type Output = Result<A::Output, B::Output>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let this = self.project();
|
||||
|
||||
if let Poll::Ready(a) = this.a.poll(cx) {
|
||||
return Poll::Ready(Ok(a));
|
||||
}
|
||||
if let Poll::Ready(b) = this.b.poll(cx) {
|
||||
return Poll::Ready(Err(b));
|
||||
}
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user