From bb7913e0fe17bd4ca45c90cf21167538bcc31e34 Mon Sep 17 00:00:00 2001 From: tuna2134 Date: Wed, 11 Sep 2024 10:10:09 +0000 Subject: [PATCH 1/3] support directml --- sbv2_core/Cargo.toml | 3 ++- sbv2_core/src/model.rs | 5 +++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/sbv2_core/Cargo.toml b/sbv2_core/Cargo.toml index 37203d9..b3d5ed0 100644 --- a/sbv2_core/Cargo.toml +++ b/sbv2_core/Cargo.toml @@ -21,4 +21,5 @@ tokenizers = "0.20.0" [features] cuda = ["ort/cuda"] cuda_tf32 = [] -dynamic = ["ort/load-dynamic"] \ No newline at end of file +dynamic = ["ort/load-dynamic"] +directml = ["ort/directml"] \ No newline at end of file diff --git a/sbv2_core/src/model.rs b/sbv2_core/src/model.rs index 1598e9b..485941b 100644 --- a/sbv2_core/src/model.rs +++ b/sbv2_core/src/model.rs @@ -17,6 +17,10 @@ pub fn load_model>(model_file: P) -> Result { } exp.push(cuda.build()); } + #[cfg(feature = "directml")] + { + exp.push(ort::DirectMLExecutionProvider::default().build()); + } exp.push(ort::CPUExecutionProvider::default().build()); Ok(Session::builder()? .with_execution_providers(exp)? @@ -26,6 +30,7 @@ pub fn load_model>(model_file: P) -> Result { .with_inter_threads(num_cpus::get_physical())? .commit_from_memory(model_file.as_ref())?) } + #[allow(clippy::too_many_arguments)] pub fn synthesize( session: &Session, From a92dd0731422a34393de756ac32f2dd8bcb8f04d Mon Sep 17 00:00:00 2001 From: tuna2134 Date: Wed, 11 Sep 2024 10:11:32 +0000 Subject: [PATCH 2/3] fixed --- sbv2_core/src/utils.rs | 26 -------------------------- 1 file changed, 26 deletions(-) diff --git a/sbv2_core/src/utils.rs b/sbv2_core/src/utils.rs index 9cf54f9..0d38a49 100644 --- a/sbv2_core/src/utils.rs +++ b/sbv2_core/src/utils.rs @@ -2,11 +2,6 @@ pub fn intersperse(slice: &[T], sep: T) -> Vec where T: Clone, { - /* - result = [item] * (len(lst) * 2 + 1) - result[1::2] = lst - return result - */ let mut result = vec![sep.clone(); slice.len() * 2 + 1]; result .iter_mut() @@ -15,24 +10,3 @@ where .for_each(|(r, s)| *r = s.clone()); result } - -/* -fn tile(arr: &Array2, reps: (usize, usize)) -> Array2 { - let (rows, cols) = arr.dim(); - let (rep_rows, rep_cols) = reps; - - let mut result = Array::zeros((rows * rep_rows, cols * rep_cols)); - - for i in 0..rep_rows { - for j in 0..rep_cols { - let view = result.slice_mut(s![ - i * rows..(i + 1) * rows, - j * cols..(j + 1) * cols - ]); - view.assign(arr); - } - } - - result -} -*/ From 90f0b054156a856f0dc0a1aa8763b7bb3c586195 Mon Sep 17 00:00:00 2001 From: tuna2134 Date: Wed, 11 Sep 2024 10:16:08 +0000 Subject: [PATCH 3/3] add directml support --- sbv2_api/Cargo.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/sbv2_api/Cargo.toml b/sbv2_api/Cargo.toml index b1ac73a..11196c8 100644 --- a/sbv2_api/Cargo.toml +++ b/sbv2_api/Cargo.toml @@ -17,3 +17,4 @@ tokio = { version = "1.40.0", features = ["full"] } cuda = ["sbv2_core/cuda"] cuda_tf32 = ["sbv2_core/cuda_tf32"] dynamic = ["sbv2_core/dynamic"] +directml = ["sbv2_core/directml"] \ No newline at end of file