diff --git a/sbv2_api/Cargo.toml b/sbv2_api/Cargo.toml index b1ac73a..11196c8 100644 --- a/sbv2_api/Cargo.toml +++ b/sbv2_api/Cargo.toml @@ -17,3 +17,4 @@ tokio = { version = "1.40.0", features = ["full"] } cuda = ["sbv2_core/cuda"] cuda_tf32 = ["sbv2_core/cuda_tf32"] dynamic = ["sbv2_core/dynamic"] +directml = ["sbv2_core/directml"] \ No newline at end of file diff --git a/sbv2_core/Cargo.toml b/sbv2_core/Cargo.toml index 37203d9..b3d5ed0 100644 --- a/sbv2_core/Cargo.toml +++ b/sbv2_core/Cargo.toml @@ -21,4 +21,5 @@ tokenizers = "0.20.0" [features] cuda = ["ort/cuda"] cuda_tf32 = [] -dynamic = ["ort/load-dynamic"] \ No newline at end of file +dynamic = ["ort/load-dynamic"] +directml = ["ort/directml"] \ No newline at end of file diff --git a/sbv2_core/src/model.rs b/sbv2_core/src/model.rs index 1598e9b..485941b 100644 --- a/sbv2_core/src/model.rs +++ b/sbv2_core/src/model.rs @@ -17,6 +17,10 @@ pub fn load_model>(model_file: P) -> Result { } exp.push(cuda.build()); } + #[cfg(feature = "directml")] + { + exp.push(ort::DirectMLExecutionProvider::default().build()); + } exp.push(ort::CPUExecutionProvider::default().build()); Ok(Session::builder()? .with_execution_providers(exp)? @@ -26,6 +30,7 @@ pub fn load_model>(model_file: P) -> Result { .with_inter_threads(num_cpus::get_physical())? .commit_from_memory(model_file.as_ref())?) } + #[allow(clippy::too_many_arguments)] pub fn synthesize( session: &Session, diff --git a/sbv2_core/src/utils.rs b/sbv2_core/src/utils.rs index 9cf54f9..0d38a49 100644 --- a/sbv2_core/src/utils.rs +++ b/sbv2_core/src/utils.rs @@ -2,11 +2,6 @@ pub fn intersperse(slice: &[T], sep: T) -> Vec where T: Clone, { - /* - result = [item] * (len(lst) * 2 + 1) - result[1::2] = lst - return result - */ let mut result = vec![sep.clone(); slice.len() * 2 + 1]; result .iter_mut() @@ -15,24 +10,3 @@ where .for_each(|(r, s)| *r = s.clone()); result } - -/* -fn tile(arr: &Array2, reps: (usize, usize)) -> Array2 { - let (rows, cols) = arr.dim(); - let (rep_rows, rep_cols) = reps; - - let mut result = Array::zeros((rows * rep_rows, cols * rep_cols)); - - for i in 0..rep_rows { - for j in 0..rep_cols { - let view = result.slice_mut(s![ - i * rows..(i + 1) * rows, - j * cols..(j + 1) * cols - ]); - view.assign(arr); - } - } - - result -} -*/