mirror of
https://github.com/neodyland/sbv2-api.git
synced 2026-01-06 14:32:57 +00:00
Merge branch 'main' of https://github.com/tuna2134/sbv2-api into python
This commit is contained in:
6
.dockerignore
Normal file
6
.dockerignore
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
target/
|
||||||
|
models/
|
||||||
|
docker/
|
||||||
|
.env*
|
||||||
|
renovate.json
|
||||||
|
*.py
|
||||||
6
.env.sample
Normal file
6
.env.sample
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
BERT_MODEL_PATH=models/deberta.onnx
|
||||||
|
MODEL_PATH=models/model_tsukuyomi.onnx
|
||||||
|
MODELS_PATH=models
|
||||||
|
STYLE_VECTORS_PATH=models/style_vectors.json
|
||||||
|
TOKENIZER_PATH=models/tokenizer.json
|
||||||
|
ADDR=localhost:3000
|
||||||
36
.github/workflows/build.yml
vendored
Normal file
36
.github/workflows/build.yml
vendored
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
name: Push to github container register
|
||||||
|
|
||||||
|
on:
|
||||||
|
release:
|
||||||
|
types: [created]
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
push-docker:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
packages: write
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
tag: [cpu, cuda]
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- name: Set up QEMU
|
||||||
|
uses: docker/setup-qemu-action@v3
|
||||||
|
- name: Set up Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v3
|
||||||
|
- name: Login to GitHub Container Registry
|
||||||
|
uses: docker/login-action@v3
|
||||||
|
with:
|
||||||
|
registry: ghcr.io
|
||||||
|
username: ${{ github.actor }}
|
||||||
|
password: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
- name: Build and push image
|
||||||
|
uses: docker/build-push-action@v6
|
||||||
|
with:
|
||||||
|
context: .
|
||||||
|
push: true
|
||||||
|
tags: |
|
||||||
|
ghcr.io/${{ github.repository }}:${{ matrix.tag }}
|
||||||
|
file: docker/${{ matrix.tag }}.Dockerfile
|
||||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -1,5 +1,6 @@
|
|||||||
target
|
target
|
||||||
models/*.onnx
|
models/*.onnx
|
||||||
models/*.json
|
models/*.json
|
||||||
venv
|
venv/
|
||||||
.env
|
.env
|
||||||
|
output.wav
|
||||||
2360
Cargo.lock
generated
2360
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,11 @@
|
|||||||
[workspace]
|
[workspace]
|
||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
<<<<<<< HEAD
|
||||||
members = [ "sbv2_api","sbv2_core", "sbv2_bindings"]
|
members = [ "sbv2_api","sbv2_core", "sbv2_bindings"]
|
||||||
|
=======
|
||||||
|
members = ["sbv2_api", "sbv2_core"]
|
||||||
|
>>>>>>> cfea5d735aeb7d0abad5b913a3dda3810d8e59f8
|
||||||
|
|
||||||
[workspace.dependencies]
|
[workspace.dependencies]
|
||||||
anyhow = "1.0.86"
|
anyhow = "1.0.86"
|
||||||
|
dotenvy = "0.15.7"
|
||||||
2
LICENSE
2
LICENSE
@@ -1,6 +1,6 @@
|
|||||||
MIT License
|
MIT License
|
||||||
|
|
||||||
Copyright (c) 2024 コマリン親衛隊
|
Copyright (c) 2024 tuna2134
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
|||||||
51
README.md
51
README.md
@@ -1,25 +1,52 @@
|
|||||||
# sbv2-api
|
# sbv2-api
|
||||||
このプロジェクトはStyle-Bert-ViTS2をONNX化したものをRustで実行するのを目的としています。つまり推論しか行いません。
|
このプロジェクトはStyle-Bert-ViTS2をONNX化したものをRustで実行するのを目的としています。
|
||||||
|
|
||||||
学習したいのであれば、Style-Bert-ViT2で調べてやってください。
|
学習したい場合は、Style-Bert-ViTS2 学習方法 などで調べるとよいかもしれません。
|
||||||
|
|
||||||
注意:JP-Extraしか対応していません。
|
JP-Extraしか対応していません。(基本的に対応する予定もありません)
|
||||||
|
|
||||||
## ONNX化する方法
|
## ONNX化する方法
|
||||||
dabertaとstbv2本体をonnx化する必要があります。
|
```sh
|
||||||
|
cd convert
|
||||||
あくまで推奨ですが、onnxsimを使うことをお勧めします。
|
# (何かしらの方法でvenv作成(推奨))
|
||||||
onnxsim使うことでモデルのサイズを軽くすることができます。
|
pip install -r requirements.txt
|
||||||
|
python convert_deberta.py
|
||||||
## onnxモデルの配置方法
|
python convert_model.py --style_file ../../style-bert-vits2/model_assets/something/style_vectors.npy --config_file ../../style-bert-vits2/model_assets/something/config.json --model_file ../../style-bert-vits2/model_assets/something/something_eXXX_sXXXX.safetensors
|
||||||
- `models/daberta.onnx` - DaBertaのonnxモデル
|
```
|
||||||
- `models/sbv2.onnx` - `Style-Bert-ViT2`の本体
|
|
||||||
|
|
||||||
## Todo
|
## Todo
|
||||||
- [x] WebAPIの実装
|
- [x] WebAPIの実装
|
||||||
- [x] Rustライブラリの実装
|
- [x] Rustライブラリの実装
|
||||||
- [ ] 余裕があればPyO3使ってPythonで利用可能にする
|
- [ ] 余裕があればPyO3使ってPythonで利用可能にする
|
||||||
|
- [x] GPU対応(優先的にCUDA)
|
||||||
|
- [ ] WASM変換(ortがサポートやめたので、中止)
|
||||||
|
|
||||||
|
## 構造説明
|
||||||
|
- `sbv2_api` - 推論用 REST API
|
||||||
|
- `sbv2_core` - 推論コア部分
|
||||||
|
- `docker` - dockerビルドスクリプト
|
||||||
|
|
||||||
|
## APIの起動方法
|
||||||
|
```sh
|
||||||
|
cargo run -p sbv2_api -r
|
||||||
|
```
|
||||||
|
|
||||||
|
### CUDAでの起動
|
||||||
|
```sh
|
||||||
|
cargo run -p sbv2_api -r -F cuda,cuda_tf32
|
||||||
|
```
|
||||||
|
|
||||||
|
### Dynamic Linkサポート
|
||||||
|
```sh
|
||||||
|
ORT_DYLIB_PATH=./libonnxruntime.dll cargo run -p sbv2_api -r -F dynamic
|
||||||
|
```
|
||||||
|
|
||||||
|
### テストコマンド
|
||||||
|
```sh
|
||||||
|
curl -XPOST -H "Content-type: application/json" -d '{"text": "こんにちは","ident": "something"}' 'http://localhost:3000/synthesize'
|
||||||
|
curl http://localhost:3000/models
|
||||||
|
```
|
||||||
|
|
||||||
## 謝辞
|
## 謝辞
|
||||||
- [litagin02/Style-Bert-VITS2](https://github.com/litagin02/Style-Bert-VITS2) - このコードの書くにあたり、ベースとなる部分を参考にさせていただきました。
|
- [litagin02/Style-Bert-VITS2](https://github.com/litagin02/Style-Bert-VITS2) - このコードの書くにあたり、ベースとなる部分を参考にさせていただきました。
|
||||||
- [Googlefan](https://github.com/Googlefan256) - 彼にモデルをONNXヘ変換および効率化をする方法を教わりました。
|
- [Googlefan](https://github.com/Googlefan256) - 彼にモデルをONNXヘ変換および効率化をする方法を教わりました。
|
||||||
|
|||||||
50
convert/convert_deberta.py
Normal file
50
convert/convert_deberta.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
from transformers.convert_slow_tokenizer import BertConverter
|
||||||
|
from style_bert_vits2.nlp import bert_models
|
||||||
|
from style_bert_vits2.constants import Languages
|
||||||
|
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
|
||||||
|
parser = ArgumentParser()
|
||||||
|
parser.add_argument("--model", default="ku-nlp/deberta-v2-large-japanese-char-wwm")
|
||||||
|
args = parser.parse_args()
|
||||||
|
model_name = args.model
|
||||||
|
|
||||||
|
bert_models.load_tokenizer(Languages.JP, model_name)
|
||||||
|
tokenizer = bert_models.load_tokenizer(Languages.JP)
|
||||||
|
converter = BertConverter(tokenizer)
|
||||||
|
tokenizer = converter.converted()
|
||||||
|
tokenizer.save("../models/tokenizer.json")
|
||||||
|
|
||||||
|
|
||||||
|
class ORTDeberta(nn.Module):
|
||||||
|
def __init__(self, model_name):
|
||||||
|
super(ORTDeberta, self).__init__()
|
||||||
|
self.model = AutoModelForMaskedLM.from_pretrained(model_name)
|
||||||
|
|
||||||
|
def forward(self, input_ids, token_type_ids, attention_mask):
|
||||||
|
inputs = {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"token_type_ids": token_type_ids,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
}
|
||||||
|
res = self.model(**inputs, output_hidden_states=True)
|
||||||
|
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
model = ORTDeberta(model_name)
|
||||||
|
inputs = AutoTokenizer.from_pretrained(model_name)(
|
||||||
|
"今日はいい天気ですね", return_tensors="pt"
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.onnx.export(
|
||||||
|
model,
|
||||||
|
(inputs["input_ids"], inputs["token_type_ids"], inputs["attention_mask"]),
|
||||||
|
"../models/deberta.onnx",
|
||||||
|
input_names=["input_ids", "token_type_ids", "attention_mask"],
|
||||||
|
output_names=["output"],
|
||||||
|
verbose=True,
|
||||||
|
dynamic_axes={"input_ids": {1: "batch_size"}, "attention_mask": {1: "batch_size"}},
|
||||||
|
)
|
||||||
143
convert/convert_model.py
Normal file
143
convert/convert_model.py
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
import numpy as np
|
||||||
|
import json
|
||||||
|
from style_bert_vits2.nlp import bert_models
|
||||||
|
from style_bert_vits2.constants import Languages
|
||||||
|
from style_bert_vits2.models.infer import get_net_g, get_text
|
||||||
|
from style_bert_vits2.models.hyper_parameters import HyperParameters
|
||||||
|
import torch
|
||||||
|
from style_bert_vits2.constants import (
|
||||||
|
DEFAULT_ASSIST_TEXT_WEIGHT,
|
||||||
|
DEFAULT_STYLE,
|
||||||
|
DEFAULT_STYLE_WEIGHT,
|
||||||
|
Languages,
|
||||||
|
)
|
||||||
|
from style_bert_vits2.tts_model import TTSModel
|
||||||
|
import numpy as np
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
|
||||||
|
parser = ArgumentParser()
|
||||||
|
parser.add_argument("--style_file", required=True)
|
||||||
|
parser.add_argument("--config_file", required=True)
|
||||||
|
parser.add_argument("--model_file", required=True)
|
||||||
|
args = parser.parse_args()
|
||||||
|
style_file = args.style_file
|
||||||
|
config_file = args.config_file
|
||||||
|
model_file = args.model_file
|
||||||
|
|
||||||
|
bert_models.load_model(Languages.JP, "ku-nlp/deberta-v2-large-japanese-char-wwm")
|
||||||
|
bert_models.load_tokenizer(Languages.JP, "ku-nlp/deberta-v2-large-japanese-char-wwm")
|
||||||
|
|
||||||
|
array = np.load(style_file)
|
||||||
|
data = array.tolist()
|
||||||
|
hyper_parameters = HyperParameters.load_from_json(config_file)
|
||||||
|
out_name = hyper_parameters.model_name
|
||||||
|
|
||||||
|
with open(f"../models/style_vectors_{out_name}.json", "w") as f:
|
||||||
|
json.dump(
|
||||||
|
{
|
||||||
|
"data": data,
|
||||||
|
"shape": array.shape,
|
||||||
|
},
|
||||||
|
f,
|
||||||
|
)
|
||||||
|
text = "今日はいい天気ですね。"
|
||||||
|
|
||||||
|
bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
|
||||||
|
text,
|
||||||
|
Languages.JP,
|
||||||
|
hyper_parameters,
|
||||||
|
"cpu",
|
||||||
|
assist_text=None,
|
||||||
|
assist_text_weight=DEFAULT_ASSIST_TEXT_WEIGHT,
|
||||||
|
given_phone=None,
|
||||||
|
given_tone=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
tts_model = TTSModel(
|
||||||
|
model_path=model_file,
|
||||||
|
config_path=config_file,
|
||||||
|
style_vec_path=style_file,
|
||||||
|
device="cpu",
|
||||||
|
)
|
||||||
|
device = "cpu"
|
||||||
|
style_id = tts_model.style2id[DEFAULT_STYLE]
|
||||||
|
|
||||||
|
|
||||||
|
def get_style_vector(style_id, weight):
|
||||||
|
style_vectors = np.load(style_file)
|
||||||
|
mean = style_vectors[0]
|
||||||
|
style_vec = style_vectors[style_id]
|
||||||
|
style_vec = mean + (style_vec - mean) * weight
|
||||||
|
return style_vec
|
||||||
|
|
||||||
|
|
||||||
|
style_vector = get_style_vector(style_id, DEFAULT_STYLE_WEIGHT)
|
||||||
|
|
||||||
|
x_tst = phones.to(device).unsqueeze(0)
|
||||||
|
tones = tones.to(device).unsqueeze(0)
|
||||||
|
lang_ids = lang_ids.to(device).unsqueeze(0)
|
||||||
|
bert = bert.to(device).unsqueeze(0)
|
||||||
|
ja_bert = ja_bert.to(device).unsqueeze(0)
|
||||||
|
en_bert = en_bert.to(device).unsqueeze(0)
|
||||||
|
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
|
||||||
|
style_vec_tensor = torch.from_numpy(style_vector).to(device).unsqueeze(0)
|
||||||
|
|
||||||
|
model = get_net_g(
|
||||||
|
model_file,
|
||||||
|
hyper_parameters.version,
|
||||||
|
device,
|
||||||
|
hyper_parameters,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(x, x_len, sid, tone, lang, bert, style, length_scale, sdp_ratio):
|
||||||
|
return model.infer(
|
||||||
|
x,
|
||||||
|
x_len,
|
||||||
|
sid,
|
||||||
|
tone,
|
||||||
|
lang,
|
||||||
|
bert,
|
||||||
|
style,
|
||||||
|
sdp_ratio=sdp_ratio,
|
||||||
|
length_scale=length_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
model.forward = forward
|
||||||
|
|
||||||
|
torch.onnx.export(
|
||||||
|
model,
|
||||||
|
(
|
||||||
|
x_tst,
|
||||||
|
x_tst_lengths,
|
||||||
|
torch.LongTensor([0]).to(device),
|
||||||
|
tones,
|
||||||
|
lang_ids,
|
||||||
|
bert,
|
||||||
|
style_vec_tensor,
|
||||||
|
torch.tensor(1.0),
|
||||||
|
torch.tensor(0.0),
|
||||||
|
),
|
||||||
|
f"../models/model_{out_name}.onnx",
|
||||||
|
verbose=True,
|
||||||
|
dynamic_axes={
|
||||||
|
"x_tst": {1: "batch_size"},
|
||||||
|
"x_tst_lengths": {0: "batch_size"},
|
||||||
|
"tones": {1: "batch_size"},
|
||||||
|
"language": {1: "batch_size"},
|
||||||
|
"bert": {2: "batch_size"},
|
||||||
|
},
|
||||||
|
input_names=[
|
||||||
|
"x_tst",
|
||||||
|
"x_tst_lengths",
|
||||||
|
"sid",
|
||||||
|
"tones",
|
||||||
|
"language",
|
||||||
|
"bert",
|
||||||
|
"style_vec",
|
||||||
|
"length_scale",
|
||||||
|
"sdp_ratio",
|
||||||
|
],
|
||||||
|
output_names=["output"],
|
||||||
|
)
|
||||||
3
convert/requirements.txt
Normal file
3
convert/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
style-bert-vits2
|
||||||
|
onnxsim
|
||||||
|
numpy<3
|
||||||
9
docker/cpu.Dockerfile
Normal file
9
docker/cpu.Dockerfile
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
FROM rust AS builder
|
||||||
|
WORKDIR /work
|
||||||
|
COPY . .
|
||||||
|
RUN cargo build -r --bin sbv2_api
|
||||||
|
FROM gcr.io/distroless/cc-debian12
|
||||||
|
WORKDIR /work
|
||||||
|
COPY --from=builder /work/target/release/sbv2_api /work/main
|
||||||
|
COPY --from=builder /work/target/release/*.so /work
|
||||||
|
CMD ["/work/main"]
|
||||||
10
docker/cuda.Dockerfile
Normal file
10
docker/cuda.Dockerfile
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
FROM rust AS builder
|
||||||
|
WORKDIR /work
|
||||||
|
COPY . .
|
||||||
|
RUN cargo build -r --bin sbv2_api -F cuda,cuda_tf32
|
||||||
|
|
||||||
|
FROM nvidia/cuda:12.6.1-cudnn-runtime-ubuntu24.04
|
||||||
|
WORKDIR /work
|
||||||
|
COPY --from=builder /work/target/release/sbv2_api /work/main
|
||||||
|
COPY --from=builder /work/target/release/*.so /work
|
||||||
|
CMD ["/work/main"]
|
||||||
1
docker/run.sh
Normal file
1
docker/run.sh
Normal file
@@ -0,0 +1 @@
|
|||||||
|
docker run -it --rm -p 3000:3000 --name sbv2 -v ./models:/work/models --env-file .env sbv2
|
||||||
BIN
output.wav
BIN
output.wav
Binary file not shown.
@@ -6,7 +6,14 @@ edition = "2021"
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
anyhow.workspace = true
|
anyhow.workspace = true
|
||||||
axum = "0.7.5"
|
axum = "0.7.5"
|
||||||
dotenvy = "0.15.7"
|
dotenvy.workspace = true
|
||||||
|
env_logger = "0.11.5"
|
||||||
|
log = "0.4.22"
|
||||||
sbv2_core = { version = "0.1.0", path = "../sbv2_core" }
|
sbv2_core = { version = "0.1.0", path = "../sbv2_core" }
|
||||||
serde = { version = "1.0.210", features = ["derive"] }
|
serde = { version = "1.0.210", features = ["derive"] }
|
||||||
tokio = { version = "1.40.0", features = ["full"] }
|
tokio = { version = "1.40.0", features = ["full"] }
|
||||||
|
|
||||||
|
[features]
|
||||||
|
cuda = ["sbv2_core/cuda"]
|
||||||
|
cuda_tf32 = ["sbv2_core/cuda_tf32"]
|
||||||
|
dynamic = ["sbv2_core/dynamic"]
|
||||||
|
|||||||
@@ -5,58 +5,130 @@ use axum::{
|
|||||||
routing::{get, post},
|
routing::{get, post},
|
||||||
Json, Router,
|
Json, Router,
|
||||||
};
|
};
|
||||||
use sbv2_core::tts::TTSModel;
|
use sbv2_core::tts::TTSModelHolder;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use std::env;
|
use std::env;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use tokio::fs;
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
|
|
||||||
mod error;
|
mod error;
|
||||||
use crate::error::AppResult;
|
use crate::error::AppResult;
|
||||||
|
|
||||||
|
async fn models(State(state): State<AppState>) -> AppResult<impl IntoResponse> {
|
||||||
|
Ok(Json(state.tts_model.lock().await.models()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sdp_default() -> f32 {
|
||||||
|
0.0
|
||||||
|
}
|
||||||
|
|
||||||
|
fn length_default() -> f32 {
|
||||||
|
1.0
|
||||||
|
}
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct SynthesizeRequest {
|
struct SynthesizeRequest {
|
||||||
text: String,
|
text: String,
|
||||||
|
ident: String,
|
||||||
|
#[serde(default = "sdp_default")]
|
||||||
|
sdp_ratio: f32,
|
||||||
|
#[serde(default = "length_default")]
|
||||||
|
length_scale: f32,
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn synthesize(
|
async fn synthesize(
|
||||||
State(state): State<Arc<AppState>>,
|
State(state): State<AppState>,
|
||||||
Json(SynthesizeRequest { text }): Json<SynthesizeRequest>,
|
Json(SynthesizeRequest {
|
||||||
|
text,
|
||||||
|
ident,
|
||||||
|
sdp_ratio,
|
||||||
|
length_scale,
|
||||||
|
}): Json<SynthesizeRequest>,
|
||||||
) -> AppResult<impl IntoResponse> {
|
) -> AppResult<impl IntoResponse> {
|
||||||
|
log::debug!("processing request: text={text}, ident={ident}, sdp_ratio={sdp_ratio}, length_scale={length_scale}");
|
||||||
let buffer = {
|
let buffer = {
|
||||||
let mut tts_model = state.tts_model.lock().await;
|
let tts_model = state.tts_model.lock().await;
|
||||||
let tts_model = if let Some(tts_model) = &*tts_model {
|
|
||||||
tts_model
|
|
||||||
} else {
|
|
||||||
*tts_model = Some(TTSModel::new(
|
|
||||||
&env::var("BERT_MODEL_PATH")?,
|
|
||||||
&env::var("MAIN_MODEL_PATH")?,
|
|
||||||
&env::var("STYLE_VECTORS_PATH")?,
|
|
||||||
)?);
|
|
||||||
tts_model.as_ref().unwrap()
|
|
||||||
};
|
|
||||||
let (bert_ori, phones, tones, lang_ids) = tts_model.parse_text(&text)?;
|
let (bert_ori, phones, tones, lang_ids) = tts_model.parse_text(&text)?;
|
||||||
let style_vector = tts_model.get_style_vector(0, 1.0)?;
|
let style_vector = tts_model.get_style_vector(&ident, 0, 1.0)?;
|
||||||
tts_model.synthesize(bert_ori.to_owned(), phones, tones, lang_ids, style_vector)?
|
tts_model.synthesize(
|
||||||
|
ident,
|
||||||
|
bert_ori.to_owned(),
|
||||||
|
phones,
|
||||||
|
tones,
|
||||||
|
lang_ids,
|
||||||
|
style_vector,
|
||||||
|
sdp_ratio,
|
||||||
|
length_scale,
|
||||||
|
)?
|
||||||
};
|
};
|
||||||
Ok(([(CONTENT_TYPE, "audio/wav")], buffer))
|
Ok(([(CONTENT_TYPE, "audio/wav")], buffer))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
struct AppState {
|
struct AppState {
|
||||||
tts_model: Arc<Mutex<Option<TTSModel>>>,
|
tts_model: Arc<Mutex<TTSModelHolder>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AppState {
|
||||||
|
pub async fn new() -> anyhow::Result<Self> {
|
||||||
|
let mut tts_model = TTSModelHolder::new(
|
||||||
|
&fs::read(env::var("BERT_MODEL_PATH")?).await?,
|
||||||
|
&fs::read(env::var("TOKENIZER_PATH")?).await?,
|
||||||
|
)?;
|
||||||
|
let models = env::var("MODELS_PATH").unwrap_or("models".to_string());
|
||||||
|
let mut f = fs::read_dir(&models).await?;
|
||||||
|
let mut entries = vec![];
|
||||||
|
while let Ok(Some(e)) = f.next_entry().await {
|
||||||
|
let name = e.file_name().to_string_lossy().to_string();
|
||||||
|
if name.ends_with(".onnx") && name.starts_with("model_") {
|
||||||
|
let name_len = name.len();
|
||||||
|
let name = name.chars();
|
||||||
|
entries.push(
|
||||||
|
name.collect::<Vec<_>>()[6..name_len - 5]
|
||||||
|
.iter()
|
||||||
|
.collect::<String>(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for entry in entries {
|
||||||
|
log::info!("Try loading: {entry}");
|
||||||
|
let style_vectors_bytes =
|
||||||
|
match fs::read(format!("{models}/style_vectors_{entry}.json")).await {
|
||||||
|
Ok(b) => b,
|
||||||
|
Err(e) => {
|
||||||
|
log::warn!("Error loading style_vectors_bytes from file {entry}: {e}");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let vits2_bytes = match fs::read(format!("{models}/model_{entry}.onnx")).await {
|
||||||
|
Ok(b) => b,
|
||||||
|
Err(e) => {
|
||||||
|
log::warn!("Error loading vits2_bytes from file {entry}: {e}");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
if let Err(e) = tts_model.load(&entry, style_vectors_bytes, vits2_bytes) {
|
||||||
|
log::warn!("Error loading {entry}: {e}");
|
||||||
|
};
|
||||||
|
}
|
||||||
|
Ok(Self {
|
||||||
|
tts_model: Arc::new(Mutex::new(tts_model)),
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> anyhow::Result<()> {
|
async fn main() -> anyhow::Result<()> {
|
||||||
dotenvy::dotenv().ok();
|
dotenvy::dotenv().ok();
|
||||||
|
env_logger::init();
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
.route("/", get(|| async { "Hello, World!" }))
|
.route("/", get(|| async { "Hello, World!" }))
|
||||||
.route("/synthesize", post(synthesize))
|
.route("/synthesize", post(synthesize))
|
||||||
.with_state(Arc::new(AppState {
|
.route("/models", get(models))
|
||||||
tts_model: Arc::new(Mutex::new(None)),
|
.with_state(AppState::new().await?);
|
||||||
}));
|
let addr = env::var("ADDR").unwrap_or("0.0.0.0:3000".to_string());
|
||||||
|
let listener = tokio::net::TcpListener::bind(&addr).await?;
|
||||||
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await?;
|
log::info!("Listening on {addr}");
|
||||||
axum::serve(listener, app).await?;
|
axum::serve(listener, app).await?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|||||||
@@ -5,14 +5,20 @@ edition = "2021"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
anyhow.workspace = true
|
anyhow.workspace = true
|
||||||
|
dotenvy.workspace = true
|
||||||
hound = "3.5.1"
|
hound = "3.5.1"
|
||||||
jpreprocess = { version = "0.10.0", features = ["naist-jdic"] }
|
jpreprocess = { version = "0.10.0", features = ["naist-jdic"] }
|
||||||
ndarray = "0.16.1"
|
ndarray = "0.16.1"
|
||||||
num_cpus = "1.16.0"
|
num_cpus = "1.16.0"
|
||||||
once_cell = "1.19.0"
|
once_cell = "1.19.0"
|
||||||
ort = { git = "https://github.com/pykeio/ort.git", version = "2.0.0-rc.5" }
|
ort = { git = "https://github.com/pykeio/ort.git", version = "2.0.0-rc.6" }
|
||||||
regex = "1.10.6"
|
regex = "1.10.6"
|
||||||
serde = { version = "1.0.210", features = ["derive"] }
|
serde = { version = "1.0.210", features = ["derive"] }
|
||||||
serde_json = "1.0.128"
|
serde_json = "1.0.128"
|
||||||
thiserror = "1.0.63"
|
thiserror = "1.0.63"
|
||||||
tokenizers = "0.20.0"
|
tokenizers = "0.20.0"
|
||||||
|
|
||||||
|
[features]
|
||||||
|
cuda = ["ort/cuda"]
|
||||||
|
cuda_tf32 = []
|
||||||
|
dynamic = ["ort/load-dynamic"]
|
||||||
@@ -193,4 +193,4 @@ for mora, consonant, vowel in __MORA_LIST_ADDITIONAL:
|
|||||||
|
|
||||||
|
|
||||||
with open("src/mora_list.json", "w") as f:
|
with open("src/mora_list.json", "w") as f:
|
||||||
json.dump(data, f, ensure_ascii=False, indent=4)
|
json.dump(data, f, ensure_ascii=False, indent=4)
|
||||||
@@ -18,6 +18,8 @@ pub enum Error {
|
|||||||
IoError(#[from] std::io::Error),
|
IoError(#[from] std::io::Error),
|
||||||
#[error("hound error: {0}")]
|
#[error("hound error: {0}")]
|
||||||
HoundError(#[from] hound::Error),
|
HoundError(#[from] hound::Error),
|
||||||
|
#[error("model not found error")]
|
||||||
|
ModelNotFoundError(String),
|
||||||
}
|
}
|
||||||
|
|
||||||
pub type Result<T> = std::result::Result<T, Error>;
|
pub type Result<T> = std::result::Result<T, Error>;
|
||||||
|
|||||||
@@ -4,13 +4,13 @@ use crate::norm::{replace_punctuation, PUNCTUATIONS};
|
|||||||
use jpreprocess::*;
|
use jpreprocess::*;
|
||||||
use once_cell::sync::Lazy;
|
use once_cell::sync::Lazy;
|
||||||
use regex::Regex;
|
use regex::Regex;
|
||||||
|
use std::cmp::Reverse;
|
||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokenizers::Tokenizer;
|
|
||||||
|
|
||||||
type JPreprocessType = JPreprocess<DefaultFetcher>;
|
type JPreprocessType = JPreprocess<DefaultFetcher>;
|
||||||
|
|
||||||
fn get_jtalk() -> Result<JPreprocessType> {
|
fn initialize_jtalk() -> Result<JPreprocessType> {
|
||||||
let config = JPreprocessConfig {
|
let config = JPreprocessConfig {
|
||||||
dictionary: SystemDictionaryConfig::Bundled(kind::JPreprocessDictionaryKind::NaistJdic),
|
dictionary: SystemDictionaryConfig::Bundled(kind::JPreprocessDictionaryKind::NaistJdic),
|
||||||
user_dictionary: None,
|
user_dictionary: None,
|
||||||
@@ -50,7 +50,7 @@ pub struct JTalk {
|
|||||||
|
|
||||||
impl JTalk {
|
impl JTalk {
|
||||||
pub fn new() -> Result<Self> {
|
pub fn new() -> Result<Self> {
|
||||||
let jpreprocess = Arc::new(get_jtalk()?);
|
let jpreprocess = Arc::new(initialize_jtalk()?);
|
||||||
Ok(Self { jpreprocess })
|
Ok(Self { jpreprocess })
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -64,7 +64,7 @@ impl JTalk {
|
|||||||
static KATAKANA_PATTERN: Lazy<Regex> = Lazy::new(|| Regex::new(r"[\u30A0-\u30FF]+").unwrap());
|
static KATAKANA_PATTERN: Lazy<Regex> = Lazy::new(|| Regex::new(r"[\u30A0-\u30FF]+").unwrap());
|
||||||
static MORA_PATTERN: Lazy<Vec<String>> = Lazy::new(|| {
|
static MORA_PATTERN: Lazy<Vec<String>> = Lazy::new(|| {
|
||||||
let mut sorted_keys: Vec<String> = MORA_KATA_TO_MORA_PHONEMES.keys().cloned().collect();
|
let mut sorted_keys: Vec<String> = MORA_KATA_TO_MORA_PHONEMES.keys().cloned().collect();
|
||||||
sorted_keys.sort_by(|a, b| b.len().cmp(&a.len()));
|
sorted_keys.sort_by_key(|b| Reverse(b.len()));
|
||||||
sorted_keys
|
sorted_keys
|
||||||
});
|
});
|
||||||
static LONG_PATTERN: Lazy<Regex> = Lazy::new(|| Regex::new(r"(\w)(ー*)").unwrap());
|
static LONG_PATTERN: Lazy<Regex> = Lazy::new(|| Regex::new(r"(\w)(ー*)").unwrap());
|
||||||
@@ -128,8 +128,8 @@ impl JTalkProcess {
|
|||||||
JTalkProcess::align_tones(phone_w_punct, phone_tone_list_wo_punct)?;
|
JTalkProcess::align_tones(phone_w_punct, phone_tone_list_wo_punct)?;
|
||||||
|
|
||||||
let mut sep_tokenized: Vec<Vec<String>> = Vec::new();
|
let mut sep_tokenized: Vec<Vec<String>> = Vec::new();
|
||||||
for i in 0..seq_text.len() {
|
for seq_text_item in &seq_text {
|
||||||
let text = seq_text[i].clone();
|
let text = seq_text_item.clone();
|
||||||
if !PUNCTUATIONS.contains(&text.as_str()) {
|
if !PUNCTUATIONS.contains(&text.as_str()) {
|
||||||
sep_tokenized.push(text.chars().map(|x| x.to_string()).collect());
|
sep_tokenized.push(text.chars().map(|x| x.to_string()).collect());
|
||||||
} else {
|
} else {
|
||||||
@@ -390,22 +390,3 @@ impl JTalkProcess {
|
|||||||
Ok(phones)
|
Ok(phones)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_tokenizer() -> Result<Tokenizer> {
|
|
||||||
let tokenizer = Tokenizer::from_file("tokenizer.json")?;
|
|
||||||
Ok(tokenizer)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn tokenize(text: &str, tokenizer: &Tokenizer) -> Result<(Vec<i64>, Vec<i64>)> {
|
|
||||||
let mut token_ids = vec![1];
|
|
||||||
let mut attention_masks = vec![1];
|
|
||||||
for content in text.chars() {
|
|
||||||
let token = tokenizer.encode(content.to_string(), false)?;
|
|
||||||
let ids = token.get_ids();
|
|
||||||
token_ids.extend(ids.iter().map(|&x| x as i64));
|
|
||||||
attention_masks.extend(token.get_attention_mask().iter().map(|&x| x as i64));
|
|
||||||
}
|
|
||||||
token_ids.push(2);
|
|
||||||
attention_masks.push(1);
|
|
||||||
Ok((token_ids, attention_masks))
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -6,20 +6,6 @@ pub mod mora;
|
|||||||
pub mod nlp;
|
pub mod nlp;
|
||||||
pub mod norm;
|
pub mod norm;
|
||||||
pub mod style;
|
pub mod style;
|
||||||
|
pub mod tokenizer;
|
||||||
pub mod tts;
|
pub mod tts;
|
||||||
pub mod utils;
|
pub mod utils;
|
||||||
|
|
||||||
pub fn add(left: usize, right: usize) -> usize {
|
|
||||||
left + right
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn it_works() {
|
|
||||||
let result = add(2, 2);
|
|
||||||
assert_eq!(result, 4);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,20 +1,49 @@
|
|||||||
use sbv2_core::{error, tts};
|
use std::{fs, time::Instant};
|
||||||
|
|
||||||
fn main() -> error::Result<()> {
|
use sbv2_core::tts;
|
||||||
|
use std::env;
|
||||||
|
|
||||||
|
fn main() -> anyhow::Result<()> {
|
||||||
|
dotenvy::dotenv().ok();
|
||||||
let text = "眠たい";
|
let text = "眠たい";
|
||||||
|
let ident = "aaa";
|
||||||
let tts_model = tts::TTSModel::new(
|
let mut tts_holder = tts::TTSModelHolder::new(
|
||||||
"models/debert.onnx",
|
&fs::read(env::var("BERT_MODEL_PATH")?)?,
|
||||||
"models/model_opt.onnx",
|
&fs::read(env::var("TOKENIZER_PATH")?)?,
|
||||||
"models/style_vectors.json",
|
)?;
|
||||||
|
tts_holder.load(
|
||||||
|
ident,
|
||||||
|
fs::read(env::var("STYLE_VECTORS_PATH")?)?,
|
||||||
|
fs::read(env::var("MODEL_PATH")?)?,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let (bert_ori, phones, tones, lang_ids) = tts_model.parse_text(text)?;
|
let (bert_ori, phones, tones, lang_ids) = tts_holder.parse_text(text)?;
|
||||||
|
|
||||||
let style_vector = tts_model.get_style_vector(0, 1.0)?;
|
|
||||||
let data = tts_model.synthesize(bert_ori.to_owned(), phones, tones, lang_ids, style_vector)?;
|
|
||||||
|
|
||||||
|
let style_vector = tts_holder.get_style_vector(ident, 0, 1.0)?;
|
||||||
|
let data = tts_holder.synthesize(
|
||||||
|
ident,
|
||||||
|
bert_ori.to_owned(),
|
||||||
|
phones.clone(),
|
||||||
|
tones.clone(),
|
||||||
|
lang_ids.clone(),
|
||||||
|
style_vector.clone(),
|
||||||
|
0.0,
|
||||||
|
0.5,
|
||||||
|
)?;
|
||||||
std::fs::write("output.wav", data)?;
|
std::fs::write("output.wav", data)?;
|
||||||
|
let now = Instant::now();
|
||||||
|
for _ in 0..10 {
|
||||||
|
tts_holder.synthesize(
|
||||||
|
ident,
|
||||||
|
bert_ori.to_owned(),
|
||||||
|
phones.clone(),
|
||||||
|
tones.clone(),
|
||||||
|
lang_ids.clone(),
|
||||||
|
style_vector.clone(),
|
||||||
|
0.0,
|
||||||
|
1.0,
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
println!("Time taken: {}", now.elapsed().as_millis());
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,17 +4,29 @@ use ndarray::{array, s, Array1, Array2, Axis};
|
|||||||
use ort::{GraphOptimizationLevel, Session};
|
use ort::{GraphOptimizationLevel, Session};
|
||||||
use std::io::Cursor;
|
use std::io::Cursor;
|
||||||
|
|
||||||
pub fn load_model(model_file: &str) -> Result<Session> {
|
#[allow(clippy::vec_init_then_push)]
|
||||||
let session = Session::builder()?
|
pub fn load_model<P: AsRef<[u8]>>(model_file: P) -> Result<Session> {
|
||||||
|
let mut exp = Vec::new();
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
{
|
||||||
|
let mut cuda = ort::CUDAExecutionProvider::default()
|
||||||
|
.with_conv_algorithm_search(ort::CUDAExecutionProviderCuDNNConvAlgoSearch::Default);
|
||||||
|
#[cfg(feature = "cuda_tf32")]
|
||||||
|
{
|
||||||
|
cuda = cuda.with_tf32(true);
|
||||||
|
}
|
||||||
|
exp.push(cuda.build());
|
||||||
|
}
|
||||||
|
exp.push(ort::CPUExecutionProvider::default().build());
|
||||||
|
Ok(Session::builder()?
|
||||||
|
.with_execution_providers(exp)?
|
||||||
.with_optimization_level(GraphOptimizationLevel::Level3)?
|
.with_optimization_level(GraphOptimizationLevel::Level3)?
|
||||||
.with_intra_threads(1)?
|
|
||||||
.with_intra_threads(num_cpus::get_physical())?
|
.with_intra_threads(num_cpus::get_physical())?
|
||||||
.with_parallel_execution(true)?
|
.with_parallel_execution(true)?
|
||||||
.with_inter_threads(num_cpus::get_physical())?
|
.with_inter_threads(num_cpus::get_physical())?
|
||||||
.commit_from_file(model_file)?;
|
.commit_from_memory(model_file.as_ref())?)
|
||||||
Ok(session)
|
|
||||||
}
|
}
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn synthesize(
|
pub fn synthesize(
|
||||||
session: &Session,
|
session: &Session,
|
||||||
bert_ori: Array2<f32>,
|
bert_ori: Array2<f32>,
|
||||||
@@ -22,6 +34,8 @@ pub fn synthesize(
|
|||||||
tones: Array1<i64>,
|
tones: Array1<i64>,
|
||||||
lang_ids: Array1<i64>,
|
lang_ids: Array1<i64>,
|
||||||
style_vector: Array1<f32>,
|
style_vector: Array1<f32>,
|
||||||
|
sdp_ratio: f32,
|
||||||
|
length_scale: f32,
|
||||||
) -> Result<Vec<u8>> {
|
) -> Result<Vec<u8>> {
|
||||||
let bert = bert_ori.insert_axis(Axis(0));
|
let bert = bert_ori.insert_axis(Axis(0));
|
||||||
let x_tst_lengths: Array1<i64> = array![x_tst.shape()[0] as i64];
|
let x_tst_lengths: Array1<i64> = array![x_tst.shape()[0] as i64];
|
||||||
@@ -36,7 +50,9 @@ pub fn synthesize(
|
|||||||
"tones" => tones,
|
"tones" => tones,
|
||||||
"language" => lang_ids,
|
"language" => lang_ids,
|
||||||
"bert" => bert,
|
"bert" => bert,
|
||||||
"ja_bert" => style_vector,
|
"style_vec" => style_vector,
|
||||||
|
"sdp_ratio" => array![sdp_ratio],
|
||||||
|
"length_scale" => array![length_scale],
|
||||||
}?)?;
|
}?)?;
|
||||||
|
|
||||||
let audio_array = outputs
|
let audio_array = outputs
|
||||||
|
|||||||
@@ -8,8 +8,8 @@ pub struct Data {
|
|||||||
pub data: Vec<Vec<f32>>,
|
pub data: Vec<Vec<f32>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn load_style(path: &str) -> Result<Array2<f32>> {
|
pub fn load_style<P: AsRef<[u8]>>(path: P) -> Result<Array2<f32>> {
|
||||||
let data: Data = serde_json::from_str(&std::fs::read_to_string(path)?)?;
|
let data: Data = serde_json::from_slice(path.as_ref())?;
|
||||||
Ok(Array2::from_shape_vec(
|
Ok(Array2::from_shape_vec(
|
||||||
data.shape,
|
data.shape,
|
||||||
data.data.iter().flatten().copied().collect(),
|
data.data.iter().flatten().copied().collect(),
|
||||||
@@ -17,7 +17,7 @@ pub fn load_style(path: &str) -> Result<Array2<f32>> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_style_vector(
|
pub fn get_style_vector(
|
||||||
style_vectors: Array2<f32>,
|
style_vectors: &Array2<f32>,
|
||||||
style_id: i32,
|
style_id: i32,
|
||||||
weight: f32,
|
weight: f32,
|
||||||
) -> Result<Array1<f32>> {
|
) -> Result<Array1<f32>> {
|
||||||
|
|||||||
21
sbv2_core/src/tokenizer.rs
Normal file
21
sbv2_core/src/tokenizer.rs
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
use crate::error::Result;
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
pub fn get_tokenizer<P: AsRef<[u8]>>(p: P) -> Result<Tokenizer> {
|
||||||
|
let tokenizer = Tokenizer::from_bytes(p)?;
|
||||||
|
Ok(tokenizer)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tokenize(text: &str, tokenizer: &Tokenizer) -> Result<(Vec<i64>, Vec<i64>)> {
|
||||||
|
let mut token_ids = vec![1];
|
||||||
|
let mut attention_masks = vec![1];
|
||||||
|
for content in text.chars() {
|
||||||
|
let token = tokenizer.encode(content.to_string(), false)?;
|
||||||
|
let ids = token.get_ids();
|
||||||
|
token_ids.extend(ids.iter().map(|&x| x as i64));
|
||||||
|
attention_masks.extend(token.get_attention_mask().iter().map(|&x| x as i64));
|
||||||
|
}
|
||||||
|
token_ids.push(2);
|
||||||
|
attention_masks.push(1);
|
||||||
|
Ok((token_ids, attention_masks))
|
||||||
|
}
|
||||||
@@ -1,33 +1,91 @@
|
|||||||
use crate::error::Result;
|
use crate::error::{Error, Result};
|
||||||
use crate::{bert, jtalk, model, nlp, norm, style, utils};
|
use crate::{bert, jtalk, model, nlp, norm, style, tokenizer, utils};
|
||||||
use ndarray::{concatenate, s, Array, Array1, Array2, Axis};
|
use ndarray::{concatenate, s, Array, Array1, Array2, Axis};
|
||||||
use ort::Session;
|
use ort::Session;
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
#[derive(PartialEq, Eq, Clone)]
|
||||||
|
pub struct TTSIdent(String);
|
||||||
|
|
||||||
|
impl std::fmt::Display for TTSIdent {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.write_str(&self.0)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S> From<S> for TTSIdent
|
||||||
|
where
|
||||||
|
S: AsRef<str>,
|
||||||
|
{
|
||||||
|
fn from(value: S) -> Self {
|
||||||
|
TTSIdent(value.as_ref().to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub struct TTSModel {
|
pub struct TTSModel {
|
||||||
bert: Session,
|
|
||||||
vits2: Session,
|
vits2: Session,
|
||||||
style_vectors: Array2<f32>,
|
style_vectors: Array2<f32>,
|
||||||
|
ident: TTSIdent,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct TTSModelHolder {
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
bert: Session,
|
||||||
|
models: Vec<TTSModel>,
|
||||||
jtalk: jtalk::JTalk,
|
jtalk: jtalk::JTalk,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TTSModel {
|
impl TTSModelHolder {
|
||||||
pub fn new(
|
pub fn new<P: AsRef<[u8]>>(bert_model_bytes: P, tokenizer_bytes: P) -> Result<Self> {
|
||||||
bert_model_path: &str,
|
let bert = model::load_model(bert_model_bytes)?;
|
||||||
main_model_path: &str,
|
|
||||||
style_vector_path: &str,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let bert = model::load_model(bert_model_path)?;
|
|
||||||
let vits2 = model::load_model(main_model_path)?;
|
|
||||||
let style_vectors = style::load_style(style_vector_path)?;
|
|
||||||
let jtalk = jtalk::JTalk::new()?;
|
let jtalk = jtalk::JTalk::new()?;
|
||||||
Ok(TTSModel {
|
let tokenizer = tokenizer::get_tokenizer(tokenizer_bytes)?;
|
||||||
|
Ok(TTSModelHolder {
|
||||||
bert,
|
bert,
|
||||||
vits2,
|
models: vec![],
|
||||||
style_vectors,
|
|
||||||
jtalk,
|
jtalk,
|
||||||
|
tokenizer,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn models(&self) -> Vec<String> {
|
||||||
|
self.models.iter().map(|m| m.ident.to_string()).collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load<I: Into<TTSIdent>, P: AsRef<[u8]>>(
|
||||||
|
&mut self,
|
||||||
|
ident: I,
|
||||||
|
style_vectors_bytes: P,
|
||||||
|
vits2_bytes: P,
|
||||||
|
) -> Result<()> {
|
||||||
|
let ident = ident.into();
|
||||||
|
if self.find_model(ident.clone()).is_err() {
|
||||||
|
self.models.push(TTSModel {
|
||||||
|
vits2: model::load_model(vits2_bytes)?,
|
||||||
|
style_vectors: style::load_style(style_vectors_bytes)?,
|
||||||
|
ident,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn unload<I: Into<TTSIdent>>(&mut self, ident: I) -> bool {
|
||||||
|
let ident = ident.into();
|
||||||
|
if let Some((i, _)) = self
|
||||||
|
.models
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.find(|(_, m)| m.ident == ident)
|
||||||
|
{
|
||||||
|
self.models.remove(i);
|
||||||
|
true
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::type_complexity)]
|
||||||
pub fn parse_text(
|
pub fn parse_text(
|
||||||
&self,
|
&self,
|
||||||
text: &str,
|
text: &str,
|
||||||
@@ -40,13 +98,11 @@ impl TTSModel {
|
|||||||
let phones = utils::intersperse(&phones, 0);
|
let phones = utils::intersperse(&phones, 0);
|
||||||
let tones = utils::intersperse(&tones, 0);
|
let tones = utils::intersperse(&tones, 0);
|
||||||
let lang_ids = utils::intersperse(&lang_ids, 0);
|
let lang_ids = utils::intersperse(&lang_ids, 0);
|
||||||
for i in 0..word2ph.len() {
|
for item in &mut word2ph {
|
||||||
word2ph[i] *= 2;
|
*item *= 2;
|
||||||
}
|
}
|
||||||
word2ph[0] += 1;
|
word2ph[0] += 1;
|
||||||
|
let (token_ids, attention_masks) = tokenizer::tokenize(&normalized_text, &self.tokenizer)?;
|
||||||
let tokenizer = jtalk::get_tokenizer()?;
|
|
||||||
let (token_ids, attention_masks) = jtalk::tokenize(&normalized_text, &tokenizer)?;
|
|
||||||
|
|
||||||
let bert_content = bert::predict(&self.bert, token_ids, attention_masks)?;
|
let bert_content = bert::predict(&self.bert, token_ids, attention_masks)?;
|
||||||
|
|
||||||
@@ -58,9 +114,9 @@ impl TTSModel {
|
|||||||
);
|
);
|
||||||
|
|
||||||
let mut phone_level_feature = vec![];
|
let mut phone_level_feature = vec![];
|
||||||
for i in 0..word2ph.len() {
|
for (i, reps) in word2ph.iter().enumerate() {
|
||||||
let repeat_feature = {
|
let repeat_feature = {
|
||||||
let (reps_rows, reps_cols) = (word2ph[i], 1);
|
let (reps_rows, reps_cols) = (*reps, 1);
|
||||||
let arr_len = bert_content.slice(s![i, ..]).len();
|
let arr_len = bert_content.slice(s![i, ..]).len();
|
||||||
|
|
||||||
let mut results: Array2<f32> =
|
let mut results: Array2<f32> =
|
||||||
@@ -92,25 +148,44 @@ impl TTSModel {
|
|||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_style_vector(&self, style_id: i32, weight: f32) -> Result<Array1<f32>> {
|
fn find_model<I: Into<TTSIdent>>(&self, ident: I) -> Result<&TTSModel> {
|
||||||
style::get_style_vector(self.style_vectors.clone(), style_id, weight)
|
let ident = ident.into();
|
||||||
|
self.models
|
||||||
|
.iter()
|
||||||
|
.find(|m| m.ident == ident)
|
||||||
|
.ok_or(Error::ModelNotFoundError(ident.to_string()))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn synthesize(
|
pub fn get_style_vector<I: Into<TTSIdent>>(
|
||||||
&self,
|
&self,
|
||||||
|
ident: I,
|
||||||
|
style_id: i32,
|
||||||
|
weight: f32,
|
||||||
|
) -> Result<Array1<f32>> {
|
||||||
|
style::get_style_vector(&self.find_model(ident)?.style_vectors, style_id, weight)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub fn synthesize<I: Into<TTSIdent>>(
|
||||||
|
&self,
|
||||||
|
ident: I,
|
||||||
bert_ori: Array2<f32>,
|
bert_ori: Array2<f32>,
|
||||||
phones: Array1<i64>,
|
phones: Array1<i64>,
|
||||||
tones: Array1<i64>,
|
tones: Array1<i64>,
|
||||||
lang_ids: Array1<i64>,
|
lang_ids: Array1<i64>,
|
||||||
style_vector: Array1<f32>,
|
style_vector: Array1<f32>,
|
||||||
|
sdp_ratio: f32,
|
||||||
|
length_scale: f32,
|
||||||
) -> Result<Vec<u8>> {
|
) -> Result<Vec<u8>> {
|
||||||
let buffer = model::synthesize(
|
let buffer = model::synthesize(
|
||||||
&self.vits2,
|
&self.find_model(ident)?.vits2,
|
||||||
bert_ori.to_owned(),
|
bert_ori.to_owned(),
|
||||||
phones,
|
phones,
|
||||||
tones,
|
tones,
|
||||||
lang_ids,
|
lang_ids,
|
||||||
style_vector,
|
style_vector,
|
||||||
|
sdp_ratio,
|
||||||
|
length_scale,
|
||||||
)?;
|
)?;
|
||||||
Ok(buffer)
|
Ok(buffer)
|
||||||
}
|
}
|
||||||
|
|||||||
12
test.py
12
test.py
@@ -1,8 +1,8 @@
|
|||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
res = requests.post(
|
||||||
res = requests.post('http://localhost:3000/synthesize', json={
|
"http://localhost:3001/synthesize",
|
||||||
"text": "初めて神戸に移り住んだ時に地元の人に教わった「阪急はオシャレして乗らなあかん。阪神はスリッパで乗っていい。JRは早い。」、好きすぎていまだに東京の人に説明するとき使ってる。"
|
json={"text": "おはようございます", "ident": "tsukuyomi"},
|
||||||
})
|
)
|
||||||
with open('output.wav', 'wb') as f:
|
with open("output.wav", "wb") as f:
|
||||||
f.write(res.content)
|
f.write(res.content)
|
||||||
|
|||||||
22116
tokenizer.json
22116
tokenizer.json
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user