diff --git a/src/common/catalog/src/lib.rs b/src/common/catalog/src/lib.rs index 6363958862..1527b7f4a0 100644 --- a/src/common/catalog/src/lib.rs +++ b/src/common/catalog/src/lib.rs @@ -32,6 +32,33 @@ pub fn build_db_string(catalog: &str, schema: &str) -> String { } } +/// Attempt to parse catalog and schema from given database name +/// +/// The database name may come from different sources: +/// +/// - MySQL `schema` name in MySQL protocol login request: it's optional and user +/// and switch database using `USE` command +/// - Postgres `database` parameter in Postgres wire protocol, required +/// - HTTP RESTful API: the database parameter, optional +/// - gRPC: the dbname field in header, optional but has a higher priority than +/// original catalog/schema +/// +/// When database name is provided, we attempt to parse catalog and schema from +/// it. We assume the format `[-]`: +/// +/// - If `[-]` part is not provided, we use whole database name as +/// schema name +/// - if `[-]` is provided, we split database name with `-` and use +/// `` and ``. +pub fn parse_catalog_and_schema_from_db_string(db: &str) -> (&str, &str) { + let parts = db.splitn(2, '-').collect::>(); + if parts.len() == 2 { + (parts[0], parts[1]) + } else { + (DEFAULT_CATALOG_NAME, db) + } +} + #[cfg(test)] mod tests { use super::*; @@ -41,4 +68,22 @@ mod tests { assert_eq!("test", build_db_string(DEFAULT_CATALOG_NAME, "test")); assert_eq!("a0b1c2d3-test", build_db_string("a0b1c2d3", "test")); } + + #[test] + fn test_parse_catalog_and_schema() { + assert_eq!( + (DEFAULT_CATALOG_NAME, "fullschema"), + parse_catalog_and_schema_from_db_string("fullschema") + ); + + assert_eq!( + ("catalog", "schema"), + parse_catalog_and_schema_from_db_string("catalog-schema") + ); + + assert_eq!( + ("catalog", "schema1-schema2"), + parse_catalog_and_schema_from_db_string("catalog-schema1-schema2") + ); + } } diff --git a/src/datanode/src/instance.rs b/src/datanode/src/instance.rs index 3196551489..3f2def6626 100644 --- a/src/datanode/src/instance.rs +++ b/src/datanode/src/instance.rs @@ -44,7 +44,7 @@ use mito::engine::MitoEngine; use object_store::{util, ObjectStore}; use query::query_engine::{QueryEngineFactory, QueryEngineRef}; use servers::Mode; -use session::context::QueryContext; +use session::context::QueryContextBuilder; use snafu::prelude::*; use storage::compaction::{CompactionHandler, CompactionSchedulerRef}; use storage::config::EngineConfig as StorageEngineConfig; @@ -379,14 +379,14 @@ impl Instance { }) }) .collect::>(); - let flush_result = futures::future::try_join_all( - flush_requests - .into_iter() - .map(|request| self.sql_handler.execute(request, QueryContext::arc())), - ) - .await - .map_err(BoxedError::new) - .context(ShutdownInstanceSnafu); + let flush_result = + futures::future::try_join_all(flush_requests.into_iter().map(|request| { + self.sql_handler + .execute(request, QueryContextBuilder::default().build()) + })) + .await + .map_err(BoxedError::new) + .context(ShutdownInstanceSnafu); info!("Flushed all tables result: {}", flush_result.is_ok()); let _ = flush_result?; diff --git a/src/datanode/src/instance/grpc.rs b/src/datanode/src/instance/grpc.rs index 0abbc8a296..f02ff845ea 100644 --- a/src/datanode/src/instance/grpc.rs +++ b/src/datanode/src/instance/grpc.rs @@ -191,13 +191,13 @@ impl Instance { name: "DdlRequest.expr", })?; match expr { - DdlExpr::CreateTable(expr) => self.handle_create(expr).await, - DdlExpr::Alter(expr) => self.handle_alter(expr).await, + DdlExpr::CreateTable(expr) => self.handle_create(expr, query_ctx).await, + DdlExpr::Alter(expr) => self.handle_alter(expr, query_ctx).await, DdlExpr::CreateDatabase(expr) => self.handle_create_database(expr, query_ctx).await, - DdlExpr::DropTable(expr) => self.handle_drop_table(expr).await, - DdlExpr::FlushTable(expr) => self.handle_flush_table(expr).await, - DdlExpr::CompactTable(expr) => self.handle_compact_table(expr).await, - DdlExpr::TruncateTable(expr) => self.handle_truncate_table(expr).await, + DdlExpr::DropTable(expr) => self.handle_drop_table(expr, query_ctx).await, + DdlExpr::FlushTable(expr) => self.handle_flush_table(expr, query_ctx).await, + DdlExpr::CompactTable(expr) => self.handle_compact_table(expr, query_ctx).await, + DdlExpr::TruncateTable(expr) => self.handle_truncate_table(expr, query_ctx).await, } } } diff --git a/src/datanode/src/server/grpc.rs b/src/datanode/src/server/grpc.rs index d856e947f6..4159d1e8ee 100644 --- a/src/datanode/src/server/grpc.rs +++ b/src/datanode/src/server/grpc.rs @@ -20,7 +20,7 @@ use common_catalog::format_full_table_name; use common_grpc_expr::{alter_expr_to_request, create_expr_to_request}; use common_query::Output; use common_telemetry::info; -use session::context::QueryContext; +use session::context::QueryContextRef; use snafu::prelude::*; use table::requests::{ CompactTableRequest, DropTableRequest, FlushTableRequest, TruncateTableRequest, @@ -35,7 +35,11 @@ use crate::sql::SqlRequest; impl Instance { /// Handle gRPC create table requests. - pub(crate) async fn handle_create(&self, expr: CreateTableExpr) -> Result { + pub(crate) async fn handle_create( + &self, + expr: CreateTableExpr, + ctx: QueryContextRef, + ) -> Result { let table_name = format!( "{}.{}.{}", expr.catalog_name, expr.schema_name, expr.table_name @@ -69,11 +73,15 @@ impl Instance { .context(CreateExprToRequestSnafu)?; self.sql_handler() - .execute(SqlRequest::CreateTable(request), QueryContext::arc()) + .execute(SqlRequest::CreateTable(request), ctx) .await } - pub(crate) async fn handle_alter(&self, expr: AlterExpr) -> Result { + pub(crate) async fn handle_alter( + &self, + expr: AlterExpr, + ctx: QueryContextRef, + ) -> Result { let table_id = match expr.table_id.as_ref() { None => { self.catalog_manager @@ -96,11 +104,15 @@ impl Instance { let request = alter_expr_to_request(table_id, expr).context(AlterExprToRequestSnafu)?; self.sql_handler() - .execute(SqlRequest::Alter(request), QueryContext::arc()) + .execute(SqlRequest::Alter(request), ctx) .await } - pub(crate) async fn handle_drop_table(&self, expr: DropTableExpr) -> Result { + pub(crate) async fn handle_drop_table( + &self, + expr: DropTableExpr, + ctx: QueryContextRef, + ) -> Result { let table = self .catalog_manager .table(&expr.catalog_name, &expr.schema_name, &expr.table_name) @@ -121,11 +133,15 @@ impl Instance { table_id: table.table_info().ident.table_id, }; self.sql_handler() - .execute(SqlRequest::DropTable(req), QueryContext::arc()) + .execute(SqlRequest::DropTable(req), ctx) .await } - pub(crate) async fn handle_flush_table(&self, expr: FlushTableExpr) -> Result { + pub(crate) async fn handle_flush_table( + &self, + expr: FlushTableExpr, + ctx: QueryContextRef, + ) -> Result { let table_name = if expr.table_name.trim().is_empty() { None } else { @@ -140,11 +156,15 @@ impl Instance { wait: None, }; self.sql_handler() - .execute(SqlRequest::FlushTable(req), QueryContext::arc()) + .execute(SqlRequest::FlushTable(req), ctx) .await } - pub(crate) async fn handle_compact_table(&self, expr: CompactTableExpr) -> Result { + pub(crate) async fn handle_compact_table( + &self, + expr: CompactTableExpr, + ctx: QueryContextRef, + ) -> Result { let table_name = if expr.table_name.trim().is_empty() { None } else { @@ -159,11 +179,15 @@ impl Instance { wait: None, }; self.sql_handler() - .execute(SqlRequest::CompactTable(req), QueryContext::arc()) + .execute(SqlRequest::CompactTable(req), ctx) .await } - pub(crate) async fn handle_truncate_table(&self, expr: TruncateTableExpr) -> Result { + pub(crate) async fn handle_truncate_table( + &self, + expr: TruncateTableExpr, + ctx: QueryContextRef, + ) -> Result { let table = self .catalog_manager .table(&expr.catalog_name, &expr.schema_name, &expr.table_name) @@ -184,7 +208,7 @@ impl Instance { table_id: table.table_info().ident.table_id, }; self.sql_handler() - .execute(SqlRequest::TruncateTable(req), QueryContext::arc()) + .execute(SqlRequest::TruncateTable(req), ctx) .await } } diff --git a/src/script/src/python/engine.rs b/src/script/src/python/engine.rs index 0fa23b1bae..27431c0b34 100644 --- a/src/script/src/python/engine.rs +++ b/src/script/src/python/engine.rs @@ -35,7 +35,7 @@ use datatypes::vectors::VectorRef; use futures::Stream; use query::parser::{QueryLanguageParser, QueryStatement}; use query::QueryEngineRef; -use session::context::QueryContext; +use session::context::QueryContextBuilder; use snafu::{ensure, ResultExt}; use sql::statements::statement::Statement; @@ -286,15 +286,16 @@ impl Script for PyScript { matches!(stmt, QueryStatement::Sql(Statement::Query { .. })), error::UnsupportedSqlSnafu { sql } ); + let ctx = QueryContextBuilder::default().build(); let plan = self .query_engine .planner() - .plan(stmt, QueryContext::arc()) + .plan(stmt, ctx.clone()) .await .context(DatabaseQuerySnafu)?; let res = self .query_engine - .execute(plan, QueryContext::arc()) + .execute(plan, ctx) .await .context(DatabaseQuerySnafu)?; let copr = self.copr.clone(); diff --git a/src/script/src/python/ffi_types/copr.rs b/src/script/src/python/ffi_types/copr.rs index 8df5e8994d..eb8cf98c92 100644 --- a/src/script/src/python/ffi_types/copr.rs +++ b/src/script/src/python/ffi_types/copr.rs @@ -34,7 +34,7 @@ use rustpython_compiler_core::CodeObject; use rustpython_vm as vm; #[cfg(test)] use serde::Deserialize; -use session::context::{QueryContext, QueryContextBuilder}; +use session::context::QueryContextBuilder; use snafu::{OptionExt, ResultExt}; use vm::convert::ToPyObject; use vm::{pyclass as rspyclass, PyObjectRef, PyPayload, PyResult, VirtualMachine}; @@ -379,14 +379,15 @@ impl PyQueryEngine { let rt = tokio::runtime::Runtime::new().map_err(|e| e.to_string())?; let handle = rt.handle().clone(); let res = handle.block_on(async { + let ctx = QueryContextBuilder::default().build(); let plan = engine .planner() - .plan(stmt, QueryContextBuilder::default().build()) + .plan(stmt, ctx.clone()) .await .map_err(|e| e.to_string())?; let res = engine .clone() - .execute(plan, QueryContext::arc()) + .execute(plan, ctx) .await .map_err(|e| e.to_string()); match res { diff --git a/src/script/src/table.rs b/src/script/src/table.rs index c9cb1becca..d734aed7b9 100644 --- a/src/script/src/table.rs +++ b/src/script/src/table.rs @@ -32,7 +32,7 @@ use datatypes::schema::{ColumnSchema, RawSchema}; use datatypes::vectors::{StringVector, TimestampMillisecondVector, Vector, VectorRef}; use query::parser::QueryLanguageParser; use query::QueryEngineRef; -use session::context::QueryContext; +use session::context::QueryContextBuilder; use snafu::{ensure, OptionExt, ResultExt}; use store_api::storage::ScanRequest; use table::requests::{CreateTableRequest, InsertRequest, TableOptions}; @@ -246,17 +246,18 @@ impl ScriptsTable { name ); let stmt = QueryLanguageParser::parse_sql(&sql).unwrap(); + let ctx = QueryContextBuilder::default().build(); let plan = self .query_engine .planner() - .plan(stmt, QueryContext::arc()) + .plan(stmt, ctx.clone()) .await .unwrap(); let stream = match self .query_engine - .execute(plan, QueryContext::arc()) + .execute(plan, ctx) .await .context(FindScriptSnafu { name })? { diff --git a/src/servers/src/grpc/handler.rs b/src/servers/src/grpc/handler.rs index 9216f31139..75dab865fe 100644 --- a/src/servers/src/grpc/handler.rs +++ b/src/servers/src/grpc/handler.rs @@ -19,6 +19,7 @@ use api::helper::request_type; use api::v1::auth_header::AuthScheme; use api::v1::{Basic, GreptimeRequest, RequestHeader}; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; +use common_catalog::parse_catalog_and_schema_from_db_string; use common_error::ext::ErrorExt; use common_error::status_code::StatusCode; use common_query::Output; @@ -153,7 +154,7 @@ pub(crate) fn create_query_context(header: Option<&RequestHeader>) -> QueryConte // We provide dbname field in newer versions of protos/sdks // parse dbname from header in priority if !header.dbname.is_empty() { - crate::parse_catalog_and_schema_from_client_database_name(&header.dbname) + parse_catalog_and_schema_from_db_string(&header.dbname) } else { ( if !header.catalog.is_empty() { diff --git a/src/servers/src/http.rs b/src/servers/src/http.rs index ed77848d13..8f8c7b005c 100644 --- a/src/servers/src/http.rs +++ b/src/servers/src/http.rs @@ -42,6 +42,8 @@ use axum::middleware::{self, Next}; use axum::response::{Html, IntoResponse, Json}; use axum::{routing, BoxError, Extension, Router}; use common_base::readable_size::ReadableSize; +use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; +use common_catalog::parse_catalog_and_schema_from_db_string; use common_error::ext::ErrorExt; use common_error::status_code::StatusCode; use common_query::Output; @@ -86,23 +88,28 @@ pub(crate) async fn query_context_from_db( query_handler: ServerSqlQueryHandlerRef, db: Option, ) -> std::result::Result, JsonResponse> { - if let Some(db) = &db { - let (catalog, schema) = super::parse_catalog_and_schema_from_client_database_name(db); + let (catalog, schema) = if let Some(db) = &db { + let (catalog, schema) = parse_catalog_and_schema_from_db_string(db); match query_handler.is_valid_schema(catalog, schema).await { - Ok(true) => Ok(QueryContext::with(catalog, schema)), - Ok(false) => Err(JsonResponse::with_error( - format!("Database not found: {db}"), - StatusCode::DatabaseNotFound, - )), - Err(e) => Err(JsonResponse::with_error( - format!("Error checking database: {db}, {e}"), - StatusCode::Internal, - )), + Ok(true) => (catalog, schema), + Ok(false) => { + return Err(JsonResponse::with_error( + format!("Database not found: {db}"), + StatusCode::DatabaseNotFound, + )) + } + Err(e) => { + return Err(JsonResponse::with_error( + format!("Error checking database: {db}, {e}"), + StatusCode::Internal, + )) + } } } else { - Ok(QueryContext::arc()) - } + (DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME) + }; + Ok(QueryContext::with(catalog, schema)) } pub const HTTP_API_VERSION: &str = "v1"; diff --git a/src/servers/src/http/admin.rs b/src/servers/src/http/admin.rs index a6e25329e4..76adf8271b 100644 --- a/src/servers/src/http/admin.rs +++ b/src/servers/src/http/admin.rs @@ -64,7 +64,9 @@ pub async fn flush( })), }); - grpc_handler.do_query(request, QueryContext::arc()).await?; + grpc_handler + .do_query(request, QueryContext::with(&catalog_name, &schema_name)) + .await?; Ok((StatusCode::NO_CONTENT, ())) } @@ -104,6 +106,8 @@ pub async fn compact( })), }); - grpc_handler.do_query(request, QueryContext::arc()).await?; + grpc_handler + .do_query(request, QueryContext::with(&catalog_name, &schema_name)) + .await?; Ok((StatusCode::NO_CONTENT, ())) } diff --git a/src/servers/src/http/authorize.rs b/src/servers/src/http/authorize.rs index 49e4e32e7f..3974a13ca7 100644 --- a/src/servers/src/http/authorize.rs +++ b/src/servers/src/http/authorize.rs @@ -16,6 +16,7 @@ use std::marker::PhantomData; use axum::http::{self, Request, StatusCode}; use axum::response::Response; +use common_catalog::parse_catalog_and_schema_from_db_string; use common_error::ext::ErrorExt; use common_telemetry::warn; use futures::future::BoxFuture; @@ -157,9 +158,7 @@ fn extract_catalog_and_schema( msg: "db not provided or corrupted", })?; - Ok(crate::parse_catalog_and_schema_from_client_database_name( - dbname, - )) + Ok(parse_catalog_and_schema_from_db_string(dbname)) } fn get_influxdb_credentials( diff --git a/src/servers/src/http/influxdb.rs b/src/servers/src/http/influxdb.rs index df5131188a..d2d931865d 100644 --- a/src/servers/src/http/influxdb.rs +++ b/src/servers/src/http/influxdb.rs @@ -18,13 +18,13 @@ use axum::extract::{Query, State}; use axum::http::StatusCode; use axum::response::IntoResponse; use common_catalog::consts::DEFAULT_SCHEMA_NAME; +use common_catalog::parse_catalog_and_schema_from_db_string; use common_grpc::writer::Precision; use common_telemetry::timer; use session::context::QueryContext; use crate::error::{Result, TimePrecisionSnafu}; use crate::influxdb::InfluxdbRequest; -use crate::parse_catalog_and_schema_from_client_database_name; use crate::query_handler::InfluxdbLineProtocolHandlerRef; // https://docs.influxdata.com/influxdb/v1.8/tools/api/#ping-http-endpoint @@ -86,7 +86,7 @@ pub async fn influxdb_write( &[(crate::metrics::METRIC_DB_LABEL, db.to_string())] ); - let (catalog, schema) = parse_catalog_and_schema_from_client_database_name(db); + let (catalog, schema) = parse_catalog_and_schema_from_db_string(db); let ctx = QueryContext::with(catalog, schema); let request = InfluxdbRequest { precision, lines }; diff --git a/src/servers/src/http/opentsdb.rs b/src/servers/src/http/opentsdb.rs index c4f8103064..5dbe70e143 100644 --- a/src/servers/src/http/opentsdb.rs +++ b/src/servers/src/http/opentsdb.rs @@ -17,7 +17,6 @@ use std::collections::HashMap; use axum::extract::{Query, RawBody, State}; use axum::http::StatusCode as HttpStatusCode; use axum::Json; -use common_catalog::consts::DEFAULT_SCHEMA_NAME; use hyper::Body; use serde::{Deserialize, Serialize}; use session::context::QueryContext; @@ -25,7 +24,6 @@ use snafu::ResultExt; use crate::error::{self, Error, Result}; use crate::opentsdb::codec::DataPoint; -use crate::parse_catalog_and_schema_from_client_database_name; use crate::query_handler::OpentsdbProtocolHandlerRef; #[derive(Serialize, Deserialize)] @@ -84,13 +82,7 @@ pub async fn put( let summary = params.contains_key("summary"); let details = params.contains_key("details"); - let db = params - .get("db") - .map(|v| v.as_str()) - .unwrap_or(DEFAULT_SCHEMA_NAME); - - let (catalog, schema) = parse_catalog_and_schema_from_client_database_name(db); - let ctx = QueryContext::with(catalog, schema); + let ctx = QueryContext::with_db_name(params.get("db")); let data_points = parse_data_points(body).await?; diff --git a/src/servers/src/http/otlp.rs b/src/servers/src/http/otlp.rs index 05d29951a2..3fca759c91 100644 --- a/src/servers/src/http/otlp.rs +++ b/src/servers/src/http/otlp.rs @@ -27,7 +27,6 @@ use snafu::prelude::*; use crate::error::{self, Result}; use crate::http::header::GreptimeDbName; -use crate::parse_catalog_and_schema_from_client_database_name; use crate::query_handler::OpenTelemetryProtocolHandlerRef; #[axum_macros::debug_handler] @@ -36,12 +35,7 @@ pub async fn metrics( TypedHeader(db): TypedHeader, RawBody(body): RawBody, ) -> Result { - let ctx = if let Some(db) = db.value() { - let (catalog, schema) = parse_catalog_and_schema_from_client_database_name(db); - QueryContext::with(catalog, schema) - } else { - QueryContext::arc() - }; + let ctx = QueryContext::with_db_name(db.value()); let _timer = timer!( crate::metrics::METRIC_HTTP_OPENTELEMETRY_ELAPSED, &[(crate::metrics::METRIC_DB_LABEL, ctx.get_db_string())] diff --git a/src/servers/src/http/prom_store.rs b/src/servers/src/http/prom_store.rs index 22ab0c5f5d..a80896edba 100644 --- a/src/servers/src/http/prom_store.rs +++ b/src/servers/src/http/prom_store.rs @@ -26,7 +26,6 @@ use session::context::QueryContext; use snafu::prelude::*; use crate::error::{self, Result}; -use crate::parse_catalog_and_schema_from_client_database_name; use crate::prom_store::snappy_decompress; use crate::query_handler::{PromStoreProtocolHandlerRef, PromStoreResponse}; @@ -58,14 +57,8 @@ pub async fn remote_write( params.db.clone().unwrap_or_default() )] ); - let ctx = if let Some(db) = params.db { - let (catalog, schema) = parse_catalog_and_schema_from_client_database_name(&db); - QueryContext::with(catalog, schema) - } else { - QueryContext::arc() - }; - // TODO(shuiyisong): add more error log + let ctx = QueryContext::with_db_name(params.db.as_ref()); handler.write(request, ctx).await?; Ok((StatusCode::NO_CONTENT, ())) } @@ -98,14 +91,8 @@ pub async fn remote_read( params.db.clone().unwrap_or_default() )] ); - let ctx = if let Some(db) = params.db { - let (catalog, schema) = parse_catalog_and_schema_from_client_database_name(&db); - QueryContext::with(catalog, schema) - } else { - QueryContext::arc() - }; - // TODO(shuiyisong): add more error log + let ctx = QueryContext::with_db_name(params.db.as_ref()); handler.read(request, ctx).await } diff --git a/src/servers/src/lib.rs b/src/servers/src/lib.rs index b77bb9c941..55365d9d1c 100644 --- a/src/servers/src/lib.rs +++ b/src/servers/src/lib.rs @@ -15,7 +15,6 @@ #![feature(assert_matches)] #![feature(try_blocks)] -use common_catalog::consts::DEFAULT_CATALOG_NAME; use datatypes::schema::Schema; use query::plan::LogicalPlan; use serde::{Deserialize, Serialize}; @@ -49,33 +48,6 @@ pub enum Mode { Distributed, } -/// Attempt to parse catalog and schema from given database name -/// -/// The database name may come from different sources: -/// -/// - MySQL `schema` name in MySQL protocol login request: it's optional and user -/// and switch database using `USE` command -/// - Postgres `database` parameter in Postgres wire protocol, required -/// - HTTP RESTful API: the database parameter, optional -/// - gRPC: the dbname field in header, optional but has a higher priority than -/// original catalog/schema -/// -/// When database name is provided, we attempt to parse catalog and schema from -/// it. We assume the format `[-]`: -/// -/// - If `[-]` part is not provided, we use whole database name as -/// schema name -/// - if `[-]` is provided, we split database name with `-` and use -/// `` and ``. -pub fn parse_catalog_and_schema_from_client_database_name(db: &str) -> (&str, &str) { - let parts = db.splitn(2, '-').collect::>(); - if parts.len() == 2 { - (parts[0], parts[1]) - } else { - (DEFAULT_CATALOG_NAME, db) - } -} - /// Cached SQL and logical plan for database interfaces #[derive(Clone)] pub struct SqlPlan { @@ -83,26 +55,3 @@ pub struct SqlPlan { plan: Option, schema: Option, } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_parse_catalog_and_schema() { - assert_eq!( - (DEFAULT_CATALOG_NAME, "fullschema"), - parse_catalog_and_schema_from_client_database_name("fullschema") - ); - - assert_eq!( - ("catalog", "schema"), - parse_catalog_and_schema_from_client_database_name("catalog-schema") - ); - - assert_eq!( - ("catalog", "schema1-schema2"), - parse_catalog_and_schema_from_client_database_name("catalog-schema1-schema2") - ); - } -} diff --git a/src/servers/src/mysql/handler.rs b/src/servers/src/mysql/handler.rs index d833aca555..a89847d2c7 100644 --- a/src/servers/src/mysql/handler.rs +++ b/src/servers/src/mysql/handler.rs @@ -19,6 +19,7 @@ use std::time::{Duration, Instant}; use async_trait::async_trait; use chrono::{NaiveDate, NaiveDateTime}; +use common_catalog::parse_catalog_and_schema_from_db_string; use common_error::ext::ErrorExt; use common_query::Output; use common_telemetry::{error, info, logging, timer, warn}; @@ -351,7 +352,7 @@ impl AsyncMysqlShim for MysqlInstanceShi } async fn on_init<'a>(&'a mut self, database: &'a str, w: InitWriter<'a, W>) -> Result<()> { - let (catalog, schema) = crate::parse_catalog_and_schema_from_client_database_name(database); + let (catalog, schema) = parse_catalog_and_schema_from_db_string(database); if !self.query_handler.is_valid_schema(catalog, schema).await? { return w diff --git a/src/servers/src/postgres/auth_handler.rs b/src/servers/src/postgres/auth_handler.rs index 5aa543c580..17e7413011 100644 --- a/src/servers/src/postgres/auth_handler.rs +++ b/src/servers/src/postgres/auth_handler.rs @@ -15,6 +15,7 @@ use std::fmt::Debug; use async_trait::async_trait; +use common_catalog::parse_catalog_and_schema_from_db_string; use common_error::ext::ErrorExt; use futures::{Sink, SinkExt}; use metrics::increment_counter; @@ -233,7 +234,7 @@ where { let db_ref = client.metadata().get(super::METADATA_DATABASE); if let Some(db) = db_ref { - let (catalog, schema) = crate::parse_catalog_and_schema_from_client_database_name(db); + let (catalog, schema) = parse_catalog_and_schema_from_db_string(db); if query_handler .is_valid_schema(catalog, schema) .await diff --git a/src/servers/src/prometheus.rs b/src/servers/src/prometheus.rs index 705325a40e..b0cc355ef2 100644 --- a/src/servers/src/prometheus.rs +++ b/src/servers/src/prometheus.rs @@ -23,6 +23,7 @@ use axum::extract::{Path, Query, State}; use axum::{middleware, routing, Form, Json, Router}; use catalog::CatalogManagerRef; use common_catalog::consts::DEFAULT_SCHEMA_NAME; +use common_catalog::parse_catalog_and_schema_from_db_string; use common_error::ext::ErrorExt; use common_error::status_code::StatusCode; use common_query::Output; @@ -444,10 +445,7 @@ pub async fn instant_query( step: "1s".to_string(), }; - let db = ¶ms.db.unwrap_or(DEFAULT_SCHEMA_NAME.to_string()); - let (catalog, schema) = crate::parse_catalog_and_schema_from_client_database_name(db); - - let query_ctx = QueryContext::with(catalog, schema); + let query_ctx = QueryContext::with_db_name(params.db.as_ref()); let result = handler.do_query(&prom_query, query_ctx).await; let (metric_name, result_type) = match retrieve_metric_name_and_result_type(&prom_query.query) { @@ -483,10 +481,7 @@ pub async fn range_query( step: params.step.or(form_params.step).unwrap_or_default(), }; - let db = ¶ms.db.unwrap_or(DEFAULT_SCHEMA_NAME.to_string()); - let (catalog, schema) = crate::parse_catalog_and_schema_from_client_database_name(db); - - let query_ctx = QueryContext::with(catalog, schema); + let query_ctx = QueryContext::with_db_name(params.db.as_ref()); let result = handler.do_query(&prom_query, query_ctx).await; let metric_name = match retrieve_metric_name_and_result_type(&prom_query.query) { @@ -551,7 +546,7 @@ pub async fn labels_query( let _timer = timer!(crate::metrics::METRIC_HTTP_PROMQL_LABEL_QUERY_ELAPSED); let db = ¶ms.db.unwrap_or(DEFAULT_SCHEMA_NAME.to_string()); - let (catalog, schema) = crate::parse_catalog_and_schema_from_client_database_name(db); + let (catalog, schema) = parse_catalog_and_schema_from_db_string(db); let query_ctx = QueryContext::with(catalog, schema); let mut queries = params.matches.0; @@ -815,7 +810,7 @@ pub async fn label_values_query( let _timer = timer!(crate::metrics::METRIC_HTTP_PROMQL_LABEL_VALUE_QUERY_ELAPSED); let db = ¶ms.db.unwrap_or(DEFAULT_SCHEMA_NAME.to_string()); - let (catalog, schema) = crate::parse_catalog_and_schema_from_client_database_name(db); + let (catalog, schema) = parse_catalog_and_schema_from_db_string(db); if label_name == METRIC_NAME_LABEL { let mut table_names = match handler.catalog_manager().table_names(catalog, schema).await { @@ -955,9 +950,7 @@ pub async fn series_query( .or(form_params.end) .unwrap_or_else(current_time_rfc3339); - let db = ¶ms.db.unwrap_or(DEFAULT_SCHEMA_NAME.to_string()); - let (catalog, schema) = super::parse_catalog_and_schema_from_client_database_name(db); - let query_ctx = QueryContext::with(catalog, schema); + let query_ctx = QueryContext::with_db_name(params.db.as_ref()); let mut series = Vec::new(); for query in queries { diff --git a/src/session/src/context.rs b/src/session/src/context.rs index 3159774309..54e755c353 100644 --- a/src/session/src/context.rs +++ b/src/session/src/context.rs @@ -17,8 +17,8 @@ use std::net::SocketAddr; use std::sync::Arc; use arc_swap::ArcSwap; -use common_catalog::build_db_string; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; +use common_catalog::{build_db_string, parse_catalog_and_schema_from_db_string}; use common_time::TimeZone; use derive_builder::Builder; use sql::dialect::{Dialect, GreptimeDbDialect, MySqlDialect, PostgreSqlDialect}; @@ -61,6 +61,24 @@ impl QueryContext { .build() } + pub fn with_db_name(db_name: Option<&String>) -> QueryContextRef { + let (catalog, schema) = db_name + .map(|db| { + let (catalog, schema) = parse_catalog_and_schema_from_db_string(db); + (catalog.to_string(), schema.to_string()) + }) + .unwrap_or_else(|| { + ( + DEFAULT_CATALOG_NAME.to_string(), + DEFAULT_SCHEMA_NAME.to_string(), + ) + }); + QueryContextBuilder::default() + .current_catalog(catalog) + .current_schema(schema) + .build() + } + #[inline] pub fn current_schema(&self) -> String { self.current_schema.clone()