diff --git a/src/common/function/src/system.rs b/src/common/function/src/system.rs index 6d1c1ebb47..b543fb9523 100644 --- a/src/common/function/src/system.rs +++ b/src/common/function/src/system.rs @@ -22,7 +22,7 @@ mod version; use std::sync::Arc; use build::BuildFunction; -use database::DatabaseFunction; +use database::{CurrentSchemaFunction, DatabaseFunction}; use pg_catalog::PGCatalogFunction; use procedure_state::ProcedureStateFunction; use timezone::TimezoneFunction; @@ -37,6 +37,7 @@ impl SystemFunction { registry.register(Arc::new(BuildFunction)); registry.register(Arc::new(VersionFunction)); registry.register(Arc::new(DatabaseFunction)); + registry.register(Arc::new(CurrentSchemaFunction)); registry.register(Arc::new(TimezoneFunction)); registry.register_async(Arc::new(ProcedureStateFunction)); PGCatalogFunction::register(registry); diff --git a/src/common/function/src/system/database.rs b/src/common/function/src/system/database.rs index 56630270d4..fece862ee0 100644 --- a/src/common/function/src/system/database.rs +++ b/src/common/function/src/system/database.rs @@ -26,11 +26,35 @@ use crate::function::{Function, FunctionContext}; #[derive(Clone, Debug, Default)] pub struct DatabaseFunction; -const NAME: &str = "database"; +#[derive(Clone, Debug, Default)] +pub struct CurrentSchemaFunction; + +const DATABASE_FUNCTION_NAME: &str = "database"; +const CURRENT_SCHEMA_FUNCTION_NAME: &str = "current_schema"; impl Function for DatabaseFunction { fn name(&self) -> &str { - NAME + DATABASE_FUNCTION_NAME + } + + fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result { + Ok(ConcreteDataType::string_datatype()) + } + + fn signature(&self) -> Signature { + Signature::uniform(0, vec![], Volatility::Immutable) + } + + fn eval(&self, func_ctx: FunctionContext, _columns: &[VectorRef]) -> Result { + let db = func_ctx.query_ctx.current_schema(); + + Ok(Arc::new(StringVector::from_slice(&[&db])) as _) + } +} + +impl Function for CurrentSchemaFunction { + fn name(&self) -> &str { + CURRENT_SCHEMA_FUNCTION_NAME } fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result { @@ -54,6 +78,12 @@ impl fmt::Display for DatabaseFunction { } } +impl fmt::Display for CurrentSchemaFunction { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "CURRENT_SCHEMA") + } +} + #[cfg(test)] mod tests { use std::sync::Arc; diff --git a/src/common/function/src/system/pg_catalog.rs b/src/common/function/src/system/pg_catalog.rs index 78726cffbe..26b7dc4f24 100644 --- a/src/common/function/src/system/pg_catalog.rs +++ b/src/common/function/src/system/pg_catalog.rs @@ -14,11 +14,13 @@ mod pg_get_userbyid; mod table_is_visible; +mod version; use std::sync::Arc; use pg_get_userbyid::PGGetUserByIdFunction; use table_is_visible::PGTableIsVisibleFunction; +use version::PGVersionFunction; use crate::function_registry::FunctionRegistry; @@ -35,5 +37,6 @@ impl PGCatalogFunction { pub fn register(registry: &FunctionRegistry) { registry.register(Arc::new(PGTableIsVisibleFunction)); registry.register(Arc::new(PGGetUserByIdFunction)); + registry.register(Arc::new(PGVersionFunction)); } } diff --git a/src/common/function/src/system/pg_catalog/version.rs b/src/common/function/src/system/pg_catalog/version.rs new file mode 100644 index 0000000000..e9511bd6e1 --- /dev/null +++ b/src/common/function/src/system/pg_catalog/version.rs @@ -0,0 +1,54 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; +use std::{env, fmt}; + +use common_query::error::Result; +use common_query::prelude::{Signature, Volatility}; +use datatypes::data_type::ConcreteDataType; +use datatypes::vectors::{StringVector, VectorRef}; + +use crate::function::{Function, FunctionContext}; + +#[derive(Clone, Debug, Default)] +pub(crate) struct PGVersionFunction; + +impl fmt::Display for PGVersionFunction { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, crate::pg_catalog_func_fullname!("VERSION")) + } +} + +impl Function for PGVersionFunction { + fn name(&self) -> &str { + crate::pg_catalog_func_fullname!("version") + } + + fn return_type(&self, _input_types: &[ConcreteDataType]) -> Result { + Ok(ConcreteDataType::string_datatype()) + } + + fn signature(&self) -> Signature { + Signature::exact(vec![], Volatility::Immutable) + } + + fn eval(&self, _func_ctx: FunctionContext, _columns: &[VectorRef]) -> Result { + let result = StringVector::from(vec![format!( + "PostgreSQL 16.3 GreptimeDB {}", + env!("CARGO_PKG_VERSION") + )]); + Ok(Arc::new(result)) + } +} diff --git a/src/servers/src/postgres.rs b/src/servers/src/postgres.rs index c6e10ad8db..42683ff680 100644 --- a/src/servers/src/postgres.rs +++ b/src/servers/src/postgres.rs @@ -13,6 +13,7 @@ // limitations under the License. mod auth_handler; +mod fixtures; mod handler; mod server; mod types; @@ -41,13 +42,13 @@ use self::handler::DefaultQueryParser; use crate::query_handler::sql::ServerSqlQueryHandlerRef; pub(crate) struct GreptimeDBStartupParameters { - version: &'static str, + version: String, } impl GreptimeDBStartupParameters { fn new() -> GreptimeDBStartupParameters { GreptimeDBStartupParameters { - version: env!("CARGO_PKG_VERSION"), + version: format!("16.3-greptime-{}", env!("CARGO_PKG_VERSION")), } } } @@ -58,7 +59,7 @@ impl ServerParameterProvider for GreptimeDBStartupParameters { C: ClientInfo, { Some(HashMap::from([ - ("server_version".to_owned(), self.version.to_owned()), + ("server_version".to_owned(), self.version.clone()), ("server_encoding".to_owned(), "UTF8".to_owned()), ("client_encoding".to_owned(), "UTF8".to_owned()), ("DateStyle".to_owned(), "ISO YMD".to_owned()), diff --git a/src/servers/src/postgres/fixtures.rs b/src/servers/src/postgres/fixtures.rs new file mode 100644 index 0000000000..5b02480da9 --- /dev/null +++ b/src/servers/src/postgres/fixtures.rs @@ -0,0 +1,167 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::HashMap; +use std::sync::Arc; + +use futures::stream; +use once_cell::sync::Lazy; +use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response, Tag}; +use pgwire::api::Type; +use pgwire::error::PgWireResult; +use pgwire::messages::data::DataRow; +use regex::Regex; +use session::context::QueryContextRef; + +fn build_string_data_rows( + schema: Arc>, + rows: Vec>, +) -> Vec> { + rows.iter() + .map(|row| { + let mut encoder = DataRowEncoder::new(schema.clone()); + for value in row { + encoder.encode_field(&Some(value))?; + } + encoder.finish() + }) + .collect() +} + +static VAR_VALUES: Lazy> = Lazy::new(|| { + HashMap::from([ + ("default_transaction_isolation", "read committed"), + ("transaction isolation level", "read committed"), + ("standard_conforming_strings", "on"), + ("client_encoding", "UTF8"), + ]) +}); + +static SHOW_PATTERN: Lazy = Lazy::new(|| Regex::new("(?i)^SHOW (.*?);?$").unwrap()); +static SET_TRANSACTION_PATTERN: Lazy = + Lazy::new(|| Regex::new("(?i)^SET TRANSACTION (.*?);?$").unwrap()); +static TRANSACTION_PATTERN: Lazy = + Lazy::new(|| Regex::new("(?i)^(BEGIN|ROLLBACK|COMMIT);?").unwrap()); + +/// Process unsupported SQL and return fixed result as a compatibility solution +pub(crate) fn process<'a>( + query: &str, + _query_ctx: QueryContextRef, +) -> Option>>> { + // Transaction directives: + if let Some(tx) = TRANSACTION_PATTERN.captures(query) { + let tx_tag = &tx[1]; + Some(Ok(vec![Response::Execution(Tag::new( + &tx_tag.to_uppercase(), + ))])) + } else if let Some(show_var) = SHOW_PATTERN.captures(query) { + let show_var = show_var[1].to_lowercase(); + if let Some(value) = VAR_VALUES.get(&show_var.as_ref()) { + let f1 = FieldInfo::new( + show_var.clone(), + None, + None, + Type::VARCHAR, + FieldFormat::Text, + ); + let schema = Arc::new(vec![f1]); + let data = stream::iter(build_string_data_rows( + schema.clone(), + vec![vec![value.to_string()]], + )); + + Some(Ok(vec![Response::Query(QueryResponse::new(schema, data))])) + } else { + None + } + } else if SET_TRANSACTION_PATTERN.is_match(query) { + Some(Ok(vec![Response::Execution(Tag::new("SET"))])) + } else { + None + } +} + +#[cfg(test)] +mod test { + use session::context::{QueryContext, QueryContextRef}; + + use super::*; + + fn assert_tag(q: &str, t: &str, query_context: QueryContextRef) { + if let Response::Execution(tag) = process(q, query_context.clone()) + .unwrap_or_else(|| panic!("fail to match {}", q)) + .expect("unexpected error") + .remove(0) + { + assert_eq!(Tag::new(t), tag); + } else { + panic!("Invalid response"); + } + } + + fn get_data<'a>(q: &str, query_context: QueryContextRef) -> QueryResponse<'a> { + if let Response::Query(resp) = process(q, query_context.clone()) + .unwrap_or_else(|| panic!("fail to match {}", q)) + .expect("unexpected error") + .remove(0) + { + resp + } else { + panic!("Invalid response"); + } + } + + #[test] + fn test_process() { + let query_context = QueryContext::arc(); + + assert_tag("BEGIN", "BEGIN", query_context.clone()); + assert_tag("BEGIN;", "BEGIN", query_context.clone()); + assert_tag("begin;", "BEGIN", query_context.clone()); + assert_tag("ROLLBACK", "ROLLBACK", query_context.clone()); + assert_tag("ROLLBACK;", "ROLLBACK", query_context.clone()); + assert_tag("rollback;", "ROLLBACK", query_context.clone()); + assert_tag("COMMIT", "COMMIT", query_context.clone()); + assert_tag("COMMIT;", "COMMIT", query_context.clone()); + assert_tag("commit;", "COMMIT", query_context.clone()); + assert_tag( + "SET TRANSACTION ISOLATION LEVEL READ COMMITTED", + "SET", + query_context.clone(), + ); + assert_tag( + "SET TRANSACTION ISOLATION LEVEL READ COMMITTED;", + "SET", + query_context.clone(), + ); + assert_tag( + "SET transaction isolation level READ COMMITTED;", + "SET", + query_context.clone(), + ); + + let resp = get_data("SHOW transaction isolation level", query_context.clone()); + assert_eq!(1, resp.row_schema().len()); + let resp = get_data("show client_encoding;", query_context.clone()); + assert_eq!(1, resp.row_schema().len()); + let resp = get_data("show standard_conforming_strings;", query_context.clone()); + assert_eq!(1, resp.row_schema().len()); + let resp = get_data("show default_transaction_isolation", query_context.clone()); + assert_eq!(1, resp.row_schema().len()); + + assert!(process("SELECT 1", query_context.clone()).is_none()); + assert!(process("SHOW TABLES ", query_context.clone()).is_none()); + assert!(process("SET TIME_ZONE=utc ", query_context.clone()).is_none()); + } +} diff --git a/src/servers/src/postgres/handler.rs b/src/servers/src/postgres/handler.rs index 49a596acc1..5d0c041cf2 100644 --- a/src/servers/src/postgres/handler.rs +++ b/src/servers/src/postgres/handler.rs @@ -37,7 +37,7 @@ use sql::dialect::PostgreSqlDialect; use sql::parser::{ParseOptions, ParserContext}; use super::types::*; -use super::PostgresServerHandler; +use super::{fixtures, PostgresServerHandler}; use crate::error::Result; use crate::query_handler::sql::ServerSqlQueryHandlerRef; use crate::SqlPlan; @@ -58,20 +58,26 @@ impl SimpleQueryHandler for PostgresServerHandler { let _timer = crate::metrics::METRIC_POSTGRES_QUERY_TIMER .with_label_values(&[crate::metrics::METRIC_POSTGRES_SIMPLE_QUERY, db.as_str()]) .start_timer(); - let outputs = self.query_handler.do_query(query, query_ctx.clone()).await; - let mut results = Vec::with_capacity(outputs.len()); + if let Some(resps) = fixtures::process(query, query_ctx.clone()) { + resps + } else { + let outputs = self.query_handler.do_query(query, query_ctx.clone()).await; - for output in outputs { - let resp = output_to_query_response(query_ctx.clone(), output, &Format::UnifiedText)?; - results.push(resp); + let mut results = Vec::with_capacity(outputs.len()); + + for output in outputs { + let resp = + output_to_query_response(query_ctx.clone(), output, &Format::UnifiedText)?; + results.push(resp); + } + + Ok(results) } - - Ok(results) } } -fn output_to_query_response<'a>( +pub(crate) fn output_to_query_response<'a>( query_ctx: QueryContextRef, output: Result, field_format: &Format,