feat: make scripts table work again (#2420)

* feat: make scripts table work again

* chore: typo

* fix: license header

* Update src/table/src/metadata.rs

Co-authored-by: Ruihang Xia <waynestxia@gmail.com>

* chore: cr comments

---------

Co-authored-by: Ruihang Xia <waynestxia@gmail.com>
This commit is contained in:
dennis zhuang
2023-09-18 19:43:21 +08:00
committed by GitHub
parent 14e6998d41
commit 5566f34bd1
17 changed files with 620 additions and 287 deletions

5
Cargo.lock generated
View File

@@ -3240,6 +3240,7 @@ name = "frontend"
version = "0.4.0-nightly"
dependencies = [
"api",
"arc-swap",
"arrow-flight",
"async-compat",
"async-stream",
@@ -8427,6 +8428,8 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
name = "script"
version = "0.4.0-nightly"
dependencies = [
"api",
"arc-swap",
"arrow",
"async-trait",
"catalog",
@@ -8452,6 +8455,7 @@ dependencies = [
"log-store",
"mito",
"once_cell",
"operator",
"paste",
"pyo3",
"query",
@@ -8466,6 +8470,7 @@ dependencies = [
"rustpython-stdlib",
"rustpython-vm",
"serde",
"servers",
"session",
"snafu",
"sql",

View File

@@ -11,6 +11,7 @@ testing = []
[dependencies]
api = { workspace = true }
arc-swap = "1.0"
arrow-flight.workspace = true
async-compat = "0.2"
async-stream.workspace = true

View File

@@ -393,7 +393,7 @@ impl FrontendInstance for Instance {
heartbeat_task.start().await?;
}
self.script_executor.start(self).await?;
self.script_executor.start(self)?;
futures::future::try_join_all(self.servers.values().map(start_server))
.await

View File

@@ -18,6 +18,7 @@ use async_trait::async_trait;
use common_query::Output;
use common_telemetry::timer;
use servers::query_handler::ScriptHandler;
use session::context::QueryContextRef;
use crate::instance::Instance;
use crate::metrics;
@@ -26,25 +27,25 @@ use crate::metrics;
impl ScriptHandler for Instance {
async fn insert_script(
&self,
schema: &str,
query_ctx: QueryContextRef,
name: &str,
script: &str,
) -> servers::error::Result<()> {
let _timer = timer!(metrics::METRIC_HANDLE_SCRIPTS_ELAPSED);
self.script_executor
.insert_script(schema, name, script)
.insert_script(query_ctx, name, script)
.await
}
async fn execute_script(
&self,
schema: &str,
query_ctx: QueryContextRef,
name: &str,
params: HashMap<String, String>,
) -> servers::error::Result<Output> {
let _timer = timer!(metrics::METRIC_RUN_SCRIPT_ELAPSED);
self.script_executor
.execute_script(schema, name, params)
.execute_script(query_ctx, name, params)
.await
}
}

View File

@@ -13,12 +13,17 @@
// limitations under the License.
use std::collections::HashMap;
use std::sync::Arc;
use catalog::CatalogManagerRef;
use common_query::Output;
use query::QueryEngineRef;
use servers::query_handler::grpc::GrpcQueryHandler;
use session::context::QueryContextRef;
use crate::error::Result;
use crate::error::{Error, Result};
type FrontendGrpcQueryHandlerRef = Arc<dyn GrpcQueryHandler<Error = Error> + Send + Sync>;
#[cfg(not(feature = "python"))]
mod dummy {
@@ -34,13 +39,13 @@ mod dummy {
Ok(Self {})
}
pub async fn start(&self) -> Result<()> {
pub fn start(&self, instance: &Instance) -> Result<()> {
Ok(())
}
pub async fn insert_script(
&self,
_schema: &str,
_query_ctx: QueryContextRef,
_name: &str,
_script: &str,
) -> servers::error::Result<()> {
@@ -49,7 +54,7 @@ mod dummy {
pub async fn execute_script(
&self,
_schema: &str,
_query_ctx: QueryContextRef,
_name: &str,
_params: HashMap<String, String>,
) -> servers::error::Result<Output> {
@@ -63,10 +68,11 @@ mod python {
use api::v1::ddl_request::Expr;
use api::v1::greptime_request::Request;
use api::v1::{CreateTableExpr, DdlRequest};
use arc_swap::ArcSwap;
use catalog::RegisterSystemTableRequest;
use common_error::ext::BoxedError;
use common_meta::table_name::TableName;
use common_telemetry::logging::error;
use common_telemetry::{error, info};
use operator::expr_factory;
use script::manager::ScriptManager;
use servers::query_handler::grpc::GrpcQueryHandler;
@@ -78,8 +84,33 @@ mod python {
use crate::error::{CatalogSnafu, InvalidSystemTableDefSnafu, TableNotFoundSnafu};
use crate::instance::Instance;
/// A placeholder for the real gRPC handler.
/// It is temporary and will be replaced soon.
struct DummyHandler;
impl DummyHandler {
fn arc() -> Arc<Self> {
Arc::new(Self {})
}
}
#[async_trait::async_trait]
impl GrpcQueryHandler for DummyHandler {
type Error = Error;
async fn do_query(
&self,
_query: Request,
_ctx: QueryContextRef,
) -> std::result::Result<Output, Self::Error> {
unreachable!();
}
}
pub struct ScriptExecutor {
script_manager: ScriptManager,
script_manager: ScriptManager<Error>,
grpc_handler: ArcSwap<FrontendGrpcQueryHandlerRef>,
catalog_manager: CatalogManagerRef,
}
impl ScriptExecutor {
@@ -87,21 +118,42 @@ mod python {
catalog_manager: CatalogManagerRef,
query_engine: QueryEngineRef,
) -> Result<Self> {
let grpc_handler = DummyHandler::arc();
Ok(Self {
script_manager: ScriptManager::new(catalog_manager, query_engine)
grpc_handler: ArcSwap::new(Arc::new(grpc_handler.clone() as _)),
script_manager: ScriptManager::new(grpc_handler as _, query_engine)
.await
.context(crate::error::StartScriptManagerSnafu)?,
catalog_manager,
})
}
pub async fn start(&self, instance: &Instance) -> Result<()> {
pub fn start(&self, instance: &Instance) -> Result<()> {
let handler = Arc::new(instance.clone());
self.grpc_handler.store(Arc::new(handler.clone() as _));
self.script_manager
.start(handler)
.context(crate::error::StartScriptManagerSnafu)?;
Ok(())
}
/// Create scripts table for the specific catalog if it's not exists.
/// The function is idempotent and safe to be called more than once for the same catalog
async fn create_scripts_table_if_need(&self, catalog: &str) -> Result<()> {
let scripts_table = self.script_manager.get_scripts_table(catalog);
if scripts_table.is_some() {
return Ok(());
}
let RegisterSystemTableRequest {
create_table_request: request,
open_hook,
} = self.script_manager.create_table_request();
} = self.script_manager.create_table_request(catalog);
if let Some(table) = instance
.catalog_manager()
if let Some(table) = self
.catalog_manager
.table(
&request.catalog_name,
&request.schema_name,
@@ -111,9 +163,11 @@ mod python {
.context(CatalogSnafu)?
{
if let Some(open_hook) = open_hook {
(open_hook)(table).await.context(CatalogSnafu)?;
(open_hook)(table.clone()).await.context(CatalogSnafu)?;
}
self.script_manager.insert_scripts_table(catalog, table);
return Ok(());
}
@@ -125,7 +179,9 @@ mod python {
let expr = Self::create_table_expr(request)?;
let _ = instance
let _ = self
.grpc_handler
.load()
.do_query(
Request::Ddl(DdlRequest {
expr: Some(Expr::CreateTable(expr)),
@@ -134,8 +190,8 @@ mod python {
)
.await?;
let table = instance
.catalog_manager()
let table = self
.catalog_manager
.table(
&table_name.catalog_name,
&table_name.schema_name,
@@ -148,9 +204,16 @@ mod python {
})?;
if let Some(open_hook) = open_hook {
(open_hook)(table).await.context(CatalogSnafu)?;
(open_hook)(table.clone()).await.context(CatalogSnafu)?;
}
info!(
"Created scripts table {}.",
table.table_info().full_table_name()
);
self.script_manager.insert_scripts_table(catalog, table);
Ok(())
}
@@ -196,16 +259,31 @@ mod python {
pub async fn insert_script(
&self,
schema: &str,
query_ctx: QueryContextRef,
name: &str,
script: &str,
) -> servers::error::Result<()> {
let _s = self
.script_manager
.insert_and_compile(schema, name, script)
self.create_scripts_table_if_need(query_ctx.current_catalog())
.await
.map_err(|e| {
error!(e; "Instance failed to insert script");
error!(e; "Failed to create scripts table");
servers::error::InternalSnafu {
err_msg: e.to_string(),
}
.build()
})?;
let _s = self
.script_manager
.insert_and_compile(
query_ctx.current_catalog(),
query_ctx.current_schema(),
name,
script,
)
.await
.map_err(|e| {
error!(e; "Failed to insert script");
BoxedError::new(e)
})
.context(servers::error::InsertScriptSnafu { name })?;
@@ -215,15 +293,30 @@ mod python {
pub async fn execute_script(
&self,
schema: &str,
query_ctx: QueryContextRef,
name: &str,
params: HashMap<String, String>,
) -> servers::error::Result<Output> {
self.script_manager
.execute(schema, name, params)
self.create_scripts_table_if_need(query_ctx.current_catalog())
.await
.map_err(|e| {
error!(e; "Instance failed to execute script");
error!(e; "Failed to create scripts table");
servers::error::InternalSnafu {
err_msg: e.to_string(),
}
.build()
})?;
self.script_manager
.execute(
query_ctx.current_catalog(),
query_ctx.current_schema(),
name,
params,
)
.await
.map_err(|e| {
error!(e; "Failed to execute script");
BoxedError::new(e)
})
.context(servers::error::ExecuteScriptSnafu { name })

View File

@@ -24,6 +24,8 @@ python = [
]
[dependencies]
api.workspace = true
arc-swap = "1.0"
arrow.workspace = true
async-trait.workspace = true
catalog = { workspace = true }
@@ -62,6 +64,7 @@ rustpython-vm = { git = "https://github.com/discord9/RustPython", optional = tru
"default",
"codegen",
] }
servers.workspace = true
session = { workspace = true }
snafu = { version = "0.7", features = ["backtraces"] }
sql = { workspace = true }
@@ -75,6 +78,7 @@ common-test-util = { workspace = true }
criterion = { version = "0.4", features = ["html_reports", "async_tokio"] }
log-store = { workspace = true }
mito = { workspace = true }
operator.workspace = true
rayon = "1.0"
ron = "0.7"
serde = { version = "1.0", features = ["derive"] }

View File

@@ -14,28 +14,16 @@
use std::any::Any;
use common_error::ext::ErrorExt;
use common_error::ext::{BoxedError, ErrorExt};
use common_error::status_code::StatusCode;
use snafu::{Location, Snafu};
#[derive(Debug, Snafu)]
#[snafu(visibility(pub))]
pub enum Error {
#[snafu(display("Failed to find scripts table, source: {}", source))]
FindScriptsTable {
location: Location,
source: catalog::error::Error,
},
#[snafu(display("Failed to find column in scripts table, name: {}", name))]
FindColumnInScriptsTable { name: String, location: Location },
#[snafu(display("Failed to register scripts table, source: {}", source))]
RegisterScriptsTable {
location: Location,
source: catalog::error::Error,
},
#[snafu(display("Scripts table not found"))]
ScriptsTableNotFound { location: Location },
@@ -47,7 +35,7 @@ pub enum Error {
InsertScript {
name: String,
location: Location,
source: table::error::Error,
source: BoxedError,
},
#[snafu(display("Failed to compile python script, name: {}, source: {}", name, source))]
@@ -67,13 +55,6 @@ pub enum Error {
#[snafu(display("Script not found, name: {}", name))]
ScriptNotFound { location: Location, name: String },
#[snafu(display("Failed to find script by name: {}", name))]
FindScript {
name: String,
location: Location,
source: query::error::Error,
},
#[snafu(display("Failed to collect record batch, source: {}", source))]
CollectRecords {
location: Location,
@@ -104,12 +85,8 @@ impl ErrorExt for Error {
match self {
FindColumnInScriptsTable { .. } | CastType { .. } => StatusCode::Unexpected,
ScriptsTableNotFound { .. } => StatusCode::TableNotFound,
RegisterScriptsTable { source, .. } | FindScriptsTable { source, .. } => {
source.status_code()
}
InsertScript { source, .. } => source.status_code(),
CompilePython { source, .. } | ExecutePython { source, .. } => source.status_code(),
FindScript { source, .. } => source.status_code(),
CollectRecords { source, .. } => source.status_code(),
ScriptNotFound { .. } => StatusCode::InvalidArguments,
BuildDfLogicalPlan { .. } => StatusCode::Internal,

View File

@@ -22,3 +22,5 @@ pub mod manager;
#[cfg(feature = "python")]
pub mod python;
pub mod table;
#[cfg(test)]
mod test;

View File

@@ -16,54 +16,70 @@
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use catalog::{CatalogManagerRef, OpenSystemTableHook, RegisterSystemTableRequest};
use common_catalog::consts::{
default_engine, DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, SCRIPTS_TABLE_ID,
};
use arc_swap::ArcSwap;
use catalog::{OpenSystemTableHook, RegisterSystemTableRequest};
use common_catalog::consts::{default_engine, DEFAULT_SCHEMA_NAME, SCRIPTS_TABLE_ID};
use common_error::ext::ErrorExt;
use common_query::Output;
use common_telemetry::logging;
use futures::future::FutureExt;
use query::QueryEngineRef;
use servers::query_handler::grpc::GrpcQueryHandlerRef;
use snafu::{OptionExt, ResultExt};
use table::requests::{CreateTableRequest, TableOptions};
use table::TableRef;
use crate::engine::{CompileContext, EvalContext, Script, ScriptEngine};
use crate::error::{CompilePythonSnafu, ExecutePythonSnafu, Result, ScriptNotFoundSnafu};
use crate::error::{
CompilePythonSnafu, ExecutePythonSnafu, Result, ScriptNotFoundSnafu, ScriptsTableNotFoundSnafu,
};
use crate::python::{PyEngine, PyScript};
use crate::table::{build_scripts_schema, ScriptsTable, SCRIPTS_TABLE_NAME};
use crate::table::{
build_scripts_schema, get_primary_key_indices, ScriptsTable, ScriptsTableRef,
SCRIPTS_TABLE_NAME,
};
pub struct ScriptManager {
pub struct ScriptManager<E: ErrorExt + Send + Sync + 'static> {
compiled: RwLock<HashMap<String, Arc<PyScript>>>,
py_engine: PyEngine,
grpc_handler: ArcSwap<GrpcQueryHandlerRef<E>>,
// Catalog name -> `[ScriptsTable]`
tables: RwLock<HashMap<String, ScriptsTableRef<E>>>,
query_engine: QueryEngineRef,
table: ScriptsTable,
}
impl ScriptManager {
impl<E: ErrorExt + Send + Sync + 'static> ScriptManager<E> {
pub async fn new(
catalog_manager: CatalogManagerRef,
grpc_handler: GrpcQueryHandlerRef<E>,
query_engine: QueryEngineRef,
) -> Result<Self> {
Ok(Self {
compiled: RwLock::new(HashMap::default()),
py_engine: PyEngine::new(query_engine.clone()),
query_engine: query_engine.clone(),
table: ScriptsTable::new(catalog_manager, query_engine).await?,
query_engine,
grpc_handler: ArcSwap::new(Arc::new(grpc_handler)),
tables: RwLock::new(HashMap::default()),
})
}
pub fn create_table_request(&self) -> RegisterSystemTableRequest {
pub fn start(&self, grpc_handler: GrpcQueryHandlerRef<E>) -> Result<()> {
self.grpc_handler.store(Arc::new(grpc_handler));
Ok(())
}
pub fn create_table_request(&self, catalog: &str) -> RegisterSystemTableRequest {
let request = CreateTableRequest {
id: SCRIPTS_TABLE_ID,
catalog_name: DEFAULT_CATALOG_NAME.to_string(),
catalog_name: catalog.to_string(),
// TODO(dennis): put the scripts table into `system` schema?
// We always put the scripts table into `public` schema right now.
schema_name: DEFAULT_SCHEMA_NAME.to_string(),
table_name: SCRIPTS_TABLE_NAME.to_string(),
desc: Some("GreptimeDB scripts table for Python".to_string()),
schema: build_scripts_schema(),
region_numbers: vec![0],
// 'schema' and 'name' are primary keys
primary_key_indices: vec![0, 1],
primary_key_indices: get_primary_key_indices(),
create_if_not_exists: true,
table_options: TableOptions::default(),
engine: default_engine().to_string(),
@@ -73,7 +89,7 @@ impl ScriptManager {
let hook: OpenSystemTableHook = Box::new(move |table: TableRef| {
let query_engine = query_engine.clone();
async move { ScriptsTable::recompile_register_udf(table, query_engine.clone()).await }
async move { ScriptsTable::<E>::recompile_register_udf(table, query_engine.clone()).await }
.boxed()
});
@@ -112,19 +128,48 @@ impl ScriptManager {
.context(CompilePythonSnafu { name })
}
/// Get the scripts table in the catalog
pub fn get_scripts_table(&self, catalog: &str) -> Option<ScriptsTableRef<E>> {
self.tables.read().unwrap().get(catalog).cloned()
}
/// Insert a scripts table.
pub fn insert_scripts_table(&self, catalog: &str, table: TableRef) {
let mut tables = self.tables.write().unwrap();
if tables.get(catalog).is_some() {
return;
}
tables.insert(
catalog.to_string(),
Arc::new(ScriptsTable::new(
table,
self.grpc_handler.load().as_ref().clone(),
self.query_engine.clone(),
)),
);
}
pub async fn insert_and_compile(
&self,
catalog: &str,
schema: &str,
name: &str,
script: &str,
) -> Result<Arc<PyScript>> {
let compiled_script = self.compile(name, script).await?;
self.table.insert(schema, name, script).await?;
self.get_scripts_table(catalog)
.context(ScriptsTableNotFoundSnafu)?
.insert(schema, name, script)
.await?;
Ok(compiled_script)
}
pub async fn execute(
&self,
catalog: &str,
schema: &str,
name: &str,
params: HashMap<String, String>,
@@ -135,7 +180,8 @@ impl ScriptManager {
if s.is_some() {
s
} else {
self.try_find_script_and_compile(schema, name).await?
self.try_find_script_and_compile(catalog, schema, name)
.await?
}
};
@@ -149,10 +195,15 @@ impl ScriptManager {
async fn try_find_script_and_compile(
&self,
catalog: &str,
schema: &str,
name: &str,
) -> Result<Option<Arc<PyScript>>> {
let script = self.table.find_script_by_name(schema, name).await?;
let script = self
.get_scripts_table(catalog)
.context(ScriptsTableNotFoundSnafu)?
.find_script_by_name(schema, name)
.await?;
Ok(Some(self.compile(name, &script).await?))
}
@@ -160,50 +211,67 @@ impl ScriptManager {
#[cfg(test)]
mod tests {
use catalog::memory::MemoryCatalogManager;
use query::QueryEngineFactory;
use super::*;
use crate::test::setup_scripts_manager;
#[ignore = "script engine is temporary disabled"]
#[tokio::test]
async fn test_insert_find_compile_script() {
common_telemetry::init_default_ut_logging();
let catalog_manager = MemoryCatalogManager::new();
let factory = QueryEngineFactory::new(catalog_manager.clone(), None, false);
let query_engine = factory.query_engine();
let mgr = ScriptManager::new(catalog_manager.clone(), query_engine)
.await
.unwrap();
let catalog = "greptime";
let schema = "schema";
let name = "test";
mgr.table
.insert(
schema,
name,
r#"
@copr(sql='select number from numbers limit 10', args=['number'], returns=['n'])
def test(n):
return n + 1;
"#,
)
.await
.unwrap();
let script = r#"
@copr(returns=['n'])
def test() -> vector[str]:
return 'hello';
"#;
let mgr = setup_scripts_manager(catalog, schema, name, script).await;
{
let cached = mgr.compiled.read().unwrap();
assert!(cached.get(name).is_none());
}
mgr.insert_and_compile(catalog, schema, name, script)
.await
.unwrap();
{
let cached = mgr.compiled.read().unwrap();
assert!(cached.get(name).is_some());
}
// try to find and compile
let script = mgr.try_find_script_and_compile(schema, name).await.unwrap();
let script = mgr
.try_find_script_and_compile(catalog, schema, name)
.await
.unwrap();
let _ = script.unwrap();
{
let cached = mgr.compiled.read().unwrap();
let _ = cached.get(name).unwrap();
}
// execute script
let output = mgr
.execute(catalog, schema, name, HashMap::new())
.await
.unwrap();
match output {
Output::RecordBatches(batches) => {
let expected = "\
+-------+
| n |
+-------+
| hello |
+-------+";
assert_eq!(expected, batches.pretty_print().unwrap());
}
_ => unreachable!(),
}
}
}

View File

@@ -75,6 +75,7 @@ impl PyUDF {
fn register_as_udf(zelf: Arc<Self>) {
FUNCTION_REGISTRY.register(zelf)
}
fn register_to_query_engine(zelf: Arc<Self>, engine: QueryEngineRef) {
engine.register_function(zelf)
}
@@ -138,10 +139,12 @@ impl Function for PyUDF {
}
}
}
// The Volatility should be volatile, the return value from evaluation may be changed.
if know_all_types {
Signature::variadic(arg_types, Volatility::Immutable)
Signature::variadic(arg_types, Volatility::Volatile)
} else {
Signature::any(self.copr.arg_types.len(), Volatility::Immutable)
Signature::any(self.copr.arg_types.len(), Volatility::Volatile)
}
}

View File

@@ -13,49 +13,68 @@
// limitations under the License.
//! Scripts table
use std::collections::HashMap;
use std::sync::Arc;
use api::helper::ColumnDataTypeWrapper;
use api::v1::greptime_request::Request;
use api::v1::value::ValueData;
use api::v1::{
ColumnDataType, ColumnSchema as PbColumnSchema, Row, RowInsertRequest, RowInsertRequests, Rows,
SemanticType,
};
use catalog::error::CompileScriptInternalSnafu;
use catalog::CatalogManagerRef;
use common_catalog::consts::{DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME};
use common_catalog::format_full_table_name;
use common_error::ext::BoxedError;
use common_error::ext::{BoxedError, ErrorExt};
use common_query::Output;
use common_recordbatch::{util as record_util, RecordBatch, SendableRecordBatchStream};
use common_telemetry::logging;
use common_time::util;
use datafusion::datasource::DefaultTableSource;
use datafusion::logical_expr::{and, col, lit};
use datafusion_common::TableReference;
use datafusion_expr::LogicalPlanBuilder;
use datatypes::prelude::{ConcreteDataType, ScalarVector};
use datatypes::prelude::ScalarVector;
use datatypes::schema::{ColumnSchema, RawSchema};
use datatypes::vectors::{StringVector, TimestampMillisecondVector, Vector, VectorRef};
use query::parser::QueryLanguageParser;
use datatypes::vectors::{StringVector, Vector};
use query::plan::LogicalPlan;
use query::QueryEngineRef;
use session::context::QueryContextBuilder;
use servers::query_handler::grpc::GrpcQueryHandlerRef;
use session::context::{QueryContextBuilder, QueryContextRef};
use snafu::{ensure, OptionExt, ResultExt};
use table::requests::InsertRequest;
use table::metadata::TableInfo;
use table::table::adapter::DfTableProviderAdapter;
use table::TableRef;
use crate::error::{
BuildDfLogicalPlanSnafu, CastTypeSnafu, CollectRecordsSnafu, ExecuteInternalStatementSnafu,
FindColumnInScriptsTableSnafu, FindScriptSnafu, FindScriptsTableSnafu, InsertScriptSnafu,
Result, ScriptNotFoundSnafu, ScriptsTableNotFoundSnafu,
FindColumnInScriptsTableSnafu, InsertScriptSnafu, Result, ScriptNotFoundSnafu,
};
use crate::python::PyScript;
pub const SCRIPTS_TABLE_NAME: &str = "scripts";
pub struct ScriptsTable {
catalog_manager: CatalogManagerRef,
pub type ScriptsTableRef<E> = Arc<ScriptsTable<E>>;
/// The scripts table that keeps the script content etc.
pub struct ScriptsTable<E: ErrorExt + Send + Sync + 'static> {
table: TableRef,
grpc_handler: GrpcQueryHandlerRef<E>,
query_engine: QueryEngineRef,
name: String,
}
impl ScriptsTable {
impl<E: ErrorExt + Send + Sync + 'static> ScriptsTable<E> {
/// Create a new `[ScriptsTable]` based on the table.
pub fn new(
table: TableRef,
grpc_handler: GrpcQueryHandlerRef<E>,
query_engine: QueryEngineRef,
) -> Self {
Self {
table,
grpc_handler,
query_engine,
}
}
fn get_str_col_by_name<'a>(record: &'a RecordBatch, name: &str) -> Result<&'a StringVector> {
let column = record
.column_by_name(name)
@@ -79,6 +98,8 @@ impl ScriptsTable {
table: TableRef,
query_engine: QueryEngineRef,
) -> catalog::error::Result<()> {
let table_info = table.table_info();
let rbs = Self::table_full_scan(table, &query_engine)
.await
.map_err(BoxedError::new)
@@ -108,6 +129,12 @@ impl ScriptsTable {
script_list.extend(part_of_scripts_list);
}
logging::info!(
"Found {} scripts in {}",
script_list.len(),
table_info.full_table_name()
);
for (name, script) in script_list {
match PyScript::from_script(&script, query_engine.clone()) {
Ok(script) => {
@@ -129,123 +156,86 @@ impl ScriptsTable {
Ok(())
}
pub fn new_empty(
catalog_manager: CatalogManagerRef,
query_engine: QueryEngineRef,
) -> Result<Self> {
Ok(Self {
catalog_manager,
query_engine,
name: format_full_table_name(
DEFAULT_CATALOG_NAME,
DEFAULT_SCHEMA_NAME,
SCRIPTS_TABLE_NAME,
),
})
}
pub async fn new(
catalog_manager: CatalogManagerRef,
query_engine: QueryEngineRef,
) -> Result<Self> {
Ok(Self {
catalog_manager,
query_engine,
name: format_full_table_name(
DEFAULT_CATALOG_NAME,
DEFAULT_SCHEMA_NAME,
SCRIPTS_TABLE_NAME,
),
})
}
pub async fn insert(&self, schema: &str, name: &str, script: &str) -> Result<()> {
let now = util::current_time_millis();
let columns_values: HashMap<String, VectorRef> = HashMap::from([
(
"schema".to_string(),
Arc::new(StringVector::from(vec![schema])) as VectorRef,
),
("name".to_string(), Arc::new(StringVector::from(vec![name]))),
(
"script".to_string(),
Arc::new(StringVector::from(vec![script])) as VectorRef,
),
(
"engine".to_string(),
// TODO(dennis): we only supports python right now.
Arc::new(StringVector::from(vec!["python"])) as VectorRef,
),
(
"timestamp".to_string(),
// Timestamp in key part is intentionally left to 0
Arc::new(TimestampMillisecondVector::from_slice([0])) as VectorRef,
),
(
"gmt_created".to_string(),
Arc::new(TimestampMillisecondVector::from_slice([now])) as VectorRef,
),
(
"gmt_modified".to_string(),
Arc::new(TimestampMillisecondVector::from_slice([now])) as VectorRef,
),
]);
let table = self
.catalog_manager
.table(
DEFAULT_CATALOG_NAME,
DEFAULT_SCHEMA_NAME,
SCRIPTS_TABLE_NAME,
)
.await
.context(FindScriptsTableSnafu)?
.context(ScriptsTableNotFoundSnafu)?;
let _ = table
.insert(InsertRequest {
catalog_name: DEFAULT_CATALOG_NAME.to_string(),
schema_name: DEFAULT_SCHEMA_NAME.to_string(),
table_name: SCRIPTS_TABLE_NAME.to_string(),
columns_values,
region_number: 0,
})
let table_info = self.table.table_info();
let insert = RowInsertRequest {
table_name: SCRIPTS_TABLE_NAME.to_string(),
rows: Some(Rows {
schema: build_insert_column_schemas(),
rows: vec![Row {
values: vec![
ValueData::StringValue(schema.to_string()).into(),
ValueData::StringValue(name.to_string()).into(),
// TODO(dennis): we only supports python right now.
ValueData::StringValue("python".to_string()).into(),
ValueData::StringValue(script.to_string()).into(),
// Timestamp in key part is intentionally left to 0
ValueData::TimestampMillisecondValue(0).into(),
ValueData::TimestampMillisecondValue(now).into(),
],
}],
}),
};
let requests = RowInsertRequests {
inserts: vec![insert],
};
let output = self
.grpc_handler
.do_query(Request::RowInserts(requests), query_ctx(&table_info))
.await
.map_err(BoxedError::new)
.context(InsertScriptSnafu { name })?;
logging::info!("Inserted script: name={} into scripts table.", name);
logging::info!(
"Inserted script: {} into scripts table: {}, output: {:?}.",
name,
table_info.full_table_name(),
output
);
Ok(())
}
pub async fn find_script_by_name(&self, schema: &str, name: &str) -> Result<String> {
// FIXME(dennis): SQL injection
// TODO(dennis): we use sql to find the script, the better way is use a function
// such as `find_record_by_primary_key` in table_engine.
let sql = format!(
"select script from {} where schema='{}' and name='{}'",
self.name(),
schema,
name
let table_info = self.table.table_info();
let table_name = TableReference::full(
table_info.catalog_name.clone(),
table_info.schema_name.clone(),
table_info.name.clone(),
);
let stmt = QueryLanguageParser::parse_sql(&sql).unwrap();
let ctx = QueryContextBuilder::default().build();
let plan = self
.query_engine
.planner()
.plan(stmt, ctx.clone())
.await
.unwrap();
let table_provider = Arc::new(DfTableProviderAdapter::new(self.table.clone()));
let table_source = Arc::new(DefaultTableSource::new(table_provider));
let stream = match self
let plan = LogicalPlanBuilder::scan(table_name, table_source, None)
.context(BuildDfLogicalPlanSnafu)?
.filter(and(
col("schema").eq(lit(schema)),
col("name").eq(lit(name)),
))
.context(BuildDfLogicalPlanSnafu)?
.project(vec![col("script")])
.context(BuildDfLogicalPlanSnafu)?
.build()
.context(BuildDfLogicalPlanSnafu)?;
let output = self
.query_engine
.execute(plan, ctx)
.execute(LogicalPlan::DfPlan(plan), query_ctx(&table_info))
.await
.context(FindScriptSnafu { name })?
{
.context(ExecuteInternalStatementSnafu)?;
let stream = match output {
Output::Stream(stream) => stream,
Output::RecordBatches(record_batches) => record_batches.as_stream(),
_ => unreachable!(),
};
let records = record_util::collect(stream)
.await
.context(CollectRecordsSnafu)?;
@@ -267,12 +257,9 @@ impl ScriptsTable {
})?;
assert_eq!(script_column.len(), 1);
Ok(script_column.get_data(0).unwrap().to_string())
}
#[inline]
pub fn name(&self) -> &str {
&self.name
// Safety: asserted above
Ok(script_column.get_data(0).unwrap().to_string())
}
async fn table_full_scan(
@@ -295,10 +282,7 @@ impl ScriptsTable {
.context(BuildDfLogicalPlanSnafu)?;
let output = query_engine
.execute(
LogicalPlan::DfPlan(plan),
QueryContextBuilder::default().build(),
)
.execute(LogicalPlan::DfPlan(plan), query_ctx(&table_info))
.await
.context(ExecuteInternalStatementSnafu)?;
let stream = match output {
@@ -310,46 +294,80 @@ impl ScriptsTable {
}
}
/// Build the inserted column schemas
fn build_insert_column_schemas() -> Vec<PbColumnSchema> {
vec![
// The schema that script belongs to.
PbColumnSchema {
column_name: "schema".to_string(),
datatype: ColumnDataType::String.into(),
semantic_type: SemanticType::Tag.into(),
},
PbColumnSchema {
column_name: "name".to_string(),
datatype: ColumnDataType::String.into(),
semantic_type: SemanticType::Tag.into(),
},
PbColumnSchema {
column_name: "engine".to_string(),
datatype: ColumnDataType::String.into(),
semantic_type: SemanticType::Tag.into(),
},
PbColumnSchema {
column_name: "script".to_string(),
datatype: ColumnDataType::String.into(),
semantic_type: SemanticType::Field.into(),
},
PbColumnSchema {
column_name: "greptime_timestamp".to_string(),
datatype: ColumnDataType::TimestampMillisecond.into(),
semantic_type: SemanticType::Timestamp.into(),
},
PbColumnSchema {
column_name: "gmt_modified".to_string(),
datatype: ColumnDataType::TimestampMillisecond.into(),
semantic_type: SemanticType::Field.into(),
},
]
}
fn query_ctx(table_info: &TableInfo) -> QueryContextRef {
QueryContextBuilder::default()
.current_catalog(table_info.catalog_name.to_string())
.current_schema(table_info.schema_name.to_string())
.build()
}
/// Returns the scripts schema's primary key indices
pub fn get_primary_key_indices() -> Vec<usize> {
let mut indices = vec![];
for (index, c) in build_insert_column_schemas().into_iter().enumerate() {
if c.semantic_type == (SemanticType::Tag as i32) {
indices.push(index);
}
}
indices
}
/// Build scripts table
pub fn build_scripts_schema() -> RawSchema {
let cols = vec![
ColumnSchema::new(
"schema".to_string(),
ConcreteDataType::string_datatype(),
false,
),
ColumnSchema::new(
"name".to_string(),
ConcreteDataType::string_datatype(),
false,
),
ColumnSchema::new(
"script".to_string(),
ConcreteDataType::string_datatype(),
false,
),
ColumnSchema::new(
"engine".to_string(),
ConcreteDataType::string_datatype(),
false,
),
ColumnSchema::new(
"timestamp".to_string(),
ConcreteDataType::timestamp_millisecond_datatype(),
false,
)
.with_time_index(true),
ColumnSchema::new(
"gmt_created".to_string(),
ConcreteDataType::timestamp_millisecond_datatype(),
false,
),
ColumnSchema::new(
"gmt_modified".to_string(),
ConcreteDataType::timestamp_millisecond_datatype(),
false,
),
];
let cols = build_insert_column_schemas()
.into_iter()
.map(|c| {
let cs = ColumnSchema::new(
c.column_name,
// Safety: the type always exists
ColumnDataTypeWrapper::try_new(c.datatype).unwrap().into(),
false,
);
if c.semantic_type == SemanticType::Timestamp as i32 {
cs.with_time_index(true)
} else {
cs
}
})
.collect();
RawSchema::new(cols)
}

78
src/script/src/test.rs Normal file
View File

@@ -0,0 +1,78 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use api::v1::greptime_request::Request;
use async_trait::async_trait;
use catalog::memory::MemoryCatalogManager;
use common_query::Output;
use common_recordbatch::RecordBatch;
use datatypes::prelude::ConcreteDataType;
use datatypes::schema::{ColumnSchema, Schema};
use datatypes::vectors::{StringVector, VectorRef};
use query::QueryEngineFactory;
use servers::query_handler::grpc::GrpcQueryHandler;
use session::context::QueryContextRef;
use table::test_util::MemTable;
use crate::error::{Error, Result};
use crate::manager::ScriptManager;
/// Setup the scripts table and create a script manager.
pub async fn setup_scripts_manager(
catalog: &str,
schema: &str,
name: &str,
script: &str,
) -> ScriptManager<Error> {
let column_schemas = vec![
ColumnSchema::new("script", ConcreteDataType::string_datatype(), false),
ColumnSchema::new("schema", ConcreteDataType::string_datatype(), false),
ColumnSchema::new("name", ConcreteDataType::string_datatype(), false),
];
let columns: Vec<VectorRef> = vec![
Arc::new(StringVector::from(vec![script])),
Arc::new(StringVector::from(vec![schema])),
Arc::new(StringVector::from(vec![name])),
];
let schema = Arc::new(Schema::new(column_schemas));
let recordbatch = RecordBatch::new(schema, columns).unwrap();
let table = MemTable::table("scripts", recordbatch);
let catalog_manager = MemoryCatalogManager::new_with_table(table.clone());
let factory = QueryEngineFactory::new(catalog_manager.clone(), None, false);
let query_engine = factory.query_engine();
let mgr = ScriptManager::new(Arc::new(MockGrpcQueryHandler {}) as _, query_engine)
.await
.unwrap();
mgr.insert_scripts_table(catalog, table);
mgr
}
struct MockGrpcQueryHandler {}
#[async_trait]
impl GrpcQueryHandler for MockGrpcQueryHandler {
type Error = Error;
async fn do_query(&self, _query: Request, _ctx: QueryContextRef) -> Result<Output> {
Ok(Output::AffectedRows(1))
}
}

View File

@@ -15,9 +15,11 @@ use std::collections::HashMap;
use std::time::Instant;
use axum::extract::{Json, Query, RawBody, State};
use common_catalog::consts::DEFAULT_CATALOG_NAME;
use common_error::ext::ErrorExt;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use session::context::QueryContext;
use crate::http::{ApiState, JsonResponse};
@@ -51,6 +53,9 @@ pub async fn scripts(
RawBody(body): RawBody,
) -> Json<JsonResponse> {
if let Some(script_handler) = &state.script_handler {
let catalog = params
.catalog
.unwrap_or_else(|| DEFAULT_CATALOG_NAME.to_string());
let schema = params.db.as_ref();
if schema.is_none() || schema.unwrap().is_empty() {
@@ -67,8 +72,10 @@ pub async fn scripts(
let script = unwrap_or_json_err!(String::from_utf8(bytes.to_vec()));
// Safety: schema and name are already checked above.
let query_ctx = QueryContext::with(&catalog, schema.unwrap());
let body = match script_handler
.insert_script(schema.unwrap(), name.unwrap(), &script)
.insert_script(query_ctx, name.unwrap(), &script)
.await
{
Ok(()) => JsonResponse::with_output(None),
@@ -83,6 +90,7 @@ pub async fn scripts(
#[derive(Debug, Serialize, Deserialize, JsonSchema, Default)]
pub struct ScriptQuery {
pub catalog: Option<String>,
pub db: Option<String>,
pub name: Option<String>,
#[serde(flatten)]
@@ -96,6 +104,9 @@ pub async fn run_script(
Query(params): Query<ScriptQuery>,
) -> Json<JsonResponse> {
if let Some(script_handler) = &state.script_handler {
let catalog = params
.catalog
.unwrap_or_else(|| DEFAULT_CATALOG_NAME.to_string());
let start = Instant::now();
let schema = params.db.as_ref();
@@ -109,10 +120,10 @@ pub async fn run_script(
json_err!("invalid name");
}
// TODO(sunng87): query_context and db name resolution
// Safety: schema and name are already checked above.
let query_ctx = QueryContext::with(&catalog, schema.unwrap());
let output = script_handler
.execute_script(schema.unwrap(), name.unwrap(), params.params)
.execute_script(query_ctx, name.unwrap(), params.params)
.await;
let resp = JsonResponse::from_output(vec![output]).await;

View File

@@ -49,10 +49,15 @@ pub type ScriptHandlerRef = Arc<dyn ScriptHandler + Send + Sync>;
#[async_trait]
pub trait ScriptHandler {
async fn insert_script(&self, schema: &str, name: &str, script: &str) -> Result<()>;
async fn insert_script(
&self,
query_ctx: QueryContextRef,
name: &str,
script: &str,
) -> Result<()>;
async fn execute_script(
&self,
schema: &str,
query_ctx: QueryContextRef,
name: &str,
params: HashMap<String, String>,
) -> Result<Output>;

View File

@@ -116,7 +116,15 @@ impl SqlQueryHandler for DummyInstance {
#[async_trait]
impl ScriptHandler for DummyInstance {
async fn insert_script(&self, schema: &str, name: &str, script: &str) -> Result<()> {
async fn insert_script(
&self,
query_ctx: QueryContextRef,
name: &str,
script: &str,
) -> Result<()> {
let catalog = query_ctx.current_catalog();
let schema = query_ctx.current_schema();
let script = self
.py_engine
.compile(script, CompileContext::default())
@@ -127,18 +135,20 @@ impl ScriptHandler for DummyInstance {
.scripts
.write()
.unwrap()
.insert(format!("{schema}_{name}"), Arc::new(script));
.insert(format!("{catalog}_{schema}_{name}"), Arc::new(script));
Ok(())
}
async fn execute_script(
&self,
schema: &str,
query_ctx: QueryContextRef,
name: &str,
params: HashMap<String, String>,
) -> Result<Output> {
let key = format!("{schema}_{name}");
let catalog = query_ctx.current_catalog();
let schema = query_ctx.current_schema();
let key = format!("{catalog}_{schema}_{name}");
let py_script = self.scripts.read().unwrap().get(&key).unwrap().clone();

View File

@@ -12,30 +12,79 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use std::sync::Arc;
use common_query::Output;
use common_recordbatch::RecordBatch;
use datatypes::prelude::ConcreteDataType;
use datatypes::schema::{ColumnSchema, Schema};
use datatypes::vectors::{StringVector, VectorRef};
use servers::error::Result;
use servers::query_handler::sql::SqlQueryHandler;
use servers::query_handler::ScriptHandler;
use session::context::QueryContext;
use session::context::QueryContextBuilder;
use table::test_util::MemTable;
use crate::create_testing_instance;
#[tokio::test]
async fn test_insert_py_udf_and_query() -> Result<()> {
let query_ctx = QueryContext::arc();
let table = MemTable::default_numbers_table();
let catalog = "greptime";
let schema = "test";
let name = "hello";
let script = r#"
@copr(returns=['n'])
def hello() -> vector[str]:
return 'hello';
"#;
let column_schemas = vec![
ColumnSchema::new("script", ConcreteDataType::string_datatype(), false),
ColumnSchema::new("schema", ConcreteDataType::string_datatype(), false),
ColumnSchema::new("name", ConcreteDataType::string_datatype(), false),
];
let columns: Vec<VectorRef> = vec![
Arc::new(StringVector::from(vec![script])),
Arc::new(StringVector::from(vec![schema])),
Arc::new(StringVector::from(vec![name])),
];
let raw_schema = Arc::new(Schema::new(column_schemas));
let recordbatch = RecordBatch::new(raw_schema, columns).unwrap();
let table = MemTable::table("scripts", recordbatch);
let query_ctx = QueryContextBuilder::default()
.current_catalog(catalog.to_string())
.current_schema(schema.to_string())
.build();
let instance = create_testing_instance(table);
let src = r#"
@coprocessor(args=["uint32s"], returns = ["ret"])
def double_that(col) -> vector[u32]:
return col*2
"#;
instance
.insert_script("schema_test", "double_that", src)
.insert_script(query_ctx.clone(), name, script)
.await?;
let output = instance
.execute_script(query_ctx.clone(), name, HashMap::new())
.await?;
match output {
Output::RecordBatches(batches) => {
let expected = "\
+-------+
| n |
+-------+
| hello |
+-------+";
assert_eq!(expected, batches.pretty_print().unwrap());
}
_ => unreachable!(),
}
let res = instance
.do_query("select double_that(uint32s) from numbers", query_ctx)
.do_query("select hello()", query_ctx)
.await
.remove(0)
.unwrap();
@@ -46,12 +95,16 @@ def double_that(col) -> vector[u32]:
}
common_query::Output::Stream(s) => {
let batches = common_recordbatch::util::collect_batches(s).await.unwrap();
assert_eq!(batches.iter().count(), 1);
let first = batches.iter().next().unwrap();
let col = first.column(0);
let val = col.get(1);
assert_eq!(val, datatypes::value::Value::UInt32(2));
let expected = "\
+---------+
| hello() |
+---------+
| hello |
+---------+";
assert_eq!(expected, batches.pretty_print().unwrap());
}
}
Ok(())
}

View File

@@ -479,6 +479,10 @@ impl TableInfo {
.map(|id| RegionId::new(self.table_id(), *id))
.collect()
}
/// Returns the full table name in the form of `{catalog}.{schema}.{table}`.
pub fn full_table_name(&self) -> String {
common_catalog::format_full_table_name(&self.catalog_name, &self.schema_name, &self.name)
}
}
impl TableInfoBuilder {