diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..883236f --- /dev/null +++ b/.dockerignore @@ -0,0 +1,6 @@ +target/ +models/ +docker/ +.env* +renovate.json +*.py \ No newline at end of file diff --git a/.env.sample b/.env.sample index 9cc540c..300cd96 100644 --- a/.env.sample +++ b/.env.sample @@ -1,4 +1,4 @@ -BERT_MODEL_PATH=models/debert.onnx -MAIN_MODEL_PATH=models/model_opt.onnx +BERT_MODEL_PATH=models/deberta.onnx +MODEL_PATH=models/model_tsukuyomi.onnx STYLE_VECTORS_PATH=models/style_vectors.json TOKENIZER_PATH=models/tokenizer.json \ No newline at end of file diff --git a/docker/cpu.Dockerfile b/docker/cpu.Dockerfile new file mode 100644 index 0000000..e009023 --- /dev/null +++ b/docker/cpu.Dockerfile @@ -0,0 +1,9 @@ +FROM rust AS builder +WORKDIR /work +COPY . . +RUN cargo build -r --bin sbv2_api +FROM gcr.io/distroless/cc-debian12 +WORKDIR /work +COPY --from=builder /work/target/release/sbv2_api /work/main +COPY --from=builder /work/target/release/*.so /work +CMD ["/work/main"] \ No newline at end of file diff --git a/docker/run.sh b/docker/run.sh new file mode 100644 index 0000000..40f384a --- /dev/null +++ b/docker/run.sh @@ -0,0 +1 @@ +docker run -it --rm -p 3000:3000 --name sbv2 -v ./models:/work/models --env-file .env.sample sbv2 \ No newline at end of file diff --git a/sbv2_api/src/main.rs b/sbv2_api/src/main.rs index cc7c8e2..0602acf 100644 --- a/sbv2_api/src/main.rs +++ b/sbv2_api/src/main.rs @@ -5,7 +5,7 @@ use axum::{ routing::{get, post}, Json, Router, }; -use sbv2_core::tts::TTSModel; +use sbv2_core::tts::TTSModelHolder; use serde::Deserialize; use std::env; use std::sync::Arc; @@ -18,34 +18,46 @@ use crate::error::AppResult; #[derive(Deserialize)] struct SynthesizeRequest { text: String, + ident: String, } async fn synthesize( State(state): State>, - Json(SynthesizeRequest { text }): Json, + Json(SynthesizeRequest { text, ident }): Json, ) -> AppResult { let buffer = { let mut tts_model = state.tts_model.lock().await; let tts_model = if let Some(tts_model) = &*tts_model { tts_model } else { - *tts_model = Some(TTSModel::new( + let mut tts_holder = TTSModelHolder::new( &fs::read(env::var("BERT_MODEL_PATH")?).await?, - &fs::read(env::var("MAIN_MODEL_PATH")?).await?, - &fs::read(env::var("STYLE_VECTORS_PATH")?).await?, &fs::read(env::var("TOKENIZER_PATH")?).await?, - )?); + )?; + tts_holder.load( + "tsukuyomi", + fs::read(env::var("STYLE_VECTORS_PATH")?).await?, + fs::read(env::var("MODEL_PATH")?).await?, + )?; + *tts_model = Some(tts_holder); tts_model.as_ref().unwrap() }; let (bert_ori, phones, tones, lang_ids) = tts_model.parse_text(&text)?; - let style_vector = tts_model.get_style_vector(0, 1.0)?; - tts_model.synthesize(bert_ori.to_owned(), phones, tones, lang_ids, style_vector)? + let style_vector = tts_model.get_style_vector(&ident, 0, 1.0)?; + tts_model.synthesize( + ident, + bert_ori.to_owned(), + phones, + tones, + lang_ids, + style_vector, + )? }; Ok(([(CONTENT_TYPE, "audio/wav")], buffer)) } struct AppState { - tts_model: Arc>>, + tts_model: Arc>>, } #[tokio::main] diff --git a/sbv2_core/src/error.rs b/sbv2_core/src/error.rs index 554c7cc..99aadd9 100644 --- a/sbv2_core/src/error.rs +++ b/sbv2_core/src/error.rs @@ -18,6 +18,8 @@ pub enum Error { IoError(#[from] std::io::Error), #[error("hound error: {0}")] HoundError(#[from] hound::Error), + #[error("model not found error")] + ModelNotFoundError(String), } pub type Result = std::result::Result; diff --git a/sbv2_core/src/main.rs b/sbv2_core/src/main.rs index 87893bc..8db83e9 100644 --- a/sbv2_core/src/main.rs +++ b/sbv2_core/src/main.rs @@ -5,17 +5,21 @@ use sbv2_core::{error, tts}; fn main() -> error::Result<()> { let text = "眠たい"; - let tts_model = tts::TTSModel::new( + let mut tts_model = tts::TTSModelHolder::new( fs::read("models/debert.onnx")?, fs::read("models/model_opt.onnx")?, + )?; + tts_model.load( + "tsukuyomi", fs::read("models/style_vectors.json")?, fs::read("models/tokenizer.json")?, )?; let (bert_ori, phones, tones, lang_ids) = tts_model.parse_text(text)?; - let style_vector = tts_model.get_style_vector(0, 1.0)?; + let style_vector = tts_model.get_style_vector("tsukuyomi", 0, 1.0)?; let data = tts_model.synthesize( + "tsukuyomi", bert_ori.to_owned(), phones.clone(), tones.clone(), @@ -26,6 +30,7 @@ fn main() -> error::Result<()> { let now = Instant::now(); for _ in 0..10 { tts_model.synthesize( + "tsukuyomi", bert_ori.to_owned(), phones.clone(), tones.clone(), diff --git a/sbv2_core/src/style.rs b/sbv2_core/src/style.rs index b26eeb9..5b832fc 100644 --- a/sbv2_core/src/style.rs +++ b/sbv2_core/src/style.rs @@ -17,7 +17,7 @@ pub fn load_style>(path: P) -> Result> { } pub fn get_style_vector( - style_vectors: Array2, + style_vectors: &Array2, style_id: i32, weight: f32, ) -> Result> { diff --git a/sbv2_core/src/tts.rs b/sbv2_core/src/tts.rs index 31cd793..28e7d93 100644 --- a/sbv2_core/src/tts.rs +++ b/sbv2_core/src/tts.rs @@ -1,37 +1,83 @@ -use crate::error::Result; +use crate::error::{Error, Result}; use crate::{bert, jtalk, model, nlp, norm, style, tokenizer, utils}; use ndarray::{concatenate, s, Array, Array1, Array2, Axis}; use ort::Session; use tokenizers::Tokenizer; +#[derive(PartialEq, Eq, Clone)] +pub struct TTSIdent(String); + +impl std::fmt::Display for TTSIdent { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(&self.0)?; + Ok(()) + } +} + +impl From for TTSIdent +where + S: AsRef, +{ + fn from(value: S) -> Self { + TTSIdent(value.as_ref().to_string()) + } +} + pub struct TTSModel { - tokenizer: Tokenizer, - bert: Session, vits2: Session, style_vectors: Array2, + ident: TTSIdent, +} + +pub struct TTSModelHolder { + tokenizer: Tokenizer, + bert: Session, + models: Vec, jtalk: jtalk::JTalk, } -impl TTSModel { - pub fn new>( - bert_model_bytes: P, - main_model_bytes: P, - style_vector_bytes: P, - tokenizer_bytes: P, - ) -> Result { +impl TTSModelHolder { + pub fn new>(bert_model_bytes: P, tokenizer_bytes: P) -> Result { let bert = model::load_model(bert_model_bytes)?; - let vits2 = model::load_model(main_model_bytes)?; - let style_vectors = style::load_style(style_vector_bytes)?; let jtalk = jtalk::JTalk::new()?; let tokenizer = tokenizer::get_tokenizer(tokenizer_bytes)?; - Ok(TTSModel { + Ok(TTSModelHolder { bert, - vits2, - style_vectors, + models: vec![], jtalk, tokenizer, }) } + pub fn load, P: AsRef<[u8]>>( + &mut self, + ident: I, + style_vectors_bytes: P, + vits2_bytes: P, + ) -> Result<()> { + let ident = ident.into(); + if self.find_model(ident.clone()).is_err() { + self.models.push(TTSModel { + vits2: model::load_model(vits2_bytes)?, + style_vectors: style::load_style(style_vectors_bytes)?, + ident, + }) + } + Ok(()) + } + pub fn unload>(&mut self, ident: I) -> bool { + let ident = ident.into(); + if let Some((i, _)) = self + .models + .iter() + .enumerate() + .find(|(_, m)| m.ident == ident) + { + self.models.remove(i); + true + } else { + false + } + } #[allow(clippy::type_complexity)] pub fn parse_text( &self, @@ -94,13 +140,26 @@ impl TTSModel { lang_ids.into(), )) } - - pub fn get_style_vector(&self, style_id: i32, weight: f32) -> Result> { - style::get_style_vector(self.style_vectors.clone(), style_id, weight) + fn find_model>(&self, ident: I) -> Result<&TTSModel> { + let ident = ident.into(); + self.models + .iter() + .find(|m| m.ident == ident) + .ok_or(Error::ModelNotFoundError(ident.to_string())) } - pub fn synthesize( + pub fn get_style_vector>( &self, + ident: I, + style_id: i32, + weight: f32, + ) -> Result> { + style::get_style_vector(&self.find_model(ident)?.style_vectors, style_id, weight) + } + + pub fn synthesize>( + &self, + ident: I, bert_ori: Array2, phones: Array1, tones: Array1, @@ -108,7 +167,7 @@ impl TTSModel { style_vector: Array1, ) -> Result> { let buffer = model::synthesize( - &self.vits2, + &self.find_model(ident)?.vits2, bert_ori.to_owned(), phones, tones, diff --git a/test.py b/test.py index 5cc33b5..2ec36c9 100644 --- a/test.py +++ b/test.py @@ -2,6 +2,7 @@ import requests res = requests.post('http://localhost:3000/synthesize', json={ "text": "おはようございます", + "ident": "tsukuyomi" }) with open('output.wav', 'wb') as f: f.write(res.content) \ No newline at end of file