mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-05-26 09:50:40 +00:00
feat: add catalog name resolution for postgres and http interface (#810)
* feat: add catalog name resolution for postgres and http interface * test: add tests for catalog resolution on http and postgres * feat: assign custom catalog for query * chore: order code for better readability
This commit is contained in:
@@ -255,7 +255,10 @@ mod test {
|
||||
let bare = ObjectName(vec![my_table.into()]);
|
||||
|
||||
let using_schema = "foo";
|
||||
let query_ctx = Arc::new(QueryContext::with_current_schema(using_schema.to_string()));
|
||||
let query_ctx = Arc::new(QueryContext::with(
|
||||
DEFAULT_CATALOG_NAME.to_owned(),
|
||||
using_schema.to_string(),
|
||||
));
|
||||
let empty_ctx = Arc::new(QueryContext::new());
|
||||
|
||||
assert_eq!(
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_catalog::consts::DEFAULT_SCHEMA_NAME;
|
||||
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
|
||||
use common_query::Output;
|
||||
use common_recordbatch::util;
|
||||
use datatypes::data_type::ConcreteDataType;
|
||||
@@ -559,6 +559,9 @@ async fn execute_sql(instance: &MockInstance, sql: &str) -> Output {
|
||||
}
|
||||
|
||||
async fn execute_sql_in_db(instance: &MockInstance, sql: &str, db: &str) -> Output {
|
||||
let query_ctx = Arc::new(QueryContext::with_current_schema(db.to_string()));
|
||||
let query_ctx = Arc::new(QueryContext::with(
|
||||
DEFAULT_CATALOG_NAME.to_owned(),
|
||||
db.to_string(),
|
||||
));
|
||||
instance.inner().execute_sql(sql, query_ctx).await.unwrap()
|
||||
}
|
||||
|
||||
@@ -343,9 +343,12 @@ impl Instance {
|
||||
}
|
||||
|
||||
fn handle_use(&self, db: String, query_ctx: QueryContextRef) -> Result<Output> {
|
||||
let catalog = query_ctx.current_catalog();
|
||||
let catalog = catalog.as_deref().unwrap_or(DEFAULT_CATALOG_NAME);
|
||||
|
||||
ensure!(
|
||||
self.catalog_manager
|
||||
.schema(DEFAULT_CATALOG_NAME, &db)
|
||||
.schema(catalog, &db)
|
||||
.context(error::CatalogSnafu)?
|
||||
.is_some(),
|
||||
error::SchemaNotFoundSnafu { schema_info: &db }
|
||||
|
||||
@@ -226,6 +226,7 @@ impl DistInstance {
|
||||
/// Handles distributed database creation
|
||||
async fn handle_create_database(&self, expr: CreateDatabaseExpr) -> Result<Output> {
|
||||
let key = SchemaKey {
|
||||
// TODO(sunng87): custom catalog
|
||||
catalog_name: DEFAULT_CATALOG_NAME.to_string(),
|
||||
schema_name: expr.database_name,
|
||||
};
|
||||
|
||||
@@ -130,8 +130,11 @@ pub fn show_tables(
|
||||
.current_schema()
|
||||
.unwrap_or_else(|| DEFAULT_SCHEMA_NAME.to_string())
|
||||
};
|
||||
// TODO(sunng87): move this function into query_ctx
|
||||
let catalog = query_ctx.current_catalog();
|
||||
let catalog = catalog.as_deref().unwrap_or(DEFAULT_CATALOG_NAME);
|
||||
let schema = catalog_manager
|
||||
.schema(DEFAULT_CATALOG_NAME, &schema)
|
||||
.schema(catalog, &schema)
|
||||
.context(error::CatalogSnafu)?
|
||||
.context(error::SchemaNotFoundSnafu { schema })?;
|
||||
let tables = schema.table_names().context(error::CatalogSnafu)?;
|
||||
|
||||
@@ -30,6 +30,7 @@ use axum::body::BoxBody;
|
||||
use axum::error_handling::HandleErrorLayer;
|
||||
use axum::response::{Html, Json};
|
||||
use axum::{routing, BoxError, Extension, Router};
|
||||
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
|
||||
use common_error::prelude::ErrorExt;
|
||||
use common_error::status_code::StatusCode;
|
||||
use common_query::Output;
|
||||
@@ -40,6 +41,7 @@ use futures::FutureExt;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use session::context::QueryContext;
|
||||
use snafu::{ensure, ResultExt};
|
||||
use tokio::sync::oneshot::{self, Sender};
|
||||
use tokio::sync::Mutex;
|
||||
@@ -58,6 +60,38 @@ use crate::query_handler::{
|
||||
};
|
||||
use crate::server::Server;
|
||||
|
||||
/// create query context from database name information, catalog and schema are
|
||||
/// resolved from the name
|
||||
pub(crate) fn query_context_from_db(
|
||||
query_handler: SqlQueryHandlerRef,
|
||||
db: Option<String>,
|
||||
) -> std::result::Result<Arc<QueryContext>, JsonResponse> {
|
||||
if let Some(db) = &db {
|
||||
let (catalog, schema) = super::parse_catalog_and_schema_from_client_database_name(db);
|
||||
let catalog = catalog.unwrap_or(DEFAULT_CATALOG_NAME);
|
||||
|
||||
match query_handler.is_valid_schema(catalog, schema) {
|
||||
Ok(true) => Ok(Arc::new(QueryContext::with(
|
||||
catalog.to_owned(),
|
||||
schema.to_owned(),
|
||||
))),
|
||||
Ok(false) => Err(JsonResponse::with_error(
|
||||
format!("Database not found: {db}"),
|
||||
StatusCode::DatabaseNotFound,
|
||||
)),
|
||||
Err(e) => Err(JsonResponse::with_error(
|
||||
format!("Error checking database: {db}, {e}"),
|
||||
StatusCode::Internal,
|
||||
)),
|
||||
}
|
||||
} else {
|
||||
Ok(Arc::new(QueryContext::with(
|
||||
DEFAULT_CATALOG_NAME.to_owned(),
|
||||
DEFAULT_SCHEMA_NAME.to_owned(),
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
const HTTP_API_VERSION: &str = "v1";
|
||||
|
||||
pub struct HttpServer {
|
||||
|
||||
@@ -13,18 +13,16 @@
|
||||
// limitations under the License.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
use aide::transform::TransformOperation;
|
||||
use axum::extract::{Json, Query, State};
|
||||
use axum::Extension;
|
||||
use common_catalog::consts::DEFAULT_CATALOG_NAME;
|
||||
use common_error::status_code::StatusCode;
|
||||
use common_telemetry::metric;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use session::context::{QueryContext, UserInfo};
|
||||
use session::context::UserInfo;
|
||||
|
||||
use crate::http::{ApiState, JsonResponse};
|
||||
|
||||
@@ -45,26 +43,12 @@ pub async fn sql(
|
||||
let sql_handler = &state.sql_handler;
|
||||
let start = Instant::now();
|
||||
let resp = if let Some(sql) = ¶ms.sql {
|
||||
let query_ctx = Arc::new(QueryContext::new());
|
||||
if let Some(db) = ¶ms.database {
|
||||
match sql_handler.is_valid_schema(DEFAULT_CATALOG_NAME, db) {
|
||||
Ok(true) => query_ctx.set_current_schema(db),
|
||||
Ok(false) => {
|
||||
return Json(JsonResponse::with_error(
|
||||
format!("Database not found: {db}"),
|
||||
StatusCode::DatabaseNotFound,
|
||||
));
|
||||
}
|
||||
Err(e) => {
|
||||
return Json(JsonResponse::with_error(
|
||||
format!("Error checking database: {db}, {e}"),
|
||||
StatusCode::Internal,
|
||||
));
|
||||
}
|
||||
match super::query_context_from_db(sql_handler.clone(), params.database) {
|
||||
Ok(query_ctx) => {
|
||||
JsonResponse::from_output(sql_handler.do_query(sql, query_ctx).await).await
|
||||
}
|
||||
Err(resp) => resp,
|
||||
}
|
||||
|
||||
JsonResponse::from_output(sql_handler.do_query(sql, query_ctx).await).await
|
||||
} else {
|
||||
JsonResponse::with_error(
|
||||
"sql parameter is required.".to_string(),
|
||||
|
||||
@@ -90,6 +90,8 @@ pub async fn run_script(
|
||||
json_err!("invalid name");
|
||||
}
|
||||
|
||||
// TODO(sunng87): query_context and db name resolution
|
||||
|
||||
let output = script_handler.execute_script(name.unwrap()).await;
|
||||
let resp = JsonResponse::from_output(vec![output]).await;
|
||||
|
||||
|
||||
@@ -39,3 +39,50 @@ pub enum Mode {
|
||||
Standalone,
|
||||
Distributed,
|
||||
}
|
||||
|
||||
/// Attempt to parse catalog and schema from given database name
|
||||
///
|
||||
/// The database name may come from different sources:
|
||||
///
|
||||
/// - MySQL `schema` name in MySQL protocol login request: it's optional and user
|
||||
/// and switch database using `USE` command
|
||||
/// - Postgres `database` parameter in Postgres wire protocol, required
|
||||
/// - HTTP RESTful API: the database parameter, optional
|
||||
///
|
||||
/// When database name is provided, we attempt to parse catalog and schema from
|
||||
/// it. We assume the format `[<catalog>-]<schema>`:
|
||||
///
|
||||
/// - If `[<catalog>-]` part is not provided, we use whole database name as
|
||||
/// schema name
|
||||
/// - if `[<catalog>-]` is provided, we split database name with `-` and use
|
||||
/// `<catalog>` and `<schema>`.
|
||||
pub(crate) fn parse_catalog_and_schema_from_client_database_name(db: &str) -> (Option<&str>, &str) {
|
||||
let parts = db.splitn(2, '-').collect::<Vec<&str>>();
|
||||
if parts.len() == 2 {
|
||||
(Some(parts[0]), parts[1])
|
||||
} else {
|
||||
(None, db)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_parse_catalog_and_schema_from_client_database_name() {
|
||||
assert_eq!(
|
||||
(None, "fullschema"),
|
||||
super::parse_catalog_and_schema_from_client_database_name("fullschema")
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
(Some("catalog"), "schema"),
|
||||
super::parse_catalog_and_schema_from_client_database_name("catalog-schema")
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
(Some("catalog"), "schema1-schema2"),
|
||||
super::parse_catalog_and_schema_from_client_database_name("catalog-schema1-schema2")
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,5 +18,9 @@ mod server;
|
||||
|
||||
pub(crate) const METADATA_USER: &str = "user";
|
||||
pub(crate) const METADATA_DATABASE: &str = "database";
|
||||
/// key to store our parsed catalog
|
||||
pub(crate) const METADATA_CATALOG: &str = "catalog";
|
||||
/// key to store our parsed schema
|
||||
pub(crate) const METADATA_SCHEMA: &str = "schema";
|
||||
|
||||
pub use server::PostgresServer;
|
||||
|
||||
@@ -151,31 +151,19 @@ impl StartupHandler for PgAuthStartupHandler {
|
||||
auth::save_startup_parameters_to_metadata(client, startup);
|
||||
|
||||
// check if db is valid
|
||||
let db_ref = client.metadata().get(super::METADATA_DATABASE);
|
||||
if let Some(db) = db_ref {
|
||||
if !self
|
||||
.query_handler
|
||||
.is_valid_schema(DEFAULT_CATALOG_NAME, db)
|
||||
.map_err(|e| PgWireError::ApiError(Box::new(e)))?
|
||||
{
|
||||
send_error(
|
||||
client,
|
||||
"FATAL",
|
||||
"3D000",
|
||||
format!("Database not found: {db}"),
|
||||
)
|
||||
.await?;
|
||||
match resolve_db_info(client, self.query_handler.clone())? {
|
||||
DbResolution::Resolved(catalog, schema) => {
|
||||
client
|
||||
.metadata_mut()
|
||||
.insert(super::METADATA_CATALOG.to_owned(), catalog);
|
||||
client
|
||||
.metadata_mut()
|
||||
.insert(super::METADATA_SCHEMA.to_owned(), schema);
|
||||
}
|
||||
DbResolution::NotFound(msg) => {
|
||||
send_error(client, "FATAL", "3D000", msg).await?;
|
||||
return Ok(());
|
||||
}
|
||||
} else {
|
||||
send_error(
|
||||
client,
|
||||
"FATAL",
|
||||
"3D000",
|
||||
"Database not specified".to_owned(),
|
||||
)
|
||||
.await?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if self.verifier.user_provider.is_some() {
|
||||
@@ -222,3 +210,36 @@ where
|
||||
client.close().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
enum DbResolution {
|
||||
Resolved(String, String),
|
||||
NotFound(String),
|
||||
}
|
||||
|
||||
/// A function extracted to resolve lifetime and readability issues:
|
||||
fn resolve_db_info<C>(
|
||||
client: &mut C,
|
||||
query_handler: SqlQueryHandlerRef,
|
||||
) -> PgWireResult<DbResolution>
|
||||
where
|
||||
C: ClientInfo + Unpin + Send,
|
||||
{
|
||||
let db_ref = client.metadata().get(super::METADATA_DATABASE);
|
||||
if let Some(db) = db_ref {
|
||||
let (catalog, schema) = crate::parse_catalog_and_schema_from_client_database_name(db);
|
||||
let catalog = catalog.unwrap_or(DEFAULT_CATALOG_NAME);
|
||||
if query_handler
|
||||
.is_valid_schema(catalog, schema)
|
||||
.map_err(|e| PgWireError::ApiError(Box::new(e)))?
|
||||
{
|
||||
Ok(DbResolution::Resolved(
|
||||
catalog.to_owned(),
|
||||
schema.to_owned(),
|
||||
))
|
||||
} else {
|
||||
Ok(DbResolution::NotFound(format!("Database not found: {db}")))
|
||||
}
|
||||
} else {
|
||||
Ok(DbResolution::NotFound("Database not specified".to_owned()))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,7 +47,10 @@ where
|
||||
C: ClientInfo,
|
||||
{
|
||||
let query_context = QueryContext::new();
|
||||
if let Some(current_schema) = client.metadata().get(super::METADATA_DATABASE) {
|
||||
if let Some(current_catalog) = client.metadata().get(super::METADATA_CATALOG) {
|
||||
query_context.set_current_catalog(current_catalog);
|
||||
}
|
||||
if let Some(current_schema) = client.metadata().get(super::METADATA_SCHEMA) {
|
||||
query_context.set_current_schema(current_schema);
|
||||
}
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, SystemTime};
|
||||
|
||||
use common_catalog::consts::DEFAULT_SCHEMA_NAME;
|
||||
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
|
||||
use common_runtime::Builder as RuntimeBuilder;
|
||||
use rand::rngs::StdRng;
|
||||
use rand::Rng;
|
||||
@@ -249,6 +249,24 @@ async fn test_using_db() -> Result<()> {
|
||||
.unwrap();
|
||||
let result = client.simple_query("SELECT uint32s FROM numbers").await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
let client = create_connection_with_given_catalog_schema(
|
||||
server_port,
|
||||
DEFAULT_CATALOG_NAME,
|
||||
DEFAULT_SCHEMA_NAME,
|
||||
)
|
||||
.await;
|
||||
assert!(client.is_ok());
|
||||
|
||||
let client =
|
||||
create_connection_with_given_catalog_schema(server_port, "notfound", DEFAULT_SCHEMA_NAME)
|
||||
.await;
|
||||
assert!(client.is_err());
|
||||
|
||||
let client =
|
||||
create_connection_with_given_catalog_schema(server_port, DEFAULT_CATALOG_NAME, "notfound")
|
||||
.await;
|
||||
assert!(client.is_err());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -330,6 +348,17 @@ async fn create_connection_with_given_db(
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
async fn create_connection_with_given_catalog_schema(
|
||||
port: u16,
|
||||
catalog: &str,
|
||||
schema: &str,
|
||||
) -> std::result::Result<Client, PgError> {
|
||||
let url = format!("host=127.0.0.1 port={port} connect_timeout=2 dbname={catalog}-{schema}");
|
||||
let (client, conn) = tokio_postgres::connect(&url, NoTls).await?;
|
||||
tokio::spawn(conn);
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
async fn create_connection_without_db(port: u16) -> std::result::Result<Client, PgError> {
|
||||
let url = format!("host=127.0.0.1 port={port} connect_timeout=2");
|
||||
let (client, conn) = tokio_postgres::connect(&url, NoTls).await?;
|
||||
|
||||
@@ -22,6 +22,7 @@ pub type QueryContextRef = Arc<QueryContext>;
|
||||
pub type ConnInfoRef = Arc<ConnInfo>;
|
||||
|
||||
pub struct QueryContext {
|
||||
current_catalog: ArcSwapOption<String>,
|
||||
current_schema: ArcSwapOption<String>,
|
||||
}
|
||||
|
||||
@@ -38,12 +39,14 @@ impl QueryContext {
|
||||
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
current_catalog: ArcSwapOption::new(None),
|
||||
current_schema: ArcSwapOption::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_current_schema(schema: String) -> Self {
|
||||
pub fn with(catalog: String, schema: String) -> Self {
|
||||
Self {
|
||||
current_catalog: ArcSwapOption::new(Some(Arc::new(catalog))),
|
||||
current_schema: ArcSwapOption::new(Some(Arc::new(schema))),
|
||||
}
|
||||
}
|
||||
@@ -52,6 +55,10 @@ impl QueryContext {
|
||||
self.current_schema.load().as_deref().cloned()
|
||||
}
|
||||
|
||||
pub fn current_catalog(&self) -> Option<String> {
|
||||
self.current_catalog.load().as_deref().cloned()
|
||||
}
|
||||
|
||||
pub fn set_current_schema(&self, schema: &str) {
|
||||
let last = self.current_schema.swap(Some(Arc::new(schema.to_string())));
|
||||
info!(
|
||||
@@ -59,6 +66,16 @@ impl QueryContext {
|
||||
schema, last
|
||||
)
|
||||
}
|
||||
|
||||
pub fn set_current_catalog(&self, catalog: &str) {
|
||||
let last = self
|
||||
.current_catalog
|
||||
.swap(Some(Arc::new(catalog.to_string())));
|
||||
info!(
|
||||
"set new session default catalog: {:?}, swap old: {:?}",
|
||||
catalog, last
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub const DEFAULT_USERNAME: &str = "greptime";
|
||||
|
||||
@@ -221,6 +221,43 @@ pub async fn test_sql_api(store_type: StorageType) {
|
||||
let body = serde_json::from_str::<JsonResponse>(&res.text().await).unwrap();
|
||||
assert_eq!(body.code(), ErrorCode::DatabaseNotFound as u32);
|
||||
|
||||
// test catalog-schema given
|
||||
let res = client
|
||||
.get("/v1/sql?database=greptime-public&sql=select cpu, ts from demo limit 1")
|
||||
.send()
|
||||
.await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
|
||||
let body = serde_json::from_str::<JsonResponse>(&res.text().await).unwrap();
|
||||
assert!(body.success());
|
||||
assert!(body.execution_time_ms().is_some());
|
||||
let outputs = body.output().unwrap();
|
||||
assert_eq!(outputs.len(), 1);
|
||||
assert_eq!(
|
||||
outputs[0],
|
||||
serde_json::from_value::<JsonOutput>(json!({
|
||||
"records":{"schema":{"column_schemas":[{"name":"cpu","data_type":"Float64"},{"name":"ts","data_type":"TimestampMillisecond"}]},"rows":[[66.6,0]]}
|
||||
})).unwrap()
|
||||
);
|
||||
|
||||
// test invalid catalog
|
||||
let res = client
|
||||
.get("/v1/sql?database=notfound2-schema&sql=select cpu, ts from demo limit 1")
|
||||
.send()
|
||||
.await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body = serde_json::from_str::<JsonResponse>(&res.text().await).unwrap();
|
||||
assert_eq!(body.code(), ErrorCode::Internal as u32);
|
||||
|
||||
// test invalid schema
|
||||
let res = client
|
||||
.get("/v1/sql?database=greptime-schema&sql=select cpu, ts from demo limit 1")
|
||||
.send()
|
||||
.await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body = serde_json::from_str::<JsonResponse>(&res.text().await).unwrap();
|
||||
assert_eq!(body.code(), ErrorCode::DatabaseNotFound as u32);
|
||||
|
||||
guard.remove_all().await;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user