diff --git a/.env.sample b/.env.sample index d75c186..a23ae89 100644 --- a/.env.sample +++ b/.env.sample @@ -1,6 +1,6 @@ BERT_MODEL_PATH=models/deberta.onnx -MODEL_PATH=models/model_tsukuyomi.onnx +MODEL_PATH=models/tsukuyomi.sbv2 MODELS_PATH=models -STYLE_VECTORS_PATH=models/style_vectors_tsukuyomi.json TOKENIZER_PATH=models/tokenizer.json -ADDR=localhost:3000 \ No newline at end of file +ADDR=localhost:3000 +RUST_LOG=warn \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index f74f29b..b09045f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1756,6 +1756,7 @@ version = "0.1.1" dependencies = [ "anyhow", "dotenvy", + "env_logger", "hound", "jpreprocess", "ndarray", diff --git a/Cargo.toml b/Cargo.toml index e6acc62..beb944f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,3 +5,4 @@ members = ["sbv2_api", "sbv2_core"] [workspace.dependencies] anyhow = "1.0.86" dotenvy = "0.15.7" +env_logger = "0.11.5" diff --git a/sbv2_api/Cargo.toml b/sbv2_api/Cargo.toml index 479642a..a0fa8b6 100644 --- a/sbv2_api/Cargo.toml +++ b/sbv2_api/Cargo.toml @@ -7,7 +7,7 @@ edition = "2021" anyhow.workspace = true axum = "0.7.5" dotenvy.workspace = true -env_logger = "0.11.5" +env_logger.workspace = true log = "0.4.22" sbv2_core = { version = "0.1.1", path = "../sbv2_core" } serde = { version = "1.0.210", features = ["derive"] } @@ -17,4 +17,5 @@ tokio = { version = "1.40.0", features = ["full"] } cuda = ["sbv2_core/cuda"] cuda_tf32 = ["sbv2_core/cuda_tf32"] dynamic = ["sbv2_core/dynamic"] -directml = ["sbv2_core/directml"] \ No newline at end of file +directml = ["sbv2_core/directml"] +tensorrt = ["sbv2_core/tensorrt"] \ No newline at end of file diff --git a/sbv2_api/src/main.rs b/sbv2_api/src/main.rs index e8f8d8a..f9cd22d 100644 --- a/sbv2_api/src/main.rs +++ b/sbv2_api/src/main.rs @@ -135,7 +135,7 @@ impl AppState { #[tokio::main] async fn main() -> anyhow::Result<()> { - dotenvy::dotenv().ok(); + dotenvy::dotenv_override().ok(); env_logger::init(); let app = Router::new() .route("/", get(|| async { "Hello, World!" })) diff --git a/sbv2_core/Cargo.toml b/sbv2_core/Cargo.toml index 80f2e7a..0a8a88c 100644 --- a/sbv2_core/Cargo.toml +++ b/sbv2_core/Cargo.toml @@ -10,6 +10,7 @@ repository = "https://github.com/tuna2134/sbv2-api" [dependencies] anyhow.workspace = true dotenvy.workspace = true +env_logger.workspace = true hound = "3.5.1" jpreprocess = { version = "0.10.0", features = ["naist-jdic"] } ndarray = "0.16.1" @@ -29,3 +30,4 @@ cuda = ["ort/cuda"] cuda_tf32 = [] dynamic = ["ort/load-dynamic"] directml = ["ort/directml"] +tensorrt = ["ort/tensorrt"] diff --git a/sbv2_core/src/main.rs b/sbv2_core/src/main.rs index 5d27c65..925e735 100644 --- a/sbv2_core/src/main.rs +++ b/sbv2_core/src/main.rs @@ -4,18 +4,15 @@ use sbv2_core::tts; use std::env; fn main() -> anyhow::Result<()> { - dotenvy::dotenv().ok(); + dotenvy::dotenv_override().ok(); + env_logger::init(); let text = "眠たい"; let ident = "aaa"; let mut tts_holder = tts::TTSModelHolder::new( &fs::read(env::var("BERT_MODEL_PATH")?)?, &fs::read(env::var("TOKENIZER_PATH")?)?, )?; - tts_holder.load( - ident, - fs::read(env::var("STYLE_VECTORS_PATH")?)?, - fs::read(env::var("MODEL_PATH")?)?, - )?; + tts_holder.load_sbv2file(ident, fs::read(env::var("MODEL_PATH")?)?)?; let (bert_ori, phones, tones, lang_ids) = tts_holder.parse_text(text)?; @@ -32,6 +29,14 @@ fn main() -> anyhow::Result<()> { )?; std::fs::write("output.wav", data)?; let now = Instant::now(); + for _ in 0..10 { + tts_holder.parse_text(text)?; + } + println!( + "Time taken(parse_text): {}ms/it", + now.elapsed().as_millis() / 10 + ); + let now = Instant::now(); for _ in 0..10 { tts_holder.synthesize( ident, @@ -44,6 +49,9 @@ fn main() -> anyhow::Result<()> { 1.0, )?; } - println!("Time taken: {}", now.elapsed().as_millis()); + println!( + "Time taken(synthesize): {}ms/it", + now.elapsed().as_millis() / 10 + ); Ok(()) } diff --git a/sbv2_core/src/model.rs b/sbv2_core/src/model.rs index 485941b..2fa75c3 100644 --- a/sbv2_core/src/model.rs +++ b/sbv2_core/src/model.rs @@ -4,11 +4,25 @@ use ndarray::{array, s, Array1, Array2, Axis}; use ort::{GraphOptimizationLevel, Session}; use std::io::Cursor; -#[allow(clippy::vec_init_then_push)] -pub fn load_model>(model_file: P) -> Result { +#[allow(clippy::vec_init_then_push, unused_variables)] +pub fn load_model>(model_file: P, bert: bool) -> Result { let mut exp = Vec::new(); + #[cfg(feature = "tensorrt")] + { + if bert { + exp.push( + ort::TensorRTExecutionProvider::default() + .with_fp16(true) + .with_profile_min_shapes("input_ids:1x1,attention_mask:1x1") + .with_profile_max_shapes("input_ids:1x100,attention_mask:1x100") + .with_profile_opt_shapes("input_ids:1x25,attention_mask:1x25") + .build(), + ); + } + } #[cfg(feature = "cuda")] { + #[allow(unused_mut)] let mut cuda = ort::CUDAExecutionProvider::default() .with_conv_algorithm_search(ort::CUDAExecutionProviderCuDNNConvAlgoSearch::Default); #[cfg(feature = "cuda_tf32")] diff --git a/sbv2_core/src/tts.rs b/sbv2_core/src/tts.rs index 1b012e8..0d50096 100644 --- a/sbv2_core/src/tts.rs +++ b/sbv2_core/src/tts.rs @@ -41,7 +41,7 @@ pub struct TTSModelHolder { impl TTSModelHolder { pub fn new>(bert_model_bytes: P, tokenizer_bytes: P) -> Result { - let bert = model::load_model(bert_model_bytes)?; + let bert = model::load_model(bert_model_bytes, true)?; let jtalk = jtalk::JTalk::new()?; let tokenizer = tokenizer::get_tokenizer(tokenizer_bytes)?; Ok(TTSModelHolder { @@ -55,7 +55,7 @@ impl TTSModelHolder { pub fn models(&self) -> Vec { self.models.iter().map(|m| m.ident.to_string()).collect() } - + pub fn load_sbv2file, P: AsRef<[u8]>>( &mut self, ident: I, @@ -94,7 +94,7 @@ impl TTSModelHolder { let ident = ident.into(); if self.find_model(ident.clone()).is_err() { self.models.push(TTSModel { - vits2: model::load_model(vits2_bytes)?, + vits2: model::load_model(vits2_bytes, false)?, style_vectors: style::load_style(style_vectors_bytes)?, ident, })