Files
neon/libs/proxy/tokio-postgres2/src/client.rs
Conrad Ludgate cd1d2d1996 fix(proxy): forward notifications from authentication (#9948)
Fixes https://github.com/neondatabase/cloud/issues/20973. 

This refactors `connect_raw` in order to return direct access to the
delayed notices.

I cannot find a way to test this with psycopg2 unfortunately, although
testing it with psql does return the expected results.
2024-12-02 12:29:57 +00:00

437 lines
15 KiB
Rust

use crate::codec::{BackendMessages, FrontendMessage};
use crate::config::Host;
use crate::config::SslMode;
use crate::connection::{Request, RequestMessages};
use crate::query::RowStream;
use crate::simple_query::SimpleQueryStream;
use crate::types::{Oid, ToSql, Type};
use crate::{
prepare, query, simple_query, slice_iter, CancelToken, Error, ReadyForQueryStatus, Row,
SimpleQueryMessage, Statement, ToStatement, Transaction, TransactionBuilder,
};
use bytes::BytesMut;
use fallible_iterator::FallibleIterator;
use futures_util::{future, ready, TryStreamExt};
use parking_lot::Mutex;
use postgres_protocol2::message::{backend::Message, frontend};
use std::collections::HashMap;
use std::fmt;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::sync::mpsc;
use std::time::Duration;
pub struct Responses {
receiver: mpsc::Receiver<BackendMessages>,
cur: BackendMessages,
}
impl Responses {
pub fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Result<Message, Error>> {
loop {
match self.cur.next().map_err(Error::parse)? {
Some(Message::ErrorResponse(body)) => return Poll::Ready(Err(Error::db(body))),
Some(message) => return Poll::Ready(Ok(message)),
None => {}
}
match ready!(self.receiver.poll_recv(cx)) {
Some(messages) => self.cur = messages,
None => return Poll::Ready(Err(Error::closed())),
}
}
}
pub async fn next(&mut self) -> Result<Message, Error> {
future::poll_fn(|cx| self.poll_next(cx)).await
}
}
/// A cache of type info and prepared statements for fetching type info
/// (corresponding to the queries in the [prepare] module).
#[derive(Default)]
struct CachedTypeInfo {
/// A statement for basic information for a type from its
/// OID. Corresponds to [TYPEINFO_QUERY](prepare::TYPEINFO_QUERY) (or its
/// fallback).
typeinfo: Option<Statement>,
/// A statement for getting information for a composite type from its OID.
/// Corresponds to [TYPEINFO_QUERY](prepare::TYPEINFO_COMPOSITE_QUERY).
typeinfo_composite: Option<Statement>,
/// A statement for getting information for a composite type from its OID.
/// Corresponds to [TYPEINFO_QUERY](prepare::TYPEINFO_COMPOSITE_QUERY) (or
/// its fallback).
typeinfo_enum: Option<Statement>,
/// Cache of types already looked up.
types: HashMap<Oid, Type>,
}
pub struct InnerClient {
sender: mpsc::UnboundedSender<Request>,
cached_typeinfo: Mutex<CachedTypeInfo>,
/// A buffer to use when writing out postgres commands.
buffer: Mutex<BytesMut>,
}
impl InnerClient {
pub fn send(&self, messages: RequestMessages) -> Result<Responses, Error> {
let (sender, receiver) = mpsc::channel(1);
let request = Request { messages, sender };
self.sender.send(request).map_err(|_| Error::closed())?;
Ok(Responses {
receiver,
cur: BackendMessages::empty(),
})
}
pub fn typeinfo(&self) -> Option<Statement> {
self.cached_typeinfo.lock().typeinfo.clone()
}
pub fn set_typeinfo(&self, statement: &Statement) {
self.cached_typeinfo.lock().typeinfo = Some(statement.clone());
}
pub fn typeinfo_composite(&self) -> Option<Statement> {
self.cached_typeinfo.lock().typeinfo_composite.clone()
}
pub fn set_typeinfo_composite(&self, statement: &Statement) {
self.cached_typeinfo.lock().typeinfo_composite = Some(statement.clone());
}
pub fn typeinfo_enum(&self) -> Option<Statement> {
self.cached_typeinfo.lock().typeinfo_enum.clone()
}
pub fn set_typeinfo_enum(&self, statement: &Statement) {
self.cached_typeinfo.lock().typeinfo_enum = Some(statement.clone());
}
pub fn type_(&self, oid: Oid) -> Option<Type> {
self.cached_typeinfo.lock().types.get(&oid).cloned()
}
pub fn set_type(&self, oid: Oid, type_: &Type) {
self.cached_typeinfo.lock().types.insert(oid, type_.clone());
}
/// Call the given function with a buffer to be used when writing out
/// postgres commands.
pub fn with_buf<F, R>(&self, f: F) -> R
where
F: FnOnce(&mut BytesMut) -> R,
{
let mut buffer = self.buffer.lock();
let r = f(&mut buffer);
buffer.clear();
r
}
}
#[derive(Clone)]
pub struct SocketConfig {
pub host: Host,
pub port: u16,
pub connect_timeout: Option<Duration>,
// pub keepalive: Option<KeepaliveConfig>,
}
/// An asynchronous PostgreSQL client.
///
/// The client is one half of what is returned when a connection is established. Users interact with the database
/// through this client object.
pub struct Client {
inner: Arc<InnerClient>,
socket_config: SocketConfig,
ssl_mode: SslMode,
process_id: i32,
secret_key: i32,
}
impl Client {
pub(crate) fn new(
sender: mpsc::UnboundedSender<Request>,
socket_config: SocketConfig,
ssl_mode: SslMode,
process_id: i32,
secret_key: i32,
) -> Client {
Client {
inner: Arc::new(InnerClient {
sender,
cached_typeinfo: Default::default(),
buffer: Default::default(),
}),
socket_config,
ssl_mode,
process_id,
secret_key,
}
}
/// Returns process_id.
pub fn get_process_id(&self) -> i32 {
self.process_id
}
pub(crate) fn inner(&self) -> &Arc<InnerClient> {
&self.inner
}
/// Creates a new prepared statement.
///
/// Prepared statements can be executed repeatedly, and may contain query parameters (indicated by `$1`, `$2`, etc),
/// which are set when executed. Prepared statements can only be used with the connection that created them.
pub async fn prepare(&self, query: &str) -> Result<Statement, Error> {
self.prepare_typed(query, &[]).await
}
/// Like `prepare`, but allows the types of query parameters to be explicitly specified.
///
/// The list of types may be smaller than the number of parameters - the types of the remaining parameters will be
/// inferred. For example, `client.prepare_typed(query, &[])` is equivalent to `client.prepare(query)`.
pub async fn prepare_typed(
&self,
query: &str,
parameter_types: &[Type],
) -> Result<Statement, Error> {
prepare::prepare(&self.inner, query, parameter_types).await
}
/// Executes a statement, returning a vector of the resulting rows.
///
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
/// provided, 1-indexed.
///
/// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be
/// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
/// with the `prepare` method.
///
/// # Panics
///
/// Panics if the number of parameters provided does not match the number expected.
pub async fn query<T>(
&self,
statement: &T,
params: &[&(dyn ToSql + Sync)],
) -> Result<Vec<Row>, Error>
where
T: ?Sized + ToStatement,
{
self.query_raw(statement, slice_iter(params))
.await?
.try_collect()
.await
}
/// The maximally flexible version of [`query`].
///
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
/// provided, 1-indexed.
///
/// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be
/// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
/// with the `prepare` method.
///
/// # Panics
///
/// Panics if the number of parameters provided does not match the number expected.
///
/// [`query`]: #method.query
pub async fn query_raw<'a, T, I>(&self, statement: &T, params: I) -> Result<RowStream, Error>
where
T: ?Sized + ToStatement,
I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
I::IntoIter: ExactSizeIterator,
{
let statement = statement.__convert().into_statement(self).await?;
query::query(&self.inner, statement, params).await
}
/// Pass text directly to the Postgres backend to allow it to sort out typing itself and
/// to save a roundtrip
pub async fn query_raw_txt<S, I>(&self, statement: &str, params: I) -> Result<RowStream, Error>
where
S: AsRef<str>,
I: IntoIterator<Item = Option<S>>,
I::IntoIter: ExactSizeIterator,
{
query::query_txt(&self.inner, statement, params).await
}
/// Executes a statement, returning the number of rows modified.
///
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
/// provided, 1-indexed.
///
/// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be
/// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
/// with the `prepare` method.
///
/// If the statement does not modify any rows (e.g. `SELECT`), 0 is returned.
///
/// # Panics
///
/// Panics if the number of parameters provided does not match the number expected.
pub async fn execute<T>(
&self,
statement: &T,
params: &[&(dyn ToSql + Sync)],
) -> Result<u64, Error>
where
T: ?Sized + ToStatement,
{
self.execute_raw(statement, slice_iter(params)).await
}
/// The maximally flexible version of [`execute`].
///
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
/// provided, 1-indexed.
///
/// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be
/// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
/// with the `prepare` method.
///
/// # Panics
///
/// Panics if the number of parameters provided does not match the number expected.
///
/// [`execute`]: #method.execute
pub async fn execute_raw<'a, T, I>(&self, statement: &T, params: I) -> Result<u64, Error>
where
T: ?Sized + ToStatement,
I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
I::IntoIter: ExactSizeIterator,
{
let statement = statement.__convert().into_statement(self).await?;
query::execute(self.inner(), statement, params).await
}
/// Executes a sequence of SQL statements using the simple query protocol, returning the resulting rows.
///
/// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that
/// point. The simple query protocol returns the values in rows as strings rather than in their binary encodings,
/// so the associated row type doesn't work with the `FromSql` trait. Rather than simply returning a list of the
/// rows, this method returns a list of an enum which indicates either the completion of one of the commands,
/// or a row of data. This preserves the framing between the separate statements in the request.
///
/// # Warning
///
/// Prepared statements should be use for any query which contains user-specified data, as they provided the
/// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass
/// them to this method!
pub async fn simple_query(&self, query: &str) -> Result<Vec<SimpleQueryMessage>, Error> {
self.simple_query_raw(query).await?.try_collect().await
}
pub(crate) async fn simple_query_raw(&self, query: &str) -> Result<SimpleQueryStream, Error> {
simple_query::simple_query(self.inner(), query).await
}
/// Executes a sequence of SQL statements using the simple query protocol.
///
/// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that
/// point. This is intended for use when, for example, initializing a database schema.
///
/// # Warning
///
/// Prepared statements should be use for any query which contains user-specified data, as they provided the
/// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass
/// them to this method!
pub async fn batch_execute(&self, query: &str) -> Result<ReadyForQueryStatus, Error> {
simple_query::batch_execute(self.inner(), query).await
}
/// Begins a new database transaction.
///
/// The transaction will roll back by default - use the `commit` method to commit it.
pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
struct RollbackIfNotDone<'me> {
client: &'me Client,
done: bool,
}
impl Drop for RollbackIfNotDone<'_> {
fn drop(&mut self) {
if self.done {
return;
}
let buf = self.client.inner().with_buf(|buf| {
frontend::query("ROLLBACK", buf).unwrap();
buf.split().freeze()
});
let _ = self
.client
.inner()
.send(RequestMessages::Single(FrontendMessage::Raw(buf)));
}
}
// This is done, as `Future` created by this method can be dropped after
// `RequestMessages` is synchronously send to the `Connection` by
// `batch_execute()`, but before `Responses` is asynchronously polled to
// completion. In that case `Transaction` won't be created and thus
// won't be rolled back.
{
let mut cleaner = RollbackIfNotDone {
client: self,
done: false,
};
self.batch_execute("BEGIN").await?;
cleaner.done = true;
}
Ok(Transaction::new(self))
}
/// Returns a builder for a transaction with custom settings.
///
/// Unlike the `transaction` method, the builder can be used to control the transaction's isolation level and other
/// attributes.
pub fn build_transaction(&mut self) -> TransactionBuilder<'_> {
TransactionBuilder::new(self)
}
/// Constructs a cancellation token that can later be used to request cancellation of a query running on the
/// connection associated with this client.
pub fn cancel_token(&self) -> CancelToken {
CancelToken {
socket_config: Some(self.socket_config.clone()),
ssl_mode: self.ssl_mode,
process_id: self.process_id,
secret_key: self.secret_key,
}
}
/// Query for type information
pub async fn get_type(&self, oid: Oid) -> Result<Type, Error> {
crate::prepare::get_type(&self.inner, oid).await
}
/// Determines if the connection to the server has already closed.
///
/// In that case, all future queries will fail.
pub fn is_closed(&self) -> bool {
self.inner.sender.is_closed()
}
}
impl fmt::Debug for Client {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Client").finish()
}
}