diff --git a/sbv2_core/Cargo.toml b/sbv2_core/Cargo.toml index 37203d9..b3d5ed0 100644 --- a/sbv2_core/Cargo.toml +++ b/sbv2_core/Cargo.toml @@ -21,4 +21,5 @@ tokenizers = "0.20.0" [features] cuda = ["ort/cuda"] cuda_tf32 = [] -dynamic = ["ort/load-dynamic"] \ No newline at end of file +dynamic = ["ort/load-dynamic"] +directml = ["ort/directml"] \ No newline at end of file diff --git a/sbv2_core/src/model.rs b/sbv2_core/src/model.rs index 1598e9b..485941b 100644 --- a/sbv2_core/src/model.rs +++ b/sbv2_core/src/model.rs @@ -17,6 +17,10 @@ pub fn load_model>(model_file: P) -> Result { } exp.push(cuda.build()); } + #[cfg(feature = "directml")] + { + exp.push(ort::DirectMLExecutionProvider::default().build()); + } exp.push(ort::CPUExecutionProvider::default().build()); Ok(Session::builder()? .with_execution_providers(exp)? @@ -26,6 +30,7 @@ pub fn load_model>(model_file: P) -> Result { .with_inter_threads(num_cpus::get_physical())? .commit_from_memory(model_file.as_ref())?) } + #[allow(clippy::too_many_arguments)] pub fn synthesize( session: &Session,