feat: add catalog name resolution for postgres and http interface (#810)

* feat: add catalog name resolution for postgres and http interface

* test: add tests for catalog resolution on http and postgres

* feat: assign custom catalog for query

* chore: order code for better readability
This commit is contained in:
Ning Sun
2023-01-09 11:43:25 +08:00
committed by GitHub
parent 777a3182c5
commit 3988770266
15 changed files with 243 additions and 52 deletions

View File

@@ -255,7 +255,10 @@ mod test {
let bare = ObjectName(vec![my_table.into()]);
let using_schema = "foo";
let query_ctx = Arc::new(QueryContext::with_current_schema(using_schema.to_string()));
let query_ctx = Arc::new(QueryContext::with(
DEFAULT_CATALOG_NAME.to_owned(),
using_schema.to_string(),
));
let empty_ctx = Arc::new(QueryContext::new());
assert_eq!(

View File

@@ -14,7 +14,7 @@
use std::sync::Arc;
use common_catalog::consts::DEFAULT_SCHEMA_NAME;
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use common_query::Output;
use common_recordbatch::util;
use datatypes::data_type::ConcreteDataType;
@@ -559,6 +559,9 @@ async fn execute_sql(instance: &MockInstance, sql: &str) -> Output {
}
async fn execute_sql_in_db(instance: &MockInstance, sql: &str, db: &str) -> Output {
let query_ctx = Arc::new(QueryContext::with_current_schema(db.to_string()));
let query_ctx = Arc::new(QueryContext::with(
DEFAULT_CATALOG_NAME.to_owned(),
db.to_string(),
));
instance.inner().execute_sql(sql, query_ctx).await.unwrap()
}

View File

@@ -343,9 +343,12 @@ impl Instance {
}
fn handle_use(&self, db: String, query_ctx: QueryContextRef) -> Result<Output> {
let catalog = query_ctx.current_catalog();
let catalog = catalog.as_deref().unwrap_or(DEFAULT_CATALOG_NAME);
ensure!(
self.catalog_manager
.schema(DEFAULT_CATALOG_NAME, &db)
.schema(catalog, &db)
.context(error::CatalogSnafu)?
.is_some(),
error::SchemaNotFoundSnafu { schema_info: &db }

View File

@@ -226,6 +226,7 @@ impl DistInstance {
/// Handles distributed database creation
async fn handle_create_database(&self, expr: CreateDatabaseExpr) -> Result<Output> {
let key = SchemaKey {
// TODO(sunng87): custom catalog
catalog_name: DEFAULT_CATALOG_NAME.to_string(),
schema_name: expr.database_name,
};

View File

@@ -130,8 +130,11 @@ pub fn show_tables(
.current_schema()
.unwrap_or_else(|| DEFAULT_SCHEMA_NAME.to_string())
};
// TODO(sunng87): move this function into query_ctx
let catalog = query_ctx.current_catalog();
let catalog = catalog.as_deref().unwrap_or(DEFAULT_CATALOG_NAME);
let schema = catalog_manager
.schema(DEFAULT_CATALOG_NAME, &schema)
.schema(catalog, &schema)
.context(error::CatalogSnafu)?
.context(error::SchemaNotFoundSnafu { schema })?;
let tables = schema.table_names().context(error::CatalogSnafu)?;

View File

@@ -30,6 +30,7 @@ use axum::body::BoxBody;
use axum::error_handling::HandleErrorLayer;
use axum::response::{Html, Json};
use axum::{routing, BoxError, Extension, Router};
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use common_error::prelude::ErrorExt;
use common_error::status_code::StatusCode;
use common_query::Output;
@@ -40,6 +41,7 @@ use futures::FutureExt;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use session::context::QueryContext;
use snafu::{ensure, ResultExt};
use tokio::sync::oneshot::{self, Sender};
use tokio::sync::Mutex;
@@ -58,6 +60,38 @@ use crate::query_handler::{
};
use crate::server::Server;
/// create query context from database name information, catalog and schema are
/// resolved from the name
pub(crate) fn query_context_from_db(
query_handler: SqlQueryHandlerRef,
db: Option<String>,
) -> std::result::Result<Arc<QueryContext>, JsonResponse> {
if let Some(db) = &db {
let (catalog, schema) = super::parse_catalog_and_schema_from_client_database_name(db);
let catalog = catalog.unwrap_or(DEFAULT_CATALOG_NAME);
match query_handler.is_valid_schema(catalog, schema) {
Ok(true) => Ok(Arc::new(QueryContext::with(
catalog.to_owned(),
schema.to_owned(),
))),
Ok(false) => Err(JsonResponse::with_error(
format!("Database not found: {db}"),
StatusCode::DatabaseNotFound,
)),
Err(e) => Err(JsonResponse::with_error(
format!("Error checking database: {db}, {e}"),
StatusCode::Internal,
)),
}
} else {
Ok(Arc::new(QueryContext::with(
DEFAULT_CATALOG_NAME.to_owned(),
DEFAULT_SCHEMA_NAME.to_owned(),
)))
}
}
const HTTP_API_VERSION: &str = "v1";
pub struct HttpServer {

View File

@@ -13,18 +13,16 @@
// limitations under the License.
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
use aide::transform::TransformOperation;
use axum::extract::{Json, Query, State};
use axum::Extension;
use common_catalog::consts::DEFAULT_CATALOG_NAME;
use common_error::status_code::StatusCode;
use common_telemetry::metric;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use session::context::{QueryContext, UserInfo};
use session::context::UserInfo;
use crate::http::{ApiState, JsonResponse};
@@ -45,26 +43,12 @@ pub async fn sql(
let sql_handler = &state.sql_handler;
let start = Instant::now();
let resp = if let Some(sql) = &params.sql {
let query_ctx = Arc::new(QueryContext::new());
if let Some(db) = &params.database {
match sql_handler.is_valid_schema(DEFAULT_CATALOG_NAME, db) {
Ok(true) => query_ctx.set_current_schema(db),
Ok(false) => {
return Json(JsonResponse::with_error(
format!("Database not found: {db}"),
StatusCode::DatabaseNotFound,
));
}
Err(e) => {
return Json(JsonResponse::with_error(
format!("Error checking database: {db}, {e}"),
StatusCode::Internal,
));
}
match super::query_context_from_db(sql_handler.clone(), params.database) {
Ok(query_ctx) => {
JsonResponse::from_output(sql_handler.do_query(sql, query_ctx).await).await
}
Err(resp) => resp,
}
JsonResponse::from_output(sql_handler.do_query(sql, query_ctx).await).await
} else {
JsonResponse::with_error(
"sql parameter is required.".to_string(),

View File

@@ -90,6 +90,8 @@ pub async fn run_script(
json_err!("invalid name");
}
// TODO(sunng87): query_context and db name resolution
let output = script_handler.execute_script(name.unwrap()).await;
let resp = JsonResponse::from_output(vec![output]).await;

View File

@@ -39,3 +39,50 @@ pub enum Mode {
Standalone,
Distributed,
}
/// Attempt to parse catalog and schema from given database name
///
/// The database name may come from different sources:
///
/// - MySQL `schema` name in MySQL protocol login request: it's optional and user
/// and switch database using `USE` command
/// - Postgres `database` parameter in Postgres wire protocol, required
/// - HTTP RESTful API: the database parameter, optional
///
/// When database name is provided, we attempt to parse catalog and schema from
/// it. We assume the format `[<catalog>-]<schema>`:
///
/// - If `[<catalog>-]` part is not provided, we use whole database name as
/// schema name
/// - if `[<catalog>-]` is provided, we split database name with `-` and use
/// `<catalog>` and `<schema>`.
pub(crate) fn parse_catalog_and_schema_from_client_database_name(db: &str) -> (Option<&str>, &str) {
let parts = db.splitn(2, '-').collect::<Vec<&str>>();
if parts.len() == 2 {
(Some(parts[0]), parts[1])
} else {
(None, db)
}
}
#[cfg(test)]
mod tests {
#[test]
fn test_parse_catalog_and_schema_from_client_database_name() {
assert_eq!(
(None, "fullschema"),
super::parse_catalog_and_schema_from_client_database_name("fullschema")
);
assert_eq!(
(Some("catalog"), "schema"),
super::parse_catalog_and_schema_from_client_database_name("catalog-schema")
);
assert_eq!(
(Some("catalog"), "schema1-schema2"),
super::parse_catalog_and_schema_from_client_database_name("catalog-schema1-schema2")
);
}
}

View File

@@ -18,5 +18,9 @@ mod server;
pub(crate) const METADATA_USER: &str = "user";
pub(crate) const METADATA_DATABASE: &str = "database";
/// key to store our parsed catalog
pub(crate) const METADATA_CATALOG: &str = "catalog";
/// key to store our parsed schema
pub(crate) const METADATA_SCHEMA: &str = "schema";
pub use server::PostgresServer;

View File

@@ -151,31 +151,19 @@ impl StartupHandler for PgAuthStartupHandler {
auth::save_startup_parameters_to_metadata(client, startup);
// check if db is valid
let db_ref = client.metadata().get(super::METADATA_DATABASE);
if let Some(db) = db_ref {
if !self
.query_handler
.is_valid_schema(DEFAULT_CATALOG_NAME, db)
.map_err(|e| PgWireError::ApiError(Box::new(e)))?
{
send_error(
client,
"FATAL",
"3D000",
format!("Database not found: {db}"),
)
.await?;
match resolve_db_info(client, self.query_handler.clone())? {
DbResolution::Resolved(catalog, schema) => {
client
.metadata_mut()
.insert(super::METADATA_CATALOG.to_owned(), catalog);
client
.metadata_mut()
.insert(super::METADATA_SCHEMA.to_owned(), schema);
}
DbResolution::NotFound(msg) => {
send_error(client, "FATAL", "3D000", msg).await?;
return Ok(());
}
} else {
send_error(
client,
"FATAL",
"3D000",
"Database not specified".to_owned(),
)
.await?;
return Ok(());
}
if self.verifier.user_provider.is_some() {
@@ -222,3 +210,36 @@ where
client.close().await?;
Ok(())
}
enum DbResolution {
Resolved(String, String),
NotFound(String),
}
/// A function extracted to resolve lifetime and readability issues:
fn resolve_db_info<C>(
client: &mut C,
query_handler: SqlQueryHandlerRef,
) -> PgWireResult<DbResolution>
where
C: ClientInfo + Unpin + Send,
{
let db_ref = client.metadata().get(super::METADATA_DATABASE);
if let Some(db) = db_ref {
let (catalog, schema) = crate::parse_catalog_and_schema_from_client_database_name(db);
let catalog = catalog.unwrap_or(DEFAULT_CATALOG_NAME);
if query_handler
.is_valid_schema(catalog, schema)
.map_err(|e| PgWireError::ApiError(Box::new(e)))?
{
Ok(DbResolution::Resolved(
catalog.to_owned(),
schema.to_owned(),
))
} else {
Ok(DbResolution::NotFound(format!("Database not found: {db}")))
}
} else {
Ok(DbResolution::NotFound("Database not specified".to_owned()))
}
}

View File

@@ -47,7 +47,10 @@ where
C: ClientInfo,
{
let query_context = QueryContext::new();
if let Some(current_schema) = client.metadata().get(super::METADATA_DATABASE) {
if let Some(current_catalog) = client.metadata().get(super::METADATA_CATALOG) {
query_context.set_current_catalog(current_catalog);
}
if let Some(current_schema) = client.metadata().get(super::METADATA_SCHEMA) {
query_context.set_current_schema(current_schema);
}

View File

@@ -16,7 +16,7 @@ use std::net::SocketAddr;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use common_catalog::consts::DEFAULT_SCHEMA_NAME;
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use common_runtime::Builder as RuntimeBuilder;
use rand::rngs::StdRng;
use rand::Rng;
@@ -249,6 +249,24 @@ async fn test_using_db() -> Result<()> {
.unwrap();
let result = client.simple_query("SELECT uint32s FROM numbers").await;
assert!(result.is_ok());
let client = create_connection_with_given_catalog_schema(
server_port,
DEFAULT_CATALOG_NAME,
DEFAULT_SCHEMA_NAME,
)
.await;
assert!(client.is_ok());
let client =
create_connection_with_given_catalog_schema(server_port, "notfound", DEFAULT_SCHEMA_NAME)
.await;
assert!(client.is_err());
let client =
create_connection_with_given_catalog_schema(server_port, DEFAULT_CATALOG_NAME, "notfound")
.await;
assert!(client.is_err());
Ok(())
}
@@ -330,6 +348,17 @@ async fn create_connection_with_given_db(
Ok(client)
}
async fn create_connection_with_given_catalog_schema(
port: u16,
catalog: &str,
schema: &str,
) -> std::result::Result<Client, PgError> {
let url = format!("host=127.0.0.1 port={port} connect_timeout=2 dbname={catalog}-{schema}");
let (client, conn) = tokio_postgres::connect(&url, NoTls).await?;
tokio::spawn(conn);
Ok(client)
}
async fn create_connection_without_db(port: u16) -> std::result::Result<Client, PgError> {
let url = format!("host=127.0.0.1 port={port} connect_timeout=2");
let (client, conn) = tokio_postgres::connect(&url, NoTls).await?;

View File

@@ -22,6 +22,7 @@ pub type QueryContextRef = Arc<QueryContext>;
pub type ConnInfoRef = Arc<ConnInfo>;
pub struct QueryContext {
current_catalog: ArcSwapOption<String>,
current_schema: ArcSwapOption<String>,
}
@@ -38,12 +39,14 @@ impl QueryContext {
pub fn new() -> Self {
Self {
current_catalog: ArcSwapOption::new(None),
current_schema: ArcSwapOption::new(None),
}
}
pub fn with_current_schema(schema: String) -> Self {
pub fn with(catalog: String, schema: String) -> Self {
Self {
current_catalog: ArcSwapOption::new(Some(Arc::new(catalog))),
current_schema: ArcSwapOption::new(Some(Arc::new(schema))),
}
}
@@ -52,6 +55,10 @@ impl QueryContext {
self.current_schema.load().as_deref().cloned()
}
pub fn current_catalog(&self) -> Option<String> {
self.current_catalog.load().as_deref().cloned()
}
pub fn set_current_schema(&self, schema: &str) {
let last = self.current_schema.swap(Some(Arc::new(schema.to_string())));
info!(
@@ -59,6 +66,16 @@ impl QueryContext {
schema, last
)
}
pub fn set_current_catalog(&self, catalog: &str) {
let last = self
.current_catalog
.swap(Some(Arc::new(catalog.to_string())));
info!(
"set new session default catalog: {:?}, swap old: {:?}",
catalog, last
)
}
}
pub const DEFAULT_USERNAME: &str = "greptime";

View File

@@ -221,6 +221,43 @@ pub async fn test_sql_api(store_type: StorageType) {
let body = serde_json::from_str::<JsonResponse>(&res.text().await).unwrap();
assert_eq!(body.code(), ErrorCode::DatabaseNotFound as u32);
// test catalog-schema given
let res = client
.get("/v1/sql?database=greptime-public&sql=select cpu, ts from demo limit 1")
.send()
.await;
assert_eq!(res.status(), StatusCode::OK);
let body = serde_json::from_str::<JsonResponse>(&res.text().await).unwrap();
assert!(body.success());
assert!(body.execution_time_ms().is_some());
let outputs = body.output().unwrap();
assert_eq!(outputs.len(), 1);
assert_eq!(
outputs[0],
serde_json::from_value::<JsonOutput>(json!({
"records":{"schema":{"column_schemas":[{"name":"cpu","data_type":"Float64"},{"name":"ts","data_type":"TimestampMillisecond"}]},"rows":[[66.6,0]]}
})).unwrap()
);
// test invalid catalog
let res = client
.get("/v1/sql?database=notfound2-schema&sql=select cpu, ts from demo limit 1")
.send()
.await;
assert_eq!(res.status(), StatusCode::OK);
let body = serde_json::from_str::<JsonResponse>(&res.text().await).unwrap();
assert_eq!(body.code(), ErrorCode::Internal as u32);
// test invalid schema
let res = client
.get("/v1/sql?database=greptime-schema&sql=select cpu, ts from demo limit 1")
.send()
.await;
assert_eq!(res.status(), StatusCode::OK);
let body = serde_json::from_str::<JsonResponse>(&res.text().await).unwrap();
assert_eq!(body.code(), ErrorCode::DatabaseNotFound as u32);
guard.remove_all().await;
}