Files
sbv2-api/scripts/convert/convert_deberta.py
2025-02-22 08:26:59 +00:00

52 lines
1.8 KiB
Python

from transformers.convert_slow_tokenizer import BertConverter
from style_bert_vits2.nlp import bert_models
from style_bert_vits2.constants import Languages
from transformers import AutoModelForMaskedLM, AutoTokenizer
import torch
from torch import nn
from argparse import ArgumentParser
import os
parser = ArgumentParser()
parser.add_argument("--model", default="ku-nlp/deberta-v2-large-japanese-char-wwm")
args = parser.parse_args()
model_name = args.model
bert_models.load_tokenizer(Languages.JP, model_name)
tokenizer = bert_models.load_tokenizer(Languages.JP)
converter = BertConverter(tokenizer)
tokenizer = converter.converted()
tokenizer.save("../../models/tokenizer.json")
class ORTDeberta(nn.Module):
def __init__(self, model_name):
super(ORTDeberta, self).__init__()
self.model = AutoModelForMaskedLM.from_pretrained(model_name)
def forward(self, input_ids, token_type_ids, attention_mask):
inputs = {
"input_ids": input_ids,
"token_type_ids": token_type_ids,
"attention_mask": attention_mask,
}
res = self.model(**inputs, output_hidden_states=True)
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
return res
model = ORTDeberta(model_name)
inputs = AutoTokenizer.from_pretrained(model_name)(
"今日はいい天気ですね", return_tensors="pt"
)
torch.onnx.export(
model,
(inputs["input_ids"], inputs["token_type_ids"], inputs["attention_mask"]),
"../../models/deberta.onnx",
input_names=["input_ids", "token_type_ids", "attention_mask"],
output_names=["output"],
verbose=True,
dynamic_axes={"input_ids": {1: "batch_size"}, "attention_mask": {1: "batch_size"}},
)
os.system("onnxsim ../../models/deberta.onnx ../../models/deberta.onnx")