diff --git a/Cargo.lock b/Cargo.lock index 0e085b4..f53accb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1208,6 +1208,16 @@ dependencies = [ "autocfg", ] +[[package]] +name = "num_cpus" +version = "1.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" +dependencies = [ + "hermit-abi", + "libc", +] + [[package]] name = "number_prefix" version = "0.4.0" @@ -1634,6 +1644,7 @@ dependencies = [ "hound", "jpreprocess", "ndarray", + "num_cpus", "once_cell", "ort", "regex", diff --git a/sbv2_core/Cargo.toml b/sbv2_core/Cargo.toml index 2fc6818..c8f7974 100644 --- a/sbv2_core/Cargo.toml +++ b/sbv2_core/Cargo.toml @@ -8,6 +8,7 @@ anyhow.workspace = true hound = "3.5.1" jpreprocess = { version = "0.10.0", features = ["naist-jdic"] } ndarray = "0.16.1" +num_cpus = "1.16.0" once_cell = "1.19.0" ort = { git = "https://github.com/pykeio/ort.git", version = "2.0.0-rc.5" } regex = "1.10.6" diff --git a/sbv2_core/src/model.rs b/sbv2_core/src/model.rs index ae6c535..a44b6f5 100644 --- a/sbv2_core/src/model.rs +++ b/sbv2_core/src/model.rs @@ -8,6 +8,9 @@ pub fn load_model(model_file: &str) -> Result { let session = Session::builder()? .with_optimization_level(GraphOptimizationLevel::Level3)? .with_intra_threads(1)? + .with_intra_threads(num_cpus::get_physical())? + .with_parallel_execution(true)? + .with_inter_threads(num_cpus::get_physical())? .commit_from_file(model_file)?; Ok(session) }