From bb7913e0fe17bd4ca45c90cf21167538bcc31e34 Mon Sep 17 00:00:00 2001 From: tuna2134 Date: Wed, 11 Sep 2024 10:10:09 +0000 Subject: [PATCH] support directml --- sbv2_core/Cargo.toml | 3 ++- sbv2_core/src/model.rs | 5 +++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/sbv2_core/Cargo.toml b/sbv2_core/Cargo.toml index 37203d9..b3d5ed0 100644 --- a/sbv2_core/Cargo.toml +++ b/sbv2_core/Cargo.toml @@ -21,4 +21,5 @@ tokenizers = "0.20.0" [features] cuda = ["ort/cuda"] cuda_tf32 = [] -dynamic = ["ort/load-dynamic"] \ No newline at end of file +dynamic = ["ort/load-dynamic"] +directml = ["ort/directml"] \ No newline at end of file diff --git a/sbv2_core/src/model.rs b/sbv2_core/src/model.rs index 1598e9b..485941b 100644 --- a/sbv2_core/src/model.rs +++ b/sbv2_core/src/model.rs @@ -17,6 +17,10 @@ pub fn load_model>(model_file: P) -> Result { } exp.push(cuda.build()); } + #[cfg(feature = "directml")] + { + exp.push(ort::DirectMLExecutionProvider::default().build()); + } exp.push(ort::CPUExecutionProvider::default().build()); Ok(Session::builder()? .with_execution_providers(exp)? @@ -26,6 +30,7 @@ pub fn load_model>(model_file: P) -> Result { .with_inter_threads(num_cpus::get_physical())? .commit_from_memory(model_file.as_ref())?) } + #[allow(clippy::too_many_arguments)] pub fn synthesize( session: &Session,