mirror of
https://github.com/neodyland/sbv2-api.git
synced 2025-12-22 23:49:58 +00:00
feat: docker(cpu), 1:n bert:vits2 support
This commit is contained in:
6
.dockerignore
Normal file
6
.dockerignore
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
target/
|
||||||
|
models/
|
||||||
|
docker/
|
||||||
|
.env*
|
||||||
|
renovate.json
|
||||||
|
*.py
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
BERT_MODEL_PATH=models/debert.onnx
|
BERT_MODEL_PATH=models/deberta.onnx
|
||||||
MAIN_MODEL_PATH=models/model_opt.onnx
|
MODEL_PATH=models/model_tsukuyomi.onnx
|
||||||
STYLE_VECTORS_PATH=models/style_vectors.json
|
STYLE_VECTORS_PATH=models/style_vectors.json
|
||||||
TOKENIZER_PATH=models/tokenizer.json
|
TOKENIZER_PATH=models/tokenizer.json
|
||||||
9
docker/cpu.Dockerfile
Normal file
9
docker/cpu.Dockerfile
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
FROM rust AS builder
|
||||||
|
WORKDIR /work
|
||||||
|
COPY . .
|
||||||
|
RUN cargo build -r --bin sbv2_api
|
||||||
|
FROM gcr.io/distroless/cc-debian12
|
||||||
|
WORKDIR /work
|
||||||
|
COPY --from=builder /work/target/release/sbv2_api /work/main
|
||||||
|
COPY --from=builder /work/target/release/*.so /work
|
||||||
|
CMD ["/work/main"]
|
||||||
1
docker/run.sh
Normal file
1
docker/run.sh
Normal file
@@ -0,0 +1 @@
|
|||||||
|
docker run -it --rm -p 3000:3000 --name sbv2 -v ./models:/work/models --env-file .env.sample sbv2
|
||||||
@@ -5,7 +5,7 @@ use axum::{
|
|||||||
routing::{get, post},
|
routing::{get, post},
|
||||||
Json, Router,
|
Json, Router,
|
||||||
};
|
};
|
||||||
use sbv2_core::tts::TTSModel;
|
use sbv2_core::tts::TTSModelHolder;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use std::env;
|
use std::env;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
@@ -18,34 +18,46 @@ use crate::error::AppResult;
|
|||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct SynthesizeRequest {
|
struct SynthesizeRequest {
|
||||||
text: String,
|
text: String,
|
||||||
|
ident: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn synthesize(
|
async fn synthesize(
|
||||||
State(state): State<Arc<AppState>>,
|
State(state): State<Arc<AppState>>,
|
||||||
Json(SynthesizeRequest { text }): Json<SynthesizeRequest>,
|
Json(SynthesizeRequest { text, ident }): Json<SynthesizeRequest>,
|
||||||
) -> AppResult<impl IntoResponse> {
|
) -> AppResult<impl IntoResponse> {
|
||||||
let buffer = {
|
let buffer = {
|
||||||
let mut tts_model = state.tts_model.lock().await;
|
let mut tts_model = state.tts_model.lock().await;
|
||||||
let tts_model = if let Some(tts_model) = &*tts_model {
|
let tts_model = if let Some(tts_model) = &*tts_model {
|
||||||
tts_model
|
tts_model
|
||||||
} else {
|
} else {
|
||||||
*tts_model = Some(TTSModel::new(
|
let mut tts_holder = TTSModelHolder::new(
|
||||||
&fs::read(env::var("BERT_MODEL_PATH")?).await?,
|
&fs::read(env::var("BERT_MODEL_PATH")?).await?,
|
||||||
&fs::read(env::var("MAIN_MODEL_PATH")?).await?,
|
|
||||||
&fs::read(env::var("STYLE_VECTORS_PATH")?).await?,
|
|
||||||
&fs::read(env::var("TOKENIZER_PATH")?).await?,
|
&fs::read(env::var("TOKENIZER_PATH")?).await?,
|
||||||
)?);
|
)?;
|
||||||
|
tts_holder.load(
|
||||||
|
"tsukuyomi",
|
||||||
|
fs::read(env::var("STYLE_VECTORS_PATH")?).await?,
|
||||||
|
fs::read(env::var("MODEL_PATH")?).await?,
|
||||||
|
)?;
|
||||||
|
*tts_model = Some(tts_holder);
|
||||||
tts_model.as_ref().unwrap()
|
tts_model.as_ref().unwrap()
|
||||||
};
|
};
|
||||||
let (bert_ori, phones, tones, lang_ids) = tts_model.parse_text(&text)?;
|
let (bert_ori, phones, tones, lang_ids) = tts_model.parse_text(&text)?;
|
||||||
let style_vector = tts_model.get_style_vector(0, 1.0)?;
|
let style_vector = tts_model.get_style_vector(&ident, 0, 1.0)?;
|
||||||
tts_model.synthesize(bert_ori.to_owned(), phones, tones, lang_ids, style_vector)?
|
tts_model.synthesize(
|
||||||
|
ident,
|
||||||
|
bert_ori.to_owned(),
|
||||||
|
phones,
|
||||||
|
tones,
|
||||||
|
lang_ids,
|
||||||
|
style_vector,
|
||||||
|
)?
|
||||||
};
|
};
|
||||||
Ok(([(CONTENT_TYPE, "audio/wav")], buffer))
|
Ok(([(CONTENT_TYPE, "audio/wav")], buffer))
|
||||||
}
|
}
|
||||||
|
|
||||||
struct AppState {
|
struct AppState {
|
||||||
tts_model: Arc<Mutex<Option<TTSModel>>>,
|
tts_model: Arc<Mutex<Option<TTSModelHolder>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
|
|||||||
@@ -18,6 +18,8 @@ pub enum Error {
|
|||||||
IoError(#[from] std::io::Error),
|
IoError(#[from] std::io::Error),
|
||||||
#[error("hound error: {0}")]
|
#[error("hound error: {0}")]
|
||||||
HoundError(#[from] hound::Error),
|
HoundError(#[from] hound::Error),
|
||||||
|
#[error("model not found error")]
|
||||||
|
ModelNotFoundError(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
pub type Result<T> = std::result::Result<T, Error>;
|
pub type Result<T> = std::result::Result<T, Error>;
|
||||||
|
|||||||
@@ -5,17 +5,21 @@ use sbv2_core::{error, tts};
|
|||||||
fn main() -> error::Result<()> {
|
fn main() -> error::Result<()> {
|
||||||
let text = "眠たい";
|
let text = "眠たい";
|
||||||
|
|
||||||
let tts_model = tts::TTSModel::new(
|
let mut tts_model = tts::TTSModelHolder::new(
|
||||||
fs::read("models/debert.onnx")?,
|
fs::read("models/debert.onnx")?,
|
||||||
fs::read("models/model_opt.onnx")?,
|
fs::read("models/model_opt.onnx")?,
|
||||||
|
)?;
|
||||||
|
tts_model.load(
|
||||||
|
"tsukuyomi",
|
||||||
fs::read("models/style_vectors.json")?,
|
fs::read("models/style_vectors.json")?,
|
||||||
fs::read("models/tokenizer.json")?,
|
fs::read("models/tokenizer.json")?,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let (bert_ori, phones, tones, lang_ids) = tts_model.parse_text(text)?;
|
let (bert_ori, phones, tones, lang_ids) = tts_model.parse_text(text)?;
|
||||||
|
|
||||||
let style_vector = tts_model.get_style_vector(0, 1.0)?;
|
let style_vector = tts_model.get_style_vector("tsukuyomi", 0, 1.0)?;
|
||||||
let data = tts_model.synthesize(
|
let data = tts_model.synthesize(
|
||||||
|
"tsukuyomi",
|
||||||
bert_ori.to_owned(),
|
bert_ori.to_owned(),
|
||||||
phones.clone(),
|
phones.clone(),
|
||||||
tones.clone(),
|
tones.clone(),
|
||||||
@@ -26,6 +30,7 @@ fn main() -> error::Result<()> {
|
|||||||
let now = Instant::now();
|
let now = Instant::now();
|
||||||
for _ in 0..10 {
|
for _ in 0..10 {
|
||||||
tts_model.synthesize(
|
tts_model.synthesize(
|
||||||
|
"tsukuyomi",
|
||||||
bert_ori.to_owned(),
|
bert_ori.to_owned(),
|
||||||
phones.clone(),
|
phones.clone(),
|
||||||
tones.clone(),
|
tones.clone(),
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ pub fn load_style<P: AsRef<[u8]>>(path: P) -> Result<Array2<f32>> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_style_vector(
|
pub fn get_style_vector(
|
||||||
style_vectors: Array2<f32>,
|
style_vectors: &Array2<f32>,
|
||||||
style_id: i32,
|
style_id: i32,
|
||||||
weight: f32,
|
weight: f32,
|
||||||
) -> Result<Array1<f32>> {
|
) -> Result<Array1<f32>> {
|
||||||
|
|||||||
@@ -1,37 +1,83 @@
|
|||||||
use crate::error::Result;
|
use crate::error::{Error, Result};
|
||||||
use crate::{bert, jtalk, model, nlp, norm, style, tokenizer, utils};
|
use crate::{bert, jtalk, model, nlp, norm, style, tokenizer, utils};
|
||||||
use ndarray::{concatenate, s, Array, Array1, Array2, Axis};
|
use ndarray::{concatenate, s, Array, Array1, Array2, Axis};
|
||||||
use ort::Session;
|
use ort::Session;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
#[derive(PartialEq, Eq, Clone)]
|
||||||
|
pub struct TTSIdent(String);
|
||||||
|
|
||||||
|
impl std::fmt::Display for TTSIdent {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.write_str(&self.0)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S> From<S> for TTSIdent
|
||||||
|
where
|
||||||
|
S: AsRef<str>,
|
||||||
|
{
|
||||||
|
fn from(value: S) -> Self {
|
||||||
|
TTSIdent(value.as_ref().to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub struct TTSModel {
|
pub struct TTSModel {
|
||||||
tokenizer: Tokenizer,
|
|
||||||
bert: Session,
|
|
||||||
vits2: Session,
|
vits2: Session,
|
||||||
style_vectors: Array2<f32>,
|
style_vectors: Array2<f32>,
|
||||||
|
ident: TTSIdent,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct TTSModelHolder {
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
bert: Session,
|
||||||
|
models: Vec<TTSModel>,
|
||||||
jtalk: jtalk::JTalk,
|
jtalk: jtalk::JTalk,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TTSModel {
|
impl TTSModelHolder {
|
||||||
pub fn new<P: AsRef<[u8]>>(
|
pub fn new<P: AsRef<[u8]>>(bert_model_bytes: P, tokenizer_bytes: P) -> Result<Self> {
|
||||||
bert_model_bytes: P,
|
|
||||||
main_model_bytes: P,
|
|
||||||
style_vector_bytes: P,
|
|
||||||
tokenizer_bytes: P,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let bert = model::load_model(bert_model_bytes)?;
|
let bert = model::load_model(bert_model_bytes)?;
|
||||||
let vits2 = model::load_model(main_model_bytes)?;
|
|
||||||
let style_vectors = style::load_style(style_vector_bytes)?;
|
|
||||||
let jtalk = jtalk::JTalk::new()?;
|
let jtalk = jtalk::JTalk::new()?;
|
||||||
let tokenizer = tokenizer::get_tokenizer(tokenizer_bytes)?;
|
let tokenizer = tokenizer::get_tokenizer(tokenizer_bytes)?;
|
||||||
Ok(TTSModel {
|
Ok(TTSModelHolder {
|
||||||
bert,
|
bert,
|
||||||
vits2,
|
models: vec![],
|
||||||
style_vectors,
|
|
||||||
jtalk,
|
jtalk,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
pub fn load<I: Into<TTSIdent>, P: AsRef<[u8]>>(
|
||||||
|
&mut self,
|
||||||
|
ident: I,
|
||||||
|
style_vectors_bytes: P,
|
||||||
|
vits2_bytes: P,
|
||||||
|
) -> Result<()> {
|
||||||
|
let ident = ident.into();
|
||||||
|
if self.find_model(ident.clone()).is_err() {
|
||||||
|
self.models.push(TTSModel {
|
||||||
|
vits2: model::load_model(vits2_bytes)?,
|
||||||
|
style_vectors: style::load_style(style_vectors_bytes)?,
|
||||||
|
ident,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
pub fn unload<I: Into<TTSIdent>>(&mut self, ident: I) -> bool {
|
||||||
|
let ident = ident.into();
|
||||||
|
if let Some((i, _)) = self
|
||||||
|
.models
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.find(|(_, m)| m.ident == ident)
|
||||||
|
{
|
||||||
|
self.models.remove(i);
|
||||||
|
true
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
#[allow(clippy::type_complexity)]
|
#[allow(clippy::type_complexity)]
|
||||||
pub fn parse_text(
|
pub fn parse_text(
|
||||||
&self,
|
&self,
|
||||||
@@ -94,13 +140,26 @@ impl TTSModel {
|
|||||||
lang_ids.into(),
|
lang_ids.into(),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
fn find_model<I: Into<TTSIdent>>(&self, ident: I) -> Result<&TTSModel> {
|
||||||
pub fn get_style_vector(&self, style_id: i32, weight: f32) -> Result<Array1<f32>> {
|
let ident = ident.into();
|
||||||
style::get_style_vector(self.style_vectors.clone(), style_id, weight)
|
self.models
|
||||||
|
.iter()
|
||||||
|
.find(|m| m.ident == ident)
|
||||||
|
.ok_or(Error::ModelNotFoundError(ident.to_string()))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn synthesize(
|
pub fn get_style_vector<I: Into<TTSIdent>>(
|
||||||
&self,
|
&self,
|
||||||
|
ident: I,
|
||||||
|
style_id: i32,
|
||||||
|
weight: f32,
|
||||||
|
) -> Result<Array1<f32>> {
|
||||||
|
style::get_style_vector(&self.find_model(ident)?.style_vectors, style_id, weight)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn synthesize<I: Into<TTSIdent>>(
|
||||||
|
&self,
|
||||||
|
ident: I,
|
||||||
bert_ori: Array2<f32>,
|
bert_ori: Array2<f32>,
|
||||||
phones: Array1<i64>,
|
phones: Array1<i64>,
|
||||||
tones: Array1<i64>,
|
tones: Array1<i64>,
|
||||||
@@ -108,7 +167,7 @@ impl TTSModel {
|
|||||||
style_vector: Array1<f32>,
|
style_vector: Array1<f32>,
|
||||||
) -> Result<Vec<u8>> {
|
) -> Result<Vec<u8>> {
|
||||||
let buffer = model::synthesize(
|
let buffer = model::synthesize(
|
||||||
&self.vits2,
|
&self.find_model(ident)?.vits2,
|
||||||
bert_ori.to_owned(),
|
bert_ori.to_owned(),
|
||||||
phones,
|
phones,
|
||||||
tones,
|
tones,
|
||||||
|
|||||||
1
test.py
1
test.py
@@ -2,6 +2,7 @@ import requests
|
|||||||
|
|
||||||
res = requests.post('http://localhost:3000/synthesize', json={
|
res = requests.post('http://localhost:3000/synthesize', json={
|
||||||
"text": "おはようございます",
|
"text": "おはようございます",
|
||||||
|
"ident": "tsukuyomi"
|
||||||
})
|
})
|
||||||
with open('output.wav', 'wb') as f:
|
with open('output.wav', 'wb') as f:
|
||||||
f.write(res.content)
|
f.write(res.content)
|
||||||
Reference in New Issue
Block a user