mirror of
https://github.com/neodyland/sbv2-api.git
synced 2026-01-09 07:52:57 +00:00
50
convert/convert_deberta.py
Normal file
50
convert/convert_deberta.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from transformers.convert_slow_tokenizer import BertConverter
|
||||
from style_bert_vits2.nlp import bert_models
|
||||
from style_bert_vits2.constants import Languages
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
import torch
|
||||
from torch import nn
|
||||
from argparse import ArgumentParser
|
||||
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--model", default="ku-nlp/deberta-v2-large-japanese-char-wwm")
|
||||
args = parser.parse_args()
|
||||
model_name = args.model
|
||||
|
||||
bert_models.load_tokenizer(Languages.JP, model_name)
|
||||
tokenizer = bert_models.load_tokenizer(Languages.JP)
|
||||
converter = BertConverter(tokenizer)
|
||||
tokenizer = converter.converted()
|
||||
tokenizer.save("../models/tokenizer.json")
|
||||
|
||||
|
||||
class ORTDeberta(nn.Module):
|
||||
def __init__(self, model_name):
|
||||
super(ORTDeberta, self).__init__()
|
||||
self.model = AutoModelForMaskedLM.from_pretrained(model_name)
|
||||
|
||||
def forward(self, input_ids, token_type_ids, attention_mask):
|
||||
inputs = {
|
||||
"input_ids": input_ids,
|
||||
"token_type_ids": token_type_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
res = self.model(**inputs, output_hidden_states=True)
|
||||
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
|
||||
return res
|
||||
|
||||
|
||||
model = ORTDeberta(model_name)
|
||||
inputs = AutoTokenizer.from_pretrained(model_name)(
|
||||
"今日はいい天気ですね", return_tensors="pt"
|
||||
)
|
||||
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(inputs["input_ids"], inputs["token_type_ids"], inputs["attention_mask"]),
|
||||
"../models/deberta.onnx",
|
||||
input_names=["input_ids", "token_type_ids", "attention_mask"],
|
||||
output_names=["output"],
|
||||
verbose=True,
|
||||
dynamic_axes={"input_ids": {1: "batch_size"}, "attention_mask": {1: "batch_size"}},
|
||||
)
|
||||
129
convert/convert_model.py
Normal file
129
convert/convert_model.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import numpy as np
|
||||
import json
|
||||
from style_bert_vits2.nlp import bert_models
|
||||
from style_bert_vits2.constants import Languages
|
||||
from style_bert_vits2.models.infer import get_net_g, get_text
|
||||
from style_bert_vits2.models.hyper_parameters import HyperParameters
|
||||
import torch
|
||||
from style_bert_vits2.constants import (
|
||||
DEFAULT_ASSIST_TEXT_WEIGHT,
|
||||
DEFAULT_STYLE,
|
||||
DEFAULT_STYLE_WEIGHT,
|
||||
Languages,
|
||||
)
|
||||
from style_bert_vits2.tts_model import TTSModel
|
||||
import numpy as np
|
||||
from argparse import ArgumentParser
|
||||
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--style_file", required=True)
|
||||
parser.add_argument("--config_file", required=True)
|
||||
parser.add_argument("--model_file", required=True)
|
||||
args = parser.parse_args()
|
||||
style_file = args.style_file
|
||||
config_file = args.config_file
|
||||
model_file = args.model_file
|
||||
|
||||
bert_models.load_model(Languages.JP, "ku-nlp/deberta-v2-large-japanese-char-wwm")
|
||||
bert_models.load_tokenizer(Languages.JP, "ku-nlp/deberta-v2-large-japanese-char-wwm")
|
||||
|
||||
array = np.load(style_file)
|
||||
data = array.tolist()
|
||||
hyper_parameters = HyperParameters.load_from_json(config_file)
|
||||
out_name = hyper_parameters.model_name
|
||||
|
||||
with open(f"../models/style_vectors_{out_name}.json", "w") as f:
|
||||
json.dump(
|
||||
{
|
||||
"data": data,
|
||||
"shape": array.shape,
|
||||
},
|
||||
f,
|
||||
)
|
||||
text = "今日はいい天気ですね。"
|
||||
|
||||
bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
|
||||
text,
|
||||
Languages.JP,
|
||||
hyper_parameters,
|
||||
"cpu",
|
||||
assist_text=None,
|
||||
assist_text_weight=DEFAULT_ASSIST_TEXT_WEIGHT,
|
||||
given_phone=None,
|
||||
given_tone=None,
|
||||
)
|
||||
|
||||
tts_model = TTSModel(
|
||||
model_path=model_file,
|
||||
config_path=config_file,
|
||||
style_vec_path=style_file,
|
||||
device="cpu",
|
||||
)
|
||||
device = "cpu"
|
||||
style_id = tts_model.style2id[DEFAULT_STYLE]
|
||||
|
||||
|
||||
def get_style_vector(style_id, weight):
|
||||
style_vectors = np.load(style_file)
|
||||
mean = style_vectors[0]
|
||||
style_vec = style_vectors[style_id]
|
||||
style_vec = mean + (style_vec - mean) * weight
|
||||
return style_vec
|
||||
|
||||
|
||||
style_vector = get_style_vector(style_id, DEFAULT_STYLE_WEIGHT)
|
||||
|
||||
x_tst = phones.to(device).unsqueeze(0)
|
||||
tones = tones.to(device).unsqueeze(0)
|
||||
lang_ids = lang_ids.to(device).unsqueeze(0)
|
||||
bert = bert.to(device).unsqueeze(0)
|
||||
ja_bert = ja_bert.to(device).unsqueeze(0)
|
||||
en_bert = en_bert.to(device).unsqueeze(0)
|
||||
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
|
||||
style_vec_tensor = torch.from_numpy(style_vector).to(device).unsqueeze(0)
|
||||
|
||||
model = get_net_g(
|
||||
model_file,
|
||||
hyper_parameters.version,
|
||||
device,
|
||||
hyper_parameters,
|
||||
)
|
||||
|
||||
|
||||
def forward(*args):
|
||||
return model.infer(*args)
|
||||
|
||||
|
||||
model.forward = forward
|
||||
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(
|
||||
x_tst,
|
||||
x_tst_lengths,
|
||||
torch.LongTensor([0]).to(device),
|
||||
tones,
|
||||
lang_ids,
|
||||
bert,
|
||||
style_vec_tensor,
|
||||
),
|
||||
f"../models/model_{out_name}.onnx",
|
||||
verbose=True,
|
||||
dynamic_axes={
|
||||
"x_tst": {1: "batch_size"},
|
||||
"x_tst_lengths": {0: "batch_size"},
|
||||
"tones": {1: "batch_size"},
|
||||
"language": {1: "batch_size"},
|
||||
"bert": {2: "batch_size"},
|
||||
},
|
||||
input_names=[
|
||||
"x_tst",
|
||||
"x_tst_lengths",
|
||||
"sid",
|
||||
"tones",
|
||||
"language",
|
||||
"bert",
|
||||
"style_vec",
|
||||
],
|
||||
output_names=["output"],
|
||||
)
|
||||
3
convert/requirements.txt
Normal file
3
convert/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
style-bert-vits2
|
||||
onnxsim
|
||||
numpy<2
|
||||
@@ -193,4 +193,4 @@ for mora, consonant, vowel in __MORA_LIST_ADDITIONAL:
|
||||
|
||||
|
||||
with open("src/mora_list.json", "w") as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=4)
|
||||
json.dump(data, f, ensure_ascii=False, indent=4)
|
||||
|
||||
12
test.py
12
test.py
@@ -1,8 +1,8 @@
|
||||
import requests
|
||||
|
||||
res = requests.post('http://localhost:3000/synthesize', json={
|
||||
"text": "おはようございます",
|
||||
"ident": "tsukuyomi"
|
||||
})
|
||||
with open('output.wav', 'wb') as f:
|
||||
f.write(res.content)
|
||||
res = requests.post(
|
||||
"http://localhost:3000/synthesize",
|
||||
json={"text": "おはようございます", "ident": "tsukuyomi"},
|
||||
)
|
||||
with open("output.wav", "wb") as f:
|
||||
f.write(res.content)
|
||||
|
||||
Reference in New Issue
Block a user