mirror of
https://github.com/neodyland/sbv2-api.git
synced 2025-12-22 23:49:58 +00:00
170 lines
4.5 KiB
Python
170 lines
4.5 KiB
Python
import numpy as np
|
|
import json
|
|
from io import BytesIO
|
|
from style_bert_vits2.nlp import bert_models
|
|
from style_bert_vits2.constants import Languages
|
|
from style_bert_vits2.models.infer import get_net_g, get_text
|
|
from style_bert_vits2.models.hyper_parameters import HyperParameters
|
|
import torch
|
|
from style_bert_vits2.constants import (
|
|
DEFAULT_ASSIST_TEXT_WEIGHT,
|
|
DEFAULT_STYLE,
|
|
DEFAULT_STYLE_WEIGHT,
|
|
Languages,
|
|
)
|
|
import os
|
|
from tarfile import open as taropen, TarInfo
|
|
from zstandard import ZstdCompressor
|
|
from style_bert_vits2.tts_model import TTSModel
|
|
import numpy as np
|
|
from argparse import ArgumentParser
|
|
|
|
parser = ArgumentParser()
|
|
parser.add_argument("--style_file", required=True)
|
|
parser.add_argument("--config_file", required=True)
|
|
parser.add_argument("--model_file", required=True)
|
|
args = parser.parse_args()
|
|
style_file = args.style_file
|
|
config_file = args.config_file
|
|
model_file = args.model_file
|
|
|
|
bert_models.load_model(Languages.JP, "ku-nlp/deberta-v2-large-japanese-char-wwm")
|
|
bert_models.load_tokenizer(Languages.JP, "ku-nlp/deberta-v2-large-japanese-char-wwm")
|
|
|
|
array = np.load(style_file)
|
|
data = array.tolist()
|
|
hyper_parameters = HyperParameters.load_from_json(config_file)
|
|
out_name = hyper_parameters.model_name
|
|
|
|
with open(f"../models/style_vectors_{out_name}.json", "w") as f:
|
|
json.dump(
|
|
{
|
|
"data": data,
|
|
"shape": array.shape,
|
|
},
|
|
f,
|
|
)
|
|
text = "今日はいい天気ですね。"
|
|
|
|
bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
|
|
text,
|
|
Languages.JP,
|
|
hyper_parameters,
|
|
"cpu",
|
|
assist_text=None,
|
|
assist_text_weight=DEFAULT_ASSIST_TEXT_WEIGHT,
|
|
given_phone=None,
|
|
given_tone=None,
|
|
)
|
|
|
|
tts_model = TTSModel(
|
|
model_path=model_file,
|
|
config_path=config_file,
|
|
style_vec_path=style_file,
|
|
device="cpu",
|
|
)
|
|
device = "cpu"
|
|
style_id = tts_model.style2id[DEFAULT_STYLE]
|
|
|
|
|
|
def get_style_vector(style_id, weight):
|
|
style_vectors = np.load(style_file)
|
|
mean = style_vectors[0]
|
|
style_vec = style_vectors[style_id]
|
|
style_vec = mean + (style_vec - mean) * weight
|
|
return style_vec
|
|
|
|
|
|
style_vector = get_style_vector(style_id, DEFAULT_STYLE_WEIGHT)
|
|
|
|
x_tst = phones.to(device).unsqueeze(0)
|
|
tones = tones.to(device).unsqueeze(0)
|
|
lang_ids = lang_ids.to(device).unsqueeze(0)
|
|
bert = bert.to(device).unsqueeze(0)
|
|
ja_bert = ja_bert.to(device).unsqueeze(0)
|
|
en_bert = en_bert.to(device).unsqueeze(0)
|
|
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
|
|
style_vec_tensor = torch.from_numpy(style_vector).to(device).unsqueeze(0)
|
|
|
|
model = get_net_g(
|
|
model_file,
|
|
hyper_parameters.version,
|
|
device,
|
|
hyper_parameters,
|
|
)
|
|
|
|
|
|
def forward(x, x_len, sid, tone, lang, bert, style, length_scale, sdp_ratio):
|
|
return model.infer(
|
|
x,
|
|
x_len,
|
|
sid,
|
|
tone,
|
|
lang,
|
|
bert,
|
|
style,
|
|
sdp_ratio=sdp_ratio,
|
|
length_scale=length_scale,
|
|
)
|
|
|
|
|
|
model.forward = forward
|
|
|
|
torch.onnx.export(
|
|
model,
|
|
(
|
|
x_tst,
|
|
x_tst_lengths,
|
|
torch.LongTensor([0]).to(device),
|
|
tones,
|
|
lang_ids,
|
|
bert,
|
|
style_vec_tensor,
|
|
torch.tensor(1.0),
|
|
torch.tensor(0.0),
|
|
),
|
|
f"../models/model_{out_name}.onnx",
|
|
verbose=True,
|
|
dynamic_axes={
|
|
"x_tst": {0: "batch_size", 1: "x_tst_max_length"},
|
|
"x_tst_lengths": {0: "batch_size"},
|
|
"sid": {0: "batch_size"},
|
|
"tones": {0: "batch_size", 1: "x_tst_max_length"},
|
|
"language": {0: "batch_size", 1: "x_tst_max_length"},
|
|
"bert": {0: "batch_size", 2: "x_tst_max_length"},
|
|
"style_vec": {0: "batch_size"},
|
|
},
|
|
input_names=[
|
|
"x_tst",
|
|
"x_tst_lengths",
|
|
"sid",
|
|
"tones",
|
|
"language",
|
|
"bert",
|
|
"style_vec",
|
|
"length_scale",
|
|
"sdp_ratio",
|
|
],
|
|
output_names=["output"],
|
|
)
|
|
os.system(f"onnxsim ../models/model_{out_name}.onnx ../models/model_{out_name}.onnx")
|
|
onnxfile = open(f"../models/model_{out_name}.onnx", "rb").read()
|
|
stylefile = open(f"../models/style_vectors_{out_name}.json", "rb").read()
|
|
version = bytes("1", "utf8")
|
|
with taropen(f"../models/tmp_{out_name}.sbv2tar", "w") as w:
|
|
|
|
def add_tar(f, b):
|
|
t = TarInfo(f)
|
|
t.size = len(b)
|
|
w.addfile(t, BytesIO(b))
|
|
|
|
add_tar("version.txt", version)
|
|
add_tar("model.onnx", onnxfile)
|
|
add_tar("style_vectors.json", stylefile)
|
|
open(f"../models/{out_name}.sbv2", "wb").write(
|
|
ZstdCompressor(threads=-1, level=22).compress(
|
|
open(f"../models/tmp_{out_name}.sbv2tar", "rb").read()
|
|
)
|
|
)
|
|
os.unlink(f"../models/tmp_{out_name}.sbv2tar")
|