mirror of
https://github.com/neodyland/sbv2-api.git
synced 2025-12-25 00:29:57 +00:00
52 lines
1.8 KiB
Python
52 lines
1.8 KiB
Python
from transformers.convert_slow_tokenizer import BertConverter
|
|
from style_bert_vits2.nlp import bert_models
|
|
from style_bert_vits2.constants import Languages
|
|
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
|
import torch
|
|
from torch import nn
|
|
from argparse import ArgumentParser
|
|
import os
|
|
|
|
parser = ArgumentParser()
|
|
parser.add_argument("--model", default="ku-nlp/deberta-v2-large-japanese-char-wwm")
|
|
args = parser.parse_args()
|
|
model_name = args.model
|
|
|
|
bert_models.load_tokenizer(Languages.JP, model_name)
|
|
tokenizer = bert_models.load_tokenizer(Languages.JP)
|
|
converter = BertConverter(tokenizer)
|
|
tokenizer = converter.converted()
|
|
tokenizer.save("../../models/tokenizer.json")
|
|
|
|
|
|
class ORTDeberta(nn.Module):
|
|
def __init__(self, model_name):
|
|
super(ORTDeberta, self).__init__()
|
|
self.model = AutoModelForMaskedLM.from_pretrained(model_name)
|
|
|
|
def forward(self, input_ids, token_type_ids, attention_mask):
|
|
inputs = {
|
|
"input_ids": input_ids,
|
|
"token_type_ids": token_type_ids,
|
|
"attention_mask": attention_mask,
|
|
}
|
|
res = self.model(**inputs, output_hidden_states=True)
|
|
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
|
|
return res
|
|
|
|
|
|
model = ORTDeberta(model_name)
|
|
inputs = AutoTokenizer.from_pretrained(model_name)(
|
|
"今日はいい天気ですね", return_tensors="pt"
|
|
)
|
|
|
|
torch.onnx.export(
|
|
model,
|
|
(inputs["input_ids"], inputs["token_type_ids"], inputs["attention_mask"]),
|
|
"../../models/deberta.onnx",
|
|
input_names=["input_ids", "token_type_ids", "attention_mask"],
|
|
output_names=["output"],
|
|
verbose=True,
|
|
dynamic_axes={"input_ids": {1: "batch_size"}, "attention_mask": {1: "batch_size"}},
|
|
)
|
|
os.system("onnxsim ../../models/deberta.onnx ../../models/deberta.onnx") |