From 4ef9afd8d8253ea1387784becd3ffabb32ca5bdd Mon Sep 17 00:00:00 2001 From: Weny Xu Date: Tue, 1 Apr 2025 17:17:01 +0800 Subject: [PATCH] feat: introduce read preference (#5783) * feat: introduce read preference * feat: introduce `RegionQueryHandlerFactory` * feat: extract ReadPreference from http header * test: add more tests * chore: apply suggestions from CR * chore: apply suggestions from CR --- Cargo.lock | 11 +++++ Cargo.toml | 2 + src/common/function/src/system.rs | 5 ++- src/common/function/src/system/database.rs | 29 ++++++++++++ src/common/session/Cargo.toml | 11 +++++ src/common/session/src/lib.rs | 45 +++++++++++++++++++ src/frontend/src/error.rs | 15 +++++-- src/frontend/src/instance/builder.rs | 7 ++- src/frontend/src/instance/region_query.rs | 22 ++++++--- src/operator/src/statement.rs | 4 +- src/operator/src/statement/set.rs | 34 ++++++++++++++ src/partition/Cargo.toml | 1 + src/partition/src/manager.rs | 1 + src/query/Cargo.toml | 1 + src/query/src/dist_plan/merge_scan.rs | 3 +- src/query/src/region_query.rs | 21 ++++++++- src/query/src/sql.rs | 1 + src/servers/Cargo.toml | 1 + src/servers/src/http.rs | 6 ++- src/servers/src/http/header.rs | 5 +++ src/servers/src/http/read_preference.rs | 40 +++++++++++++++++ src/session/Cargo.toml | 1 + src/session/src/context.rs | 10 ++++- src/session/src/lib.rs | 11 +++++ .../common/system/read_preference.result | 28 ++++++++++++ .../common/system/read_preference.sql | 13 ++++++ 26 files changed, 311 insertions(+), 17 deletions(-) create mode 100644 src/common/session/Cargo.toml create mode 100644 src/common/session/src/lib.rs create mode 100644 src/servers/src/http/read_preference.rs create mode 100644 tests/cases/standalone/common/system/read_preference.result create mode 100644 tests/cases/standalone/common/system/read_preference.sql diff --git a/Cargo.lock b/Cargo.lock index 3167b27a6e..b36093fb60 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2344,6 +2344,13 @@ dependencies = [ "tokio-util", ] +[[package]] +name = "common-session" +version = "0.14.0" +dependencies = [ + "strum 0.27.1", +] + [[package]] name = "common-telemetry" version = "0.14.0" @@ -8105,6 +8112,7 @@ dependencies = [ "itertools 0.14.0", "serde", "serde_json", + "session", "snafu 0.8.5", "sql", "sqlparser 0.52.0 (git+https://github.com/GreptimeTeam/sqlparser-rs.git?rev=71dd86058d2af97b9925093d40c4e03360403170)", @@ -9117,6 +9125,7 @@ dependencies = [ "num-traits", "object-store", "once_cell", + "partition", "paste", "pretty_assertions", "prometheus", @@ -10555,6 +10564,7 @@ dependencies = [ "common-query", "common-recordbatch", "common-runtime", + "common-session", "common-telemetry", "common-test-util", "common-time", @@ -10651,6 +10661,7 @@ dependencies = [ "common-error", "common-macro", "common-recordbatch", + "common-session", "common-telemetry", "common-time", "derive_builder 0.20.1", diff --git a/Cargo.toml b/Cargo.toml index e738ff55ce..62454b5987 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,6 +29,7 @@ members = [ "src/common/query", "src/common/recordbatch", "src/common/runtime", + "src/common/session", "src/common/substrait", "src/common/telemetry", "src/common/test-util", @@ -248,6 +249,7 @@ common-procedure-test = { path = "src/common/procedure-test" } common-query = { path = "src/common/query" } common-recordbatch = { path = "src/common/recordbatch" } common-runtime = { path = "src/common/runtime" } +common-session = { path = "src/common/session" } common-telemetry = { path = "src/common/telemetry" } common-test-util = { path = "src/common/test-util" } common-time = { path = "src/common/time" } diff --git a/src/common/function/src/system.rs b/src/common/function/src/system.rs index b2e6c41135..dad1e4f7bf 100644 --- a/src/common/function/src/system.rs +++ b/src/common/function/src/system.rs @@ -22,7 +22,9 @@ mod version; use std::sync::Arc; use build::BuildFunction; -use database::{CurrentSchemaFunction, DatabaseFunction, SessionUserFunction}; +use database::{ + CurrentSchemaFunction, DatabaseFunction, ReadPreferenceFunction, SessionUserFunction, +}; use pg_catalog::PGCatalogFunction; use procedure_state::ProcedureStateFunction; use timezone::TimezoneFunction; @@ -39,6 +41,7 @@ impl SystemFunction { registry.register(Arc::new(CurrentSchemaFunction)); registry.register(Arc::new(DatabaseFunction)); registry.register(Arc::new(SessionUserFunction)); + registry.register(Arc::new(ReadPreferenceFunction)); registry.register(Arc::new(TimezoneFunction)); registry.register_async(Arc::new(ProcedureStateFunction)); PGCatalogFunction::register(registry); diff --git a/src/common/function/src/system/database.rs b/src/common/function/src/system/database.rs index 370bd2c8da..f1f4682de7 100644 --- a/src/common/function/src/system/database.rs +++ b/src/common/function/src/system/database.rs @@ -30,9 +30,12 @@ pub struct DatabaseFunction; pub struct CurrentSchemaFunction; pub struct SessionUserFunction; +pub struct ReadPreferenceFunction; + const DATABASE_FUNCTION_NAME: &str = "database"; const CURRENT_SCHEMA_FUNCTION_NAME: &str = "current_schema"; const SESSION_USER_FUNCTION_NAME: &str = "session_user"; +const READ_PREFERENCE_FUNCTION_NAME: &str = "read_preference"; impl Function for DatabaseFunction { fn name(&self) -> &str { @@ -94,6 +97,26 @@ impl Function for SessionUserFunction { } } +impl Function for ReadPreferenceFunction { + fn name(&self) -> &str { + READ_PREFERENCE_FUNCTION_NAME + } + + fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result { + Ok(ConcreteDataType::string_datatype()) + } + + fn signature(&self) -> Signature { + Signature::nullary(Volatility::Immutable) + } + + fn eval(&self, func_ctx: &FunctionContext, _columns: &[VectorRef]) -> Result { + let read_preference = func_ctx.query_ctx.read_preference(); + + Ok(Arc::new(StringVector::from_slice(&[read_preference.as_ref()])) as _) + } +} + impl fmt::Display for DatabaseFunction { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "DATABASE") @@ -112,6 +135,12 @@ impl fmt::Display for SessionUserFunction { } } +impl fmt::Display for ReadPreferenceFunction { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "READ_PREFERENCE") + } +} + #[cfg(test)] mod tests { use std::sync::Arc; diff --git a/src/common/session/Cargo.toml b/src/common/session/Cargo.toml new file mode 100644 index 0000000000..cd9bfa626b --- /dev/null +++ b/src/common/session/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "common-session" +version.workspace = true +edition.workspace = true +license.workspace = true + +[lints] +workspace = true + +[dependencies] +strum.workspace = true diff --git a/src/common/session/src/lib.rs b/src/common/session/src/lib.rs new file mode 100644 index 0000000000..51d7846532 --- /dev/null +++ b/src/common/session/src/lib.rs @@ -0,0 +1,45 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use strum::{AsRefStr, Display, EnumString}; + +/// Defines the read preference for frontend route operations, +/// determining whether to read from the region leader or follower. +#[derive(Debug, Clone, Copy, Default, EnumString, Display, AsRefStr, PartialEq, Eq)] +pub enum ReadPreference { + #[default] + // Reads all operations from the region leader. This is the default mode. + #[strum(serialize = "leader", to_string = "LEADER")] + Leader, +} + +#[cfg(test)] +mod tests { + use std::str::FromStr; + + use crate::ReadPreference; + + #[test] + fn test_read_preference() { + assert_eq!(ReadPreference::Leader.to_string(), "LEADER"); + + let read_preference = ReadPreference::from_str("LEADER").unwrap(); + assert_eq!(read_preference, ReadPreference::Leader); + + let read_preference = ReadPreference::from_str("leader").unwrap(); + assert_eq!(read_preference, ReadPreference::Leader); + + ReadPreference::from_str("follower").unwrap_err(); + } +} diff --git a/src/frontend/src/error.rs b/src/frontend/src/error.rs index 076ecf9943..06cb2e8a5e 100644 --- a/src/frontend/src/error.rs +++ b/src/frontend/src/error.rs @@ -19,7 +19,9 @@ use common_error::define_into_tonic_status; use common_error::ext::{BoxedError, ErrorExt}; use common_error::status_code::StatusCode; use common_macro::stack_trace_debug; +use session::ReadPreference; use snafu::{Location, Snafu}; +use store_api::storage::RegionId; #[derive(Snafu)] #[snafu(visibility(pub))] @@ -140,9 +142,14 @@ pub enum Error { location: Location, }, - #[snafu(display("Failed to find table route for table id {}", table_id))] - FindTableRoute { - table_id: u32, + #[snafu(display( + "Failed to find region peer for region id {}, read preference: {}", + region_id, + read_preference + ))] + FindRegionPeer { + region_id: RegionId, + read_preference: ReadPreference, #[snafu(implicit)] location: Location, source: partition::error::Error, @@ -410,7 +417,7 @@ impl ErrorExt for Error { Error::External { source, .. } | Error::InitPlugin { source, .. } => { source.status_code() } - Error::FindTableRoute { source, .. } => source.status_code(), + Error::FindRegionPeer { source, .. } => source.status_code(), Error::TableOperation { source, .. } => source.status_code(), diff --git a/src/frontend/src/instance/builder.rs b/src/frontend/src/instance/builder.rs index 4e51138bc1..52b2463503 100644 --- a/src/frontend/src/instance/builder.rs +++ b/src/frontend/src/instance/builder.rs @@ -33,6 +33,7 @@ use operator::statement::{StatementExecutor, StatementExecutorRef}; use operator::table::TableMutationOperator; use partition::manager::PartitionRuleManager; use pipeline::pipeline_operator::PipelineOperator; +use query::region_query::RegionQueryHandlerFactoryRef; use query::stats::StatementStatistics; use query::QueryEngineFactory; use snafu::OptionExt; @@ -114,7 +115,11 @@ impl FrontendBuilder { .unwrap_or_else(|| Arc::new(DummyCacheInvalidator)); let region_query_handler = - FrontendRegionQueryHandler::arc(partition_manager.clone(), node_manager.clone()); + if let Some(factory) = plugins.get::() { + factory.build(partition_manager.clone(), node_manager.clone()) + } else { + FrontendRegionQueryHandler::arc(partition_manager.clone(), node_manager.clone()) + }; let table_flownode_cache = self.layered_cache_registry diff --git a/src/frontend/src/instance/region_query.rs b/src/frontend/src/instance/region_query.rs index b6d2f00b22..56874f5777 100644 --- a/src/frontend/src/instance/region_query.rs +++ b/src/frontend/src/instance/region_query.rs @@ -22,9 +22,10 @@ use common_recordbatch::SendableRecordBatchStream; use partition::manager::PartitionRuleManagerRef; use query::error::{RegionQuerySnafu, Result as QueryResult}; use query::region_query::RegionQueryHandler; +use session::ReadPreference; use snafu::ResultExt; -use crate::error::{FindTableRouteSnafu, RequestQuerySnafu, Result}; +use crate::error::{FindRegionPeerSnafu, RequestQuerySnafu, Result}; pub(crate) struct FrontendRegionQueryHandler { partition_manager: PartitionRuleManagerRef, @@ -45,8 +46,12 @@ impl FrontendRegionQueryHandler { #[async_trait] impl RegionQueryHandler for FrontendRegionQueryHandler { - async fn do_get(&self, request: QueryRequest) -> QueryResult { - self.do_get_inner(request) + async fn do_get( + &self, + read_preference: ReadPreference, + request: QueryRequest, + ) -> QueryResult { + self.do_get_inner(read_preference, request) .await .map_err(BoxedError::new) .context(RegionQuerySnafu) @@ -54,15 +59,20 @@ impl RegionQueryHandler for FrontendRegionQueryHandler { } impl FrontendRegionQueryHandler { - async fn do_get_inner(&self, request: QueryRequest) -> Result { + async fn do_get_inner( + &self, + read_preference: ReadPreference, + request: QueryRequest, + ) -> Result { let region_id = request.region_id; let peer = &self .partition_manager .find_region_leader(region_id) .await - .context(FindTableRouteSnafu { - table_id: region_id.table_id(), + .context(FindRegionPeerSnafu { + region_id, + read_preference, })?; let client = self.node_manager.datanode(peer).await; diff --git a/src/operator/src/statement.rs b/src/operator/src/statement.rs index 6d1ab2f1e2..d786396e87 100644 --- a/src/operator/src/statement.rs +++ b/src/operator/src/statement.rs @@ -56,7 +56,7 @@ use query::parser::QueryStatement; use query::QueryEngineRef; use session::context::{Channel, QueryContextRef}; use session::table_name::table_idents_to_full_name; -use set::set_query_timeout; +use set::{set_query_timeout, set_read_preference}; use snafu::{ensure, OptionExt, ResultExt}; use sql::statements::copy::{ CopyDatabase, CopyDatabaseArgument, CopyQueryToArgument, CopyTable, CopyTableArgument, @@ -379,6 +379,8 @@ impl StatementExecutor { fn set_variables(&self, set_var: SetVariables, query_ctx: QueryContextRef) -> Result { let var_name = set_var.variable.to_string().to_uppercase(); match var_name.as_str() { + "READ_PREFERENCE" => set_read_preference(set_var.value, query_ctx)?, + "TIMEZONE" | "TIME_ZONE" => set_timezone(set_var.value, query_ctx)?, "BYTEA_OUTPUT" => set_bytea_output(set_var.value, query_ctx)?, diff --git a/src/operator/src/statement/set.rs b/src/operator/src/statement/set.rs index 6211f1a554..6df8e3e630 100644 --- a/src/operator/src/statement/set.rs +++ b/src/operator/src/statement/set.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::str::FromStr; use std::time::Duration; use common_time::Timezone; @@ -20,6 +21,7 @@ use regex::Regex; use session::context::Channel::Postgres; use session::context::QueryContextRef; use session::session_config::{PGByteaOutputValue, PGDateOrder, PGDateTimeStyle}; +use session::ReadPreference; use snafu::{ensure, OptionExt, ResultExt}; use sql::ast::{Expr, Ident, Value}; use sql::statements::set_variables::SetVariables; @@ -35,6 +37,38 @@ lazy_static! { static ref PG_TIME_INPUT_REGEX: Regex = Regex::new(r"^(\d+)(ms|s|min|h|d)$").unwrap(); } +pub fn set_read_preference(exprs: Vec, ctx: QueryContextRef) -> Result<()> { + let read_preference_expr = exprs.first().context(NotSupportedSnafu { + feat: "No read preference find in set variable statement", + })?; + + match read_preference_expr { + Expr::Value(Value::SingleQuotedString(expr)) + | Expr::Value(Value::DoubleQuotedString(expr)) => { + match ReadPreference::from_str(expr.as_str().to_lowercase().as_str()) { + Ok(read_preference) => ctx.set_read_preference(read_preference), + Err(_) => { + return NotSupportedSnafu { + feat: format!( + "Invalid read preference expr {} in set variable statement", + expr, + ), + } + .fail() + } + } + Ok(()) + } + expr => NotSupportedSnafu { + feat: format!( + "Unsupported read preference expr {} in set variable statement", + expr + ), + } + .fail(), + } +} + pub fn set_timezone(exprs: Vec, ctx: QueryContextRef) -> Result<()> { let tz_expr = exprs.first().context(NotSupportedSnafu { feat: "No timezone find in set variable statement", diff --git a/src/partition/Cargo.toml b/src/partition/Cargo.toml index 6402d2feff..ebb7d68f8d 100644 --- a/src/partition/Cargo.toml +++ b/src/partition/Cargo.toml @@ -20,6 +20,7 @@ datatypes.workspace = true itertools.workspace = true serde.workspace = true serde_json.workspace = true +session.workspace = true snafu.workspace = true sql.workspace = true sqlparser.workspace = true diff --git a/src/partition/src/manager.rs b/src/partition/src/manager.rs index 6a96da06e4..b7b0e4c76b 100644 --- a/src/partition/src/manager.rs +++ b/src/partition/src/manager.rs @@ -171,6 +171,7 @@ impl PartitionRuleManager { Ok(Arc::new(rule) as _) } + /// Find the leader of the region. pub async fn find_region_leader(&self, region_id: RegionId) -> Result { let region_routes = &self .find_physical_table_route(region_id.table_id()) diff --git a/src/query/Cargo.toml b/src/query/Cargo.toml index 50676afacb..53f66034e7 100644 --- a/src/query/Cargo.toml +++ b/src/query/Cargo.toml @@ -51,6 +51,7 @@ meter-core.workspace = true meter-macros.workspace = true object-store.workspace = true once_cell.workspace = true +partition.workspace = true prometheus.workspace = true promql.workspace = true promql-parser.workspace = true diff --git a/src/query/src/dist_plan/merge_scan.rs b/src/query/src/dist_plan/merge_scan.rs index 47a9512583..81e1d463d9 100644 --- a/src/query/src/dist_plan/merge_scan.rs +++ b/src/query/src/dist_plan/merge_scan.rs @@ -198,6 +198,7 @@ impl MergeScanExec { let dbname = context.task_id().unwrap_or_default(); let tracing_context = TracingContext::from_json(context.session_id().as_str()); let current_channel = self.query_ctx.channel(); + let read_preference = self.query_ctx.read_preference(); let stream = Box::pin(stream!({ // only report metrics once for each MergeScan @@ -226,7 +227,7 @@ impl MergeScanExec { }; let do_get_start = Instant::now(); let mut stream = region_query_handler - .do_get(request) + .do_get(read_preference, request) .await .map_err(|e| { MERGE_SCAN_ERRORS_TOTAL.inc(); diff --git a/src/query/src/region_query.rs b/src/query/src/region_query.rs index 92b2a4e2f9..2519ccc25f 100644 --- a/src/query/src/region_query.rs +++ b/src/query/src/region_query.rs @@ -15,14 +15,33 @@ use std::sync::Arc; use async_trait::async_trait; +use common_meta::node_manager::NodeManagerRef; use common_query::request::QueryRequest; use common_recordbatch::SendableRecordBatchStream; +use partition::manager::PartitionRuleManagerRef; +use session::ReadPreference; use crate::error::Result; +/// A factory to create a [`RegionQueryHandler`]. +pub trait RegionQueryHandlerFactory: Send + Sync { + /// Build a [`RegionQueryHandler`] with the given partition manager and node manager. + fn build( + &self, + partition_manager: PartitionRuleManagerRef, + node_manager: NodeManagerRef, + ) -> RegionQueryHandlerRef; +} + +pub type RegionQueryHandlerFactoryRef = Arc; + #[async_trait] pub trait RegionQueryHandler: Send + Sync { - async fn do_get(&self, request: QueryRequest) -> Result; + async fn do_get( + &self, + read_preference: ReadPreference, + request: QueryRequest, + ) -> Result; } pub type RegionQueryHandlerRef = Arc; diff --git a/src/query/src/sql.rs b/src/query/src/sql.rs index fa8ba640f6..2d82d771e8 100644 --- a/src/query/src/sql.rs +++ b/src/query/src/sql.rs @@ -724,6 +724,7 @@ pub fn show_variable(stmt: ShowVariables, query_ctx: QueryContextRef) -> Result< let value = match variable.as_str() { "SYSTEM_TIME_ZONE" | "SYSTEM_TIMEZONE" => get_timezone(None).to_string(), "TIME_ZONE" | "TIMEZONE" => query_ctx.timezone().to_string(), + "READ_PREFERENCE" => query_ctx.read_preference().to_string(), "DATESTYLE" => { let (style, order) = *query_ctx.configuration_parameter().pg_datetime_style(); format!("{}, {}", style, order) diff --git a/src/servers/Cargo.toml b/src/servers/Cargo.toml index abd3d105c0..6a6ea8d9d7 100644 --- a/src/servers/Cargo.toml +++ b/src/servers/Cargo.toml @@ -46,6 +46,7 @@ common-pprof = { workspace = true, optional = true } common-query.workspace = true common-recordbatch.workspace = true common-runtime.workspace = true +common-session.workspace = true common-telemetry.workspace = true common-time.workspace = true common-version = { workspace = true, features = ["codec"] } diff --git a/src/servers/src/http.rs b/src/servers/src/http.rs index 2251f75eca..ca3ae8a5bb 100644 --- a/src/servers/src/http.rs +++ b/src/servers/src/http.rs @@ -103,6 +103,7 @@ mod timeout; pub(crate) use timeout::DynamicTimeoutLayer; mod hints; +mod read_preference; #[cfg(any(test, feature = "testing"))] pub mod test_helpers; @@ -804,7 +805,10 @@ impl HttpServer { AuthState::new(self.user_provider.clone()), authorize::check_http_auth, )) - .layer(middleware::from_fn(hints::extract_hints)), + .layer(middleware::from_fn(hints::extract_hints)) + .layer(middleware::from_fn( + read_preference::extract_read_preference, + )), ) // Handlers for debug, we don't expect a timeout. .nest( diff --git a/src/servers/src/http/header.rs b/src/servers/src/http/header.rs index e620e7dd9b..b112743f05 100644 --- a/src/servers/src/http/header.rs +++ b/src/servers/src/http/header.rs @@ -43,6 +43,7 @@ pub mod constants { pub const GREPTIME_DB_HEADER_EXECUTION_TIME: &str = "x-greptime-execution-time"; pub const GREPTIME_DB_HEADER_METRICS: &str = "x-greptime-metrics"; pub const GREPTIME_DB_HEADER_NAME: &str = "x-greptime-db-name"; + pub const GREPTIME_DB_HEADER_READ_PREFERENCE: &str = "x-greptime-read-preference"; pub const GREPTIME_TIMEZONE_HEADER_NAME: &str = "x-greptime-timezone"; pub const GREPTIME_DB_HEADER_ERROR_CODE: &str = common_error::GREPTIME_DB_HEADER_ERROR_CODE; @@ -76,6 +77,10 @@ pub static GREPTIME_DB_HEADER_NAME: HeaderName = pub static GREPTIME_TIMEZONE_HEADER_NAME: HeaderName = HeaderName::from_static(constants::GREPTIME_TIMEZONE_HEADER_NAME); +/// Header key of query specific read preference. Example format of the header value is `leader`. +pub static GREPTIME_DB_HEADER_READ_PREFERENCE: HeaderName = + HeaderName::from_static(constants::GREPTIME_DB_HEADER_READ_PREFERENCE); + pub static CONTENT_TYPE_PROTOBUF_STR: &str = "application/x-protobuf"; pub static CONTENT_TYPE_PROTOBUF: HeaderValue = HeaderValue::from_static(CONTENT_TYPE_PROTOBUF_STR); pub static CONTENT_ENCODING_SNAPPY: HeaderValue = HeaderValue::from_static("snappy"); diff --git a/src/servers/src/http/read_preference.rs b/src/servers/src/http/read_preference.rs new file mode 100644 index 0000000000..aee56c1fe8 --- /dev/null +++ b/src/servers/src/http/read_preference.rs @@ -0,0 +1,40 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::str::FromStr; + +use axum::body::Body; +use axum::http::Request; +use axum::middleware::Next; +use axum::response::Response; +use session::context::QueryContext; +use session::ReadPreference; + +use crate::http::header::GREPTIME_DB_HEADER_READ_PREFERENCE; + +/// Extract read preference from the request headers. +pub async fn extract_read_preference(mut request: Request, next: Next) -> Response { + let read_preference = request + .headers() + .get(&GREPTIME_DB_HEADER_READ_PREFERENCE) + .and_then(|header| header.to_str().ok()) + .and_then(|s| ReadPreference::from_str(s).ok()) + .unwrap_or_default(); + + if let Some(query_ctx) = request.extensions_mut().get_mut::() { + common_telemetry::debug!("Setting read preference to {}", read_preference); + query_ctx.set_read_preference(read_preference); + } + next.run(request).await +} diff --git a/src/session/Cargo.toml b/src/session/Cargo.toml index 49a18d1f16..c263c21616 100644 --- a/src/session/Cargo.toml +++ b/src/session/Cargo.toml @@ -18,6 +18,7 @@ common-catalog.workspace = true common-error.workspace = true common-macro.workspace = true common-recordbatch.workspace = true +common-session.workspace = true common-telemetry.workspace = true common-time.workspace = true derive_builder.workspace = true diff --git a/src/session/src/context.rs b/src/session/src/context.rs index 41e3d3a37d..84087e66e7 100644 --- a/src/session/src/context.rs +++ b/src/session/src/context.rs @@ -32,7 +32,7 @@ use derive_builder::Builder; use sql::dialect::{Dialect, GenericDialect, GreptimeDbDialect, MySqlDialect, PostgreSqlDialect}; use crate::session_config::{PGByteaOutputValue, PGDateOrder, PGDateTimeStyle}; -use crate::MutableInner; +use crate::{MutableInner, ReadPreference}; pub type QueryContextRef = Arc; pub type ConnInfoRef = Arc; @@ -266,6 +266,14 @@ impl QueryContext { self.mutable_session_data.write().unwrap().timezone = timezone; } + pub fn read_preference(&self) -> ReadPreference { + self.mutable_session_data.read().unwrap().read_preference + } + + pub fn set_read_preference(&self, read_preference: ReadPreference) { + self.mutable_session_data.write().unwrap().read_preference = read_preference; + } + pub fn current_user(&self) -> UserInfoRef { self.mutable_session_data.read().unwrap().user_info.clone() } diff --git a/src/session/src/lib.rs b/src/session/src/lib.rs index c018d47ebc..fa78699774 100644 --- a/src/session/src/lib.rs +++ b/src/session/src/lib.rs @@ -25,6 +25,7 @@ use auth::UserInfoRef; use common_catalog::build_db_string; use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME}; use common_recordbatch::cursor::RecordBatchStreamCursor; +pub use common_session::ReadPreference; use common_time::timezone::get_timezone; use common_time::Timezone; use context::{ConfigurationVariables, QueryContextBuilder}; @@ -50,6 +51,7 @@ pub(crate) struct MutableInner { user_info: UserInfoRef, timezone: Timezone, query_timeout: Option, + read_preference: ReadPreference, #[debug(skip)] pub(crate) cursors: HashMap>, } @@ -61,6 +63,7 @@ impl Default for MutableInner { user_info: auth::userinfo_by_name(None), timezone: get_timezone(None).clone(), query_timeout: None, + read_preference: ReadPreference::Leader, cursors: HashMap::with_capacity(0), } } @@ -101,11 +104,19 @@ impl Session { self.mutable_inner.read().unwrap().timezone.clone() } + pub fn read_preference(&self) -> ReadPreference { + self.mutable_inner.read().unwrap().read_preference + } + pub fn set_timezone(&self, tz: Timezone) { let mut inner = self.mutable_inner.write().unwrap(); inner.timezone = tz; } + pub fn set_read_preference(&self, read_preference: ReadPreference) { + self.mutable_inner.write().unwrap().read_preference = read_preference; + } + pub fn user_info(&self) -> UserInfoRef { self.mutable_inner.read().unwrap().user_info.clone() } diff --git a/tests/cases/standalone/common/system/read_preference.result b/tests/cases/standalone/common/system/read_preference.result new file mode 100644 index 0000000000..fa6992d6bd --- /dev/null +++ b/tests/cases/standalone/common/system/read_preference.result @@ -0,0 +1,28 @@ +-- SQLNESS PROTOCOL MYSQL +SELECT read_preference(); + ++-------------------+ +| read_preference() | ++-------------------+ +| LEADER | ++-------------------+ + +-- SQLNESS PROTOCOL MYSQL +SET read_preference = 'hi'; + +Failed to execute query, err: MySqlError { ERROR 1235 (42000): (Unsupported): Not supported: Invalid read preference expr hi in set variable statement } + +-- SQLNESS PROTOCOL MYSQL +SET read_preference = 'leader'; + +affected_rows: 0 + +-- SQLNESS PROTOCOL MYSQL +SELECT read_preference(); + ++-------------------+ +| read_preference() | ++-------------------+ +| LEADER | ++-------------------+ + diff --git a/tests/cases/standalone/common/system/read_preference.sql b/tests/cases/standalone/common/system/read_preference.sql new file mode 100644 index 0000000000..e01da023fc --- /dev/null +++ b/tests/cases/standalone/common/system/read_preference.sql @@ -0,0 +1,13 @@ +-- SQLNESS PROTOCOL MYSQL +SELECT read_preference(); + +-- SQLNESS PROTOCOL MYSQL +SET read_preference = 'hi'; + +-- SQLNESS PROTOCOL MYSQL +SET read_preference = 'leader'; + +-- SQLNESS PROTOCOL MYSQL +SELECT read_preference(); + +