mirror of
https://github.com/GreptimeTeam/greptimedb.git
synced 2026-01-07 05:42:57 +00:00
feat!: Remove script crate and python feature (#5321)
* feat: exclude script crate * chore: simplify feature * feat: remove the script crate * chore: remove python feature and some comments * chore: fix warning
This commit is contained in:
@@ -9,8 +9,8 @@ runs:
|
||||
steps:
|
||||
# Download artifacts from previous jobs, the artifacts will be downloaded to:
|
||||
# ${WORKING_DIR}
|
||||
# |- greptime-darwin-amd64-pyo3-v0.5.0/greptime-darwin-amd64-pyo3-v0.5.0.tar.gz
|
||||
# |- greptime-darwin-amd64-pyo3-v0.5.0.sha256sum/greptime-darwin-amd64-pyo3-v0.5.0.sha256sum
|
||||
# |- greptime-darwin-amd64-v0.5.0/greptime-darwin-amd64-v0.5.0.tar.gz
|
||||
# |- greptime-darwin-amd64-v0.5.0.sha256sum/greptime-darwin-amd64-v0.5.0.sha256sum
|
||||
# |- greptime-darwin-amd64-v0.5.0/greptime-darwin-amd64-v0.5.0.tar.gz
|
||||
# |- greptime-darwin-amd64-v0.5.0.sha256sum/greptime-darwin-amd64-v0.5.0.sha256sum
|
||||
# ...
|
||||
|
||||
6
.github/actions/upload-artifacts/action.yml
vendored
6
.github/actions/upload-artifacts/action.yml
vendored
@@ -30,9 +30,9 @@ runs:
|
||||
done
|
||||
|
||||
# The compressed artifacts will use the following layout:
|
||||
# greptime-linux-amd64-pyo3-v0.3.0sha256sum
|
||||
# greptime-linux-amd64-pyo3-v0.3.0.tar.gz
|
||||
# greptime-linux-amd64-pyo3-v0.3.0
|
||||
# greptime-linux-amd64-v0.3.0sha256sum
|
||||
# greptime-linux-amd64-v0.3.0.tar.gz
|
||||
# greptime-linux-amd64-v0.3.0
|
||||
# └── greptime
|
||||
- name: Compress artifacts and calculate checksum
|
||||
working-directory: ${{ inputs.working-dir }}
|
||||
|
||||
8
.github/scripts/upload-artifacts-to-s3.sh
vendored
8
.github/scripts/upload-artifacts-to-s3.sh
vendored
@@ -27,11 +27,11 @@ function upload_artifacts() {
|
||||
# ├── latest-version.txt
|
||||
# ├── latest-nightly-version.txt
|
||||
# ├── v0.1.0
|
||||
# │ ├── greptime-darwin-amd64-pyo3-v0.1.0.sha256sum
|
||||
# │ └── greptime-darwin-amd64-pyo3-v0.1.0.tar.gz
|
||||
# │ ├── greptime-darwin-amd64-v0.1.0.sha256sum
|
||||
# │ └── greptime-darwin-amd64-v0.1.0.tar.gz
|
||||
# └── v0.2.0
|
||||
# ├── greptime-darwin-amd64-pyo3-v0.2.0.sha256sum
|
||||
# └── greptime-darwin-amd64-pyo3-v0.2.0.tar.gz
|
||||
# ├── greptime-darwin-amd64-v0.2.0.sha256sum
|
||||
# └── greptime-darwin-amd64-v0.2.0.tar.gz
|
||||
find "$ARTIFACTS_DIR" -type f \( -name "*.tar.gz" -o -name "*.sha256sum" \) | while IFS= read -r file; do
|
||||
aws s3 cp \
|
||||
"$file" "s3://$AWS_S3_BUCKET/$RELEASE_DIRS/$VERSION/$(basename "$file")"
|
||||
|
||||
1219
Cargo.lock
generated
1219
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -55,7 +55,6 @@ members = [
|
||||
"src/promql",
|
||||
"src/puffin",
|
||||
"src/query",
|
||||
"src/script",
|
||||
"src/servers",
|
||||
"src/session",
|
||||
"src/sql",
|
||||
@@ -79,8 +78,6 @@ clippy.dbg_macro = "warn"
|
||||
clippy.implicit_clone = "warn"
|
||||
clippy.readonly_write_lock = "allow"
|
||||
rust.unknown_lints = "deny"
|
||||
# Remove this after https://github.com/PyO3/pyo3/issues/4094
|
||||
rust.non_local_definitions = "allow"
|
||||
rust.unexpected_cfgs = { level = "warn", check-cfg = ['cfg(tokio_unstable)'] }
|
||||
|
||||
[workspace.dependencies]
|
||||
@@ -258,7 +255,6 @@ plugins = { path = "src/plugins" }
|
||||
promql = { path = "src/promql" }
|
||||
puffin = { path = "src/puffin" }
|
||||
query = { path = "src/query" }
|
||||
script = { path = "src/script" }
|
||||
servers = { path = "src/servers" }
|
||||
session = { path = "src/session" }
|
||||
sql = { path = "src/sql" }
|
||||
|
||||
@@ -139,7 +139,7 @@ Check the prerequisite:
|
||||
* [Rust toolchain](https://www.rust-lang.org/tools/install) (nightly)
|
||||
* [Protobuf compiler](https://grpc.io/docs/protoc-installation/) (>= 3.15)
|
||||
* C/C++ building essentials, including `gcc`/`g++`/`autoconf` and glibc library (eg. `libc6-dev` on Ubuntu and `glibc-devel` on Fedora)
|
||||
* Python toolchain (optional): Required only if built with PyO3 backend. More details for compiling with PyO3 can be found in its [documentation](https://pyo3.rs/v0.18.1/building_and_distribution#configuring-the-python-version).
|
||||
* Python toolchain (optional): Required only if using some test scripts.
|
||||
|
||||
Build GreptimeDB binary:
|
||||
|
||||
|
||||
@@ -122,13 +122,6 @@ pub enum Error {
|
||||
source: BoxedError,
|
||||
},
|
||||
|
||||
#[snafu(display("Failed to re-compile script due to internal error"))]
|
||||
CompileScriptInternal {
|
||||
#[snafu(implicit)]
|
||||
location: Location,
|
||||
source: BoxedError,
|
||||
},
|
||||
|
||||
#[snafu(display("Failed to create table, table info: {}", table_info))]
|
||||
CreateTable {
|
||||
table_info: String,
|
||||
@@ -343,9 +336,7 @@ impl ErrorExt for Error {
|
||||
Error::DecodePlan { source, .. } => source.status_code(),
|
||||
Error::InvalidTableInfoInCatalog { source, .. } => source.status_code(),
|
||||
|
||||
Error::CompileScriptInternal { source, .. } | Error::Internal { source, .. } => {
|
||||
source.status_code()
|
||||
}
|
||||
Error::Internal { source, .. } => source.status_code(),
|
||||
|
||||
Error::QueryAccessDenied { .. } => StatusCode::AccessDenied,
|
||||
Error::Datafusion { error, .. } => datafusion_status_code::<Self>(error, None),
|
||||
|
||||
@@ -10,9 +10,8 @@ name = "greptime"
|
||||
path = "src/bin/greptime.rs"
|
||||
|
||||
[features]
|
||||
default = ["python", "servers/pprof", "servers/mem-prof"]
|
||||
default = ["servers/pprof", "servers/mem-prof"]
|
||||
tokio-console = ["common-telemetry/tokio-console"]
|
||||
python = ["frontend/python"]
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
@@ -18,7 +18,6 @@ use arrow::error::ArrowError;
|
||||
use common_error::ext::{BoxedError, ErrorExt};
|
||||
use common_error::status_code::StatusCode;
|
||||
use common_macro::stack_trace_debug;
|
||||
use common_recordbatch::error::Error as RecordbatchError;
|
||||
use datafusion_common::DataFusionError;
|
||||
use datatypes::arrow;
|
||||
use datatypes::arrow::datatypes::DataType as ArrowDatatype;
|
||||
@@ -31,21 +30,6 @@ use statrs::StatsError;
|
||||
#[snafu(visibility(pub))]
|
||||
#[stack_trace_debug]
|
||||
pub enum Error {
|
||||
#[snafu(display("Failed to execute Python UDF: {}", msg))]
|
||||
PyUdf {
|
||||
// TODO(discord9): find a way that prevent circle depend(query<-script<-query) and can use script's error type
|
||||
msg: String,
|
||||
#[snafu(implicit)]
|
||||
location: Location,
|
||||
},
|
||||
|
||||
#[snafu(display("Failed to create temporary recordbatch when eval Python UDF"))]
|
||||
UdfTempRecordBatch {
|
||||
#[snafu(implicit)]
|
||||
location: Location,
|
||||
source: RecordbatchError,
|
||||
},
|
||||
|
||||
#[snafu(display("Failed to execute function"))]
|
||||
ExecuteFunction {
|
||||
#[snafu(source)]
|
||||
@@ -260,9 +244,7 @@ pub type Result<T> = std::result::Result<T, Error>;
|
||||
impl ErrorExt for Error {
|
||||
fn status_code(&self) -> StatusCode {
|
||||
match self {
|
||||
Error::UdfTempRecordBatch { .. }
|
||||
| Error::PyUdf { .. }
|
||||
| Error::CreateAccumulator { .. }
|
||||
Error::CreateAccumulator { .. }
|
||||
| Error::DowncastVector { .. }
|
||||
| Error::InvalidInputState { .. }
|
||||
| Error::InvalidInputCol { .. }
|
||||
|
||||
@@ -5,8 +5,6 @@ edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[features]
|
||||
default = ["python"]
|
||||
python = ["dep:script"]
|
||||
testing = []
|
||||
|
||||
[lints]
|
||||
@@ -52,7 +50,6 @@ prometheus.workspace = true
|
||||
prost.workspace = true
|
||||
query.workspace = true
|
||||
raft-engine.workspace = true
|
||||
script = { workspace = true, features = ["python"], optional = true }
|
||||
serde.workspace = true
|
||||
servers.workspace = true
|
||||
session.workspace = true
|
||||
|
||||
@@ -238,14 +238,6 @@ pub enum Error {
|
||||
location: Location,
|
||||
},
|
||||
|
||||
#[cfg(feature = "python")]
|
||||
#[snafu(display("Failed to start script manager"))]
|
||||
StartScriptManager {
|
||||
#[snafu(implicit)]
|
||||
location: Location,
|
||||
source: script::error::Error,
|
||||
},
|
||||
|
||||
#[snafu(display("Failed to insert value into table: {}", table_name))]
|
||||
Insert {
|
||||
table_name: String,
|
||||
@@ -394,9 +386,6 @@ impl ErrorExt for Error {
|
||||
}
|
||||
Error::FindTableRoute { source, .. } => source.status_code(),
|
||||
|
||||
#[cfg(feature = "python")]
|
||||
Error::StartScriptManager { source, .. } => source.status_code(),
|
||||
|
||||
Error::TableOperation { source, .. } => source.status_code(),
|
||||
|
||||
Error::InFlightWriteBytesExceeded { .. } => StatusCode::RateLimited,
|
||||
|
||||
@@ -21,7 +21,6 @@ mod opentsdb;
|
||||
mod otlp;
|
||||
mod prom_store;
|
||||
mod region_query;
|
||||
mod script;
|
||||
pub mod standalone;
|
||||
|
||||
use std::sync::Arc;
|
||||
@@ -66,7 +65,7 @@ use servers::query_handler::grpc::GrpcQueryHandler;
|
||||
use servers::query_handler::sql::SqlQueryHandler;
|
||||
use servers::query_handler::{
|
||||
InfluxdbLineProtocolHandler, LogQueryHandler, OpenTelemetryProtocolHandler,
|
||||
OpentsdbProtocolHandler, PipelineHandler, PromStoreProtocolHandler, ScriptHandler,
|
||||
OpentsdbProtocolHandler, PipelineHandler, PromStoreProtocolHandler,
|
||||
};
|
||||
use servers::server::ServerHandlers;
|
||||
use session::context::QueryContextRef;
|
||||
@@ -88,7 +87,6 @@ use crate::error::{
|
||||
use crate::frontend::FrontendOptions;
|
||||
use crate::heartbeat::HeartbeatTask;
|
||||
use crate::limiter::LimiterRef;
|
||||
use crate::script::ScriptExecutor;
|
||||
|
||||
#[async_trait]
|
||||
pub trait FrontendInstance:
|
||||
@@ -98,7 +96,6 @@ pub trait FrontendInstance:
|
||||
+ InfluxdbLineProtocolHandler
|
||||
+ PromStoreProtocolHandler
|
||||
+ OpenTelemetryProtocolHandler
|
||||
+ ScriptHandler
|
||||
+ PrometheusHandler
|
||||
+ PipelineHandler
|
||||
+ LogQueryHandler
|
||||
@@ -115,7 +112,6 @@ pub type FrontendInstanceRef = Arc<dyn FrontendInstance>;
|
||||
pub struct Instance {
|
||||
options: FrontendOptions,
|
||||
catalog_manager: CatalogManagerRef,
|
||||
script_executor: Arc<ScriptExecutor>,
|
||||
pipeline_operator: Arc<PipelineOperator>,
|
||||
statement_executor: Arc<StatementExecutor>,
|
||||
query_engine: QueryEngineRef,
|
||||
@@ -205,8 +201,6 @@ impl FrontendInstance for Instance {
|
||||
heartbeat_task.start().await?;
|
||||
}
|
||||
|
||||
self.script_executor.start(self)?;
|
||||
|
||||
if let Some(t) = self.export_metrics_task.as_ref() {
|
||||
if t.send_by_handler {
|
||||
let handler = ExportMetricHandler::new_handler(
|
||||
|
||||
@@ -44,7 +44,6 @@ use crate::heartbeat::HeartbeatTask;
|
||||
use crate::instance::region_query::FrontendRegionQueryHandler;
|
||||
use crate::instance::Instance;
|
||||
use crate::limiter::Limiter;
|
||||
use crate::script::ScriptExecutor;
|
||||
|
||||
/// The frontend [`Instance`] builder.
|
||||
pub struct FrontendBuilder {
|
||||
@@ -174,10 +173,6 @@ impl FrontendBuilder {
|
||||
)
|
||||
.query_engine();
|
||||
|
||||
let script_executor = Arc::new(
|
||||
ScriptExecutor::new(self.catalog_manager.clone(), query_engine.clone()).await?,
|
||||
);
|
||||
|
||||
let statement_executor = Arc::new(StatementExecutor::new(
|
||||
self.catalog_manager.clone(),
|
||||
query_engine.clone(),
|
||||
@@ -208,7 +203,6 @@ impl FrontendBuilder {
|
||||
Ok(Instance {
|
||||
options: self.options,
|
||||
catalog_manager: self.catalog_manager,
|
||||
script_executor,
|
||||
pipeline_operator,
|
||||
statement_executor,
|
||||
query_engine,
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use common_query::Output;
|
||||
use servers::error::Error;
|
||||
use servers::interceptor::{ScriptInterceptor, ScriptInterceptorRef};
|
||||
use servers::query_handler::ScriptHandler;
|
||||
use session::context::QueryContextRef;
|
||||
|
||||
use crate::instance::Instance;
|
||||
use crate::metrics;
|
||||
|
||||
#[async_trait]
|
||||
impl ScriptHandler for Instance {
|
||||
async fn insert_script(
|
||||
&self,
|
||||
query_ctx: QueryContextRef,
|
||||
name: &str,
|
||||
script: &str,
|
||||
) -> servers::error::Result<()> {
|
||||
let interceptor_ref = self.plugins.get::<ScriptInterceptorRef<Error>>();
|
||||
interceptor_ref.pre_execute(name, query_ctx.clone())?;
|
||||
|
||||
let _timer = metrics::INSERT_SCRIPTS_ELAPSED.start_timer();
|
||||
self.script_executor
|
||||
.insert_script(query_ctx, name, script)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn execute_script(
|
||||
&self,
|
||||
query_ctx: QueryContextRef,
|
||||
name: &str,
|
||||
params: HashMap<String, String>,
|
||||
) -> servers::error::Result<Output> {
|
||||
let interceptor_ref = self.plugins.get::<ScriptInterceptorRef<Error>>();
|
||||
interceptor_ref.pre_execute(name, query_ctx.clone())?;
|
||||
|
||||
let _timer = metrics::EXECUTE_SCRIPT_ELAPSED.start_timer();
|
||||
self.script_executor
|
||||
.execute_script(query_ctx, name, params)
|
||||
.await
|
||||
}
|
||||
}
|
||||
@@ -20,6 +20,5 @@ pub mod heartbeat;
|
||||
pub mod instance;
|
||||
pub(crate) mod limiter;
|
||||
pub(crate) mod metrics;
|
||||
mod script;
|
||||
pub mod server;
|
||||
pub mod service_config;
|
||||
|
||||
@@ -29,19 +29,6 @@ lazy_static! {
|
||||
pub static ref GRPC_HANDLE_PROMQL_ELAPSED: Histogram = GRPC_HANDLE_QUERY_ELAPSED
|
||||
.with_label_values(&["promql"]);
|
||||
|
||||
/// Timer of handling scripts in the script handler.
|
||||
pub static ref HANDLE_SCRIPT_ELAPSED: HistogramVec = register_histogram_vec!(
|
||||
"greptime_frontend_handle_script_elapsed",
|
||||
"Elapsed time of handling scripts in the script handler",
|
||||
&["type"],
|
||||
vec![0.005, 0.01, 0.05, 0.1, 0.5, 1.0, 5.0, 10.0, 60.0, 300.0]
|
||||
)
|
||||
.unwrap();
|
||||
pub static ref INSERT_SCRIPTS_ELAPSED: Histogram = HANDLE_SCRIPT_ELAPSED
|
||||
.with_label_values(&["insert"]);
|
||||
pub static ref EXECUTE_SCRIPT_ELAPSED: Histogram = HANDLE_SCRIPT_ELAPSED
|
||||
.with_label_values(&["execute"]);
|
||||
|
||||
/// The number of OpenTelemetry metrics send by frontend node.
|
||||
pub static ref OTLP_METRICS_ROWS: IntCounter = register_int_counter!(
|
||||
"greptime_frontend_otlp_metrics_rows",
|
||||
|
||||
@@ -1,294 +0,0 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use catalog::CatalogManagerRef;
|
||||
use common_query::Output;
|
||||
use query::QueryEngineRef;
|
||||
use session::context::QueryContextRef;
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::instance::Instance;
|
||||
|
||||
#[cfg(not(feature = "python"))]
|
||||
mod dummy {
|
||||
use super::*;
|
||||
|
||||
pub struct ScriptExecutor;
|
||||
|
||||
impl ScriptExecutor {
|
||||
pub async fn new(
|
||||
_catalog_manager: CatalogManagerRef,
|
||||
_query_engine: QueryEngineRef,
|
||||
) -> Result<Self> {
|
||||
Ok(Self {})
|
||||
}
|
||||
|
||||
pub fn start(&self, _instance: &Instance) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn insert_script(
|
||||
&self,
|
||||
_query_ctx: QueryContextRef,
|
||||
_name: &str,
|
||||
_script: &str,
|
||||
) -> servers::error::Result<()> {
|
||||
servers::error::NotSupportedSnafu { feat: "script" }.fail()
|
||||
}
|
||||
|
||||
pub async fn execute_script(
|
||||
&self,
|
||||
_query_ctx: QueryContextRef,
|
||||
_name: &str,
|
||||
_params: HashMap<String, String>,
|
||||
) -> servers::error::Result<Output> {
|
||||
servers::error::NotSupportedSnafu { feat: "script" }.fail()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "python")]
|
||||
mod python {
|
||||
use std::sync::Arc;
|
||||
|
||||
use api::v1::ddl_request::Expr;
|
||||
use api::v1::greptime_request::Request;
|
||||
use api::v1::DdlRequest;
|
||||
use arc_swap::ArcSwap;
|
||||
use catalog::RegisterSystemTableRequest;
|
||||
use common_error::ext::{BoxedError, ErrorExt};
|
||||
use common_telemetry::{error, info};
|
||||
use script::manager::ScriptManager;
|
||||
use servers::query_handler::grpc::GrpcQueryHandler;
|
||||
use session::context::QueryContext;
|
||||
use snafu::{OptionExt, ResultExt};
|
||||
use table::table_name::TableName;
|
||||
|
||||
use super::*;
|
||||
use crate::error::{CatalogSnafu, Error, TableNotFoundSnafu};
|
||||
|
||||
type FrontendGrpcQueryHandlerRef = Arc<dyn GrpcQueryHandler<Error = Error> + Send + Sync>;
|
||||
|
||||
/// A placeholder for the real gRPC handler.
|
||||
/// It is temporary and will be replaced soon.
|
||||
struct DummyHandler;
|
||||
|
||||
impl DummyHandler {
|
||||
fn arc() -> Arc<Self> {
|
||||
Arc::new(Self {})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl GrpcQueryHandler for DummyHandler {
|
||||
type Error = Error;
|
||||
|
||||
async fn do_query(
|
||||
&self,
|
||||
_query: Request,
|
||||
_ctx: QueryContextRef,
|
||||
) -> std::result::Result<Output, Self::Error> {
|
||||
unreachable!();
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ScriptExecutor {
|
||||
script_manager: ScriptManager<Error>,
|
||||
grpc_handler: ArcSwap<FrontendGrpcQueryHandlerRef>,
|
||||
catalog_manager: CatalogManagerRef,
|
||||
}
|
||||
|
||||
impl ScriptExecutor {
|
||||
pub async fn new(
|
||||
catalog_manager: CatalogManagerRef,
|
||||
query_engine: QueryEngineRef,
|
||||
) -> Result<Self> {
|
||||
let grpc_handler = DummyHandler::arc();
|
||||
Ok(Self {
|
||||
grpc_handler: ArcSwap::new(Arc::new(grpc_handler.clone() as _)),
|
||||
script_manager: ScriptManager::new(grpc_handler as _, query_engine)
|
||||
.await
|
||||
.context(crate::error::StartScriptManagerSnafu)?,
|
||||
catalog_manager,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn start(&self, instance: &Instance) -> Result<()> {
|
||||
let handler = Arc::new(instance.clone());
|
||||
self.grpc_handler.store(Arc::new(handler.clone() as _));
|
||||
self.script_manager
|
||||
.start(handler)
|
||||
.context(crate::error::StartScriptManagerSnafu)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Create scripts table for the specific catalog if it's not exists.
|
||||
/// The function is idempotent and safe to be called more than once for the same catalog
|
||||
async fn create_scripts_table_if_need(&self, catalog: &str) -> Result<()> {
|
||||
let scripts_table = self.script_manager.get_scripts_table(catalog);
|
||||
|
||||
if scripts_table.is_some() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let RegisterSystemTableRequest {
|
||||
create_table_expr: expr,
|
||||
open_hook,
|
||||
} = self.script_manager.create_table_request(catalog);
|
||||
|
||||
if let Some(table) = self
|
||||
.catalog_manager
|
||||
.table(
|
||||
&expr.catalog_name,
|
||||
&expr.schema_name,
|
||||
&expr.table_name,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.context(CatalogSnafu)?
|
||||
{
|
||||
if let Some(open_hook) = open_hook {
|
||||
(open_hook)(table.clone()).await.context(CatalogSnafu)?;
|
||||
}
|
||||
|
||||
self.script_manager.insert_scripts_table(catalog, table);
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let table_name =
|
||||
TableName::new(&expr.catalog_name, &expr.schema_name, &expr.table_name);
|
||||
|
||||
let _ = self
|
||||
.grpc_handler
|
||||
.load()
|
||||
.do_query(
|
||||
Request::Ddl(DdlRequest {
|
||||
expr: Some(Expr::CreateTable(expr)),
|
||||
}),
|
||||
QueryContext::arc(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let table = self
|
||||
.catalog_manager
|
||||
.table(
|
||||
&table_name.catalog_name,
|
||||
&table_name.schema_name,
|
||||
&table_name.table_name,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.context(CatalogSnafu)?
|
||||
.with_context(|| TableNotFoundSnafu {
|
||||
table_name: table_name.to_string(),
|
||||
})?;
|
||||
|
||||
if let Some(open_hook) = open_hook {
|
||||
(open_hook)(table.clone()).await.context(CatalogSnafu)?;
|
||||
}
|
||||
|
||||
info!(
|
||||
"Created scripts table {}.",
|
||||
table.table_info().full_table_name()
|
||||
);
|
||||
|
||||
self.script_manager.insert_scripts_table(catalog, table);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn insert_script(
|
||||
&self,
|
||||
query_ctx: QueryContextRef,
|
||||
name: &str,
|
||||
script: &str,
|
||||
) -> servers::error::Result<()> {
|
||||
self.create_scripts_table_if_need(query_ctx.current_catalog())
|
||||
.await
|
||||
.map_err(|e| {
|
||||
if e.status_code().should_log_error() {
|
||||
error!(e; "Failed to create scripts table");
|
||||
}
|
||||
|
||||
servers::error::InternalSnafu {
|
||||
err_msg: e.to_string(),
|
||||
}
|
||||
.build()
|
||||
})?;
|
||||
|
||||
let _s = self
|
||||
.script_manager
|
||||
.insert_and_compile(
|
||||
query_ctx.current_catalog(),
|
||||
&query_ctx.current_schema(),
|
||||
name,
|
||||
script,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
if e.status_code().should_log_error() {
|
||||
error!(e; "Failed to insert script");
|
||||
}
|
||||
|
||||
BoxedError::new(e)
|
||||
})
|
||||
.context(servers::error::InsertScriptSnafu { name })?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn execute_script(
|
||||
&self,
|
||||
query_ctx: QueryContextRef,
|
||||
name: &str,
|
||||
params: HashMap<String, String>,
|
||||
) -> servers::error::Result<Output> {
|
||||
self.create_scripts_table_if_need(query_ctx.current_catalog())
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!(e; "Failed to create scripts table");
|
||||
servers::error::InternalSnafu {
|
||||
err_msg: e.to_string(),
|
||||
}
|
||||
.build()
|
||||
})?;
|
||||
|
||||
self.script_manager
|
||||
.execute(
|
||||
query_ctx.current_catalog(),
|
||||
&query_ctx.current_schema(),
|
||||
name,
|
||||
params,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
if e.status_code().should_log_error() {
|
||||
error!(e; "Failed to execute script");
|
||||
}
|
||||
|
||||
BoxedError::new(e)
|
||||
})
|
||||
.context(servers::error::ExecuteScriptSnafu { name })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "python"))]
|
||||
pub use self::dummy::*;
|
||||
#[cfg(feature = "python")]
|
||||
pub use self::python::*;
|
||||
@@ -78,10 +78,8 @@ where
|
||||
}
|
||||
|
||||
pub fn http_server_builder(&self, opts: &FrontendOptions) -> HttpServerBuilder {
|
||||
let mut builder = HttpServerBuilder::new(opts.http.clone()).with_sql_handler(
|
||||
ServerSqlQueryHandlerAdapter::arc(self.instance.clone()),
|
||||
Some(self.instance.clone()),
|
||||
);
|
||||
let mut builder = HttpServerBuilder::new(opts.http.clone())
|
||||
.with_sql_handler(ServerSqlQueryHandlerAdapter::arc(self.instance.clone()));
|
||||
|
||||
let validator = self.plugins.get::<LogValidatorRef>();
|
||||
let ingest_interceptor = self.plugins.get::<LogIngestInterceptorRef<ServerError>>();
|
||||
|
||||
@@ -2037,18 +2037,19 @@ mod test {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_successive_runs() {
|
||||
impl From<(i32, i32, Option<i32>, Option<i32>)> for SucRun<i32> {
|
||||
fn from((offset, len, min_val, max_val): (i32, i32, Option<i32>, Option<i32>)) -> Self {
|
||||
Self {
|
||||
offset: offset as usize,
|
||||
len: len as usize,
|
||||
first_val: min_val,
|
||||
last_val: max_val,
|
||||
}
|
||||
impl From<(i32, i32, Option<i32>, Option<i32>)> for SucRun<i32> {
|
||||
fn from((offset, len, min_val, max_val): (i32, i32, Option<i32>, Option<i32>)) -> Self {
|
||||
Self {
|
||||
offset: offset as usize,
|
||||
len: len as usize,
|
||||
first_val: min_val,
|
||||
last_val: max_val,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_successive_runs() {
|
||||
let testcases = vec![
|
||||
(
|
||||
vec![Some(1), Some(1), Some(2), Some(1), Some(3)],
|
||||
|
||||
@@ -1,90 +0,0 @@
|
||||
[package]
|
||||
name = "script"
|
||||
edition.workspace = true
|
||||
version.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[features]
|
||||
default = ["python"]
|
||||
pyo3_backend = ["dep:pyo3", "arrow/pyarrow"]
|
||||
python = [
|
||||
"dep:datafusion",
|
||||
"dep:datafusion-common",
|
||||
"dep:datafusion-expr",
|
||||
"dep:datafusion-functions",
|
||||
"dep:datafusion-physical-expr",
|
||||
"dep:rustpython-vm",
|
||||
"dep:rustpython-parser",
|
||||
"dep:rustpython-compiler",
|
||||
"dep:rustpython-compiler-core",
|
||||
"dep:rustpython-codegen",
|
||||
"dep:rustpython-pylib",
|
||||
"dep:rustpython-stdlib",
|
||||
"dep:paste",
|
||||
]
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
api.workspace = true
|
||||
arc-swap = "1.0"
|
||||
arrow.workspace = true
|
||||
async-trait.workspace = true
|
||||
catalog.workspace = true
|
||||
common-catalog.workspace = true
|
||||
common-error.workspace = true
|
||||
common-function.workspace = true
|
||||
common-macro.workspace = true
|
||||
common-query.workspace = true
|
||||
common-recordbatch.workspace = true
|
||||
common-runtime.workspace = true
|
||||
common-telemetry.workspace = true
|
||||
common-time.workspace = true
|
||||
console = "0.15"
|
||||
crossbeam-utils.workspace = true
|
||||
datafusion = { workspace = true, optional = true }
|
||||
datafusion-common = { workspace = true, optional = true }
|
||||
datafusion-expr = { workspace = true, optional = true }
|
||||
datafusion-functions = { workspace = true, optional = true }
|
||||
datafusion-physical-expr = { workspace = true, optional = true }
|
||||
datatypes.workspace = true
|
||||
futures.workspace = true
|
||||
lazy_static.workspace = true
|
||||
once_cell.workspace = true
|
||||
paste = { workspace = true, optional = true }
|
||||
prometheus.workspace = true
|
||||
query.workspace = true
|
||||
# TODO(discord9): This is a forked and tweaked version of RustPython, please update it to newest original RustPython After RustPython support GC
|
||||
pyo3 = { version = "0.20", optional = true, features = ["abi3", "abi3-py37"] }
|
||||
rustpython-codegen = { git = "https://github.com/discord9/RustPython", optional = true, rev = "9ed5137412" }
|
||||
rustpython-compiler = { git = "https://github.com/discord9/RustPython", optional = true, rev = "9ed5137412" }
|
||||
rustpython-compiler-core = { git = "https://github.com/discord9/RustPython", optional = true, rev = "9ed5137412" }
|
||||
rustpython-parser = { git = "https://github.com/discord9/RustPython", optional = true, rev = "9ed5137412" }
|
||||
rustpython-pylib = { git = "https://github.com/discord9/RustPython", optional = true, rev = "9ed5137412", features = [
|
||||
"freeze-stdlib",
|
||||
] }
|
||||
rustpython-stdlib = { git = "https://github.com/discord9/RustPython", optional = true, rev = "9ed5137412" }
|
||||
rustpython-vm = { git = "https://github.com/discord9/RustPython", optional = true, rev = "9ed5137412", features = [
|
||||
"default",
|
||||
"codegen",
|
||||
] }
|
||||
servers.workspace = true
|
||||
session.workspace = true
|
||||
snafu.workspace = true
|
||||
sql.workspace = true
|
||||
table.workspace = true
|
||||
tokio.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
catalog = { workspace = true, features = ["testing"] }
|
||||
criterion = { version = "0.4", features = ["html_reports", "async_tokio"] }
|
||||
operator.workspace = true
|
||||
rayon = "1.0"
|
||||
ron = "0.7"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
session = { workspace = true, features = ["testing"] }
|
||||
|
||||
[[bench]]
|
||||
name = "py_benchmark"
|
||||
harness = false
|
||||
@@ -1,210 +0,0 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::any::Any;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use catalog::memory::MemoryCatalogManager;
|
||||
use common_catalog::consts::NUMBERS_TABLE_ID;
|
||||
use common_query::OutputData;
|
||||
use criterion::{black_box, criterion_group, criterion_main, Criterion};
|
||||
use futures::Future;
|
||||
use once_cell::sync::{Lazy, OnceCell};
|
||||
use query::QueryEngineFactory;
|
||||
use rayon::ThreadPool;
|
||||
use script::engine::{CompileContext, EvalContext, Script, ScriptEngine};
|
||||
use script::python::{PyEngine, PyScript};
|
||||
use table::table::numbers::NumbersTable;
|
||||
use tokio::runtime::Runtime;
|
||||
|
||||
static SCRIPT_ENGINE: Lazy<PyEngine> = Lazy::new(sample_script_engine);
|
||||
static LOCAL_RUNTIME: OnceCell<tokio::runtime::Runtime> = OnceCell::new();
|
||||
fn get_local_runtime() -> std::thread::Result<&'static Runtime> {
|
||||
let rt = LOCAL_RUNTIME.get_or_try_init(|| {
|
||||
tokio::runtime::Runtime::new().map_err(|e| Box::new(e) as Box<dyn Any + Send + 'static>)
|
||||
})?;
|
||||
Ok(rt)
|
||||
}
|
||||
/// a terrible hack to call async from sync by:
|
||||
/// TODO(discord9): find a better way
|
||||
/// 1. spawn a new thread
|
||||
/// 2. create a new runtime in new thread and call `block_on` on it
|
||||
pub fn block_on_async<T, F>(f: F) -> std::thread::Result<T>
|
||||
where
|
||||
F: Future<Output = T> + Send + 'static,
|
||||
T: Send + 'static,
|
||||
{
|
||||
let rt = get_local_runtime()?;
|
||||
|
||||
std::thread::spawn(move || rt.block_on(f)).join()
|
||||
}
|
||||
|
||||
pub(crate) fn sample_script_engine() -> PyEngine {
|
||||
let catalog_manager =
|
||||
MemoryCatalogManager::new_with_table(NumbersTable::table(NUMBERS_TABLE_ID));
|
||||
let query_engine =
|
||||
QueryEngineFactory::new(catalog_manager, None, None, None, None, false).query_engine();
|
||||
|
||||
PyEngine::new(query_engine.clone())
|
||||
}
|
||||
|
||||
async fn compile_script(script: &str) -> PyScript {
|
||||
SCRIPT_ENGINE
|
||||
.compile(script, CompileContext::default())
|
||||
.await
|
||||
.unwrap()
|
||||
}
|
||||
async fn run_compiled(script: &PyScript) {
|
||||
let output = script
|
||||
.execute(HashMap::default(), EvalContext::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let _res = match output.data {
|
||||
OutputData::Stream(s) => common_recordbatch::util::collect_batches(s).await.unwrap(),
|
||||
OutputData::RecordBatches(rbs) => rbs,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
}
|
||||
|
||||
async fn fibonacci(n: u64, backend: &str) {
|
||||
let source = format!(
|
||||
r#"
|
||||
@copr(returns=["value"], backend="{backend}")
|
||||
def entry() -> vector[i64]:
|
||||
def fibonacci(n):
|
||||
if n <2:
|
||||
return 1
|
||||
else:
|
||||
return fibonacci(n-1) + fibonacci(n-2)
|
||||
return fibonacci({n})
|
||||
"#
|
||||
);
|
||||
let compiled = compile_script(&source).await;
|
||||
for _ in 0..10 {
|
||||
run_compiled(&compiled).await;
|
||||
}
|
||||
}
|
||||
|
||||
/// TODO(discord9): use a better way to benchmark in parallel
|
||||
async fn parallel_fibonacci(n: u64, backend: &str, pool: &ThreadPool) {
|
||||
let source = format!(
|
||||
r#"
|
||||
@copr(returns=["value"], backend="{backend}")
|
||||
def entry() -> vector[i64]:
|
||||
def fibonacci(n):
|
||||
if n <2:
|
||||
return 1
|
||||
else:
|
||||
return fibonacci(n-1) + fibonacci(n-2)
|
||||
return fibonacci({n})
|
||||
"#
|
||||
);
|
||||
let source = Arc::new(source);
|
||||
// execute the script in parallel for every thread in the pool
|
||||
let _ = pool.broadcast(|_| {
|
||||
let source = source.clone();
|
||||
let rt = get_local_runtime().unwrap();
|
||||
rt.block_on(async move {
|
||||
let compiled = compile_script(&source).await;
|
||||
for _ in 0..10 {
|
||||
run_compiled(&compiled).await;
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
async fn loop_1_million(backend: &str) {
|
||||
let source = format!(
|
||||
r#"
|
||||
@copr(returns=["value"], backend="{backend}")
|
||||
def entry() -> vector[i64]:
|
||||
for i in range(1000000):
|
||||
pass
|
||||
return 1
|
||||
"#
|
||||
);
|
||||
let compiled = compile_script(&source).await;
|
||||
for _ in 0..10 {
|
||||
run_compiled(&compiled).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn api_heavy(backend: &str) {
|
||||
let source = format!(
|
||||
r#"
|
||||
from greptime import vector
|
||||
@copr(args=["number"], sql="select number from numbers", returns=["value"], backend="{backend}")
|
||||
def entry(number) -> vector[i64]:
|
||||
for i in range(1000):
|
||||
n2 = number + number
|
||||
n_mul = n2 * n2
|
||||
n_mask = n_mul[n_mul>2]
|
||||
return 1
|
||||
"#
|
||||
);
|
||||
let compiled = compile_script(&source).await;
|
||||
for _ in 0..10 {
|
||||
run_compiled(&compiled).await;
|
||||
}
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
// TODO(discord9): Prime Number,
|
||||
// and database-local computation/remote download python script comparison
|
||||
// which require a local mock library
|
||||
// TODO(discord9): revisit once mock library is ready
|
||||
|
||||
let pool = rayon::ThreadPoolBuilder::new()
|
||||
.num_threads(16)
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let _ = c
|
||||
.bench_function("fib 20 rspy", |b| {
|
||||
b.to_async(tokio::runtime::Runtime::new().unwrap())
|
||||
.iter(|| fibonacci(black_box(20), "rspy"))
|
||||
})
|
||||
.bench_function("fib 20 pyo3", |b| {
|
||||
b.to_async(tokio::runtime::Runtime::new().unwrap())
|
||||
.iter(|| fibonacci(black_box(20), "pyo3"))
|
||||
})
|
||||
.bench_function("par fib 20 rspy", |b| {
|
||||
b.to_async(tokio::runtime::Runtime::new().unwrap())
|
||||
.iter(|| parallel_fibonacci(black_box(20), "rspy", &pool))
|
||||
})
|
||||
.bench_function("par fib 20 pyo3", |b| {
|
||||
b.to_async(tokio::runtime::Runtime::new().unwrap())
|
||||
.iter(|| parallel_fibonacci(black_box(20), "pyo3", &pool))
|
||||
})
|
||||
.bench_function("loop 1M rspy", |b| {
|
||||
b.to_async(tokio::runtime::Runtime::new().unwrap())
|
||||
.iter(|| loop_1_million(black_box("rspy")))
|
||||
})
|
||||
.bench_function("loop 1M pyo3", |b| {
|
||||
b.to_async(tokio::runtime::Runtime::new().unwrap())
|
||||
.iter(|| loop_1_million(black_box("pyo3")))
|
||||
})
|
||||
.bench_function("api heavy rspy", |b| {
|
||||
b.to_async(tokio::runtime::Runtime::new().unwrap())
|
||||
.iter(|| api_heavy(black_box("rspy")))
|
||||
})
|
||||
.bench_function("api heavy pyo3", |b| {
|
||||
b.to_async(tokio::runtime::Runtime::new().unwrap())
|
||||
.iter(|| api_heavy(black_box("pyo3")))
|
||||
});
|
||||
}
|
||||
|
||||
criterion_group!(benches, criterion_benchmark);
|
||||
criterion_main!(benches);
|
||||
@@ -1,76 +0,0 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Script engine
|
||||
|
||||
use std::any::Any;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use common_error::ext::ErrorExt;
|
||||
use common_query::Output;
|
||||
use session::context::{QueryContext, QueryContextRef};
|
||||
|
||||
#[async_trait]
|
||||
pub trait Script {
|
||||
type Error: ErrorExt + Send + Sync;
|
||||
|
||||
/// Returns the script engine name such as `python` etc.
|
||||
fn engine_name(&self) -> &str;
|
||||
|
||||
fn as_any(&self) -> &dyn Any;
|
||||
|
||||
/// Execute the script and returns the output.
|
||||
async fn execute(
|
||||
&self,
|
||||
params: HashMap<String, String>,
|
||||
ctx: EvalContext,
|
||||
) -> std::result::Result<Output, Self::Error>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait ScriptEngine {
|
||||
type Error: ErrorExt + Send + Sync;
|
||||
type Script: Script<Error = Self::Error>;
|
||||
|
||||
/// Returns the script engine name such as `python` etc.
|
||||
fn name(&self) -> &str;
|
||||
|
||||
fn as_any(&self) -> &dyn Any;
|
||||
|
||||
/// Compile a script text into a script instance.
|
||||
async fn compile(
|
||||
&self,
|
||||
script: &str,
|
||||
ctx: CompileContext,
|
||||
) -> std::result::Result<Self::Script, Self::Error>;
|
||||
}
|
||||
|
||||
/// Evaluate script context
|
||||
#[derive(Debug)]
|
||||
pub struct EvalContext {
|
||||
pub query_ctx: QueryContextRef,
|
||||
}
|
||||
|
||||
impl Default for EvalContext {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
query_ctx: QueryContext::arc(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Compile script context
|
||||
#[derive(Debug, Default)]
|
||||
pub struct CompileContext {}
|
||||
@@ -1,120 +0,0 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::any::Any;
|
||||
|
||||
use common_error::ext::{BoxedError, ErrorExt};
|
||||
use common_error::status_code::StatusCode;
|
||||
use common_macro::stack_trace_debug;
|
||||
use snafu::{Location, Snafu};
|
||||
|
||||
#[derive(Snafu)]
|
||||
#[snafu(visibility(pub))]
|
||||
#[stack_trace_debug]
|
||||
pub enum Error {
|
||||
#[snafu(display("Failed to find column in scripts table, name: {}", name))]
|
||||
FindColumnInScriptsTable {
|
||||
name: String,
|
||||
#[snafu(implicit)]
|
||||
location: Location,
|
||||
},
|
||||
|
||||
#[snafu(display("Scripts table not found"))]
|
||||
ScriptsTableNotFound {
|
||||
#[snafu(implicit)]
|
||||
location: Location,
|
||||
},
|
||||
|
||||
#[snafu(display("Failed to insert script to scripts table, name: {}", name))]
|
||||
InsertScript {
|
||||
name: String,
|
||||
#[snafu(implicit)]
|
||||
location: Location,
|
||||
source: BoxedError,
|
||||
},
|
||||
|
||||
#[snafu(display("Failed to compile python script, name: {}", name))]
|
||||
CompilePython {
|
||||
name: String,
|
||||
#[snafu(implicit)]
|
||||
location: Location,
|
||||
source: crate::python::error::Error,
|
||||
},
|
||||
|
||||
#[snafu(display("Failed to execute python script {}", name))]
|
||||
ExecutePython {
|
||||
name: String,
|
||||
#[snafu(implicit)]
|
||||
location: Location,
|
||||
source: crate::python::error::Error,
|
||||
},
|
||||
|
||||
#[snafu(display("Script not found, name: {}", name))]
|
||||
ScriptNotFound {
|
||||
#[snafu(implicit)]
|
||||
location: Location,
|
||||
name: String,
|
||||
},
|
||||
|
||||
#[snafu(display("Failed to collect record batch"))]
|
||||
CollectRecords {
|
||||
#[snafu(implicit)]
|
||||
location: Location,
|
||||
source: common_recordbatch::error::Error,
|
||||
},
|
||||
|
||||
#[snafu(display("Failed to cast type, msg: {}", msg))]
|
||||
CastType {
|
||||
msg: String,
|
||||
#[snafu(implicit)]
|
||||
location: Location,
|
||||
},
|
||||
|
||||
#[snafu(display("Failed to build DataFusion logical plan"))]
|
||||
BuildDfLogicalPlan {
|
||||
#[snafu(source)]
|
||||
error: datafusion_common::DataFusionError,
|
||||
#[snafu(implicit)]
|
||||
location: Location,
|
||||
},
|
||||
|
||||
#[snafu(display("Failed to execute internal statement"))]
|
||||
ExecuteInternalStatement {
|
||||
source: query::error::Error,
|
||||
#[snafu(implicit)]
|
||||
location: Location,
|
||||
},
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
impl ErrorExt for Error {
|
||||
fn status_code(&self) -> StatusCode {
|
||||
use Error::*;
|
||||
match self {
|
||||
FindColumnInScriptsTable { .. } | CastType { .. } => StatusCode::Unexpected,
|
||||
ScriptsTableNotFound { .. } => StatusCode::TableNotFound,
|
||||
InsertScript { source, .. } => source.status_code(),
|
||||
CompilePython { source, .. } | ExecutePython { source, .. } => source.status_code(),
|
||||
CollectRecords { source, .. } => source.status_code(),
|
||||
ScriptNotFound { .. } => StatusCode::InvalidArguments,
|
||||
BuildDfLogicalPlan { .. } => StatusCode::Internal,
|
||||
ExecuteInternalStatement { source, .. } => source.status_code(),
|
||||
}
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
}
|
||||
}
|
||||
@@ -1,26 +0,0 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// TODO(discord9): spawn new process for executing python script(if hit gil limit) and use shared memory to communicate
|
||||
#![deny(clippy::implicit_clone)]
|
||||
|
||||
pub mod engine;
|
||||
pub mod error;
|
||||
#[cfg(feature = "python")]
|
||||
pub mod manager;
|
||||
#[cfg(feature = "python")]
|
||||
pub mod python;
|
||||
pub mod table;
|
||||
#[cfg(test)]
|
||||
mod test;
|
||||
@@ -1,278 +0,0 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Scripts manager
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
use api::v1::CreateTableExpr;
|
||||
use arc_swap::ArcSwap;
|
||||
use catalog::{OpenSystemTableHook, RegisterSystemTableRequest};
|
||||
use common_catalog::consts::{default_engine, DEFAULT_SCHEMA_NAME};
|
||||
use common_error::ext::ErrorExt;
|
||||
use common_query::Output;
|
||||
use common_telemetry::info;
|
||||
use futures::future::FutureExt;
|
||||
use query::QueryEngineRef;
|
||||
use servers::query_handler::grpc::GrpcQueryHandlerRef;
|
||||
use snafu::{OptionExt, ResultExt};
|
||||
use table::TableRef;
|
||||
|
||||
use crate::engine::{CompileContext, EvalContext, Script, ScriptEngine};
|
||||
use crate::error::{
|
||||
CompilePythonSnafu, ExecutePythonSnafu, Result, ScriptNotFoundSnafu, ScriptsTableNotFoundSnafu,
|
||||
};
|
||||
use crate::python::{PyEngine, PyScript};
|
||||
use crate::table::{build_scripts_schema, ScriptsTable, ScriptsTableRef, SCRIPTS_TABLE_NAME};
|
||||
|
||||
pub struct ScriptManager<E: ErrorExt + Send + Sync + 'static> {
|
||||
compiled: RwLock<HashMap<String, Arc<PyScript>>>,
|
||||
py_engine: PyEngine,
|
||||
grpc_handler: ArcSwap<GrpcQueryHandlerRef<E>>,
|
||||
// Catalog name -> `[ScriptsTable]`
|
||||
tables: RwLock<HashMap<String, ScriptsTableRef<E>>>,
|
||||
query_engine: QueryEngineRef,
|
||||
}
|
||||
|
||||
impl<E: ErrorExt + Send + Sync + 'static> ScriptManager<E> {
|
||||
pub async fn new(
|
||||
grpc_handler: GrpcQueryHandlerRef<E>,
|
||||
query_engine: QueryEngineRef,
|
||||
) -> Result<Self> {
|
||||
Ok(Self {
|
||||
compiled: RwLock::new(HashMap::default()),
|
||||
py_engine: PyEngine::new(query_engine.clone()),
|
||||
query_engine,
|
||||
grpc_handler: ArcSwap::new(Arc::new(grpc_handler)),
|
||||
tables: RwLock::new(HashMap::default()),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn start(&self, grpc_handler: GrpcQueryHandlerRef<E>) -> Result<()> {
|
||||
self.grpc_handler.store(Arc::new(grpc_handler));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn create_table_request(&self, catalog: &str) -> RegisterSystemTableRequest {
|
||||
let (time_index, primary_keys, column_defs) = build_scripts_schema();
|
||||
|
||||
let create_table_expr = CreateTableExpr {
|
||||
catalog_name: catalog.to_string(),
|
||||
// TODO(dennis): put the scripts table into `system` schema?
|
||||
// We always put the scripts table into `public` schema right now.
|
||||
schema_name: DEFAULT_SCHEMA_NAME.to_string(),
|
||||
table_name: SCRIPTS_TABLE_NAME.to_string(),
|
||||
desc: "GreptimeDB scripts table for Python".to_string(),
|
||||
column_defs,
|
||||
time_index,
|
||||
primary_keys,
|
||||
create_if_not_exists: true,
|
||||
table_options: Default::default(),
|
||||
table_id: None, // Should and will be assigned by Meta.
|
||||
engine: default_engine().to_string(),
|
||||
};
|
||||
|
||||
let query_engine = self.query_engine.clone();
|
||||
|
||||
let hook: OpenSystemTableHook = Box::new(move |table: TableRef| {
|
||||
let query_engine = query_engine.clone();
|
||||
async move { ScriptsTable::<E>::recompile_register_udf(table, query_engine.clone()).await }
|
||||
.boxed()
|
||||
});
|
||||
|
||||
RegisterSystemTableRequest {
|
||||
create_table_expr,
|
||||
open_hook: Some(hook),
|
||||
}
|
||||
}
|
||||
|
||||
/// compile script, and register them to the query engine and UDF registry
|
||||
async fn compile(&self, name: &str, script: &str) -> Result<Arc<PyScript>> {
|
||||
let script = Arc::new(Self::compile_without_cache(&self.py_engine, name, script).await?);
|
||||
|
||||
{
|
||||
let mut compiled = self.compiled.write().unwrap();
|
||||
let _ = compiled.insert(name.to_string(), script.clone());
|
||||
}
|
||||
info!("Compiled and cached script: {}", name);
|
||||
|
||||
script.as_ref().register_udf().await;
|
||||
|
||||
info!("Script register as UDF: {}", name);
|
||||
|
||||
Ok(script)
|
||||
}
|
||||
|
||||
/// compile script to PyScript, but not register them to the query engine and UDF registry nor caching in `compiled`
|
||||
async fn compile_without_cache(
|
||||
py_engine: &PyEngine,
|
||||
name: &str,
|
||||
script: &str,
|
||||
) -> Result<PyScript> {
|
||||
py_engine
|
||||
.compile(script, CompileContext::default())
|
||||
.await
|
||||
.context(CompilePythonSnafu { name })
|
||||
}
|
||||
|
||||
/// Get the scripts table in the catalog
|
||||
pub fn get_scripts_table(&self, catalog: &str) -> Option<ScriptsTableRef<E>> {
|
||||
self.tables.read().unwrap().get(catalog).cloned()
|
||||
}
|
||||
|
||||
/// Insert a scripts table.
|
||||
pub fn insert_scripts_table(&self, catalog: &str, table: TableRef) {
|
||||
let mut tables = self.tables.write().unwrap();
|
||||
|
||||
if tables.get(catalog).is_some() {
|
||||
return;
|
||||
}
|
||||
|
||||
tables.insert(
|
||||
catalog.to_string(),
|
||||
Arc::new(ScriptsTable::new(
|
||||
table,
|
||||
self.grpc_handler.load().as_ref().clone(),
|
||||
self.query_engine.clone(),
|
||||
)),
|
||||
);
|
||||
}
|
||||
|
||||
pub async fn insert_and_compile(
|
||||
&self,
|
||||
catalog: &str,
|
||||
schema: &str,
|
||||
name: &str,
|
||||
script: &str,
|
||||
) -> Result<Arc<PyScript>> {
|
||||
let compiled_script = self.compile(name, script).await?;
|
||||
self.get_scripts_table(catalog)
|
||||
.context(ScriptsTableNotFoundSnafu)?
|
||||
.insert(schema, name, script)
|
||||
.await?;
|
||||
|
||||
Ok(compiled_script)
|
||||
}
|
||||
|
||||
pub async fn execute(
|
||||
&self,
|
||||
catalog: &str,
|
||||
schema: &str,
|
||||
name: &str,
|
||||
params: HashMap<String, String>,
|
||||
) -> Result<Output> {
|
||||
let script = {
|
||||
let s = self.compiled.read().unwrap().get(name).cloned();
|
||||
|
||||
if s.is_some() {
|
||||
s
|
||||
} else {
|
||||
self.try_find_script_and_compile(catalog, schema, name)
|
||||
.await?
|
||||
}
|
||||
};
|
||||
|
||||
let script = script.context(ScriptNotFoundSnafu { name })?;
|
||||
|
||||
script
|
||||
.execute(params, EvalContext::default())
|
||||
.await
|
||||
.context(ExecutePythonSnafu { name })
|
||||
}
|
||||
|
||||
async fn try_find_script_and_compile(
|
||||
&self,
|
||||
catalog: &str,
|
||||
schema: &str,
|
||||
name: &str,
|
||||
) -> Result<Option<Arc<PyScript>>> {
|
||||
let script = self
|
||||
.get_scripts_table(catalog)
|
||||
.context(ScriptsTableNotFoundSnafu)?
|
||||
.find_script_by_name(schema, name)
|
||||
.await?;
|
||||
|
||||
Ok(Some(self.compile(name, &script).await?))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use common_query::OutputData;
|
||||
|
||||
use super::*;
|
||||
use crate::test::setup_scripts_manager;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_insert_find_compile_script() {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
|
||||
let catalog = "greptime";
|
||||
let schema = "schema";
|
||||
let name = "test";
|
||||
let script = r#"
|
||||
@copr(returns=['n'])
|
||||
def test() -> vector[str]:
|
||||
return 'hello';
|
||||
"#;
|
||||
|
||||
let mgr = setup_scripts_manager(catalog, schema, name, script).await;
|
||||
|
||||
{
|
||||
let cached = mgr.compiled.read().unwrap();
|
||||
assert!(cached.get(name).is_none());
|
||||
}
|
||||
|
||||
mgr.insert_and_compile(catalog, schema, name, script)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
{
|
||||
let cached = mgr.compiled.read().unwrap();
|
||||
assert!(cached.get(name).is_some());
|
||||
}
|
||||
|
||||
// try to find and compile
|
||||
let script = mgr
|
||||
.try_find_script_and_compile(catalog, schema, name)
|
||||
.await
|
||||
.unwrap();
|
||||
let _ = script.unwrap();
|
||||
|
||||
{
|
||||
let cached = mgr.compiled.read().unwrap();
|
||||
let _ = cached.get(name).unwrap();
|
||||
}
|
||||
|
||||
// execute script
|
||||
let output = mgr
|
||||
.execute(catalog, schema, name, HashMap::new())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
match output.data {
|
||||
OutputData::RecordBatches(batches) => {
|
||||
let expected = "\
|
||||
+-------+
|
||||
| n |
|
||||
+-------+
|
||||
| hello |
|
||||
+-------+";
|
||||
assert_eq!(expected, batches.pretty_print().unwrap());
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,28 +0,0 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Python script coprocessor
|
||||
|
||||
mod engine;
|
||||
pub mod error;
|
||||
pub(crate) mod metric;
|
||||
pub(crate) mod utils;
|
||||
|
||||
pub use self::engine::{PyEngine, PyScript};
|
||||
|
||||
mod ffi_types;
|
||||
|
||||
#[cfg(feature = "pyo3_backend")]
|
||||
mod pyo3;
|
||||
mod rspython;
|
||||
@@ -1,579 +0,0 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Python script engine
|
||||
use std::any::Any;
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use common_error::ext::BoxedError;
|
||||
use common_function::function::Function;
|
||||
use common_function::function_registry::FUNCTION_REGISTRY;
|
||||
use common_query::error::{PyUdfSnafu, UdfTempRecordBatchSnafu};
|
||||
use common_query::prelude::Signature;
|
||||
use common_query::{Output, OutputData};
|
||||
use common_recordbatch::adapter::RecordBatchMetrics;
|
||||
use common_recordbatch::error::{ExternalSnafu, Result as RecordBatchResult};
|
||||
use common_recordbatch::{
|
||||
OrderOption, RecordBatch, RecordBatchStream, RecordBatches, SendableRecordBatchStream,
|
||||
};
|
||||
use datafusion_expr::Volatility;
|
||||
use datatypes::schema::{ColumnSchema, Schema, SchemaRef};
|
||||
use datatypes::vectors::VectorRef;
|
||||
use futures::Stream;
|
||||
use query::parser::{QueryLanguageParser, QueryStatement};
|
||||
use query::QueryEngineRef;
|
||||
use snafu::{ensure, ResultExt};
|
||||
use sql::statements::statement::Statement;
|
||||
|
||||
use crate::engine::{CompileContext, EvalContext, Script, ScriptEngine};
|
||||
use crate::python::error::{self, DatabaseQuerySnafu, PyRuntimeSnafu, Result, TokioJoinSnafu};
|
||||
use crate::python::ffi_types::copr::{exec_parsed, parse, AnnotationInfo, CoprocessorRef};
|
||||
use crate::python::utils::spawn_blocking_script;
|
||||
const PY_ENGINE: &str = "python";
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PyUDF {
|
||||
copr: CoprocessorRef,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for PyUDF {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"{}({})->",
|
||||
&self.copr.name,
|
||||
self.copr
|
||||
.deco_args
|
||||
.arg_names
|
||||
.as_ref()
|
||||
.unwrap_or(&vec![])
|
||||
.join(",")
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl PyUDF {
|
||||
fn from_copr(copr: CoprocessorRef) -> Arc<Self> {
|
||||
Arc::new(Self { copr })
|
||||
}
|
||||
|
||||
/// Register to `FUNCTION_REGISTRY`
|
||||
fn register_as_udf(zelf: Arc<Self>) {
|
||||
FUNCTION_REGISTRY.register(zelf)
|
||||
}
|
||||
|
||||
fn register_to_query_engine(zelf: Arc<Self>, engine: QueryEngineRef) {
|
||||
engine.register_function(zelf)
|
||||
}
|
||||
|
||||
/// Fake a schema, should only be used with dynamically eval a Python Udf
|
||||
fn fake_schema(&self, columns: &[VectorRef]) -> SchemaRef {
|
||||
// try to give schema right names in args so script can run as UDF without modify
|
||||
// because when running as PyUDF, the incoming columns should have matching names to make sense
|
||||
// for Coprocessor
|
||||
let args = self.copr.deco_args.arg_names.clone();
|
||||
let try_get_name = |i: usize| {
|
||||
if let Some(arg_name) = args.as_ref().and_then(|args| args.get(i)) {
|
||||
arg_name.clone()
|
||||
} else {
|
||||
format!("name_{i}")
|
||||
}
|
||||
};
|
||||
let col_sch: Vec<_> = columns
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, col)| ColumnSchema::new(try_get_name(i), col.data_type(), true))
|
||||
.collect();
|
||||
let schema = datatypes::schema::Schema::new(col_sch);
|
||||
Arc::new(schema)
|
||||
}
|
||||
}
|
||||
|
||||
impl Function for PyUDF {
|
||||
fn name(&self) -> &str {
|
||||
&self.copr.name
|
||||
}
|
||||
|
||||
fn return_type(
|
||||
&self,
|
||||
_input_types: &[datatypes::prelude::ConcreteDataType],
|
||||
) -> common_query::error::Result<datatypes::prelude::ConcreteDataType> {
|
||||
// TODO(discord9): use correct return annotation if exist
|
||||
match self.copr.return_types.first() {
|
||||
Some(Some(AnnotationInfo {
|
||||
datatype: Some(ty), ..
|
||||
})) => Ok(ty.clone()),
|
||||
_ => PyUdfSnafu {
|
||||
msg: "Can't found return type for python UDF {self}",
|
||||
}
|
||||
.fail(),
|
||||
}
|
||||
}
|
||||
|
||||
fn signature(&self) -> common_query::prelude::Signature {
|
||||
if self.copr.arg_types.is_empty() {
|
||||
return Signature::any(0, Volatility::Volatile);
|
||||
}
|
||||
|
||||
// try our best to get a type signature
|
||||
let mut arg_types = Vec::with_capacity(self.copr.arg_types.len());
|
||||
let mut know_all_types = true;
|
||||
for ty in self.copr.arg_types.iter() {
|
||||
match ty {
|
||||
Some(AnnotationInfo {
|
||||
datatype: Some(ty), ..
|
||||
}) => arg_types.push(ty.clone()),
|
||||
_ => {
|
||||
know_all_types = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// The Volatility should be volatile, the return value from evaluation may be changed.
|
||||
if know_all_types {
|
||||
Signature::variadic(arg_types, Volatility::Volatile)
|
||||
} else {
|
||||
Signature::any(self.copr.arg_types.len(), Volatility::Volatile)
|
||||
}
|
||||
}
|
||||
|
||||
fn eval(
|
||||
&self,
|
||||
func_ctx: common_function::function::FunctionContext,
|
||||
columns: &[datatypes::vectors::VectorRef],
|
||||
) -> common_query::error::Result<datatypes::vectors::VectorRef> {
|
||||
// FIXME(discord9): exec_parsed require a RecordBatch(basically a Vector+Schema), where schema can't pop out from nowhere, right?
|
||||
let schema = self.fake_schema(columns);
|
||||
let columns = columns.to_vec();
|
||||
let rb = Some(RecordBatch::new(schema, columns).context(UdfTempRecordBatchSnafu)?);
|
||||
|
||||
let res = exec_parsed(
|
||||
&self.copr,
|
||||
&rb,
|
||||
&HashMap::new(),
|
||||
&EvalContext {
|
||||
query_ctx: func_ctx.query_ctx.clone(),
|
||||
},
|
||||
)
|
||||
.map_err(BoxedError::new)
|
||||
.context(common_query::error::ExecuteSnafu)?;
|
||||
|
||||
let len = res.columns().len();
|
||||
if len == 0 {
|
||||
return PyUdfSnafu {
|
||||
msg: "Python UDF should return exactly one column, found zero column".to_string(),
|
||||
}
|
||||
.fail();
|
||||
} // if more than one columns, just return first one
|
||||
|
||||
// TODO(discord9): more error handling
|
||||
let res0 = res.column(0);
|
||||
Ok(res0.clone())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PyScript {
|
||||
query_engine: QueryEngineRef,
|
||||
pub(crate) copr: CoprocessorRef,
|
||||
}
|
||||
|
||||
impl PyScript {
|
||||
pub fn from_script(script: &str, query_engine: QueryEngineRef) -> Result<Self> {
|
||||
let copr = Arc::new(parse::parse_and_compile_copr(
|
||||
script,
|
||||
Some(query_engine.clone()),
|
||||
)?);
|
||||
|
||||
Ok(PyScript { copr, query_engine })
|
||||
}
|
||||
/// Register Current Script as UDF, register name is same as script name
|
||||
/// FIXME(discord9): possible inject attack?
|
||||
pub async fn register_udf(&self) {
|
||||
let udf = PyUDF::from_copr(self.copr.clone());
|
||||
PyUDF::register_as_udf(udf.clone());
|
||||
PyUDF::register_to_query_engine(udf, self.query_engine.clone());
|
||||
}
|
||||
}
|
||||
|
||||
pub struct CoprStream {
|
||||
stream: SendableRecordBatchStream,
|
||||
copr: CoprocessorRef,
|
||||
ret_schema: SchemaRef,
|
||||
params: HashMap<String, String>,
|
||||
eval_ctx: EvalContext,
|
||||
}
|
||||
|
||||
impl CoprStream {
|
||||
fn try_new(
|
||||
stream: SendableRecordBatchStream,
|
||||
copr: CoprocessorRef,
|
||||
params: HashMap<String, String>,
|
||||
eval_ctx: EvalContext,
|
||||
) -> Result<Self> {
|
||||
let mut schema = vec![];
|
||||
for (ty, name) in copr.return_types.iter().zip(&copr.deco_args.ret_names) {
|
||||
let ty = ty.clone().ok_or_else(|| {
|
||||
PyRuntimeSnafu {
|
||||
msg: "return type not annotated, can't generate schema",
|
||||
}
|
||||
.build()
|
||||
})?;
|
||||
let is_nullable = ty.is_nullable;
|
||||
let ty = ty.datatype.ok_or_else(|| {
|
||||
PyRuntimeSnafu {
|
||||
msg: "return type not annotated, can't generate schema",
|
||||
}
|
||||
.build()
|
||||
})?;
|
||||
let col_schema = ColumnSchema::new(name, ty, is_nullable);
|
||||
schema.push(col_schema);
|
||||
}
|
||||
let ret_schema = Arc::new(Schema::new(schema));
|
||||
Ok(Self {
|
||||
stream,
|
||||
copr,
|
||||
ret_schema,
|
||||
params,
|
||||
eval_ctx,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl RecordBatchStream for CoprStream {
|
||||
fn schema(&self) -> SchemaRef {
|
||||
// FIXME(discord9): use copr returns for schema
|
||||
self.ret_schema.clone()
|
||||
}
|
||||
|
||||
fn output_ordering(&self) -> Option<&[OrderOption]> {
|
||||
None
|
||||
}
|
||||
|
||||
fn metrics(&self) -> Option<RecordBatchMetrics> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
impl Stream for CoprStream {
|
||||
type Item = RecordBatchResult<RecordBatch>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
match Pin::new(&mut self.stream).poll_next(cx) {
|
||||
Poll::Pending => Poll::Pending,
|
||||
Poll::Ready(Some(Ok(recordbatch))) => {
|
||||
let batch =
|
||||
exec_parsed(&self.copr, &Some(recordbatch), &self.params, &self.eval_ctx)
|
||||
.map_err(BoxedError::new)
|
||||
.context(ExternalSnafu)?;
|
||||
Poll::Ready(Some(Ok(batch)))
|
||||
}
|
||||
Poll::Ready(other) => Poll::Ready(other),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn size_hint(&self) -> (usize, Option<usize>) {
|
||||
self.stream.size_hint()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Script for PyScript {
|
||||
type Error = error::Error;
|
||||
|
||||
fn engine_name(&self) -> &str {
|
||||
PY_ENGINE
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
}
|
||||
|
||||
async fn execute(&self, params: HashMap<String, String>, ctx: EvalContext) -> Result<Output> {
|
||||
if let Some(sql) = &self.copr.deco_args.sql {
|
||||
let stmt = QueryLanguageParser::parse_sql(sql, &ctx.query_ctx).unwrap();
|
||||
ensure!(
|
||||
matches!(stmt, QueryStatement::Sql(Statement::Query { .. })),
|
||||
error::UnsupportedSqlSnafu { sql }
|
||||
);
|
||||
let plan = self
|
||||
.query_engine
|
||||
.planner()
|
||||
.plan(&stmt, ctx.query_ctx.clone())
|
||||
.await
|
||||
.context(DatabaseQuerySnafu)?;
|
||||
let res = self
|
||||
.query_engine
|
||||
.execute(plan, ctx.query_ctx.clone())
|
||||
.await
|
||||
.context(DatabaseQuerySnafu)?;
|
||||
let copr = self.copr.clone();
|
||||
match res.data {
|
||||
OutputData::Stream(stream) => Ok(Output::new_with_stream(Box::pin(
|
||||
CoprStream::try_new(stream, copr, params, ctx)?,
|
||||
))),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
} else {
|
||||
let copr = self.copr.clone();
|
||||
let params = params.clone();
|
||||
let batch = spawn_blocking_script(move || exec_parsed(&copr, &None, ¶ms, &ctx))
|
||||
.await
|
||||
.context(TokioJoinSnafu)??;
|
||||
let batches = RecordBatches::try_new(batch.schema.clone(), vec![batch]).unwrap();
|
||||
Ok(Output::new_with_record_batches(batches))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PyEngine {
|
||||
query_engine: QueryEngineRef,
|
||||
}
|
||||
|
||||
impl PyEngine {
|
||||
pub fn new(query_engine: QueryEngineRef) -> Self {
|
||||
Self { query_engine }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ScriptEngine for PyEngine {
|
||||
type Error = error::Error;
|
||||
type Script = PyScript;
|
||||
|
||||
fn name(&self) -> &str {
|
||||
PY_ENGINE
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
}
|
||||
|
||||
async fn compile(&self, script: &str, _ctx: CompileContext) -> Result<PyScript> {
|
||||
let copr = Arc::new(parse::parse_and_compile_copr(
|
||||
script,
|
||||
Some(self.query_engine.clone()),
|
||||
)?);
|
||||
|
||||
Ok(PyScript {
|
||||
copr,
|
||||
query_engine: self.query_engine.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) use tests::sample_script_engine;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use catalog::memory::MemoryCatalogManager;
|
||||
use common_catalog::consts::NUMBERS_TABLE_ID;
|
||||
use common_recordbatch::util;
|
||||
use datatypes::prelude::ScalarVector;
|
||||
use datatypes::value::Value;
|
||||
use datatypes::vectors::{Float64Vector, Int64Vector};
|
||||
use query::QueryEngineFactory;
|
||||
use table::table::numbers::NumbersTable;
|
||||
|
||||
use super::*;
|
||||
|
||||
pub(crate) fn sample_script_engine() -> PyEngine {
|
||||
let catalog_manager =
|
||||
MemoryCatalogManager::new_with_table(NumbersTable::table(NUMBERS_TABLE_ID));
|
||||
let query_engine =
|
||||
QueryEngineFactory::new(catalog_manager, None, None, None, None, false).query_engine();
|
||||
|
||||
PyEngine::new(query_engine.clone())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sql_in_py() {
|
||||
let script_engine = sample_script_engine();
|
||||
|
||||
let script = r#"
|
||||
import greptime as gt
|
||||
|
||||
@copr(args=["number"], returns = ["number"], sql = "select * from numbers")
|
||||
def test(number) -> vector[u32]:
|
||||
from greptime import query
|
||||
return query().sql("select * from numbers")[0]
|
||||
"#;
|
||||
let script = script_engine
|
||||
.compile(script, CompileContext::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let output = script
|
||||
.execute(HashMap::default(), EvalContext::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let res = common_recordbatch::util::collect_batches(match output.data {
|
||||
OutputData::Stream(s) => s,
|
||||
_ => unreachable!(),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let rb = res.iter().next().expect("One and only one recordbatch");
|
||||
assert_eq!(rb.column(0).len(), 100);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_user_params_in_py() {
|
||||
let script_engine = sample_script_engine();
|
||||
|
||||
let script = r#"
|
||||
@copr(returns = ["number"])
|
||||
def test(**params) -> vector[i64]:
|
||||
return int(params['a']) + int(params['b'])
|
||||
"#;
|
||||
let script = script_engine
|
||||
.compile(script, CompileContext::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let params = HashMap::from([
|
||||
("a".to_string(), "30".to_string()),
|
||||
("b".to_string(), "12".to_string()),
|
||||
]);
|
||||
let output = script
|
||||
.execute(params, EvalContext::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let res = match output.data {
|
||||
OutputData::RecordBatches(s) => s,
|
||||
data => unreachable!("data: {data:?}"),
|
||||
};
|
||||
let rb = res.iter().next().expect("One and only one recordbatch");
|
||||
assert_eq!(rb.column(0).len(), 1);
|
||||
let result = rb.column(0).get(0);
|
||||
assert!(matches!(result, Value::Int64(42)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_data_frame_in_py() {
|
||||
let script_engine = sample_script_engine();
|
||||
|
||||
let script = r#"
|
||||
from greptime import col
|
||||
|
||||
@copr(args=["number"], returns = ["number"], sql = "select * from numbers")
|
||||
def test(number) -> vector[u32]:
|
||||
from greptime import PyDataFrame
|
||||
return PyDataFrame.from_sql("select * from numbers").filter(col("number")==col("number")).collect()[0][0]
|
||||
"#;
|
||||
let script = script_engine
|
||||
.compile(script, CompileContext::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let output = script
|
||||
.execute(HashMap::new(), EvalContext::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let res = common_recordbatch::util::collect_batches(match output.data {
|
||||
OutputData::Stream(s) => s,
|
||||
data => unreachable!("data: {data:?}"),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let rb = res.iter().next().expect("One and only one recordbatch");
|
||||
assert_eq!(rb.column(0).len(), 100);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_compile_execute() {
|
||||
let script_engine = sample_script_engine();
|
||||
|
||||
// To avoid divide by zero, the script divides `add(a, b)` by `g.sqrt(c + 1)` instead of `g.sqrt(c)`
|
||||
let script = r#"
|
||||
import greptime as g
|
||||
def add(a, b):
|
||||
return a + b;
|
||||
|
||||
@copr(args=["a", "b", "c"], returns = ["r"], sql="select number as a,number as b,number as c from numbers limit 100")
|
||||
def test(a, b, c) -> vector[f64]:
|
||||
return add(a, b) / g.sqrt(c + 1)
|
||||
"#;
|
||||
let script = script_engine
|
||||
.compile(script, CompileContext::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let output = script
|
||||
.execute(HashMap::new(), EvalContext::default())
|
||||
.await
|
||||
.unwrap();
|
||||
match output.data {
|
||||
OutputData::Stream(stream) => {
|
||||
let numbers = util::collect(stream).await.unwrap();
|
||||
|
||||
assert_eq!(1, numbers.len());
|
||||
let number = &numbers[0];
|
||||
assert_eq!(number.num_columns(), 1);
|
||||
assert_eq!("r", number.schema.column_schemas()[0].name);
|
||||
|
||||
assert_eq!(1, number.num_columns());
|
||||
assert_eq!(100, number.column(0).len());
|
||||
let rows = number
|
||||
.column(0)
|
||||
.as_any()
|
||||
.downcast_ref::<Float64Vector>()
|
||||
.unwrap();
|
||||
assert_eq!(0f64, rows.get_data(0).unwrap());
|
||||
assert_eq!((99f64 + 99f64) / 100f64.sqrt(), rows.get_data(99).unwrap())
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
|
||||
// test list comprehension
|
||||
let script = r#"
|
||||
import greptime as gt
|
||||
|
||||
@copr(args=["number"], returns = ["r"], sql="select number from numbers limit 100")
|
||||
def test(a) -> vector[i64]:
|
||||
return gt.vector([x for x in a if x % 2 == 0])
|
||||
"#;
|
||||
let script = script_engine
|
||||
.compile(script, CompileContext::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let output = script
|
||||
.execute(HashMap::new(), EvalContext::default())
|
||||
.await
|
||||
.unwrap();
|
||||
match output.data {
|
||||
OutputData::Stream(stream) => {
|
||||
let numbers = util::collect(stream).await.unwrap();
|
||||
|
||||
assert_eq!(1, numbers.len());
|
||||
let number = &numbers[0];
|
||||
assert_eq!(number.num_columns(), 1);
|
||||
assert_eq!("r", number.schema.column_schemas()[0].name);
|
||||
|
||||
assert_eq!(1, number.num_columns());
|
||||
assert_eq!(50, number.column(0).len());
|
||||
let rows = number
|
||||
.column(0)
|
||||
.as_any()
|
||||
.downcast_ref::<Int64Vector>()
|
||||
.unwrap();
|
||||
assert_eq!(0, rows.get_data(0).unwrap());
|
||||
assert_eq!(98, rows.get_data(49).unwrap())
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,238 +0,0 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use common_error::ext::ErrorExt;
|
||||
use common_error::status_code::StatusCode;
|
||||
use common_macro::stack_trace_debug;
|
||||
use console::{style, Style};
|
||||
use datafusion::error::DataFusionError;
|
||||
use datatypes::arrow::error::ArrowError;
|
||||
use datatypes::error::Error as DataTypeError;
|
||||
use query::error::Error as QueryError;
|
||||
use rustpython_codegen::error::CodegenError;
|
||||
use rustpython_parser::ast::Location;
|
||||
use rustpython_parser::ParseError;
|
||||
pub use snafu::ensure;
|
||||
use snafu::prelude::Snafu;
|
||||
use snafu::Location as SnafuLocation;
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
pub(crate) fn ret_other_error_with(reason: String) -> OtherSnafu<String> {
|
||||
OtherSnafu { reason }
|
||||
}
|
||||
|
||||
#[derive(Snafu)]
|
||||
#[snafu(visibility(pub(crate)))]
|
||||
#[stack_trace_debug]
|
||||
pub enum Error {
|
||||
#[snafu(display("Datatype error"))]
|
||||
TypeCast {
|
||||
#[snafu(implicit)]
|
||||
location: SnafuLocation,
|
||||
source: DataTypeError,
|
||||
},
|
||||
|
||||
#[snafu(display("Failed to query"))]
|
||||
DatabaseQuery {
|
||||
#[snafu(implicit)]
|
||||
location: SnafuLocation,
|
||||
source: QueryError,
|
||||
},
|
||||
|
||||
#[snafu(display("Failed to parse script"))]
|
||||
PyParse {
|
||||
#[snafu(implicit)]
|
||||
location: SnafuLocation,
|
||||
#[snafu(source)]
|
||||
error: ParseError,
|
||||
},
|
||||
|
||||
#[snafu(display("Failed to compile script"))]
|
||||
PyCompile {
|
||||
#[snafu(implicit)]
|
||||
location: SnafuLocation,
|
||||
#[snafu(source)]
|
||||
error: CodegenError,
|
||||
},
|
||||
|
||||
/// rustpython problem, using python virtual machines' backtrace instead
|
||||
#[snafu(display("Python Runtime error, error: {}", msg))]
|
||||
PyRuntime {
|
||||
msg: String,
|
||||
#[snafu(implicit)]
|
||||
location: SnafuLocation,
|
||||
},
|
||||
|
||||
#[snafu(display("Arrow error"))]
|
||||
Arrow {
|
||||
#[snafu(implicit)]
|
||||
location: SnafuLocation,
|
||||
#[snafu(source)]
|
||||
error: ArrowError,
|
||||
},
|
||||
|
||||
#[snafu(display("DataFusion error"))]
|
||||
DataFusion {
|
||||
#[snafu(implicit)]
|
||||
location: SnafuLocation,
|
||||
#[snafu(source)]
|
||||
error: DataFusionError,
|
||||
},
|
||||
|
||||
/// errors in coprocessors' parse check for types and etc.
|
||||
#[snafu(display("Coprocessor error: {} {}.", reason,
|
||||
if let Some(loc) = loc{
|
||||
format!("at {loc:?}")
|
||||
}else{
|
||||
"".into()
|
||||
}))]
|
||||
CoprParse {
|
||||
#[snafu(implicit)]
|
||||
location: SnafuLocation,
|
||||
reason: String,
|
||||
// location is option because maybe errors can't give a clear location?
|
||||
loc: Option<Location>,
|
||||
},
|
||||
|
||||
/// Other types of error that isn't any of above
|
||||
#[snafu(display("Coprocessor's Internal error: {}", reason))]
|
||||
Other {
|
||||
#[snafu(implicit)]
|
||||
location: SnafuLocation,
|
||||
reason: String,
|
||||
},
|
||||
|
||||
#[snafu(display("Unsupported sql in coprocessor: {}", sql))]
|
||||
UnsupportedSql {
|
||||
sql: String,
|
||||
#[snafu(implicit)]
|
||||
location: SnafuLocation,
|
||||
},
|
||||
|
||||
#[snafu(display("Failed to retrieve record batches"))]
|
||||
RecordBatch {
|
||||
#[snafu(implicit)]
|
||||
location: SnafuLocation,
|
||||
source: common_recordbatch::error::Error,
|
||||
},
|
||||
|
||||
#[snafu(display("Failed to create record batch"))]
|
||||
NewRecordBatch {
|
||||
#[snafu(implicit)]
|
||||
location: SnafuLocation,
|
||||
source: common_recordbatch::error::Error,
|
||||
},
|
||||
#[snafu(display("Failed to create tokio task"))]
|
||||
TokioJoin {
|
||||
#[snafu(source)]
|
||||
error: tokio::task::JoinError,
|
||||
},
|
||||
}
|
||||
|
||||
impl ErrorExt for Error {
|
||||
fn status_code(&self) -> StatusCode {
|
||||
match self {
|
||||
Error::DataFusion { .. }
|
||||
| Error::Arrow { .. }
|
||||
| Error::PyRuntime { .. }
|
||||
| Error::TokioJoin { .. }
|
||||
| Error::Other { .. } => StatusCode::Internal,
|
||||
|
||||
Error::RecordBatch { source, .. } | Error::NewRecordBatch { source, .. } => {
|
||||
source.status_code()
|
||||
}
|
||||
Error::DatabaseQuery { source, .. } => source.status_code(),
|
||||
Error::TypeCast { source, .. } => source.status_code(),
|
||||
|
||||
Error::PyParse { .. }
|
||||
| Error::PyCompile { .. }
|
||||
| Error::CoprParse { .. }
|
||||
| Error::UnsupportedSql { .. } => StatusCode::InvalidArguments,
|
||||
}
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// pretty print [`Error`] in given script,
|
||||
/// basically print a arrow which point to where error occurs(if possible to get a location)
|
||||
pub fn pretty_print_error_in_src(
|
||||
script: &str,
|
||||
err: &Error,
|
||||
ln_offset: usize,
|
||||
filename: &str,
|
||||
) -> String {
|
||||
let (reason, loc) = get_error_reason_loc(err);
|
||||
if let Some(loc) = loc {
|
||||
visualize_loc(script, &loc, &err.to_string(), &reason, ln_offset, filename)
|
||||
} else {
|
||||
// No location provide
|
||||
format!("\n{}: {}", style("error").red().bold(), err)
|
||||
}
|
||||
}
|
||||
|
||||
/// pretty print a location in script with desc.
|
||||
///
|
||||
/// `ln_offset` is line offset number that added to `loc`'s `row`, `filename` is the file's name display with it's row and columns info.
|
||||
pub fn visualize_loc(
|
||||
script: &str,
|
||||
loc: &Location,
|
||||
err_ty: &str,
|
||||
desc: &str,
|
||||
ln_offset: usize,
|
||||
filename: &str,
|
||||
) -> String {
|
||||
let lines: Vec<&str> = script.split('\n').collect();
|
||||
let (row, col) = (loc.row(), loc.column());
|
||||
let red_bold = Style::new().red().bold();
|
||||
let blue_bold = Style::new().blue().bold();
|
||||
let col_space = (ln_offset + row).to_string().len().max(1);
|
||||
let space: String = " ".repeat(col_space - 1);
|
||||
let indicate = format!(
|
||||
"
|
||||
{error}: {err_ty}
|
||||
{space}{r_arrow}{filename}:{row}:{col}
|
||||
{prow:col_space$}{ln_pad} {line}
|
||||
{space} {ln_pad} {arrow:>pad$} {desc}
|
||||
",
|
||||
error = red_bold.apply_to("error"),
|
||||
err_ty = style(err_ty).bold(),
|
||||
r_arrow = blue_bold.apply_to("-->"),
|
||||
filename = filename,
|
||||
row = ln_offset + row,
|
||||
col = col,
|
||||
line = lines[loc.row() - 1],
|
||||
pad = loc.column(),
|
||||
arrow = red_bold.apply_to("^"),
|
||||
desc = red_bold.apply_to(desc),
|
||||
ln_pad = blue_bold.apply_to("|"),
|
||||
prow = blue_bold.apply_to(ln_offset + row),
|
||||
space = space
|
||||
);
|
||||
indicate
|
||||
}
|
||||
|
||||
/// extract a reason for [`Error`] in string format, also return a location if possible
|
||||
pub fn get_error_reason_loc(err: &Error) -> (String, Option<Location>) {
|
||||
match err {
|
||||
Error::CoprParse { reason, loc, .. } => (reason.clone(), *loc),
|
||||
Error::Other { reason, .. } => (reason.clone(), None),
|
||||
Error::PyRuntime { msg, .. } => (msg.clone(), None),
|
||||
Error::PyParse { error, .. } => (error.error.to_string(), Some(error.location)),
|
||||
Error::PyCompile { error, .. } => (error.error.to_string(), Some(error.location)),
|
||||
_ => (format!("Unknown error: {err:?}"), None),
|
||||
}
|
||||
}
|
||||
@@ -1,22 +0,0 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
pub(crate) mod copr;
|
||||
pub(crate) mod py_recordbatch;
|
||||
pub(crate) mod utils;
|
||||
pub(crate) mod vector;
|
||||
pub(crate) use copr::{check_args_anno_real_type, select_from_rb, Coprocessor};
|
||||
pub(crate) use vector::{PyVector, PyVectorRef};
|
||||
#[cfg(test)]
|
||||
mod pair_tests;
|
||||
@@ -1,533 +0,0 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
pub mod compile;
|
||||
pub mod parse;
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::result::Result as StdResult;
|
||||
use std::sync::{Arc, Weak};
|
||||
|
||||
use common_query::OutputData;
|
||||
use common_recordbatch::{RecordBatch, RecordBatches};
|
||||
use datafusion_common::ScalarValue;
|
||||
use datatypes::arrow::compute;
|
||||
use datatypes::data_type::{ConcreteDataType, DataType};
|
||||
use datatypes::prelude::Value;
|
||||
use datatypes::schema::{ColumnSchema, Schema, SchemaRef};
|
||||
use datatypes::vectors::{Helper, VectorRef};
|
||||
// use crate::python::builtins::greptime_builtin;
|
||||
use parse::DecoratorArgs;
|
||||
#[cfg(feature = "pyo3_backend")]
|
||||
use pyo3::pyclass as pyo3class;
|
||||
use query::parser::QueryLanguageParser;
|
||||
use query::QueryEngine;
|
||||
use rustpython_compiler_core::CodeObject;
|
||||
use rustpython_vm as vm;
|
||||
#[cfg(test)]
|
||||
use serde::Deserialize;
|
||||
use session::context::{QueryContextBuilder, QueryContextRef};
|
||||
use snafu::{OptionExt, ResultExt};
|
||||
use vm::convert::ToPyObject;
|
||||
use vm::{pyclass as rspyclass, PyObjectRef, PyPayload, PyResult, VirtualMachine};
|
||||
|
||||
use super::py_recordbatch::PyRecordBatch;
|
||||
use crate::engine::EvalContext;
|
||||
use crate::python::error::{
|
||||
ensure, ArrowSnafu, DataFusionSnafu, OtherSnafu, Result, TypeCastSnafu,
|
||||
};
|
||||
use crate::python::ffi_types::PyVector;
|
||||
#[cfg(feature = "pyo3_backend")]
|
||||
use crate::python::pyo3::pyo3_exec_parsed;
|
||||
use crate::python::rspython::rspy_exec_parsed;
|
||||
|
||||
#[cfg_attr(test, derive(Deserialize))]
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct AnnotationInfo {
|
||||
/// if None, use types inferred by PyVector
|
||||
// TODO(yingwen): We should use our data type. i.e. ConcreteDataType.
|
||||
pub datatype: Option<ConcreteDataType>,
|
||||
pub is_nullable: bool,
|
||||
}
|
||||
|
||||
#[cfg_attr(test, derive(Deserialize))]
|
||||
#[derive(Debug, Default, Clone, Eq, PartialEq)]
|
||||
pub enum BackendType {
|
||||
#[default]
|
||||
RustPython,
|
||||
// TODO(discord9): intergral test
|
||||
#[allow(unused)]
|
||||
CPython,
|
||||
}
|
||||
|
||||
pub type CoprocessorRef = Arc<Coprocessor>;
|
||||
|
||||
#[cfg_attr(test, derive(Deserialize))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Coprocessor {
|
||||
pub name: String,
|
||||
pub deco_args: DecoratorArgs,
|
||||
/// get from python function args' annotation, first is type, second is is_nullable
|
||||
pub arg_types: Vec<Option<AnnotationInfo>>,
|
||||
/// get from python function returns' annotation, first is type, second is is_nullable
|
||||
pub return_types: Vec<Option<AnnotationInfo>>,
|
||||
/// kwargs in coprocessor function's signature
|
||||
pub kwarg: Option<String>,
|
||||
/// store its corresponding script, also skip serde when in `cfg(test)` to reduce work in compare
|
||||
#[cfg_attr(test, serde(skip))]
|
||||
pub script: String,
|
||||
// We must use option here, because we use `serde` to deserialize coprocessor
|
||||
// from ron file and `Deserialize` requires Coprocessor implementing `Default` trait,
|
||||
// but CodeObject doesn't.
|
||||
#[cfg_attr(test, serde(skip))]
|
||||
pub code_obj: Option<CodeObject>,
|
||||
#[cfg_attr(test, serde(skip))]
|
||||
pub query_engine: Option<QueryEngineWeakRef>,
|
||||
/// Use which backend to run this script
|
||||
/// Ideally in test both backend should be tested, so skip this
|
||||
#[cfg_attr(test, serde(skip))]
|
||||
pub backend: BackendType,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct QueryEngineWeakRef(pub Weak<dyn QueryEngine>);
|
||||
|
||||
impl From<Weak<dyn QueryEngine>> for QueryEngineWeakRef {
|
||||
fn from(value: Weak<dyn QueryEngine>) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&Arc<dyn QueryEngine>> for QueryEngineWeakRef {
|
||||
fn from(value: &Arc<dyn QueryEngine>) -> Self {
|
||||
Self(Arc::downgrade(value))
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for QueryEngineWeakRef {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_tuple("QueryEngineWeakRef")
|
||||
.field(&self.0.upgrade().map(|f| f.name().to_string()))
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for Coprocessor {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.name == other.name
|
||||
&& self.deco_args == other.deco_args
|
||||
&& self.arg_types == other.arg_types
|
||||
&& self.return_types == other.return_types
|
||||
&& self.script == other.script
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for Coprocessor {}
|
||||
|
||||
impl Coprocessor {
|
||||
/// generate [`Schema`] according to return names, types,
|
||||
/// if no annotation
|
||||
/// the datatypes of the actual columns is used directly
|
||||
pub(crate) fn gen_schema(&self, cols: &[VectorRef]) -> Result<SchemaRef> {
|
||||
let names = &self.deco_args.ret_names;
|
||||
let anno = &self.return_types;
|
||||
ensure!(
|
||||
cols.len() == names.len() && names.len() == anno.len(),
|
||||
OtherSnafu {
|
||||
reason: format!(
|
||||
"Unmatched length for cols({}), names({}) and annotation({})",
|
||||
cols.len(),
|
||||
names.len(),
|
||||
anno.len()
|
||||
)
|
||||
}
|
||||
);
|
||||
|
||||
let column_schemas = names
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, name)| {
|
||||
let real_ty = cols[idx].data_type();
|
||||
let AnnotationInfo {
|
||||
datatype: ty,
|
||||
is_nullable,
|
||||
} = anno[idx].clone().unwrap_or_else(|| {
|
||||
// default to be not nullable and use DataType inferred by PyVector itself
|
||||
AnnotationInfo {
|
||||
datatype: Some(real_ty.clone()),
|
||||
is_nullable: false,
|
||||
}
|
||||
});
|
||||
let column_type = match ty {
|
||||
Some(anno_type) => anno_type,
|
||||
// if type is like `_` or `_ | None`
|
||||
None => real_ty,
|
||||
};
|
||||
Ok(ColumnSchema::new(name, column_type, is_nullable))
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
Ok(Arc::new(Schema::new(column_schemas)))
|
||||
}
|
||||
|
||||
/// check if real types and annotation types(if have) is the same, if not try cast columns to annotated type
|
||||
pub(crate) fn check_and_cast_type(&self, cols: &mut [VectorRef]) -> Result<()> {
|
||||
for col in cols.iter_mut() {
|
||||
if let ConcreteDataType::List(x) = col.data_type() {
|
||||
let values =
|
||||
ScalarValue::convert_array_to_scalar_vec(col.to_arrow_array().as_ref())
|
||||
.context(DataFusionSnafu)?
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.map(Value::try_from)
|
||||
.collect::<std::result::Result<Vec<_>, _>>()
|
||||
.context(TypeCastSnafu)?;
|
||||
|
||||
let mut builder = x.item_type().create_mutable_vector(values.len());
|
||||
for v in values.iter() {
|
||||
builder.push_value_ref(v.as_value_ref());
|
||||
}
|
||||
*col = builder.to_vector();
|
||||
}
|
||||
}
|
||||
|
||||
let return_types = &self.return_types;
|
||||
// allow ignore Return Type Annotation
|
||||
if return_types.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
ensure!(
|
||||
cols.len() == return_types.len(),
|
||||
OtherSnafu {
|
||||
reason: format!(
|
||||
"The number of return Vector is wrong, expect {}, found {}",
|
||||
return_types.len(),
|
||||
cols.len()
|
||||
)
|
||||
}
|
||||
);
|
||||
for (col, anno) in cols.iter_mut().zip(return_types) {
|
||||
if let Some(AnnotationInfo {
|
||||
datatype: Some(datatype),
|
||||
is_nullable: _,
|
||||
}) = anno
|
||||
{
|
||||
let real_ty = col.data_type();
|
||||
let anno_ty = datatype;
|
||||
if real_ty != *anno_ty {
|
||||
let array = col.to_arrow_array();
|
||||
let array =
|
||||
compute::cast(&array, &anno_ty.as_arrow_type()).context(ArrowSnafu)?;
|
||||
*col = Helper::try_into_vector(array).context(TypeCastSnafu)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// select columns according to `fetch_names` from `rb`
|
||||
/// and cast them into a Vec of PyVector
|
||||
pub(crate) fn select_from_rb(rb: &RecordBatch, fetch_names: &[String]) -> Result<Vec<PyVector>> {
|
||||
fetch_names
|
||||
.iter()
|
||||
.map(|name| {
|
||||
let vector = rb.column_by_name(name).with_context(|| OtherSnafu {
|
||||
reason: format!("Can't find field name {name} in all columns in {rb:?}"),
|
||||
})?;
|
||||
Ok(PyVector::from(vector.clone()))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// match between arguments' real type and annotation types
|
||||
/// if type anno is `vector[_]` then use real type(from RecordBatch's schema)
|
||||
pub(crate) fn check_args_anno_real_type(
|
||||
arg_names: &[String],
|
||||
args: &[PyVector],
|
||||
copr: &Coprocessor,
|
||||
rb: &RecordBatch,
|
||||
) -> Result<()> {
|
||||
ensure!(
|
||||
arg_names.len() == args.len(),
|
||||
OtherSnafu {
|
||||
reason: format!("arg_names:{arg_names:?} and args{args:?}'s length is different")
|
||||
}
|
||||
);
|
||||
for (idx, arg) in args.iter().enumerate() {
|
||||
let anno_ty = copr.arg_types[idx].clone();
|
||||
let real_ty = arg.data_type();
|
||||
let arg_name = arg_names[idx].clone();
|
||||
let col_idx = rb.schema.column_index_by_name(&arg_name).ok_or_else(|| {
|
||||
OtherSnafu {
|
||||
reason: format!("Can't find column by name {arg_name}"),
|
||||
}
|
||||
.build()
|
||||
})?;
|
||||
let is_nullable: bool = rb.schema.column_schemas()[col_idx].is_nullable();
|
||||
ensure!(
|
||||
anno_ty
|
||||
.clone()
|
||||
.map(|v| v.datatype.is_none() // like a vector[_]
|
||||
|| v.datatype == Some(real_ty.clone()) && v.is_nullable == is_nullable)
|
||||
.unwrap_or(true),
|
||||
OtherSnafu {
|
||||
reason: format!(
|
||||
"column {}'s Type annotation is {:?}, but actual type is {:?} with nullable=={}",
|
||||
// It's safe to unwrap here, we already ensure the args and types number is the same when parsing
|
||||
copr.deco_args.arg_names.as_ref().unwrap()[idx],
|
||||
anno_ty,
|
||||
real_ty,
|
||||
is_nullable
|
||||
)
|
||||
}
|
||||
)
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// The coprocessor function accept a python script and a Record Batch:
|
||||
/// ## What it does
|
||||
/// 1. it take a python script and a [`RecordBatch`], extract columns and annotation info according to `args` given in decorator in python script
|
||||
/// 2. execute python code and return a vector or a tuple of vector,
|
||||
/// 3. the returning vector(s) is assembled into a new [`RecordBatch`] according to `returns` in python decorator and return to caller
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```ignore
|
||||
/// use std::sync::Arc;
|
||||
/// use common_recordbatch::RecordBatch;
|
||||
/// use datatypes::prelude::*;
|
||||
/// use datatypes::schema::{ColumnSchema, Schema};
|
||||
/// use datatypes::vectors::{Float32Vector, Float64Vector};
|
||||
/// use common_function::scalars::python::exec_coprocessor;
|
||||
/// let python_source = r#"
|
||||
/// @copr(args=["cpu", "mem"], returns=["perf", "what"])
|
||||
/// def a(cpu, mem):
|
||||
/// return cpu + mem, cpu - mem
|
||||
/// "#;
|
||||
/// let cpu_array = Float32Vector::from_slice([0.9f32, 0.8, 0.7, 0.6]);
|
||||
/// let mem_array = Float64Vector::from_slice([0.1f64, 0.2, 0.3, 0.4]);
|
||||
/// let schema = Arc::new(Schema::new(vec![
|
||||
/// ColumnSchema::new("cpu", ConcreteDataType::float32_datatype(), false),
|
||||
/// ColumnSchema::new("mem", ConcreteDataType::float64_datatype(), false),
|
||||
/// ]));
|
||||
/// let rb =
|
||||
/// RecordBatch::new(schema, vec![Arc::new(cpu_array), Arc::new(mem_array)]).unwrap();
|
||||
/// let ret = exec_coprocessor(python_source, &rb).unwrap();
|
||||
/// assert_eq!(ret.column(0).len(), 4);
|
||||
/// ```
|
||||
///
|
||||
/// # Type Annotation
|
||||
/// you can use type annotations in args and returns to designate types, so coprocessor will check for corresponding types.
|
||||
///
|
||||
/// Currently support types are `u8`, `u16`, `u32`, `u64`, `i8`, `i16`, `i32`, `i64` and `f16`, `f32`, `f64`
|
||||
///
|
||||
/// use `f64 | None` to mark if returning column is nullable like in [`RecordBatch`]'s schema's [`ColumnSchema`]'s is_nullable
|
||||
///
|
||||
/// you can also use single underscore `_` to let coprocessor infer what type it is, so `_` and `_ | None` are both valid in type annotation.
|
||||
/// Note: using `_` means not nullable column, using `_ | None` means nullable column
|
||||
///
|
||||
/// a example (of python script) given below:
|
||||
/// ```python
|
||||
/// @copr(args=["cpu", "mem"], returns=["perf", "minus", "mul", "div"])
|
||||
/// def a(cpu: vector[f32], mem: vector[f64])->(vector[f64|None], vector[f64], vector[_], vector[_ | None]):
|
||||
/// return cpu + mem, cpu - mem, cpu * mem, cpu / mem
|
||||
/// ```
|
||||
///
|
||||
/// # Return Constant columns
|
||||
/// You can return constant in python code like `return 1, 1.0, True`
|
||||
/// which create a constant array(with same value)(currently support int, float and bool) as column on return
|
||||
#[cfg(test)]
|
||||
pub fn exec_coprocessor(
|
||||
script: &str,
|
||||
rb: &Option<RecordBatch>,
|
||||
eval_ctx: &EvalContext,
|
||||
) -> Result<RecordBatch> {
|
||||
// 1. parse the script and check if it's only a function with `@coprocessor` decorator, and get `args` and `returns`,
|
||||
// 2. also check for exist of `args` in `rb`, if not found, return error
|
||||
// cache the result of parse_copr
|
||||
let copr = parse::parse_and_compile_copr(script, None)?;
|
||||
exec_parsed(&copr, rb, &HashMap::new(), eval_ctx)
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "pyo3_backend", pyo3class(name = "query_engine"))]
|
||||
#[rspyclass(module = false, name = "query_engine")]
|
||||
#[derive(Debug, PyPayload, Clone)]
|
||||
pub struct PyQueryEngine {
|
||||
inner: QueryEngineWeakRef,
|
||||
query_ctx: QueryContextRef,
|
||||
}
|
||||
pub(crate) enum Either {
|
||||
Rb(RecordBatches),
|
||||
AffectedRows(usize),
|
||||
}
|
||||
|
||||
impl PyQueryEngine {
|
||||
pub(crate) fn sql_to_rb(&self, sql: String) -> StdResult<RecordBatch, String> {
|
||||
let res = self.query_with_new_thread(sql.clone())?;
|
||||
match res {
|
||||
Either::Rb(rbs) => {
|
||||
let rb = compute::concat_batches(
|
||||
rbs.schema().arrow_schema(),
|
||||
rbs.iter().map(|r| r.df_record_batch()),
|
||||
)
|
||||
.map_err(|e| format!("Concat batches failed for query {sql}: {e}"))?;
|
||||
|
||||
RecordBatch::try_from_df_record_batch(rbs.schema(), rb)
|
||||
.map_err(|e| format!("Convert datafusion record batch to record batch failed for query {sql}: {e}"))
|
||||
}
|
||||
Either::AffectedRows(_) => Err(format!("Expect actual results from query {sql}")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[rspyclass]
|
||||
impl PyQueryEngine {
|
||||
pub(crate) fn from_weakref(inner: QueryEngineWeakRef, query_ctx: QueryContextRef) -> Self {
|
||||
Self { inner, query_ctx }
|
||||
}
|
||||
pub(crate) fn query_with_new_thread(&self, s: String) -> StdResult<Either, String> {
|
||||
let query = self.inner.0.upgrade();
|
||||
let query_ctx = self.query_ctx.clone();
|
||||
let thread_handle = std::thread::spawn(move || -> std::result::Result<_, String> {
|
||||
if let Some(engine) = query {
|
||||
let stmt =
|
||||
QueryLanguageParser::parse_sql(&s, &query_ctx).map_err(|e| e.to_string())?;
|
||||
|
||||
// To prevent the error of nested creating Runtime, if is nested, use the parent runtime instead
|
||||
|
||||
let rt = tokio::runtime::Runtime::new().map_err(|e| e.to_string())?;
|
||||
let handle = rt.handle().clone();
|
||||
let res = handle.block_on(async {
|
||||
let ctx = Arc::new(QueryContextBuilder::default().build());
|
||||
let plan = engine
|
||||
.planner()
|
||||
.plan(&stmt, ctx.clone())
|
||||
.await
|
||||
.map_err(|e| e.to_string())?;
|
||||
let res = engine
|
||||
.clone()
|
||||
.execute(plan, ctx)
|
||||
.await
|
||||
.map_err(|e| e.to_string());
|
||||
match res {
|
||||
Ok(o) => match o.data {
|
||||
OutputData::AffectedRows(cnt) => Ok(Either::AffectedRows(cnt)),
|
||||
OutputData::RecordBatches(rbs) => Ok(Either::Rb(rbs)),
|
||||
OutputData::Stream(s) => Ok(Either::Rb(
|
||||
common_recordbatch::util::collect_batches(s).await.unwrap(),
|
||||
)),
|
||||
},
|
||||
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
})?;
|
||||
Ok(res)
|
||||
} else {
|
||||
Err("Query Engine is already dropped".to_string())
|
||||
}
|
||||
});
|
||||
thread_handle
|
||||
.join()
|
||||
.map_err(|e| format!("Dedicated thread for sql query panic: {e:?}"))?
|
||||
}
|
||||
// TODO(discord9): find a better way to call sql query api, now we don't if we are in async context or not
|
||||
/// - return sql query results in `PyRecordBatch`, or
|
||||
/// - a empty `PyDict` if query results is empty
|
||||
/// - or number of AffectedRows
|
||||
#[pymethod]
|
||||
fn sql(&self, s: String, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
|
||||
self.query_with_new_thread(s)
|
||||
.map_err(|e| vm.new_system_error(e))
|
||||
.map(|rbs| match rbs {
|
||||
Either::Rb(rbs) => {
|
||||
let rb = compute::concat_batches(
|
||||
rbs.schema().arrow_schema(),
|
||||
rbs.iter().map(|rb| rb.df_record_batch()),
|
||||
)
|
||||
.map_err(|e| {
|
||||
vm.new_runtime_error(format!("Failed to concat batches: {e:#?}"))
|
||||
})?;
|
||||
let rb =
|
||||
RecordBatch::try_from_df_record_batch(rbs.schema(), rb).map_err(|e| {
|
||||
vm.new_runtime_error(format!("Failed to cast recordbatch: {e:#?}"))
|
||||
})?;
|
||||
let rb = PyRecordBatch::new(rb);
|
||||
|
||||
Ok(rb.to_pyobject(vm))
|
||||
}
|
||||
Either::AffectedRows(cnt) => Ok(vm.ctx.new_int(cnt).to_pyobject(vm)),
|
||||
})?
|
||||
}
|
||||
}
|
||||
|
||||
/// using a parsed `Coprocessor` struct as input to execute python code
|
||||
pub fn exec_parsed(
|
||||
copr: &Coprocessor,
|
||||
rb: &Option<RecordBatch>,
|
||||
params: &HashMap<String, String>,
|
||||
eval_ctx: &EvalContext,
|
||||
) -> Result<RecordBatch> {
|
||||
match copr.backend {
|
||||
BackendType::RustPython => rspy_exec_parsed(copr, rb, params, eval_ctx),
|
||||
BackendType::CPython => {
|
||||
#[cfg(feature = "pyo3_backend")]
|
||||
{
|
||||
pyo3_exec_parsed(copr, rb, params, eval_ctx)
|
||||
}
|
||||
#[cfg(not(feature = "pyo3_backend"))]
|
||||
{
|
||||
OtherSnafu {
|
||||
reason: "`pyo3` feature is disabled, therefore can't run scripts in cpython"
|
||||
.to_string(),
|
||||
}
|
||||
.fail()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::python::ffi_types::copr::parse::parse_and_compile_copr;
|
||||
|
||||
#[test]
|
||||
fn test_parse_copr() {
|
||||
let script = r#"
|
||||
def add(a, b):
|
||||
return a + b
|
||||
|
||||
@copr(args=["a", "b", "c"], returns = ["r"], sql="select number as a,number as b,number as c from numbers limit 100")
|
||||
def test(a, b, c, **params):
|
||||
import greptime as g
|
||||
return ( a + b ) / g.sqrt(c)
|
||||
"#;
|
||||
|
||||
let copr = parse_and_compile_copr(script, None).unwrap();
|
||||
assert_eq!(copr.name, "test");
|
||||
let deco_args = copr.deco_args.clone();
|
||||
assert_eq!(
|
||||
deco_args.sql.unwrap(),
|
||||
"select number as a,number as b,number as c from numbers limit 100"
|
||||
);
|
||||
assert_eq!(deco_args.ret_names, vec!["r"]);
|
||||
assert_eq!(deco_args.arg_names.unwrap(), vec!["a", "b", "c"]);
|
||||
assert_eq!(copr.arg_types, vec![None, None, None]);
|
||||
assert_eq!(copr.return_types, vec![None]);
|
||||
assert_eq!(copr.kwarg, Some("params".to_string()));
|
||||
assert_eq!(copr.script, script);
|
||||
let _ = copr.code_obj.unwrap();
|
||||
}
|
||||
}
|
||||
@@ -1,165 +0,0 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! compile script to code object
|
||||
use rustpython_codegen::compile::compile_top;
|
||||
use rustpython_compiler::{CompileOpts, Mode};
|
||||
use rustpython_compiler_core::CodeObject;
|
||||
use rustpython_parser::ast::{ArgData, Located, Location};
|
||||
use rustpython_parser::{ast, parse, Mode as ParseMode};
|
||||
use snafu::ResultExt;
|
||||
|
||||
use crate::fail_parse_error;
|
||||
use crate::python::error::{PyCompileSnafu, PyParseSnafu, Result};
|
||||
use crate::python::ffi_types::copr::parse::{ret_parse_error, DecoratorArgs};
|
||||
|
||||
fn create_located<T>(node: T, loc: Location) -> Located<T> {
|
||||
Located::new(loc, loc, node)
|
||||
}
|
||||
|
||||
/// generate a call to the coprocessor function
|
||||
/// with arguments given in decorator's `args` list
|
||||
/// also set in location in source code to `loc`
|
||||
fn gen_call(
|
||||
name: &str,
|
||||
deco_args: &DecoratorArgs,
|
||||
kwarg: &Option<String>,
|
||||
loc: &Location,
|
||||
) -> ast::Stmt<()> {
|
||||
let mut loc = *loc;
|
||||
// adding a line to avoid confusing if any error occurs when calling the function
|
||||
// then the pretty print will point to the last line in code
|
||||
// instead of point to any of existing code written by user.
|
||||
loc.newline();
|
||||
let mut args: Vec<Located<ast::ExprKind>> = if let Some(arg_names) = &deco_args.arg_names {
|
||||
arg_names
|
||||
.iter()
|
||||
.map(|v| {
|
||||
let node = ast::ExprKind::Name {
|
||||
id: v.clone(),
|
||||
ctx: ast::ExprContext::Load,
|
||||
};
|
||||
create_located(node, loc)
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
if let Some(kwarg) = kwarg {
|
||||
let node = ast::ExprKind::Name {
|
||||
id: kwarg.clone(),
|
||||
ctx: ast::ExprContext::Load,
|
||||
};
|
||||
args.push(create_located(node, loc));
|
||||
}
|
||||
|
||||
let func = ast::ExprKind::Call {
|
||||
func: Box::new(create_located(
|
||||
ast::ExprKind::Name {
|
||||
id: name.to_string(),
|
||||
ctx: ast::ExprContext::Load,
|
||||
},
|
||||
loc,
|
||||
)),
|
||||
args,
|
||||
keywords: Vec::new(),
|
||||
};
|
||||
let stmt = ast::StmtKind::Expr {
|
||||
value: Box::new(create_located(func, loc)),
|
||||
};
|
||||
create_located(stmt, loc)
|
||||
}
|
||||
|
||||
/// stripe the decorator(`@xxxx`) and type annotation(for type checker is done in rust function), add one line in the ast for call function with given parameter, and compiler into `CodeObject`
|
||||
///
|
||||
/// The rationale is that rustpython's vm is not very efficient according to [official benchmark](https://rustpython.github.io/benchmarks),
|
||||
/// So we should avoid running too much Python Bytecode, hence in this function we delete `@` decorator(instead of actually write a decorator in python)
|
||||
/// And add a function call in the end and also
|
||||
/// strip type annotation
|
||||
pub fn compile_script(
|
||||
name: &str,
|
||||
deco_args: &DecoratorArgs,
|
||||
kwarg: &Option<String>,
|
||||
script: &str,
|
||||
) -> Result<CodeObject> {
|
||||
// note that it's important to use `parser::Mode::Interactive` so the ast can be compile to return a result instead of return None in eval mode
|
||||
let mut top = parse(script, ParseMode::Interactive, "<embedded>").context(PyParseSnafu)?;
|
||||
// erase decorator
|
||||
if let ast::Mod::Interactive { body } = &mut top {
|
||||
let stmts = body;
|
||||
let mut loc = None;
|
||||
for stmt in stmts.iter_mut() {
|
||||
if let ast::StmtKind::FunctionDef {
|
||||
name: _,
|
||||
args,
|
||||
body: _,
|
||||
decorator_list,
|
||||
returns,
|
||||
type_comment: __main__,
|
||||
} = &mut stmt.node
|
||||
{
|
||||
// Rewrite kwargs in coprocessor, make it as a positional argument
|
||||
if !decorator_list.is_empty() {
|
||||
if let Some(kwarg) = kwarg {
|
||||
args.kwarg = None;
|
||||
let node = ArgData {
|
||||
arg: kwarg.clone(),
|
||||
annotation: None,
|
||||
type_comment: Some("kwargs".to_string()),
|
||||
};
|
||||
let kwarg = create_located(node, stmt.location);
|
||||
args.args.push(kwarg);
|
||||
}
|
||||
}
|
||||
|
||||
*decorator_list = Vec::new();
|
||||
// strip type annotation
|
||||
// def a(b: int, c:int) -> int
|
||||
// will became
|
||||
// def a(b, c)
|
||||
*returns = None;
|
||||
for arg in &mut args.args {
|
||||
arg.node.annotation = None;
|
||||
}
|
||||
} else if matches!(
|
||||
stmt.node,
|
||||
ast::StmtKind::Import { .. } | ast::StmtKind::ImportFrom { .. }
|
||||
) {
|
||||
// import statements are allowed.
|
||||
} else {
|
||||
// already checked in parser
|
||||
unreachable!()
|
||||
}
|
||||
loc = Some(stmt.location);
|
||||
|
||||
// This manually construct ast has no corresponding code
|
||||
// in the script, so just give it a location that don't exist in original script
|
||||
// (which doesn't matter because Location usually only used in pretty print errors)
|
||||
}
|
||||
// Append statement which calling coprocessor function.
|
||||
// It's safe to unwrap loc, it is always exists.
|
||||
stmts.push(gen_call(name, deco_args, kwarg, &loc.unwrap()));
|
||||
} else {
|
||||
return fail_parse_error!(format!("Expect statement in script, found: {top:?}"), None);
|
||||
}
|
||||
// use `compile::Mode::BlockExpr` so it return the result of statement
|
||||
compile_top(
|
||||
&top,
|
||||
"<embedded>".to_string(),
|
||||
Mode::BlockExpr,
|
||||
CompileOpts { optimize: 0 },
|
||||
)
|
||||
.context(PyCompileSnafu)
|
||||
}
|
||||
@@ -1,552 +0,0 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
|
||||
use datatypes::prelude::ConcreteDataType;
|
||||
use query::QueryEngineRef;
|
||||
use rustpython_parser::ast::{Arguments, Location};
|
||||
use rustpython_parser::{ast, parse_program};
|
||||
#[cfg(test)]
|
||||
use serde::Deserialize;
|
||||
use snafu::{OptionExt, ResultExt};
|
||||
|
||||
use crate::python::error::{ensure, CoprParseSnafu, PyParseSnafu, Result};
|
||||
use crate::python::ffi_types::copr::{compile, AnnotationInfo, BackendType, Coprocessor};
|
||||
#[cfg_attr(test, derive(Deserialize))]
|
||||
#[derive(Default, Debug, Clone, PartialEq, Eq)]
|
||||
pub struct DecoratorArgs {
|
||||
pub arg_names: Option<Vec<String>>,
|
||||
pub ret_names: Vec<String>,
|
||||
pub sql: Option<String>,
|
||||
#[cfg_attr(test, serde(skip))]
|
||||
pub backend: BackendType, // maybe add a URL for connecting or what?
|
||||
// also predicate for timed triggered or conditional triggered?
|
||||
}
|
||||
|
||||
/// Return a CoprParseSnafu for you to chain fail() to return correct err Result type
|
||||
pub(crate) fn ret_parse_error(
|
||||
reason: String,
|
||||
loc: Option<Location>,
|
||||
) -> CoprParseSnafu<String, Option<Location>> {
|
||||
CoprParseSnafu { reason, loc }
|
||||
}
|
||||
|
||||
/// append a `.fail()` after `ret_parse_error`, so compiler can return a Err(this error)
|
||||
#[macro_export]
|
||||
macro_rules! fail_parse_error {
|
||||
($reason:expr, $loc:expr $(,)*) => {
|
||||
ret_parse_error($reason, $loc).fail()
|
||||
};
|
||||
}
|
||||
|
||||
fn py_str_to_string(s: &ast::Expr<()>) -> Result<String> {
|
||||
if let ast::ExprKind::Constant {
|
||||
value: ast::Constant::Str(v),
|
||||
kind: _,
|
||||
} = &s.node
|
||||
{
|
||||
Ok(v.clone())
|
||||
} else {
|
||||
fail_parse_error!(
|
||||
format!(
|
||||
"Expect a list of String, found one element to be: \n{:#?}",
|
||||
&s.node
|
||||
),
|
||||
Some(s.location)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// turn a python list of string in ast form(a `ast::Expr`) of string into a `Vec<String>`
|
||||
fn pylist_to_vec(lst: &ast::Expr<()>) -> Result<Vec<String>> {
|
||||
if let ast::ExprKind::List { elts, ctx: _ } = &lst.node {
|
||||
let ret = elts.iter().map(py_str_to_string).collect::<Result<_>>()?;
|
||||
Ok(ret)
|
||||
} else {
|
||||
fail_parse_error!(
|
||||
format!("Expect a list, found \n{:#?}", &lst.node),
|
||||
Some(lst.location)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn try_into_datatype(ty: &str, loc: &Location) -> Result<Option<ConcreteDataType>> {
|
||||
match ty {
|
||||
"bool" => Ok(Some(ConcreteDataType::boolean_datatype())),
|
||||
"u8" => Ok(Some(ConcreteDataType::uint8_datatype())),
|
||||
"u16" => Ok(Some(ConcreteDataType::uint16_datatype())),
|
||||
"u32" => Ok(Some(ConcreteDataType::uint32_datatype())),
|
||||
"u64" => Ok(Some(ConcreteDataType::uint64_datatype())),
|
||||
"i8" => Ok(Some(ConcreteDataType::int8_datatype())),
|
||||
"i16" => Ok(Some(ConcreteDataType::int16_datatype())),
|
||||
"i32" => Ok(Some(ConcreteDataType::int32_datatype())),
|
||||
"i64" => Ok(Some(ConcreteDataType::int64_datatype())),
|
||||
"f32" => Ok(Some(ConcreteDataType::float32_datatype())),
|
||||
"f64" => Ok(Some(ConcreteDataType::float64_datatype())),
|
||||
"str" => Ok(Some(ConcreteDataType::string_datatype())),
|
||||
// for any datatype
|
||||
"_" => Ok(None),
|
||||
// note the different between "_" and _
|
||||
_ => fail_parse_error!(format!("Unknown datatype: {ty} at {loc:?}"), Some(*loc)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Item => NativeType
|
||||
/// default to be not nullable
|
||||
fn parse_native_type(sub: &ast::Expr<()>) -> Result<AnnotationInfo> {
|
||||
match &sub.node {
|
||||
ast::ExprKind::Name { id, .. } => Ok(AnnotationInfo {
|
||||
datatype: try_into_datatype(id, &sub.location)?,
|
||||
is_nullable: false,
|
||||
}),
|
||||
_ => fail_parse_error!(
|
||||
format!("Expect types' name, found \n{:#?}", &sub.node),
|
||||
Some(sub.location)
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
/// check if binary op expr is legal(with one typename and one `None`)
|
||||
fn check_bin_op(bin_op: &ast::Expr<()>) -> Result<()> {
|
||||
if let ast::ExprKind::BinOp { left, op: _, right } = &bin_op.node {
|
||||
// 1. first check if this BinOp is legal(Have one typename and(optional) a None)
|
||||
let is_none = |node: &ast::Expr<()>| -> bool {
|
||||
matches!(
|
||||
&node.node,
|
||||
ast::ExprKind::Constant {
|
||||
value: ast::Constant::None,
|
||||
kind: _,
|
||||
}
|
||||
)
|
||||
};
|
||||
let is_type = |node: &ast::Expr<()>| {
|
||||
if let ast::ExprKind::Name { id, ctx: _ } = &node.node {
|
||||
try_into_datatype(id, &node.location).is_ok()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
};
|
||||
let left_is_ty = is_type(left);
|
||||
let left_is_none = is_none(left);
|
||||
let right_is_ty = is_type(right);
|
||||
let right_is_none = is_none(right);
|
||||
if left_is_ty && right_is_ty || left_is_none && right_is_none {
|
||||
fail_parse_error!(
|
||||
"Expect one typenames and one `None`".to_string(),
|
||||
Some(bin_op.location)
|
||||
)?;
|
||||
} else if !(left_is_none && right_is_ty || left_is_ty && right_is_none) {
|
||||
fail_parse_error!(
|
||||
format!(
|
||||
"Expect a type name and a `None`, found left: \n{:#?} \nand right: \n{:#?}",
|
||||
&left.node, &right.node
|
||||
),
|
||||
Some(bin_op.location)
|
||||
)?;
|
||||
}
|
||||
Ok(())
|
||||
} else {
|
||||
fail_parse_error!(
|
||||
format!(
|
||||
"Expect binary ops like `DataType | None`, found \n{:#?}",
|
||||
bin_op.node
|
||||
),
|
||||
Some(bin_op.location)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// parse a `DataType | None` or a single `DataType`
|
||||
fn parse_bin_op(bin_op: &ast::Expr<()>) -> Result<AnnotationInfo> {
|
||||
// 1. first check if this BinOp is legal(Have one typename and(optional) a None)
|
||||
check_bin_op(bin_op)?;
|
||||
if let ast::ExprKind::BinOp { left, op: _, right } = &bin_op.node {
|
||||
// then get types from this BinOp
|
||||
let left_ty = parse_native_type(left).ok();
|
||||
let right_ty = parse_native_type(right).ok();
|
||||
let mut ty_anno = if let Some(left_ty) = left_ty {
|
||||
left_ty
|
||||
} else if let Some(right_ty) = right_ty {
|
||||
right_ty
|
||||
} else {
|
||||
// deal with errors anyway in case code above changed but forget to modify
|
||||
return fail_parse_error!(
|
||||
"Expect a type name, not two `None`".into(),
|
||||
Some(bin_op.location),
|
||||
);
|
||||
};
|
||||
// because check_bin_op assure a `None` exist
|
||||
ty_anno.is_nullable = true;
|
||||
return Ok(ty_anno);
|
||||
}
|
||||
unreachable!()
|
||||
}
|
||||
|
||||
/// check for the grammar correctness of annotation, also return the slice of subscript for further parsing
|
||||
fn check_annotation_ret_slice(sub: &ast::Expr<()>) -> Result<&ast::Expr<()>> {
|
||||
// TODO(discord9): allow a single annotation like `vector`
|
||||
if let ast::ExprKind::Subscript {
|
||||
value,
|
||||
slice,
|
||||
ctx: _,
|
||||
} = &sub.node
|
||||
{
|
||||
if let ast::ExprKind::Name { id, ctx: _ } = &value.node {
|
||||
ensure!(
|
||||
id == "vector",
|
||||
ret_parse_error(
|
||||
format!("Wrong type annotation, expect `vector[...]`, found `{id}`"),
|
||||
Some(value.location)
|
||||
)
|
||||
);
|
||||
} else {
|
||||
return fail_parse_error!(
|
||||
format!("Expect \"vector\", found \n{:#?}", &value.node),
|
||||
Some(value.location)
|
||||
);
|
||||
}
|
||||
Ok(slice)
|
||||
} else {
|
||||
fail_parse_error!(
|
||||
format!("Expect type annotation, found \n{:#?}", &sub),
|
||||
Some(sub.location)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// where:
|
||||
///
|
||||
/// Start => vector`[`TYPE`]`
|
||||
///
|
||||
/// TYPE => Item | Item `|` None
|
||||
///
|
||||
/// Item => NativeType
|
||||
fn parse_annotation(sub: &ast::Expr<()>) -> Result<AnnotationInfo> {
|
||||
let slice = check_annotation_ret_slice(sub)?;
|
||||
|
||||
{
|
||||
// i.e: vector[f64]
|
||||
match &slice.node {
|
||||
ast::ExprKind::Name { .. } => parse_native_type(slice),
|
||||
ast::ExprKind::BinOp {
|
||||
left: _,
|
||||
op: _,
|
||||
right: _,
|
||||
} => parse_bin_op(slice),
|
||||
_ => {
|
||||
fail_parse_error!(
|
||||
format!("Expect type in `vector[...]`, found \n{:#?}", &slice.node),
|
||||
Some(slice.location),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// parse a list of keyword and return args and returns list from keywords
|
||||
fn parse_keywords(keywords: &Vec<ast::Keyword<()>>) -> Result<DecoratorArgs> {
|
||||
// more keys maybe add to this list of `avail_key`(like `sql` for querying and maybe config for connecting to database?), for better extension using a `HashSet` in here
|
||||
let avail_key = HashSet::from(["args", "returns", "sql", "backend"]);
|
||||
let opt_keys = HashSet::from(["sql", "args", "backend"]);
|
||||
let mut visited_key = HashSet::new();
|
||||
let len_min = avail_key.len() - opt_keys.len();
|
||||
let len_max = avail_key.len();
|
||||
ensure!(
|
||||
// "sql" is optional(for now)
|
||||
keywords.len() >= len_min && keywords.len() <= len_max,
|
||||
CoprParseSnafu {
|
||||
reason: format!(
|
||||
"Expect between {len_min} and {len_max} keyword argument, found {}.",
|
||||
keywords.len()
|
||||
),
|
||||
loc: keywords.first().map(|s| s.location)
|
||||
}
|
||||
);
|
||||
let mut ret_args = DecoratorArgs::default();
|
||||
for kw in keywords {
|
||||
match &kw.node.arg {
|
||||
Some(s) => {
|
||||
let s = s.as_str();
|
||||
if visited_key.contains(s) {
|
||||
return fail_parse_error!(
|
||||
format!("`{s}` occur multiple times in decorator's arguments' list."),
|
||||
Some(kw.location),
|
||||
);
|
||||
}
|
||||
if !avail_key.contains(s) {
|
||||
return fail_parse_error!(
|
||||
format!("Expect one of {:?}, found `{}`", &avail_key, s),
|
||||
Some(kw.location),
|
||||
);
|
||||
} else {
|
||||
let _ = visited_key.insert(s);
|
||||
}
|
||||
match s {
|
||||
"args" => ret_args.arg_names = Some(pylist_to_vec(&kw.node.value)?),
|
||||
"returns" => ret_args.ret_names = pylist_to_vec(&kw.node.value)?,
|
||||
"sql" => ret_args.sql = Some(py_str_to_string(&kw.node.value)?),
|
||||
"backend" => {
|
||||
let value = py_str_to_string(&kw.node.value)?;
|
||||
match value.as_str() {
|
||||
// although this is default option to use RustPython for interpreter
|
||||
// but that could change in the future
|
||||
"rspy" => ret_args.backend = BackendType::RustPython,
|
||||
"pyo3" => ret_args.backend = BackendType::CPython,
|
||||
_ => {
|
||||
return fail_parse_error!(
|
||||
format!(
|
||||
"backend type can only be of `rspy` and `pyo3`, found {value}"
|
||||
),
|
||||
Some(kw.location),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
None => {
|
||||
return fail_parse_error!(
|
||||
format!(
|
||||
"Expect explicitly set both `args` and `returns`, found \n{:#?}",
|
||||
&kw.node
|
||||
),
|
||||
Some(kw.location),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
let loc = keywords[0].location;
|
||||
for key in avail_key {
|
||||
if !visited_key.contains(key) && !opt_keys.contains(key) {
|
||||
return fail_parse_error!(format!("Expect `{key}` keyword"), Some(loc));
|
||||
}
|
||||
}
|
||||
Ok(ret_args)
|
||||
}
|
||||
|
||||
/// returns args and returns in Vec of String
|
||||
fn parse_decorator(decorator: &ast::Expr<()>) -> Result<DecoratorArgs> {
|
||||
//check_decorator(decorator)?;
|
||||
if let ast::ExprKind::Call {
|
||||
func,
|
||||
args: _,
|
||||
keywords,
|
||||
} = &decorator.node
|
||||
{
|
||||
ensure!(
|
||||
func.node
|
||||
== ast::ExprKind::Name {
|
||||
id: "copr".to_string(),
|
||||
ctx: ast::ExprContext::Load
|
||||
}
|
||||
|| func.node
|
||||
== ast::ExprKind::Name {
|
||||
id: "coprocessor".to_string(),
|
||||
ctx: ast::ExprContext::Load
|
||||
},
|
||||
CoprParseSnafu {
|
||||
reason: format!(
|
||||
"Expect decorator with name `copr` or `coprocessor`, found \n{:#?}",
|
||||
&func.node
|
||||
),
|
||||
loc: Some(func.location)
|
||||
}
|
||||
);
|
||||
parse_keywords(keywords)
|
||||
} else {
|
||||
fail_parse_error!(
|
||||
format!(
|
||||
"Expect decorator to be a function call(like `@copr(...)`), found \n{:#?}",
|
||||
decorator.node
|
||||
),
|
||||
Some(decorator.location),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// get type annotation in arguments
|
||||
fn get_arg_annotations(args: &Arguments) -> Result<Vec<Option<AnnotationInfo>>> {
|
||||
// get arg types from type annotation>
|
||||
args.args
|
||||
.iter()
|
||||
.map(|arg| {
|
||||
if let Some(anno) = &arg.node.annotation {
|
||||
// for there is error handling for parse_annotation
|
||||
parse_annotation(anno).map(Some)
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
})
|
||||
.collect::<Result<Vec<Option<_>>>>()
|
||||
}
|
||||
|
||||
fn get_return_annotations(rets: &ast::Expr<()>) -> Result<Vec<Option<AnnotationInfo>>> {
|
||||
let mut return_types = Vec::with_capacity(match &rets.node {
|
||||
ast::ExprKind::Tuple { elts, ctx: _ } => elts.len(),
|
||||
ast::ExprKind::Subscript {
|
||||
value: _,
|
||||
slice: _,
|
||||
ctx: _,
|
||||
} => 1,
|
||||
_ => {
|
||||
return fail_parse_error!(
|
||||
format!(
|
||||
"Expect `(vector[...], vector[...], ...)` or `vector[...]`, found \n{:#?}",
|
||||
&rets.node
|
||||
),
|
||||
Some(rets.location),
|
||||
)
|
||||
}
|
||||
});
|
||||
match &rets.node {
|
||||
// python: ->(vector[...], vector[...], ...)
|
||||
ast::ExprKind::Tuple { elts, .. } => {
|
||||
for elem in elts {
|
||||
return_types.push(Some(parse_annotation(elem)?))
|
||||
}
|
||||
}
|
||||
// python: -> vector[...]
|
||||
ast::ExprKind::Subscript {
|
||||
value: _,
|
||||
slice: _,
|
||||
ctx: _,
|
||||
} => return_types.push(Some(parse_annotation(rets)?)),
|
||||
_ => {
|
||||
return fail_parse_error!(
|
||||
format!(
|
||||
"Expect one or many type annotation for the return type, found \n{:#?}",
|
||||
&rets.node
|
||||
),
|
||||
Some(rets.location),
|
||||
)
|
||||
}
|
||||
}
|
||||
Ok(return_types)
|
||||
}
|
||||
|
||||
/// parse script and return `Coprocessor` struct with info extract from ast
|
||||
pub fn parse_and_compile_copr(
|
||||
script: &str,
|
||||
query_engine: Option<QueryEngineRef>,
|
||||
) -> Result<Coprocessor> {
|
||||
let python_ast = parse_program(script, "<embedded>").context(PyParseSnafu)?;
|
||||
|
||||
let mut coprocessor = None;
|
||||
|
||||
for stmt in python_ast {
|
||||
if let ast::StmtKind::FunctionDef {
|
||||
name,
|
||||
args: fn_args,
|
||||
body: _,
|
||||
decorator_list,
|
||||
returns,
|
||||
type_comment: _,
|
||||
} = &stmt.node
|
||||
{
|
||||
if !decorator_list.is_empty() {
|
||||
ensure!(coprocessor.is_none(),
|
||||
CoprParseSnafu {
|
||||
reason: "Expect one and only one python function with `@coprocessor` or `@cpor` decorator",
|
||||
loc: stmt.location,
|
||||
}
|
||||
);
|
||||
ensure!(
|
||||
decorator_list.len() == 1,
|
||||
CoprParseSnafu {
|
||||
reason: "Expect one decorator",
|
||||
loc: decorator_list.first().map(|s| s.location)
|
||||
}
|
||||
);
|
||||
|
||||
let decorator = &decorator_list[0];
|
||||
let deco_args = parse_decorator(decorator)?;
|
||||
|
||||
// get arg types from type annotation
|
||||
let arg_types = get_arg_annotations(fn_args)?;
|
||||
|
||||
// get return types from type annotation
|
||||
let return_types = if let Some(rets) = returns {
|
||||
get_return_annotations(rets)?
|
||||
} else {
|
||||
// if no anntation at all, set it to all None
|
||||
std::iter::repeat(None)
|
||||
.take(deco_args.ret_names.len())
|
||||
.collect()
|
||||
};
|
||||
|
||||
// make sure both arguments&returns in function
|
||||
// and in decorator have same length
|
||||
if let Some(arg_names) = &deco_args.arg_names {
|
||||
ensure!(
|
||||
arg_names.len() == arg_types.len(),
|
||||
CoprParseSnafu {
|
||||
reason: format!(
|
||||
"args number in decorator({}) and function({}) doesn't match",
|
||||
arg_names.len(),
|
||||
arg_types.len()
|
||||
),
|
||||
loc: None
|
||||
}
|
||||
);
|
||||
}
|
||||
ensure!(
|
||||
deco_args.ret_names.len() == return_types.len(),
|
||||
CoprParseSnafu {
|
||||
reason: format!(
|
||||
"returns number in decorator( {} ) and function annotation( {} ) doesn't match",
|
||||
deco_args.ret_names.len(),
|
||||
return_types.len()
|
||||
),
|
||||
loc: None
|
||||
}
|
||||
);
|
||||
|
||||
let backend = deco_args.backend.clone();
|
||||
let kwarg = fn_args.kwarg.as_ref().map(|arg| arg.node.arg.clone());
|
||||
coprocessor = Some(Coprocessor {
|
||||
code_obj: Some(compile::compile_script(name, &deco_args, &kwarg, script)?),
|
||||
name: name.to_string(),
|
||||
deco_args,
|
||||
arg_types,
|
||||
return_types,
|
||||
kwarg,
|
||||
script: script.to_string(),
|
||||
query_engine: query_engine.as_ref().map(|e| Arc::downgrade(e).into()),
|
||||
backend,
|
||||
});
|
||||
}
|
||||
} else if matches!(
|
||||
stmt.node,
|
||||
ast::StmtKind::Import { .. } | ast::StmtKind::ImportFrom { .. }
|
||||
) {
|
||||
// import statements are allowed.
|
||||
} else {
|
||||
return fail_parse_error!(
|
||||
format!(
|
||||
"Expect a function definition, but found a \n{:#?}",
|
||||
&stmt.node
|
||||
),
|
||||
Some(stmt.location),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
coprocessor.context(CoprParseSnafu {
|
||||
reason: "Coprocessor not found in script",
|
||||
loc: None,
|
||||
})
|
||||
}
|
||||
@@ -1,214 +0,0 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
mod sample_testcases;
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow::compute::kernels::numeric;
|
||||
use common_query::OutputData;
|
||||
use common_recordbatch::RecordBatch;
|
||||
use datafusion::arrow::array::Float64Array;
|
||||
use datafusion::arrow::compute;
|
||||
use datatypes::arrow::datatypes::DataType as ArrowDataType;
|
||||
use datatypes::schema::{ColumnSchema, Schema};
|
||||
use datatypes::vectors::VectorRef;
|
||||
#[cfg(feature = "pyo3_backend")]
|
||||
use pyo3::{types::PyDict, Python};
|
||||
use rustpython_compiler::Mode;
|
||||
|
||||
use crate::engine::{CompileContext, EvalContext, Script, ScriptEngine};
|
||||
use crate::python::engine::sample_script_engine;
|
||||
use crate::python::ffi_types::pair_tests::sample_testcases::{
|
||||
generate_copr_intgrate_tests, sample_test_case,
|
||||
};
|
||||
use crate::python::ffi_types::PyVector;
|
||||
#[cfg(feature = "pyo3_backend")]
|
||||
use crate::python::pyo3::{init_cpython_interpreter, vector_impl::into_pyo3_cell};
|
||||
use crate::python::rspython::init_interpreter;
|
||||
|
||||
// TODO(discord9): paired test for slicing Vector
|
||||
// & slice tests & lit() function for dataframe & test with full coprocessor&query engine ability
|
||||
/// generate testcases that should be tested in paired both in RustPython and CPython
|
||||
#[derive(Debug, Clone)]
|
||||
struct CodeBlockTestCase {
|
||||
input: HashMap<String, VectorRef>,
|
||||
script: String,
|
||||
expect: VectorRef,
|
||||
}
|
||||
|
||||
/// TODO(discord9): input a simple recordbatch, set a query engine, and such,
|
||||
/// so that for a full Coprocessor it will work
|
||||
#[derive(Debug, Clone, Default)]
|
||||
struct CoprTestCase {
|
||||
// will be build to a RecordBatch and feed to coprocessor
|
||||
script: String,
|
||||
expect: Option<HashMap<String, VectorRef>>,
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
fn into_recordbatch(input: HashMap<String, VectorRef>) -> RecordBatch {
|
||||
let mut schema = Vec::new();
|
||||
let mut columns = Vec::new();
|
||||
for (name, v) in input {
|
||||
schema.push(ColumnSchema::new(name, v.data_type(), false));
|
||||
columns.push(v);
|
||||
}
|
||||
let schema = Arc::new(Schema::new(schema));
|
||||
|
||||
RecordBatch::new(schema, columns).unwrap()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[allow(clippy::print_stdout)]
|
||||
async fn integrated_py_copr_test() {
|
||||
let testcases = generate_copr_intgrate_tests();
|
||||
let script_engine = sample_script_engine();
|
||||
for (idx, case) in testcases.into_iter().enumerate() {
|
||||
println!("Testcase {idx}:\n script: {}", case.script);
|
||||
let script = case.script;
|
||||
let script = script_engine
|
||||
.compile(&script, CompileContext::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let output = script
|
||||
.execute(HashMap::default(), EvalContext::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let res = match output.data {
|
||||
OutputData::Stream(s) => common_recordbatch::util::collect_batches(s).await.unwrap(),
|
||||
OutputData::RecordBatches(rbs) => rbs,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let rb = res.iter().next().expect("One and only one recordbatch");
|
||||
if let Some(expect_result) = case.expect {
|
||||
let mut actual_result = HashMap::new();
|
||||
for col_sch in rb.schema.column_schemas() {
|
||||
let col = rb.column_by_name(&col_sch.name).unwrap();
|
||||
let _ = actual_result.insert(col_sch.name.clone(), col.clone());
|
||||
}
|
||||
for (name, col) in expect_result {
|
||||
let actual_col = actual_result.get(&name).unwrap_or_else(|| {
|
||||
panic!("Expect column with name: {name} in {actual_result:?}")
|
||||
});
|
||||
if !check_equal(col.clone(), actual_col.clone()) {
|
||||
panic!("Column {name} doesn't match, expect {col:?}, found {actual_col:?}")
|
||||
}
|
||||
}
|
||||
}
|
||||
println!(".. Ok");
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::print_stdout)]
|
||||
#[test]
|
||||
fn pyo3_rspy_test_in_pairs() {
|
||||
let testcases = sample_test_case();
|
||||
for case in testcases {
|
||||
println!("Testcase: {}", case.script);
|
||||
eval_rspy(case.clone());
|
||||
#[cfg(feature = "pyo3_backend")]
|
||||
eval_pyo3(case);
|
||||
}
|
||||
}
|
||||
|
||||
fn check_equal(v0: VectorRef, v1: VectorRef) -> bool {
|
||||
let v0 = v0.to_arrow_array();
|
||||
let v1 = v1.to_arrow_array();
|
||||
if v0.len() != v1.len() {
|
||||
return false;
|
||||
}
|
||||
fn is_float(ty: &ArrowDataType) -> bool {
|
||||
use ArrowDataType::*;
|
||||
matches!(ty, Float16 | Float32 | Float64)
|
||||
}
|
||||
if is_float(v0.data_type()) || is_float(v1.data_type()) {
|
||||
let v0 = compute::cast(&v0, &ArrowDataType::Float64).unwrap();
|
||||
let v0 = v0.as_any().downcast_ref::<Float64Array>().unwrap();
|
||||
|
||||
let v1 = compute::cast(&v1, &ArrowDataType::Float64).unwrap();
|
||||
let v1 = v1.as_any().downcast_ref::<Float64Array>().unwrap();
|
||||
|
||||
let res = numeric::sub(v0, v1).unwrap();
|
||||
let res = res.as_any().downcast_ref::<Float64Array>().unwrap();
|
||||
res.iter().all(|v| {
|
||||
if let Some(v) = v {
|
||||
v.abs() <= 2.0 * f32::EPSILON as f64
|
||||
} else {
|
||||
false
|
||||
}
|
||||
})
|
||||
} else {
|
||||
*v0 == *v1
|
||||
}
|
||||
}
|
||||
|
||||
/// will panic if something is wrong, used in tests only
|
||||
fn eval_rspy(case: CodeBlockTestCase) {
|
||||
let interpreter = init_interpreter();
|
||||
interpreter.enter(|vm| {
|
||||
let scope = vm.new_scope_with_builtins();
|
||||
for (k, v) in case.input {
|
||||
let v = PyVector::from(v);
|
||||
scope.locals.set_item(&k, vm.new_pyobj(v), vm).unwrap();
|
||||
}
|
||||
let code_obj = vm
|
||||
.compile(&case.script, Mode::BlockExpr, "<embedded>".to_owned())
|
||||
.map_err(|err| vm.new_syntax_error(&err))
|
||||
.unwrap();
|
||||
let result_vector = vm
|
||||
.run_code_obj(code_obj, scope)
|
||||
.unwrap()
|
||||
.downcast::<PyVector>()
|
||||
.unwrap();
|
||||
|
||||
if !check_equal(result_vector.as_vector_ref(), case.expect.clone()) {
|
||||
panic!(
|
||||
"(RsPy)code:{}\nReal: {:?}!=Expected: {:?}",
|
||||
case.script, result_vector, case.expect
|
||||
)
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#[cfg(feature = "pyo3_backend")]
|
||||
fn eval_pyo3(case: CodeBlockTestCase) {
|
||||
init_cpython_interpreter().unwrap();
|
||||
Python::with_gil(|py| {
|
||||
let locals = {
|
||||
let locals_dict = PyDict::new(py);
|
||||
for (k, v) in case.input {
|
||||
let v = PyVector::from(v);
|
||||
locals_dict
|
||||
.set_item(k, into_pyo3_cell(py, v).unwrap())
|
||||
.unwrap();
|
||||
}
|
||||
locals_dict
|
||||
};
|
||||
py.run(&case.script, None, Some(locals)).unwrap();
|
||||
let res_vec = locals
|
||||
.get_item("ret")
|
||||
.unwrap()
|
||||
.unwrap()
|
||||
.extract::<PyVector>()
|
||||
.unwrap();
|
||||
if !check_equal(res_vec.as_vector_ref(), case.expect.clone()) {
|
||||
panic!(
|
||||
"(PyO3)code:{}\nReal: {:?}!=Expected: {:?}",
|
||||
case.script, res_vec, case.expect
|
||||
)
|
||||
}
|
||||
})
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,136 +0,0 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! PyRecordBatch is a Python class that wraps a RecordBatch,
|
||||
//! and provide a PyMapping Protocol to
|
||||
//! access the columns of the RecordBatch.
|
||||
|
||||
use common_recordbatch::RecordBatch;
|
||||
use crossbeam_utils::atomic::AtomicCell;
|
||||
#[cfg(feature = "pyo3_backend")]
|
||||
use pyo3::{
|
||||
exceptions::{PyKeyError, PyRuntimeError},
|
||||
pyclass as pyo3class, pymethods, PyObject, PyResult, Python,
|
||||
};
|
||||
use rustpython_vm::builtins::PyStr;
|
||||
use rustpython_vm::protocol::PyMappingMethods;
|
||||
use rustpython_vm::types::AsMapping;
|
||||
use rustpython_vm::{
|
||||
atomic_func, pyclass as rspyclass, PyObject as RsPyObject, PyPayload, PyResult as RsPyResult,
|
||||
VirtualMachine,
|
||||
};
|
||||
|
||||
use crate::python::ffi_types::PyVector;
|
||||
|
||||
/// This is a Wrapper around a RecordBatch, impl PyMapping Protocol so you can do both `a[0]` and `a["number"]` to retrieve column.
|
||||
#[cfg_attr(feature = "pyo3_backend", pyo3class(name = "PyRecordBatch"))]
|
||||
#[rspyclass(module = false, name = "PyRecordBatch")]
|
||||
#[derive(Debug, PyPayload)]
|
||||
pub(crate) struct PyRecordBatch {
|
||||
record_batch: RecordBatch,
|
||||
}
|
||||
|
||||
impl PyRecordBatch {
|
||||
pub fn new(record_batch: RecordBatch) -> Self {
|
||||
Self { record_batch }
|
||||
}
|
||||
}
|
||||
|
||||
impl From<RecordBatch> for PyRecordBatch {
|
||||
fn from(record_batch: RecordBatch) -> Self {
|
||||
Self::new(record_batch)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "pyo3_backend")]
|
||||
#[pymethods]
|
||||
impl PyRecordBatch {
|
||||
fn __repr__(&self) -> String {
|
||||
// TODO(discord9): a better pretty print
|
||||
format!("{:#?}", &self.record_batch.df_record_batch())
|
||||
}
|
||||
fn __getitem__(&self, py: Python, key: PyObject) -> PyResult<PyVector> {
|
||||
let column = if let Ok(key) = key.extract::<String>(py) {
|
||||
self.record_batch.column_by_name(&key)
|
||||
} else if let Ok(key) = key.extract::<usize>(py) {
|
||||
Some(self.record_batch.column(key))
|
||||
} else {
|
||||
return Err(PyRuntimeError::new_err(format!(
|
||||
"Expect either str or int, found {key:?}"
|
||||
)));
|
||||
}
|
||||
.ok_or_else(|| PyKeyError::new_err(format!("Column {} not found", key)))?;
|
||||
let v = PyVector::from(column.clone());
|
||||
Ok(v)
|
||||
}
|
||||
fn __iter__(&self) -> PyResult<Vec<PyVector>> {
|
||||
let iter: Vec<_> = self
|
||||
.record_batch
|
||||
.columns()
|
||||
.iter()
|
||||
.map(|i| PyVector::from(i.clone()))
|
||||
.collect();
|
||||
Ok(iter)
|
||||
}
|
||||
fn __len__(&self) -> PyResult<usize> {
|
||||
Ok(self.len())
|
||||
}
|
||||
}
|
||||
|
||||
impl PyRecordBatch {
|
||||
fn len(&self) -> usize {
|
||||
self.record_batch.num_rows()
|
||||
}
|
||||
fn get_item(&self, needle: &RsPyObject, vm: &VirtualMachine) -> RsPyResult {
|
||||
if let Ok(index) = needle.try_to_value::<usize>(vm) {
|
||||
let column = self.record_batch.column(index);
|
||||
let v = PyVector::from(column.clone());
|
||||
Ok(v.into_pyobject(vm))
|
||||
} else if let Ok(index) = needle.try_to_value::<String>(vm) {
|
||||
let key = index.as_str();
|
||||
|
||||
let v = self.record_batch.column_by_name(key).ok_or_else(|| {
|
||||
vm.new_key_error(PyStr::from(format!("Column {} not found", key)).into_pyobject(vm))
|
||||
})?;
|
||||
let v: PyVector = v.clone().into();
|
||||
Ok(v.into_pyobject(vm))
|
||||
} else {
|
||||
Err(vm.new_key_error(
|
||||
PyStr::from(format!("Expect either str or int, found {needle:?}"))
|
||||
.into_pyobject(vm),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[rspyclass(with(AsMapping))]
|
||||
impl PyRecordBatch {
|
||||
#[pymethod(name = "__repr__")]
|
||||
fn rspy_repr(&self) -> String {
|
||||
format!("{:#?}", &self.record_batch.df_record_batch())
|
||||
}
|
||||
}
|
||||
|
||||
impl AsMapping for PyRecordBatch {
|
||||
fn as_mapping() -> &'static PyMappingMethods {
|
||||
static AS_MAPPING: PyMappingMethods = PyMappingMethods {
|
||||
length: atomic_func!(|mapping, _vm| Ok(PyRecordBatch::mapping_downcast(mapping).len())),
|
||||
subscript: atomic_func!(
|
||||
|mapping, needle, vm| PyRecordBatch::mapping_downcast(mapping).get_item(needle, vm)
|
||||
),
|
||||
ass_subscript: AtomicCell::new(None),
|
||||
};
|
||||
&AS_MAPPING
|
||||
}
|
||||
}
|
||||
@@ -1,81 +0,0 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// to avoid put too many #cfg for pyo3 feature flag
|
||||
#![allow(unused)]
|
||||
use datafusion::arrow::compute;
|
||||
use datafusion::arrow::datatypes::Field;
|
||||
use datafusion_common::ScalarValue;
|
||||
use datafusion_expr::ColumnarValue;
|
||||
use datatypes::arrow::datatypes::DataType as ArrowDataType;
|
||||
|
||||
pub fn new_item_field(data_type: ArrowDataType) -> Field {
|
||||
Field::new("item", data_type, false)
|
||||
}
|
||||
|
||||
/// Generate friendly error message when the type of the input `values` is different than `ty`
|
||||
/// # Example
|
||||
/// `values` is [Int64(1), Float64(1.0), Int64(2)] and `ty` is Int64
|
||||
/// then the error message will be: " Float64 at 2th location\n"
|
||||
pub(crate) fn collect_diff_types_string(values: &[ScalarValue], ty: &ArrowDataType) -> String {
|
||||
values
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(idx, val)| {
|
||||
if val.data_type() != *ty {
|
||||
Some((idx, val.data_type()))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.map(|(idx, ty)| format!(" {:?} at {}th location\n", ty, idx + 1))
|
||||
.reduce(|mut acc, item| {
|
||||
acc.push_str(&item);
|
||||
acc
|
||||
})
|
||||
.unwrap_or_else(|| "Nothing".to_string())
|
||||
}
|
||||
|
||||
/// Because most of the datafusion's UDF only support f32/64, so cast all to f64 to use datafusion's UDF
|
||||
pub fn all_to_f64(col: ColumnarValue) -> Result<ColumnarValue, String> {
|
||||
match col {
|
||||
ColumnarValue::Array(arr) => {
|
||||
let res = compute::cast(&arr, &ArrowDataType::Float64).map_err(|err| {
|
||||
format!(
|
||||
"Arrow Type Cast Fail(from {:#?} to {:#?}): {err:#?}",
|
||||
arr.data_type(),
|
||||
ArrowDataType::Float64
|
||||
)
|
||||
})?;
|
||||
Ok(ColumnarValue::Array(res))
|
||||
}
|
||||
ColumnarValue::Scalar(val) => {
|
||||
let val_in_f64 = match val {
|
||||
ScalarValue::Float64(Some(v)) => v,
|
||||
ScalarValue::Int64(Some(v)) => v as f64,
|
||||
ScalarValue::Boolean(Some(v)) => v as i64 as f64,
|
||||
_ => {
|
||||
return Err(format!(
|
||||
"Can't cast type {:#?} to {:#?}",
|
||||
val.data_type(),
|
||||
ArrowDataType::Float64
|
||||
))
|
||||
}
|
||||
};
|
||||
Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(
|
||||
val_in_f64,
|
||||
))))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,546 +0,0 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
use std::ops::Deref;
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow::array::Datum;
|
||||
use arrow::compute::kernels::{cmp, numeric};
|
||||
use datatypes::arrow::array::{
|
||||
Array, ArrayRef, BooleanArray, Float64Array, Int64Array, UInt64Array,
|
||||
};
|
||||
use datatypes::arrow::compute;
|
||||
use datatypes::arrow::datatypes::DataType as ArrowDataType;
|
||||
use datatypes::arrow::error::Result as ArrowResult;
|
||||
use datatypes::data_type::DataType;
|
||||
use datatypes::prelude::{ConcreteDataType, Value};
|
||||
use datatypes::value::{self, OrderedFloat};
|
||||
use datatypes::vectors::{Helper, NullVector, VectorRef};
|
||||
#[cfg(feature = "pyo3_backend")]
|
||||
use pyo3::pyclass as pyo3class;
|
||||
use rustpython_vm::builtins::{PyBaseExceptionRef, PyBool, PyFloat, PyInt, PyNone, PyStr};
|
||||
use rustpython_vm::sliceable::{SaturatedSlice, SequenceIndex, SequenceIndexOp};
|
||||
use rustpython_vm::types::PyComparisonOp;
|
||||
use rustpython_vm::{
|
||||
pyclass as rspyclass, AsObject, PyObject, PyObjectRef, PyPayload, PyRef, PyResult,
|
||||
VirtualMachine,
|
||||
};
|
||||
|
||||
use crate::python::rspython::utils::is_instance;
|
||||
|
||||
/// The Main FFI type `PyVector` that is used both in RustPython and PyO3
|
||||
#[cfg_attr(feature = "pyo3_backend", pyo3class(name = "vector"))]
|
||||
#[rspyclass(module = false, name = "vector")]
|
||||
#[repr(transparent)]
|
||||
#[derive(PyPayload, Debug, Clone)]
|
||||
pub struct PyVector {
|
||||
pub(crate) vector: VectorRef,
|
||||
}
|
||||
|
||||
pub(crate) type PyVectorRef = PyRef<PyVector>;
|
||||
|
||||
impl From<VectorRef> for PyVector {
|
||||
fn from(vector: VectorRef) -> Self {
|
||||
Self { vector }
|
||||
}
|
||||
}
|
||||
|
||||
fn to_type_error(vm: &'_ VirtualMachine) -> impl FnOnce(String) -> PyBaseExceptionRef + '_ {
|
||||
|msg: String| vm.new_type_error(msg)
|
||||
}
|
||||
|
||||
/// Performs `val - arr`.
|
||||
pub(crate) fn arrow_rsub(arr: &dyn Datum, val: &dyn Datum) -> Result<ArrayRef, String> {
|
||||
numeric::sub(val, arr).map_err(|e| format!("rsub error: {e}"))
|
||||
}
|
||||
|
||||
/// Performs `val / arr`
|
||||
pub(crate) fn arrow_rtruediv(arr: &dyn Datum, val: &dyn Datum) -> Result<ArrayRef, String> {
|
||||
numeric::div(val, arr).map_err(|e| format!("rtruediv error: {e}"))
|
||||
}
|
||||
|
||||
/// Performs `val / arr`, but cast to i64.
|
||||
pub(crate) fn arrow_rfloordiv(arr: &dyn Datum, val: &dyn Datum) -> Result<ArrayRef, String> {
|
||||
let array = numeric::div(val, arr).map_err(|e| format!("rfloordiv divide error: {e}"))?;
|
||||
compute::cast(&array, &ArrowDataType::Int64).map_err(|e| format!("rfloordiv cast error: {e}"))
|
||||
}
|
||||
|
||||
pub(crate) fn wrap_result<F>(f: F) -> impl Fn(&dyn Datum, &dyn Datum) -> Result<ArrayRef, String>
|
||||
where
|
||||
F: Fn(&dyn Datum, &dyn Datum) -> ArrowResult<ArrayRef>,
|
||||
{
|
||||
move |left, right| f(left, right).map_err(|e| format!("arithmetic error {e}"))
|
||||
}
|
||||
|
||||
#[cfg(feature = "pyo3_backend")]
|
||||
pub(crate) fn wrap_bool_result<F>(
|
||||
op_bool_arr: F,
|
||||
) -> impl Fn(&dyn Datum, &dyn Datum) -> Result<ArrayRef, String>
|
||||
where
|
||||
F: Fn(&dyn Datum, &dyn Datum) -> ArrowResult<BooleanArray>,
|
||||
{
|
||||
move |a: &dyn Datum, b: &dyn Datum| -> Result<ArrayRef, String> {
|
||||
let array = op_bool_arr(a, b).map_err(|e| format!("logical op error: {e}"))?;
|
||||
Ok(Arc::new(array))
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn is_float(datatype: &ArrowDataType) -> bool {
|
||||
matches!(
|
||||
datatype,
|
||||
ArrowDataType::Float16 | ArrowDataType::Float32 | ArrowDataType::Float64
|
||||
)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn is_signed(datatype: &ArrowDataType) -> bool {
|
||||
matches!(
|
||||
datatype,
|
||||
ArrowDataType::Int8 | ArrowDataType::Int16 | ArrowDataType::Int32 | ArrowDataType::Int64
|
||||
)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn is_unsigned(datatype: &ArrowDataType) -> bool {
|
||||
matches!(
|
||||
datatype,
|
||||
ArrowDataType::UInt8
|
||||
| ArrowDataType::UInt16
|
||||
| ArrowDataType::UInt32
|
||||
| ArrowDataType::UInt64
|
||||
)
|
||||
}
|
||||
|
||||
fn cast(array: ArrayRef, target_type: &ArrowDataType) -> Result<ArrayRef, String> {
|
||||
compute::cast(&array, target_type).map_err(|e| e.to_string())
|
||||
}
|
||||
|
||||
impl AsRef<PyVector> for PyVector {
|
||||
fn as_ref(&self) -> &PyVector {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl PyVector {
|
||||
#[inline]
|
||||
pub(crate) fn data_type(&self) -> ConcreteDataType {
|
||||
self.vector.data_type()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn arrow_data_type(&self) -> ArrowDataType {
|
||||
self.vector.data_type().as_arrow_type()
|
||||
}
|
||||
|
||||
pub(crate) fn vector_and(left: &Self, right: &Self) -> Result<Self, String> {
|
||||
let left = left.to_arrow_array();
|
||||
let right = right.to_arrow_array();
|
||||
let left = left
|
||||
.as_any()
|
||||
.downcast_ref::<BooleanArray>()
|
||||
.ok_or_else(|| format!("Can't cast {left:#?} as a Boolean Array"))?;
|
||||
let right = right
|
||||
.as_any()
|
||||
.downcast_ref::<BooleanArray>()
|
||||
.ok_or_else(|| format!("Can't cast {right:#?} as a Boolean Array"))?;
|
||||
let res =
|
||||
Arc::new(compute::kernels::boolean::and(left, right).map_err(|err| err.to_string())?)
|
||||
as ArrayRef;
|
||||
let ret = Helper::try_into_vector(res.clone()).map_err(|err| err.to_string())?;
|
||||
Ok(ret.into())
|
||||
}
|
||||
pub(crate) fn vector_or(left: &Self, right: &Self) -> Result<Self, String> {
|
||||
let left = left.to_arrow_array();
|
||||
let right = right.to_arrow_array();
|
||||
let left = left
|
||||
.as_any()
|
||||
.downcast_ref::<BooleanArray>()
|
||||
.ok_or_else(|| format!("Can't cast {left:#?} as a Boolean Array"))?;
|
||||
let right = right
|
||||
.as_any()
|
||||
.downcast_ref::<BooleanArray>()
|
||||
.ok_or_else(|| format!("Can't cast {right:#?} as a Boolean Array"))?;
|
||||
let res =
|
||||
Arc::new(compute::kernels::boolean::or(left, right).map_err(|err| err.to_string())?)
|
||||
as ArrayRef;
|
||||
let ret = Helper::try_into_vector(res.clone()).map_err(|err| err.to_string())?;
|
||||
Ok(ret.into())
|
||||
}
|
||||
pub(crate) fn vector_invert(left: &Self) -> Result<Self, String> {
|
||||
let zelf = left.to_arrow_array();
|
||||
let zelf = zelf
|
||||
.as_any()
|
||||
.downcast_ref::<BooleanArray>()
|
||||
.ok_or_else(|| format!("Can't cast {left:#?} as a Boolean Array"))?;
|
||||
let res = Arc::new(compute::kernels::boolean::not(zelf).map_err(|err| err.to_string())?)
|
||||
as ArrayRef;
|
||||
let ret = Helper::try_into_vector(res.clone()).map_err(|err| err.to_string())?;
|
||||
Ok(ret.into())
|
||||
}
|
||||
/// create a ref to inner vector
|
||||
#[inline]
|
||||
pub fn as_vector_ref(&self) -> VectorRef {
|
||||
self.vector.clone()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn to_arrow_array(&self) -> ArrayRef {
|
||||
self.vector.to_arrow_array()
|
||||
}
|
||||
|
||||
pub(crate) fn scalar_arith_op<F>(
|
||||
&self,
|
||||
right: value::Value,
|
||||
target_type: Option<ArrowDataType>,
|
||||
op: F,
|
||||
) -> Result<Self, String>
|
||||
where
|
||||
F: Fn(&dyn Datum, &dyn Datum) -> Result<ArrayRef, String>,
|
||||
{
|
||||
let right_type = right.data_type().as_arrow_type();
|
||||
// assuming they are all 64 bit type if possible
|
||||
let left = self.to_arrow_array();
|
||||
|
||||
let left_type = left.data_type();
|
||||
let right_type = &right_type;
|
||||
let target_type = Self::coerce_types(left_type, right_type, &target_type);
|
||||
let left = cast(left, &target_type)?;
|
||||
let left_len = left.len();
|
||||
|
||||
// Convert `right` to an array of `target_type`.
|
||||
let right: Box<dyn Array> = if is_float(&target_type) {
|
||||
match right {
|
||||
value::Value::Int64(v) => Box::new(Float64Array::from_value(v as f64, left_len)),
|
||||
value::Value::UInt64(v) => Box::new(Float64Array::from_value(v as f64, left_len)),
|
||||
value::Value::Float64(v) => {
|
||||
Box::new(Float64Array::from_value(f64::from(v), left_len))
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
} else if is_signed(&target_type) {
|
||||
match right {
|
||||
value::Value::Int64(v) => Box::new(Int64Array::from_value(v, left_len)),
|
||||
value::Value::UInt64(v) => Box::new(Int64Array::from_value(v as i64, left_len)),
|
||||
value::Value::Float64(v) => Box::new(Int64Array::from_value(v.0 as i64, left_len)),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
} else if is_unsigned(&target_type) {
|
||||
match right {
|
||||
value::Value::Int64(v) => Box::new(UInt64Array::from_value(v as u64, left_len)),
|
||||
value::Value::UInt64(v) => Box::new(UInt64Array::from_value(v, left_len)),
|
||||
value::Value::Float64(v) => Box::new(UInt64Array::from_value(v.0 as u64, left_len)),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
} else {
|
||||
return Err(format!(
|
||||
"Can't cast source operand of type {:?} into target type of {:?}",
|
||||
right_type, &target_type
|
||||
));
|
||||
};
|
||||
|
||||
let result = op(&left, &right.as_ref())?;
|
||||
|
||||
Ok(Helper::try_into_vector(result.clone())
|
||||
.map_err(|e| format!("Can't cast result into vector, result: {result:?}, err: {e:?}",))?
|
||||
.into())
|
||||
}
|
||||
|
||||
pub(crate) fn rspy_scalar_arith_op<F>(
|
||||
&self,
|
||||
other: PyObjectRef,
|
||||
target_type: Option<ArrowDataType>,
|
||||
op: F,
|
||||
vm: &VirtualMachine,
|
||||
) -> PyResult<PyVector>
|
||||
where
|
||||
F: Fn(&dyn Datum, &dyn Datum) -> Result<ArrayRef, String>,
|
||||
{
|
||||
// the right operand only support PyInt or PyFloat,
|
||||
let right = {
|
||||
if is_instance::<PyInt>(&other, vm) {
|
||||
other.try_into_value::<i64>(vm).map(value::Value::Int64)?
|
||||
} else if is_instance::<PyFloat>(&other, vm) {
|
||||
other
|
||||
.try_into_value::<f64>(vm)
|
||||
.map(|v| (value::Value::Float64(OrderedFloat(v))))?
|
||||
} else {
|
||||
return Err(vm.new_type_error(format!(
|
||||
"Can't cast right operand into Scalar of Int or Float, actual: {}",
|
||||
other.class().name()
|
||||
)));
|
||||
}
|
||||
};
|
||||
self.scalar_arith_op(right, target_type, op)
|
||||
.map_err(to_type_error(vm))
|
||||
}
|
||||
|
||||
/// Returns the type that should be used for the result of an arithmetic operation
|
||||
fn coerce_types(
|
||||
left_type: &ArrowDataType,
|
||||
right_type: &ArrowDataType,
|
||||
target_type: &Option<ArrowDataType>,
|
||||
) -> ArrowDataType {
|
||||
// TODO(discord9): found better way to cast between signed and unsigned types
|
||||
target_type.clone().unwrap_or_else(|| {
|
||||
if is_signed(left_type) && is_signed(right_type) {
|
||||
ArrowDataType::Int64
|
||||
} else if is_unsigned(left_type) && is_unsigned(right_type) {
|
||||
ArrowDataType::UInt64
|
||||
} else {
|
||||
ArrowDataType::Float64
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn vector_arith_op<F>(
|
||||
&self,
|
||||
right: &Self,
|
||||
target_type: Option<ArrowDataType>,
|
||||
op: F,
|
||||
) -> Result<PyVector, String>
|
||||
where
|
||||
F: Fn(&dyn Datum, &dyn Datum) -> Result<ArrayRef, String>,
|
||||
{
|
||||
let left = self.to_arrow_array();
|
||||
let right = right.to_arrow_array();
|
||||
|
||||
let left_type = &left.data_type();
|
||||
let right_type = &right.data_type();
|
||||
|
||||
let target_type = Self::coerce_types(left_type, right_type, &target_type);
|
||||
|
||||
let left = cast(left, &target_type)?;
|
||||
let right = cast(right, &target_type)?;
|
||||
|
||||
let result = op(&left, &right)?;
|
||||
|
||||
Ok(Helper::try_into_vector(result.clone())
|
||||
.map_err(|e| format!("Can't cast result into vector, result: {result:?}, err: {e:?}",))?
|
||||
.into())
|
||||
}
|
||||
|
||||
pub(crate) fn rspy_vector_arith_op<F>(
|
||||
&self,
|
||||
other: PyObjectRef,
|
||||
target_type: Option<ArrowDataType>,
|
||||
op: F,
|
||||
vm: &VirtualMachine,
|
||||
) -> PyResult<PyVector>
|
||||
where
|
||||
F: Fn(&dyn Datum, &dyn Datum) -> Result<ArrayRef, String>,
|
||||
{
|
||||
let right = other.downcast_ref::<PyVector>().ok_or_else(|| {
|
||||
vm.new_type_error(format!(
|
||||
"Can't cast right operand into PyVector, actual type: {}",
|
||||
other.class().name()
|
||||
))
|
||||
})?;
|
||||
self.vector_arith_op(right, target_type, op)
|
||||
.map_err(to_type_error(vm))
|
||||
}
|
||||
|
||||
pub(crate) fn _getitem(&self, needle: &PyObject, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
|
||||
if let Some(seq) = needle.payload::<PyVector>() {
|
||||
let mask = seq.to_arrow_array();
|
||||
let mask = mask
|
||||
.as_any()
|
||||
.downcast_ref::<BooleanArray>()
|
||||
.ok_or_else(|| {
|
||||
vm.new_type_error(format!("Can't cast {seq:#?} as a Boolean Array"))
|
||||
})?;
|
||||
let res = compute::filter(self.to_arrow_array().as_ref(), mask)
|
||||
.map_err(|err| vm.new_runtime_error(format!("Arrow Error: {err:#?}")))?;
|
||||
let ret = Helper::try_into_vector(res.clone()).map_err(|e| {
|
||||
vm.new_type_error(format!("Can't cast result into vector, err: {e:?}"))
|
||||
})?;
|
||||
Ok(Self::from(ret).into_pyobject(vm))
|
||||
} else {
|
||||
match SequenceIndex::try_from_borrowed_object(vm, needle, "vector")? {
|
||||
SequenceIndex::Int(i) => self.getitem_by_index(i, vm),
|
||||
SequenceIndex::Slice(slice) => self.getitem_by_slice(&slice, vm),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn getitem_by_index(&self, i: isize, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
|
||||
// in the newest version of rustpython_vm, wrapped_at for isize is replace by wrap_index(i, len)
|
||||
let i = i.wrapped_at(self.len()).ok_or_else(|| {
|
||||
vm.new_index_error(format!("PyVector index {i} out of range {}", self.len()))
|
||||
})?;
|
||||
val_to_pyobj(self.as_vector_ref().get(i), vm)
|
||||
}
|
||||
|
||||
/// Return a `PyVector` in `PyObjectRef`
|
||||
fn getitem_by_slice(
|
||||
&self,
|
||||
slice: &SaturatedSlice,
|
||||
vm: &VirtualMachine,
|
||||
) -> PyResult<PyObjectRef> {
|
||||
// adjust_indices so negative number is transform to usize
|
||||
let (mut range, step, slice_len) = slice.adjust_indices(self.len());
|
||||
let vector = self.as_vector_ref();
|
||||
let mut buf = vector.data_type().create_mutable_vector(slice_len);
|
||||
if slice_len == 0 {
|
||||
let v: PyVector = buf.to_vector().into();
|
||||
Ok(v.into_pyobject(vm))
|
||||
} else if step == 1 {
|
||||
let v: PyVector = vector.slice(range.next().unwrap_or(0), slice_len).into();
|
||||
Ok(v.into_pyobject(vm))
|
||||
} else if step.is_negative() {
|
||||
// Negative step require special treatment
|
||||
// range.start > range.stop if slice can found no-empty
|
||||
for i in range.rev().step_by(step.unsigned_abs()) {
|
||||
// Safety: This mutable vector is created from the vector's data type.
|
||||
buf.push_value_ref(vector.get_ref(i));
|
||||
}
|
||||
let v: PyVector = buf.to_vector().into();
|
||||
Ok(v.into_pyobject(vm))
|
||||
} else {
|
||||
for i in range.step_by(step.unsigned_abs()) {
|
||||
// Safety: This mutable vector is created from the vector's data type.
|
||||
buf.push_value_ref(vector.get_ref(i));
|
||||
}
|
||||
let v: PyVector = buf.to_vector().into();
|
||||
Ok(v.into_pyobject(vm))
|
||||
}
|
||||
}
|
||||
|
||||
/// Unsupported
|
||||
/// TODO(discord9): make it work
|
||||
#[allow(unused)]
|
||||
fn setitem_by_index(
|
||||
zelf: PyRef<Self>,
|
||||
i: isize,
|
||||
value: PyObjectRef,
|
||||
vm: &VirtualMachine,
|
||||
) -> PyResult<()> {
|
||||
Err(vm.new_not_implemented_error("setitem_by_index unimplemented".to_string()))
|
||||
}
|
||||
|
||||
/// rich compare, return a boolean array, accept type are vec and vec and vec and number
|
||||
pub(crate) fn richcompare(
|
||||
&self,
|
||||
other: PyObjectRef,
|
||||
op: PyComparisonOp,
|
||||
vm: &VirtualMachine,
|
||||
) -> PyResult<PyVector> {
|
||||
if rspy_is_pyobj_scalar(&other, vm) {
|
||||
let scalar_op = get_arrow_scalar_op(op);
|
||||
self.rspy_scalar_arith_op(other, None, scalar_op, vm)
|
||||
} else {
|
||||
let arr_op = get_arrow_op(op);
|
||||
self.rspy_vector_arith_op(other, None, wrap_result(arr_op), vm)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn len(&self) -> usize {
|
||||
self.as_vector_ref().len()
|
||||
}
|
||||
}
|
||||
|
||||
/// get corresponding arrow op function according to given PyComaprsionOp
|
||||
fn get_arrow_op(op: PyComparisonOp) -> impl Fn(&dyn Datum, &dyn Datum) -> ArrowResult<ArrayRef> {
|
||||
let op_bool_arr = match op {
|
||||
PyComparisonOp::Eq => cmp::eq,
|
||||
PyComparisonOp::Ne => cmp::neq,
|
||||
PyComparisonOp::Gt => cmp::gt,
|
||||
PyComparisonOp::Lt => cmp::lt,
|
||||
PyComparisonOp::Ge => cmp::gt_eq,
|
||||
PyComparisonOp::Le => cmp::lt_eq,
|
||||
};
|
||||
|
||||
move |a: &dyn Datum, b: &dyn Datum| -> ArrowResult<ArrayRef> {
|
||||
let array = op_bool_arr(a, b)?;
|
||||
Ok(Arc::new(array))
|
||||
}
|
||||
}
|
||||
|
||||
/// get corresponding arrow scalar op function according to given PyComaprsionOp
|
||||
fn get_arrow_scalar_op(
|
||||
op: PyComparisonOp,
|
||||
) -> impl Fn(&dyn Datum, &dyn Datum) -> Result<ArrayRef, String> {
|
||||
let op_bool_arr = match op {
|
||||
PyComparisonOp::Eq => cmp::eq,
|
||||
PyComparisonOp::Ne => cmp::neq,
|
||||
PyComparisonOp::Gt => cmp::gt,
|
||||
PyComparisonOp::Lt => cmp::lt,
|
||||
PyComparisonOp::Ge => cmp::gt_eq,
|
||||
PyComparisonOp::Le => cmp::lt_eq,
|
||||
};
|
||||
|
||||
move |a: &dyn Datum, b: &dyn Datum| -> Result<ArrayRef, String> {
|
||||
let array = op_bool_arr(a, b).map_err(|e| format!("scalar op error: {e}"))?;
|
||||
Ok(Arc::new(array))
|
||||
}
|
||||
}
|
||||
|
||||
/// if this pyobj can be cast to a scalar value(i.e Null/Int/Float/Bool)
|
||||
#[inline]
|
||||
pub(crate) fn rspy_is_pyobj_scalar(obj: &PyObjectRef, vm: &VirtualMachine) -> bool {
|
||||
is_instance::<PyNone>(obj, vm)
|
||||
|| is_instance::<PyInt>(obj, vm)
|
||||
|| is_instance::<PyFloat>(obj, vm)
|
||||
|| is_instance::<PyBool>(obj, vm)
|
||||
|| is_instance::<PyStr>(obj, vm)
|
||||
}
|
||||
|
||||
/// convert a DataType `Value` into a `PyObjectRef`
|
||||
pub fn val_to_pyobj(val: value::Value, vm: &VirtualMachine) -> PyResult {
|
||||
Ok(match val {
|
||||
// This comes from:https://github.com/RustPython/RustPython/blob/8ab4e770351d451cfdff5dc2bf8cce8df76a60ab/vm/src/builtins/singletons.rs#L37
|
||||
// None in Python is universally singleton so
|
||||
// use `vm.ctx.new_int` and `new_***` is more idiomatic for there are certain optimize can be used in this way(small int pool etc.)
|
||||
value::Value::Null => vm.ctx.none(),
|
||||
value::Value::Boolean(v) => vm.ctx.new_bool(v).into(),
|
||||
value::Value::UInt8(v) => vm.ctx.new_int(v).into(),
|
||||
value::Value::UInt16(v) => vm.ctx.new_int(v).into(),
|
||||
value::Value::UInt32(v) => vm.ctx.new_int(v).into(),
|
||||
value::Value::UInt64(v) => vm.ctx.new_int(v).into(),
|
||||
value::Value::Int8(v) => vm.ctx.new_int(v).into(),
|
||||
value::Value::Int16(v) => vm.ctx.new_int(v).into(),
|
||||
value::Value::Int32(v) => vm.ctx.new_int(v).into(),
|
||||
value::Value::Int64(v) => vm.ctx.new_int(v).into(),
|
||||
value::Value::Float32(v) => vm.ctx.new_float(v.0 as f64).into(),
|
||||
value::Value::Float64(v) => vm.ctx.new_float(v.0).into(),
|
||||
value::Value::String(s) => vm.ctx.new_str(s.as_utf8()).into(),
|
||||
// is this copy necessary?
|
||||
value::Value::Binary(b) => vm.ctx.new_bytes(b.deref().to_vec()).into(),
|
||||
// TODO(dennis):is `Date` and `DateTime` supported yet? For now just ad hoc into PyInt, but it's better to be cast into python Date, DateTime objects etc..
|
||||
value::Value::Date(v) => vm.ctx.new_int(v.val()).into(),
|
||||
value::Value::DateTime(v) => vm.ctx.new_int(v.val()).into(),
|
||||
// FIXME(dennis): lose the timestamp unit here
|
||||
Value::Timestamp(v) => vm.ctx.new_int(v.value()).into(),
|
||||
value::Value::List(list) => {
|
||||
let list: Vec<_> = list
|
||||
.items()
|
||||
.iter()
|
||||
.map(|v| val_to_pyobj(v.clone(), vm))
|
||||
.collect::<Result<_, _>>()?;
|
||||
vm.ctx.new_list(list).into()
|
||||
}
|
||||
#[allow(unreachable_patterns)]
|
||||
_ => return Err(vm.new_type_error(format!("Convert from {val:?} is not supported yet"))),
|
||||
})
|
||||
}
|
||||
|
||||
impl Default for PyVector {
|
||||
fn default() -> PyVector {
|
||||
PyVector {
|
||||
vector: Arc::new(NullVector::new(0)),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,186 +0,0 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Here are pair-tests for vector types in both rustpython and cpython
|
||||
//!
|
||||
|
||||
// TODO: sample record batch
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use datatypes::scalars::ScalarVector;
|
||||
use datatypes::vectors::{BooleanVector, Float64Vector, Int64Vector, VectorRef};
|
||||
#[cfg(feature = "pyo3_backend")]
|
||||
use pyo3::{types::PyDict, Python};
|
||||
use rustpython_compiler::Mode;
|
||||
use rustpython_vm::AsObject;
|
||||
|
||||
use crate::python::ffi_types::PyVector;
|
||||
#[cfg(feature = "pyo3_backend")]
|
||||
use crate::python::pyo3::{init_cpython_interpreter, vector_impl::into_pyo3_cell};
|
||||
use crate::python::rspython::init_interpreter;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct TestCase {
|
||||
eval: String,
|
||||
result: VectorRef,
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_eval_py_vector_in_pairs() {
|
||||
let locals: HashMap<_, _> = sample_py_vector()
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k, PyVector::from(v)))
|
||||
.collect();
|
||||
|
||||
let testcases = get_test_cases();
|
||||
|
||||
for testcase in testcases {
|
||||
#[cfg(feature = "pyo3_backend")]
|
||||
eval_pyo3(testcase.clone(), locals.clone());
|
||||
eval_rspy(testcase, locals.clone())
|
||||
}
|
||||
}
|
||||
|
||||
fn sample_py_vector() -> HashMap<String, VectorRef> {
|
||||
let b1 = Arc::new(BooleanVector::from_slice(&[false, false, true, true])) as VectorRef;
|
||||
let b2 = Arc::new(BooleanVector::from_slice(&[false, true, false, true])) as VectorRef;
|
||||
let f1 = Arc::new(Float64Vector::from_slice([0.0f64, 2.0, 10.0, 42.0])) as VectorRef;
|
||||
let f2 = Arc::new(Float64Vector::from_slice([-0.1f64, -42.0, 2., 7.0])) as VectorRef;
|
||||
let f3 = Arc::new(Float64Vector::from_slice([1.0f64, -42.0, 2., 7.0])) as VectorRef;
|
||||
HashMap::from([
|
||||
("b1".to_owned(), b1),
|
||||
("b2".to_owned(), b2),
|
||||
("f1".to_owned(), f1),
|
||||
("f2".to_owned(), f2),
|
||||
("f3".to_owned(), f3),
|
||||
])
|
||||
}
|
||||
|
||||
/// testcases for test basic operations
|
||||
/// this is more powerful&flexible than standalone testcases configure file
|
||||
fn get_test_cases() -> Vec<TestCase> {
|
||||
let testcases = [
|
||||
TestCase {
|
||||
eval: "b1 & b2".to_string(),
|
||||
result: Arc::new(BooleanVector::from_slice(&[false, false, false, true])) as VectorRef,
|
||||
},
|
||||
TestCase {
|
||||
eval: "b1 | b2".to_string(),
|
||||
result: Arc::new(BooleanVector::from_slice(&[false, true, true, true])) as VectorRef,
|
||||
},
|
||||
TestCase {
|
||||
eval: "~b1".to_string(),
|
||||
result: Arc::new(BooleanVector::from_slice(&[true, true, false, false])) as VectorRef,
|
||||
},
|
||||
TestCase {
|
||||
eval: "f1+f2".to_string(),
|
||||
result: Arc::new(Float64Vector::from_slice([-0.1f64, -40.0, 12., 49.0])) as VectorRef,
|
||||
},
|
||||
TestCase {
|
||||
eval: "f1-f2".to_string(),
|
||||
result: Arc::new(Float64Vector::from_slice([0.1f64, 44.0, 8., 35.0])) as VectorRef,
|
||||
},
|
||||
TestCase {
|
||||
eval: "f1*f2".to_string(),
|
||||
result: Arc::new(Float64Vector::from_slice([-0.0f64, -84.0, 20., 42.0 * 7.0]))
|
||||
as VectorRef,
|
||||
},
|
||||
TestCase {
|
||||
eval: "f1/f2".to_string(),
|
||||
result: Arc::new(Float64Vector::from_slice([
|
||||
0.0 / -0.1f64,
|
||||
2. / -42.,
|
||||
10. / 2.,
|
||||
42. / 7.,
|
||||
])) as VectorRef,
|
||||
},
|
||||
TestCase {
|
||||
eval: "f2.__rtruediv__(f1)".to_string(),
|
||||
result: Arc::new(Float64Vector::from_slice([
|
||||
0.0 / -0.1f64,
|
||||
2. / -42.,
|
||||
10. / 2.,
|
||||
42. / 7.,
|
||||
])) as VectorRef,
|
||||
},
|
||||
TestCase {
|
||||
eval: "f2.__floordiv__(f3)".to_string(),
|
||||
result: Arc::new(Int64Vector::from_slice([0, 1, 1, 1])) as VectorRef,
|
||||
},
|
||||
TestCase {
|
||||
eval: "f3.__rfloordiv__(f2)".to_string(),
|
||||
result: Arc::new(Int64Vector::from_slice([0, 1, 1, 1])) as VectorRef,
|
||||
},
|
||||
TestCase {
|
||||
eval: "f3.filter(b1)".to_string(),
|
||||
result: Arc::new(Float64Vector::from_slice([2.0, 7.0])) as VectorRef,
|
||||
},
|
||||
];
|
||||
Vec::from(testcases)
|
||||
}
|
||||
#[cfg(feature = "pyo3_backend")]
|
||||
fn eval_pyo3(testcase: TestCase, locals: HashMap<String, PyVector>) {
|
||||
init_cpython_interpreter().unwrap();
|
||||
Python::with_gil(|py| {
|
||||
let locals = {
|
||||
let locals_dict = PyDict::new(py);
|
||||
for (k, v) in locals {
|
||||
locals_dict
|
||||
.set_item(k, into_pyo3_cell(py, v).unwrap())
|
||||
.unwrap();
|
||||
}
|
||||
locals_dict
|
||||
};
|
||||
let res = py.eval(&testcase.eval, None, Some(locals)).unwrap();
|
||||
let res_vec = res.extract::<PyVector>().unwrap();
|
||||
let raw_arr = res_vec.as_vector_ref().to_arrow_array();
|
||||
let expect_arr = testcase.result.to_arrow_array();
|
||||
if *raw_arr != *expect_arr {
|
||||
panic!("{raw_arr:?}!={expect_arr:?}")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn eval_rspy(testcase: TestCase, locals: HashMap<String, PyVector>) {
|
||||
init_interpreter().enter(|vm| {
|
||||
let scope = vm.new_scope_with_builtins();
|
||||
locals.into_iter().for_each(|(k, v)| {
|
||||
scope
|
||||
.locals
|
||||
.as_object()
|
||||
.set_item(&k, vm.new_pyobj(v), vm)
|
||||
.unwrap();
|
||||
});
|
||||
let code_obj = vm
|
||||
.compile(&testcase.eval, Mode::Eval, "<embedded>".to_string())
|
||||
.map_err(|err| vm.new_syntax_error(&err))
|
||||
.unwrap();
|
||||
let obj = vm
|
||||
.run_code_obj(code_obj, scope)
|
||||
.map_err(|e| {
|
||||
let mut output = String::new();
|
||||
vm.write_exception(&mut output, &e).unwrap();
|
||||
(e, output)
|
||||
})
|
||||
.unwrap();
|
||||
let v = obj.downcast::<PyVector>().unwrap();
|
||||
let result_arr = v.to_arrow_array();
|
||||
let expect_arr = testcase.result.to_arrow_array();
|
||||
if *result_arr != *expect_arr {
|
||||
panic!("{result_arr:?}!={expect_arr:?}")
|
||||
}
|
||||
});
|
||||
}
|
||||
@@ -1,53 +0,0 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use lazy_static::lazy_static;
|
||||
use prometheus::*;
|
||||
|
||||
lazy_static! {
|
||||
pub static ref METRIC_RSPY_INIT_ELAPSED: Histogram = register_histogram!(
|
||||
"greptime_script_rspy_init_elapsed",
|
||||
"script rspy init elapsed"
|
||||
)
|
||||
.unwrap();
|
||||
pub static ref METRIC_RSPY_EXEC_ELAPSED: Histogram = register_histogram!(
|
||||
"greptime_script_rspy_exec_elapsed",
|
||||
"script rspy exec elapsed"
|
||||
)
|
||||
.unwrap();
|
||||
pub static ref METRIC_RSPY_EXEC_TOTAL_ELAPSED: Histogram = register_histogram!(
|
||||
"greptime_script_rspy_exec_total_elapsed",
|
||||
"script rspy exec total elapsed"
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[cfg(feature = "pyo3_backend")]
|
||||
lazy_static! {
|
||||
pub static ref METRIC_PYO3_EXEC_ELAPSED: Histogram = register_histogram!(
|
||||
"greptime_script_pyo3_exec_elapsed",
|
||||
"script pyo3 exec elapsed"
|
||||
)
|
||||
.unwrap();
|
||||
pub static ref METRIC_PYO3_INIT_ELAPSED: Histogram = register_histogram!(
|
||||
"greptime_script_pyo3_init_elapsed",
|
||||
"script pyo3 init elapsed"
|
||||
)
|
||||
.unwrap();
|
||||
pub static ref METRIC_PYO3_EXEC_TOTAL_ELAPSED: Histogram = register_histogram!(
|
||||
"greptime_script_pyo3_exec_total_elapsed",
|
||||
"script pyo3 exec total elapsed"
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
@@ -1,24 +0,0 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
mod builtins;
|
||||
pub(crate) mod copr_impl;
|
||||
mod dataframe_impl;
|
||||
mod utils;
|
||||
pub(crate) mod vector_impl;
|
||||
|
||||
#[cfg(feature = "pyo3_backend")]
|
||||
pub(crate) use copr_impl::pyo3_exec_parsed;
|
||||
#[cfg(test)]
|
||||
pub(crate) use utils::init_cpython_interpreter;
|
||||
@@ -1,410 +0,0 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_function::function::FunctionRef;
|
||||
use common_function::function_registry::FUNCTION_REGISTRY;
|
||||
use datafusion::arrow::array::{ArrayRef, NullArray};
|
||||
use datafusion::physical_plan::expressions;
|
||||
use datafusion_expr::ColumnarValue;
|
||||
use datafusion_physical_expr::AggregateExpr;
|
||||
use datatypes::vectors::VectorRef;
|
||||
use pyo3::exceptions::{PyKeyError, PyValueError};
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::PyDict;
|
||||
|
||||
use super::dataframe_impl::PyDataFrame;
|
||||
use super::utils::scalar_value_to_py_any;
|
||||
use crate::python::ffi_types::copr::PyQueryEngine;
|
||||
use crate::python::ffi_types::utils::all_to_f64;
|
||||
use crate::python::ffi_types::PyVector;
|
||||
use crate::python::pyo3::dataframe_impl::{col, lit};
|
||||
use crate::python::pyo3::utils::{
|
||||
columnar_value_to_py_any, try_into_columnar_value, val_to_py_any,
|
||||
};
|
||||
|
||||
/// Try to extract a `PyVector` or convert from a `pyarrow.array` object
|
||||
#[inline]
|
||||
fn try_into_py_vector(py: Python, obj: PyObject) -> PyResult<PyVector> {
|
||||
if let Ok(v) = obj.extract::<PyVector>(py) {
|
||||
Ok(v)
|
||||
} else {
|
||||
PyVector::from_pyarrow(obj.as_ref(py).get_type(), py, obj.clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn to_array_of_py_vec(py: Python, obj: &[&PyObject]) -> PyResult<Vec<PyVector>> {
|
||||
obj.iter()
|
||||
.map(|v| try_into_py_vector(py, v.to_object(py)))
|
||||
.collect::<PyResult<_>>()
|
||||
}
|
||||
|
||||
macro_rules! batch_import {
|
||||
($m: ident, [$($fn_name: ident),*]) => {
|
||||
$($m.add_function(wrap_pyfunction!($fn_name, $m)?)?;)*
|
||||
};
|
||||
}
|
||||
|
||||
#[pymodule]
|
||||
#[pyo3(name = "greptime")]
|
||||
pub(crate) fn greptime_builtins(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
m.add_class::<PyVector>()?;
|
||||
m.add_class::<PyDataFrame>()?;
|
||||
use self::query_engine;
|
||||
batch_import!(
|
||||
m,
|
||||
[
|
||||
dataframe,
|
||||
query_engine,
|
||||
lit,
|
||||
col,
|
||||
pow,
|
||||
clip,
|
||||
diff,
|
||||
mean,
|
||||
polyval,
|
||||
argmax,
|
||||
argmin,
|
||||
percentile,
|
||||
scipy_stats_norm_cdf,
|
||||
scipy_stats_norm_pdf,
|
||||
sqrt,
|
||||
sin,
|
||||
cos,
|
||||
tan,
|
||||
asin,
|
||||
acos,
|
||||
atan,
|
||||
floor,
|
||||
ceil,
|
||||
round,
|
||||
// trunc,
|
||||
abs,
|
||||
signum,
|
||||
exp,
|
||||
ln,
|
||||
log2,
|
||||
log10,
|
||||
random,
|
||||
approx_distinct,
|
||||
// median,
|
||||
approx_percentile_cont,
|
||||
array_agg,
|
||||
avg,
|
||||
correlation,
|
||||
count,
|
||||
// covariance,
|
||||
// covariance_pop,
|
||||
max,
|
||||
min,
|
||||
stddev,
|
||||
stddev_pop,
|
||||
sum,
|
||||
variance,
|
||||
variance_pop
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_globals(py: Python) -> PyResult<&PyDict> {
|
||||
// TODO(discord9): check if this is sound(in python)
|
||||
let py_main = PyModule::import(py, "__main__")?;
|
||||
let globals = py_main.dict();
|
||||
Ok(globals)
|
||||
}
|
||||
|
||||
/// In case of not wanting to repeat the same sql statement in sql,
|
||||
/// this function is still useful even we already have PyDataFrame.from_sql()
|
||||
#[pyfunction]
|
||||
fn dataframe(py: Python) -> PyResult<PyDataFrame> {
|
||||
let globals = get_globals(py)?;
|
||||
let df = globals
|
||||
.get_item("__dataframe__")?
|
||||
.ok_or_else(|| PyKeyError::new_err("No __dataframe__ variable is found"))?
|
||||
.extract::<PyDataFrame>()?;
|
||||
Ok(df)
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(name = "query")]
|
||||
pub(crate) fn query_engine(py: Python) -> PyResult<PyQueryEngine> {
|
||||
let globals = get_globals(py)?;
|
||||
let query = globals
|
||||
.get_item("__query__")?
|
||||
.ok_or_else(|| PyKeyError::new_err("No __query__ variable is found"))?
|
||||
.extract::<PyQueryEngine>()?;
|
||||
Ok(query)
|
||||
}
|
||||
|
||||
fn eval_func(py: Python<'_>, name: &str, v: &[&PyObject]) -> PyResult<PyVector> {
|
||||
let v = to_array_of_py_vec(py, v)?;
|
||||
py.allow_threads(|| {
|
||||
let v: Vec<VectorRef> = v.iter().map(|v| v.as_vector_ref()).collect();
|
||||
let func: Option<FunctionRef> = FUNCTION_REGISTRY.get_function(name);
|
||||
let res = match func {
|
||||
Some(f) => f.eval(Default::default(), &v),
|
||||
None => return Err(PyValueError::new_err(format!("Can't find function {name}"))),
|
||||
};
|
||||
match res {
|
||||
Ok(v) => Ok(v.into()),
|
||||
Err(err) => Err(PyValueError::new_err(format!(
|
||||
"Fail to evaluate the function,: {err}"
|
||||
))),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn eval_aggr_func(py: Python<'_>, name: &str, args: &[&PyVector]) -> PyResult<PyObject> {
|
||||
let res = py.allow_threads(|| {
|
||||
let v: Vec<VectorRef> = args.iter().map(|v| v.as_vector_ref()).collect();
|
||||
let func = FUNCTION_REGISTRY.get_aggr_function(name);
|
||||
let f = match func {
|
||||
Some(f) => f.create().creator(),
|
||||
None => return Err(PyValueError::new_err(format!("Can't find function {name}"))),
|
||||
};
|
||||
let types: Vec<_> = v.iter().map(|v| v.data_type()).collect();
|
||||
let acc = f(&types);
|
||||
let mut acc = match acc {
|
||||
Ok(acc) => acc,
|
||||
Err(err) => {
|
||||
return Err(PyValueError::new_err(format!(
|
||||
"Failed to create accumulator: {err}"
|
||||
)))
|
||||
}
|
||||
};
|
||||
match acc.update_batch(&v) {
|
||||
Ok(_) => (),
|
||||
Err(err) => {
|
||||
return Err(PyValueError::new_err(format!(
|
||||
"Failed to update batch: {err}"
|
||||
)))
|
||||
}
|
||||
};
|
||||
let res = match acc.evaluate() {
|
||||
Ok(r) => r,
|
||||
Err(err) => {
|
||||
return Err(PyValueError::new_err(format!(
|
||||
"Failed to evaluate accumulator: {err}"
|
||||
)))
|
||||
}
|
||||
};
|
||||
Ok(res)
|
||||
})?;
|
||||
val_to_py_any(py, res)
|
||||
}
|
||||
|
||||
/// evaluate Aggregate Expr using its backing accumulator
|
||||
/// TODO(discord9): cast to f64 before use/Provide cast to f64 function?
|
||||
fn eval_df_aggr_expr<T: AggregateExpr>(
|
||||
py: Python<'_>,
|
||||
aggr: T,
|
||||
values: &[ArrayRef],
|
||||
) -> PyResult<PyObject> {
|
||||
let res = py.allow_threads(|| -> PyResult<_> {
|
||||
// acquire the accumulator, where the actual implement of aggregate expr layers
|
||||
let mut acc = aggr
|
||||
.create_accumulator()
|
||||
.map_err(|e| PyValueError::new_err(format!("{e:?}")))?;
|
||||
acc.update_batch(values)
|
||||
.map_err(|e| PyValueError::new_err(format!("{e:?}")))?;
|
||||
let res = acc
|
||||
.evaluate()
|
||||
.map_err(|e| PyValueError::new_err(format!("{e:?}")))?;
|
||||
Ok(res)
|
||||
})?;
|
||||
scalar_value_to_py_any(py, res)
|
||||
}
|
||||
|
||||
/// use to bind to Data Fusion's UDF function
|
||||
macro_rules! bind_call_unary_math_function {
|
||||
($($DF_FUNC: ident),*) => {
|
||||
$(
|
||||
#[pyfunction]
|
||||
fn $DF_FUNC(py: Python<'_>, val: PyObject) -> PyResult<PyObject> {
|
||||
let args =
|
||||
&[all_to_f64(try_into_columnar_value(py, val)?).map_err(PyValueError::new_err)?];
|
||||
let res = datafusion_functions::math::$DF_FUNC()
|
||||
.invoke(args)
|
||||
.map_err(|e| PyValueError::new_err(format!("{e:?}")))?;
|
||||
columnar_value_to_py_any(py, res)
|
||||
}
|
||||
)*
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! simple_vector_fn {
|
||||
($name: ident, $name_str: tt, [$($arg:ident),*]) => {
|
||||
#[pyfunction]
|
||||
fn $name(py: Python<'_>, $($arg: PyObject),*) -> PyResult<PyVector> {
|
||||
eval_func(py, $name_str, &[$(&$arg),*])
|
||||
}
|
||||
};
|
||||
($name: ident, $name_str: tt, AGG[$($arg:ident),*]) => {
|
||||
#[pyfunction]
|
||||
fn $name(py: Python<'_>, $($arg: &PyVector),*) -> PyResult<PyObject> {
|
||||
eval_aggr_func(py, $name_str, &[$($arg),*])
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// TODO(discord9): More Aggr functions& allow threads
|
||||
simple_vector_fn!(pow, "pow", [v0, v1]);
|
||||
simple_vector_fn!(clip, "clip", [v0, v1, v2]);
|
||||
simple_vector_fn!(diff, "diff", AGG[v0]);
|
||||
simple_vector_fn!(mean, "mean", AGG[v0]);
|
||||
simple_vector_fn!(polyval, "polyval", AGG[v0, v1]);
|
||||
simple_vector_fn!(argmax, "argmax", AGG[v0]);
|
||||
simple_vector_fn!(argmin, "argmin", AGG[v0]);
|
||||
simple_vector_fn!(percentile, "percentile", AGG[v0, v1]);
|
||||
simple_vector_fn!(scipy_stats_norm_cdf, "scipystatsnormcdf", AGG[v0, v1]);
|
||||
simple_vector_fn!(scipy_stats_norm_pdf, "scipystatsnormpdf", AGG[v0, v1]);
|
||||
|
||||
/*
|
||||
This macro basically expand to this code below:
|
||||
```rust
|
||||
fn sqrt(py: Python<'_>, val: PyObject) -> PyResult<PyObject> {
|
||||
let args = &[all_to_f64(try_into_columnar_value(py, val)?).map_err(PyValueError::new_err)?];
|
||||
let res = math_expressions::sqrt(args).map_err(|e| PyValueError::new_err(format!("{e:?}")))?;
|
||||
columnar_value_to_py_any(py, res)
|
||||
}
|
||||
```
|
||||
*/
|
||||
bind_call_unary_math_function!(
|
||||
sqrt, sin, cos, tan, asin, acos, atan, floor, ceil, abs, signum, exp, ln, log2,
|
||||
log10 // trunc,
|
||||
);
|
||||
|
||||
/// return a random vector range from 0 to 1 and length of len
|
||||
#[pyfunction]
|
||||
fn random(py: Python<'_>, len: usize) -> PyResult<PyObject> {
|
||||
// This is in a proc macro so using full path to avoid strange things
|
||||
// more info at: https://doc.rust-lang.org/reference/procedural-macros.html#procedural-macro-hygiene
|
||||
let arg = NullArray::new(len);
|
||||
let args = &[ColumnarValue::Array(std::sync::Arc::new(arg) as _)];
|
||||
let res = datafusion_functions::math::random()
|
||||
.invoke(args)
|
||||
.map_err(|e| PyValueError::new_err(format!("{e:?}")))?;
|
||||
columnar_value_to_py_any(py, res)
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
fn round(py: Python<'_>, val: PyObject) -> PyResult<PyObject> {
|
||||
let value = try_into_columnar_value(py, val)?;
|
||||
let result = datafusion_functions::math::round()
|
||||
.invoke(&[value])
|
||||
.and_then(|x| x.into_array(1))
|
||||
.map_err(|e| PyValueError::new_err(format!("{e:?}")))?;
|
||||
columnar_value_to_py_any(py, ColumnarValue::Array(result))
|
||||
}
|
||||
|
||||
/// The macro for binding function in `datafusion_physical_expr::expressions`(most of them are aggregate function)
|
||||
macro_rules! bind_aggr_expr {
|
||||
($FUNC_NAME:ident, $AGGR_FUNC: ident, [$($ARG: ident),*], $ARG_TY: ident, $($EXPR:ident => $idx: literal),*) => {
|
||||
#[pyfunction]
|
||||
fn $FUNC_NAME(py: Python<'_>, $($ARG: &PyVector),*)->PyResult<PyObject>{
|
||||
// just a place holder, we just want the inner `XXXAccumulator`'s function
|
||||
// so its expr is irrelevant
|
||||
return eval_df_aggr_expr(
|
||||
py,
|
||||
expressions::$AGGR_FUNC::new(
|
||||
$(
|
||||
Arc::new(expressions::Column::new(stringify!($EXPR), $idx)) as _,
|
||||
)*
|
||||
stringify!($AGGR_FUNC),
|
||||
$ARG_TY.arrow_data_type().to_owned()),
|
||||
&[$($ARG.to_arrow_array()),*]
|
||||
)
|
||||
}
|
||||
};
|
||||
}
|
||||
/*
|
||||
`bind_aggr_expr!(approx_distinct, ApproxDistinct,[v0], v0, expr0=>0);`
|
||||
expand into:
|
||||
```
|
||||
fn approx_distinct(py: Python<'_>, v0: &PyVector) -> PyResult<PyObject> {
|
||||
return eval_df_aggr_expr(
|
||||
py,
|
||||
expressions::ApproxDistinct::new(
|
||||
Arc::new(expressions::Column::new("expr0", 0)) as _,
|
||||
"ApproxDistinct",
|
||||
v0.arrow_data_type().to_owned(),
|
||||
),
|
||||
&[v0.to_arrow_array()],
|
||||
);
|
||||
}
|
||||
```
|
||||
|
||||
*/
|
||||
bind_aggr_expr!(approx_distinct, ApproxDistinct,[v0], v0, expr0=>0);
|
||||
|
||||
// bind_aggr_expr!(median, Median,[v0], v0, expr0=>0);
|
||||
|
||||
#[pyfunction]
|
||||
fn approx_percentile_cont(py: Python<'_>, values: &PyVector, percent: f64) -> PyResult<PyObject> {
|
||||
let percent = expressions::Literal::new(datafusion_common::ScalarValue::Float64(Some(percent)));
|
||||
eval_df_aggr_expr(
|
||||
py,
|
||||
expressions::ApproxPercentileCont::new(
|
||||
vec![
|
||||
Arc::new(expressions::Column::new("expr0", 0)) as _,
|
||||
Arc::new(percent) as _,
|
||||
],
|
||||
"ApproxPercentileCont",
|
||||
values.arrow_data_type(),
|
||||
)
|
||||
.map_err(|e| PyValueError::new_err(format!("{e:?}")))?,
|
||||
&[values.to_arrow_array()],
|
||||
)
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
fn array_agg(py: Python<'_>, v: &PyVector) -> PyResult<PyObject> {
|
||||
eval_df_aggr_expr(
|
||||
py,
|
||||
expressions::ArrayAgg::new(
|
||||
Arc::new(expressions::Column::new("expr0", 0)) as _,
|
||||
"ArrayAgg",
|
||||
v.arrow_data_type(),
|
||||
true,
|
||||
),
|
||||
&[v.to_arrow_array()],
|
||||
)
|
||||
}
|
||||
|
||||
bind_aggr_expr!(avg, Avg,[v0], v0, expr0=>0);
|
||||
|
||||
bind_aggr_expr!(correlation, Correlation,[v0, v1], v0, expr0=>0, expr1=>1);
|
||||
|
||||
bind_aggr_expr!(count, Count,[v0], v0, expr0=>0);
|
||||
|
||||
// bind_aggr_expr!(covariance, Covariance,[v0, v1], v0, expr0=>0, expr1=>1);
|
||||
|
||||
// bind_aggr_expr!(covariance_pop, CovariancePop,[v0, v1], v0, expr0=>0, expr1=>1);
|
||||
|
||||
bind_aggr_expr!(max, Max,[v0], v0, expr0=>0);
|
||||
|
||||
bind_aggr_expr!(min, Min,[v0], v0, expr0=>0);
|
||||
|
||||
bind_aggr_expr!(stddev, Stddev,[v0], v0, expr0=>0);
|
||||
|
||||
bind_aggr_expr!(stddev_pop, StddevPop,[v0], v0, expr0=>0);
|
||||
|
||||
bind_aggr_expr!(sum, Sum,[v0], v0, expr0=>0);
|
||||
|
||||
bind_aggr_expr!(variance, Variance,[v0], v0, expr0=>0);
|
||||
|
||||
bind_aggr_expr!(variance_pop, VariancePop,[v0], v0, expr0=>0);
|
||||
@@ -1,354 +0,0 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use arrow::compute;
|
||||
use common_recordbatch::RecordBatch;
|
||||
use datafusion_common::ScalarValue;
|
||||
use datatypes::prelude::ConcreteDataType;
|
||||
use datatypes::vectors::{Helper, VectorRef};
|
||||
use pyo3::exceptions::{PyRuntimeError, PyValueError};
|
||||
use pyo3::types::{PyBool, PyDict, PyFloat, PyInt, PyList, PyModule, PyString, PyTuple};
|
||||
use pyo3::{pymethods, IntoPy, PyAny, PyCell, PyObject, PyResult, Python, ToPyObject};
|
||||
use snafu::{ensure, ResultExt};
|
||||
|
||||
use crate::engine::EvalContext;
|
||||
use crate::python::error::{self, NewRecordBatchSnafu, OtherSnafu, Result};
|
||||
use crate::python::ffi_types::copr::PyQueryEngine;
|
||||
use crate::python::ffi_types::py_recordbatch::PyRecordBatch;
|
||||
use crate::python::ffi_types::{check_args_anno_real_type, select_from_rb, Coprocessor, PyVector};
|
||||
use crate::python::metric;
|
||||
use crate::python::pyo3::dataframe_impl::PyDataFrame;
|
||||
use crate::python::pyo3::utils::{init_cpython_interpreter, pyo3_obj_try_to_typed_val};
|
||||
|
||||
#[pymethods]
|
||||
impl PyQueryEngine {
|
||||
#[pyo3(name = "sql")]
|
||||
pub(crate) fn sql_pyo3(&self, py: Python<'_>, s: String) -> PyResult<PyObject> {
|
||||
let res = self
|
||||
.query_with_new_thread(s.clone())
|
||||
.map_err(PyValueError::new_err)?;
|
||||
match res {
|
||||
crate::python::ffi_types::copr::Either::Rb(rbs) => {
|
||||
let rb = compute::concat_batches(
|
||||
rbs.schema().arrow_schema(),
|
||||
rbs.iter().map(|rb| rb.df_record_batch()),
|
||||
)
|
||||
.map_err(|e| PyRuntimeError::new_err(format!("{e:?}")))?;
|
||||
let rb = RecordBatch::try_from_df_record_batch(rbs.schema(), rb).map_err(|e| {
|
||||
PyRuntimeError::new_err(format!(
|
||||
"Convert datafusion record batch to record batch failed for query {s}: {e}"
|
||||
))
|
||||
})?;
|
||||
let rb = PyRecordBatch::new(rb);
|
||||
Ok(rb.into_py(py))
|
||||
}
|
||||
crate::python::ffi_types::copr::Either::AffectedRows(count) => Ok(count.to_object(py)),
|
||||
}
|
||||
}
|
||||
// TODO: put this into greptime module
|
||||
}
|
||||
|
||||
/// Execute a `Coprocessor` with given `RecordBatch`
|
||||
pub(crate) fn pyo3_exec_parsed(
|
||||
copr: &Coprocessor,
|
||||
rb: &Option<RecordBatch>,
|
||||
params: &HashMap<String, String>,
|
||||
eval_ctx: &EvalContext,
|
||||
) -> Result<RecordBatch> {
|
||||
let _t = metric::METRIC_PYO3_EXEC_TOTAL_ELAPSED.start_timer();
|
||||
// i.e params or use `vector(..)` to construct a PyVector
|
||||
let arg_names = &copr.deco_args.arg_names.clone().unwrap_or_default();
|
||||
let args: Vec<PyVector> = if let Some(rb) = rb {
|
||||
let args = select_from_rb(rb, arg_names)?;
|
||||
check_args_anno_real_type(arg_names, &args, copr, rb)?;
|
||||
args
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
// Just in case cpython is not inited
|
||||
init_cpython_interpreter().unwrap();
|
||||
Python::with_gil(|py| -> Result<_> {
|
||||
let _t = metric::METRIC_PYO3_EXEC_ELAPSED.start_timer();
|
||||
|
||||
let mut cols = (|| -> PyResult<_> {
|
||||
let dummy_decorator = "
|
||||
# Postponed evaluation of annotations(PEP 563) so annotation can be set freely
|
||||
# This is needed for Python < 3.9
|
||||
from __future__ import annotations
|
||||
# A dummy decorator, actual implementation is in Rust code
|
||||
def copr(*dummy, **kwdummy):
|
||||
def inner(func):
|
||||
return func
|
||||
return inner
|
||||
coprocessor = copr
|
||||
";
|
||||
let gen_call = format!("\n_return_from_coprocessor = {}(*_args_for_coprocessor, **_kwargs_for_coprocessor)", copr.name);
|
||||
let script = format!("{}{}{}", dummy_decorator, copr.script, gen_call);
|
||||
let args = args
|
||||
.clone()
|
||||
.into_iter()
|
||||
.map(|v| PyCell::new(py, v))
|
||||
.collect::<PyResult<Vec<_>>>()?;
|
||||
let args = PyTuple::new(py, args);
|
||||
|
||||
let kwargs = PyDict::new(py);
|
||||
if let Some(_copr_kwargs) = &copr.kwarg {
|
||||
for (k, v) in params {
|
||||
kwargs.set_item(k, v)?;
|
||||
}
|
||||
}
|
||||
|
||||
let py_main = PyModule::import(py, "__main__")?;
|
||||
let globals = py_main.dict();
|
||||
|
||||
let locals = py_main.dict();
|
||||
|
||||
if let Some(engine) = &copr.query_engine {
|
||||
let query_engine = PyQueryEngine::from_weakref(engine.clone(), eval_ctx.query_ctx.clone());
|
||||
let query_engine = PyCell::new(py, query_engine)?;
|
||||
globals.set_item("__query__", query_engine)?;
|
||||
}
|
||||
|
||||
// TODO(discord9): find out why `dataframe` is not in scope
|
||||
if let Some(rb) = rb {
|
||||
let dataframe = PyDataFrame::from_record_batch(rb.df_record_batch())
|
||||
.map_err(|err|
|
||||
PyValueError::new_err(
|
||||
format!("Can't create dataframe from record batch: {}", err
|
||||
)
|
||||
)
|
||||
)?;
|
||||
let dataframe = PyCell::new(py, dataframe)?;
|
||||
globals.set_item("__dataframe__", dataframe)?;
|
||||
}
|
||||
|
||||
locals.set_item("_args_for_coprocessor", args)?;
|
||||
locals.set_item("_kwargs_for_coprocessor", kwargs)?;
|
||||
// `greptime` is already import when init interpreter, so no need to set in here
|
||||
|
||||
// TODO(discord9): find a better way to set `dataframe` and `query` in scope/ or set it into module(latter might be impossible and not idomatic even in python)
|
||||
// set `dataframe` and `query` in scope/ or set it into module
|
||||
// could generate a call in python code and use Python::run to run it, just like in RustPython
|
||||
// Expect either: a PyVector Or a List/Tuple of PyVector
|
||||
py.run(&script, Some(globals), Some(locals))?;
|
||||
let result = locals.get_item("_return_from_coprocessor")?.ok_or_else(||
|
||||
PyValueError::new_err(format!("cannot find the return value of script '{script}'"))
|
||||
)?;
|
||||
|
||||
let col_len = rb.as_ref().map(|rb| rb.num_rows()).unwrap_or(1);
|
||||
py_any_to_vec(result, col_len)
|
||||
})()
|
||||
.map_err(|err| error::Error::PyRuntime {
|
||||
msg: err.into_value(py).to_string(),
|
||||
location: snafu::location!(),
|
||||
})?;
|
||||
ensure!(
|
||||
cols.len() == copr.deco_args.ret_names.len(),
|
||||
OtherSnafu {
|
||||
reason: format!(
|
||||
"The number of return Vector is wrong, expect {}, found {}",
|
||||
copr.deco_args.ret_names.len(),
|
||||
cols.len()
|
||||
)
|
||||
}
|
||||
);
|
||||
copr.check_and_cast_type(&mut cols)?;
|
||||
let schema = copr.gen_schema(&cols)?;
|
||||
RecordBatch::new(schema, cols).context(NewRecordBatchSnafu)
|
||||
})
|
||||
}
|
||||
|
||||
/// Cast return of py script result to `Vec<VectorRef>`,
|
||||
/// constants will be broadcast to length of `col_len`
|
||||
/// accept and convert if obj is of two types:
|
||||
/// 1. tuples of PyVector/PyList of literals/single literal of same type
|
||||
/// or a mixed tuple of PyVector and PyList of same type Literals
|
||||
/// 2. a single PyVector
|
||||
/// 3. a PyList of same type Literals
|
||||
/// 4. a single constant, will be expanded to a PyVector of length of `col_len`
|
||||
fn py_any_to_vec(obj: &PyAny, col_len: usize) -> PyResult<Vec<VectorRef>> {
|
||||
let is_literal = |obj: &PyAny| -> PyResult<bool> {
|
||||
Ok(obj.is_instance_of::<PyInt>()
|
||||
|| obj.is_instance_of::<PyFloat>()
|
||||
|| obj.is_instance_of::<PyString>()
|
||||
|| obj.is_instance_of::<PyBool>())
|
||||
};
|
||||
let check = if obj.is_instance_of::<PyTuple>() {
|
||||
let tuple = obj.downcast::<PyTuple>()?;
|
||||
(0..tuple.len())
|
||||
.map(|idx| {
|
||||
tuple.get_item(idx).map(|i| -> PyResult<bool> {
|
||||
Ok(i.is_instance_of::<PyVector>()
|
||||
|| i.is_instance_of::<PyList>()
|
||||
|| is_literal(i)?)
|
||||
})
|
||||
})
|
||||
.all(|i| matches!(i, Ok(Ok(true))))
|
||||
} else {
|
||||
obj.is_instance_of::<PyVector>() || obj.is_instance_of::<PyList>() || is_literal(obj)?
|
||||
};
|
||||
if !check {
|
||||
return Err(PyRuntimeError::new_err(format!(
|
||||
"Expect a tuple of vectors(or lists) or one single vector or a list of same type literals, found {obj}"
|
||||
)));
|
||||
}
|
||||
|
||||
if let Ok(tuple) = obj.downcast::<PyTuple>() {
|
||||
let len = tuple.len();
|
||||
let v = (0..len)
|
||||
.map(|idx| tuple.get_item(idx))
|
||||
.map(|elem| {
|
||||
elem.map(|any| {
|
||||
if let Ok(list) = any.downcast::<PyList>() {
|
||||
py_list_to_vec(list)
|
||||
} else {
|
||||
py_obj_broadcast_to_vec(any, col_len)
|
||||
}
|
||||
})
|
||||
.and_then(|v| v)
|
||||
})
|
||||
.collect::<PyResult<Vec<_>>>()?;
|
||||
Ok(v)
|
||||
} else if let Ok(list) = obj.downcast::<PyList>() {
|
||||
let ret = py_list_to_vec(list)?;
|
||||
Ok(vec![ret])
|
||||
} else {
|
||||
let ret = py_obj_broadcast_to_vec(obj, col_len)?;
|
||||
Ok(vec![ret])
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert a python list to a [`VectorRef`] all of same type: bool/int/float/string
|
||||
fn py_list_to_vec(list: &PyList) -> PyResult<VectorRef> {
|
||||
/// make sure elements of list is all of same type: bool/int/float/string
|
||||
#[derive(PartialEq, Eq, Debug, Copy, Clone)]
|
||||
enum ExpectType {
|
||||
Bool,
|
||||
Int,
|
||||
Float,
|
||||
String,
|
||||
}
|
||||
let mut expected_type = None;
|
||||
let mut v = Vec::with_capacity(list.len());
|
||||
for (idx, elem) in list.iter().enumerate() {
|
||||
let (elem_ty, con_type) = if elem.is_instance_of::<PyBool>() {
|
||||
(ExpectType::Bool, ConcreteDataType::boolean_datatype())
|
||||
} else if elem.is_instance_of::<PyInt>() {
|
||||
(ExpectType::Int, ConcreteDataType::int64_datatype())
|
||||
} else if elem.is_instance_of::<PyFloat>() {
|
||||
(ExpectType::Float, ConcreteDataType::float64_datatype())
|
||||
} else if elem.is_instance_of::<PyString>() {
|
||||
(ExpectType::String, ConcreteDataType::string_datatype())
|
||||
} else {
|
||||
return Err(PyRuntimeError::new_err(format!(
|
||||
"Expect list contains bool or int or float or string, found <{list}>"
|
||||
)));
|
||||
};
|
||||
if let Some(ty) = expected_type {
|
||||
if ty != elem_ty {
|
||||
return Err(PyRuntimeError::new_err(format!(
|
||||
"Expect a list of same type elements, found {list} in position {idx} in list"
|
||||
)));
|
||||
}
|
||||
} else {
|
||||
expected_type = Some(elem_ty);
|
||||
}
|
||||
// push into a vector buffer
|
||||
let val = pyo3_obj_try_to_typed_val(elem, Some(con_type))?;
|
||||
let scalar = val.try_to_scalar_value(&val.data_type()).map_err(|err| {
|
||||
PyRuntimeError::new_err(format!("Can't convert value to scalar value: {}", err))
|
||||
})?;
|
||||
v.push(scalar);
|
||||
}
|
||||
let array = ScalarValue::iter_to_array(v).map_err(|err| {
|
||||
PyRuntimeError::new_err(format!("Can't convert scalar value list to array: {}", err))
|
||||
})?;
|
||||
let ret = Helper::try_into_vector(array).map_err(|err| {
|
||||
PyRuntimeError::new_err(format!("Can't convert array to vector: {}", err))
|
||||
})?;
|
||||
Ok(ret)
|
||||
}
|
||||
|
||||
/// broadcast a single Python Object to a Vector of same object with length `col_len`
|
||||
/// obj is either:
|
||||
/// 1. a PyVector
|
||||
/// 2. a single Literal
|
||||
fn py_obj_broadcast_to_vec(obj: &PyAny, col_len: usize) -> PyResult<VectorRef> {
|
||||
if let Ok(v) = obj.extract::<PyVector>() {
|
||||
Ok(v.as_vector_ref())
|
||||
} else {
|
||||
let val = pyo3_obj_try_to_typed_val(obj, None)?;
|
||||
let handler = |e: datatypes::Error| PyValueError::new_err(e.to_string());
|
||||
let v = Helper::try_from_scalar_value(
|
||||
val.try_to_scalar_value(&val.data_type()).map_err(handler)?,
|
||||
col_len,
|
||||
)
|
||||
.map_err(handler)?;
|
||||
Ok(v)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod copr_test {
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_recordbatch::RecordBatch;
|
||||
use datatypes::data_type::ConcreteDataType;
|
||||
use datatypes::schema::{ColumnSchema, Schema};
|
||||
use datatypes::vectors::{Float32Vector, Float64Vector, VectorRef};
|
||||
|
||||
use super::*;
|
||||
use crate::python::ffi_types::copr::{exec_parsed, parse, BackendType};
|
||||
|
||||
#[test]
|
||||
#[allow(unused_must_use)]
|
||||
fn simple_test_pyo3_copr() {
|
||||
let python_source = r#"
|
||||
@copr(args=["cpu", "mem"], returns=["ref"], backend="pyo3")
|
||||
def a(cpu, mem, **kwargs):
|
||||
import greptime as gt
|
||||
from greptime import vector, log2, sum, pow, col, lit, dataframe
|
||||
for k, v in kwargs.items():
|
||||
print("%s == %s" % (k, v))
|
||||
print(dataframe().select([col("cpu")<lit(0.3)]).collect())
|
||||
return (0.5 < cpu) & ~(cpu >= 0.75)
|
||||
"#;
|
||||
let cpu_array = Float32Vector::from_slice([0.9f32, 0.8, 0.7, 0.3]);
|
||||
let mem_array = Float64Vector::from_slice([0.1f64, 0.2, 0.3, 0.4]);
|
||||
let schema = Arc::new(Schema::new(vec![
|
||||
ColumnSchema::new("cpu", ConcreteDataType::float32_datatype(), false),
|
||||
ColumnSchema::new("mem", ConcreteDataType::float64_datatype(), false),
|
||||
]));
|
||||
let rb = RecordBatch::new(
|
||||
schema,
|
||||
[
|
||||
Arc::new(cpu_array) as VectorRef,
|
||||
Arc::new(mem_array) as VectorRef,
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
let copr = parse::parse_and_compile_copr(python_source, None).unwrap();
|
||||
assert_eq!(copr.backend, BackendType::CPython);
|
||||
let ret = exec_parsed(
|
||||
&copr,
|
||||
&Some(rb),
|
||||
&HashMap::from([("a".to_string(), "1".to_string())]),
|
||||
&EvalContext::default(),
|
||||
);
|
||||
let _ = ret.unwrap();
|
||||
}
|
||||
}
|
||||
@@ -1,309 +0,0 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::ops::Not;
|
||||
|
||||
use arrow::compute;
|
||||
use common_recordbatch::{DfRecordBatch, RecordBatch};
|
||||
use datafusion::dataframe::DataFrame as DfDataFrame;
|
||||
use datafusion_expr::Expr as DfExpr;
|
||||
use datatypes::schema::Schema;
|
||||
use pyo3::exceptions::{PyRuntimeError, PyValueError};
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::pyclass::CompareOp;
|
||||
use pyo3::types::{PyDict, PyType};
|
||||
use snafu::ResultExt;
|
||||
|
||||
use crate::python::error::DataFusionSnafu;
|
||||
use crate::python::ffi_types::py_recordbatch::PyRecordBatch;
|
||||
use crate::python::pyo3::builtins::query_engine;
|
||||
use crate::python::pyo3::utils::pyo3_obj_try_to_typed_scalar_value;
|
||||
use crate::python::utils::block_on_async;
|
||||
type PyExprRef = Py<PyExpr>;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[pyclass]
|
||||
pub(crate) struct PyDataFrame {
|
||||
inner: DfDataFrame,
|
||||
}
|
||||
|
||||
impl From<DfDataFrame> for PyDataFrame {
|
||||
fn from(inner: DfDataFrame) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
}
|
||||
|
||||
impl PyDataFrame {
|
||||
pub(crate) fn from_record_batch(rb: &DfRecordBatch) -> crate::python::error::Result<Self> {
|
||||
let ctx = datafusion::execution::context::SessionContext::new();
|
||||
let inner = ctx.read_batch(rb.clone()).context(DataFusionSnafu)?;
|
||||
Ok(Self { inner })
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl PyDataFrame {
|
||||
#[classmethod]
|
||||
fn from_sql(_cls: &PyType, py: Python, sql: String) -> PyResult<Self> {
|
||||
let query = query_engine(py)?;
|
||||
let rb = query.sql_to_rb(sql).map_err(PyRuntimeError::new_err)?;
|
||||
let ctx = datafusion::execution::context::SessionContext::new();
|
||||
ctx.read_batch(rb.df_record_batch().clone())
|
||||
.map_err(|e| PyRuntimeError::new_err(format!("{e:?}")))
|
||||
.map(Self::from)
|
||||
}
|
||||
fn __call__(&self) -> PyResult<Self> {
|
||||
Ok(self.clone())
|
||||
}
|
||||
fn select_columns(&self, columns: Vec<String>) -> PyResult<Self> {
|
||||
Ok(self
|
||||
.inner
|
||||
.clone()
|
||||
.select_columns(&columns.iter().map(AsRef::as_ref).collect::<Vec<&str>>())
|
||||
.map_err(|e| PyValueError::new_err(e.to_string()))?
|
||||
.into())
|
||||
}
|
||||
fn select(&self, py: Python<'_>, expr_list: Vec<PyExprRef>) -> PyResult<Self> {
|
||||
Ok(self
|
||||
.inner
|
||||
.clone()
|
||||
.select(
|
||||
expr_list
|
||||
.iter()
|
||||
.map(|e| e.borrow(py).inner.clone())
|
||||
.collect(),
|
||||
)
|
||||
.map_err(|e| PyValueError::new_err(e.to_string()))?
|
||||
.into())
|
||||
}
|
||||
fn filter(&self, predicate: &PyExpr) -> PyResult<Self> {
|
||||
Ok(self
|
||||
.inner
|
||||
.clone()
|
||||
.filter(predicate.inner.clone())
|
||||
.map_err(|e| PyValueError::new_err(e.to_string()))?
|
||||
.into())
|
||||
}
|
||||
fn aggregate(
|
||||
&self,
|
||||
py: Python<'_>,
|
||||
group_expr: Vec<PyExprRef>,
|
||||
aggr_expr: Vec<PyExprRef>,
|
||||
) -> PyResult<Self> {
|
||||
let ret = self.inner.clone().aggregate(
|
||||
group_expr
|
||||
.iter()
|
||||
.map(|i| i.borrow(py).inner.clone())
|
||||
.collect(),
|
||||
aggr_expr
|
||||
.iter()
|
||||
.map(|i| i.borrow(py).inner.clone())
|
||||
.collect(),
|
||||
);
|
||||
Ok(ret
|
||||
.map_err(|e| PyValueError::new_err(e.to_string()))?
|
||||
.into())
|
||||
}
|
||||
fn limit(&self, skip: usize, fetch: Option<usize>) -> PyResult<Self> {
|
||||
Ok(self
|
||||
.inner
|
||||
.clone()
|
||||
.limit(skip, fetch)
|
||||
.map_err(|e| PyValueError::new_err(e.to_string()))?
|
||||
.into())
|
||||
}
|
||||
fn union(&self, df: &PyDataFrame) -> PyResult<Self> {
|
||||
Ok(self
|
||||
.inner
|
||||
.clone()
|
||||
.union(df.inner.clone())
|
||||
.map_err(|e| PyValueError::new_err(e.to_string()))?
|
||||
.into())
|
||||
}
|
||||
fn union_distinct(&self, df: &PyDataFrame) -> PyResult<Self> {
|
||||
Ok(self
|
||||
.inner
|
||||
.clone()
|
||||
.union_distinct(df.inner.clone())
|
||||
.map_err(|e| PyValueError::new_err(e.to_string()))?
|
||||
.into())
|
||||
}
|
||||
fn distinct(&self) -> PyResult<Self> {
|
||||
Ok(self
|
||||
.inner
|
||||
.clone()
|
||||
.distinct()
|
||||
.map_err(|e| PyValueError::new_err(e.to_string()))?
|
||||
.into())
|
||||
}
|
||||
fn sort(&self, py: Python<'_>, expr: Vec<PyExprRef>) -> PyResult<Self> {
|
||||
Ok(self
|
||||
.inner
|
||||
.clone()
|
||||
.sort(expr.iter().map(|e| e.borrow(py).inner.clone()).collect())
|
||||
.map_err(|e| PyValueError::new_err(e.to_string()))?
|
||||
.into())
|
||||
}
|
||||
fn join(
|
||||
&self,
|
||||
py: Python<'_>,
|
||||
right: &PyDataFrame,
|
||||
join_type: String,
|
||||
left_cols: Vec<String>,
|
||||
right_cols: Vec<String>,
|
||||
filter: Option<PyExprRef>,
|
||||
) -> PyResult<Self> {
|
||||
use datafusion::prelude::JoinType;
|
||||
let join_type = match join_type.as_str() {
|
||||
"inner" | "Inner" => JoinType::Inner,
|
||||
"left" | "Left" => JoinType::Left,
|
||||
"right" | "Right" => JoinType::Right,
|
||||
"full" | "Full" => JoinType::Full,
|
||||
"leftSemi" | "LeftSemi" => JoinType::LeftSemi,
|
||||
"rightSemi" | "RightSemi" => JoinType::RightSemi,
|
||||
"leftAnti" | "LeftAnti" => JoinType::LeftAnti,
|
||||
"rightAnti" | "RightAnti" => JoinType::RightAnti,
|
||||
_ => {
|
||||
return Err(PyValueError::new_err(format!(
|
||||
"Unknown join type: {join_type}"
|
||||
)))
|
||||
}
|
||||
};
|
||||
let left_cols: Vec<&str> = left_cols.iter().map(AsRef::as_ref).collect();
|
||||
let right_cols: Vec<&str> = right_cols.iter().map(AsRef::as_ref).collect();
|
||||
let filter = filter.map(|f| f.borrow(py).inner.clone());
|
||||
Ok(self
|
||||
.inner
|
||||
.clone()
|
||||
.join(
|
||||
right.inner.clone(),
|
||||
join_type,
|
||||
&left_cols,
|
||||
&right_cols,
|
||||
filter,
|
||||
)
|
||||
.map_err(|e| PyValueError::new_err(e.to_string()))?
|
||||
.into())
|
||||
}
|
||||
fn intersect(&self, py: Python<'_>, df: &PyDataFrame) -> PyResult<Self> {
|
||||
py.allow_threads(|| {
|
||||
Ok(self
|
||||
.inner
|
||||
.clone()
|
||||
.intersect(df.inner.clone())
|
||||
.map_err(|e| PyValueError::new_err(e.to_string()))?
|
||||
.into())
|
||||
})
|
||||
}
|
||||
fn except(&self, df: &PyDataFrame) -> PyResult<Self> {
|
||||
Ok(self
|
||||
.inner
|
||||
.clone()
|
||||
.except(df.inner.clone())
|
||||
.map_err(|e| PyValueError::new_err(e.to_string()))?
|
||||
.into())
|
||||
}
|
||||
/// collect `DataFrame` results into `PyRecordBatch` that impl Mapping Protocol
|
||||
fn collect(&self, py: Python) -> PyResult<PyObject> {
|
||||
let inner = self.inner.clone();
|
||||
let res = block_on_async(async { inner.collect().await });
|
||||
let res = res
|
||||
.map_err(|e| PyValueError::new_err(format!("{e:?}")))?
|
||||
.map_err(|e| PyValueError::new_err(e.to_string()))?;
|
||||
if res.is_empty() {
|
||||
return Ok(PyDict::new(py).into());
|
||||
}
|
||||
let concat_rb = compute::concat_batches(&res[0].schema(), res.iter()).map_err(|e| {
|
||||
PyRuntimeError::new_err(format!("Concat batches failed for dataframe {self:?}: {e}"))
|
||||
})?;
|
||||
|
||||
let schema = Schema::try_from(concat_rb.schema()).map_err(|e| {
|
||||
PyRuntimeError::new_err(format!(
|
||||
"Convert to Schema failed for dataframe {self:?}: {e}"
|
||||
))
|
||||
})?;
|
||||
let rb = RecordBatch::try_from_df_record_batch(schema.into(), concat_rb).map_err(|e| {
|
||||
PyRuntimeError::new_err(format!(
|
||||
"Convert to RecordBatch failed for dataframe {self:?}: {e}"
|
||||
))
|
||||
})?;
|
||||
let rb = PyRecordBatch::new(rb);
|
||||
Ok(rb.into_py(py))
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert a Python Object into a `Expr` for use in constructing literal i.e. `col("number") < lit(42)`
|
||||
#[pyfunction]
|
||||
pub(crate) fn lit(py: Python<'_>, value: PyObject) -> PyResult<PyExpr> {
|
||||
let value = pyo3_obj_try_to_typed_scalar_value(value.as_ref(py), None)?;
|
||||
let expr: PyExpr = DfExpr::Literal(value).into();
|
||||
Ok(expr)
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
#[pyclass]
|
||||
pub(crate) struct PyExpr {
|
||||
inner: DfExpr,
|
||||
}
|
||||
|
||||
impl From<datafusion_expr::Expr> for PyExpr {
|
||||
fn from(value: DfExpr) -> Self {
|
||||
Self { inner: value }
|
||||
}
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
pub(crate) fn col(name: String) -> PyExpr {
|
||||
let expr: PyExpr = DfExpr::Column(datafusion_common::Column::from_name(name)).into();
|
||||
expr
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl PyExpr {
|
||||
fn __call__(&self) -> PyResult<Self> {
|
||||
Ok(self.clone())
|
||||
}
|
||||
fn __richcmp__(&self, py: Python<'_>, other: PyObject, op: CompareOp) -> PyResult<Self> {
|
||||
let other = other.extract::<Self>(py).or_else(|_| lit(py, other))?;
|
||||
let op = match op {
|
||||
CompareOp::Lt => DfExpr::lt,
|
||||
CompareOp::Le => DfExpr::lt_eq,
|
||||
CompareOp::Eq => DfExpr::eq,
|
||||
CompareOp::Ne => DfExpr::not_eq,
|
||||
CompareOp::Gt => DfExpr::gt,
|
||||
CompareOp::Ge => DfExpr::gt_eq,
|
||||
};
|
||||
py.allow_threads(|| Ok(op(self.inner.clone(), other.inner.clone()).into()))
|
||||
}
|
||||
fn alias(&self, name: String) -> PyResult<PyExpr> {
|
||||
Ok(self.inner.clone().alias(name).into())
|
||||
}
|
||||
fn __and__(&self, py: Python<'_>, other: PyExprRef) -> PyResult<PyExpr> {
|
||||
let other = other.borrow(py).inner.clone();
|
||||
py.allow_threads(|| Ok(self.inner.clone().and(other).into()))
|
||||
}
|
||||
fn __or__(&self, py: Python<'_>, other: PyExprRef) -> PyResult<PyExpr> {
|
||||
let other = other.borrow(py).inner.clone();
|
||||
py.allow_threads(|| Ok(self.inner.clone().or(other).into()))
|
||||
}
|
||||
fn __invert__(&self) -> PyResult<PyExpr> {
|
||||
Ok(self.inner.clone().not().into())
|
||||
}
|
||||
fn sort(&self, asc: bool, nulls_first: bool) -> PyExpr {
|
||||
self.inner.clone().sort(asc, nulls_first).into()
|
||||
}
|
||||
fn __repr__(&self) -> String {
|
||||
format!("{:#?}", &self.inner)
|
||||
}
|
||||
}
|
||||
@@ -1,311 +0,0 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Mutex;
|
||||
|
||||
use arrow::pyarrow::PyArrowException;
|
||||
use common_telemetry::info;
|
||||
use datafusion_common::ScalarValue;
|
||||
use datafusion_expr::ColumnarValue;
|
||||
use datatypes::arrow::datatypes::DataType as ArrowDataType;
|
||||
use datatypes::prelude::ConcreteDataType;
|
||||
use datatypes::value::{OrderedFloat, Value};
|
||||
use datatypes::vectors::Helper;
|
||||
use once_cell::sync::Lazy;
|
||||
use pyo3::exceptions::PyValueError;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::{PyBool, PyFloat, PyInt, PyList, PyTuple};
|
||||
|
||||
use crate::python::ffi_types::utils::collect_diff_types_string;
|
||||
use crate::python::ffi_types::PyVector;
|
||||
use crate::python::metric;
|
||||
use crate::python::pyo3::builtins::greptime_builtins;
|
||||
|
||||
/// prevent race condition of init cpython
|
||||
static START_PYO3: Lazy<Mutex<bool>> = Lazy::new(|| Mutex::new(false));
|
||||
pub(crate) fn to_py_err(err: impl ToString) -> PyErr {
|
||||
PyArrowException::new_err(err.to_string())
|
||||
}
|
||||
|
||||
/// init cpython interpreter with `greptime` builtins, if already inited, do nothing
|
||||
pub(crate) fn init_cpython_interpreter() -> PyResult<()> {
|
||||
let _t = metric::METRIC_PYO3_INIT_ELAPSED.start_timer();
|
||||
let mut start = START_PYO3.lock().unwrap();
|
||||
if !*start {
|
||||
pyo3::append_to_inittab!(greptime_builtins);
|
||||
pyo3::prepare_freethreaded_python();
|
||||
let version = Python::with_gil(|py| -> PyResult<String> {
|
||||
let builtins = PyModule::import(py, "sys")?;
|
||||
builtins.getattr("version")?.extract()
|
||||
})?;
|
||||
*start = true;
|
||||
info!("Started CPython Interpreter {version}");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn val_to_py_any(py: Python<'_>, val: Value) -> PyResult<PyObject> {
|
||||
Ok(match val {
|
||||
Value::Null => py.None(),
|
||||
Value::Boolean(val) => val.to_object(py),
|
||||
Value::UInt8(val) => val.to_object(py),
|
||||
Value::UInt16(val) => val.to_object(py),
|
||||
Value::UInt32(val) => val.to_object(py),
|
||||
Value::UInt64(val) => val.to_object(py),
|
||||
Value::Int8(val) => val.to_object(py),
|
||||
Value::Int16(val) => val.to_object(py),
|
||||
Value::Int32(val) => val.to_object(py),
|
||||
Value::Int64(val) => val.to_object(py),
|
||||
Value::Float32(val) => val.0.to_object(py),
|
||||
Value::Float64(val) => val.0.to_object(py),
|
||||
Value::String(val) => val.as_utf8().to_object(py),
|
||||
Value::Binary(val) => val.to_object(py),
|
||||
Value::Date(val) => val.val().to_object(py),
|
||||
Value::DateTime(val) => val.val().to_object(py),
|
||||
Value::Timestamp(val) => val.value().to_object(py),
|
||||
Value::List(val) => {
|
||||
let list = val
|
||||
.items()
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(|v| val_to_py_any(py, v))
|
||||
.collect::<PyResult<Vec<_>>>()?;
|
||||
list.to_object(py)
|
||||
}
|
||||
#[allow(unreachable_patterns)]
|
||||
_ => {
|
||||
return Err(PyValueError::new_err(format!(
|
||||
"Convert from {val:?} is not supported yet"
|
||||
)))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
macro_rules! to_con_type {
|
||||
($dtype:ident,$obj:ident, $($cty:ident => $rty:ty),*$(,)?) => {
|
||||
match $dtype {
|
||||
$(
|
||||
ConcreteDataType::$cty(_) => $obj.extract::<$rty>().map(Value::$cty),
|
||||
)*
|
||||
_ => unreachable!(),
|
||||
}
|
||||
};
|
||||
($dtype:ident,$obj:ident, $($cty:ident =ord=> $rty:ty),*$(,)?) => {
|
||||
match $dtype {
|
||||
$(
|
||||
ConcreteDataType::$cty(_) => $obj.extract::<$rty>()
|
||||
.map(OrderedFloat)
|
||||
.map(Value::$cty),
|
||||
)*
|
||||
_ => unreachable!(),
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// Convert PyAny to [`ScalarValue`]
|
||||
pub(crate) fn pyo3_obj_try_to_typed_scalar_value(
|
||||
obj: &PyAny,
|
||||
dtype: Option<ConcreteDataType>,
|
||||
) -> PyResult<ScalarValue> {
|
||||
let val = pyo3_obj_try_to_typed_val(obj, dtype)?;
|
||||
val.try_to_scalar_value(&val.data_type())
|
||||
.map_err(|e| PyValueError::new_err(e.to_string()))
|
||||
}
|
||||
/// to int/float/boolean, if dtype is None, then convert to highest prec type
|
||||
pub(crate) fn pyo3_obj_try_to_typed_val(
|
||||
obj: &PyAny,
|
||||
dtype: Option<ConcreteDataType>,
|
||||
) -> PyResult<Value> {
|
||||
if let Ok(b) = obj.downcast::<PyBool>() {
|
||||
if let Some(ConcreteDataType::Boolean(_)) = dtype {
|
||||
let dtype = ConcreteDataType::boolean_datatype();
|
||||
let ret = to_con_type!(dtype, b,
|
||||
Boolean => bool
|
||||
)?;
|
||||
Ok(ret)
|
||||
} else {
|
||||
Err(PyValueError::new_err(format!(
|
||||
"Can't cast num to {dtype:?}"
|
||||
)))
|
||||
}
|
||||
} else if let Ok(num) = obj.downcast::<PyInt>() {
|
||||
if let Some(dtype) = dtype {
|
||||
if dtype.is_signed() || dtype.is_unsigned() {
|
||||
let ret = to_con_type!(dtype, num,
|
||||
Int8 => i8,
|
||||
Int16 => i16,
|
||||
Int32 => i32,
|
||||
Int64 => i64,
|
||||
UInt8 => u8,
|
||||
UInt16 => u16,
|
||||
UInt32 => u32,
|
||||
UInt64 => u64,
|
||||
)?;
|
||||
Ok(ret)
|
||||
} else {
|
||||
Err(PyValueError::new_err(format!(
|
||||
"Can't cast num to {dtype:?}"
|
||||
)))
|
||||
}
|
||||
} else {
|
||||
num.extract::<i64>().map(Value::Int64)
|
||||
}
|
||||
} else if let Ok(num) = obj.downcast::<PyFloat>() {
|
||||
if let Some(dtype) = dtype {
|
||||
if dtype.is_float() {
|
||||
let ret = to_con_type!(dtype, num,
|
||||
Float32 =ord=> f32,
|
||||
Float64 =ord=> f64,
|
||||
)?;
|
||||
Ok(ret)
|
||||
} else {
|
||||
Err(PyValueError::new_err(format!(
|
||||
"Can't cast num to {dtype:?}"
|
||||
)))
|
||||
}
|
||||
} else {
|
||||
num.extract::<f64>()
|
||||
.map(|v| Value::Float64(OrderedFloat(v)))
|
||||
}
|
||||
} else if let Ok(s) = obj.extract::<String>() {
|
||||
Ok(Value::String(s.into()))
|
||||
} else {
|
||||
Err(PyValueError::new_err(format!(
|
||||
"Can't cast {obj} to {dtype:?}"
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
/// cast a columnar value into python object
|
||||
///
|
||||
/// | Rust | Python |
|
||||
/// | ------ | --------------- |
|
||||
/// | Array | PyVector |
|
||||
/// | Scalar | int/float/bool/str |
|
||||
pub fn columnar_value_to_py_any(py: Python<'_>, val: ColumnarValue) -> PyResult<PyObject> {
|
||||
match val {
|
||||
ColumnarValue::Array(arr) => {
|
||||
let v = PyVector::from(
|
||||
Helper::try_into_vector(arr).map_err(|e| PyValueError::new_err(format!("{e}")))?,
|
||||
);
|
||||
Ok(PyCell::new(py, v)?.into())
|
||||
}
|
||||
ColumnarValue::Scalar(scalar) => scalar_value_to_py_any(py, scalar),
|
||||
}
|
||||
}
|
||||
|
||||
/// turn a ScalarValue into a Python Object, currently support
|
||||
pub fn scalar_value_to_py_any(py: Python<'_>, val: ScalarValue) -> PyResult<PyObject> {
|
||||
macro_rules! to_py_any {
|
||||
($val:ident, [$($scalar_ty:ident),*]) => {
|
||||
match val{
|
||||
ScalarValue::Null => Ok(py.None()),
|
||||
$(ScalarValue::$scalar_ty(Some(v)) => Ok(v.to_object(py)),)*
|
||||
ScalarValue::List(array) => {
|
||||
let col = ScalarValue::convert_array_to_scalar_vec(array.as_ref()).map_err(|e|
|
||||
PyValueError::new_err(format!("{e}"))
|
||||
)?;
|
||||
let list:Vec<PyObject> = col
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.map(|v| scalar_value_to_py_any(py, v))
|
||||
.collect::<PyResult<_>>()?;
|
||||
let list = PyList::new(py, list);
|
||||
Ok(list.into())
|
||||
}
|
||||
_ => Err(PyValueError::new_err(format!(
|
||||
"Can't cast a Scalar Value `{:#?}` of type {:#?} to a Python Object",
|
||||
$val, $val.data_type()
|
||||
)))
|
||||
}
|
||||
};
|
||||
}
|
||||
to_py_any!(
|
||||
val,
|
||||
[
|
||||
Boolean, Float32, Float64, Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64,
|
||||
Utf8, LargeUtf8
|
||||
]
|
||||
)
|
||||
}
|
||||
|
||||
pub fn try_into_columnar_value(py: Python<'_>, obj: PyObject) -> PyResult<ColumnarValue> {
|
||||
macro_rules! to_rust_types {
|
||||
($obj: ident, $($ty: ty => $scalar_ty: ident),*) => {
|
||||
$(
|
||||
if let Ok(val) = $obj.extract::<$ty>(py) {
|
||||
Ok(ColumnarValue::Scalar(ScalarValue::$scalar_ty(Some(val))))
|
||||
}
|
||||
)else*
|
||||
else{
|
||||
Err(PyValueError::new_err(format!("Can't cast {} into Columnar Value", $obj)))
|
||||
}
|
||||
};
|
||||
}
|
||||
if let Ok(v) = obj.extract::<PyVector>(py) {
|
||||
Ok(ColumnarValue::Array(v.to_arrow_array()))
|
||||
} else if obj.as_ref(py).is_instance_of::<PyList>()
|
||||
|| obj.as_ref(py).is_instance_of::<PyTuple>()
|
||||
{
|
||||
let ret: Vec<ScalarValue> = {
|
||||
if let Ok(val) = obj.downcast::<PyList>(py) {
|
||||
val.iter().map(|v|->PyResult<ScalarValue>{
|
||||
let val = try_into_columnar_value(py, v.into())?;
|
||||
match val{
|
||||
ColumnarValue::Array(arr) => Err(PyValueError::new_err(format!(
|
||||
"Expect only scalar value in a list, found a vector of type {:?} nested in list", arr.data_type()
|
||||
))),
|
||||
ColumnarValue::Scalar(val) => Ok(val),
|
||||
}
|
||||
}).collect::<PyResult<_>>()?
|
||||
} else if let Ok(val) = obj.downcast::<PyTuple>(py) {
|
||||
val.iter().map(|v|->PyResult<ScalarValue>{
|
||||
let val = try_into_columnar_value(py, v.into())?;
|
||||
match val{
|
||||
ColumnarValue::Array(arr) => Err(PyValueError::new_err(format!(
|
||||
"Expect only scalar value in a tuple, found a vector of type {:?} nested in tuple", arr.data_type()
|
||||
))),
|
||||
ColumnarValue::Scalar(val) => Ok(val),
|
||||
}
|
||||
}).collect::<PyResult<_>>()?
|
||||
} else {
|
||||
unreachable!()
|
||||
}
|
||||
};
|
||||
|
||||
if ret.is_empty() {
|
||||
return Ok(ColumnarValue::Scalar(ScalarValue::List(
|
||||
ScalarValue::new_list(&[], &ArrowDataType::Null),
|
||||
)));
|
||||
}
|
||||
let ty = ret[0].data_type();
|
||||
|
||||
if ret.iter().any(|i| i.data_type() != ty) {
|
||||
return Err(PyValueError::new_err(format!(
|
||||
"All elements in a list should be same type to cast to Datafusion list!\nExpect {ty:?}, found {}",
|
||||
collect_diff_types_string(&ret, &ty)
|
||||
)));
|
||||
}
|
||||
Ok(ColumnarValue::Scalar(ScalarValue::List(
|
||||
ScalarValue::new_list(ret.as_slice(), &ty),
|
||||
)))
|
||||
} else {
|
||||
to_rust_types!(obj,
|
||||
bool => Boolean,
|
||||
i64 => Int64,
|
||||
f64 => Float64,
|
||||
String => Utf8
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -1,455 +0,0 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use arrow::array::{make_array, ArrayData, Datum};
|
||||
use arrow::compute::kernels::{cmp, numeric};
|
||||
use arrow::pyarrow::{FromPyArrow, ToPyArrow};
|
||||
use datafusion::arrow::array::BooleanArray;
|
||||
use datafusion::arrow::compute;
|
||||
use datatypes::arrow::array::{Array, ArrayRef};
|
||||
use datatypes::arrow::datatypes::DataType as ArrowDataType;
|
||||
use datatypes::prelude::{ConcreteDataType, DataType};
|
||||
use datatypes::vectors::Helper;
|
||||
use pyo3::exceptions::{PyIndexError, PyRuntimeError, PyValueError};
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::pyclass::CompareOp;
|
||||
use pyo3::types::{PyBool, PyFloat, PyInt, PySequence, PySlice, PyString, PyType};
|
||||
|
||||
use super::utils::val_to_py_any;
|
||||
use crate::python::ffi_types::vector::{arrow_rtruediv, wrap_bool_result, wrap_result, PyVector};
|
||||
use crate::python::pyo3::utils::{pyo3_obj_try_to_typed_val, to_py_err};
|
||||
|
||||
macro_rules! get_con_type {
|
||||
($obj:ident, $($pyty:ident => $con_ty:ident),*$(,)?) => {
|
||||
$(
|
||||
if $obj.is_instance_of::<$pyty>() {
|
||||
Ok(ConcreteDataType::$con_ty())
|
||||
}
|
||||
) else* else{
|
||||
Err(PyValueError::new_err("Unsupported pyobject type: {obj:?}"))
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
fn get_py_type(obj: &PyAny) -> PyResult<ConcreteDataType> {
|
||||
// Bool need to precede Int because `PyBool` is also a instance of `PyInt`
|
||||
get_con_type!(obj,
|
||||
PyBool => boolean_datatype,
|
||||
PyInt => int64_datatype,
|
||||
PyFloat => float64_datatype,
|
||||
PyString => string_datatype
|
||||
)
|
||||
}
|
||||
|
||||
fn pyo3_is_obj_scalar(obj: &PyAny) -> bool {
|
||||
get_py_type(obj).is_ok()
|
||||
}
|
||||
|
||||
impl PyVector {
|
||||
fn pyo3_scalar_arith_op<F>(
|
||||
&self,
|
||||
py: Python<'_>,
|
||||
right: PyObject,
|
||||
target_type: Option<ArrowDataType>,
|
||||
op: F,
|
||||
) -> PyResult<Self>
|
||||
where
|
||||
F: Fn(&dyn Datum, &dyn Datum) -> Result<ArrayRef, String> + Send,
|
||||
{
|
||||
let right = pyo3_obj_try_to_typed_val(right.as_ref(py), None)?;
|
||||
py.allow_threads(|| {
|
||||
self.scalar_arith_op(right, target_type, op)
|
||||
.map_err(PyValueError::new_err)
|
||||
})
|
||||
}
|
||||
fn pyo3_vector_arith_op<F>(
|
||||
&self,
|
||||
py: Python<'_>,
|
||||
right: PyObject,
|
||||
target_type: Option<ArrowDataType>,
|
||||
op: F,
|
||||
) -> PyResult<Self>
|
||||
where
|
||||
F: Fn(&dyn Datum, &dyn Datum) -> Result<ArrayRef, String> + Send,
|
||||
{
|
||||
let right = right.extract::<PyVector>(py)?;
|
||||
py.allow_threads(|| {
|
||||
self.vector_arith_op(&right, target_type, op)
|
||||
.map_err(PyValueError::new_err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl PyVector {
|
||||
/// convert from numpy array to [`PyVector`]
|
||||
#[classmethod]
|
||||
fn from_numpy(cls: &PyType, py: Python<'_>, obj: PyObject) -> PyResult<PyObject> {
|
||||
let pa = py.import("pyarrow")?;
|
||||
let obj = pa.call_method1("array", (obj,))?;
|
||||
let zelf = Self::from_pyarrow(cls, py, obj.into())?;
|
||||
Ok(zelf.into_py(py))
|
||||
}
|
||||
|
||||
fn numpy(&self, py: Python<'_>) -> PyResult<PyObject> {
|
||||
let pa_arrow = self.to_arrow_array().to_data().to_pyarrow(py)?;
|
||||
let ndarray = pa_arrow.call_method0(py, "to_numpy")?;
|
||||
Ok(ndarray)
|
||||
}
|
||||
|
||||
/// create a `PyVector` with a `PyList` that contains only elements of same type
|
||||
#[new]
|
||||
pub(crate) fn py_new(iterable: PyObject, py: Python<'_>) -> PyResult<Self> {
|
||||
let iterable = iterable.downcast::<PySequence>(py)?;
|
||||
let dtype = get_py_type(iterable.get_item(0)?)?;
|
||||
let mut buf = dtype.create_mutable_vector(iterable.len()?);
|
||||
for i in 0..iterable.len()? {
|
||||
let element = iterable.get_item(i)?;
|
||||
let val = pyo3_obj_try_to_typed_val(element, Some(dtype.clone()))?;
|
||||
buf.push_value_ref(val.as_value_ref());
|
||||
}
|
||||
Ok(buf.to_vector().into())
|
||||
}
|
||||
fn __richcmp__(&self, py: Python<'_>, other: PyObject, op: CompareOp) -> PyResult<Self> {
|
||||
let op_fn = match op {
|
||||
CompareOp::Lt => cmp::lt,
|
||||
CompareOp::Le => cmp::lt_eq,
|
||||
CompareOp::Eq => cmp::eq,
|
||||
CompareOp::Ne => cmp::neq,
|
||||
CompareOp::Gt => cmp::gt,
|
||||
CompareOp::Ge => cmp::gt_eq,
|
||||
};
|
||||
if pyo3_is_obj_scalar(other.as_ref(py)) {
|
||||
self.pyo3_scalar_arith_op(py, other, None, wrap_bool_result(op_fn))
|
||||
} else {
|
||||
self.pyo3_vector_arith_op(py, other, None, wrap_bool_result(op_fn))
|
||||
}
|
||||
}
|
||||
|
||||
fn __add__(&self, py: Python<'_>, other: PyObject) -> PyResult<Self> {
|
||||
if pyo3_is_obj_scalar(other.as_ref(py)) {
|
||||
self.pyo3_scalar_arith_op(py, other, None, wrap_result(numeric::add))
|
||||
} else {
|
||||
self.pyo3_vector_arith_op(py, other, None, wrap_result(numeric::add))
|
||||
}
|
||||
}
|
||||
fn __radd__(&self, py: Python<'_>, other: PyObject) -> PyResult<Self> {
|
||||
self.__add__(py, other)
|
||||
}
|
||||
|
||||
fn __sub__(&self, py: Python<'_>, other: PyObject) -> PyResult<Self> {
|
||||
if pyo3_is_obj_scalar(other.as_ref(py)) {
|
||||
self.pyo3_scalar_arith_op(py, other, None, wrap_result(numeric::sub))
|
||||
} else {
|
||||
self.pyo3_vector_arith_op(py, other, None, wrap_result(numeric::sub))
|
||||
}
|
||||
}
|
||||
fn __rsub__(&self, py: Python<'_>, other: PyObject) -> PyResult<Self> {
|
||||
if pyo3_is_obj_scalar(other.as_ref(py)) {
|
||||
self.pyo3_scalar_arith_op(py, other, None, wrap_result(|a, b| numeric::sub(b, a)))
|
||||
} else {
|
||||
self.pyo3_vector_arith_op(py, other, None, wrap_result(|a, b| numeric::sub(b, a)))
|
||||
}
|
||||
}
|
||||
fn __mul__(&self, py: Python<'_>, other: PyObject) -> PyResult<Self> {
|
||||
if pyo3_is_obj_scalar(other.as_ref(py)) {
|
||||
self.pyo3_scalar_arith_op(py, other, None, wrap_result(numeric::mul))
|
||||
} else {
|
||||
self.pyo3_vector_arith_op(py, other, None, wrap_result(numeric::mul))
|
||||
}
|
||||
}
|
||||
fn __rmul__(&self, py: Python<'_>, other: PyObject) -> PyResult<Self> {
|
||||
self.__mul__(py, other)
|
||||
}
|
||||
fn __truediv__(&self, py: Python<'_>, other: PyObject) -> PyResult<Self> {
|
||||
if pyo3_is_obj_scalar(other.as_ref(py)) {
|
||||
self.pyo3_scalar_arith_op(
|
||||
py,
|
||||
other,
|
||||
Some(ArrowDataType::Float64),
|
||||
wrap_result(numeric::div),
|
||||
)
|
||||
} else {
|
||||
self.pyo3_vector_arith_op(
|
||||
py,
|
||||
other,
|
||||
Some(ArrowDataType::Float64),
|
||||
wrap_result(numeric::div),
|
||||
)
|
||||
}
|
||||
}
|
||||
#[allow(unused)]
|
||||
fn __rtruediv__(&self, py: Python<'_>, other: PyObject) -> PyResult<Self> {
|
||||
if pyo3_is_obj_scalar(other.as_ref(py)) {
|
||||
self.pyo3_scalar_arith_op(py, other, Some(ArrowDataType::Float64), arrow_rtruediv)
|
||||
} else {
|
||||
self.pyo3_vector_arith_op(
|
||||
py,
|
||||
other,
|
||||
Some(ArrowDataType::Float64),
|
||||
wrap_result(|a, b| numeric::div(b, a)),
|
||||
)
|
||||
}
|
||||
}
|
||||
#[allow(unused)]
|
||||
fn __floordiv__(&self, py: Python<'_>, other: PyObject) -> PyResult<Self> {
|
||||
if pyo3_is_obj_scalar(other.as_ref(py)) {
|
||||
self.pyo3_scalar_arith_op(
|
||||
py,
|
||||
other,
|
||||
Some(ArrowDataType::Int64),
|
||||
wrap_result(numeric::div),
|
||||
)
|
||||
} else {
|
||||
self.pyo3_vector_arith_op(
|
||||
py,
|
||||
other,
|
||||
Some(ArrowDataType::Int64),
|
||||
wrap_result(numeric::div),
|
||||
)
|
||||
}
|
||||
}
|
||||
#[allow(unused)]
|
||||
fn __rfloordiv__(&self, py: Python<'_>, other: PyObject) -> PyResult<Self> {
|
||||
if pyo3_is_obj_scalar(other.as_ref(py)) {
|
||||
self.pyo3_scalar_arith_op(py, other, Some(ArrowDataType::Int64), arrow_rtruediv)
|
||||
} else {
|
||||
self.pyo3_vector_arith_op(
|
||||
py,
|
||||
other,
|
||||
Some(ArrowDataType::Int64),
|
||||
wrap_result(|a, b| numeric::div(b, a)),
|
||||
)
|
||||
}
|
||||
}
|
||||
fn __and__(&self, other: &Self) -> PyResult<Self> {
|
||||
Self::vector_and(self, other).map_err(PyValueError::new_err)
|
||||
}
|
||||
fn __or__(&self, other: &Self) -> PyResult<Self> {
|
||||
Self::vector_or(self, other).map_err(PyValueError::new_err)
|
||||
}
|
||||
fn __invert__(&self) -> PyResult<Self> {
|
||||
Self::vector_invert(self).map_err(PyValueError::new_err)
|
||||
}
|
||||
|
||||
#[pyo3(name = "concat")]
|
||||
fn pyo3_concat(&self, py: Python<'_>, other: &Self) -> PyResult<Self> {
|
||||
py.allow_threads(|| {
|
||||
let left = self.to_arrow_array();
|
||||
let right = other.to_arrow_array();
|
||||
|
||||
let res = compute::concat(&[left.as_ref(), right.as_ref()]);
|
||||
let res = res.map_err(|err| PyValueError::new_err(format!("Arrow Error: {err:#?}")))?;
|
||||
let ret = Helper::try_into_vector(res.clone()).map_err(|e| {
|
||||
PyValueError::new_err(format!(
|
||||
"Can't cast result into vector, result: {res:?}, err: {e:?}",
|
||||
))
|
||||
})?;
|
||||
Ok(ret.into())
|
||||
})
|
||||
}
|
||||
|
||||
/// take a boolean array and filters the Array, returning elements matching the filter (i.e. where the values are true).
|
||||
#[pyo3(name = "filter")]
|
||||
fn pyo3_filter(&self, py: Python<'_>, other: &Self) -> PyResult<Self> {
|
||||
py.allow_threads(|| {
|
||||
let left = self.to_arrow_array();
|
||||
let right = other.to_arrow_array();
|
||||
if let Some(filter) = right.as_any().downcast_ref::<BooleanArray>() {
|
||||
let res = compute::filter(left.as_ref(), filter);
|
||||
let res =
|
||||
res.map_err(|err| PyValueError::new_err(format!("Arrow Error: {err:#?}")))?;
|
||||
let ret = Helper::try_into_vector(res.clone()).map_err(|e| {
|
||||
PyValueError::new_err(format!(
|
||||
"Can't cast result into vector, result: {res:?}, err: {e:?}",
|
||||
))
|
||||
})?;
|
||||
Ok(ret.into())
|
||||
} else {
|
||||
Err(PyValueError::new_err(format!(
|
||||
"Can't cast operand into a Boolean Array, which is {right:#?}"
|
||||
)))
|
||||
}
|
||||
})
|
||||
}
|
||||
fn __len__(&self) -> usize {
|
||||
self.len()
|
||||
}
|
||||
fn __doc__(&self) -> PyResult<String> {
|
||||
Ok("PyVector is like a Python array, a compact array of elem of same datatype, but Readonly for now".to_string())
|
||||
}
|
||||
fn __repr__(&self) -> PyResult<String> {
|
||||
Ok(format!("{self:#?}"))
|
||||
}
|
||||
/// Convert to `pyarrow` 's array
|
||||
pub(crate) fn to_pyarrow(&self, py: Python) -> PyResult<PyObject> {
|
||||
self.to_arrow_array().to_data().to_pyarrow(py)
|
||||
}
|
||||
/// Convert from `pyarrow`'s array
|
||||
#[classmethod]
|
||||
pub(crate) fn from_pyarrow(_cls: &PyType, py: Python, obj: PyObject) -> PyResult<PyVector> {
|
||||
let array = make_array(ArrayData::from_pyarrow(obj.as_ref(py))?);
|
||||
let v = Helper::try_into_vector(array).map_err(to_py_err)?;
|
||||
Ok(v.into())
|
||||
}
|
||||
|
||||
/// PyO3's Magic Method for slicing and indexing
|
||||
fn __getitem__(&self, py: Python, needle: PyObject) -> PyResult<PyObject> {
|
||||
if let Ok(needle) = needle.extract::<PyVector>(py) {
|
||||
let mask = needle.to_arrow_array();
|
||||
let mask = mask
|
||||
.as_any()
|
||||
.downcast_ref::<BooleanArray>()
|
||||
.ok_or_else(|| {
|
||||
PyValueError::new_err(
|
||||
"A Boolean Array is requested for slicing, found {mask:?}",
|
||||
)
|
||||
})?;
|
||||
let result = compute::filter(&self.to_arrow_array(), mask)
|
||||
.map_err(|err| PyRuntimeError::new_err(format!("Arrow Error: {err:#?}")))?;
|
||||
let ret = Helper::try_into_vector(result.clone()).map_err(|e| {
|
||||
PyRuntimeError::new_err(format!("Can't cast result into vector, err: {e:?}"))
|
||||
})?;
|
||||
let ret = Self::from(ret).into_py(py);
|
||||
Ok(ret)
|
||||
} else if let Ok(slice) = needle.downcast::<PySlice>(py) {
|
||||
let indices = slice.indices(self.len() as _)?;
|
||||
let (start, stop, step, _slicelength) = (
|
||||
indices.start,
|
||||
indices.stop,
|
||||
indices.step,
|
||||
indices.slicelength,
|
||||
);
|
||||
if start < 0 {
|
||||
return Err(PyValueError::new_err(format!(
|
||||
"Negative start is not supported, found {start} in {indices:?}"
|
||||
)));
|
||||
} // Negative stop is supported, means from "indices.start" to the actual start of the vector
|
||||
let vector = self.as_vector_ref();
|
||||
|
||||
let mut buf = vector
|
||||
.data_type()
|
||||
.create_mutable_vector(indices.slicelength as usize);
|
||||
let v = if indices.slicelength == 0 {
|
||||
buf.to_vector()
|
||||
} else {
|
||||
if indices.step > 0 {
|
||||
let range = if stop == -1 {
|
||||
start as usize..start as usize
|
||||
} else {
|
||||
start as usize..stop as usize
|
||||
};
|
||||
for i in range.step_by(step.unsigned_abs()) {
|
||||
buf.push_value_ref(vector.get_ref(i));
|
||||
}
|
||||
} else {
|
||||
// if no-empty, then stop < start
|
||||
// note: start..stop is empty is start >= stop
|
||||
// stop>=-1
|
||||
let range = { (stop + 1) as usize..=start as usize };
|
||||
for i in range.rev().step_by(step.unsigned_abs()) {
|
||||
buf.push_value_ref(vector.get_ref(i));
|
||||
}
|
||||
}
|
||||
buf.to_vector()
|
||||
};
|
||||
let v: PyVector = v.into();
|
||||
Ok(v.into_py(py))
|
||||
} else if let Ok(index) = needle.extract::<isize>(py) {
|
||||
// deal with negative index
|
||||
let len = self.len() as isize;
|
||||
let index = if index < 0 { len + index } else { index };
|
||||
if index < 0 || index >= len {
|
||||
return Err(PyIndexError::new_err(format!(
|
||||
"Index out of bound, index: {index}, len: {len}",
|
||||
index = index,
|
||||
len = len
|
||||
)));
|
||||
}
|
||||
let val = self.as_vector_ref().get(index as usize);
|
||||
val_to_py_any(py, val)
|
||||
} else {
|
||||
Err(PyValueError::new_err(
|
||||
"{needle:?} is neither a Vector nor a int, can't use for slicing or indexing",
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn into_pyo3_cell(py: Python, val: PyVector) -> PyResult<&PyCell<PyVector>> {
|
||||
PyCell::new(py, val)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use datatypes::scalars::ScalarVector;
|
||||
use datatypes::vectors::{BooleanVector, Float64Vector, VectorRef};
|
||||
use pyo3::types::{PyDict, PyModule};
|
||||
use pyo3::{PyCell, Python};
|
||||
|
||||
use crate::python::ffi_types::vector::PyVector;
|
||||
use crate::python::pyo3::init_cpython_interpreter;
|
||||
fn sample_vector() -> HashMap<String, PyVector> {
|
||||
let mut locals = HashMap::new();
|
||||
let b = BooleanVector::from_slice(&[true, false, true, true]);
|
||||
let b: PyVector = (Arc::new(b) as VectorRef).into();
|
||||
locals.insert("bv1".to_string(), b);
|
||||
let b = BooleanVector::from_slice(&[false, false, false, true]);
|
||||
let b: PyVector = (Arc::new(b) as VectorRef).into();
|
||||
locals.insert("bv2".to_string(), b);
|
||||
|
||||
let f = Float64Vector::from_slice([0.0f64, 1.0, 42.0, 3.0]);
|
||||
let f: PyVector = (Arc::new(f) as VectorRef).into();
|
||||
locals.insert("fv1".to_string(), f);
|
||||
let f = Float64Vector::from_slice([1919.810f64, 0.114, 51.4, 3.0]);
|
||||
let f: PyVector = (Arc::new(f) as VectorRef).into();
|
||||
locals.insert("fv2".to_string(), f);
|
||||
locals
|
||||
}
|
||||
#[test]
|
||||
fn test_py_vector_api() {
|
||||
init_cpython_interpreter().unwrap();
|
||||
Python::with_gil(|py| {
|
||||
let module = PyModule::new(py, "gt").unwrap();
|
||||
module.add_class::<PyVector>().unwrap();
|
||||
// Import and get sys.modules
|
||||
let sys = PyModule::import(py, "sys").unwrap();
|
||||
let py_modules: &PyDict = sys.getattr("modules").unwrap().downcast().unwrap();
|
||||
|
||||
// Insert foo into sys.modules
|
||||
py_modules.set_item("gt", module).unwrap();
|
||||
|
||||
let locals = PyDict::new(py);
|
||||
for (k, v) in sample_vector() {
|
||||
locals.set_item(k, PyCell::new(py, v).unwrap()).unwrap();
|
||||
}
|
||||
// ~bool_v1&bool_v2
|
||||
py.run(
|
||||
r#"
|
||||
from gt import vector
|
||||
print(vector([1,2]))
|
||||
print(fv1+fv2)
|
||||
"#,
|
||||
None,
|
||||
Some(locals),
|
||||
)
|
||||
.unwrap();
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -1,26 +0,0 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
mod copr_impl;
|
||||
#[cfg(test)]
|
||||
mod test;
|
||||
pub(crate) mod vector_impl;
|
||||
|
||||
pub(crate) mod builtins;
|
||||
mod dataframe_impl;
|
||||
pub(crate) mod utils;
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) use copr_impl::init_interpreter;
|
||||
pub(crate) use copr_impl::rspy_exec_parsed;
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,461 +0,0 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::fs::File;
|
||||
use std::io::Read;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow::array::Array;
|
||||
use common_telemetry::{error, info};
|
||||
use datatypes::arrow::array::{Float64Array, Int64Array};
|
||||
use datatypes::arrow::compute;
|
||||
use datatypes::arrow::datatypes::{DataType as ArrowDataType, Field};
|
||||
use datatypes::vectors::{Float64Vector, Int64Vector, VectorRef};
|
||||
use ron::from_str as from_ron_string;
|
||||
use rustpython_vm::builtins::{PyFloat, PyInt, PyList};
|
||||
use rustpython_vm::class::PyClassImpl;
|
||||
use rustpython_vm::convert::ToPyObject;
|
||||
use rustpython_vm::scope::Scope;
|
||||
use rustpython_vm::{AsObject, PyObjectRef, VirtualMachine};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::*;
|
||||
use crate::python::ffi_types::PyVector;
|
||||
use crate::python::rspython::utils::is_instance;
|
||||
use crate::python::utils::format_py_error;
|
||||
#[test]
|
||||
fn convert_scalar_to_py_obj_and_back() {
|
||||
rustpython_vm::Interpreter::with_init(Default::default(), |vm| {
|
||||
// this can be in `.enter()` closure, but for clarity, put it in the `with_init()`
|
||||
let _ = PyVector::make_class(&vm.ctx);
|
||||
})
|
||||
.enter(|vm| {
|
||||
let col = DFColValue::Scalar(ScalarValue::Float64(Some(1.0)));
|
||||
let to = try_into_py_obj(col, vm).unwrap();
|
||||
let back = try_into_columnar_value(to, vm).unwrap();
|
||||
if let DFColValue::Scalar(ScalarValue::Float64(Some(v))) = back {
|
||||
if (v - 1.0).abs() > 2.0 * f64::EPSILON {
|
||||
panic!("Expect 1.0, found {v}")
|
||||
}
|
||||
} else {
|
||||
panic!("Convert errors, expect 1.0")
|
||||
}
|
||||
let col = DFColValue::Scalar(ScalarValue::Int64(Some(1)));
|
||||
let to = try_into_py_obj(col, vm).unwrap();
|
||||
let back = try_into_columnar_value(to, vm).unwrap();
|
||||
if let DFColValue::Scalar(ScalarValue::Int64(Some(v))) = back {
|
||||
assert_eq!(v, 1);
|
||||
} else {
|
||||
panic!("Convert errors, expect 1")
|
||||
}
|
||||
let col = DFColValue::Scalar(ScalarValue::UInt64(Some(1)));
|
||||
let to = try_into_py_obj(col, vm).unwrap();
|
||||
let back = try_into_columnar_value(to, vm).unwrap();
|
||||
if let DFColValue::Scalar(ScalarValue::Int64(Some(v))) = back {
|
||||
assert_eq!(v, 1);
|
||||
} else {
|
||||
panic!("Convert errors, expect 1")
|
||||
}
|
||||
let col = DFColValue::Scalar(ScalarValue::List(ScalarValue::new_list(
|
||||
&[ScalarValue::Int64(Some(1)), ScalarValue::Int64(Some(2))],
|
||||
&ArrowDataType::Int64,
|
||||
)));
|
||||
let to = try_into_py_obj(col, vm).unwrap();
|
||||
let back = try_into_columnar_value(to, vm).unwrap();
|
||||
if let DFColValue::Scalar(ScalarValue::List(list)) = back {
|
||||
assert_eq!(list.len(), 1);
|
||||
assert_eq!(
|
||||
list.data_type(),
|
||||
&ArrowDataType::List(Arc::new(Field::new_list_field(ArrowDataType::Int64, true)))
|
||||
);
|
||||
}
|
||||
let list: Vec<PyObjectRef> = vec![vm.ctx.new_int(1).into(), vm.ctx.new_int(2).into()];
|
||||
let nested_list: Vec<PyObjectRef> =
|
||||
vec![vm.ctx.new_list(list).into(), vm.ctx.new_int(3).into()];
|
||||
let list_obj = vm.ctx.new_list(nested_list).into();
|
||||
let col = try_into_columnar_value(list_obj, vm);
|
||||
if let Err(err) = col {
|
||||
let reason = format_py_error(err, vm);
|
||||
assert!(format!("{reason}").contains(
|
||||
"TypeError: All elements in a list should be same type to cast to Datafusion list!"
|
||||
));
|
||||
}
|
||||
|
||||
let list: PyVector =
|
||||
PyVector::from(
|
||||
Arc::new(Float64Vector::from_slice([0.1f64, 0.2, 0.3, 0.4])) as VectorRef
|
||||
);
|
||||
let nested_list: Vec<PyObjectRef> = vec![list.into_pyobject(vm), vm.ctx.new_int(3).into()];
|
||||
let list_obj = vm.ctx.new_list(nested_list).into();
|
||||
let expect_err = try_into_columnar_value(list_obj, vm);
|
||||
assert!(expect_err.is_err());
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct TestCase {
|
||||
input: HashMap<String, Var>,
|
||||
script: String,
|
||||
expect: Result<Var, String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct Var {
|
||||
value: PyValue,
|
||||
ty: ArrowDataType,
|
||||
}
|
||||
|
||||
/// for floating number comparison
|
||||
const EPS: f64 = 2.0 * f64::EPSILON;
|
||||
|
||||
/// Null element just not supported for now for simplicity with writing test cases
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
enum PyValue {
|
||||
FloatVec(Vec<f64>),
|
||||
FloatVecWithNull(Vec<Option<f64>>),
|
||||
IntVec(Vec<i64>),
|
||||
IntVecWithNull(Vec<Option<i64>>),
|
||||
Int(i64),
|
||||
Float(f64),
|
||||
Bool(bool),
|
||||
Str(String),
|
||||
/// for test if the length of FloatVec is of the same as `LenFloatVec.0`
|
||||
LenFloatVec(usize),
|
||||
/// for test if the length of IntVec is of the same as `LenIntVec.0`
|
||||
LenIntVec(usize),
|
||||
/// for test if result is within the bound of err using formula:
|
||||
/// `(res - value).abs() < (value.abs()* error_percent)`
|
||||
FloatWithError {
|
||||
value: f64,
|
||||
error_percent: f64,
|
||||
},
|
||||
}
|
||||
|
||||
impl PyValue {
|
||||
/// compare if results is just as expect, not using PartialEq because it is not transitive .e.g. [1,2,3] == len(3) == [4,5,6]
|
||||
fn just_as_expect(&self, other: &Self) -> bool {
|
||||
match (self, other) {
|
||||
(PyValue::FloatVec(a), PyValue::FloatVec(b)) => a
|
||||
.iter()
|
||||
.zip(b)
|
||||
.fold(true, |acc, (x, y)| acc && (x - y).abs() <= EPS),
|
||||
|
||||
(Self::FloatVecWithNull(a), Self::FloatVecWithNull(b)) => a == b,
|
||||
|
||||
(PyValue::IntVec(a), PyValue::IntVec(b)) => a == b,
|
||||
|
||||
(PyValue::Float(a), PyValue::Float(b)) => (a - b).abs() <= EPS,
|
||||
|
||||
(PyValue::Int(a), PyValue::Int(b)) => a == b,
|
||||
|
||||
// for just compare the length of vector
|
||||
(PyValue::LenFloatVec(len), PyValue::FloatVec(v)) => *len == v.len(),
|
||||
|
||||
(PyValue::LenIntVec(len), PyValue::IntVec(v)) => *len == v.len(),
|
||||
|
||||
(PyValue::FloatVec(v), PyValue::LenFloatVec(len)) => *len == v.len(),
|
||||
|
||||
(PyValue::IntVec(v), PyValue::LenIntVec(len)) => *len == v.len(),
|
||||
|
||||
(
|
||||
Self::Float(v),
|
||||
Self::FloatWithError {
|
||||
value,
|
||||
error_percent,
|
||||
},
|
||||
)
|
||||
| (
|
||||
Self::FloatWithError {
|
||||
value,
|
||||
error_percent,
|
||||
},
|
||||
Self::Float(v),
|
||||
) => (v - value).abs() < (value.abs() * error_percent),
|
||||
(_, _) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn is_float(ty: &ArrowDataType) -> bool {
|
||||
matches!(
|
||||
ty,
|
||||
ArrowDataType::Float16 | ArrowDataType::Float32 | ArrowDataType::Float64
|
||||
)
|
||||
}
|
||||
|
||||
/// unsigned included
|
||||
fn is_int(ty: &ArrowDataType) -> bool {
|
||||
matches!(
|
||||
ty,
|
||||
ArrowDataType::UInt8
|
||||
| ArrowDataType::UInt16
|
||||
| ArrowDataType::UInt32
|
||||
| ArrowDataType::UInt64
|
||||
| ArrowDataType::Int8
|
||||
| ArrowDataType::Int16
|
||||
| ArrowDataType::Int32
|
||||
| ArrowDataType::Int64
|
||||
)
|
||||
}
|
||||
|
||||
impl PyValue {
|
||||
fn to_py_obj(&self, vm: &VirtualMachine) -> Result<PyObjectRef, String> {
|
||||
let v: VectorRef = match self {
|
||||
PyValue::FloatVec(v) => {
|
||||
Arc::new(datatypes::vectors::Float64Vector::from_vec(v.clone()))
|
||||
}
|
||||
PyValue::IntVec(v) => Arc::new(Int64Vector::from_vec(v.clone())),
|
||||
PyValue::Int(v) => return Ok(vm.ctx.new_int(*v).into()),
|
||||
PyValue::Float(v) => return Ok(vm.ctx.new_float(*v).into()),
|
||||
Self::Bool(v) => return Ok(vm.ctx.new_bool(*v).into()),
|
||||
Self::Str(s) => return Ok(vm.ctx.new_str(s.as_str()).into()),
|
||||
_ => return Err(format!("Unsupported type:{self:#?}")),
|
||||
};
|
||||
let v = PyVector::from(v).to_pyobject(vm);
|
||||
Ok(v)
|
||||
}
|
||||
|
||||
fn from_py_obj(obj: &PyObjectRef, vm: &VirtualMachine) -> Result<Self, String> {
|
||||
if is_instance::<PyVector>(obj, vm) {
|
||||
let res = obj.payload::<PyVector>().unwrap();
|
||||
let res = res.to_arrow_array();
|
||||
let ty = res.data_type();
|
||||
if is_float(ty) {
|
||||
let vec_f64 = compute::cast(&res, &ArrowDataType::Float64)
|
||||
.map_err(|err| format!("{err:#?}"))?;
|
||||
assert_eq!(vec_f64.data_type(), &ArrowDataType::Float64);
|
||||
let vec_f64 = vec_f64
|
||||
.as_any()
|
||||
.downcast_ref::<Float64Array>()
|
||||
.ok_or_else(|| format!("Can't cast {vec_f64:#?} to Float64Array!"))?;
|
||||
let ret = vec_f64.into_iter().collect::<Vec<_>>();
|
||||
if ret.iter().all(|x| x.is_some()) {
|
||||
Ok(Self::FloatVec(
|
||||
ret.into_iter().map(|i| i.unwrap()).collect(),
|
||||
))
|
||||
} else {
|
||||
Ok(Self::FloatVecWithNull(ret))
|
||||
}
|
||||
} else if is_int(ty) {
|
||||
let vec_int = compute::cast(&res, &ArrowDataType::Int64)
|
||||
.map_err(|err| format!("{err:#?}"))?;
|
||||
assert_eq!(vec_int.data_type(), &ArrowDataType::Int64);
|
||||
let vec_i64 = vec_int
|
||||
.as_any()
|
||||
.downcast_ref::<Int64Array>()
|
||||
.ok_or_else(|| format!("Can't cast {vec_int:#?} to Int64Array!"))?;
|
||||
let ret: Vec<i64> = vec_i64
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(idx, v)| {
|
||||
v.ok_or_else(|| {
|
||||
format!("No null element expected, found one in {idx} position")
|
||||
})
|
||||
})
|
||||
.collect::<Result<_, String>>()?;
|
||||
Ok(Self::IntVec(ret))
|
||||
} else {
|
||||
Err(format!("unspupported ArrowDataType:{ty:#?}"))
|
||||
}
|
||||
} else if is_instance::<PyInt>(obj, vm) {
|
||||
let res = obj
|
||||
.clone()
|
||||
.try_into_value::<i64>(vm)
|
||||
.map_err(|err| format_py_error(err, vm).to_string())?;
|
||||
Ok(Self::Int(res))
|
||||
} else if is_instance::<PyFloat>(obj, vm) {
|
||||
let res = obj
|
||||
.clone()
|
||||
.try_into_value::<f64>(vm)
|
||||
.map_err(|err| format_py_error(err, vm).to_string())?;
|
||||
Ok(Self::Float(res))
|
||||
} else if is_instance::<PyList>(obj, vm) {
|
||||
let res = obj.payload::<PyList>().unwrap();
|
||||
let res: Vec<f64> = res
|
||||
.borrow_vec()
|
||||
.iter()
|
||||
.map(|obj| {
|
||||
let res = Self::from_py_obj(obj, vm).unwrap();
|
||||
assert!(matches!(res, Self::Float(_) | Self::Int(_)));
|
||||
match res {
|
||||
Self::Float(v) => Ok(v),
|
||||
Self::Int(v) => Ok(v as f64),
|
||||
_ => Err(format!("Expect only int/float in list, found {res:#?}")),
|
||||
}
|
||||
})
|
||||
.collect::<Result<_, _>>()?;
|
||||
Ok(Self::FloatVec(res))
|
||||
} else {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn run_builtin_fn_testcases() {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
|
||||
let loc = Path::new("src/python/rspython/builtins/testcases.ron");
|
||||
let loc = loc.to_str().expect("Fail to parse path");
|
||||
let mut file = File::open(loc).expect("Fail to open file");
|
||||
let mut buf = String::new();
|
||||
let _ = file.read_to_string(&mut buf).unwrap();
|
||||
let testcases: Vec<TestCase> = from_ron_string(&buf).expect("Fail to convert to testcases");
|
||||
let cached_vm = rustpython_vm::Interpreter::with_init(Default::default(), |vm| {
|
||||
vm.add_native_module("greptime", Box::new(greptime_builtin::make_module));
|
||||
let _ = PyVector::make_class(&vm.ctx);
|
||||
});
|
||||
for (idx, case) in testcases.into_iter().enumerate() {
|
||||
info!("Testcase {idx} ...");
|
||||
cached_vm
|
||||
.enter(|vm| {
|
||||
let scope = vm.new_scope_with_builtins();
|
||||
case.input
|
||||
.iter()
|
||||
.try_for_each(|(k, v)| -> Result<(), String> {
|
||||
let v = PyValue::to_py_obj(&v.value, vm).unwrap();
|
||||
set_item_into_scope(&scope, vm, k, v)
|
||||
})
|
||||
.unwrap();
|
||||
let code_obj = vm
|
||||
.compile(
|
||||
&case.script,
|
||||
rustpython_compiler_core::Mode::BlockExpr,
|
||||
"<embedded>".to_string(),
|
||||
)
|
||||
.map_err(|err| vm.new_syntax_error(&err))
|
||||
.unwrap();
|
||||
let res = vm.run_code_obj(code_obj, scope);
|
||||
match res {
|
||||
Err(e) => {
|
||||
let err_res = format_py_error(e, vm).to_string();
|
||||
match case.expect{
|
||||
Ok(v) => {
|
||||
error!("\nError:\n{err_res}");
|
||||
panic!("Expect Ok: {v:?}, found Error in case {}", case.script);
|
||||
},
|
||||
Err(err) => {
|
||||
if !err_res.contains(&err){
|
||||
panic!("Error message not containing, expect {err_res}, found {err}")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(obj) => {
|
||||
let ser = PyValue::from_py_obj(&obj, vm);
|
||||
match (ser, case.expect){
|
||||
(Ok(real), Ok(expect)) => {
|
||||
if !(real.just_as_expect(&expect.value)){
|
||||
panic!("Not as Expected for code:\n{}\n Real Value is {real:#?}, but expect {expect:#?}", case.script)
|
||||
}
|
||||
},
|
||||
(Err(real), Err(expect)) => {
|
||||
if !expect.contains(&real){
|
||||
panic!("Expect Err(\"{expect}\"), found {real}")
|
||||
}
|
||||
},
|
||||
(Ok(real), Err(expect)) => panic!("Expect Err({expect}), found Ok({real:?})"),
|
||||
(Err(real), Ok(expect)) => panic!("Expect Ok({expect:?}), found Err({real})"),
|
||||
};
|
||||
}
|
||||
};
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
fn set_item_into_scope(
|
||||
scope: &Scope,
|
||||
vm: &VirtualMachine,
|
||||
name: &str,
|
||||
value: impl ToPyObject,
|
||||
) -> Result<(), String> {
|
||||
scope
|
||||
.locals
|
||||
.as_object()
|
||||
.set_item(&name.to_string(), vm.new_pyobj(value), vm)
|
||||
.map_err(|err| {
|
||||
format!(
|
||||
"Error in setting var {name} in scope: \n{}",
|
||||
format_py_error(err, vm)
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
fn set_lst_of_vecs_in_scope(
|
||||
scope: &Scope,
|
||||
vm: &VirtualMachine,
|
||||
arg_names: &[&str],
|
||||
args: Vec<PyVector>,
|
||||
) -> Result<(), String> {
|
||||
let res = arg_names.iter().zip(args).try_for_each(|(name, vector)| {
|
||||
scope
|
||||
.locals
|
||||
.as_object()
|
||||
.set_item(&name.to_string(), vm.new_pyobj(vector), vm)
|
||||
.map_err(|err| {
|
||||
format!(
|
||||
"Error in setting var {name} in scope: \n{}",
|
||||
format_py_error(err, vm)
|
||||
)
|
||||
})
|
||||
});
|
||||
res
|
||||
}
|
||||
|
||||
#[allow(unused_must_use)]
|
||||
#[test]
|
||||
fn test_vm() {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
|
||||
rustpython_vm::Interpreter::with_init(Default::default(), |vm| {
|
||||
vm.add_native_module("udf_builtins", Box::new(greptime_builtin::make_module));
|
||||
// this can be in `.enter()` closure, but for clarity, put it in the `with_init()`
|
||||
let _ = PyVector::make_class(&vm.ctx);
|
||||
})
|
||||
.enter(|vm| {
|
||||
let values = vec![1.0, 2.0, 3.0];
|
||||
let pows = vec![0i8, -1i8, 3i8];
|
||||
|
||||
let args: Vec<VectorRef> = vec![
|
||||
Arc::new(datatypes::vectors::Float32Vector::from_vec(values)),
|
||||
Arc::new(datatypes::vectors::Int8Vector::from_vec(pows)),
|
||||
];
|
||||
let args: Vec<PyVector> = args.into_iter().map(PyVector::from).collect();
|
||||
|
||||
let scope = vm.new_scope_with_builtins();
|
||||
set_lst_of_vecs_in_scope(&scope, vm, &["values", "pows"], args).unwrap();
|
||||
let code_obj = vm
|
||||
.compile(
|
||||
r#"
|
||||
from udf_builtins import *
|
||||
sin(values)"#,
|
||||
rustpython_compiler_core::Mode::BlockExpr,
|
||||
"<embedded>".to_string(),
|
||||
)
|
||||
.map_err(|err| vm.new_syntax_error(&err))
|
||||
.unwrap();
|
||||
let res = vm.run_code_obj(code_obj, scope);
|
||||
match res {
|
||||
Err(e) => {
|
||||
let err_res = format_py_error(e, vm).to_string();
|
||||
error!("Error:\n{err_res}");
|
||||
}
|
||||
Ok(obj) => {
|
||||
let _ser = PyValue::from_py_obj(&obj, vm);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,240 +0,0 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::cell::RefCell;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::rc::Rc;
|
||||
use std::result::Result as StdResult;
|
||||
|
||||
use common_recordbatch::RecordBatch;
|
||||
use common_telemetry::info;
|
||||
use datatypes::vectors::VectorRef;
|
||||
use rustpython_vm::builtins::{PyBaseExceptionRef, PyDict, PyStr, PyTuple};
|
||||
use rustpython_vm::class::PyClassImpl;
|
||||
use rustpython_vm::convert::ToPyObject;
|
||||
use rustpython_vm::scope::Scope;
|
||||
use rustpython_vm::{vm, AsObject, Interpreter, PyObjectRef, PyPayload, VirtualMachine};
|
||||
use snafu::{OptionExt, ResultExt};
|
||||
|
||||
use crate::engine::EvalContext;
|
||||
use crate::python::error::{ensure, ret_other_error_with, NewRecordBatchSnafu, OtherSnafu, Result};
|
||||
use crate::python::ffi_types::copr::PyQueryEngine;
|
||||
use crate::python::ffi_types::py_recordbatch::PyRecordBatch;
|
||||
use crate::python::ffi_types::{check_args_anno_real_type, select_from_rb, Coprocessor, PyVector};
|
||||
use crate::python::metric;
|
||||
use crate::python::rspython::builtins::init_greptime_builtins;
|
||||
use crate::python::rspython::dataframe_impl::data_frame::set_dataframe_in_scope;
|
||||
use crate::python::rspython::dataframe_impl::init_data_frame;
|
||||
use crate::python::rspython::utils::{format_py_error, is_instance, py_obj_to_vec};
|
||||
|
||||
thread_local!(static INTERPRETER: RefCell<Option<Rc<Interpreter>>> = const { RefCell::new(None) });
|
||||
|
||||
/// Using `RustPython` to run a parsed `Coprocessor` struct as input to execute python code
|
||||
pub(crate) fn rspy_exec_parsed(
|
||||
copr: &Coprocessor,
|
||||
rb: &Option<RecordBatch>,
|
||||
params: &HashMap<String, String>,
|
||||
eval_ctx: &EvalContext,
|
||||
) -> Result<RecordBatch> {
|
||||
let _t = metric::METRIC_RSPY_EXEC_TOTAL_ELAPSED.start_timer();
|
||||
// 3. get args from `rb`, and cast them into PyVector
|
||||
let args: Vec<PyVector> = if let Some(rb) = rb {
|
||||
let arg_names = copr.deco_args.arg_names.clone().unwrap_or_default();
|
||||
let args = select_from_rb(rb, &arg_names)?;
|
||||
check_args_anno_real_type(&arg_names, &args, copr, rb)?;
|
||||
args
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
let interpreter = init_interpreter();
|
||||
// 4. then set args in scope and compile then run `CodeObject` which already append a new `Call` node
|
||||
exec_with_cached_vm(copr, rb, args, params, &interpreter, eval_ctx)
|
||||
}
|
||||
|
||||
/// set arguments with given name and values in python scopes
|
||||
fn set_items_in_scope(
|
||||
scope: &Scope,
|
||||
vm: &VirtualMachine,
|
||||
arg_names: &[String],
|
||||
args: Vec<PyVector>,
|
||||
) -> Result<()> {
|
||||
let _ = arg_names
|
||||
.iter()
|
||||
.zip(args)
|
||||
.map(|(name, vector)| {
|
||||
scope
|
||||
.locals
|
||||
.as_object()
|
||||
.set_item(name, vm.new_pyobj(vector), vm)
|
||||
})
|
||||
.collect::<StdResult<Vec<()>, PyBaseExceptionRef>>()
|
||||
.map_err(|e| format_py_error(e, vm))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn set_query_engine_in_scope(
|
||||
scope: &Scope,
|
||||
vm: &VirtualMachine,
|
||||
name: &str,
|
||||
query_engine: PyQueryEngine,
|
||||
) -> Result<()> {
|
||||
scope
|
||||
.locals
|
||||
.as_object()
|
||||
.set_item(name, query_engine.to_pyobject(vm), vm)
|
||||
.map_err(|e| format_py_error(e, vm))
|
||||
}
|
||||
|
||||
pub(crate) fn exec_with_cached_vm(
|
||||
copr: &Coprocessor,
|
||||
rb: &Option<RecordBatch>,
|
||||
args: Vec<PyVector>,
|
||||
params: &HashMap<String, String>,
|
||||
vm: &Rc<Interpreter>,
|
||||
eval_ctx: &EvalContext,
|
||||
) -> Result<RecordBatch> {
|
||||
vm.enter(|vm| -> Result<RecordBatch> {
|
||||
let _t = metric::METRIC_RSPY_EXEC_ELAPSED.start_timer();
|
||||
|
||||
// set arguments with given name and values
|
||||
let scope = vm.new_scope_with_builtins();
|
||||
if let Some(rb) = rb {
|
||||
set_dataframe_in_scope(&scope, vm, "__dataframe__", rb)?;
|
||||
}
|
||||
|
||||
if let Some(arg_names) = &copr.deco_args.arg_names {
|
||||
assert_eq!(arg_names.len(), args.len());
|
||||
set_items_in_scope(&scope, vm, arg_names, args)?;
|
||||
}
|
||||
|
||||
if let Some(engine) = &copr.query_engine {
|
||||
let query_engine =
|
||||
PyQueryEngine::from_weakref(engine.clone(), eval_ctx.query_ctx.clone());
|
||||
|
||||
// put a object named with query of class PyQueryEngine in scope
|
||||
set_query_engine_in_scope(&scope, vm, "__query__", query_engine)?;
|
||||
}
|
||||
|
||||
if let Some(kwarg) = &copr.kwarg {
|
||||
let dict = PyDict::new_ref(&vm.ctx);
|
||||
for (k, v) in params {
|
||||
dict.set_item(k, PyStr::from(v.clone()).into_pyobject(vm), vm)
|
||||
.map_err(|e| format_py_error(e, vm))?;
|
||||
}
|
||||
scope
|
||||
.locals
|
||||
.as_object()
|
||||
.set_item(kwarg, vm.new_pyobj(dict), vm)
|
||||
.map_err(|e| format_py_error(e, vm))?;
|
||||
}
|
||||
|
||||
// It's safe to unwrap code_object, it's already compiled before.
|
||||
let code_obj = vm.ctx.new_code(copr.code_obj.clone().unwrap());
|
||||
let ret = vm
|
||||
.run_code_obj(code_obj, scope)
|
||||
.map_err(|e| format_py_error(e, vm))?;
|
||||
|
||||
// 5. get returns as either a PyVector or a PyTuple, and naming schema them according to `returns`
|
||||
let col_len = rb.as_ref().map(|rb| rb.num_rows()).unwrap_or(1);
|
||||
let mut cols = try_into_columns(&ret, vm, col_len)?;
|
||||
ensure!(
|
||||
cols.len() == copr.deco_args.ret_names.len(),
|
||||
OtherSnafu {
|
||||
reason: format!(
|
||||
"The number of return Vector is wrong, expect {}, found {}",
|
||||
copr.deco_args.ret_names.len(),
|
||||
cols.len()
|
||||
)
|
||||
}
|
||||
);
|
||||
|
||||
// if cols and schema's data types is not match, try coerce it to given type(if annotated)(if error occur, return relevant error with question mark)
|
||||
copr.check_and_cast_type(&mut cols)?;
|
||||
|
||||
// 6. return a assembled DfRecordBatch
|
||||
let schema = copr.gen_schema(&cols)?;
|
||||
RecordBatch::new(schema, cols).context(NewRecordBatchSnafu)
|
||||
})
|
||||
}
|
||||
|
||||
/// convert a tuple of `PyVector` or one `PyVector`(wrapped in a Python Object Ref[`PyObjectRef`])
|
||||
/// to a `Vec<VectorRef>`
|
||||
/// by default, a constant(int/float/bool) gives the a constant array of same length with input args
|
||||
fn try_into_columns(
|
||||
obj: &PyObjectRef,
|
||||
vm: &VirtualMachine,
|
||||
col_len: usize,
|
||||
) -> Result<Vec<VectorRef>> {
|
||||
if is_instance::<PyTuple>(obj, vm) {
|
||||
let tuple = obj
|
||||
.payload::<PyTuple>()
|
||||
.with_context(|| ret_other_error_with(format!("can't cast obj {obj:?} to PyTuple)")))?;
|
||||
let cols = tuple
|
||||
.iter()
|
||||
.map(|obj| py_obj_to_vec(obj, vm, col_len))
|
||||
.collect::<Result<Vec<VectorRef>>>()?;
|
||||
Ok(cols)
|
||||
} else {
|
||||
let col = py_obj_to_vec(obj, vm, col_len)?;
|
||||
Ok(vec![col])
|
||||
}
|
||||
}
|
||||
|
||||
/// init interpreter with type PyVector and Module: greptime
|
||||
pub(crate) fn init_interpreter() -> Rc<Interpreter> {
|
||||
let _t = metric::METRIC_RSPY_INIT_ELAPSED.start_timer();
|
||||
INTERPRETER.with(|i| {
|
||||
i.borrow_mut()
|
||||
.get_or_insert_with(|| {
|
||||
// we limit stdlib imports for safety reason, i.e `fcntl` is not allowed here
|
||||
let native_module_allow_list = HashSet::from([
|
||||
"array", "cmath", "gc", "hashlib", "_json", "_random", "math",
|
||||
]);
|
||||
// edge cases, can't use "..Default::default" because Settings is `#[non_exhaustive]`
|
||||
// so more in here: https://internals.rust-lang.org/t/allow-constructing-non-exhaustive-structs-using-default-default/13868
|
||||
let mut settings = vm::Settings::default();
|
||||
// disable SIG_INT handler so our own binary can take ctrl_c handler
|
||||
settings.no_sig_int = true;
|
||||
let interpreter = Rc::new(vm::Interpreter::with_init(settings, |vm| {
|
||||
// not using full stdlib to prevent security issue, instead filter out a few simple util module
|
||||
vm.add_native_modules(
|
||||
rustpython_stdlib::get_module_inits()
|
||||
.filter(|(k, _)| native_module_allow_list.contains(k.as_ref())),
|
||||
);
|
||||
|
||||
// We are freezing the stdlib to include the standard library inside the binary.
|
||||
// so according to this issue:
|
||||
// https://github.com/RustPython/RustPython/issues/4292
|
||||
// add this line for stdlib, so rustpython can found stdlib's python part in bytecode format
|
||||
vm.add_frozen(rustpython_pylib::FROZEN_STDLIB);
|
||||
// add our own custom datatype and module
|
||||
let _ = PyVector::make_class(&vm.ctx);
|
||||
let _ = PyQueryEngine::make_class(&vm.ctx);
|
||||
let _ = PyRecordBatch::make_class(&vm.ctx);
|
||||
init_greptime_builtins("greptime", vm);
|
||||
init_data_frame("data_frame", vm);
|
||||
}));
|
||||
interpreter
|
||||
.enter(|vm| {
|
||||
let sys = vm.sys_module.clone();
|
||||
let version = sys.get_attr("version", vm)?.str(vm)?;
|
||||
info!("Initialized RustPython interpreter {version}");
|
||||
Ok::<(), PyBaseExceptionRef>(())
|
||||
})
|
||||
.expect("fail to display RustPython interpreter version");
|
||||
interpreter
|
||||
})
|
||||
.clone()
|
||||
})
|
||||
}
|
||||
@@ -1,386 +0,0 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use rustpython_vm::class::PyClassImpl;
|
||||
use rustpython_vm::{pymodule as rspymodule, VirtualMachine};
|
||||
|
||||
use crate::python::rspython::builtins::greptime_builtin::PyDataFrame;
|
||||
pub(crate) fn init_data_frame(module_name: &str, vm: &mut VirtualMachine) {
|
||||
let _ = PyDataFrame::make_class(&vm.ctx);
|
||||
let _ = data_frame::PyExpr::make_class(&vm.ctx);
|
||||
vm.add_native_module(module_name.to_owned(), Box::new(data_frame::make_module));
|
||||
}
|
||||
/// with `register_batch`, and then wrap DataFrame API in it
|
||||
#[rspymodule]
|
||||
pub(crate) mod data_frame {
|
||||
use common_recordbatch::{DfRecordBatch, RecordBatch};
|
||||
use datafusion::dataframe::DataFrame as DfDataFrame;
|
||||
use datafusion::execution::context::SessionContext;
|
||||
use datafusion_expr::Expr as DfExpr;
|
||||
use rustpython_vm::convert::ToPyResult;
|
||||
use rustpython_vm::function::PyComparisonValue;
|
||||
use rustpython_vm::protocol::PyNumberMethods;
|
||||
use rustpython_vm::types::{AsNumber, Comparable, PyComparisonOp};
|
||||
use rustpython_vm::{
|
||||
pyclass as rspyclass, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
|
||||
};
|
||||
use snafu::ResultExt;
|
||||
|
||||
use crate::python::error::DataFusionSnafu;
|
||||
use crate::python::ffi_types::py_recordbatch::PyRecordBatch;
|
||||
use crate::python::rspython::builtins::greptime_builtin::{
|
||||
lit, query as get_query_engine, PyDataFrame,
|
||||
};
|
||||
use crate::python::rspython::utils::obj_cast_to;
|
||||
use crate::python::utils::block_on_async;
|
||||
|
||||
impl From<DfDataFrame> for PyDataFrame {
|
||||
fn from(inner: DfDataFrame) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
}
|
||||
/// set DataFrame instance into current scope with given name
|
||||
pub fn set_dataframe_in_scope(
|
||||
scope: &rustpython_vm::scope::Scope,
|
||||
vm: &VirtualMachine,
|
||||
name: &str,
|
||||
rb: &RecordBatch,
|
||||
) -> crate::python::error::Result<()> {
|
||||
let df = PyDataFrame::from_record_batch(rb.df_record_batch())?;
|
||||
scope
|
||||
.locals
|
||||
.set_item(name, vm.new_pyobj(df), vm)
|
||||
.map_err(|e| crate::python::utils::format_py_error(e, vm))
|
||||
}
|
||||
#[rspyclass]
|
||||
impl PyDataFrame {
|
||||
#[pymethod]
|
||||
fn from_sql(sql: String, vm: &VirtualMachine) -> PyResult<Self> {
|
||||
let query_engine = get_query_engine(vm)?;
|
||||
let rb = query_engine.sql_to_rb(sql.clone()).map_err(|e| {
|
||||
vm.new_runtime_error(format!("failed to execute sql: {:?}, error: {:?}", sql, e))
|
||||
})?;
|
||||
let ctx = SessionContext::new();
|
||||
ctx.read_batch(rb.df_record_batch().clone())
|
||||
.map_err(|e| vm.new_runtime_error(format!("{e:?}")))
|
||||
.map(|df| df.into())
|
||||
}
|
||||
/// TODO(discord9): error handling
|
||||
fn from_record_batch(rb: &DfRecordBatch) -> crate::python::error::Result<Self> {
|
||||
let ctx = SessionContext::new();
|
||||
let inner = ctx.read_batch(rb.clone()).context(DataFusionSnafu)?;
|
||||
Ok(Self { inner })
|
||||
}
|
||||
|
||||
#[pymethod]
|
||||
fn select_columns(&self, columns: Vec<String>, vm: &VirtualMachine) -> PyResult<Self> {
|
||||
Ok(self
|
||||
.inner
|
||||
.clone()
|
||||
.select_columns(&columns.iter().map(AsRef::as_ref).collect::<Vec<&str>>())
|
||||
.map_err(|e| vm.new_runtime_error(e.to_string()))?
|
||||
.into())
|
||||
}
|
||||
|
||||
#[pymethod]
|
||||
fn select(&self, expr_list: Vec<PyExprRef>, vm: &VirtualMachine) -> PyResult<Self> {
|
||||
Ok(self
|
||||
.inner
|
||||
.clone()
|
||||
.select(expr_list.iter().map(|e| e.inner.clone()).collect())
|
||||
.map_err(|e| vm.new_runtime_error(e.to_string()))?
|
||||
.into())
|
||||
}
|
||||
|
||||
#[pymethod]
|
||||
fn filter(&self, predicate: PyExprRef, vm: &VirtualMachine) -> PyResult<Self> {
|
||||
Ok(self
|
||||
.inner
|
||||
.clone()
|
||||
.filter(predicate.inner.clone())
|
||||
.map_err(|e| vm.new_runtime_error(e.to_string()))?
|
||||
.into())
|
||||
}
|
||||
|
||||
#[pymethod]
|
||||
fn aggregate(
|
||||
&self,
|
||||
group_expr: Vec<PyExprRef>,
|
||||
aggr_expr: Vec<PyExprRef>,
|
||||
vm: &VirtualMachine,
|
||||
) -> PyResult<Self> {
|
||||
let ret = self.inner.clone().aggregate(
|
||||
group_expr.iter().map(|i| i.inner.clone()).collect(),
|
||||
aggr_expr.iter().map(|i| i.inner.clone()).collect(),
|
||||
);
|
||||
Ok(ret.map_err(|e| vm.new_runtime_error(e.to_string()))?.into())
|
||||
}
|
||||
|
||||
#[pymethod]
|
||||
fn limit(&self, skip: usize, fetch: Option<usize>, vm: &VirtualMachine) -> PyResult<Self> {
|
||||
Ok(self
|
||||
.inner
|
||||
.clone()
|
||||
.limit(skip, fetch)
|
||||
.map_err(|e| vm.new_runtime_error(e.to_string()))?
|
||||
.into())
|
||||
}
|
||||
|
||||
#[pymethod]
|
||||
fn union(&self, df: PyRef<PyDataFrame>, vm: &VirtualMachine) -> PyResult<Self> {
|
||||
Ok(self
|
||||
.inner
|
||||
.clone()
|
||||
.union(df.inner.clone())
|
||||
.map_err(|e| vm.new_runtime_error(e.to_string()))?
|
||||
.into())
|
||||
}
|
||||
|
||||
#[pymethod]
|
||||
fn union_distinct(&self, df: PyRef<PyDataFrame>, vm: &VirtualMachine) -> PyResult<Self> {
|
||||
Ok(self
|
||||
.inner
|
||||
.clone()
|
||||
.union_distinct(df.inner.clone())
|
||||
.map_err(|e| vm.new_runtime_error(e.to_string()))?
|
||||
.into())
|
||||
}
|
||||
|
||||
#[pymethod]
|
||||
fn distinct(&self, vm: &VirtualMachine) -> PyResult<Self> {
|
||||
Ok(self
|
||||
.inner
|
||||
.clone()
|
||||
.distinct()
|
||||
.map_err(|e| vm.new_runtime_error(e.to_string()))?
|
||||
.into())
|
||||
}
|
||||
|
||||
#[pymethod]
|
||||
fn sort(&self, expr: Vec<PyExprRef>, vm: &VirtualMachine) -> PyResult<Self> {
|
||||
Ok(self
|
||||
.inner
|
||||
.clone()
|
||||
.sort(expr.iter().map(|e| e.inner.clone()).collect())
|
||||
.map_err(|e| vm.new_runtime_error(e.to_string()))?
|
||||
.into())
|
||||
}
|
||||
|
||||
#[pymethod]
|
||||
fn join(
|
||||
&self,
|
||||
right: PyRef<PyDataFrame>,
|
||||
join_type: String,
|
||||
left_cols: Vec<String>,
|
||||
right_cols: Vec<String>,
|
||||
filter: Option<PyExprRef>,
|
||||
vm: &VirtualMachine,
|
||||
) -> PyResult<Self> {
|
||||
use datafusion::prelude::JoinType;
|
||||
let join_type = match join_type.as_str() {
|
||||
"inner" | "Inner" => JoinType::Inner,
|
||||
"left" | "Left" => JoinType::Left,
|
||||
"right" | "Right" => JoinType::Right,
|
||||
"full" | "Full" => JoinType::Full,
|
||||
"leftSemi" | "LeftSemi" => JoinType::LeftSemi,
|
||||
"rightSemi" | "RightSemi" => JoinType::RightSemi,
|
||||
"leftAnti" | "LeftAnti" => JoinType::LeftAnti,
|
||||
"rightAnti" | "RightAnti" => JoinType::RightAnti,
|
||||
_ => return Err(vm.new_runtime_error(format!("Unknown join type: {join_type}"))),
|
||||
};
|
||||
let left_cols: Vec<&str> = left_cols.iter().map(AsRef::as_ref).collect();
|
||||
let right_cols: Vec<&str> = right_cols.iter().map(AsRef::as_ref).collect();
|
||||
let filter = filter.map(|f| f.inner.clone());
|
||||
Ok(self
|
||||
.inner
|
||||
.clone()
|
||||
.join(
|
||||
right.inner.clone(),
|
||||
join_type,
|
||||
&left_cols,
|
||||
&right_cols,
|
||||
filter,
|
||||
)
|
||||
.map_err(|e| vm.new_runtime_error(e.to_string()))?
|
||||
.into())
|
||||
}
|
||||
|
||||
#[pymethod]
|
||||
fn intersect(&self, df: PyRef<PyDataFrame>, vm: &VirtualMachine) -> PyResult<Self> {
|
||||
Ok(self
|
||||
.inner
|
||||
.clone()
|
||||
.intersect(df.inner.clone())
|
||||
.map_err(|e| vm.new_runtime_error(e.to_string()))?
|
||||
.into())
|
||||
}
|
||||
|
||||
#[pymethod]
|
||||
fn except(&self, df: PyRef<PyDataFrame>, vm: &VirtualMachine) -> PyResult<Self> {
|
||||
Ok(self
|
||||
.inner
|
||||
.clone()
|
||||
.except(df.inner.clone())
|
||||
.map_err(|e| vm.new_runtime_error(e.to_string()))?
|
||||
.into())
|
||||
}
|
||||
|
||||
#[pymethod]
|
||||
/// collect `DataFrame` results into `PyRecordBatch` that impl Mapping Protocol
|
||||
fn collect(&self, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
|
||||
let inner = self.inner.clone();
|
||||
let res = block_on_async(async { inner.collect().await });
|
||||
let res = res
|
||||
.map_err(|e| vm.new_runtime_error(format!("{e:?}")))?
|
||||
.map_err(|e| vm.new_runtime_error(e.to_string()))?;
|
||||
if res.is_empty() {
|
||||
return Ok(vm.ctx.new_dict().into());
|
||||
}
|
||||
let concat_rb =
|
||||
arrow::compute::concat_batches(&res[0].schema(), res.iter()).map_err(|e| {
|
||||
vm.new_runtime_error(format!(
|
||||
"Concat batches failed for dataframe {self:?}: {e}"
|
||||
))
|
||||
})?;
|
||||
|
||||
// we are inside a macro, so using full path
|
||||
let schema = datatypes::schema::Schema::try_from(concat_rb.schema()).map_err(|e| {
|
||||
vm.new_runtime_error(format!(
|
||||
"Convert to Schema failed for dataframe {self:?}: {e}"
|
||||
))
|
||||
})?;
|
||||
let rb =
|
||||
RecordBatch::try_from_df_record_batch(schema.into(), concat_rb).map_err(|e| {
|
||||
vm.new_runtime_error(format!(
|
||||
"Convert to RecordBatch failed for dataframe {self:?}: {e}"
|
||||
))
|
||||
})?;
|
||||
|
||||
let rb = PyRecordBatch::new(rb);
|
||||
Ok(rb.into_pyobject(vm))
|
||||
}
|
||||
}
|
||||
|
||||
#[rspyclass(module = "data_frame", name = "PyExpr")]
|
||||
#[derive(PyPayload, Debug, Clone)]
|
||||
pub struct PyExpr {
|
||||
pub inner: DfExpr,
|
||||
}
|
||||
|
||||
// TODO(discord9): lit function that take PyObject and turn it into ScalarValue
|
||||
|
||||
pub(crate) type PyExprRef = PyRef<PyExpr>;
|
||||
|
||||
impl From<datafusion_expr::Expr> for PyExpr {
|
||||
fn from(value: DfExpr) -> Self {
|
||||
Self { inner: value }
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparable for PyExpr {
|
||||
fn slot_richcompare(
|
||||
zelf: &PyObject,
|
||||
other: &PyObject,
|
||||
op: PyComparisonOp,
|
||||
vm: &VirtualMachine,
|
||||
) -> PyResult<rustpython_vm::function::Either<PyObjectRef, PyComparisonValue>> {
|
||||
if let Some(zelf) = zelf.downcast_ref::<Self>() {
|
||||
let ret = zelf.richcompare(other.to_owned(), op, vm)?;
|
||||
let ret = ret.into_pyobject(vm);
|
||||
Ok(rustpython_vm::function::Either::A(ret))
|
||||
} else {
|
||||
Err(vm.new_type_error(format!(
|
||||
"unexpected payload {zelf:?} and {other:?} for op {}",
|
||||
op.method_name(&vm.ctx).as_str()
|
||||
)))
|
||||
}
|
||||
}
|
||||
fn cmp(
|
||||
_zelf: &rustpython_vm::Py<Self>,
|
||||
_other: &PyObject,
|
||||
_op: PyComparisonOp,
|
||||
_vm: &VirtualMachine,
|
||||
) -> PyResult<PyComparisonValue> {
|
||||
Ok(PyComparisonValue::NotImplemented)
|
||||
}
|
||||
}
|
||||
|
||||
impl AsNumber for PyExpr {
|
||||
fn as_number() -> &'static PyNumberMethods {
|
||||
static AS_NUMBER: PyNumberMethods = PyNumberMethods {
|
||||
and: Some(|a, b, vm| PyExpr::and(a.to_owned(), b.to_owned(), vm).to_pyresult(vm)),
|
||||
or: Some(|a, b, vm| PyExpr::or(a.to_owned(), b.to_owned(), vm).to_pyresult(vm)),
|
||||
invert: Some(|a, vm| PyExpr::invert((*a).to_owned(), vm).to_pyresult(vm)),
|
||||
|
||||
..PyNumberMethods::NOT_IMPLEMENTED
|
||||
};
|
||||
&AS_NUMBER
|
||||
}
|
||||
}
|
||||
|
||||
#[rspyclass(with(Comparable, AsNumber))]
|
||||
impl PyExpr {
|
||||
fn richcompare(
|
||||
&self,
|
||||
other: PyObjectRef,
|
||||
op: PyComparisonOp,
|
||||
vm: &VirtualMachine,
|
||||
) -> PyResult<Self> {
|
||||
let other = if let Some(other) = other.downcast_ref::<Self>() {
|
||||
other.to_owned()
|
||||
} else {
|
||||
lit(other, vm)?
|
||||
};
|
||||
let f = match op {
|
||||
PyComparisonOp::Eq => DfExpr::eq,
|
||||
PyComparisonOp::Ne => DfExpr::not_eq,
|
||||
PyComparisonOp::Gt => DfExpr::gt,
|
||||
PyComparisonOp::Lt => DfExpr::lt,
|
||||
PyComparisonOp::Ge => DfExpr::gt_eq,
|
||||
PyComparisonOp::Le => DfExpr::lt_eq,
|
||||
};
|
||||
Ok(f(self.inner.clone(), other.inner.clone()).into())
|
||||
}
|
||||
#[pymethod]
|
||||
fn alias(&self, name: String) -> PyResult<PyExpr> {
|
||||
Ok(self.inner.clone().alias(name).into())
|
||||
}
|
||||
|
||||
#[pymethod(magic)]
|
||||
fn and(zelf: PyObjectRef, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyExpr> {
|
||||
let zelf = obj_cast_to::<Self>(zelf, vm)?;
|
||||
let other = obj_cast_to::<Self>(other, vm)?;
|
||||
Ok(zelf.inner.clone().and(other.inner.clone()).into())
|
||||
}
|
||||
#[pymethod(magic)]
|
||||
fn or(zelf: PyObjectRef, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyExpr> {
|
||||
let zelf = obj_cast_to::<Self>(zelf, vm)?;
|
||||
let other = obj_cast_to::<Self>(other, vm)?;
|
||||
Ok(zelf.inner.clone().or(other.inner.clone()).into())
|
||||
}
|
||||
|
||||
/// `~` operator, return `!self`
|
||||
#[pymethod(magic)]
|
||||
fn invert(zelf: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyExpr> {
|
||||
let zelf = obj_cast_to::<Self>(zelf, vm)?;
|
||||
Ok((!zelf.inner.clone()).into())
|
||||
}
|
||||
|
||||
/// sort ascending&nulls_first
|
||||
#[pymethod]
|
||||
fn sort(&self) -> PyExpr {
|
||||
self.inner.clone().sort(true, true).into()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,310 +0,0 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::fs::File;
|
||||
use std::io::prelude::*;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_recordbatch::RecordBatch;
|
||||
use common_telemetry::{error, info};
|
||||
use console::style;
|
||||
use datatypes::arrow::datatypes::DataType as ArrowDataType;
|
||||
use datatypes::data_type::{ConcreteDataType, DataType};
|
||||
use datatypes::schema::{ColumnSchema, Schema};
|
||||
use datatypes::vectors::{Float32Vector, Float64Vector, Int64Vector, VectorRef};
|
||||
use ron::from_str as from_ron_string;
|
||||
use rustpython_parser::{parse, Mode};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::engine::EvalContext;
|
||||
use crate::python::error::{get_error_reason_loc, pretty_print_error_in_src, visualize_loc, Error};
|
||||
use crate::python::ffi_types::copr::parse::parse_and_compile_copr;
|
||||
use crate::python::ffi_types::copr::{exec_coprocessor, AnnotationInfo};
|
||||
use crate::python::ffi_types::Coprocessor;
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct TestCase {
|
||||
name: String,
|
||||
code: String,
|
||||
predicate: Predicate,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
enum Predicate {
|
||||
ParseIsOk {
|
||||
result: Box<Coprocessor>,
|
||||
},
|
||||
ParseIsErr {
|
||||
/// used to check if after serialize [`Error`] into a String, that string contains `reason`
|
||||
reason: String,
|
||||
},
|
||||
ExecIsOk {
|
||||
fields: Vec<AnnotationInfo>,
|
||||
columns: Vec<ColumnInfo>,
|
||||
},
|
||||
ExecIsErr {
|
||||
/// used to check if after serialize [`Error`] into a String, that string contains `reason`
|
||||
reason: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
struct ColumnInfo {
|
||||
pub ty: ArrowDataType,
|
||||
pub len: usize,
|
||||
}
|
||||
|
||||
fn create_sample_recordbatch() -> RecordBatch {
|
||||
let cpu_array = Float32Vector::from_slice([0.9f32, 0.8, 0.7, 0.6]);
|
||||
let mem_array = Float64Vector::from_slice([0.1f64, 0.2, 0.3, 0.4]);
|
||||
let schema = Arc::new(Schema::new(vec![
|
||||
ColumnSchema::new("cpu", ConcreteDataType::float32_datatype(), false),
|
||||
ColumnSchema::new("mem", ConcreteDataType::float64_datatype(), false),
|
||||
]));
|
||||
|
||||
RecordBatch::new(
|
||||
schema,
|
||||
[
|
||||
Arc::new(cpu_array) as VectorRef,
|
||||
Arc::new(mem_array) as VectorRef,
|
||||
],
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// test cases which read from a .ron file, deser,
|
||||
///
|
||||
/// and exec/parse (depending on the type of predicate) then decide if result is as expected
|
||||
#[test]
|
||||
fn run_ron_testcases() {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
|
||||
let loc = Path::new("src/python/rspython/testcases.ron");
|
||||
let loc = loc.to_str().expect("Fail to parse path");
|
||||
let mut file = File::open(loc).expect("Fail to open file");
|
||||
let mut buf = String::new();
|
||||
let _ = file.read_to_string(&mut buf).unwrap();
|
||||
let testcases: Vec<TestCase> = from_ron_string(&buf).expect("Fail to convert to testcases");
|
||||
info!("Read {} testcases from {}", testcases.len(), loc);
|
||||
for testcase in testcases {
|
||||
info!(".ron test {}", testcase.name);
|
||||
match testcase.predicate {
|
||||
Predicate::ParseIsOk { result } => {
|
||||
let copr = parse_and_compile_copr(&testcase.code, None);
|
||||
let mut copr = copr.unwrap();
|
||||
copr.script = "".into();
|
||||
assert_eq!(copr, *result);
|
||||
}
|
||||
Predicate::ParseIsErr { reason } => {
|
||||
let copr = parse_and_compile_copr(&testcase.code, None);
|
||||
assert!(copr.is_err(), "Expect to be err, actual {copr:#?}");
|
||||
|
||||
let res = &copr.unwrap_err();
|
||||
error!(
|
||||
"{}",
|
||||
pretty_print_error_in_src(&testcase.code, res, 0, "<embedded>")
|
||||
);
|
||||
let (res, _) = get_error_reason_loc(res);
|
||||
assert!(
|
||||
res.contains(&reason),
|
||||
"{} Parse Error, expect \"{reason}\" in \"{res}\", actual not found.",
|
||||
testcase.code,
|
||||
);
|
||||
}
|
||||
Predicate::ExecIsOk { fields, columns } => {
|
||||
let rb = create_sample_recordbatch();
|
||||
let res =
|
||||
exec_coprocessor(&testcase.code, &Some(rb), &EvalContext::default()).unwrap();
|
||||
fields
|
||||
.iter()
|
||||
.zip(res.schema.column_schemas())
|
||||
.for_each(|(anno, real)| {
|
||||
assert!(
|
||||
anno.datatype.as_ref().unwrap() == &real.data_type
|
||||
&& anno.is_nullable == real.is_nullable(),
|
||||
"Fields expected to be {anno:#?}, actual {real:#?}"
|
||||
);
|
||||
});
|
||||
columns.iter().zip(res.columns()).for_each(|(anno, real)| {
|
||||
assert!(
|
||||
anno.ty == real.data_type().as_arrow_type() && anno.len == real.len(),
|
||||
"Type or length not match! Expect [{:#?}; {}], actual [{:#?}; {}]",
|
||||
anno.ty,
|
||||
anno.len,
|
||||
real.data_type(),
|
||||
real.len()
|
||||
);
|
||||
});
|
||||
}
|
||||
Predicate::ExecIsErr {
|
||||
reason: part_reason,
|
||||
} => {
|
||||
let rb = create_sample_recordbatch();
|
||||
let res = exec_coprocessor(&testcase.code, &Some(rb), &EvalContext::default());
|
||||
assert!(res.is_err(), "{res:#?}\nExpect Err(...), actual Ok(...)");
|
||||
if let Err(res) = res {
|
||||
error!(
|
||||
"{}",
|
||||
pretty_print_error_in_src(&testcase.code, &res, 1120, "<embedded>")
|
||||
);
|
||||
let (reason, _) = get_error_reason_loc(&res);
|
||||
assert!(
|
||||
reason.contains(&part_reason),
|
||||
"{}\nExecute error, expect \"{reason}\" in \"{res}\", actual not found.",
|
||||
testcase.code,
|
||||
reason = style(reason).green(),
|
||||
res = style(res).red()
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
info!(" ... {}", style("ok✅").green());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[allow(unused)]
|
||||
fn test_type_anno() {
|
||||
let python_source = r#"
|
||||
@copr(args=["cpu", "mem"], returns=["perf", "what", "how", "why"])
|
||||
def a(cpu, mem: vector[f64])->(vector[f64|None], vector[f64], vector[_], vector[ _ | None]):
|
||||
return cpu + mem, cpu - mem, cpu * mem, cpu / mem
|
||||
"#;
|
||||
let pyast = parse(python_source, Mode::Interactive, "<embedded>").unwrap();
|
||||
let copr = parse_and_compile_copr(python_source, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[allow(clippy::print_stdout, unused_must_use)]
|
||||
// allow print in test function for debug purpose(like for quick testing a syntax&ideas)
|
||||
fn test_calc_rvs() {
|
||||
let python_source = r#"
|
||||
@coprocessor(args=["open_time", "close"], returns=[
|
||||
"rv_7d",
|
||||
"rv_15d",
|
||||
"rv_30d",
|
||||
"rv_60d",
|
||||
"rv_90d",
|
||||
"rv_180d"
|
||||
])
|
||||
def calc_rvs(open_time, close):
|
||||
from greptime import vector, log, prev, sqrt, datetime, pow, sum, last
|
||||
import greptime as g
|
||||
def calc_rv(close, open_time, time, interval):
|
||||
mask = (open_time < time) & (open_time > time - interval)
|
||||
close = close[mask]
|
||||
open_time = open_time[mask]
|
||||
close = g.interval(open_time, close, datetime("10m"), lambda x:last(x))
|
||||
|
||||
avg_time_interval = (open_time[-1] - open_time[0])/(len(open_time)-1)
|
||||
ref = log(close/prev(close))
|
||||
var = sum(pow(ref, 2)/(len(ref)-1))
|
||||
return sqrt(var/avg_time_interval)
|
||||
|
||||
# how to get env var,
|
||||
# maybe through accessing scope and serde then send to remote?
|
||||
timepoint = open_time[-1]
|
||||
rv_7d = vector([calc_rv(close, open_time, timepoint, datetime("7d"))])
|
||||
rv_15d = vector([calc_rv(close, open_time, timepoint, datetime("15d"))])
|
||||
rv_30d = vector([calc_rv(close, open_time, timepoint, datetime("30d"))])
|
||||
rv_60d = vector([calc_rv(close, open_time, timepoint, datetime("60d"))])
|
||||
rv_90d = vector([calc_rv(close, open_time, timepoint, datetime("90d"))])
|
||||
rv_180d = vector([calc_rv(close, open_time, timepoint, datetime("180d"))])
|
||||
return rv_7d, rv_15d, rv_30d, rv_60d, rv_90d, rv_180d
|
||||
"#;
|
||||
let close_array = Float32Vector::from_slice([
|
||||
10106.79f32,
|
||||
10106.09,
|
||||
10108.73,
|
||||
10106.38,
|
||||
10106.95,
|
||||
10107.55,
|
||||
10104.68,
|
||||
10108.8,
|
||||
10115.96,
|
||||
10117.08,
|
||||
10120.43,
|
||||
]);
|
||||
let open_time_array = Int64Vector::from_slice([
|
||||
300i64, 900i64, 1200i64, 1800i64, 2400i64, 3000i64, 3600i64, 4200i64, 4800i64, 5400i64,
|
||||
6000i64,
|
||||
]);
|
||||
let schema = Arc::new(Schema::new(vec![
|
||||
ColumnSchema::new("close", ConcreteDataType::float32_datatype(), false),
|
||||
ColumnSchema::new("open_time", ConcreteDataType::int64_datatype(), false),
|
||||
]));
|
||||
let rb = RecordBatch::new(
|
||||
schema,
|
||||
[
|
||||
Arc::new(close_array) as VectorRef,
|
||||
Arc::new(open_time_array) as VectorRef,
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
let ret = exec_coprocessor(python_source, &Some(rb), &EvalContext::default());
|
||||
if let Err(Error::PyParse { location: _, error }) = ret {
|
||||
let res = visualize_loc(
|
||||
python_source,
|
||||
&error.location,
|
||||
"unknown tokens",
|
||||
error.error.to_string().as_str(),
|
||||
0,
|
||||
"copr.py",
|
||||
);
|
||||
info!("{res}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[allow(clippy::print_stdout, unused_must_use)]
|
||||
// allow print in test function for debug purpose(like for quick testing a syntax&ideas)
|
||||
fn test_coprocessor() {
|
||||
let python_source = r#"
|
||||
@copr(args=["cpu", "mem"], returns=["ref"])
|
||||
def a(cpu, mem):
|
||||
import greptime as gt
|
||||
from greptime import vector, log2, prev, sum, pow, sqrt, datetime
|
||||
abc = vector([v[0] > v[1] for v in zip(cpu, mem)])
|
||||
fed = cpu.filter(abc)
|
||||
ref = log2(fed/prev(fed))
|
||||
return cpu[(cpu > 0.5) & ~( cpu >= 0.75)]
|
||||
"#;
|
||||
let cpu_array = Float32Vector::from_slice([0.9f32, 0.8, 0.7, 0.3]);
|
||||
let mem_array = Float64Vector::from_slice([0.1f64, 0.2, 0.3, 0.4]);
|
||||
let schema = Arc::new(Schema::new(vec![
|
||||
ColumnSchema::new("cpu", ConcreteDataType::float32_datatype(), false),
|
||||
ColumnSchema::new("mem", ConcreteDataType::float64_datatype(), false),
|
||||
]));
|
||||
let rb = RecordBatch::new(
|
||||
schema,
|
||||
[
|
||||
Arc::new(cpu_array) as VectorRef,
|
||||
Arc::new(mem_array) as VectorRef,
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
let ret = exec_coprocessor(python_source, &Some(rb), &EvalContext::default());
|
||||
if let Err(Error::PyParse { location: _, error }) = ret {
|
||||
let res = visualize_loc(
|
||||
python_source,
|
||||
&error.location,
|
||||
"unknown tokens",
|
||||
error.error.to_string().as_str(),
|
||||
0,
|
||||
"copr.py",
|
||||
);
|
||||
info!("{res}");
|
||||
}
|
||||
}
|
||||
@@ -1,656 +0,0 @@
|
||||
// This is the file for python coprocessor's testcases,
|
||||
// including coprocessor parsing test and execute test
|
||||
// check src/script/python/test.rs::run_ron_testcases() for more information
|
||||
[
|
||||
(
|
||||
name: "correct_parse",
|
||||
code: r#"
|
||||
import greptime as gt
|
||||
from greptime import pow
|
||||
def add(a, b):
|
||||
return a + b
|
||||
def sub(a, b):
|
||||
return a - b
|
||||
@copr(args=["cpu", "mem"], returns=["perf", "what", "how", "why"])
|
||||
def a(cpu: vector[f32], mem: vector[f64])->(vector[f64], vector[f64|None], vector[_], vector[_ | None]):
|
||||
return add(cpu, mem), sub(cpu, mem), cpu * mem, cpu / mem
|
||||
"#,
|
||||
predicate: ParseIsOk(
|
||||
result: (
|
||||
name: "a",
|
||||
deco_args: (
|
||||
arg_names: Some(["cpu", "mem"]),
|
||||
ret_names: ["perf", "what", "how", "why"],
|
||||
),
|
||||
arg_types: [
|
||||
Some((
|
||||
datatype: Some(Float32(())),
|
||||
is_nullable: false
|
||||
)),
|
||||
Some((
|
||||
datatype: Some(Float64(())),
|
||||
is_nullable: false
|
||||
)),
|
||||
],
|
||||
return_types: [
|
||||
Some((
|
||||
datatype: Some(Float64(())),
|
||||
is_nullable: false
|
||||
)),
|
||||
Some((
|
||||
datatype: Some(Float64(())),
|
||||
is_nullable: true
|
||||
)),
|
||||
Some((
|
||||
datatype: None,
|
||||
is_nullable: false
|
||||
)),
|
||||
Some((
|
||||
datatype: None,
|
||||
is_nullable: true
|
||||
)),
|
||||
],
|
||||
kwarg: None,
|
||||
)
|
||||
)
|
||||
),
|
||||
(
|
||||
name: "correct_parse_params",
|
||||
code: r#"
|
||||
import greptime as gt
|
||||
from greptime import pow
|
||||
def add(a, b):
|
||||
return a + b
|
||||
def sub(a, b):
|
||||
return a - b
|
||||
@copr(args=["cpu", "mem"], returns=["perf", "what", "how", "why"])
|
||||
def a(cpu: vector[f32], mem: vector[f64], **params) -> (vector[f64], vector[f64|None], vector[_], vector[_ | None]):
|
||||
for key, value in params.items():
|
||||
print("%s == %s" % (key, value))
|
||||
return add(cpu, mem), sub(cpu, mem), cpu * mem, cpu / mem
|
||||
"#,
|
||||
predicate: ParseIsOk(
|
||||
result: (
|
||||
name: "a",
|
||||
deco_args: (
|
||||
arg_names: Some(["cpu", "mem"]),
|
||||
ret_names: ["perf", "what", "how", "why"],
|
||||
),
|
||||
arg_types: [
|
||||
Some((
|
||||
datatype: Some(Float32(())),
|
||||
is_nullable: false
|
||||
)),
|
||||
Some((
|
||||
datatype: Some(Float64(())),
|
||||
is_nullable: false
|
||||
)),
|
||||
],
|
||||
return_types: [
|
||||
Some((
|
||||
datatype: Some(Float64(())),
|
||||
is_nullable: false
|
||||
)),
|
||||
Some((
|
||||
datatype: Some(Float64(())),
|
||||
is_nullable: true
|
||||
)),
|
||||
Some((
|
||||
datatype: None,
|
||||
is_nullable: false
|
||||
)),
|
||||
Some((
|
||||
datatype: None,
|
||||
is_nullable: true
|
||||
)),
|
||||
],
|
||||
kwarg: Some("params"),
|
||||
)
|
||||
)
|
||||
),
|
||||
(
|
||||
name: "missing_decorator",
|
||||
code: r#"
|
||||
def a(cpu: vector[f32], mem: vector[f64])->(vector[f64], vector[f64|None], vector[_], vector[_ | None]):
|
||||
return cpu + mem, cpu - mem, cpu * mem, cpu / mem
|
||||
"#,
|
||||
predicate: ParseIsErr(
|
||||
reason: "Coprocessor not found in script"
|
||||
)
|
||||
),
|
||||
(
|
||||
name: "too_many_decorators",
|
||||
code: r#"
|
||||
@copr(args=["a"], returns=["r"])
|
||||
def test1(a):
|
||||
return a;
|
||||
@copr(args=["a"], returns=["r"])
|
||||
def test2(a):
|
||||
return a;
|
||||
"#,
|
||||
predicate: ParseIsErr(
|
||||
reason: "Expect one and only one python function with `@coprocessor` or `@cpor` decorator"
|
||||
)
|
||||
),
|
||||
(
|
||||
name: "not_a_list_of_string",
|
||||
code: r#"
|
||||
@copr(args=["cpu", 3], returns=["perf", "what", "how", "why"])
|
||||
def a(cpu: vector[f32], mem: vector[f64])->(vector[f64], vector[f64|None], vector[_], vector[_ | None]):
|
||||
return cpu + mem, cpu - mem, cpu * mem, cpu / mem
|
||||
"#,
|
||||
predicate: ParseIsErr(
|
||||
reason: "Expect a list of String, found"
|
||||
)
|
||||
),
|
||||
(
|
||||
name: "not_even_a_list",
|
||||
code: r#"
|
||||
@copr(args=42, returns=["perf", "what", "how", "why"])
|
||||
def a(cpu: vector[f32], mem: vector[f64])->(vector[f64], vector[f64|None], vector[_], vector[_ | None]):
|
||||
return cpu + mem, cpu - mem, cpu * mem, cpu / mem
|
||||
"#,
|
||||
predicate: ParseIsErr(
|
||||
reason: "Expect a list, found"
|
||||
)
|
||||
),
|
||||
(
|
||||
// unknown type names
|
||||
name: "unknown_type_names",
|
||||
code: r#"
|
||||
@copr(args=["cpu", "mem"], returns=["perf", "what", "how", "why"])
|
||||
def a(cpu: vector[g32], mem: vector[f64])->(vector[f64], vector[f64|None], vector[_], vector[_ | None]):
|
||||
return cpu + mem, cpu - mem, cpu * mem, cpu / mem
|
||||
"#,
|
||||
predicate: ParseIsErr(
|
||||
reason: "Unknown datatype:"
|
||||
)
|
||||
),
|
||||
(
|
||||
// two type name
|
||||
name: "two_type_names",
|
||||
code: r#"
|
||||
@copr(args=["cpu", "mem"], returns=["perf", "what", "how", "why"])
|
||||
def a(cpu: vector[f32 | f64], mem: vector[f64])->(vector[f64], vector[f64|None], vector[_], vector[_ | None]):
|
||||
return cpu + mem, cpu - mem, cpu * mem, cpu / mem
|
||||
"#,
|
||||
predicate: ParseIsErr(
|
||||
reason: "Expect one typenames and one `None`"
|
||||
)
|
||||
),
|
||||
(
|
||||
name: "two_none",
|
||||
// two `None`
|
||||
code: r#"
|
||||
@copr(args=["cpu", "mem"], returns=["perf", "what", "how", "why"])
|
||||
def a(cpu: vector[None | None], mem: vector[f64])->(vector[f64], vector[None|None], vector[_], vector[_ | None]):
|
||||
return cpu + mem, cpu - mem, cpu * mem, cpu / mem
|
||||
"#,
|
||||
predicate: ParseIsErr(
|
||||
reason: "Expect one typenames and one `None`"
|
||||
)
|
||||
),
|
||||
(
|
||||
// Expect a Types name
|
||||
name: "unknown_type_names_in_ret",
|
||||
code: r#"
|
||||
@copr(args=["cpu", "mem"], returns=["perf", "what", "how", "why"])
|
||||
def a(cpu: vector[f64|None], mem: vector[f64])->(vector[g64], vector[f64|None], vector[_], vector[_ | None]):
|
||||
return cpu + mem, cpu - mem, cpu * mem, cpu / mem
|
||||
"#,
|
||||
predicate: ParseIsErr(
|
||||
reason: "Unknown datatype:"
|
||||
)
|
||||
),
|
||||
(
|
||||
// no more `into`
|
||||
name: "call_deprecated_for_cast_into",
|
||||
code: r#"
|
||||
@copr(args=["cpu", "mem"], returns=["perf", "what", "how", "why"])
|
||||
def a(cpu: vector[cast(f64)], mem: vector[f64])->(vector[f64], vector[f64|None], vector[_], vector[_ | None]):
|
||||
return cpu + mem, cpu - mem, cpu * mem, cpu / mem
|
||||
"#,
|
||||
predicate: ParseIsErr(
|
||||
reason: "Expect type in `vector[...]`, found "
|
||||
)
|
||||
),
|
||||
(
|
||||
// Expect `vector` not `vec`
|
||||
name: "vector_not_vec",
|
||||
code: r#"
|
||||
@copr(args=["cpu", "mem"], returns=["perf", "what", "how", "why"])
|
||||
def a(cpu: vec[f64], mem: vector[f64])->(vector[f64|None], vector[f64], vector[_], vector[_ | None]):
|
||||
return cpu + mem, cpu - mem, cpu * mem, cpu / mem
|
||||
"#,
|
||||
predicate: ParseIsErr(
|
||||
reason: "Wrong type annotation, expect `vector[...]`, found"
|
||||
)
|
||||
),
|
||||
(
|
||||
// Expect `None`
|
||||
name: "expect_none",
|
||||
code: r#"
|
||||
@copr(args=["cpu", "mem"], returns=["perf", "what", "how", "why"])
|
||||
def a(cpu: vector[f64|1], mem: vector[f64])->(vector[f64|None], vector[f64], vector[_], vector[_ | None]):
|
||||
return cpu + mem, cpu - mem, cpu * mem, cpu / mem
|
||||
"#,
|
||||
predicate: ParseIsErr(
|
||||
reason: "Expect a type name and a `None`, found left: "
|
||||
)
|
||||
),
|
||||
(
|
||||
// more than one statement
|
||||
name: "two_stmt",
|
||||
code: r#"
|
||||
print("hello world")
|
||||
@copr(args=["cpu", "mem"], returns=["perf", "what", "how", "why"])
|
||||
def a(cpu: vector[f64], mem: vector[f64])->(vector[None|None], vector[into(f64)], vector[f64], vector[f64 | None]):
|
||||
return cpu + mem, cpu - mem, cpu * mem, cpu / mem
|
||||
"#,
|
||||
predicate: ParseIsErr(
|
||||
reason:
|
||||
"Expect a function definition, but found a"
|
||||
)
|
||||
),
|
||||
(
|
||||
// wrong decorator name
|
||||
name: "typo_copr",
|
||||
code: r#"
|
||||
@corp(args=["cpu", "mem"], returns=["perf", "what", "how", "why"])
|
||||
def a(cpu: vector[f64], mem: vector[f64])->(vector[None|None], vector[into(f64)], vector[f64], vector[f64 | None]):
|
||||
return cpu + mem, cpu - mem, cpu * mem, cpu / mem
|
||||
"#,
|
||||
predicate: ParseIsErr(
|
||||
reason:
|
||||
"Expect decorator with name `copr` or `coprocessor`, found"
|
||||
)
|
||||
),
|
||||
(
|
||||
name: "extra_keywords",
|
||||
code: r#"
|
||||
@copr(args=["cpu", "mem"], sql=3,psql = 4,rets=5)
|
||||
def a(cpu: vector[f64], mem: vector[f64])->(vector[f64|None], vector[into(f64)], vector[f64], vector[f64 | None]):
|
||||
return cpu + mem, cpu - mem, cpu * mem, cpu / mem
|
||||
"#,
|
||||
predicate: ParseIsErr(
|
||||
reason:
|
||||
"Expect a list of String, found one element to be"
|
||||
)
|
||||
),
|
||||
(
|
||||
name: "missing_keywords",
|
||||
code: r#"
|
||||
@copr(args=["cpu", "mem"])
|
||||
def a(cpu: vector[f64], mem: vector[f64])->(vector[f64|None], vector[into(f64)], vector[f64], vector[f64 | None]):
|
||||
return cpu + mem, cpu - mem, cpu * mem, cpu / mem
|
||||
"#,
|
||||
predicate: ParseIsErr(
|
||||
reason:
|
||||
"Expect `returns` keyword"
|
||||
)
|
||||
),
|
||||
(
|
||||
// exec_coprocessor
|
||||
name: "correct_exec",
|
||||
code: r#"
|
||||
@copr(args=["cpu", "mem"], returns=["perf", "what"])
|
||||
def a(cpu: vector[f32], mem: vector[f64])->(vector[f64|None],
|
||||
vector[f32]):
|
||||
return cpu + mem, cpu - mem
|
||||
"#,
|
||||
predicate: ExecIsOk(
|
||||
fields: [
|
||||
(
|
||||
datatype: Some(Float64(())),
|
||||
is_nullable: true
|
||||
),
|
||||
(
|
||||
datatype: Some(Float32(())),
|
||||
is_nullable: false
|
||||
),
|
||||
],
|
||||
columns: [
|
||||
(
|
||||
ty: Float64,
|
||||
len: 4
|
||||
),
|
||||
(
|
||||
ty: Float32,
|
||||
len: 4
|
||||
)
|
||||
]
|
||||
)
|
||||
),
|
||||
(
|
||||
// constant column(float)
|
||||
name: "constant_float_col",
|
||||
code: r#"
|
||||
@copr(args=["cpu", "mem"], returns=["perf", "what"])
|
||||
def a(cpu: vector[f32], mem: vector[f64])->(vector[f64|None],
|
||||
vector[f32]):
|
||||
return cpu + mem, 1.0
|
||||
"#,
|
||||
predicate: ExecIsOk(
|
||||
fields: [
|
||||
(
|
||||
datatype: Some(Float64(())),
|
||||
is_nullable: true
|
||||
),
|
||||
(
|
||||
datatype: Some(Float32(())),
|
||||
is_nullable: false
|
||||
),
|
||||
],
|
||||
columns: [
|
||||
(
|
||||
ty: Float64,
|
||||
len: 4
|
||||
),
|
||||
(
|
||||
ty: Float32,
|
||||
len: 4
|
||||
)
|
||||
]
|
||||
)
|
||||
),
|
||||
(
|
||||
// constant column(int)
|
||||
name: "constant_int_col",
|
||||
code: r#"
|
||||
@copr(args=["cpu", "mem"], returns=["perf", "what"])
|
||||
def a(cpu: vector[f32], mem: vector[f64])->(vector[f64|None],
|
||||
vector[f32]):
|
||||
return cpu + mem, 1
|
||||
"#,
|
||||
predicate: ExecIsOk(
|
||||
fields: [
|
||||
(
|
||||
datatype: Some(Float64(())),
|
||||
is_nullable: true
|
||||
),
|
||||
(
|
||||
datatype: Some(Float32(())),
|
||||
is_nullable: false
|
||||
),
|
||||
],
|
||||
columns: [
|
||||
(
|
||||
ty: Float64,
|
||||
len: 4
|
||||
),
|
||||
(
|
||||
ty: Float32,
|
||||
len: 4
|
||||
)
|
||||
]
|
||||
)
|
||||
),
|
||||
(
|
||||
// constant column(bool)
|
||||
name: "constant_bool_col",
|
||||
code: r#"
|
||||
@copr(args=["cpu", "mem"], returns=["perf", "what"])
|
||||
def a(cpu: vector[f32], mem: vector[f64])->(vector[f64|None],
|
||||
vector[f32]):
|
||||
return cpu + mem, True
|
||||
"#,
|
||||
predicate: ExecIsOk(
|
||||
fields: [
|
||||
(
|
||||
datatype: Some(Float64(())),
|
||||
is_nullable: true
|
||||
),
|
||||
(
|
||||
datatype: Some(Float32(())),
|
||||
is_nullable: false
|
||||
),
|
||||
],
|
||||
columns: [
|
||||
(
|
||||
ty: Float64,
|
||||
len: 4
|
||||
),
|
||||
(
|
||||
ty: Float32,
|
||||
len: 4
|
||||
)
|
||||
]
|
||||
)
|
||||
),
|
||||
(
|
||||
name: "constant_list",
|
||||
code: r#"
|
||||
@copr(args=["cpu", "mem"], returns=["what"])
|
||||
def a(cpu: vector[f32], mem: vector[f64]):
|
||||
return ["apple" ,"banana", "cherry"]
|
||||
"#,
|
||||
predicate: ExecIsOk(
|
||||
fields: [
|
||||
(
|
||||
datatype: Some(String(())),
|
||||
is_nullable: false,
|
||||
),
|
||||
],
|
||||
columns: [
|
||||
(
|
||||
ty: Utf8,
|
||||
len: 3
|
||||
),
|
||||
]
|
||||
)
|
||||
),
|
||||
(
|
||||
name: "constant_list_different_type",
|
||||
code: r#"
|
||||
@copr(args=["cpu", "mem"], returns=["what"])
|
||||
def a(cpu: vector[f32], mem: vector[f64]):
|
||||
return ["apple" ,3, "cherry"]
|
||||
"#,
|
||||
predicate: ExecIsErr(
|
||||
reason: "All elements in a list should be same type to cast to Datafusion list!",
|
||||
)
|
||||
),
|
||||
(
|
||||
// expect 4 vector ,found 5
|
||||
name: "ret_nums_wrong",
|
||||
code: r#"
|
||||
@copr(args=["cpu", "mem"], returns=["perf", "what", "how", "why", "whatever", "nihilism"])
|
||||
def a(cpu: vector[f32], mem: vector[f64])->(vector[f64|None], vector[f64], vector[f64], vector[f64 | None], vector[bool], vector[_ | None]):
|
||||
return cpu + mem, cpu - mem, cpu * mem, cpu / mem, cpu
|
||||
"#,
|
||||
predicate: ExecIsErr(
|
||||
reason: "The number of return Vector is wrong, expect"
|
||||
)
|
||||
),
|
||||
(
|
||||
name: "div_by_zero",
|
||||
code: r#"
|
||||
@copr(args=["cpu", "mem"], returns=["perf", "what"])
|
||||
def a(cpu: vector[f32], mem: vector[f64])->(vector[f64|None],
|
||||
vector[f32]):
|
||||
return cpu + mem, cpu - mem*(1/0)
|
||||
"#,
|
||||
predicate: ExecIsErr(
|
||||
reason: "ZeroDivisionError: division by zero"
|
||||
)
|
||||
),
|
||||
(
|
||||
name: "unexpected_token",
|
||||
code: r#"
|
||||
@copr(args=["cpu", "mem"], returns=["perf", "what"])
|
||||
def a(cpu: vector[f32], mem: vector[f64])->(vector[f64|None],
|
||||
vector[f32]):
|
||||
return cpu + mem, cpu - mem***
|
||||
"#,
|
||||
predicate: ParseIsErr(
|
||||
reason: "invalid syntax. Got unexpected token "
|
||||
)
|
||||
),
|
||||
(
|
||||
name: "wrong_return_anno",
|
||||
code: r#"
|
||||
@copr(args=["cpu", "mem"], returns=["perf", "what"])
|
||||
def a(cpu: vector[f32], mem: vector[f64])->f32:
|
||||
return cpu + mem, cpu - mem
|
||||
"#,
|
||||
predicate: ParseIsErr(
|
||||
reason: "Expect `(vector[...], vector[...], ...)` or `vector[...]`, found "
|
||||
)
|
||||
),
|
||||
(
|
||||
name: "break_outside_loop",
|
||||
code: r#"
|
||||
@copr(args=["cpu", "mem"], returns=["perf", "what"])
|
||||
def a(cpu: vector[f32], mem: vector[f64])->(vector[f64], vector[f64]):
|
||||
break
|
||||
return cpu + mem, cpu - mem
|
||||
"#,
|
||||
predicate: ExecIsErr(
|
||||
reason: "'break' outside loop"
|
||||
)
|
||||
),
|
||||
(
|
||||
name: "not_even_wrong",
|
||||
code: r#"
|
||||
42
|
||||
"#,
|
||||
predicate: ParseIsErr(
|
||||
reason: "Expect a function definition, but found a"
|
||||
)
|
||||
),
|
||||
(
|
||||
// constant column(int)
|
||||
name: "test_import_stdlib",
|
||||
code: r#"
|
||||
@copr(args=["cpu", "mem"], returns=["perf", "what"])
|
||||
def a(cpu: vector[f32], mem: vector[f64])->(vector[f64|None],
|
||||
vector[f32]):
|
||||
# test if using allow list for stdlib damage unrelated module
|
||||
from collections import deque
|
||||
import math
|
||||
math.ceil(0.2)
|
||||
import string
|
||||
return cpu + mem, 1
|
||||
"#,
|
||||
predicate: ExecIsOk(
|
||||
fields: [
|
||||
(
|
||||
datatype: Some(Float64(())),
|
||||
is_nullable: true
|
||||
),
|
||||
(
|
||||
datatype: Some(Float32(())),
|
||||
is_nullable: false
|
||||
),
|
||||
],
|
||||
columns: [
|
||||
(
|
||||
ty: Float64,
|
||||
len: 4
|
||||
),
|
||||
(
|
||||
ty: Float32,
|
||||
len: 4
|
||||
)
|
||||
]
|
||||
)
|
||||
),
|
||||
(
|
||||
// constant column(int)
|
||||
name: "test_data_frame",
|
||||
code: r#"
|
||||
from greptime import col, dataframe
|
||||
@copr(args=["cpu", "mem"], returns=["perf", "what"])
|
||||
def a(cpu: vector[f32], mem: vector[f64])->(vector[f64|None],
|
||||
vector[f32]):
|
||||
ret = dataframe().select([col("cpu"), col("mem")]).collect()[0]
|
||||
return ret[0], ret[1]
|
||||
"#,
|
||||
predicate: ExecIsOk(
|
||||
fields: [
|
||||
(
|
||||
datatype: Some(Float64(())),
|
||||
is_nullable: true
|
||||
),
|
||||
(
|
||||
datatype: Some(Float32(())),
|
||||
is_nullable: false
|
||||
),
|
||||
],
|
||||
columns: [
|
||||
(
|
||||
ty: Float64,
|
||||
len: 4
|
||||
),
|
||||
(
|
||||
ty: Float32,
|
||||
len: 4
|
||||
)
|
||||
]
|
||||
)
|
||||
),
|
||||
(
|
||||
// constant column(int)
|
||||
name: "test_data_frame",
|
||||
code: r#"
|
||||
from greptime import col, dataframe
|
||||
@copr(args=["cpu", "mem"], returns=["perf", "what"])
|
||||
def a(cpu: vector[f32], mem: vector[f64])->(vector[f64|None],
|
||||
vector[f32]):
|
||||
ret = dataframe().filter(col("cpu")>col("mem")).collect()[0]
|
||||
return ret[0], ret[1]
|
||||
"#,
|
||||
predicate: ExecIsOk(
|
||||
fields: [
|
||||
(
|
||||
datatype: Some(Float64(())),
|
||||
is_nullable: true
|
||||
),
|
||||
(
|
||||
datatype: Some(Float32(())),
|
||||
is_nullable: false
|
||||
),
|
||||
],
|
||||
columns: [
|
||||
(
|
||||
ty: Float64,
|
||||
len: 4
|
||||
),
|
||||
(
|
||||
ty: Float32,
|
||||
len: 4
|
||||
)
|
||||
]
|
||||
)
|
||||
),
|
||||
(
|
||||
// constant column(int)
|
||||
name: "test_neg_import_stdlib",
|
||||
code: r#"
|
||||
@copr(args=["cpu", "mem"], returns=["perf", "what"])
|
||||
def a(cpu: vector[f32], mem: vector[f64])->(vector[f64|None],
|
||||
vector[f32]):
|
||||
# test if module not in allow list can't be imported
|
||||
import fcntl
|
||||
return cpu + mem, 1
|
||||
"#,
|
||||
predicate: ExecIsErr(
|
||||
reason: "No module named 'fcntl'"
|
||||
)
|
||||
),
|
||||
(
|
||||
// constant column(int)
|
||||
name: "test_neg_import_depend_stdlib",
|
||||
code: r#"
|
||||
@copr(args=["cpu", "mem"], returns=["perf", "what"])
|
||||
def a(cpu: vector[f32], mem: vector[f64])->(vector[f64|None],
|
||||
vector[f32]):
|
||||
# test if module not in allow list can't be imported
|
||||
import mailbox
|
||||
return cpu + mem, 1
|
||||
"#,
|
||||
predicate: ExecIsErr(
|
||||
reason: "ModuleNotFoundError: No module named"
|
||||
)
|
||||
),
|
||||
]
|
||||
@@ -1,146 +0,0 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow::array::ArrayRef;
|
||||
use datafusion_common::ScalarValue;
|
||||
use datafusion_expr::ColumnarValue as DFColValue;
|
||||
use datatypes::prelude::ScalarVector;
|
||||
use datatypes::value::Value;
|
||||
use datatypes::vectors::{
|
||||
BooleanVector, Float64Vector, Helper, Int64Vector, StringVector, VectorRef,
|
||||
};
|
||||
use rustpython_vm::builtins::{PyBaseExceptionRef, PyBool, PyFloat, PyInt, PyList, PyStr};
|
||||
use rustpython_vm::object::PyObjectPayload;
|
||||
use rustpython_vm::{PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine};
|
||||
use snafu::{OptionExt, ResultExt};
|
||||
|
||||
use crate::python::error;
|
||||
use crate::python::error::ret_other_error_with;
|
||||
use crate::python::ffi_types::PyVector;
|
||||
use crate::python::rspython::builtins::try_into_columnar_value;
|
||||
|
||||
/// use `rustpython`'s `is_instance` method to check if a PyObject is a instance of class.
|
||||
/// if `PyResult` is Err, then this function return `false`
|
||||
pub fn is_instance<T: PyPayload>(obj: &PyObjectRef, vm: &VirtualMachine) -> bool {
|
||||
obj.is_instance(T::class(&vm.ctx).into(), vm)
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
pub fn obj_cast_to<T: PyObjectPayload>(
|
||||
obj: PyObjectRef,
|
||||
vm: &VirtualMachine,
|
||||
) -> PyResult<PyRef<T>> {
|
||||
obj.downcast::<T>().map_err(|e| {
|
||||
vm.new_type_error(format!(
|
||||
"Can't cast object into {}, actual type: {}",
|
||||
std::any::type_name::<T>(),
|
||||
e.class().name()
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
pub fn format_py_error(excep: PyBaseExceptionRef, vm: &VirtualMachine) -> error::Error {
|
||||
let mut msg = String::new();
|
||||
if let Err(e) = vm.write_exception(&mut msg, &excep) {
|
||||
return error::PyRuntimeSnafu {
|
||||
msg: format!("Failed to write exception msg, err: {e}"),
|
||||
}
|
||||
.build();
|
||||
}
|
||||
|
||||
error::PyRuntimeSnafu { msg }.build()
|
||||
}
|
||||
|
||||
pub(crate) fn py_obj_to_value(obj: &PyObjectRef, vm: &VirtualMachine) -> PyResult<Value> {
|
||||
macro_rules! obj2val {
|
||||
($OBJ: ident, $($PY_TYPE: ident => $RS_TYPE: ident => $VARIANT: ident),*) => {
|
||||
$(
|
||||
if is_instance::<$PY_TYPE>($OBJ, vm) {
|
||||
let val = $OBJ
|
||||
.to_owned()
|
||||
.try_into_value::<$RS_TYPE>(vm)?;
|
||||
Ok(Value::$VARIANT(val.into()))
|
||||
}
|
||||
)else*
|
||||
else {
|
||||
Err(vm.new_runtime_error(format!("can't convert obj {obj:?} to Value")))
|
||||
}
|
||||
};
|
||||
}
|
||||
obj2val!(obj,
|
||||
PyBool => bool => Boolean,
|
||||
PyInt => i64 => Int64,
|
||||
PyFloat => f64 => Float64,
|
||||
PyStr => String => String
|
||||
)
|
||||
}
|
||||
|
||||
/// convert a single PyVector or a number(a constant)(wrapping in PyObjectRef) into a Array(or a constant array)
|
||||
pub fn py_obj_to_vec(
|
||||
obj: &PyObjectRef,
|
||||
vm: &VirtualMachine,
|
||||
col_len: usize,
|
||||
) -> Result<VectorRef, error::Error> {
|
||||
// It's ugly, but we can't find a better way right now.
|
||||
if is_instance::<PyVector>(obj, vm) {
|
||||
let pyv = obj
|
||||
.payload::<PyVector>()
|
||||
.with_context(|| ret_other_error_with(format!("can't cast obj {obj:?} to PyVector")))?;
|
||||
Ok(pyv.as_vector_ref())
|
||||
} else if is_instance::<PyInt>(obj, vm) {
|
||||
let val = obj
|
||||
.clone()
|
||||
.try_into_value::<i64>(vm)
|
||||
.map_err(|e| format_py_error(e, vm))?;
|
||||
let ret = Int64Vector::from_iterator(std::iter::repeat(val).take(col_len));
|
||||
Ok(Arc::new(ret) as _)
|
||||
} else if is_instance::<PyFloat>(obj, vm) {
|
||||
let val = obj
|
||||
.clone()
|
||||
.try_into_value::<f64>(vm)
|
||||
.map_err(|e| format_py_error(e, vm))?;
|
||||
let ret = Float64Vector::from_iterator(std::iter::repeat(val).take(col_len));
|
||||
Ok(Arc::new(ret) as _)
|
||||
} else if is_instance::<PyBool>(obj, vm) {
|
||||
let val = obj
|
||||
.clone()
|
||||
.try_into_value::<bool>(vm)
|
||||
.map_err(|e| format_py_error(e, vm))?;
|
||||
|
||||
let ret = BooleanVector::from_iterator(std::iter::repeat(val).take(col_len));
|
||||
Ok(Arc::new(ret) as _)
|
||||
} else if is_instance::<PyStr>(obj, vm) {
|
||||
let val = obj
|
||||
.clone()
|
||||
.try_into_value::<String>(vm)
|
||||
.map_err(|e| format_py_error(e, vm))?;
|
||||
|
||||
let ret = StringVector::from_iterator(std::iter::repeat(val.as_str()).take(col_len));
|
||||
Ok(Arc::new(ret) as _)
|
||||
} else if is_instance::<PyList>(obj, vm) {
|
||||
let columnar_value =
|
||||
try_into_columnar_value(obj.clone(), vm).map_err(|e| format_py_error(e, vm))?;
|
||||
|
||||
match columnar_value {
|
||||
DFColValue::Scalar(ScalarValue::List(array)) => {
|
||||
Helper::try_into_vector(array as ArrayRef).context(error::TypeCastSnafu)
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
} else {
|
||||
ret_other_error_with(format!("Expect a vector or a constant, found {obj:?}")).fail()
|
||||
}
|
||||
}
|
||||
@@ -1,567 +0,0 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! PyVectors' rustpython specify methods
|
||||
|
||||
use arrow::compute::kernels::numeric;
|
||||
use common_time::date::Date;
|
||||
use common_time::datetime::DateTime;
|
||||
use common_time::timestamp::Timestamp;
|
||||
use crossbeam_utils::atomic::AtomicCell;
|
||||
use datatypes::arrow::array::{Array, BooleanArray};
|
||||
use datatypes::arrow::compute;
|
||||
use datatypes::arrow::datatypes::DataType as ArrowDataType;
|
||||
use datatypes::data_type::{ConcreteDataType, DataType};
|
||||
use datatypes::value::{self, OrderedFloat};
|
||||
use datatypes::vectors::Helper;
|
||||
use once_cell::sync::Lazy;
|
||||
use rustpython_vm::builtins::{PyBaseExceptionRef, PyBool, PyBytes, PyFloat, PyInt, PyNone, PyStr};
|
||||
use rustpython_vm::convert::ToPyResult;
|
||||
use rustpython_vm::function::{Either, OptionalArg, PyComparisonValue};
|
||||
use rustpython_vm::protocol::{PyMappingMethods, PyNumberMethods, PySequenceMethods};
|
||||
use rustpython_vm::types::{
|
||||
AsMapping, AsNumber, AsSequence, Comparable, PyComparisonOp, Representable,
|
||||
};
|
||||
use rustpython_vm::{
|
||||
atomic_func, pyclass as rspyclass, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult,
|
||||
VirtualMachine,
|
||||
};
|
||||
|
||||
use crate::python::ffi_types::vector::{
|
||||
arrow_rfloordiv, arrow_rsub, arrow_rtruediv, rspy_is_pyobj_scalar, wrap_result, PyVector,
|
||||
};
|
||||
use crate::python::rspython::utils::{is_instance, obj_cast_to};
|
||||
|
||||
fn to_type_error(vm: &'_ VirtualMachine) -> impl FnOnce(String) -> PyBaseExceptionRef + '_ {
|
||||
|msg: String| vm.new_type_error(msg)
|
||||
}
|
||||
|
||||
pub(crate) type PyVectorRef = PyRef<PyVector>;
|
||||
/// PyVector type wraps a greptime vector, impl multiply/div/add/sub opeerators etc.
|
||||
#[rspyclass(with(AsMapping, AsSequence, Comparable, AsNumber, Representable))]
|
||||
impl PyVector {
|
||||
#[pymethod]
|
||||
pub(crate) fn new(
|
||||
iterable: OptionalArg<PyObjectRef>,
|
||||
vm: &VirtualMachine,
|
||||
) -> PyResult<PyVector> {
|
||||
if let OptionalArg::Present(iterable) = iterable {
|
||||
let mut elements: Vec<PyObjectRef> = iterable.try_to_value(vm)?;
|
||||
|
||||
if elements.is_empty() {
|
||||
return Ok(PyVector::default());
|
||||
}
|
||||
|
||||
let datatype = get_concrete_type(&elements[0], vm)?;
|
||||
let mut buf = datatype.create_mutable_vector(elements.len());
|
||||
|
||||
for obj in elements.drain(..) {
|
||||
let val = if let Some(v) =
|
||||
pyobj_try_to_typed_val(obj.clone(), vm, Some(datatype.clone()))
|
||||
{
|
||||
v
|
||||
} else {
|
||||
return Err(vm.new_type_error(format!(
|
||||
"Can't cast pyobject {obj:?} into concrete type {datatype:?}",
|
||||
)));
|
||||
};
|
||||
// Safety: `pyobj_try_to_typed_val()` has checked the data type.
|
||||
buf.push_value_ref(val.as_value_ref());
|
||||
}
|
||||
|
||||
Ok(PyVector {
|
||||
vector: buf.to_vector(),
|
||||
})
|
||||
} else {
|
||||
Ok(PyVector::default())
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethod(name = "__radd__")]
|
||||
#[pymethod(magic)]
|
||||
fn add(zelf: PyObjectRef, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyVector> {
|
||||
let zelf = obj_cast_to::<PyVector>(zelf, vm)?;
|
||||
if rspy_is_pyobj_scalar(&other, vm) {
|
||||
zelf.rspy_scalar_arith_op(other, None, wrap_result(numeric::add), vm)
|
||||
} else {
|
||||
zelf.rspy_vector_arith_op(other, None, wrap_result(numeric::add), vm)
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethod(magic)]
|
||||
fn sub(zelf: PyObjectRef, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyVector> {
|
||||
let zelf = obj_cast_to::<PyVector>(zelf, vm)?;
|
||||
if rspy_is_pyobj_scalar(&other, vm) {
|
||||
zelf.rspy_scalar_arith_op(other, None, wrap_result(numeric::sub), vm)
|
||||
} else {
|
||||
zelf.rspy_vector_arith_op(other, None, wrap_result(numeric::sub), vm)
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethod(magic)]
|
||||
fn rsub(zelf: PyObjectRef, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyVector> {
|
||||
let zelf = obj_cast_to::<PyVector>(zelf, vm)?;
|
||||
if rspy_is_pyobj_scalar(&other, vm) {
|
||||
zelf.rspy_scalar_arith_op(other, None, arrow_rsub, vm)
|
||||
} else {
|
||||
zelf.rspy_vector_arith_op(other, None, wrap_result(|a, b| numeric::sub(b, a)), vm)
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethod(name = "__rmul__")]
|
||||
#[pymethod(magic)]
|
||||
fn mul(zelf: PyObjectRef, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyVector> {
|
||||
let zelf = obj_cast_to::<PyVector>(zelf, vm)?;
|
||||
if rspy_is_pyobj_scalar(&other, vm) {
|
||||
zelf.rspy_scalar_arith_op(other, None, wrap_result(numeric::mul), vm)
|
||||
} else {
|
||||
zelf.rspy_vector_arith_op(other, None, wrap_result(numeric::mul), vm)
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethod(magic)]
|
||||
fn truediv(zelf: PyObjectRef, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyVector> {
|
||||
let zelf = obj_cast_to::<PyVector>(zelf, vm)?;
|
||||
if rspy_is_pyobj_scalar(&other, vm) {
|
||||
zelf.rspy_scalar_arith_op(
|
||||
other,
|
||||
Some(ArrowDataType::Float64),
|
||||
wrap_result(numeric::div),
|
||||
vm,
|
||||
)
|
||||
} else {
|
||||
zelf.rspy_vector_arith_op(
|
||||
other,
|
||||
Some(ArrowDataType::Float64),
|
||||
wrap_result(numeric::div),
|
||||
vm,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethod(magic)]
|
||||
fn rtruediv(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyVector> {
|
||||
if rspy_is_pyobj_scalar(&other, vm) {
|
||||
self.rspy_scalar_arith_op(other, Some(ArrowDataType::Float64), arrow_rtruediv, vm)
|
||||
} else {
|
||||
self.rspy_vector_arith_op(
|
||||
other,
|
||||
Some(ArrowDataType::Float64),
|
||||
wrap_result(|a, b| numeric::div(b, a)),
|
||||
vm,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethod(magic)]
|
||||
fn floordiv(zelf: PyObjectRef, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyVector> {
|
||||
let zelf = obj_cast_to::<PyVector>(zelf, vm)?;
|
||||
if rspy_is_pyobj_scalar(&other, vm) {
|
||||
zelf.rspy_scalar_arith_op(
|
||||
other,
|
||||
Some(ArrowDataType::Int64),
|
||||
wrap_result(numeric::div),
|
||||
vm,
|
||||
)
|
||||
} else {
|
||||
zelf.rspy_vector_arith_op(
|
||||
other,
|
||||
Some(ArrowDataType::Int64),
|
||||
wrap_result(numeric::div),
|
||||
vm,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethod(magic)]
|
||||
fn rfloordiv(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyVector> {
|
||||
if rspy_is_pyobj_scalar(&other, vm) {
|
||||
// FIXME: DataType convert problem, target_type should be inferred?
|
||||
self.rspy_scalar_arith_op(other, Some(ArrowDataType::Int64), arrow_rfloordiv, vm)
|
||||
} else {
|
||||
self.rspy_vector_arith_op(
|
||||
other,
|
||||
Some(ArrowDataType::Int64),
|
||||
wrap_result(|a, b| numeric::div(b, a)),
|
||||
vm,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn obj_to_vector(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyRef<PyVector>> {
|
||||
obj.downcast::<PyVector>().map_err(|e| {
|
||||
vm.new_type_error(format!(
|
||||
"Can't cast right operand into PyVector, actual type: {}",
|
||||
e.class().name()
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
#[pymethod(magic)]
|
||||
fn and(zelf: PyObjectRef, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyVector> {
|
||||
let zelf = Self::obj_to_vector(zelf, vm)?;
|
||||
let other = Self::obj_to_vector(other, vm)?;
|
||||
Self::vector_and(&zelf, &other).map_err(to_type_error(vm))
|
||||
}
|
||||
|
||||
#[pymethod(magic)]
|
||||
fn or(zelf: PyObjectRef, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyVector> {
|
||||
let zelf = Self::obj_to_vector(zelf, vm)?;
|
||||
let other = Self::obj_to_vector(other, vm)?;
|
||||
Self::vector_or(&zelf, &other).map_err(to_type_error(vm))
|
||||
}
|
||||
|
||||
#[pymethod(magic)]
|
||||
fn invert(zelf: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyVector> {
|
||||
let zelf = Self::obj_to_vector(zelf, vm)?;
|
||||
Self::vector_invert(&zelf).map_err(to_type_error(vm))
|
||||
}
|
||||
|
||||
#[pymethod(name = "__len__")]
|
||||
fn len_rspy(&self) -> usize {
|
||||
self.len()
|
||||
}
|
||||
|
||||
#[pymethod(name = "concat")]
|
||||
fn concat(&self, other: PyVectorRef, vm: &VirtualMachine) -> PyResult<PyVector> {
|
||||
let left = self.to_arrow_array();
|
||||
let right = other.to_arrow_array();
|
||||
|
||||
let res = compute::concat(&[left.as_ref(), right.as_ref()]);
|
||||
let res = res.map_err(|err| vm.new_runtime_error(format!("Arrow Error: {err:#?}")))?;
|
||||
let ret = Helper::try_into_vector(res.clone()).map_err(|e| {
|
||||
vm.new_type_error(format!(
|
||||
"Can't cast result into vector, result: {res:?}, err: {e:?}",
|
||||
))
|
||||
})?;
|
||||
Ok(ret.into())
|
||||
}
|
||||
|
||||
/// take a boolean array and filters the Array, returning elements matching the filter (i.e. where the values are true).
|
||||
#[pymethod(name = "filter")]
|
||||
fn filter(&self, other: PyVectorRef, vm: &VirtualMachine) -> PyResult<PyVector> {
|
||||
let left = self.to_arrow_array();
|
||||
let right = other.to_arrow_array();
|
||||
let filter = right.as_any().downcast_ref::<BooleanArray>();
|
||||
match filter {
|
||||
Some(filter) => {
|
||||
let res = compute::filter(left.as_ref(), filter);
|
||||
|
||||
let res =
|
||||
res.map_err(|err| vm.new_runtime_error(format!("Arrow Error: {err:#?}")))?;
|
||||
let ret = Helper::try_into_vector(res.clone()).map_err(|e| {
|
||||
vm.new_type_error(format!(
|
||||
"Can't cast result into vector, result: {res:?}, err: {e:?}",
|
||||
))
|
||||
})?;
|
||||
Ok(ret.into())
|
||||
}
|
||||
None => Err(vm.new_runtime_error(format!(
|
||||
"Can't cast operand into a Boolean Array, which is {right:#?}"
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethod(magic)]
|
||||
fn doc(&self) -> PyResult<PyStr> {
|
||||
Ok(PyStr::from(
|
||||
"PyVector is like a Python array, a compact array of elem of same datatype, but Readonly for now",
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl Representable for PyVector {
|
||||
#[inline]
|
||||
fn repr_str(zelf: &Py<Self>, _vm: &VirtualMachine) -> PyResult<String> {
|
||||
Ok(format!("{:#?}", *zelf))
|
||||
}
|
||||
}
|
||||
|
||||
impl AsNumber for PyVector {
|
||||
fn as_number() -> &'static PyNumberMethods {
|
||||
// FIXME(discord9): have to use `&PyObject.to_owned()` here
|
||||
// because it seems to be the only way to convert a `&PyObject` to `PyObjectRef`.
|
||||
static AS_NUMBER: PyNumberMethods = PyNumberMethods {
|
||||
and: Some(|a, b, vm| PyVector::and(a.to_owned(), b.to_owned(), vm).to_pyresult(vm)),
|
||||
or: Some(|a, b, vm| PyVector::or(a.to_owned(), b.to_owned(), vm).to_pyresult(vm)),
|
||||
invert: Some(|a, vm| PyVector::invert((*a).to_owned(), vm).to_pyresult(vm)),
|
||||
add: Some(|a, b, vm| PyVector::add(a.to_owned(), b.to_owned(), vm).to_pyresult(vm)),
|
||||
subtract: Some(|a, b, vm| {
|
||||
PyVector::sub(a.to_owned(), b.to_owned(), vm).to_pyresult(vm)
|
||||
}),
|
||||
multiply: Some(|a, b, vm| {
|
||||
PyVector::mul(a.to_owned(), b.to_owned(), vm).to_pyresult(vm)
|
||||
}),
|
||||
true_divide: Some(|a, b, vm| {
|
||||
PyVector::truediv(a.to_owned(), b.to_owned(), vm).to_pyresult(vm)
|
||||
}),
|
||||
floor_divide: Some(|a, b, vm| {
|
||||
PyVector::floordiv(a.to_owned(), b.to_owned(), vm).to_pyresult(vm)
|
||||
}),
|
||||
..PyNumberMethods::NOT_IMPLEMENTED
|
||||
};
|
||||
&AS_NUMBER
|
||||
}
|
||||
}
|
||||
|
||||
impl AsMapping for PyVector {
|
||||
fn as_mapping() -> &'static PyMappingMethods {
|
||||
static AS_MAPPING: PyMappingMethods = PyMappingMethods {
|
||||
length: atomic_func!(|mapping, _vm| Ok(PyVector::mapping_downcast(mapping).len())),
|
||||
subscript: atomic_func!(
|
||||
|mapping, needle, vm| PyVector::mapping_downcast(mapping)._getitem(needle, vm)
|
||||
),
|
||||
ass_subscript: AtomicCell::new(None),
|
||||
};
|
||||
&AS_MAPPING
|
||||
}
|
||||
}
|
||||
|
||||
impl AsSequence for PyVector {
|
||||
fn as_sequence() -> &'static PySequenceMethods {
|
||||
static AS_SEQUENCE: Lazy<PySequenceMethods> = Lazy::new(|| PySequenceMethods {
|
||||
length: atomic_func!(|seq, _vm| Ok(PyVector::sequence_downcast(seq).len())),
|
||||
item: atomic_func!(|seq, i, vm| {
|
||||
let zelf = PyVector::sequence_downcast(seq);
|
||||
zelf.getitem_by_index(i, vm)
|
||||
}),
|
||||
ass_item: atomic_func!(|_seq, _i, _value, vm| {
|
||||
Err(vm.new_type_error("PyVector object doesn't support item assigns".to_string()))
|
||||
}),
|
||||
..PySequenceMethods::NOT_IMPLEMENTED
|
||||
});
|
||||
&AS_SEQUENCE
|
||||
}
|
||||
}
|
||||
|
||||
impl Comparable for PyVector {
|
||||
fn slot_richcompare(
|
||||
zelf: &PyObject,
|
||||
other: &PyObject,
|
||||
op: PyComparisonOp,
|
||||
vm: &VirtualMachine,
|
||||
) -> PyResult<Either<PyObjectRef, PyComparisonValue>> {
|
||||
if let Some(zelf) = zelf.downcast_ref::<Self>() {
|
||||
let ret: PyVector = zelf.richcompare(other.to_owned(), op, vm)?;
|
||||
let ret = ret.into_pyobject(vm);
|
||||
Ok(Either::A(ret))
|
||||
} else {
|
||||
Err(vm.new_type_error(format!(
|
||||
"unexpected payload {:?} for {}",
|
||||
zelf,
|
||||
op.method_name(&vm.ctx).as_str()
|
||||
)))
|
||||
}
|
||||
}
|
||||
fn cmp(
|
||||
_zelf: &rustpython_vm::Py<Self>,
|
||||
_other: &PyObject,
|
||||
_op: PyComparisonOp,
|
||||
_vm: &VirtualMachine,
|
||||
) -> PyResult<PyComparisonValue> {
|
||||
Ok(PyComparisonValue::NotImplemented)
|
||||
}
|
||||
}
|
||||
|
||||
fn get_concrete_type(obj: &PyObjectRef, vm: &VirtualMachine) -> PyResult<ConcreteDataType> {
|
||||
if is_instance::<PyNone>(obj, vm) {
|
||||
Ok(ConcreteDataType::null_datatype())
|
||||
} else if is_instance::<PyBool>(obj, vm) {
|
||||
Ok(ConcreteDataType::boolean_datatype())
|
||||
} else if is_instance::<PyInt>(obj, vm) {
|
||||
Ok(ConcreteDataType::int64_datatype())
|
||||
} else if is_instance::<PyFloat>(obj, vm) {
|
||||
Ok(ConcreteDataType::float64_datatype())
|
||||
} else if is_instance::<PyStr>(obj, vm) {
|
||||
Ok(ConcreteDataType::string_datatype())
|
||||
} else {
|
||||
Err(vm.new_type_error(format!("Unsupported pyobject type: {obj:?}")))
|
||||
}
|
||||
}
|
||||
|
||||
/// convert a `PyObjectRef` into a `datatypes::Value`(is that ok?)
|
||||
/// if `obj` can be convert to given ConcreteDataType then return inner `Value` else return None
|
||||
/// if dtype is None, return types with highest precision
|
||||
/// Not used for now but may be use in future
|
||||
pub(crate) fn pyobj_try_to_typed_val(
|
||||
obj: PyObjectRef,
|
||||
vm: &VirtualMachine,
|
||||
dtype: Option<ConcreteDataType>,
|
||||
) -> Option<value::Value> {
|
||||
// TODO(discord9): use `PyResult` instead of `Option` for better error handling
|
||||
if let Some(dtype) = dtype {
|
||||
match dtype {
|
||||
ConcreteDataType::Null(_) => {
|
||||
if is_instance::<PyNone>(&obj, vm) {
|
||||
Some(value::Value::Null)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ConcreteDataType::Boolean(_) => {
|
||||
if is_instance::<PyBool>(&obj, vm) || is_instance::<PyInt>(&obj, vm) {
|
||||
Some(value::Value::Boolean(
|
||||
obj.try_into_value::<bool>(vm).unwrap_or(false),
|
||||
))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ConcreteDataType::Int8(_)
|
||||
| ConcreteDataType::Int16(_)
|
||||
| ConcreteDataType::Int32(_)
|
||||
| ConcreteDataType::Int64(_) => {
|
||||
if is_instance::<PyInt>(&obj, vm) {
|
||||
match dtype {
|
||||
ConcreteDataType::Int8(_) => {
|
||||
obj.try_into_value::<i8>(vm).ok().map(value::Value::Int8)
|
||||
}
|
||||
ConcreteDataType::Int16(_) => {
|
||||
obj.try_into_value::<i16>(vm).ok().map(value::Value::Int16)
|
||||
}
|
||||
ConcreteDataType::Int32(_) => {
|
||||
obj.try_into_value::<i32>(vm).ok().map(value::Value::Int32)
|
||||
}
|
||||
ConcreteDataType::Int64(_) => {
|
||||
obj.try_into_value::<i64>(vm).ok().map(value::Value::Int64)
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ConcreteDataType::UInt8(_)
|
||||
| ConcreteDataType::UInt16(_)
|
||||
| ConcreteDataType::UInt32(_)
|
||||
| ConcreteDataType::UInt64(_) => {
|
||||
if is_instance::<PyInt>(&obj, vm)
|
||||
&& obj.clone().try_into_value::<i64>(vm).unwrap_or(-1) >= 0
|
||||
{
|
||||
match dtype {
|
||||
ConcreteDataType::UInt8(_) => {
|
||||
obj.try_into_value::<u8>(vm).ok().map(value::Value::UInt8)
|
||||
}
|
||||
ConcreteDataType::UInt16(_) => {
|
||||
obj.try_into_value::<u16>(vm).ok().map(value::Value::UInt16)
|
||||
}
|
||||
ConcreteDataType::UInt32(_) => {
|
||||
obj.try_into_value::<u32>(vm).ok().map(value::Value::UInt32)
|
||||
}
|
||||
ConcreteDataType::UInt64(_) => {
|
||||
obj.try_into_value::<u64>(vm).ok().map(value::Value::UInt64)
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ConcreteDataType::Float32(_) | ConcreteDataType::Float64(_) => {
|
||||
if is_instance::<PyFloat>(&obj, vm) {
|
||||
match dtype {
|
||||
ConcreteDataType::Float32(_) => obj
|
||||
.try_into_value::<f32>(vm)
|
||||
.ok()
|
||||
.map(|v| value::Value::Float32(OrderedFloat(v))),
|
||||
ConcreteDataType::Float64(_) => obj
|
||||
.try_into_value::<f64>(vm)
|
||||
.ok()
|
||||
.map(|v| value::Value::Float64(OrderedFloat(v))),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
ConcreteDataType::String(_) => {
|
||||
if is_instance::<PyStr>(&obj, vm) {
|
||||
obj.try_into_value::<String>(vm)
|
||||
.ok()
|
||||
.map(|v| value::Value::String(v.into()))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ConcreteDataType::Binary(_) => {
|
||||
if is_instance::<PyBytes>(&obj, vm) {
|
||||
obj.try_into_value::<Vec<u8>>(vm).ok().and_then(|v| {
|
||||
String::from_utf8(v)
|
||||
.ok()
|
||||
.map(|v| value::Value::String(v.into()))
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ConcreteDataType::Date(_)
|
||||
| ConcreteDataType::DateTime(_)
|
||||
| ConcreteDataType::Timestamp(_) => {
|
||||
if is_instance::<PyInt>(&obj, vm) {
|
||||
match dtype {
|
||||
ConcreteDataType::Date(_) => obj
|
||||
.try_into_value::<i32>(vm)
|
||||
.ok()
|
||||
.map(Date::new)
|
||||
.map(value::Value::Date),
|
||||
ConcreteDataType::DateTime(_) => obj
|
||||
.try_into_value::<i64>(vm)
|
||||
.ok()
|
||||
.map(DateTime::new)
|
||||
.map(value::Value::DateTime),
|
||||
ConcreteDataType::Timestamp(_) => {
|
||||
// FIXME(dennis): we always consider the timestamp unit is millis, it's not correct if user define timestamp column with other units.
|
||||
obj.try_into_value::<i64>(vm)
|
||||
.ok()
|
||||
.map(Timestamp::new_millisecond)
|
||||
.map(value::Value::Timestamp)
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
} else if is_instance::<PyNone>(&obj, vm) {
|
||||
// if Untyped then by default return types with highest precision
|
||||
Some(value::Value::Null)
|
||||
} else if is_instance::<PyBool>(&obj, vm) {
|
||||
Some(value::Value::Boolean(
|
||||
obj.try_into_value::<bool>(vm).unwrap_or(false),
|
||||
))
|
||||
} else if is_instance::<PyInt>(&obj, vm) {
|
||||
obj.try_into_value::<i64>(vm).ok().map(value::Value::Int64)
|
||||
} else if is_instance::<PyFloat>(&obj, vm) {
|
||||
obj.try_into_value::<f64>(vm)
|
||||
.ok()
|
||||
.map(|v| value::Value::Float64(OrderedFloat(v)))
|
||||
} else if is_instance::<PyStr>(&obj, vm) {
|
||||
obj.try_into_value::<Vec<u8>>(vm).ok().and_then(|v| {
|
||||
String::from_utf8(v)
|
||||
.ok()
|
||||
.map(|v| value::Value::String(v.into()))
|
||||
})
|
||||
} else if is_instance::<PyBytes>(&obj, vm) {
|
||||
obj.try_into_value::<Vec<u8>>(vm).ok().and_then(|v| {
|
||||
String::from_utf8(v)
|
||||
.ok()
|
||||
.map(|v| value::Value::String(v.into()))
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
@@ -1,58 +0,0 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use common_runtime::runtime::RuntimeTrait;
|
||||
use common_runtime::JoinHandle;
|
||||
use futures::Future;
|
||||
use rustpython_vm::builtins::PyBaseExceptionRef;
|
||||
use rustpython_vm::VirtualMachine;
|
||||
|
||||
use crate::python::error;
|
||||
|
||||
pub fn format_py_error(excep: PyBaseExceptionRef, vm: &VirtualMachine) -> error::Error {
|
||||
let mut msg = String::new();
|
||||
if let Err(e) = vm.write_exception(&mut msg, &excep) {
|
||||
return error::PyRuntimeSnafu {
|
||||
msg: format!("Failed to write exception msg, err: {e}"),
|
||||
}
|
||||
.build();
|
||||
}
|
||||
error::PyRuntimeSnafu { msg }.build()
|
||||
}
|
||||
|
||||
/// just like [`tokio::task::spawn_blocking`] but using a dedicated runtime(runtime `bg`) using by `scripts` crate
|
||||
pub fn spawn_blocking_script<F, R>(f: F) -> JoinHandle<R>
|
||||
where
|
||||
F: FnOnce() -> R + Send + 'static,
|
||||
R: Send + 'static,
|
||||
{
|
||||
common_runtime::spawn_blocking_global(f)
|
||||
}
|
||||
|
||||
/// Please only use this method because you are calling from (optionally first as async) to sync then to a async
|
||||
/// a terrible hack to call async from sync by:
|
||||
///
|
||||
/// TODO(discord9): find a better way
|
||||
/// 1. using a cached runtime
|
||||
/// 2. block on that runtime
|
||||
pub fn block_on_async<T, F>(f: F) -> std::thread::Result<T>
|
||||
where
|
||||
F: Future<Output = T> + Send + 'static,
|
||||
T: Send + 'static,
|
||||
{
|
||||
let rt = common_runtime::global_runtime();
|
||||
// spawn a thread to block on the runtime, also should prevent `start a runtime inside of runtime` error
|
||||
// it's ok to block here, assume calling from async to sync is using a `spawn_blocking_*` call
|
||||
std::thread::spawn(move || rt.block_on(f)).join()
|
||||
}
|
||||
@@ -1,376 +0,0 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Scripts table
|
||||
use std::sync::Arc;
|
||||
|
||||
use api::v1::greptime_request::Request;
|
||||
use api::v1::value::ValueData;
|
||||
use api::v1::{
|
||||
ColumnDataType, ColumnDef, ColumnSchema as PbColumnSchema, Row, RowInsertRequest,
|
||||
RowInsertRequests, Rows, SemanticType,
|
||||
};
|
||||
use catalog::error::CompileScriptInternalSnafu;
|
||||
use common_error::ext::{BoxedError, ErrorExt};
|
||||
use common_query::OutputData;
|
||||
use common_recordbatch::{util as record_util, RecordBatch, SendableRecordBatchStream};
|
||||
use common_telemetry::{debug, info, warn};
|
||||
use common_time::util;
|
||||
use datafusion::datasource::DefaultTableSource;
|
||||
use datafusion::logical_expr::{and, col, lit};
|
||||
use datafusion_common::TableReference;
|
||||
use datafusion_expr::LogicalPlanBuilder;
|
||||
use datatypes::prelude::ScalarVector;
|
||||
use datatypes::vectors::{StringVector, Vector};
|
||||
use query::QueryEngineRef;
|
||||
use servers::query_handler::grpc::GrpcQueryHandlerRef;
|
||||
use session::context::{QueryContextBuilder, QueryContextRef};
|
||||
use snafu::{ensure, OptionExt, ResultExt};
|
||||
use table::metadata::TableInfo;
|
||||
use table::table::adapter::DfTableProviderAdapter;
|
||||
use table::TableRef;
|
||||
|
||||
use crate::error::{
|
||||
BuildDfLogicalPlanSnafu, CastTypeSnafu, CollectRecordsSnafu, ExecuteInternalStatementSnafu,
|
||||
FindColumnInScriptsTableSnafu, InsertScriptSnafu, Result, ScriptNotFoundSnafu,
|
||||
};
|
||||
use crate::python::PyScript;
|
||||
|
||||
pub const SCRIPTS_TABLE_NAME: &str = "scripts";
|
||||
|
||||
pub type ScriptsTableRef<E> = Arc<ScriptsTable<E>>;
|
||||
|
||||
/// The scripts table that keeps the script content etc.
|
||||
pub struct ScriptsTable<E: ErrorExt + Send + Sync + 'static> {
|
||||
table: TableRef,
|
||||
grpc_handler: GrpcQueryHandlerRef<E>,
|
||||
query_engine: QueryEngineRef,
|
||||
}
|
||||
|
||||
impl<E: ErrorExt + Send + Sync + 'static> ScriptsTable<E> {
|
||||
/// Create a new `[ScriptsTable]` based on the table.
|
||||
pub fn new(
|
||||
table: TableRef,
|
||||
grpc_handler: GrpcQueryHandlerRef<E>,
|
||||
query_engine: QueryEngineRef,
|
||||
) -> Self {
|
||||
Self {
|
||||
table,
|
||||
grpc_handler,
|
||||
query_engine,
|
||||
}
|
||||
}
|
||||
|
||||
fn get_str_col_by_name<'a>(record: &'a RecordBatch, name: &str) -> Result<&'a StringVector> {
|
||||
let column = record
|
||||
.column_by_name(name)
|
||||
.with_context(|| FindColumnInScriptsTableSnafu { name })?;
|
||||
let column = column
|
||||
.as_any()
|
||||
.downcast_ref::<StringVector>()
|
||||
.with_context(|| CastTypeSnafu {
|
||||
msg: format!(
|
||||
"can't downcast {:?} array into string vector",
|
||||
column.data_type()
|
||||
),
|
||||
})?;
|
||||
Ok(column)
|
||||
}
|
||||
/// this is used as a callback function when scripts table is created. `table` should be `scripts` table.
|
||||
/// the function will try it best to register all scripts, and ignore the error in parsing and register scripts
|
||||
/// if any, just emit a warning
|
||||
/// TODO(discord9): rethink error handling here
|
||||
pub async fn recompile_register_udf(
|
||||
table: TableRef,
|
||||
query_engine: QueryEngineRef,
|
||||
) -> catalog::error::Result<()> {
|
||||
let table_info = table.table_info();
|
||||
|
||||
let rbs = Self::table_full_scan(table, &query_engine)
|
||||
.await
|
||||
.map_err(BoxedError::new)
|
||||
.context(CompileScriptInternalSnafu)?;
|
||||
let records = record_util::collect(rbs)
|
||||
.await
|
||||
.map_err(BoxedError::new)
|
||||
.context(CompileScriptInternalSnafu)?;
|
||||
|
||||
let mut script_list: Vec<(String, String)> = Vec::new();
|
||||
for record in records {
|
||||
let names = Self::get_str_col_by_name(&record, "name")
|
||||
.map_err(BoxedError::new)
|
||||
.context(CompileScriptInternalSnafu)?;
|
||||
let scripts = Self::get_str_col_by_name(&record, "script")
|
||||
.map_err(BoxedError::new)
|
||||
.context(CompileScriptInternalSnafu)?;
|
||||
|
||||
let part_of_scripts_list =
|
||||
names
|
||||
.iter_data()
|
||||
.zip(scripts.iter_data())
|
||||
.filter_map(|i| match i {
|
||||
(Some(a), Some(b)) => Some((a.to_string(), b.to_string())),
|
||||
_ => None,
|
||||
});
|
||||
script_list.extend(part_of_scripts_list);
|
||||
}
|
||||
|
||||
info!(
|
||||
"Found {} scripts in {}",
|
||||
script_list.len(),
|
||||
table_info.full_table_name()
|
||||
);
|
||||
|
||||
for (name, script) in script_list {
|
||||
match PyScript::from_script(&script, query_engine.clone()) {
|
||||
Ok(script) => {
|
||||
script.register_udf().await;
|
||||
debug!(
|
||||
"Script in `scripts` system table re-register as UDF: {}",
|
||||
name
|
||||
);
|
||||
}
|
||||
Err(err) => {
|
||||
warn!(
|
||||
r#"Failed to compile script "{}"" in `scripts` table: {:?}"#,
|
||||
name, err
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn insert(&self, schema: &str, name: &str, script: &str) -> Result<()> {
|
||||
let now = util::current_time_millis();
|
||||
|
||||
let table_info = self.table.table_info();
|
||||
|
||||
let insert = RowInsertRequest {
|
||||
table_name: SCRIPTS_TABLE_NAME.to_string(),
|
||||
rows: Some(Rows {
|
||||
schema: build_insert_column_schemas(),
|
||||
rows: vec![Row {
|
||||
values: vec![
|
||||
ValueData::StringValue(schema.to_string()).into(),
|
||||
ValueData::StringValue(name.to_string()).into(),
|
||||
// TODO(dennis): we only supports python right now.
|
||||
ValueData::StringValue("python".to_string()).into(),
|
||||
ValueData::StringValue(script.to_string()).into(),
|
||||
// Timestamp in key part is intentionally left to 0
|
||||
ValueData::TimestampMillisecondValue(0).into(),
|
||||
ValueData::TimestampMillisecondValue(now).into(),
|
||||
],
|
||||
}],
|
||||
}),
|
||||
};
|
||||
|
||||
let requests = RowInsertRequests {
|
||||
inserts: vec![insert],
|
||||
};
|
||||
|
||||
let output = self
|
||||
.grpc_handler
|
||||
.do_query(Request::RowInserts(requests), query_ctx(&table_info))
|
||||
.await
|
||||
.map_err(BoxedError::new)
|
||||
.context(InsertScriptSnafu { name })?;
|
||||
|
||||
info!(
|
||||
"Inserted script: {} into scripts table: {}, output: {:?}.",
|
||||
name,
|
||||
table_info.full_table_name(),
|
||||
output
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn find_script_by_name(&self, schema: &str, name: &str) -> Result<String> {
|
||||
let table_info = self.table.table_info();
|
||||
|
||||
let table_name = TableReference::full(
|
||||
table_info.catalog_name.clone(),
|
||||
table_info.schema_name.clone(),
|
||||
table_info.name.clone(),
|
||||
);
|
||||
|
||||
let table_provider = Arc::new(DfTableProviderAdapter::new(self.table.clone()));
|
||||
let table_source = Arc::new(DefaultTableSource::new(table_provider));
|
||||
|
||||
let plan = LogicalPlanBuilder::scan(table_name, table_source, None)
|
||||
.context(BuildDfLogicalPlanSnafu)?
|
||||
.filter(and(
|
||||
col("schema").eq(lit(schema)),
|
||||
col("name").eq(lit(name)),
|
||||
))
|
||||
.context(BuildDfLogicalPlanSnafu)?
|
||||
.project(vec![col("script")])
|
||||
.context(BuildDfLogicalPlanSnafu)?
|
||||
.build()
|
||||
.context(BuildDfLogicalPlanSnafu)?;
|
||||
|
||||
let output = self
|
||||
.query_engine
|
||||
.execute(plan, query_ctx(&table_info))
|
||||
.await
|
||||
.context(ExecuteInternalStatementSnafu)?;
|
||||
let stream = match output.data {
|
||||
OutputData::Stream(stream) => stream,
|
||||
OutputData::RecordBatches(record_batches) => record_batches.as_stream(),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let records = record_util::collect(stream)
|
||||
.await
|
||||
.context(CollectRecordsSnafu)?;
|
||||
|
||||
ensure!(!records.is_empty(), ScriptNotFoundSnafu { name });
|
||||
|
||||
assert_eq!(records.len(), 1);
|
||||
assert_eq!(records[0].num_columns(), 1);
|
||||
|
||||
let script_column = records[0].column(0);
|
||||
let script_column = script_column
|
||||
.as_any()
|
||||
.downcast_ref::<StringVector>()
|
||||
.with_context(|| CastTypeSnafu {
|
||||
msg: format!(
|
||||
"can't downcast {:?} array into string vector",
|
||||
script_column.data_type()
|
||||
),
|
||||
})?;
|
||||
|
||||
assert_eq!(script_column.len(), 1);
|
||||
|
||||
// Safety: asserted above
|
||||
Ok(script_column.get_data(0).unwrap().to_string())
|
||||
}
|
||||
|
||||
async fn table_full_scan(
|
||||
table: TableRef,
|
||||
query_engine: &QueryEngineRef,
|
||||
) -> Result<SendableRecordBatchStream> {
|
||||
let table_info = table.table_info();
|
||||
let table_name = TableReference::full(
|
||||
table_info.catalog_name.clone(),
|
||||
table_info.schema_name.clone(),
|
||||
table_info.name.clone(),
|
||||
);
|
||||
|
||||
let table_provider = Arc::new(DfTableProviderAdapter::new(table));
|
||||
let table_source = Arc::new(DefaultTableSource::new(table_provider));
|
||||
|
||||
let plan = LogicalPlanBuilder::scan(table_name, table_source, None)
|
||||
.context(BuildDfLogicalPlanSnafu)?
|
||||
.build()
|
||||
.context(BuildDfLogicalPlanSnafu)?;
|
||||
|
||||
let output = query_engine
|
||||
.execute(plan, query_ctx(&table_info))
|
||||
.await
|
||||
.context(ExecuteInternalStatementSnafu)?;
|
||||
let stream = match output.data {
|
||||
OutputData::Stream(stream) => stream,
|
||||
OutputData::RecordBatches(record_batches) => record_batches.as_stream(),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
Ok(stream)
|
||||
}
|
||||
}
|
||||
|
||||
/// Build the inserted column schemas
|
||||
fn build_insert_column_schemas() -> Vec<PbColumnSchema> {
|
||||
vec![
|
||||
// The schema that script belongs to.
|
||||
PbColumnSchema {
|
||||
column_name: "schema".to_string(),
|
||||
datatype: ColumnDataType::String.into(),
|
||||
semantic_type: SemanticType::Tag.into(),
|
||||
..Default::default()
|
||||
},
|
||||
PbColumnSchema {
|
||||
column_name: "name".to_string(),
|
||||
datatype: ColumnDataType::String.into(),
|
||||
semantic_type: SemanticType::Tag.into(),
|
||||
..Default::default()
|
||||
},
|
||||
PbColumnSchema {
|
||||
column_name: "engine".to_string(),
|
||||
datatype: ColumnDataType::String.into(),
|
||||
semantic_type: SemanticType::Tag.into(),
|
||||
..Default::default()
|
||||
},
|
||||
PbColumnSchema {
|
||||
column_name: "script".to_string(),
|
||||
datatype: ColumnDataType::String.into(),
|
||||
semantic_type: SemanticType::Field.into(),
|
||||
..Default::default()
|
||||
},
|
||||
PbColumnSchema {
|
||||
column_name: "greptime_timestamp".to_string(),
|
||||
datatype: ColumnDataType::TimestampMillisecond.into(),
|
||||
semantic_type: SemanticType::Timestamp.into(),
|
||||
..Default::default()
|
||||
},
|
||||
PbColumnSchema {
|
||||
column_name: "gmt_modified".to_string(),
|
||||
datatype: ColumnDataType::TimestampMillisecond.into(),
|
||||
semantic_type: SemanticType::Field.into(),
|
||||
..Default::default()
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
fn query_ctx(table_info: &TableInfo) -> QueryContextRef {
|
||||
QueryContextBuilder::default()
|
||||
.current_catalog(table_info.catalog_name.to_string())
|
||||
.current_schema(table_info.schema_name.to_string())
|
||||
.build()
|
||||
.into()
|
||||
}
|
||||
|
||||
/// Builds scripts schema, returns (time index, primary keys, column defs)
|
||||
pub fn build_scripts_schema() -> (String, Vec<String>, Vec<ColumnDef>) {
|
||||
let cols = build_insert_column_schemas();
|
||||
|
||||
let time_index = cols
|
||||
.iter()
|
||||
.find_map(|c| {
|
||||
(c.semantic_type == (SemanticType::Timestamp as i32)).then(|| c.column_name.clone())
|
||||
})
|
||||
.unwrap(); // Safety: the column always exists
|
||||
|
||||
let primary_keys = cols
|
||||
.iter()
|
||||
.filter(|c| (c.semantic_type == (SemanticType::Tag as i32)))
|
||||
.map(|c| c.column_name.clone())
|
||||
.collect();
|
||||
|
||||
let column_defs = cols
|
||||
.into_iter()
|
||||
.map(|c| ColumnDef {
|
||||
name: c.column_name,
|
||||
data_type: c.datatype,
|
||||
is_nullable: false,
|
||||
default_constraint: vec![],
|
||||
semantic_type: c.semantic_type,
|
||||
comment: "".to_string(),
|
||||
datatype_extension: None,
|
||||
options: c.options,
|
||||
})
|
||||
.collect();
|
||||
|
||||
(time_index, primary_keys, column_defs)
|
||||
}
|
||||
@@ -1,78 +0,0 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use api::v1::greptime_request::Request;
|
||||
use async_trait::async_trait;
|
||||
use catalog::memory::MemoryCatalogManager;
|
||||
use common_query::Output;
|
||||
use common_recordbatch::RecordBatch;
|
||||
use datatypes::prelude::ConcreteDataType;
|
||||
use datatypes::schema::{ColumnSchema, Schema};
|
||||
use datatypes::vectors::{StringVector, VectorRef};
|
||||
use query::QueryEngineFactory;
|
||||
use servers::query_handler::grpc::GrpcQueryHandler;
|
||||
use session::context::QueryContextRef;
|
||||
use table::test_util::MemTable;
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
use crate::manager::ScriptManager;
|
||||
|
||||
/// Setup the scripts table and create a script manager.
|
||||
pub async fn setup_scripts_manager(
|
||||
catalog: &str,
|
||||
schema: &str,
|
||||
name: &str,
|
||||
script: &str,
|
||||
) -> ScriptManager<Error> {
|
||||
let column_schemas = vec![
|
||||
ColumnSchema::new("script", ConcreteDataType::string_datatype(), false),
|
||||
ColumnSchema::new("schema", ConcreteDataType::string_datatype(), false),
|
||||
ColumnSchema::new("name", ConcreteDataType::string_datatype(), false),
|
||||
];
|
||||
|
||||
let columns: Vec<VectorRef> = vec![
|
||||
Arc::new(StringVector::from(vec![script])),
|
||||
Arc::new(StringVector::from(vec![schema])),
|
||||
Arc::new(StringVector::from(vec![name])),
|
||||
];
|
||||
|
||||
let schema = Arc::new(Schema::new(column_schemas));
|
||||
let recordbatch = RecordBatch::new(schema, columns).unwrap();
|
||||
|
||||
let table = MemTable::table("scripts", recordbatch);
|
||||
|
||||
let catalog_manager = MemoryCatalogManager::new_with_table(table.clone());
|
||||
|
||||
let factory = QueryEngineFactory::new(catalog_manager.clone(), None, None, None, None, false);
|
||||
let query_engine = factory.query_engine();
|
||||
let mgr = ScriptManager::new(Arc::new(MockGrpcQueryHandler {}) as _, query_engine)
|
||||
.await
|
||||
.unwrap();
|
||||
mgr.insert_scripts_table(catalog, table);
|
||||
|
||||
mgr
|
||||
}
|
||||
|
||||
struct MockGrpcQueryHandler {}
|
||||
|
||||
#[async_trait]
|
||||
impl GrpcQueryHandler for MockGrpcQueryHandler {
|
||||
type Error = Error;
|
||||
|
||||
async fn do_query(&self, _query: Request, _ctx: QueryContextRef) -> Result<Output> {
|
||||
Ok(Output::new_with_affected_rows(1))
|
||||
}
|
||||
}
|
||||
@@ -130,7 +130,6 @@ mysql_async = { version = "0.33", default-features = false, features = [
|
||||
] }
|
||||
permutation = "0.4"
|
||||
rand.workspace = true
|
||||
script = { workspace = true, features = ["python"] }
|
||||
serde_json.workspace = true
|
||||
session = { workspace = true, features = ["testing"] }
|
||||
table.workspace = true
|
||||
|
||||
@@ -149,14 +149,6 @@ pub enum Error {
|
||||
#[snafu(display("Failed to describe statement"))]
|
||||
DescribeStatement { source: BoxedError },
|
||||
|
||||
#[snafu(display("Failed to insert script with name: {}", name))]
|
||||
InsertScript {
|
||||
name: String,
|
||||
#[snafu(implicit)]
|
||||
location: Location,
|
||||
source: BoxedError,
|
||||
},
|
||||
|
||||
#[snafu(display("Pipeline management api error"))]
|
||||
Pipeline {
|
||||
#[snafu(source)]
|
||||
@@ -165,14 +157,6 @@ pub enum Error {
|
||||
location: Location,
|
||||
},
|
||||
|
||||
#[snafu(display("Failed to execute script by name: {}", name))]
|
||||
ExecuteScript {
|
||||
name: String,
|
||||
#[snafu(implicit)]
|
||||
location: Location,
|
||||
source: BoxedError,
|
||||
},
|
||||
|
||||
#[snafu(display("Not supported: {}", feat))]
|
||||
NotSupported { feat: String },
|
||||
|
||||
@@ -633,9 +617,7 @@ impl ErrorExt for Error {
|
||||
|
||||
CollectRecordbatch { .. } => StatusCode::EngineExecuteQuery,
|
||||
|
||||
InsertScript { source, .. }
|
||||
| ExecuteScript { source, .. }
|
||||
| ExecuteQuery { source, .. }
|
||||
ExecuteQuery { source, .. }
|
||||
| ExecutePlan { source, .. }
|
||||
| ExecuteGrpcQuery { source, .. }
|
||||
| ExecuteGrpcRequest { source, .. }
|
||||
|
||||
@@ -67,7 +67,7 @@ use crate::prometheus_handler::PrometheusHandlerRef;
|
||||
use crate::query_handler::sql::ServerSqlQueryHandlerRef;
|
||||
use crate::query_handler::{
|
||||
InfluxdbLineProtocolHandlerRef, LogQueryHandlerRef, OpenTelemetryProtocolHandlerRef,
|
||||
OpentsdbProtocolHandlerRef, PipelineHandlerRef, PromStoreProtocolHandlerRef, ScriptHandlerRef,
|
||||
OpentsdbProtocolHandlerRef, PipelineHandlerRef, PromStoreProtocolHandlerRef,
|
||||
};
|
||||
use crate::server::Server;
|
||||
|
||||
@@ -89,7 +89,6 @@ pub mod pprof;
|
||||
pub mod prom_store;
|
||||
pub mod prometheus;
|
||||
pub mod result;
|
||||
pub mod script;
|
||||
mod timeout;
|
||||
|
||||
pub(crate) use timeout::DynamicTimeoutLayer;
|
||||
@@ -464,7 +463,6 @@ impl From<JsonResponse> for HttpResponse {
|
||||
#[derive(Clone)]
|
||||
pub struct ApiState {
|
||||
pub sql_handler: ServerSqlQueryHandlerRef,
|
||||
pub script_handler: Option<ScriptHandlerRef>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -490,15 +488,8 @@ impl HttpServerBuilder {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_sql_handler(
|
||||
self,
|
||||
sql_handler: ServerSqlQueryHandlerRef,
|
||||
script_handler: Option<ScriptHandlerRef>,
|
||||
) -> Self {
|
||||
let sql_router = HttpServer::route_sql(ApiState {
|
||||
sql_handler,
|
||||
script_handler,
|
||||
});
|
||||
pub fn with_sql_handler(self, sql_handler: ServerSqlQueryHandlerRef) -> Self {
|
||||
let sql_router = HttpServer::route_sql(ApiState { sql_handler });
|
||||
|
||||
Self {
|
||||
router: self
|
||||
@@ -783,8 +774,6 @@ impl HttpServer {
|
||||
"/promql",
|
||||
routing::get(handler::promql).post(handler::promql),
|
||||
)
|
||||
.route("/scripts", routing::post(script::scripts))
|
||||
.route("/run-script", routing::post(script::run_script))
|
||||
.with_state(api_state)
|
||||
}
|
||||
|
||||
@@ -1065,7 +1054,7 @@ mod test {
|
||||
let instance = Arc::new(DummyInstance { _tx: tx });
|
||||
let sql_instance = ServerSqlQueryHandlerAdapter::arc(instance.clone());
|
||||
let server = HttpServerBuilder::new(HttpOptions::default())
|
||||
.with_sql_handler(sql_instance, None)
|
||||
.with_sql_handler(sql_instance)
|
||||
.build();
|
||||
server.build(server.make_app()).route(
|
||||
"/test/timeout",
|
||||
|
||||
@@ -1,143 +0,0 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
use axum::extract::{Query, RawBody, State};
|
||||
use common_catalog::consts::DEFAULT_CATALOG_NAME;
|
||||
use common_error::ext::ErrorExt;
|
||||
use common_error::status_code::StatusCode;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use session::context::QueryContext;
|
||||
use snafu::ResultExt;
|
||||
|
||||
use crate::error::{HyperSnafu, InvalidUtf8ValueSnafu};
|
||||
use crate::http::result::error_result::ErrorResponse;
|
||||
use crate::http::{ApiState, GreptimedbV1Response, HttpResponse};
|
||||
|
||||
macro_rules! json_err {
|
||||
($e: expr) => {{
|
||||
return HttpResponse::Error(ErrorResponse::from_error($e));
|
||||
}};
|
||||
|
||||
($msg: expr, $code: expr) => {{
|
||||
return HttpResponse::Error(ErrorResponse::from_error_message($code, $msg.to_string()));
|
||||
}};
|
||||
}
|
||||
|
||||
macro_rules! unwrap_or_json_err {
|
||||
($result: expr) => {
|
||||
match $result {
|
||||
Ok(result) => result,
|
||||
Err(e) => json_err!(e),
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// Handler to insert and compile script
|
||||
#[axum_macros::debug_handler]
|
||||
pub async fn scripts(
|
||||
State(state): State<ApiState>,
|
||||
Query(params): Query<ScriptQuery>,
|
||||
RawBody(body): RawBody,
|
||||
) -> HttpResponse {
|
||||
if let Some(script_handler) = &state.script_handler {
|
||||
let catalog = params
|
||||
.catalog
|
||||
.unwrap_or_else(|| DEFAULT_CATALOG_NAME.to_string());
|
||||
let schema = params.db.as_ref();
|
||||
|
||||
if schema.is_none() || schema.unwrap().is_empty() {
|
||||
json_err!("invalid schema", StatusCode::InvalidArguments)
|
||||
}
|
||||
|
||||
let name = params.name.as_ref();
|
||||
|
||||
if name.is_none() || name.unwrap().is_empty() {
|
||||
json_err!("invalid name", StatusCode::InvalidArguments);
|
||||
}
|
||||
|
||||
let bytes = unwrap_or_json_err!(hyper::body::to_bytes(body).await.context(HyperSnafu));
|
||||
|
||||
let script =
|
||||
unwrap_or_json_err!(String::from_utf8(bytes.to_vec()).context(InvalidUtf8ValueSnafu));
|
||||
|
||||
// Safety: schema and name are already checked above.
|
||||
let query_ctx = Arc::new(QueryContext::with(&catalog, schema.unwrap()));
|
||||
match script_handler
|
||||
.insert_script(query_ctx, name.unwrap(), &script)
|
||||
.await
|
||||
{
|
||||
Ok(()) => GreptimedbV1Response::from_output(vec![]).await,
|
||||
Err(e) => json_err!(
|
||||
format!("Insert script error: {}", e.output_msg()),
|
||||
e.status_code()
|
||||
),
|
||||
}
|
||||
} else {
|
||||
json_err!(
|
||||
"Script execution not supported, missing script handler",
|
||||
StatusCode::Unsupported
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Default)]
|
||||
pub struct ScriptQuery {
|
||||
pub catalog: Option<String>,
|
||||
pub db: Option<String>,
|
||||
pub name: Option<String>,
|
||||
#[serde(flatten)]
|
||||
pub params: HashMap<String, String>,
|
||||
}
|
||||
|
||||
/// Handler to execute script
|
||||
#[axum_macros::debug_handler]
|
||||
pub async fn run_script(
|
||||
State(state): State<ApiState>,
|
||||
Query(params): Query<ScriptQuery>,
|
||||
) -> HttpResponse {
|
||||
if let Some(script_handler) = &state.script_handler {
|
||||
let catalog = params
|
||||
.catalog
|
||||
.unwrap_or_else(|| DEFAULT_CATALOG_NAME.to_string());
|
||||
let start = Instant::now();
|
||||
let schema = params.db.as_ref();
|
||||
|
||||
if schema.is_none() || schema.unwrap().is_empty() {
|
||||
json_err!("invalid schema", StatusCode::InvalidArguments)
|
||||
}
|
||||
|
||||
let name = params.name.as_ref();
|
||||
|
||||
if name.is_none() || name.unwrap().is_empty() {
|
||||
json_err!("invalid name", StatusCode::InvalidArguments);
|
||||
}
|
||||
|
||||
// Safety: schema and name are already checked above.
|
||||
let query_ctx = Arc::new(QueryContext::with(&catalog, schema.unwrap()));
|
||||
let output = script_handler
|
||||
.execute_script(query_ctx, name.unwrap(), params.params)
|
||||
.await;
|
||||
let resp = GreptimedbV1Response::from_output(vec![output]).await;
|
||||
resp.with_execution_time(start.elapsed().as_millis() as u64)
|
||||
} else {
|
||||
json_err!(
|
||||
"Script execution not supported, missing script handler",
|
||||
StatusCode::Unsupported
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -254,31 +254,6 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
/// ScriptInterceptor can track life cycle of a script request and customize or
|
||||
/// abort its execution at given point.
|
||||
pub trait ScriptInterceptor {
|
||||
type Error: ErrorExt;
|
||||
|
||||
/// Called before script request is actually executed.
|
||||
fn pre_execute(&self, _name: &str, _query_ctx: QueryContextRef) -> Result<(), Self::Error> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub type ScriptInterceptorRef<E> = Arc<dyn ScriptInterceptor<Error = E> + Send + Sync + 'static>;
|
||||
|
||||
impl<E: ErrorExt> ScriptInterceptor for Option<ScriptInterceptorRef<E>> {
|
||||
type Error = E;
|
||||
|
||||
fn pre_execute(&self, name: &str, query_ctx: QueryContextRef) -> Result<(), Self::Error> {
|
||||
if let Some(this) = self {
|
||||
this.pre_execute(name, query_ctx)
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// LineProtocolInterceptor can track life cycle of a line protocol request
|
||||
/// and customize or abort its execution at given point.
|
||||
#[async_trait]
|
||||
|
||||
@@ -51,26 +51,9 @@ pub type OpentsdbProtocolHandlerRef = Arc<dyn OpentsdbProtocolHandler + Send + S
|
||||
pub type InfluxdbLineProtocolHandlerRef = Arc<dyn InfluxdbLineProtocolHandler + Send + Sync>;
|
||||
pub type PromStoreProtocolHandlerRef = Arc<dyn PromStoreProtocolHandler + Send + Sync>;
|
||||
pub type OpenTelemetryProtocolHandlerRef = Arc<dyn OpenTelemetryProtocolHandler + Send + Sync>;
|
||||
pub type ScriptHandlerRef = Arc<dyn ScriptHandler + Send + Sync>;
|
||||
pub type PipelineHandlerRef = Arc<dyn PipelineHandler + Send + Sync>;
|
||||
pub type LogQueryHandlerRef = Arc<dyn LogQueryHandler + Send + Sync>;
|
||||
|
||||
#[async_trait]
|
||||
pub trait ScriptHandler {
|
||||
async fn insert_script(
|
||||
&self,
|
||||
query_ctx: QueryContextRef,
|
||||
name: &str,
|
||||
script: &str,
|
||||
) -> Result<()>;
|
||||
async fn execute_script(
|
||||
&self,
|
||||
query_ctx: QueryContextRef,
|
||||
name: &str,
|
||||
params: HashMap<String, String>,
|
||||
) -> Result<Output>;
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait InfluxdbLineProtocolHandler {
|
||||
/// A successful request will not return a response.
|
||||
|
||||
@@ -14,38 +14,32 @@
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use axum::body::{Body, Bytes};
|
||||
use axum::extract::{Json, Query, RawBody, State};
|
||||
use axum::extract::{Json, Query, State};
|
||||
use axum::http::header;
|
||||
use axum::response::IntoResponse;
|
||||
use axum::Form;
|
||||
use bytes::Bytes;
|
||||
use headers::HeaderValue;
|
||||
use http_body::combinators::UnsyncBoxBody;
|
||||
use hyper::Response;
|
||||
use mime_guess::mime;
|
||||
use servers::http::GreptimeQueryOutput::Records;
|
||||
use servers::http::{
|
||||
handler as http_handler, script as script_handler, ApiState, GreptimeOptionsConfigState,
|
||||
GreptimeQueryOutput, HttpResponse,
|
||||
handler as http_handler, ApiState, GreptimeOptionsConfigState, GreptimeQueryOutput,
|
||||
HttpResponse,
|
||||
};
|
||||
use servers::metrics_handler::MetricsHandler;
|
||||
use session::context::QueryContext;
|
||||
use table::test_util::MemTable;
|
||||
|
||||
use crate::{
|
||||
create_testing_script_handler, create_testing_sql_query_handler, ScriptHandlerRef,
|
||||
ServerSqlQueryHandlerRef,
|
||||
};
|
||||
use crate::create_testing_sql_query_handler;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sql_not_provided() {
|
||||
let sql_handler = create_testing_sql_query_handler(MemTable::default_numbers_table());
|
||||
let ctx = QueryContext::with_db_name(None);
|
||||
ctx.set_current_user(auth::userinfo_by_name(None));
|
||||
let api_state = ApiState {
|
||||
sql_handler,
|
||||
script_handler: None,
|
||||
};
|
||||
let api_state = ApiState { sql_handler };
|
||||
|
||||
for format in ["greptimedb_v1", "influxdb_v1", "csv", "table"] {
|
||||
let query = http_handler::SqlQuery {
|
||||
@@ -76,10 +70,7 @@ async fn test_sql_output_rows() {
|
||||
|
||||
let ctx = QueryContext::with_db_name(None);
|
||||
ctx.set_current_user(auth::userinfo_by_name(None));
|
||||
let api_state = ApiState {
|
||||
sql_handler,
|
||||
script_handler: None,
|
||||
};
|
||||
let api_state = ApiState { sql_handler };
|
||||
|
||||
let query_sql = "select sum(uint32s) from numbers limit 20";
|
||||
for format in ["greptimedb_v1", "influxdb_v1", "csv", "table"] {
|
||||
@@ -182,10 +173,7 @@ async fn test_dashboard_sql_limit() {
|
||||
let sql_handler = create_testing_sql_query_handler(MemTable::specified_numbers_table(2000));
|
||||
let ctx = QueryContext::with_db_name(None);
|
||||
ctx.set_current_user(auth::userinfo_by_name(None));
|
||||
let api_state = ApiState {
|
||||
sql_handler,
|
||||
script_handler: None,
|
||||
};
|
||||
let api_state = ApiState { sql_handler };
|
||||
for format in ["greptimedb_v1", "csv", "table"] {
|
||||
let query = create_query(format, "select * from numbers", Some(1000));
|
||||
let sql_response = http_handler::sql(
|
||||
@@ -228,10 +216,7 @@ async fn test_sql_form() {
|
||||
|
||||
let ctx = QueryContext::with_db_name(None);
|
||||
ctx.set_current_user(auth::userinfo_by_name(None));
|
||||
let api_state = ApiState {
|
||||
sql_handler,
|
||||
script_handler: None,
|
||||
};
|
||||
let api_state = ApiState { sql_handler };
|
||||
|
||||
for format in ["greptimedb_v1", "influxdb_v1", "csv", "table"] {
|
||||
let form = create_form(format);
|
||||
@@ -341,196 +326,6 @@ async fn test_metrics() {
|
||||
assert!(text.contains("test_metrics counter"));
|
||||
}
|
||||
|
||||
async fn insert_script(
|
||||
script: String,
|
||||
script_handler: ScriptHandlerRef,
|
||||
sql_handler: ServerSqlQueryHandlerRef,
|
||||
) {
|
||||
let body = RawBody(Body::from(script.clone()));
|
||||
let invalid_query = create_invalid_script_query();
|
||||
let json = script_handler::scripts(
|
||||
State(ApiState {
|
||||
sql_handler: sql_handler.clone(),
|
||||
script_handler: Some(script_handler.clone()),
|
||||
}),
|
||||
invalid_query,
|
||||
body,
|
||||
)
|
||||
.await;
|
||||
let HttpResponse::Error(json) = json else {
|
||||
unreachable!()
|
||||
};
|
||||
assert_eq!(json.error(), "invalid schema");
|
||||
|
||||
let body = RawBody(Body::from(script.clone()));
|
||||
let exec = create_script_query();
|
||||
// Insert the script
|
||||
let json = script_handler::scripts(
|
||||
State(ApiState {
|
||||
sql_handler: sql_handler.clone(),
|
||||
script_handler: Some(script_handler.clone()),
|
||||
}),
|
||||
exec,
|
||||
body,
|
||||
)
|
||||
.await;
|
||||
let HttpResponse::GreptimedbV1(json) = json else {
|
||||
unreachable!()
|
||||
};
|
||||
assert!(json.output().is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_scripts() {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
|
||||
let script = r#"
|
||||
@copr(sql='select uint32s as number from numbers limit 5', args=['number'], returns=['n'])
|
||||
def test(n) -> vector[i64]:
|
||||
return n;
|
||||
"#
|
||||
.to_string();
|
||||
let sql_handler = create_testing_sql_query_handler(MemTable::default_numbers_table());
|
||||
let script_handler = create_testing_script_handler(MemTable::default_numbers_table());
|
||||
|
||||
insert_script(script.clone(), script_handler.clone(), sql_handler.clone()).await;
|
||||
// Run the script
|
||||
let exec = create_script_query();
|
||||
let json = script_handler::run_script(
|
||||
State(ApiState {
|
||||
sql_handler,
|
||||
script_handler: Some(script_handler),
|
||||
}),
|
||||
exec,
|
||||
)
|
||||
.await;
|
||||
let HttpResponse::GreptimedbV1(json) = json else {
|
||||
unreachable!()
|
||||
};
|
||||
match &json.output()[0] {
|
||||
GreptimeQueryOutput::Records(records) => {
|
||||
let json = serde_json::to_string_pretty(&records).unwrap();
|
||||
assert_eq!(5, records.num_rows());
|
||||
assert_eq!(
|
||||
json,
|
||||
r#"{
|
||||
"schema": {
|
||||
"column_schemas": [
|
||||
{
|
||||
"name": "n",
|
||||
"data_type": "Int64"
|
||||
}
|
||||
]
|
||||
},
|
||||
"rows": [
|
||||
[
|
||||
0
|
||||
],
|
||||
[
|
||||
1
|
||||
],
|
||||
[
|
||||
2
|
||||
],
|
||||
[
|
||||
3
|
||||
],
|
||||
[
|
||||
4
|
||||
]
|
||||
],
|
||||
"total_rows": 5
|
||||
}"#
|
||||
);
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_scripts_with_params() {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
|
||||
let script = r#"
|
||||
@copr(sql='select uint32s as number from numbers limit 5', args=['number'], returns=['n'])
|
||||
def test(n, **params) -> vector[i64]:
|
||||
return n + int(params['a'])
|
||||
"#
|
||||
.to_string();
|
||||
let sql_handler = create_testing_sql_query_handler(MemTable::default_numbers_table());
|
||||
let script_handler = create_testing_script_handler(MemTable::default_numbers_table());
|
||||
|
||||
insert_script(script.clone(), script_handler.clone(), sql_handler.clone()).await;
|
||||
// Run the script
|
||||
let mut exec = create_script_query();
|
||||
let _ = exec.0.params.insert("a".to_string(), "42".to_string());
|
||||
let json = script_handler::run_script(
|
||||
State(ApiState {
|
||||
sql_handler,
|
||||
script_handler: Some(script_handler),
|
||||
}),
|
||||
exec,
|
||||
)
|
||||
.await;
|
||||
let HttpResponse::GreptimedbV1(json) = json else {
|
||||
unreachable!()
|
||||
};
|
||||
match &json.output()[0] {
|
||||
GreptimeQueryOutput::Records(records) => {
|
||||
let json = serde_json::to_string_pretty(&records).unwrap();
|
||||
assert_eq!(5, records.num_rows());
|
||||
assert_eq!(
|
||||
json,
|
||||
r#"{
|
||||
"schema": {
|
||||
"column_schemas": [
|
||||
{
|
||||
"name": "n",
|
||||
"data_type": "Int64"
|
||||
}
|
||||
]
|
||||
},
|
||||
"rows": [
|
||||
[
|
||||
42
|
||||
],
|
||||
[
|
||||
43
|
||||
],
|
||||
[
|
||||
44
|
||||
],
|
||||
[
|
||||
45
|
||||
],
|
||||
[
|
||||
46
|
||||
]
|
||||
],
|
||||
"total_rows": 5
|
||||
}"#
|
||||
);
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn create_script_query() -> Query<script_handler::ScriptQuery> {
|
||||
Query(script_handler::ScriptQuery {
|
||||
db: Some("test".to_string()),
|
||||
name: Some("test".to_string()),
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
fn create_invalid_script_query() -> Query<script_handler::ScriptQuery> {
|
||||
Query(script_handler::ScriptQuery {
|
||||
db: None,
|
||||
name: None,
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
fn create_query(format: &str, sql: &str, limit: Option<usize>) -> Query<http_handler::SqlQuery> {
|
||||
Query(http_handler::SqlQuery {
|
||||
sql: Some(sql.to_string()),
|
||||
|
||||
@@ -117,7 +117,7 @@ fn make_test_app(tx: Arc<mpsc::Sender<(String, String)>>, db_name: Option<&str>)
|
||||
})
|
||||
}
|
||||
let server = HttpServerBuilder::new(http_opts)
|
||||
.with_sql_handler(instance.clone(), None)
|
||||
.with_sql_handler(instance.clone())
|
||||
.with_user_provider(Arc::new(user_provider))
|
||||
.with_influxdb_handler(instance)
|
||||
.build();
|
||||
|
||||
@@ -109,7 +109,7 @@ fn make_test_app(tx: mpsc::Sender<String>) -> Router {
|
||||
|
||||
let instance = Arc::new(DummyInstance { tx });
|
||||
let server = HttpServerBuilder::new(http_opts)
|
||||
.with_sql_handler(instance.clone(), None)
|
||||
.with_sql_handler(instance.clone())
|
||||
.with_opentsdb_handler(instance)
|
||||
.build();
|
||||
server.build(server.make_app())
|
||||
|
||||
@@ -138,7 +138,7 @@ fn make_test_app(tx: mpsc::Sender<(String, Vec<u8>)>) -> Router {
|
||||
let is_strict_mode = false;
|
||||
let instance = Arc::new(DummyInstance { tx });
|
||||
let server = HttpServerBuilder::new(http_opts)
|
||||
.with_sql_handler(instance.clone(), None)
|
||||
.with_sql_handler(instance.clone())
|
||||
.with_prom_handler(instance, true, is_strict_mode)
|
||||
.build();
|
||||
server.build(server.make_app())
|
||||
|
||||
@@ -12,8 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, RwLock};
|
||||
use std::sync::Arc;
|
||||
|
||||
use api::v1::greptime_request::Request;
|
||||
use api::v1::query_request::Query;
|
||||
@@ -25,12 +24,9 @@ use datafusion_expr::LogicalPlan;
|
||||
use query::parser::{PromQuery, QueryLanguageParser, QueryStatement};
|
||||
use query::query_engine::DescribeResult;
|
||||
use query::{QueryEngineFactory, QueryEngineRef};
|
||||
use script::engine::{CompileContext, EvalContext, Script, ScriptEngine};
|
||||
use script::python::{PyEngine, PyScript};
|
||||
use servers::error::{Error, NotSupportedSnafu, Result};
|
||||
use servers::query_handler::grpc::{GrpcQueryHandler, ServerGrpcQueryHandlerRef};
|
||||
use servers::query_handler::sql::{ServerSqlQueryHandlerRef, SqlQueryHandler};
|
||||
use servers::query_handler::{ScriptHandler, ScriptHandlerRef};
|
||||
use session::context::QueryContextRef;
|
||||
use snafu::ensure;
|
||||
use sql::statements::statement::Statement;
|
||||
@@ -41,23 +37,16 @@ mod http;
|
||||
mod interceptor;
|
||||
mod mysql;
|
||||
mod postgres;
|
||||
mod py_script;
|
||||
|
||||
const LOCALHOST_WITH_0: &str = "127.0.0.1:0";
|
||||
|
||||
pub struct DummyInstance {
|
||||
query_engine: QueryEngineRef,
|
||||
py_engine: Arc<PyEngine>,
|
||||
scripts: RwLock<HashMap<String, Arc<PyScript>>>,
|
||||
}
|
||||
|
||||
impl DummyInstance {
|
||||
fn new(query_engine: QueryEngineRef) -> Self {
|
||||
Self {
|
||||
py_engine: Arc::new(PyEngine::new(query_engine.clone())),
|
||||
scripts: RwLock::new(HashMap::new()),
|
||||
query_engine,
|
||||
}
|
||||
Self { query_engine }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -113,51 +102,6 @@ impl SqlQueryHandler for DummyInstance {
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ScriptHandler for DummyInstance {
|
||||
async fn insert_script(
|
||||
&self,
|
||||
query_ctx: QueryContextRef,
|
||||
name: &str,
|
||||
script: &str,
|
||||
) -> Result<()> {
|
||||
let catalog = query_ctx.current_catalog();
|
||||
let schema = query_ctx.current_schema();
|
||||
|
||||
let script = self
|
||||
.py_engine
|
||||
.compile(script, CompileContext::default())
|
||||
.await
|
||||
.unwrap();
|
||||
script.register_udf().await;
|
||||
let _ = self
|
||||
.scripts
|
||||
.write()
|
||||
.unwrap()
|
||||
.insert(format!("{catalog}_{schema}_{name}"), Arc::new(script));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn execute_script(
|
||||
&self,
|
||||
query_ctx: QueryContextRef,
|
||||
name: &str,
|
||||
params: HashMap<String, String>,
|
||||
) -> Result<Output> {
|
||||
let catalog = query_ctx.current_catalog();
|
||||
let schema = query_ctx.current_schema();
|
||||
let key = format!("{catalog}_{schema}_{name}");
|
||||
|
||||
let py_script = self.scripts.read().unwrap().get(&key).unwrap().clone();
|
||||
|
||||
Ok(py_script
|
||||
.execute(params, EvalContext::default())
|
||||
.await
|
||||
.unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl GrpcQueryHandler for DummyInstance {
|
||||
type Error = Error;
|
||||
@@ -219,10 +163,6 @@ fn create_testing_instance(table: TableRef) -> DummyInstance {
|
||||
DummyInstance::new(query_engine)
|
||||
}
|
||||
|
||||
fn create_testing_script_handler(table: TableRef) -> ScriptHandlerRef {
|
||||
Arc::new(create_testing_instance(table)) as _
|
||||
}
|
||||
|
||||
fn create_testing_sql_query_handler(table: TableRef) -> ServerSqlQueryHandlerRef {
|
||||
Arc::new(create_testing_instance(table)) as _
|
||||
}
|
||||
|
||||
@@ -1,113 +0,0 @@
|
||||
// Copyright 2023 Greptime Team
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use common_query::OutputData;
|
||||
use common_recordbatch::RecordBatch;
|
||||
use datatypes::prelude::ConcreteDataType;
|
||||
use datatypes::schema::{ColumnSchema, Schema};
|
||||
use datatypes::vectors::{StringVector, VectorRef};
|
||||
use servers::error::Result;
|
||||
use servers::query_handler::sql::SqlQueryHandler;
|
||||
use servers::query_handler::ScriptHandler;
|
||||
use session::context::QueryContextBuilder;
|
||||
use table::test_util::MemTable;
|
||||
|
||||
use crate::create_testing_instance;
|
||||
|
||||
#[ignore = "rust-python backend is not active support at present"]
|
||||
#[tokio::test]
|
||||
async fn test_insert_py_udf_and_query() -> Result<()> {
|
||||
let catalog = "greptime";
|
||||
let schema = "test";
|
||||
let name = "hello";
|
||||
let script = r#"
|
||||
@copr(returns=['n'])
|
||||
def hello() -> vector[str]:
|
||||
return 'hello';
|
||||
"#;
|
||||
|
||||
let column_schemas = vec![
|
||||
ColumnSchema::new("script", ConcreteDataType::string_datatype(), false),
|
||||
ColumnSchema::new("schema", ConcreteDataType::string_datatype(), false),
|
||||
ColumnSchema::new("name", ConcreteDataType::string_datatype(), false),
|
||||
];
|
||||
|
||||
let columns: Vec<VectorRef> = vec![
|
||||
Arc::new(StringVector::from(vec![script])),
|
||||
Arc::new(StringVector::from(vec![schema])),
|
||||
Arc::new(StringVector::from(vec![name])),
|
||||
];
|
||||
|
||||
let raw_schema = Arc::new(Schema::new(column_schemas));
|
||||
let recordbatch = RecordBatch::new(raw_schema, columns).unwrap();
|
||||
|
||||
let table = MemTable::table("scripts", recordbatch);
|
||||
|
||||
let query_ctx = Arc::new(
|
||||
QueryContextBuilder::default()
|
||||
.current_catalog(catalog.to_string())
|
||||
.current_schema(schema.to_string())
|
||||
.build(),
|
||||
);
|
||||
|
||||
let instance = create_testing_instance(table);
|
||||
instance
|
||||
.insert_script(query_ctx.clone(), name, script)
|
||||
.await?;
|
||||
|
||||
let output = instance
|
||||
.execute_script(query_ctx.clone(), name, HashMap::new())
|
||||
.await?;
|
||||
|
||||
match output.data {
|
||||
OutputData::RecordBatches(batches) => {
|
||||
let expected = "\
|
||||
+-------+
|
||||
| n |
|
||||
+-------+
|
||||
| hello |
|
||||
+-------+";
|
||||
assert_eq!(expected, batches.pretty_print().unwrap());
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
|
||||
let res = instance
|
||||
.do_query("select hello()", query_ctx)
|
||||
.await
|
||||
.remove(0)
|
||||
.unwrap();
|
||||
match res.data {
|
||||
OutputData::AffectedRows(_) => (),
|
||||
OutputData::RecordBatches(_) => {
|
||||
unreachable!()
|
||||
}
|
||||
OutputData::Stream(s) => {
|
||||
let batches = common_recordbatch::util::collect_batches(s).await.unwrap();
|
||||
let expected = "\
|
||||
+---------+
|
||||
| hello() |
|
||||
+---------+
|
||||
| hello |
|
||||
+---------+";
|
||||
|
||||
assert_eq!(expected, batches.pretty_print().unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -91,7 +91,6 @@ partition.workspace = true
|
||||
paste.workspace = true
|
||||
prost.workspace = true
|
||||
rand.workspace = true
|
||||
script.workspace = true
|
||||
session = { workspace = true, features = ["testing"] }
|
||||
store-api.workspace = true
|
||||
tokio-postgres = { workspace = true }
|
||||
|
||||
@@ -392,10 +392,7 @@ pub async fn setup_test_http_app(store_type: StorageType, name: &str) -> (Router
|
||||
..Default::default()
|
||||
};
|
||||
let http_server = HttpServerBuilder::new(http_opts)
|
||||
.with_sql_handler(
|
||||
ServerSqlQueryHandlerAdapter::arc(instance.instance.clone()),
|
||||
None,
|
||||
)
|
||||
.with_sql_handler(ServerSqlQueryHandlerAdapter::arc(instance.instance.clone()))
|
||||
.with_metrics_handler(MetricsHandler)
|
||||
.with_greptime_config_options(instance.opts.datanode_options().to_toml().unwrap())
|
||||
.build();
|
||||
@@ -426,10 +423,7 @@ pub async fn setup_test_http_app_with_frontend_and_user_provider(
|
||||
let mut http_server = HttpServerBuilder::new(http_opts);
|
||||
|
||||
http_server = http_server
|
||||
.with_sql_handler(
|
||||
ServerSqlQueryHandlerAdapter::arc(instance.instance.clone()),
|
||||
Some(instance.instance.clone()),
|
||||
)
|
||||
.with_sql_handler(ServerSqlQueryHandlerAdapter::arc(instance.instance.clone()))
|
||||
.with_log_ingest_handler(instance.instance.clone(), None, None)
|
||||
.with_otlp_handler(instance.instance.clone())
|
||||
.with_greptime_config_options(instance.opts.to_toml().unwrap());
|
||||
@@ -465,10 +459,7 @@ pub async fn setup_test_prom_app_with_frontend(
|
||||
let frontend_ref = instance.instance.clone();
|
||||
let is_strict_mode = true;
|
||||
let http_server = HttpServerBuilder::new(http_opts)
|
||||
.with_sql_handler(
|
||||
ServerSqlQueryHandlerAdapter::arc(frontend_ref.clone()),
|
||||
Some(frontend_ref.clone()),
|
||||
)
|
||||
.with_sql_handler(ServerSqlQueryHandlerAdapter::arc(frontend_ref.clone()))
|
||||
.with_prom_handler(frontend_ref.clone(), true, is_strict_mode)
|
||||
.with_prometheus_handler(frontend_ref)
|
||||
.with_greptime_config_options(instance.opts.datanode_options().to_toml().unwrap())
|
||||
|
||||
@@ -80,7 +80,6 @@ macro_rules! http_tests {
|
||||
test_prometheus_promql_api,
|
||||
test_prom_http_api,
|
||||
test_metrics_api,
|
||||
test_scripts_api,
|
||||
test_health_api,
|
||||
test_status_api,
|
||||
test_config_api,
|
||||
@@ -721,46 +720,6 @@ pub async fn test_metrics_api(store_type: StorageType) {
|
||||
guard.remove_all().await;
|
||||
}
|
||||
|
||||
pub async fn test_scripts_api(store_type: StorageType) {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
let (app, mut guard) = setup_test_http_app_with_frontend(store_type, "script_api").await;
|
||||
let client = TestClient::new(app);
|
||||
|
||||
let res = client
|
||||
.post("/v1/scripts?db=schema_test&name=test")
|
||||
.body(
|
||||
r#"
|
||||
@copr(sql='select number from numbers limit 10', args=['number'], returns=['n'])
|
||||
def test(n) -> vector[f64]:
|
||||
return n + 1;
|
||||
"#,
|
||||
)
|
||||
.send()
|
||||
.await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
|
||||
let body = serde_json::from_str::<GreptimedbV1Response>(&res.text().await).unwrap();
|
||||
assert!(body.output().is_empty());
|
||||
|
||||
// call script
|
||||
let res = client
|
||||
.post("/v1/run-script?db=schema_test&name=test")
|
||||
.send()
|
||||
.await;
|
||||
assert_eq!(res.status(), StatusCode::OK);
|
||||
let body = serde_json::from_str::<GreptimedbV1Response>(&res.text().await).unwrap();
|
||||
let output = body.output();
|
||||
assert_eq!(output.len(), 1);
|
||||
assert_eq!(
|
||||
output[0],
|
||||
serde_json::from_value::<GreptimeQueryOutput>(json!({
|
||||
"records":{"schema":{"column_schemas":[{"name":"n","data_type":"Float64"}]},"rows":[[1.0],[2.0],[3.0],[4.0],[5.0],[6.0],[7.0],[8.0],[9.0],[10.0]],"total_rows": 10}
|
||||
})).unwrap()
|
||||
);
|
||||
|
||||
guard.remove_all().await;
|
||||
}
|
||||
|
||||
pub async fn test_health_api(store_type: StorageType) {
|
||||
common_telemetry::init_default_ut_logging();
|
||||
let (app, _guard) = setup_test_http_app_with_frontend(store_type, "health_api").await;
|
||||
|
||||
Reference in New Issue
Block a user