diff --git a/.gitignore b/.gitignore index b2bf1a3..15abba6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ target models/*.onnx models/*.json -venv \ No newline at end of file +venv +.env \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 75e2966..0e085b4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -357,6 +357,12 @@ dependencies = [ "crypto-common", ] +[[package]] +name = "dotenvy" +version = "0.15.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b" + [[package]] name = "either" version = "1.13.0" @@ -1614,6 +1620,7 @@ version = "0.1.0" dependencies = [ "anyhow", "axum", + "dotenvy", "sbv2_core", "serde", "tokio", diff --git a/output.wav b/output.wav index 18a19de..f1cd61e 100644 Binary files a/output.wav and b/output.wav differ diff --git a/sbv2_api/Cargo.toml b/sbv2_api/Cargo.toml index c32e4d0..82b1656 100644 --- a/sbv2_api/Cargo.toml +++ b/sbv2_api/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" [dependencies] anyhow.workspace = true axum = "0.7.5" +dotenvy = "0.15.7" sbv2_core = { version = "0.1.0", path = "../sbv2_core" } serde = { version = "1.0.210", features = ["derive"] } tokio = { version = "1.40.0", features = ["full"] } diff --git a/sbv2_api/src/main.rs b/sbv2_api/src/main.rs index 793dab6..6da5b7d 100644 --- a/sbv2_api/src/main.rs +++ b/sbv2_api/src/main.rs @@ -7,6 +7,7 @@ use axum::{ }; use sbv2_core::tts::TTSModel; use serde::Deserialize; +use std::env; use std::sync::Arc; use tokio::sync::Mutex; @@ -28,9 +29,9 @@ async fn synthesize( tts_model } else { *tts_model = Some(TTSModel::new( - "models/debert.onnx", - "models/model_opt.onnx", - "models/style_vectors.json", + env::var("BERT_MODEL_PATH")?, + env::var("MAIN_MODEL_PATH")?, + env::var("STYLE_VECTORS_PATH")?, )?); tts_model.as_ref().unwrap() }; @@ -47,6 +48,7 @@ struct AppState { #[tokio::main] async fn main() -> anyhow::Result<()> { + dotenvy::dotenv().ok(); let app = Router::new() .route("/", get(|| async { "Hello, World!" })) .route("/synthesize", post(synthesize)) diff --git a/test.py b/test.py index b587447..418e11d 100644 --- a/test.py +++ b/test.py @@ -2,7 +2,7 @@ import requests res = requests.post('http://localhost:3000/synthesize', json={ - "text": "眠たい" + "text": "おはよう" }) res.raise_for_status() with open('output.wav', 'wb') as f: