From 39887702660568ab9f856a7cdafb58e7c51463c3 Mon Sep 17 00:00:00 2001 From: Ning Sun Date: Mon, 9 Jan 2023 11:43:25 +0800 Subject: [PATCH] feat: add catalog name resolution for postgres and http interface (#810) * feat: add catalog name resolution for postgres and http interface * test: add tests for catalog resolution on http and postgres * feat: assign custom catalog for query * chore: order code for better readability --- src/datanode/src/instance/sql.rs | 5 +- src/datanode/src/tests/instance_test.rs | 7 ++- src/frontend/src/instance.rs | 5 +- src/frontend/src/instance/distributed.rs | 1 + src/query/src/sql.rs | 5 +- src/servers/src/http.rs | 34 ++++++++++++ src/servers/src/http/handler.rs | 26 ++------- src/servers/src/http/script.rs | 2 + src/servers/src/lib.rs | 47 +++++++++++++++++ src/servers/src/postgres.rs | 4 ++ src/servers/src/postgres/auth_handler.rs | 67 ++++++++++++++++-------- src/servers/src/postgres/handler.rs | 5 +- src/servers/tests/postgres/mod.rs | 31 ++++++++++- src/session/src/context.rs | 19 ++++++- tests-integration/tests/http.rs | 37 +++++++++++++ 15 files changed, 243 insertions(+), 52 deletions(-) diff --git a/src/datanode/src/instance/sql.rs b/src/datanode/src/instance/sql.rs index c7d49d8d56..83532e1f13 100644 --- a/src/datanode/src/instance/sql.rs +++ b/src/datanode/src/instance/sql.rs @@ -255,7 +255,10 @@ mod test { let bare = ObjectName(vec![my_table.into()]); let using_schema = "foo"; - let query_ctx = Arc::new(QueryContext::with_current_schema(using_schema.to_string())); + let query_ctx = Arc::new(QueryContext::with( + DEFAULT_CATALOG_NAME.to_owned(), + using_schema.to_string(), + )); let empty_ctx = Arc::new(QueryContext::new()); assert_eq!( diff --git a/src/datanode/src/tests/instance_test.rs b/src/datanode/src/tests/instance_test.rs index 7a3f1c5895..0596324b71 100644 --- a/src/datanode/src/tests/instance_test.rs +++ b/src/datanode/src/tests/instance_test.rs @@ -14,7 +14,7 @@ use std::sync::Arc; -use common_catalog::consts::DEFAULT_SCHEMA_NAME; +use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_query::Output; use common_recordbatch::util; use datatypes::data_type::ConcreteDataType; @@ -559,6 +559,9 @@ async fn execute_sql(instance: &MockInstance, sql: &str) -> Output { } async fn execute_sql_in_db(instance: &MockInstance, sql: &str, db: &str) -> Output { - let query_ctx = Arc::new(QueryContext::with_current_schema(db.to_string())); + let query_ctx = Arc::new(QueryContext::with( + DEFAULT_CATALOG_NAME.to_owned(), + db.to_string(), + )); instance.inner().execute_sql(sql, query_ctx).await.unwrap() } diff --git a/src/frontend/src/instance.rs b/src/frontend/src/instance.rs index 2fa7068100..972dc0b771 100644 --- a/src/frontend/src/instance.rs +++ b/src/frontend/src/instance.rs @@ -343,9 +343,12 @@ impl Instance { } fn handle_use(&self, db: String, query_ctx: QueryContextRef) -> Result { + let catalog = query_ctx.current_catalog(); + let catalog = catalog.as_deref().unwrap_or(DEFAULT_CATALOG_NAME); + ensure!( self.catalog_manager - .schema(DEFAULT_CATALOG_NAME, &db) + .schema(catalog, &db) .context(error::CatalogSnafu)? .is_some(), error::SchemaNotFoundSnafu { schema_info: &db } diff --git a/src/frontend/src/instance/distributed.rs b/src/frontend/src/instance/distributed.rs index c801d36c15..a093ad7965 100644 --- a/src/frontend/src/instance/distributed.rs +++ b/src/frontend/src/instance/distributed.rs @@ -226,6 +226,7 @@ impl DistInstance { /// Handles distributed database creation async fn handle_create_database(&self, expr: CreateDatabaseExpr) -> Result { let key = SchemaKey { + // TODO(sunng87): custom catalog catalog_name: DEFAULT_CATALOG_NAME.to_string(), schema_name: expr.database_name, }; diff --git a/src/query/src/sql.rs b/src/query/src/sql.rs index e8e95f1328..996a139d3d 100644 --- a/src/query/src/sql.rs +++ b/src/query/src/sql.rs @@ -130,8 +130,11 @@ pub fn show_tables( .current_schema() .unwrap_or_else(|| DEFAULT_SCHEMA_NAME.to_string()) }; + // TODO(sunng87): move this function into query_ctx + let catalog = query_ctx.current_catalog(); + let catalog = catalog.as_deref().unwrap_or(DEFAULT_CATALOG_NAME); let schema = catalog_manager - .schema(DEFAULT_CATALOG_NAME, &schema) + .schema(catalog, &schema) .context(error::CatalogSnafu)? .context(error::SchemaNotFoundSnafu { schema })?; let tables = schema.table_names().context(error::CatalogSnafu)?; diff --git a/src/servers/src/http.rs b/src/servers/src/http.rs index 559ef2f8f9..2dd64d760a 100644 --- a/src/servers/src/http.rs +++ b/src/servers/src/http.rs @@ -30,6 +30,7 @@ use axum::body::BoxBody; use axum::error_handling::HandleErrorLayer; use axum::response::{Html, Json}; use axum::{routing, BoxError, Extension, Router}; +use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_error::prelude::ErrorExt; use common_error::status_code::StatusCode; use common_query::Output; @@ -40,6 +41,7 @@ use futures::FutureExt; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use serde_json::Value; +use session::context::QueryContext; use snafu::{ensure, ResultExt}; use tokio::sync::oneshot::{self, Sender}; use tokio::sync::Mutex; @@ -58,6 +60,38 @@ use crate::query_handler::{ }; use crate::server::Server; +/// create query context from database name information, catalog and schema are +/// resolved from the name +pub(crate) fn query_context_from_db( + query_handler: SqlQueryHandlerRef, + db: Option, +) -> std::result::Result, JsonResponse> { + if let Some(db) = &db { + let (catalog, schema) = super::parse_catalog_and_schema_from_client_database_name(db); + let catalog = catalog.unwrap_or(DEFAULT_CATALOG_NAME); + + match query_handler.is_valid_schema(catalog, schema) { + Ok(true) => Ok(Arc::new(QueryContext::with( + catalog.to_owned(), + schema.to_owned(), + ))), + Ok(false) => Err(JsonResponse::with_error( + format!("Database not found: {db}"), + StatusCode::DatabaseNotFound, + )), + Err(e) => Err(JsonResponse::with_error( + format!("Error checking database: {db}, {e}"), + StatusCode::Internal, + )), + } + } else { + Ok(Arc::new(QueryContext::with( + DEFAULT_CATALOG_NAME.to_owned(), + DEFAULT_SCHEMA_NAME.to_owned(), + ))) + } +} + const HTTP_API_VERSION: &str = "v1"; pub struct HttpServer { diff --git a/src/servers/src/http/handler.rs b/src/servers/src/http/handler.rs index 14ecf80284..0bd481e1a2 100644 --- a/src/servers/src/http/handler.rs +++ b/src/servers/src/http/handler.rs @@ -13,18 +13,16 @@ // limitations under the License. use std::collections::HashMap; -use std::sync::Arc; use std::time::Instant; use aide::transform::TransformOperation; use axum::extract::{Json, Query, State}; use axum::Extension; -use common_catalog::consts::DEFAULT_CATALOG_NAME; use common_error::status_code::StatusCode; use common_telemetry::metric; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use session::context::{QueryContext, UserInfo}; +use session::context::UserInfo; use crate::http::{ApiState, JsonResponse}; @@ -45,26 +43,12 @@ pub async fn sql( let sql_handler = &state.sql_handler; let start = Instant::now(); let resp = if let Some(sql) = ¶ms.sql { - let query_ctx = Arc::new(QueryContext::new()); - if let Some(db) = ¶ms.database { - match sql_handler.is_valid_schema(DEFAULT_CATALOG_NAME, db) { - Ok(true) => query_ctx.set_current_schema(db), - Ok(false) => { - return Json(JsonResponse::with_error( - format!("Database not found: {db}"), - StatusCode::DatabaseNotFound, - )); - } - Err(e) => { - return Json(JsonResponse::with_error( - format!("Error checking database: {db}, {e}"), - StatusCode::Internal, - )); - } + match super::query_context_from_db(sql_handler.clone(), params.database) { + Ok(query_ctx) => { + JsonResponse::from_output(sql_handler.do_query(sql, query_ctx).await).await } + Err(resp) => resp, } - - JsonResponse::from_output(sql_handler.do_query(sql, query_ctx).await).await } else { JsonResponse::with_error( "sql parameter is required.".to_string(), diff --git a/src/servers/src/http/script.rs b/src/servers/src/http/script.rs index 97b6341f09..25738e065c 100644 --- a/src/servers/src/http/script.rs +++ b/src/servers/src/http/script.rs @@ -90,6 +90,8 @@ pub async fn run_script( json_err!("invalid name"); } + // TODO(sunng87): query_context and db name resolution + let output = script_handler.execute_script(name.unwrap()).await; let resp = JsonResponse::from_output(vec![output]).await; diff --git a/src/servers/src/lib.rs b/src/servers/src/lib.rs index 9efba81a29..d34cd73951 100644 --- a/src/servers/src/lib.rs +++ b/src/servers/src/lib.rs @@ -39,3 +39,50 @@ pub enum Mode { Standalone, Distributed, } + +/// Attempt to parse catalog and schema from given database name +/// +/// The database name may come from different sources: +/// +/// - MySQL `schema` name in MySQL protocol login request: it's optional and user +/// and switch database using `USE` command +/// - Postgres `database` parameter in Postgres wire protocol, required +/// - HTTP RESTful API: the database parameter, optional +/// +/// When database name is provided, we attempt to parse catalog and schema from +/// it. We assume the format `[-]`: +/// +/// - If `[-]` part is not provided, we use whole database name as +/// schema name +/// - if `[-]` is provided, we split database name with `-` and use +/// `` and ``. +pub(crate) fn parse_catalog_and_schema_from_client_database_name(db: &str) -> (Option<&str>, &str) { + let parts = db.splitn(2, '-').collect::>(); + if parts.len() == 2 { + (Some(parts[0]), parts[1]) + } else { + (None, db) + } +} + +#[cfg(test)] +mod tests { + + #[test] + fn test_parse_catalog_and_schema_from_client_database_name() { + assert_eq!( + (None, "fullschema"), + super::parse_catalog_and_schema_from_client_database_name("fullschema") + ); + + assert_eq!( + (Some("catalog"), "schema"), + super::parse_catalog_and_schema_from_client_database_name("catalog-schema") + ); + + assert_eq!( + (Some("catalog"), "schema1-schema2"), + super::parse_catalog_and_schema_from_client_database_name("catalog-schema1-schema2") + ); + } +} diff --git a/src/servers/src/postgres.rs b/src/servers/src/postgres.rs index 4679b41da3..87507e998e 100644 --- a/src/servers/src/postgres.rs +++ b/src/servers/src/postgres.rs @@ -18,5 +18,9 @@ mod server; pub(crate) const METADATA_USER: &str = "user"; pub(crate) const METADATA_DATABASE: &str = "database"; +/// key to store our parsed catalog +pub(crate) const METADATA_CATALOG: &str = "catalog"; +/// key to store our parsed schema +pub(crate) const METADATA_SCHEMA: &str = "schema"; pub use server::PostgresServer; diff --git a/src/servers/src/postgres/auth_handler.rs b/src/servers/src/postgres/auth_handler.rs index ba5d316013..8259ed662e 100644 --- a/src/servers/src/postgres/auth_handler.rs +++ b/src/servers/src/postgres/auth_handler.rs @@ -151,31 +151,19 @@ impl StartupHandler for PgAuthStartupHandler { auth::save_startup_parameters_to_metadata(client, startup); // check if db is valid - let db_ref = client.metadata().get(super::METADATA_DATABASE); - if let Some(db) = db_ref { - if !self - .query_handler - .is_valid_schema(DEFAULT_CATALOG_NAME, db) - .map_err(|e| PgWireError::ApiError(Box::new(e)))? - { - send_error( - client, - "FATAL", - "3D000", - format!("Database not found: {db}"), - ) - .await?; + match resolve_db_info(client, self.query_handler.clone())? { + DbResolution::Resolved(catalog, schema) => { + client + .metadata_mut() + .insert(super::METADATA_CATALOG.to_owned(), catalog); + client + .metadata_mut() + .insert(super::METADATA_SCHEMA.to_owned(), schema); + } + DbResolution::NotFound(msg) => { + send_error(client, "FATAL", "3D000", msg).await?; return Ok(()); } - } else { - send_error( - client, - "FATAL", - "3D000", - "Database not specified".to_owned(), - ) - .await?; - return Ok(()); } if self.verifier.user_provider.is_some() { @@ -222,3 +210,36 @@ where client.close().await?; Ok(()) } + +enum DbResolution { + Resolved(String, String), + NotFound(String), +} + +/// A function extracted to resolve lifetime and readability issues: +fn resolve_db_info( + client: &mut C, + query_handler: SqlQueryHandlerRef, +) -> PgWireResult +where + C: ClientInfo + Unpin + Send, +{ + let db_ref = client.metadata().get(super::METADATA_DATABASE); + if let Some(db) = db_ref { + let (catalog, schema) = crate::parse_catalog_and_schema_from_client_database_name(db); + let catalog = catalog.unwrap_or(DEFAULT_CATALOG_NAME); + if query_handler + .is_valid_schema(catalog, schema) + .map_err(|e| PgWireError::ApiError(Box::new(e)))? + { + Ok(DbResolution::Resolved( + catalog.to_owned(), + schema.to_owned(), + )) + } else { + Ok(DbResolution::NotFound(format!("Database not found: {db}"))) + } + } else { + Ok(DbResolution::NotFound("Database not specified".to_owned())) + } +} diff --git a/src/servers/src/postgres/handler.rs b/src/servers/src/postgres/handler.rs index c09b967644..5d65bac618 100644 --- a/src/servers/src/postgres/handler.rs +++ b/src/servers/src/postgres/handler.rs @@ -47,7 +47,10 @@ where C: ClientInfo, { let query_context = QueryContext::new(); - if let Some(current_schema) = client.metadata().get(super::METADATA_DATABASE) { + if let Some(current_catalog) = client.metadata().get(super::METADATA_CATALOG) { + query_context.set_current_catalog(current_catalog); + } + if let Some(current_schema) = client.metadata().get(super::METADATA_SCHEMA) { query_context.set_current_schema(current_schema); } diff --git a/src/servers/tests/postgres/mod.rs b/src/servers/tests/postgres/mod.rs index 8759452384..4dac19c4da 100644 --- a/src/servers/tests/postgres/mod.rs +++ b/src/servers/tests/postgres/mod.rs @@ -16,7 +16,7 @@ use std::net::SocketAddr; use std::sync::Arc; use std::time::{Duration, SystemTime}; -use common_catalog::consts::DEFAULT_SCHEMA_NAME; +use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_runtime::Builder as RuntimeBuilder; use rand::rngs::StdRng; use rand::Rng; @@ -249,6 +249,24 @@ async fn test_using_db() -> Result<()> { .unwrap(); let result = client.simple_query("SELECT uint32s FROM numbers").await; assert!(result.is_ok()); + + let client = create_connection_with_given_catalog_schema( + server_port, + DEFAULT_CATALOG_NAME, + DEFAULT_SCHEMA_NAME, + ) + .await; + assert!(client.is_ok()); + + let client = + create_connection_with_given_catalog_schema(server_port, "notfound", DEFAULT_SCHEMA_NAME) + .await; + assert!(client.is_err()); + + let client = + create_connection_with_given_catalog_schema(server_port, DEFAULT_CATALOG_NAME, "notfound") + .await; + assert!(client.is_err()); Ok(()) } @@ -330,6 +348,17 @@ async fn create_connection_with_given_db( Ok(client) } +async fn create_connection_with_given_catalog_schema( + port: u16, + catalog: &str, + schema: &str, +) -> std::result::Result { + let url = format!("host=127.0.0.1 port={port} connect_timeout=2 dbname={catalog}-{schema}"); + let (client, conn) = tokio_postgres::connect(&url, NoTls).await?; + tokio::spawn(conn); + Ok(client) +} + async fn create_connection_without_db(port: u16) -> std::result::Result { let url = format!("host=127.0.0.1 port={port} connect_timeout=2"); let (client, conn) = tokio_postgres::connect(&url, NoTls).await?; diff --git a/src/session/src/context.rs b/src/session/src/context.rs index a1c2086d47..1abe63f6b1 100644 --- a/src/session/src/context.rs +++ b/src/session/src/context.rs @@ -22,6 +22,7 @@ pub type QueryContextRef = Arc; pub type ConnInfoRef = Arc; pub struct QueryContext { + current_catalog: ArcSwapOption, current_schema: ArcSwapOption, } @@ -38,12 +39,14 @@ impl QueryContext { pub fn new() -> Self { Self { + current_catalog: ArcSwapOption::new(None), current_schema: ArcSwapOption::new(None), } } - pub fn with_current_schema(schema: String) -> Self { + pub fn with(catalog: String, schema: String) -> Self { Self { + current_catalog: ArcSwapOption::new(Some(Arc::new(catalog))), current_schema: ArcSwapOption::new(Some(Arc::new(schema))), } } @@ -52,6 +55,10 @@ impl QueryContext { self.current_schema.load().as_deref().cloned() } + pub fn current_catalog(&self) -> Option { + self.current_catalog.load().as_deref().cloned() + } + pub fn set_current_schema(&self, schema: &str) { let last = self.current_schema.swap(Some(Arc::new(schema.to_string()))); info!( @@ -59,6 +66,16 @@ impl QueryContext { schema, last ) } + + pub fn set_current_catalog(&self, catalog: &str) { + let last = self + .current_catalog + .swap(Some(Arc::new(catalog.to_string()))); + info!( + "set new session default catalog: {:?}, swap old: {:?}", + catalog, last + ) + } } pub const DEFAULT_USERNAME: &str = "greptime"; diff --git a/tests-integration/tests/http.rs b/tests-integration/tests/http.rs index f508863df4..7c34810d89 100644 --- a/tests-integration/tests/http.rs +++ b/tests-integration/tests/http.rs @@ -221,6 +221,43 @@ pub async fn test_sql_api(store_type: StorageType) { let body = serde_json::from_str::(&res.text().await).unwrap(); assert_eq!(body.code(), ErrorCode::DatabaseNotFound as u32); + // test catalog-schema given + let res = client + .get("/v1/sql?database=greptime-public&sql=select cpu, ts from demo limit 1") + .send() + .await; + assert_eq!(res.status(), StatusCode::OK); + + let body = serde_json::from_str::(&res.text().await).unwrap(); + assert!(body.success()); + assert!(body.execution_time_ms().is_some()); + let outputs = body.output().unwrap(); + assert_eq!(outputs.len(), 1); + assert_eq!( + outputs[0], + serde_json::from_value::(json!({ + "records":{"schema":{"column_schemas":[{"name":"cpu","data_type":"Float64"},{"name":"ts","data_type":"TimestampMillisecond"}]},"rows":[[66.6,0]]} + })).unwrap() + ); + + // test invalid catalog + let res = client + .get("/v1/sql?database=notfound2-schema&sql=select cpu, ts from demo limit 1") + .send() + .await; + assert_eq!(res.status(), StatusCode::OK); + let body = serde_json::from_str::(&res.text().await).unwrap(); + assert_eq!(body.code(), ErrorCode::Internal as u32); + + // test invalid schema + let res = client + .get("/v1/sql?database=greptime-schema&sql=select cpu, ts from demo limit 1") + .send() + .await; + assert_eq!(res.status(), StatusCode::OK); + let body = serde_json::from_str::(&res.text().await).unwrap(); + assert_eq!(body.code(), ErrorCode::DatabaseNotFound as u32); + guard.remove_all().await; }