mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-06-01 12:50:40 +00:00
feat: inject current database/schema into query context for postgres protocol (#685)
* feat: inject current database/schema into query context * test: avoid duplicate server setup
This commit is contained in:
@@ -42,14 +42,27 @@ impl PostgresServerHandler {
|
||||
}
|
||||
}
|
||||
|
||||
const CLIENT_METADATA_DATABASE: &str = "database";
|
||||
|
||||
fn query_context_from_client_info<C>(client: &C) -> Arc<QueryContext>
|
||||
where
|
||||
C: ClientInfo,
|
||||
{
|
||||
let query_context = QueryContext::new();
|
||||
if let Some(current_schema) = client.metadata().get(CLIENT_METADATA_DATABASE) {
|
||||
query_context.set_current_schema(current_schema);
|
||||
}
|
||||
|
||||
Arc::new(query_context)
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl SimpleQueryHandler for PostgresServerHandler {
|
||||
async fn do_query<C>(&self, _client: &C, query: &str) -> PgWireResult<Vec<Response>>
|
||||
async fn do_query<C>(&self, client: &C, query: &str) -> PgWireResult<Vec<Response>>
|
||||
where
|
||||
C: ClientInfo + Unpin + Send + Sync,
|
||||
{
|
||||
// TODO(LFC): Sessions in pg server.
|
||||
let query_ctx = Arc::new(QueryContext::new());
|
||||
let query_ctx = query_context_from_client_info(client);
|
||||
let output = self
|
||||
.query_handler
|
||||
.do_query(query, query_ctx)
|
||||
|
||||
@@ -16,6 +16,7 @@ use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, SystemTime};
|
||||
|
||||
use common_catalog::consts::DEFAULT_SCHEMA_NAME;
|
||||
use common_runtime::Builder as RuntimeBuilder;
|
||||
use rand::rngs::StdRng;
|
||||
use rand::Rng;
|
||||
@@ -80,7 +81,6 @@ async fn test_shutdown_pg_server(with_pwd: bool) -> Result<()> {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
|
||||
let table = MemTable::default_numbers_table();
|
||||
|
||||
let postgres_server = create_postgres_server(table, with_pwd, Default::default())?;
|
||||
let result = postgres_server.shutdown().await;
|
||||
assert!(result
|
||||
@@ -136,14 +136,7 @@ async fn test_shutdown_pg_server(with_pwd: bool) -> Result<()> {
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn test_query_pg_concurrently() -> Result<()> {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
|
||||
let table = MemTable::default_numbers_table();
|
||||
|
||||
let pg_server = create_postgres_server(table, false, Default::default())?;
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let server_addr = pg_server.start(listening).await.unwrap();
|
||||
let server_port = server_addr.port();
|
||||
let server_port = start_test_server(Default::default()).await?;
|
||||
|
||||
let threads = 4;
|
||||
let expect_executed_queries_per_worker = 300;
|
||||
@@ -211,13 +204,7 @@ async fn test_server_secure_require_client_plain() -> Result<()> {
|
||||
cert_path: "tests/ssl/server.crt".to_owned(),
|
||||
key_path: "tests/ssl/server.key".to_owned(),
|
||||
});
|
||||
|
||||
let table = MemTable::default_numbers_table();
|
||||
let pg_server = create_postgres_server(table, false, server_tls)?;
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let server_addr = pg_server.start(listening).await.unwrap();
|
||||
let server_port = server_addr.port();
|
||||
|
||||
let server_port = start_test_server(server_tls).await?;
|
||||
let r = create_plain_connection(server_port, false).await;
|
||||
assert!(r.is_err());
|
||||
Ok(())
|
||||
@@ -238,12 +225,35 @@ async fn test_server_secure_require_client_secure() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn do_simple_query(server_tls: Arc<TlsOption>, client_tls: bool) -> Result<()> {
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_using_db() -> Result<()> {
|
||||
let server_port = start_test_server(Arc::new(TlsOption::default())).await?;
|
||||
|
||||
let client = create_connection_with_given_db(server_port, "testdb")
|
||||
.await
|
||||
.unwrap();
|
||||
let result = client.simple_query("SELECT uint32s FROM numbers").await;
|
||||
assert!(result.is_err());
|
||||
|
||||
let client = create_connection_with_given_db(server_port, DEFAULT_SCHEMA_NAME)
|
||||
.await
|
||||
.unwrap();
|
||||
let result = client.simple_query("SELECT uint32s FROM numbers").await;
|
||||
assert!(result.is_ok());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn start_test_server(server_tls: Arc<TlsOption>) -> Result<u16> {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
let table = MemTable::default_numbers_table();
|
||||
let pg_server = create_postgres_server(table, false, server_tls)?;
|
||||
let listening = "127.0.0.1:0".parse::<SocketAddr>().unwrap();
|
||||
let server_addr = pg_server.start(listening).await.unwrap();
|
||||
let server_port = server_addr.port();
|
||||
Ok(server_addr.port())
|
||||
}
|
||||
|
||||
async fn do_simple_query(server_tls: Arc<TlsOption>, client_tls: bool) -> Result<()> {
|
||||
let server_port = start_test_server(server_tls).await?;
|
||||
|
||||
if !client_tls {
|
||||
let client = create_plain_connection(server_port, false).await.unwrap();
|
||||
@@ -303,6 +313,19 @@ async fn create_plain_connection(
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
async fn create_connection_with_given_db(
|
||||
port: u16,
|
||||
db: &str,
|
||||
) -> std::result::Result<Client, PgError> {
|
||||
let url = format!(
|
||||
"host=127.0.0.1 port={} connect_timeout=2 dbname={}",
|
||||
port, db
|
||||
);
|
||||
let (client, conn) = tokio_postgres::connect(&url, NoTls).await?;
|
||||
tokio::spawn(conn);
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
fn resolve_result(resp: &SimpleQueryMessage, col_index: usize) -> Option<&str> {
|
||||
match resp {
|
||||
&SimpleQueryMessage::Row(ref r) => r.get(col_index),
|
||||
|
||||
Reference in New Issue
Block a user