diff --git a/Cargo.lock b/Cargo.lock index 8c92e8d523..866d173e86 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7024,6 +7024,7 @@ dependencies = [ "catalog", "client", "common-catalog", + "common-error", "common-runtime", "common-telemetry", "datanode", diff --git a/src/common/error/src/status_code.rs b/src/common/error/src/status_code.rs index 3bfb42080e..8f8a576e8d 100644 --- a/src/common/error/src/status_code.rs +++ b/src/common/error/src/status_code.rs @@ -51,6 +51,7 @@ pub enum StatusCode { TableNotFound = 4001, TableColumnNotFound = 4002, TableColumnExists = 4003, + DatabaseNotFound = 4004, // ====== End of catalog related status code ======= // ====== Begin of storage related status code ===== diff --git a/src/datanode/src/instance/sql.rs b/src/datanode/src/instance/sql.rs index 4b0fc0ed79..41269882ec 100644 --- a/src/datanode/src/instance/sql.rs +++ b/src/datanode/src/instance/sql.rs @@ -228,6 +228,13 @@ impl SqlQueryHandler for Instance { }) .context(servers::error::ExecuteStatementSnafu) } + + fn is_valid_schema(&self, catalog: &str, schema: &str) -> servers::error::Result { + self.catalog_manager + .schema(catalog, schema) + .map(|s| s.is_some()) + .context(servers::error::CatalogSnafu) + } } #[cfg(test)] diff --git a/src/frontend/src/instance.rs b/src/frontend/src/instance.rs index babc06a76f..19d49a611d 100644 --- a/src/frontend/src/instance.rs +++ b/src/frontend/src/instance.rs @@ -44,7 +44,7 @@ use distributed::DistInstance; use meta_client::client::{MetaClient, MetaClientBuilder}; use meta_client::MetaClientOpts; use servers::query_handler::{ - CatalogHandler, GrpcAdminHandler, GrpcAdminHandlerRef, GrpcQueryHandler, GrpcQueryHandlerRef, + GrpcAdminHandler, GrpcAdminHandlerRef, GrpcQueryHandler, GrpcQueryHandlerRef, InfluxdbLineProtocolHandler, OpentsdbProtocolHandler, PrometheusProtocolHandler, ScriptHandler, ScriptHandlerRef, SqlQueryHandler, SqlQueryHandlerRef, }; @@ -79,7 +79,6 @@ pub trait FrontendInstance: + InfluxdbLineProtocolHandler + PrometheusProtocolHandler + ScriptHandler - + CatalogHandler + Send + Sync + 'static @@ -594,6 +593,13 @@ impl SqlQueryHandler for Instance { ) -> server_error::Result { self.query_statement(stmt, query_ctx).await } + + fn is_valid_schema(&self, catalog: &str, schema: &str) -> server_error::Result { + self.catalog_manager + .schema(catalog, schema) + .map(|s| s.is_some()) + .context(server_error::CatalogSnafu) + } } #[async_trait] @@ -664,15 +670,6 @@ impl GrpcAdminHandler for Instance { } } -impl CatalogHandler for Instance { - fn is_valid_schema(&self, catalog: &str, schema: &str) -> server_error::Result { - self.catalog_manager - .schema(catalog, schema) - .map(|s| s.is_some()) - .context(server_error::CatalogSnafu) - } -} - #[cfg(test)] mod tests { use std::assert_matches::assert_matches; diff --git a/src/frontend/src/instance/distributed.rs b/src/frontend/src/instance/distributed.rs index de09c1b4d3..8205a8cd0e 100644 --- a/src/frontend/src/instance/distributed.rs +++ b/src/frontend/src/instance/distributed.rs @@ -23,7 +23,7 @@ use api::v1::{ }; use async_trait::async_trait; use catalog::helper::{SchemaKey, SchemaValue, TableGlobalKey, TableGlobalValue}; -use catalog::CatalogList; +use catalog::{CatalogList, CatalogManager}; use chrono::DateTime; use client::admin::{admin_result_to_output, Admin}; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; @@ -357,6 +357,13 @@ impl SqlQueryHandler for DistInstance { .map_err(BoxedError::new) .context(server_error::ExecuteStatementSnafu) } + + fn is_valid_schema(&self, catalog: &str, schema: &str) -> server_error::Result { + self.catalog_manager + .schema(catalog, schema) + .map(|s| s.is_some()) + .context(server_error::CatalogSnafu) + } } #[async_trait] diff --git a/src/frontend/src/server.rs b/src/frontend/src/server.rs index 6c04860192..d3c55b8c97 100644 --- a/src/frontend/src/server.rs +++ b/src/frontend/src/server.rs @@ -98,7 +98,6 @@ impl Services { ); let pg_server = Box::new(PostgresServer::new( - instance.clone(), instance.clone(), opts.tls.clone(), pg_io_runtime, diff --git a/src/servers/src/http.rs b/src/servers/src/http.rs index 955c67c9cb..263b8b1193 100644 --- a/src/servers/src/http.rs +++ b/src/servers/src/http.rs @@ -554,6 +554,10 @@ mod test { ) -> Result { unimplemented!() } + + fn is_valid_schema(&self, _catalog: &str, _schema: &str) -> Result { + Ok(true) + } } fn timeout() -> TimeoutLayer { diff --git a/src/servers/src/http/handler.rs b/src/servers/src/http/handler.rs index d148851be4..419b0f891b 100644 --- a/src/servers/src/http/handler.rs +++ b/src/servers/src/http/handler.rs @@ -19,6 +19,7 @@ use std::time::Instant; use aide::transform::TransformOperation; use axum::extract::{Json, Query, State}; use axum::Extension; +use common_catalog::consts::DEFAULT_CATALOG_NAME; use common_error::status_code::StatusCode; use common_telemetry::metric; use schemars::JsonSchema; @@ -46,8 +47,22 @@ pub async fn sql( let start = Instant::now(); let resp = if let Some(sql) = ¶ms.sql { let query_ctx = Arc::new(QueryContext::new()); - if let Some(db) = params.database { - query_ctx.set_current_schema(db.as_ref()); + if let Some(db) = ¶ms.database { + match sql_handler.is_valid_schema(DEFAULT_CATALOG_NAME, db) { + Ok(true) => query_ctx.set_current_schema(db), + Ok(false) => { + return Json(JsonResponse::with_error( + format!("Database not found: {}", db), + StatusCode::DatabaseNotFound, + )); + } + Err(e) => { + return Json(JsonResponse::with_error( + format!("Error checking database: {}, {}", db, e), + StatusCode::Internal, + )); + } + } } JsonResponse::from_output(sql_handler.do_query(sql, query_ctx).await).await diff --git a/src/servers/src/postgres/auth_handler.rs b/src/servers/src/postgres/auth_handler.rs index 70a30974f1..22ea5798ba 100644 --- a/src/servers/src/postgres/auth_handler.rs +++ b/src/servers/src/postgres/auth_handler.rs @@ -29,7 +29,7 @@ use snafu::ResultExt; use crate::auth::{Identity, Password, UserProviderRef}; use crate::error; use crate::error::Result; -use crate::query_handler::CatalogHandlerRef; +use crate::query_handler::SqlQueryHandlerRef; struct PgPwdVerifier { user_provider: Option, @@ -110,20 +110,20 @@ pub struct PgAuthStartupHandler { verifier: PgPwdVerifier, param_provider: GreptimeDBStartupParameters, force_tls: bool, - catalog_handler: CatalogHandlerRef, + query_handler: SqlQueryHandlerRef, } impl PgAuthStartupHandler { pub fn new( user_provider: Option, force_tls: bool, - catalog_handler: CatalogHandlerRef, + query_handler: SqlQueryHandlerRef, ) -> Self { PgAuthStartupHandler { verifier: PgPwdVerifier { user_provider }, param_provider: GreptimeDBStartupParameters::new(), force_tls, - catalog_handler, + query_handler, } } } @@ -154,7 +154,7 @@ impl StartupHandler for PgAuthStartupHandler { let db_ref = client.metadata().get(super::METADATA_DATABASE); if let Some(db) = db_ref { if !self - .catalog_handler + .query_handler .is_valid_schema(DEFAULT_CATALOG_NAME, db) .map_err(|e| PgWireError::ApiError(Box::new(e)))? { diff --git a/src/servers/src/postgres/server.rs b/src/servers/src/postgres/server.rs index 09b2854c4c..6decbe9da6 100644 --- a/src/servers/src/postgres/server.rs +++ b/src/servers/src/postgres/server.rs @@ -29,7 +29,7 @@ use crate::auth::UserProviderRef; use crate::error::Result; use crate::postgres::auth_handler::PgAuthStartupHandler; use crate::postgres::handler::PostgresServerHandler; -use crate::query_handler::{CatalogHandlerRef, SqlQueryHandlerRef}; +use crate::query_handler::SqlQueryHandlerRef; use crate::server::{AbortableStream, BaseTcpServer, Server}; use crate::tls::TlsOption; @@ -44,16 +44,15 @@ impl PostgresServer { /// Creates a new Postgres server with provided query_handler and async runtime pub fn new( query_handler: SqlQueryHandlerRef, - catalog_handler: CatalogHandlerRef, tls: TlsOption, io_runtime: Arc, user_provider: Option, ) -> PostgresServer { - let postgres_handler = Arc::new(PostgresServerHandler::new(query_handler)); + let postgres_handler = Arc::new(PostgresServerHandler::new(query_handler.clone())); let startup_handler = Arc::new(PgAuthStartupHandler::new( user_provider, tls.should_force_tls(), - catalog_handler, + query_handler, )); PostgresServer { base_server: BaseTcpServer::create_server("Postgres", io_runtime), diff --git a/src/servers/src/query_handler.rs b/src/servers/src/query_handler.rs index 1ef549c763..2c7d47ce40 100644 --- a/src/servers/src/query_handler.rs +++ b/src/servers/src/query_handler.rs @@ -43,7 +43,6 @@ pub type OpentsdbProtocolHandlerRef = Arc; pub type PrometheusProtocolHandlerRef = Arc; pub type ScriptHandlerRef = Arc; -pub type CatalogHandlerRef = Arc; #[async_trait] pub trait SqlQueryHandler { @@ -54,6 +53,9 @@ pub trait SqlQueryHandler { stmt: Statement, query_ctx: QueryContextRef, ) -> Result; + + /// check if schema is valid + fn is_valid_schema(&self, catalog: &str, schema: &str) -> Result; } #[async_trait] @@ -101,8 +103,3 @@ pub trait PrometheusProtocolHandler { /// Handling push gateway requests async fn ingest_metrics(&self, metrics: Metrics) -> Result<()>; } - -pub trait CatalogHandler { - /// check if schema is valid - fn is_valid_schema(&self, catalog: &str, schema: &str) -> Result; -} diff --git a/src/servers/tests/http/influxdb_test.rs b/src/servers/tests/http/influxdb_test.rs index 547bb90900..3fffd76e6f 100644 --- a/src/servers/tests/http/influxdb_test.rs +++ b/src/servers/tests/http/influxdb_test.rs @@ -57,6 +57,10 @@ impl SqlQueryHandler for DummyInstance { ) -> Result { unimplemented!() } + + fn is_valid_schema(&self, _catalog: &str, _schema: &str) -> Result { + Ok(true) + } } fn make_test_app(tx: mpsc::Sender<(String, String)>) -> Router { diff --git a/src/servers/tests/http/opentsdb_test.rs b/src/servers/tests/http/opentsdb_test.rs index 15ade25fab..ffb76d8f89 100644 --- a/src/servers/tests/http/opentsdb_test.rs +++ b/src/servers/tests/http/opentsdb_test.rs @@ -56,6 +56,10 @@ impl SqlQueryHandler for DummyInstance { ) -> Result { unimplemented!() } + + fn is_valid_schema(&self, _catalog: &str, _schema: &str) -> Result { + Ok(true) + } } fn make_test_app(tx: mpsc::Sender) -> Router { diff --git a/src/servers/tests/http/prometheus_test.rs b/src/servers/tests/http/prometheus_test.rs index 9a895e3bc4..0d8e72f3c8 100644 --- a/src/servers/tests/http/prometheus_test.rs +++ b/src/servers/tests/http/prometheus_test.rs @@ -81,6 +81,10 @@ impl SqlQueryHandler for DummyInstance { ) -> Result { unimplemented!() } + + fn is_valid_schema(&self, _catalog: &str, _schema: &str) -> Result { + Ok(true) + } } fn make_test_app(tx: mpsc::Sender<(String, Vec)>) -> Router { diff --git a/src/servers/tests/mod.rs b/src/servers/tests/mod.rs index d8d9bc974d..0105f7d3f1 100644 --- a/src/servers/tests/mod.rs +++ b/src/servers/tests/mod.rs @@ -23,7 +23,7 @@ use common_query::Output; use query::{QueryEngineFactory, QueryEngineRef}; use servers::error::Result; use servers::query_handler::{ - CatalogHandler, ScriptHandler, ScriptHandlerRef, SqlQueryHandler, SqlQueryHandlerRef, + ScriptHandler, ScriptHandlerRef, SqlQueryHandler, SqlQueryHandlerRef, }; use table::test_util::MemTable; @@ -67,6 +67,10 @@ impl SqlQueryHandler for DummyInstance { ) -> Result { unimplemented!() } + + fn is_valid_schema(&self, catalog: &str, schema: &str) -> Result { + Ok(catalog == DEFAULT_CATALOG_NAME && schema == DEFAULT_SCHEMA_NAME) + } } #[async_trait] @@ -92,12 +96,6 @@ impl ScriptHandler for DummyInstance { } } -impl CatalogHandler for DummyInstance { - fn is_valid_schema(&self, catalog: &str, schema: &str) -> Result { - Ok(catalog == DEFAULT_CATALOG_NAME && schema == DEFAULT_SCHEMA_NAME) - } -} - fn create_testing_instance(table: MemTable) -> DummyInstance { let table_name = table.table_name().to_string(); let table = Arc::new(table); diff --git a/src/servers/tests/postgres/mod.rs b/src/servers/tests/postgres/mod.rs index 8c7db2f056..556fb52875 100644 --- a/src/servers/tests/postgres/mod.rs +++ b/src/servers/tests/postgres/mod.rs @@ -55,7 +55,6 @@ fn create_postgres_server( }; Ok(Box::new(PostgresServer::new( - instance.clone(), instance, tls, io_runtime, diff --git a/tests-integration/Cargo.toml b/tests-integration/Cargo.toml index 1a7107fc8f..37a58b6f39 100644 --- a/tests-integration/Cargo.toml +++ b/tests-integration/Cargo.toml @@ -11,6 +11,7 @@ axum-test-helper = { git = "https://github.com/sunng87/axum-test-helper.git", br catalog = { path = "../src/catalog" } client = { path = "../src/client" } common-catalog = { path = "../src/common/catalog" } +common-error = { path = "../src/common/error" } common-runtime = { path = "../src/common/runtime" } common-telemetry = { path = "../src/common/telemetry" } datanode = { path = "../src/datanode" } diff --git a/tests-integration/tests/http.rs b/tests-integration/tests/http.rs index ecdfac3b62..e85ed738ef 100644 --- a/tests-integration/tests/http.rs +++ b/tests-integration/tests/http.rs @@ -14,6 +14,7 @@ use axum::http::StatusCode; use axum_test_helper::TestClient; +use common_error::status_code::StatusCode as ErrorCode; use serde_json::json; use servers::http::handler::HealthResponse; use servers::http::{JsonOutput, JsonResponse}; @@ -192,6 +193,34 @@ pub async fn test_sql_api(store_type: StorageType) { assert!(body.execution_time_ms().is_some()); assert!(body.error().unwrap().contains("not found")); + // test database given + let res = client + .get("/v1/sql?database=public&sql=select cpu, ts from demo limit 1") + .send() + .await; + assert_eq!(res.status(), StatusCode::OK); + + let body = serde_json::from_str::(&res.text().await).unwrap(); + assert!(body.success()); + assert!(body.execution_time_ms().is_some()); + let outputs = body.output().unwrap(); + assert_eq!(outputs.len(), 1); + assert_eq!( + outputs[0], + serde_json::from_value::(json!({ + "records":{"schema":{"column_schemas":[{"name":"cpu","data_type":"Float64"},{"name":"ts","data_type":"TimestampMillisecond"}]},"rows":[[66.6,0]]} + })).unwrap() + ); + + // test database not found + let res = client + .get("/v1/sql?database=notfound&sql=select cpu, ts from demo limit 1") + .send() + .await; + assert_eq!(res.status(), StatusCode::OK); + let body = serde_json::from_str::(&res.text().await).unwrap(); + assert_eq!(body.code(), ErrorCode::DatabaseNotFound as u32); + guard.remove_all().await; }