mirror of
https://github.com/neodyland/sbv2-api.git
synced 2026-01-06 14:32:57 +00:00
Merge branch 'main' of https://github.com/tuna2134/sbv2-api into python
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
BERT_MODEL_PATH=models/deberta.onnx
|
BERT_MODEL_PATH=models/deberta.onnx
|
||||||
MODEL_PATH=models/model_tsukuyomi.onnx
|
MODEL_PATH=models/tsukuyomi.sbv2
|
||||||
MODELS_PATH=models
|
MODELS_PATH=models
|
||||||
STYLE_VECTORS_PATH=models/style_vectors.json
|
|
||||||
TOKENIZER_PATH=models/tokenizer.json
|
TOKENIZER_PATH=models/tokenizer.json
|
||||||
ADDR=localhost:3000
|
ADDR=localhost:3000
|
||||||
|
RUST_LOG=warn
|
||||||
7
.gitignore
vendored
7
.gitignore
vendored
@@ -1,6 +1,7 @@
|
|||||||
target
|
target
|
||||||
models/*.onnx
|
models/
|
||||||
models/*.json
|
!models/.gitkeep
|
||||||
venv/
|
venv/
|
||||||
.env
|
.env
|
||||||
output.wav
|
output.wav
|
||||||
|
node_modules
|
||||||
44
Cargo.lock
generated
44
Cargo.lock
generated
@@ -228,6 +228,8 @@ version = "1.1.18"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "b62ac837cdb5cb22e10a256099b4fc502b1dfe560cb282963a974d7abd80e476"
|
checksum = "b62ac837cdb5cb22e10a256099b4fc502b1dfe560cb282963a974d7abd80e476"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"jobserver",
|
||||||
|
"libc",
|
||||||
"shlex",
|
"shlex",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -854,6 +856,15 @@ dependencies = [
|
|||||||
"thiserror",
|
"thiserror",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "jobserver"
|
||||||
|
version = "0.1.32"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0"
|
||||||
|
dependencies = [
|
||||||
|
"libc",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "jpreprocess"
|
name = "jpreprocess"
|
||||||
version = "0.10.0"
|
version = "0.10.0"
|
||||||
@@ -1829,10 +1840,11 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "sbv2_core"
|
name = "sbv2_core"
|
||||||
version = "0.1.0"
|
version = "0.1.1"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"dotenvy",
|
"dotenvy",
|
||||||
|
"env_logger",
|
||||||
"hound",
|
"hound",
|
||||||
"jpreprocess",
|
"jpreprocess",
|
||||||
"ndarray",
|
"ndarray",
|
||||||
@@ -1842,8 +1854,10 @@ dependencies = [
|
|||||||
"regex",
|
"regex",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
"tar",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"tokenizers",
|
"tokenizers",
|
||||||
|
"zstd",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -2468,3 +2482,31 @@ name = "zeroize"
|
|||||||
version = "1.8.1"
|
version = "1.8.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde"
|
checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "zstd"
|
||||||
|
version = "0.13.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "fcf2b778a664581e31e389454a7072dab1647606d44f7feea22cd5abb9c9f3f9"
|
||||||
|
dependencies = [
|
||||||
|
"zstd-safe",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "zstd-safe"
|
||||||
|
version = "7.2.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "54a3ab4db68cea366acc5c897c7b4d4d1b8994a9cd6e6f841f8964566a419059"
|
||||||
|
dependencies = [
|
||||||
|
"zstd-sys",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "zstd-sys"
|
||||||
|
version = "2.0.13+zstd.1.5.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "38ff0f21cfee8f97d94cef41359e0c89aa6113028ab0291aa8ca0038995a95aa"
|
||||||
|
dependencies = [
|
||||||
|
"cc",
|
||||||
|
"pkg-config",
|
||||||
|
]
|
||||||
|
|||||||
@@ -5,3 +5,4 @@ members = ["sbv2_api", "sbv2_core", "sbv2_bindings"]
|
|||||||
[workspace.dependencies]
|
[workspace.dependencies]
|
||||||
anyhow = "1.0.86"
|
anyhow = "1.0.86"
|
||||||
dotenvy = "0.15.7"
|
dotenvy = "0.15.7"
|
||||||
|
env_logger = "0.11.5"
|
||||||
|
|||||||
100
README.md
100
README.md
@@ -1,52 +1,98 @@
|
|||||||
# sbv2-api
|
# SBV2-API
|
||||||
このプロジェクトはStyle-Bert-ViTS2をONNX化したものをRustで実行するのを目的としています。
|
|
||||||
|
|
||||||
学習したい場合は、Style-Bert-ViTS2 学習方法 などで調べるとよいかもしれません。
|
## プログラミングに詳しくない方向け
|
||||||
|
|
||||||
JP-Extraしか対応していません。(基本的に対応する予定もありません)
|
[こちら](https://github.com/tuna2134/sbv2-gui?tab=readme-ov-file)を参照してください。
|
||||||
|
|
||||||
## ONNX化する方法
|
コマンドやpythonの知識なしで簡単に使えるバージョンです。(できることはほぼ同じ)
|
||||||
```sh
|
|
||||||
cd convert
|
## このプロジェクトについて
|
||||||
# (何かしらの方法でvenv作成(推奨))
|
|
||||||
pip install -r requirements.txt
|
このプロジェクトは Style-Bert-ViTS2 を ONNX 化したものを Rust で実行するのを目的としたライブラリです。
|
||||||
python convert_deberta.py
|
|
||||||
python convert_model.py --style_file ../../style-bert-vits2/model_assets/something/style_vectors.npy --config_file ../../style-bert-vits2/model_assets/something/config.json --model_file ../../style-bert-vits2/model_assets/something/something_eXXX_sXXXX.safetensors
|
JP-Extra しか対応していません。(基本的に対応する予定もありません)
|
||||||
```
|
|
||||||
|
## 変換方法
|
||||||
|
|
||||||
|
[こちら](https://github.com/tuna2134/sbv2-api/tree/main/convert)を参照してください。
|
||||||
|
|
||||||
## Todo
|
## Todo
|
||||||
- [x] WebAPIの実装
|
|
||||||
- [x] Rustライブラリの実装
|
- [x] REST API の実装
|
||||||
- [ ] 余裕があればPyO3使ってPythonで利用可能にする
|
- [x] Rust ライブラリの実装
|
||||||
- [x] GPU対応(優先的にCUDA)
|
- [x] `.sbv2`フォーマットの開発
|
||||||
- [ ] WASM変換(ortがサポートやめたので、中止)
|
- [ ] PyO3 を利用し、 Python から使えるようにする
|
||||||
|
- [x] GPU 対応(CUDA)
|
||||||
|
- [x] GPU 対応(DirectML)
|
||||||
|
- [ ] WASM 変換(依存ライブラリの関係により現在は不可)
|
||||||
|
|
||||||
## 構造説明
|
## 構造説明
|
||||||
|
|
||||||
- `sbv2_api` - 推論用 REST API
|
- `sbv2_api` - 推論用 REST API
|
||||||
- `sbv2_core` - 推論コア部分
|
- `sbv2_core` - 推論コア部分
|
||||||
- `docker` - dockerビルドスクリプト
|
- `docker` - docker ビルドスクリプト
|
||||||
|
- `convert` - onnx, sbv2フォーマットへの変換スクリプト
|
||||||
|
|
||||||
|
## プログラミングある程度できる人向けREST API起動方法
|
||||||
|
|
||||||
|
### models をインストール
|
||||||
|
|
||||||
|
https://huggingface.co/googlefan/sbv2_onnx_models/tree/main
|
||||||
|
の`tokenizer.json`,`debert.onnx`,`tsukuyomi.sbv2`を models フォルダに配置
|
||||||
|
|
||||||
|
### .env ファイルの作成
|
||||||
|
|
||||||
## APIの起動方法
|
|
||||||
```sh
|
```sh
|
||||||
cargo run -p sbv2_api -r
|
cp .env.sample .env
|
||||||
```
|
```
|
||||||
|
|
||||||
### CUDAでの起動
|
### 起動
|
||||||
|
|
||||||
|
CPUの場合は
|
||||||
```sh
|
```sh
|
||||||
cargo run -p sbv2_api -r -F cuda,cuda_tf32
|
docker run -it --rm -p 3000:3000 --name sbv2 \
|
||||||
|
-v ./models:/work/models --env-file .env \
|
||||||
|
ghcr.io/tuna2134/sbv2-api:cpu
|
||||||
```
|
```
|
||||||
|
|
||||||
### Dynamic Linkサポート
|
CUDAの場合は
|
||||||
```sh
|
```sh
|
||||||
ORT_DYLIB_PATH=./libonnxruntime.dll cargo run -p sbv2_api -r -F dynamic
|
docker run -it --rm -p 3000:3000 --name sbv2 \
|
||||||
|
-v ./models:/work/models --env-file .env \
|
||||||
|
--gpus all \
|
||||||
|
ghcr.io/tuna2134/sbv2-api:cuda
|
||||||
```
|
```
|
||||||
|
|
||||||
### テストコマンド
|
### 起動確認
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
curl -XPOST -H "Content-type: application/json" -d '{"text": "こんにちは","ident": "something"}' 'http://localhost:3000/synthesize'
|
curl -XPOST -H "Content-type: application/json" -d '{"text": "こんにちは","ident": "tsukuyomi"}' 'http://localhost:3000/synthesize' --output "output.wav"
|
||||||
curl http://localhost:3000/models
|
curl http://localhost:3000/models
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## 開発者向けガイド
|
||||||
|
|
||||||
|
### Feature flags
|
||||||
|
|
||||||
|
`sbv2_api`、`sbv2_core`共に
|
||||||
|
- `cuda` featureでcuda
|
||||||
|
- `cuda_tf32` featureでcudaのtf32機能
|
||||||
|
- `tensorrt` featureでbert部分のtensorrt利用
|
||||||
|
- `dynamic` featureで手元のonnxruntime共有ライブラリを利用(`ORT_DYLIB_PATH=./libonnxruntime.dll`などで指定)
|
||||||
|
- `directml` featureでdirectmlの利用
|
||||||
|
ができます。
|
||||||
|
|
||||||
|
### 環境変数
|
||||||
|
|
||||||
|
以下の環境変数はライブラリ側では適用されません。
|
||||||
|
|
||||||
|
ライブラリAPIについては`https://docs.rs/sbv2_core`を参照してください。
|
||||||
|
|
||||||
|
- `ADDR` `localhost:3000`などのようにサーバー起動アドレスをコントロールできます。
|
||||||
|
- `MODELS_PATH` sbv2モデルの存在するフォルダを指定できます。
|
||||||
|
- `RUST_LOG` おなじみlog levelです。
|
||||||
|
|
||||||
## 謝辞
|
## 謝辞
|
||||||
|
|
||||||
- [litagin02/Style-Bert-VITS2](https://github.com/litagin02/Style-Bert-VITS2) - このコードの書くにあたり、ベースとなる部分を参考にさせていただきました。
|
- [litagin02/Style-Bert-VITS2](https://github.com/litagin02/Style-Bert-VITS2) - このコードの書くにあたり、ベースとなる部分を参考にさせていただきました。
|
||||||
- [Googlefan](https://github.com/Googlefan256) - 彼にモデルをONNXヘ変換および効率化をする方法を教わりました。
|
- [Googlefan](https://github.com/Googlefan256) - 彼にモデルを ONNX ヘ変換および効率化をする方法を教わりました。
|
||||||
|
|||||||
36
convert/README.md
Normal file
36
convert/README.md
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
# 変換方法
|
||||||
|
|
||||||
|
## 初心者向け準備
|
||||||
|
|
||||||
|
わかる人は飛ばしてください。
|
||||||
|
|
||||||
|
1. pythonを入れます。3.11.8で動作確認をしていますが、最近のバージョンなら大体動くはずです。
|
||||||
|
|
||||||
|
4. `cd convert`
|
||||||
|
|
||||||
|
3. `python -m venv venv`
|
||||||
|
|
||||||
|
4. `source venv/bin/activate`
|
||||||
|
|
||||||
|
5. `pip install -r requirements.txt`
|
||||||
|
|
||||||
|
## モデル変換
|
||||||
|
|
||||||
|
1. 変換したいモデルの`.safetensors`で終わるファイルの位置を特定してください。
|
||||||
|
|
||||||
|
2. 同様に`config.json`、`style_vectors.npy`というファイルを探してください。
|
||||||
|
|
||||||
|
3. 以下のコマンドを実行します。
|
||||||
|
```sh
|
||||||
|
python convert_model.py --style_file "ここにstyle_vectors.npyの場所" --config_file "同様にconfig.json場所" --model_file "同様に.safetensorsで終わるファイルの場所"
|
||||||
|
```
|
||||||
|
|
||||||
|
4. `models/名前.sbv2`というファイルが出力されます。GUI版のモデルファイルに入れてあげたら使えます。
|
||||||
|
|
||||||
|
## Deberta変換
|
||||||
|
|
||||||
|
意味が分からないならおそらく変換しなくてもいいってことです。
|
||||||
|
|
||||||
|
venvを用意し、requirementsを入れて、`python convert_model.py`を実行するだけです。
|
||||||
|
|
||||||
|
`models/deberta.onnx`と`models/tokenizer.json`が出力されたら成功です。
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import json
|
import json
|
||||||
|
from io import BytesIO
|
||||||
from style_bert_vits2.nlp import bert_models
|
from style_bert_vits2.nlp import bert_models
|
||||||
from style_bert_vits2.constants import Languages
|
from style_bert_vits2.constants import Languages
|
||||||
from style_bert_vits2.models.infer import get_net_g, get_text
|
from style_bert_vits2.models.infer import get_net_g, get_text
|
||||||
@@ -11,6 +12,9 @@ from style_bert_vits2.constants import (
|
|||||||
DEFAULT_STYLE_WEIGHT,
|
DEFAULT_STYLE_WEIGHT,
|
||||||
Languages,
|
Languages,
|
||||||
)
|
)
|
||||||
|
import os
|
||||||
|
from tarfile import open as taropen, TarInfo
|
||||||
|
from zstandard import ZstdCompressor
|
||||||
from style_bert_vits2.tts_model import TTSModel
|
from style_bert_vits2.tts_model import TTSModel
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
@@ -141,3 +145,23 @@ torch.onnx.export(
|
|||||||
],
|
],
|
||||||
output_names=["output"],
|
output_names=["output"],
|
||||||
)
|
)
|
||||||
|
os.system(f"onnxsim ../models/model_{out_name}.onnx ../models/model_{out_name}.onnx")
|
||||||
|
onnxfile = open(f"../models/model_{out_name}.onnx", "rb").read()
|
||||||
|
stylefile = open(f"../models/style_vectors_{out_name}.json", "rb").read()
|
||||||
|
version = bytes("1", "utf8")
|
||||||
|
with taropen(f"../models/tmp_{out_name}.sbv2tar", "w") as w:
|
||||||
|
|
||||||
|
def add_tar(f, b):
|
||||||
|
t = TarInfo(f)
|
||||||
|
t.size = len(b)
|
||||||
|
w.addfile(t, BytesIO(b))
|
||||||
|
|
||||||
|
add_tar("version.txt", version)
|
||||||
|
add_tar("model.onnx", onnxfile)
|
||||||
|
add_tar("style_vectors.json", stylefile)
|
||||||
|
open(f"../models/{out_name}.sbv2", "wb").write(
|
||||||
|
ZstdCompressor(threads=-1, level=22).compress(
|
||||||
|
open(f"../models/tmp_{out_name}.sbv2tar", "rb").read()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
os.unlink(f"../models/tmp_{out_name}.sbv2tar")
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
style-bert-vits2
|
style-bert-vits2
|
||||||
onnxsim
|
onnxsim
|
||||||
numpy<3
|
numpy<2
|
||||||
|
zstandard
|
||||||
@@ -2,9 +2,9 @@ FROM rust AS builder
|
|||||||
WORKDIR /work
|
WORKDIR /work
|
||||||
COPY . .
|
COPY . .
|
||||||
RUN cargo build -r --bin sbv2_api -F cuda,cuda_tf32
|
RUN cargo build -r --bin sbv2_api -F cuda,cuda_tf32
|
||||||
|
FROM nvidia/cuda:12.3.2-cudnn9-runtime-ubuntu22.04
|
||||||
FROM nvidia/cuda:12.6.1-cudnn-runtime-ubuntu24.04
|
|
||||||
WORKDIR /work
|
WORKDIR /work
|
||||||
COPY --from=builder /work/target/release/sbv2_api /work/main
|
COPY --from=builder /work/target/release/sbv2_api /work/main
|
||||||
COPY --from=builder /work/target/release/*.so /work
|
COPY --from=builder /work/target/release/*.so /work
|
||||||
|
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/work
|
||||||
CMD ["/work/main"]
|
CMD ["/work/main"]
|
||||||
@@ -1 +0,0 @@
|
|||||||
docker run -it --rm -p 3000:3000 --name sbv2 -v ./models:/work/models --env-file .env sbv2
|
|
||||||
3
docker/run_cpu.sh
Executable file
3
docker/run_cpu.sh
Executable file
@@ -0,0 +1,3 @@
|
|||||||
|
docker run -it --rm -p 3000:3000 --name sbv2 \
|
||||||
|
-v ./models:/work/models --env-file .env \
|
||||||
|
ghcr.io/tuna2134/sbv2-api:cpu
|
||||||
4
docker/run_cuda.sh
Executable file
4
docker/run_cuda.sh
Executable file
@@ -0,0 +1,4 @@
|
|||||||
|
docker run -it --rm -p 3000:3000 --name sbv2 \
|
||||||
|
-v ./models:/work/models --env-file .env \
|
||||||
|
--gpus all \
|
||||||
|
ghcr.io/tuna2134/sbv2-api:cuda
|
||||||
@@ -7,13 +7,16 @@ edition = "2021"
|
|||||||
anyhow.workspace = true
|
anyhow.workspace = true
|
||||||
axum = "0.7.5"
|
axum = "0.7.5"
|
||||||
dotenvy.workspace = true
|
dotenvy.workspace = true
|
||||||
env_logger = "0.11.5"
|
env_logger.workspace = true
|
||||||
log = "0.4.22"
|
log = "0.4.22"
|
||||||
sbv2_core = { version = "0.1.0", path = "../sbv2_core" }
|
sbv2_core = { version = "0.1.1", path = "../sbv2_core" }
|
||||||
serde = { version = "1.0.210", features = ["derive"] }
|
serde = { version = "1.0.210", features = ["derive"] }
|
||||||
tokio = { version = "1.40.0", features = ["full"] }
|
tokio = { version = "1.40.0", features = ["full"] }
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
|
coreml = ["sbv2_core/coreml"]
|
||||||
cuda = ["sbv2_core/cuda"]
|
cuda = ["sbv2_core/cuda"]
|
||||||
cuda_tf32 = ["sbv2_core/cuda_tf32"]
|
cuda_tf32 = ["sbv2_core/cuda_tf32"]
|
||||||
dynamic = ["sbv2_core/dynamic"]
|
dynamic = ["sbv2_core/dynamic"]
|
||||||
|
directml = ["sbv2_core/directml"]
|
||||||
|
tensorrt = ["sbv2_core/tensorrt"]
|
||||||
@@ -26,6 +26,7 @@ fn sdp_default() -> f32 {
|
|||||||
fn length_default() -> f32 {
|
fn length_default() -> f32 {
|
||||||
1.0
|
1.0
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct SynthesizeRequest {
|
struct SynthesizeRequest {
|
||||||
text: String,
|
text: String,
|
||||||
@@ -88,6 +89,20 @@ impl AppState {
|
|||||||
.iter()
|
.iter()
|
||||||
.collect::<String>(),
|
.collect::<String>(),
|
||||||
);
|
);
|
||||||
|
} else if name.ends_with(".sbv2") {
|
||||||
|
let entry = &name[..name.len() - 5];
|
||||||
|
log::info!("Try loading: {entry}");
|
||||||
|
let sbv2_bytes = match fs::read(format!("{models}/{entry}.sbv2")).await {
|
||||||
|
Ok(b) => b,
|
||||||
|
Err(e) => {
|
||||||
|
log::warn!("Error loading sbv2_bytes from file {entry}: {e}");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
if let Err(e) = tts_model.load_sbv2file(entry, sbv2_bytes) {
|
||||||
|
log::warn!("Error loading {entry}: {e}");
|
||||||
|
};
|
||||||
|
log::info!("Loaded: {entry}");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for entry in entries {
|
for entry in entries {
|
||||||
@@ -110,6 +125,7 @@ impl AppState {
|
|||||||
if let Err(e) = tts_model.load(&entry, style_vectors_bytes, vits2_bytes) {
|
if let Err(e) = tts_model.load(&entry, style_vectors_bytes, vits2_bytes) {
|
||||||
log::warn!("Error loading {entry}: {e}");
|
log::warn!("Error loading {entry}: {e}");
|
||||||
};
|
};
|
||||||
|
log::info!("Loaded: {entry}");
|
||||||
}
|
}
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
tts_model: Arc::new(Mutex::new(tts_model)),
|
tts_model: Arc::new(Mutex::new(tts_model)),
|
||||||
@@ -119,7 +135,7 @@ impl AppState {
|
|||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> anyhow::Result<()> {
|
async fn main() -> anyhow::Result<()> {
|
||||||
dotenvy::dotenv().ok();
|
dotenvy::dotenv_override().ok();
|
||||||
env_logger::init();
|
env_logger::init();
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
.route("/", get(|| async { "Hello, World!" }))
|
.route("/", get(|| async { "Hello, World!" }))
|
||||||
|
|||||||
@@ -1,11 +1,16 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "sbv2_core"
|
name = "sbv2_core"
|
||||||
version = "0.1.0"
|
description = "Style-Bert-VITSの推論ライブラリ"
|
||||||
|
version = "0.1.1"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
license = "MIT"
|
||||||
|
readme = "../README.md"
|
||||||
|
repository = "https://github.com/tuna2134/sbv2-api"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
anyhow.workspace = true
|
anyhow.workspace = true
|
||||||
dotenvy.workspace = true
|
dotenvy.workspace = true
|
||||||
|
env_logger.workspace = true
|
||||||
hound = "3.5.1"
|
hound = "3.5.1"
|
||||||
jpreprocess = { version = "0.10.0", features = ["naist-jdic"] }
|
jpreprocess = { version = "0.10.0", features = ["naist-jdic"] }
|
||||||
ndarray = "0.16.1"
|
ndarray = "0.16.1"
|
||||||
@@ -15,10 +20,15 @@ ort = { git = "https://github.com/pykeio/ort.git", version = "2.0.0-rc.6" }
|
|||||||
regex = "1.10.6"
|
regex = "1.10.6"
|
||||||
serde = { version = "1.0.210", features = ["derive"] }
|
serde = { version = "1.0.210", features = ["derive"] }
|
||||||
serde_json = "1.0.128"
|
serde_json = "1.0.128"
|
||||||
|
tar = "0.4.41"
|
||||||
thiserror = "1.0.63"
|
thiserror = "1.0.63"
|
||||||
tokenizers = "0.20.0"
|
tokenizers = "0.20.0"
|
||||||
|
zstd = "0.13.2"
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
cuda = ["ort/cuda"]
|
cuda = ["ort/cuda"]
|
||||||
cuda_tf32 = []
|
cuda_tf32 = []
|
||||||
dynamic = ["ort/load-dynamic"]
|
dynamic = ["ort/load-dynamic"]
|
||||||
|
directml = ["ort/directml"]
|
||||||
|
tensorrt = ["ort/tensorrt"]
|
||||||
|
coreml = ["ort/coreml"]
|
||||||
@@ -4,18 +4,15 @@ use sbv2_core::tts;
|
|||||||
use std::env;
|
use std::env;
|
||||||
|
|
||||||
fn main() -> anyhow::Result<()> {
|
fn main() -> anyhow::Result<()> {
|
||||||
dotenvy::dotenv().ok();
|
dotenvy::dotenv_override().ok();
|
||||||
|
env_logger::init();
|
||||||
let text = "眠たい";
|
let text = "眠たい";
|
||||||
let ident = "aaa";
|
let ident = "aaa";
|
||||||
let mut tts_holder = tts::TTSModelHolder::new(
|
let mut tts_holder = tts::TTSModelHolder::new(
|
||||||
&fs::read(env::var("BERT_MODEL_PATH")?)?,
|
&fs::read(env::var("BERT_MODEL_PATH")?)?,
|
||||||
&fs::read(env::var("TOKENIZER_PATH")?)?,
|
&fs::read(env::var("TOKENIZER_PATH")?)?,
|
||||||
)?;
|
)?;
|
||||||
tts_holder.load(
|
tts_holder.load_sbv2file(ident, fs::read(env::var("MODEL_PATH")?)?)?;
|
||||||
ident,
|
|
||||||
fs::read(env::var("STYLE_VECTORS_PATH")?)?,
|
|
||||||
fs::read(env::var("MODEL_PATH")?)?,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let (bert_ori, phones, tones, lang_ids) = tts_holder.parse_text(text)?;
|
let (bert_ori, phones, tones, lang_ids) = tts_holder.parse_text(text)?;
|
||||||
|
|
||||||
@@ -32,6 +29,14 @@ fn main() -> anyhow::Result<()> {
|
|||||||
)?;
|
)?;
|
||||||
std::fs::write("output.wav", data)?;
|
std::fs::write("output.wav", data)?;
|
||||||
let now = Instant::now();
|
let now = Instant::now();
|
||||||
|
for _ in 0..10 {
|
||||||
|
tts_holder.parse_text(text)?;
|
||||||
|
}
|
||||||
|
println!(
|
||||||
|
"Time taken(parse_text): {}ms/it",
|
||||||
|
now.elapsed().as_millis() / 10
|
||||||
|
);
|
||||||
|
let now = Instant::now();
|
||||||
for _ in 0..10 {
|
for _ in 0..10 {
|
||||||
tts_holder.synthesize(
|
tts_holder.synthesize(
|
||||||
ident,
|
ident,
|
||||||
@@ -44,6 +49,9 @@ fn main() -> anyhow::Result<()> {
|
|||||||
1.0,
|
1.0,
|
||||||
)?;
|
)?;
|
||||||
}
|
}
|
||||||
println!("Time taken: {}", now.elapsed().as_millis());
|
println!(
|
||||||
|
"Time taken(synthesize): {}ms/it",
|
||||||
|
now.elapsed().as_millis() / 10
|
||||||
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,11 +4,25 @@ use ndarray::{array, s, Array1, Array2, Axis};
|
|||||||
use ort::{GraphOptimizationLevel, Session};
|
use ort::{GraphOptimizationLevel, Session};
|
||||||
use std::io::Cursor;
|
use std::io::Cursor;
|
||||||
|
|
||||||
#[allow(clippy::vec_init_then_push)]
|
#[allow(clippy::vec_init_then_push, unused_variables)]
|
||||||
pub fn load_model<P: AsRef<[u8]>>(model_file: P) -> Result<Session> {
|
pub fn load_model<P: AsRef<[u8]>>(model_file: P, bert: bool) -> Result<Session> {
|
||||||
let mut exp = Vec::new();
|
let mut exp = Vec::new();
|
||||||
|
#[cfg(feature = "tensorrt")]
|
||||||
|
{
|
||||||
|
if bert {
|
||||||
|
exp.push(
|
||||||
|
ort::TensorRTExecutionProvider::default()
|
||||||
|
.with_fp16(true)
|
||||||
|
.with_profile_min_shapes("input_ids:1x1,attention_mask:1x1")
|
||||||
|
.with_profile_max_shapes("input_ids:1x100,attention_mask:1x100")
|
||||||
|
.with_profile_opt_shapes("input_ids:1x25,attention_mask:1x25")
|
||||||
|
.build(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
{
|
{
|
||||||
|
#[allow(unused_mut)]
|
||||||
let mut cuda = ort::CUDAExecutionProvider::default()
|
let mut cuda = ort::CUDAExecutionProvider::default()
|
||||||
.with_conv_algorithm_search(ort::CUDAExecutionProviderCuDNNConvAlgoSearch::Default);
|
.with_conv_algorithm_search(ort::CUDAExecutionProviderCuDNNConvAlgoSearch::Default);
|
||||||
#[cfg(feature = "cuda_tf32")]
|
#[cfg(feature = "cuda_tf32")]
|
||||||
@@ -17,6 +31,14 @@ pub fn load_model<P: AsRef<[u8]>>(model_file: P) -> Result<Session> {
|
|||||||
}
|
}
|
||||||
exp.push(cuda.build());
|
exp.push(cuda.build());
|
||||||
}
|
}
|
||||||
|
#[cfg(feature = "directml")]
|
||||||
|
{
|
||||||
|
exp.push(ort::DirectMLExecutionProvider::default().build());
|
||||||
|
}
|
||||||
|
#[cfg(feature = "coreml")]
|
||||||
|
{
|
||||||
|
exp.push(ort::CoreMLExecutionProvider::default().build());
|
||||||
|
}
|
||||||
exp.push(ort::CPUExecutionProvider::default().build());
|
exp.push(ort::CPUExecutionProvider::default().build());
|
||||||
Ok(Session::builder()?
|
Ok(Session::builder()?
|
||||||
.with_execution_providers(exp)?
|
.with_execution_providers(exp)?
|
||||||
@@ -26,6 +48,7 @@ pub fn load_model<P: AsRef<[u8]>>(model_file: P) -> Result<Session> {
|
|||||||
.with_inter_threads(num_cpus::get_physical())?
|
.with_inter_threads(num_cpus::get_physical())?
|
||||||
.commit_from_memory(model_file.as_ref())?)
|
.commit_from_memory(model_file.as_ref())?)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn synthesize(
|
pub fn synthesize(
|
||||||
session: &Session,
|
session: &Session,
|
||||||
|
|||||||
@@ -2,7 +2,10 @@ use crate::error::{Error, Result};
|
|||||||
use crate::{bert, jtalk, model, nlp, norm, style, tokenizer, utils};
|
use crate::{bert, jtalk, model, nlp, norm, style, tokenizer, utils};
|
||||||
use ndarray::{concatenate, s, Array, Array1, Array2, Axis};
|
use ndarray::{concatenate, s, Array, Array1, Array2, Axis};
|
||||||
use ort::Session;
|
use ort::Session;
|
||||||
|
use std::io::{Cursor, Read};
|
||||||
|
use tar::Archive;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
use zstd::decode_all;
|
||||||
|
|
||||||
#[derive(PartialEq, Eq, Clone)]
|
#[derive(PartialEq, Eq, Clone)]
|
||||||
pub struct TTSIdent(String);
|
pub struct TTSIdent(String);
|
||||||
@@ -38,7 +41,7 @@ pub struct TTSModelHolder {
|
|||||||
|
|
||||||
impl TTSModelHolder {
|
impl TTSModelHolder {
|
||||||
pub fn new<P: AsRef<[u8]>>(bert_model_bytes: P, tokenizer_bytes: P) -> Result<Self> {
|
pub fn new<P: AsRef<[u8]>>(bert_model_bytes: P, tokenizer_bytes: P) -> Result<Self> {
|
||||||
let bert = model::load_model(bert_model_bytes)?;
|
let bert = model::load_model(bert_model_bytes, true)?;
|
||||||
let jtalk = jtalk::JTalk::new()?;
|
let jtalk = jtalk::JTalk::new()?;
|
||||||
let tokenizer = tokenizer::get_tokenizer(tokenizer_bytes)?;
|
let tokenizer = tokenizer::get_tokenizer(tokenizer_bytes)?;
|
||||||
Ok(TTSModelHolder {
|
Ok(TTSModelHolder {
|
||||||
@@ -53,6 +56,35 @@ impl TTSModelHolder {
|
|||||||
self.models.iter().map(|m| m.ident.to_string()).collect()
|
self.models.iter().map(|m| m.ident.to_string()).collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn load_sbv2file<I: Into<TTSIdent>, P: AsRef<[u8]>>(
|
||||||
|
&mut self,
|
||||||
|
ident: I,
|
||||||
|
sbv2_bytes: P,
|
||||||
|
) -> Result<()> {
|
||||||
|
let mut arc = Archive::new(Cursor::new(decode_all(Cursor::new(sbv2_bytes.as_ref()))?));
|
||||||
|
let mut vits2 = None;
|
||||||
|
let mut style_vectors = None;
|
||||||
|
let mut et = arc.entries()?;
|
||||||
|
while let Some(Ok(mut e)) = et.next() {
|
||||||
|
let pth = String::from_utf8_lossy(&e.path_bytes()).to_string();
|
||||||
|
let mut b = Vec::with_capacity(e.size() as usize);
|
||||||
|
e.read_to_end(&mut b)?;
|
||||||
|
match pth.as_str() {
|
||||||
|
"model.onnx" => vits2 = Some(b),
|
||||||
|
"style_vectors.json" => style_vectors = Some(b),
|
||||||
|
_ => continue,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if style_vectors.is_none() {
|
||||||
|
return Err(Error::ModelNotFoundError("style_vectors".to_string()));
|
||||||
|
}
|
||||||
|
if vits2.is_none() {
|
||||||
|
return Err(Error::ModelNotFoundError("vits2".to_string()));
|
||||||
|
}
|
||||||
|
self.load(ident, style_vectors.unwrap(), vits2.unwrap())?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
pub fn load<I: Into<TTSIdent>, P: AsRef<[u8]>>(
|
pub fn load<I: Into<TTSIdent>, P: AsRef<[u8]>>(
|
||||||
&mut self,
|
&mut self,
|
||||||
ident: I,
|
ident: I,
|
||||||
@@ -62,7 +94,7 @@ impl TTSModelHolder {
|
|||||||
let ident = ident.into();
|
let ident = ident.into();
|
||||||
if self.find_model(ident.clone()).is_err() {
|
if self.find_model(ident.clone()).is_err() {
|
||||||
self.models.push(TTSModel {
|
self.models.push(TTSModel {
|
||||||
vits2: model::load_model(vits2_bytes)?,
|
vits2: model::load_model(vits2_bytes, false)?,
|
||||||
style_vectors: style::load_style(style_vectors_bytes)?,
|
style_vectors: style::load_style(style_vectors_bytes)?,
|
||||||
ident,
|
ident,
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -2,11 +2,6 @@ pub fn intersperse<T>(slice: &[T], sep: T) -> Vec<T>
|
|||||||
where
|
where
|
||||||
T: Clone,
|
T: Clone,
|
||||||
{
|
{
|
||||||
/*
|
|
||||||
result = [item] * (len(lst) * 2 + 1)
|
|
||||||
result[1::2] = lst
|
|
||||||
return result
|
|
||||||
*/
|
|
||||||
let mut result = vec![sep.clone(); slice.len() * 2 + 1];
|
let mut result = vec![sep.clone(); slice.len() * 2 + 1];
|
||||||
result
|
result
|
||||||
.iter_mut()
|
.iter_mut()
|
||||||
@@ -15,24 +10,3 @@ where
|
|||||||
.for_each(|(r, s)| *r = s.clone());
|
.for_each(|(r, s)| *r = s.clone());
|
||||||
result
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
|
||||||
fn tile<T: Clone>(arr: &Array2<T>, reps: (usize, usize)) -> Array2<T> {
|
|
||||||
let (rows, cols) = arr.dim();
|
|
||||||
let (rep_rows, rep_cols) = reps;
|
|
||||||
|
|
||||||
let mut result = Array::zeros((rows * rep_rows, cols * rep_cols));
|
|
||||||
|
|
||||||
for i in 0..rep_rows {
|
|
||||||
for j in 0..rep_cols {
|
|
||||||
let view = result.slice_mut(s![
|
|
||||||
i * rows..(i + 1) * rows,
|
|
||||||
j * cols..(j + 1) * cols
|
|
||||||
]);
|
|
||||||
view.assign(arr);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
result
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|||||||
2
test.py
2
test.py
@@ -1,7 +1,7 @@
|
|||||||
import requests
|
import requests
|
||||||
|
|
||||||
res = requests.post(
|
res = requests.post(
|
||||||
"http://localhost:3001/synthesize",
|
"http://localhost:3000/synthesize",
|
||||||
json={"text": "おはようございます", "ident": "tsukuyomi"},
|
json={"text": "おはようございます", "ident": "tsukuyomi"},
|
||||||
)
|
)
|
||||||
with open("output.wav", "wb") as f:
|
with open("output.wav", "wb") as f:
|
||||||
|
|||||||
Reference in New Issue
Block a user