From 5c840ba34e4444c95195d4a6899953708178efa8 Mon Sep 17 00:00:00 2001 From: tuna2134 Date: Thu, 12 Sep 2024 11:03:16 +0000 Subject: [PATCH 1/3] support coreml --- sbv2_core/Cargo.toml | 1 + sbv2_core/src/model.rs | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/sbv2_core/Cargo.toml b/sbv2_core/Cargo.toml index 0a8a88c..63badb6 100644 --- a/sbv2_core/Cargo.toml +++ b/sbv2_core/Cargo.toml @@ -31,3 +31,4 @@ cuda_tf32 = [] dynamic = ["ort/load-dynamic"] directml = ["ort/directml"] tensorrt = ["ort/tensorrt"] +coreml = ["ort/coreml"] \ No newline at end of file diff --git a/sbv2_core/src/model.rs b/sbv2_core/src/model.rs index 2fa75c3..2b975ca 100644 --- a/sbv2_core/src/model.rs +++ b/sbv2_core/src/model.rs @@ -35,6 +35,10 @@ pub fn load_model>(model_file: P, bert: bool) -> Result { exp.push(ort::DirectMLExecutionProvider::default().build()); } + #[cfg(feature = "coreml")] + { + exp.push(ort::CoreMLExecutionProvider::default().build()); + } exp.push(ort::CPUExecutionProvider::default().build()); Ok(Session::builder()? .with_execution_providers(exp)? From 07c1311bd2a0990e28f8af8e5fcdf485e21e5eb8 Mon Sep 17 00:00:00 2001 From: tuna2134 Date: Thu, 12 Sep 2024 11:35:48 +0000 Subject: [PATCH 2/3] clippy --- sbv2_api/src/main.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sbv2_api/src/main.rs b/sbv2_api/src/main.rs index f9cd22d..a10ec0a 100644 --- a/sbv2_api/src/main.rs +++ b/sbv2_api/src/main.rs @@ -99,7 +99,7 @@ impl AppState { continue; } }; - if let Err(e) = tts_model.load_sbv2file(&entry, sbv2_bytes) { + if let Err(e) = tts_model.load_sbv2file(entry, sbv2_bytes) { log::warn!("Error loading {entry}: {e}"); }; log::info!("Loaded: {entry}"); From e5c9a6d1d9e8aaff154aca31e581c07fa083ad5d Mon Sep 17 00:00:00 2001 From: tuna2134 Date: Thu, 12 Sep 2024 11:38:05 +0000 Subject: [PATCH 3/3] add coreml support for api --- sbv2_api/Cargo.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/sbv2_api/Cargo.toml b/sbv2_api/Cargo.toml index a0fa8b6..4b0b43a 100644 --- a/sbv2_api/Cargo.toml +++ b/sbv2_api/Cargo.toml @@ -14,6 +14,7 @@ serde = { version = "1.0.210", features = ["derive"] } tokio = { version = "1.40.0", features = ["full"] } [features] +coreml = ["sbv2_core/coreml"] cuda = ["sbv2_core/cuda"] cuda_tf32 = ["sbv2_core/cuda_tf32"] dynamic = ["sbv2_core/dynamic"]