diff --git a/convert/convert_deberta.py b/convert/convert_deberta.py new file mode 100644 index 0000000..2c8ccb9 --- /dev/null +++ b/convert/convert_deberta.py @@ -0,0 +1,50 @@ +from transformers.convert_slow_tokenizer import BertConverter +from style_bert_vits2.nlp import bert_models +from style_bert_vits2.constants import Languages +from transformers import AutoModelForMaskedLM, AutoTokenizer +import torch +from torch import nn +from argparse import ArgumentParser + +parser = ArgumentParser() +parser.add_argument("--model", default="ku-nlp/deberta-v2-large-japanese-char-wwm") +args = parser.parse_args() +model_name = args.model + +bert_models.load_tokenizer(Languages.JP, model_name) +tokenizer = bert_models.load_tokenizer(Languages.JP) +converter = BertConverter(tokenizer) +tokenizer = converter.converted() +tokenizer.save("../models/tokenizer.json") + + +class ORTDeberta(nn.Module): + def __init__(self, model_name): + super(ORTDeberta, self).__init__() + self.model = AutoModelForMaskedLM.from_pretrained(model_name) + + def forward(self, input_ids, token_type_ids, attention_mask): + inputs = { + "input_ids": input_ids, + "token_type_ids": token_type_ids, + "attention_mask": attention_mask, + } + res = self.model(**inputs, output_hidden_states=True) + res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu() + return res + + +model = ORTDeberta(model_name) +inputs = AutoTokenizer.from_pretrained(model_name)( + "今日はいい天気ですね", return_tensors="pt" +) + +torch.onnx.export( + model, + (inputs["input_ids"], inputs["token_type_ids"], inputs["attention_mask"]), + "../models/deberta.onnx", + input_names=["input_ids", "token_type_ids", "attention_mask"], + output_names=["output"], + verbose=True, + dynamic_axes={"input_ids": {1: "batch_size"}, "attention_mask": {1: "batch_size"}}, +) diff --git a/convert/convert_model.py b/convert/convert_model.py new file mode 100644 index 0000000..9ec97b7 --- /dev/null +++ b/convert/convert_model.py @@ -0,0 +1,129 @@ +import numpy as np +import json +from style_bert_vits2.nlp import bert_models +from style_bert_vits2.constants import Languages +from style_bert_vits2.models.infer import get_net_g, get_text +from style_bert_vits2.models.hyper_parameters import HyperParameters +import torch +from style_bert_vits2.constants import ( + DEFAULT_ASSIST_TEXT_WEIGHT, + DEFAULT_STYLE, + DEFAULT_STYLE_WEIGHT, + Languages, +) +from style_bert_vits2.tts_model import TTSModel +import numpy as np +from argparse import ArgumentParser + +parser = ArgumentParser() +parser.add_argument("--style_file", required=True) +parser.add_argument("--config_file", required=True) +parser.add_argument("--model_file", required=True) +args = parser.parse_args() +style_file = args.style_file +config_file = args.config_file +model_file = args.model_file + +bert_models.load_model(Languages.JP, "ku-nlp/deberta-v2-large-japanese-char-wwm") +bert_models.load_tokenizer(Languages.JP, "ku-nlp/deberta-v2-large-japanese-char-wwm") + +array = np.load(style_file) +data = array.tolist() +hyper_parameters = HyperParameters.load_from_json(config_file) +out_name = hyper_parameters.model_name + +with open(f"../models/style_vectors_{out_name}.json", "w") as f: + json.dump( + { + "data": data, + "shape": array.shape, + }, + f, + ) +text = "今日はいい天気ですね。" + +bert, ja_bert, en_bert, phones, tones, lang_ids = get_text( + text, + Languages.JP, + hyper_parameters, + "cpu", + assist_text=None, + assist_text_weight=DEFAULT_ASSIST_TEXT_WEIGHT, + given_phone=None, + given_tone=None, +) + +tts_model = TTSModel( + model_path=model_file, + config_path=config_file, + style_vec_path=style_file, + device="cpu", +) +device = "cpu" +style_id = tts_model.style2id[DEFAULT_STYLE] + + +def get_style_vector(style_id, weight): + style_vectors = np.load(style_file) + mean = style_vectors[0] + style_vec = style_vectors[style_id] + style_vec = mean + (style_vec - mean) * weight + return style_vec + + +style_vector = get_style_vector(style_id, DEFAULT_STYLE_WEIGHT) + +x_tst = phones.to(device).unsqueeze(0) +tones = tones.to(device).unsqueeze(0) +lang_ids = lang_ids.to(device).unsqueeze(0) +bert = bert.to(device).unsqueeze(0) +ja_bert = ja_bert.to(device).unsqueeze(0) +en_bert = en_bert.to(device).unsqueeze(0) +x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device) +style_vec_tensor = torch.from_numpy(style_vector).to(device).unsqueeze(0) + +model = get_net_g( + model_file, + hyper_parameters.version, + device, + hyper_parameters, +) + + +def forward(*args): + return model.infer(*args) + + +model.forward = forward + +torch.onnx.export( + model, + ( + x_tst, + x_tst_lengths, + torch.LongTensor([0]).to(device), + tones, + lang_ids, + bert, + style_vec_tensor, + ), + f"../models/model_{out_name}.onnx", + verbose=True, + dynamic_axes={ + "x_tst": {1: "batch_size"}, + "x_tst_lengths": {0: "batch_size"}, + "tones": {1: "batch_size"}, + "language": {1: "batch_size"}, + "bert": {2: "batch_size"}, + }, + input_names=[ + "x_tst", + "x_tst_lengths", + "sid", + "tones", + "language", + "bert", + "style_vec", + ], + output_names=["output"], +) diff --git a/convert/requirements.txt b/convert/requirements.txt new file mode 100644 index 0000000..069ec48 --- /dev/null +++ b/convert/requirements.txt @@ -0,0 +1,3 @@ +style-bert-vits2 +onnxsim +numpy<2 \ No newline at end of file diff --git a/sbv2_core/mora_convert.py b/sbv2_core/mora_convert.py index 7e946ee..512729c 100644 --- a/sbv2_core/mora_convert.py +++ b/sbv2_core/mora_convert.py @@ -193,4 +193,4 @@ for mora, consonant, vowel in __MORA_LIST_ADDITIONAL: with open("src/mora_list.json", "w") as f: - json.dump(data, f, ensure_ascii=False, indent=4) \ No newline at end of file + json.dump(data, f, ensure_ascii=False, indent=4) diff --git a/test.py b/test.py index 2ec36c9..907c2a4 100644 --- a/test.py +++ b/test.py @@ -1,8 +1,8 @@ import requests -res = requests.post('http://localhost:3000/synthesize', json={ - "text": "おはようございます", - "ident": "tsukuyomi" -}) -with open('output.wav', 'wb') as f: - f.write(res.content) \ No newline at end of file +res = requests.post( + "http://localhost:3000/synthesize", + json={"text": "おはようございます", "ident": "tsukuyomi"}, +) +with open("output.wav", "wb") as f: + f.write(res.content)