feat: inject current database/schema into query context for postgres protocol (#685)

* feat: inject current database/schema into query context

* test: avoid duplicate server setup
This commit is contained in:
Ning Sun
2022-12-02 11:49:39 +08:00
committed by GitHub
parent 13d51250ba
commit 0599465685
2 changed files with 57 additions and 21 deletions

View File

@@ -42,14 +42,27 @@ impl PostgresServerHandler {
}
}
const CLIENT_METADATA_DATABASE: &str = "database";
fn query_context_from_client_info<C>(client: &C) -> Arc<QueryContext>
where
C: ClientInfo,
{
let query_context = QueryContext::new();
if let Some(current_schema) = client.metadata().get(CLIENT_METADATA_DATABASE) {
query_context.set_current_schema(current_schema);
}
Arc::new(query_context)
}
#[async_trait]
impl SimpleQueryHandler for PostgresServerHandler {
async fn do_query<C>(&self, _client: &C, query: &str) -> PgWireResult<Vec<Response>>
async fn do_query<C>(&self, client: &C, query: &str) -> PgWireResult<Vec<Response>>
where
C: ClientInfo + Unpin + Send + Sync,
{
// TODO(LFC): Sessions in pg server.
let query_ctx = Arc::new(QueryContext::new());
let query_ctx = query_context_from_client_info(client);
let output = self
.query_handler
.do_query(query, query_ctx)

View File

@@ -16,6 +16,7 @@ use std::net::SocketAddr;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use common_catalog::consts::DEFAULT_SCHEMA_NAME;
use common_runtime::Builder as RuntimeBuilder;
use rand::rngs::StdRng;
use rand::Rng;
@@ -80,7 +81,6 @@ async fn test_shutdown_pg_server(with_pwd: bool) -> Result<()> {
common_telemetry::init_default_ut_logging();
let table = MemTable::default_numbers_table();
let postgres_server = create_postgres_server(table, with_pwd, Default::default())?;
let result = postgres_server.shutdown().await;
assert!(result
@@ -136,14 +136,7 @@ async fn test_shutdown_pg_server(with_pwd: bool) -> Result<()> {
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_query_pg_concurrently() -> Result<()> {
common_telemetry::init_default_ut_logging();
let table = MemTable::default_numbers_table();
let pg_server = create_postgres_server(table, false, Default::default())?;
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = pg_server.start(listening).await.unwrap();
let server_port = server_addr.port();
let server_port = start_test_server(Default::default()).await?;
let threads = 4;
let expect_executed_queries_per_worker = 300;
@@ -211,13 +204,7 @@ async fn test_server_secure_require_client_plain() -> Result<()> {
cert_path: "tests/ssl/server.crt".to_owned(),
key_path: "tests/ssl/server.key".to_owned(),
});
let table = MemTable::default_numbers_table();
let pg_server = create_postgres_server(table, false, server_tls)?;
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = pg_server.start(listening).await.unwrap();
let server_port = server_addr.port();
let server_port = start_test_server(server_tls).await?;
let r = create_plain_connection(server_port, false).await;
assert!(r.is_err());
Ok(())
@@ -238,12 +225,35 @@ async fn test_server_secure_require_client_secure() -> Result<()> {
Ok(())
}
async fn do_simple_query(server_tls: Arc<TlsOption>, client_tls: bool) -> Result<()> {
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_using_db() -> Result<()> {
let server_port = start_test_server(Arc::new(TlsOption::default())).await?;
let client = create_connection_with_given_db(server_port, "testdb")
.await
.unwrap();
let result = client.simple_query("SELECT uint32s FROM numbers").await;
assert!(result.is_err());
let client = create_connection_with_given_db(server_port, DEFAULT_SCHEMA_NAME)
.await
.unwrap();
let result = client.simple_query("SELECT uint32s FROM numbers").await;
assert!(result.is_ok());
Ok(())
}
async fn start_test_server(server_tls: Arc<TlsOption>) -> Result<u16> {
common_telemetry::init_default_ut_logging();
let table = MemTable::default_numbers_table();
let pg_server = create_postgres_server(table, false, server_tls)?;
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
let server_addr = pg_server.start(listening).await.unwrap();
let server_port = server_addr.port();
Ok(server_addr.port())
}
async fn do_simple_query(server_tls: Arc<TlsOption>, client_tls: bool) -> Result<()> {
let server_port = start_test_server(server_tls).await?;
if !client_tls {
let client = create_plain_connection(server_port, false).await.unwrap();
@@ -303,6 +313,19 @@ async fn create_plain_connection(
Ok(client)
}
async fn create_connection_with_given_db(
port: u16,
db: &str,
) -> std::result::Result<Client, PgError> {
let url = format!(
"host=127.0.0.1 port={} connect_timeout=2 dbname={}",
port, db
);
let (client, conn) = tokio_postgres::connect(&url, NoTls).await?;
tokio::spawn(conn);
Ok(client)
}
fn resolve_result(resp: &SimpleQueryMessage, col_index: usize) -> Option<&str> {
match resp {
&SimpleQueryMessage::Row(ref r) => r.get(col_index),