breaking: support of length_scale, sdp_ratio, /models endpoint

This commit is contained in:
Googlefan
2024-09-11 04:42:11 +00:00
parent 83b69083ca
commit 441e35b9a6
12 changed files with 243 additions and 49 deletions

View File

@@ -6,7 +6,9 @@ edition = "2021"
[dependencies]
anyhow.workspace = true
axum = "0.7.5"
dotenvy = "0.15.7"
dotenvy.workspace = true
env_logger = "0.11.5"
log = "0.4.22"
sbv2_core = { version = "0.1.0", path = "../sbv2_core" }
serde = { version = "1.0.210", features = ["derive"] }
tokio = { version = "1.40.0", features = ["full"] }
@@ -14,4 +16,4 @@ tokio = { version = "1.40.0", features = ["full"] }
[features]
cuda = ["sbv2_core/cuda"]
cuda_tf32 = ["sbv2_core/cuda_tf32"]
dynamic = ["sbv2_core/dynamic"]
dynamic = ["sbv2_core/dynamic"]

View File

@@ -15,33 +15,39 @@ use tokio::sync::Mutex;
mod error;
use crate::error::AppResult;
async fn models(State(state): State<AppState>) -> AppResult<impl IntoResponse> {
Ok(Json(state.tts_model.lock().await.models()))
}
fn sdp_default() -> f32 {
0.0
}
fn length_default() -> f32 {
1.0
}
#[derive(Deserialize)]
struct SynthesizeRequest {
text: String,
ident: String,
#[serde(default = "sdp_default")]
sdp_ratio: f32,
#[serde(default = "length_default")]
length_scale: f32,
}
async fn synthesize(
State(state): State<Arc<AppState>>,
Json(SynthesizeRequest { text, ident }): Json<SynthesizeRequest>,
State(state): State<AppState>,
Json(SynthesizeRequest {
text,
ident,
sdp_ratio,
length_scale,
}): Json<SynthesizeRequest>,
) -> AppResult<impl IntoResponse> {
log::debug!("processing request: text={text}, ident={ident}, sdp_ratio={sdp_ratio}, length_scale={length_scale}");
let buffer = {
let mut tts_model = state.tts_model.lock().await;
let tts_model = if let Some(tts_model) = &*tts_model {
tts_model
} else {
let mut tts_holder = TTSModelHolder::new(
&fs::read(env::var("BERT_MODEL_PATH")?).await?,
&fs::read(env::var("TOKENIZER_PATH")?).await?,
)?;
tts_holder.load(
"tsukuyomi",
fs::read(env::var("STYLE_VECTORS_PATH")?).await?,
fs::read(env::var("MODEL_PATH")?).await?,
)?;
*tts_model = Some(tts_holder);
tts_model.as_ref().unwrap()
};
let tts_model = state.tts_model.lock().await;
let (bert_ori, phones, tones, lang_ids) = tts_model.parse_text(&text)?;
let style_vector = tts_model.get_style_vector(&ident, 0, 1.0)?;
tts_model.synthesize(
@@ -51,26 +57,78 @@ async fn synthesize(
tones,
lang_ids,
style_vector,
sdp_ratio,
length_scale,
)?
};
Ok(([(CONTENT_TYPE, "audio/wav")], buffer))
}
#[derive(Clone)]
struct AppState {
tts_model: Arc<Mutex<Option<TTSModelHolder>>>,
tts_model: Arc<Mutex<TTSModelHolder>>,
}
impl AppState {
pub async fn new() -> anyhow::Result<Self> {
let mut tts_model = TTSModelHolder::new(
&fs::read(env::var("BERT_MODEL_PATH")?).await?,
&fs::read(env::var("TOKENIZER_PATH")?).await?,
)?;
let models = env::var("MODELS_PATH").unwrap_or("models".to_string());
let mut f = fs::read_dir(&models).await?;
let mut entries = vec![];
while let Ok(Some(e)) = f.next_entry().await {
let name = e.file_name().to_string_lossy().to_string();
if name.ends_with(".onnx") && name.starts_with("model_") {
let name_len = name.len();
let name = name.chars();
entries.push(
name.collect::<Vec<_>>()[6..name_len - 5]
.iter()
.collect::<String>(),
);
}
}
for entry in entries {
log::info!("Try loading: {entry}");
let style_vectors_bytes =
match fs::read(format!("{models}/style_vectors_{entry}.json")).await {
Ok(b) => b,
Err(e) => {
log::warn!("Error loading style_vectors_bytes from file {entry}: {e}");
continue;
}
};
let vits2_bytes = match fs::read(format!("{models}/model_{entry}.onnx")).await {
Ok(b) => b,
Err(e) => {
log::warn!("Error loading vits2_bytes from file {entry}: {e}");
continue;
}
};
if let Err(e) = tts_model.load(&entry, style_vectors_bytes, vits2_bytes) {
log::warn!("Error loading {entry}: {e}");
};
}
Ok(Self {
tts_model: Arc::new(Mutex::new(tts_model)),
})
}
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
dotenvy::dotenv().ok();
env_logger::init();
let app = Router::new()
.route("/", get(|| async { "Hello, World!" }))
.route("/synthesize", post(synthesize))
.with_state(Arc::new(AppState {
tts_model: Arc::new(Mutex::new(None)),
}));
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await?;
.route("/models", get(models))
.with_state(AppState::new().await?);
let addr = env::var("ADDR").unwrap_or("0.0.0.0:3000".to_string());
let listener = tokio::net::TcpListener::bind(&addr).await?;
log::info!("Listening on {addr}");
axum::serve(listener, app).await?;
Ok(())