mirror of
https://github.com/neodyland/sbv2-api.git
synced 2026-05-25 10:10:37 +00:00
refactor
This commit is contained in:
50
scripts/convert/convert_deberta.py
Normal file
50
scripts/convert/convert_deberta.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from transformers.convert_slow_tokenizer import BertConverter
|
||||
from style_bert_vits2.nlp import bert_models
|
||||
from style_bert_vits2.constants import Languages
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
import torch
|
||||
from torch import nn
|
||||
from argparse import ArgumentParser
|
||||
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--model", default="ku-nlp/deberta-v2-large-japanese-char-wwm")
|
||||
args = parser.parse_args()
|
||||
model_name = args.model
|
||||
|
||||
bert_models.load_tokenizer(Languages.JP, model_name)
|
||||
tokenizer = bert_models.load_tokenizer(Languages.JP)
|
||||
converter = BertConverter(tokenizer)
|
||||
tokenizer = converter.converted()
|
||||
tokenizer.save("../models/tokenizer.json")
|
||||
|
||||
|
||||
class ORTDeberta(nn.Module):
|
||||
def __init__(self, model_name):
|
||||
super(ORTDeberta, self).__init__()
|
||||
self.model = AutoModelForMaskedLM.from_pretrained(model_name)
|
||||
|
||||
def forward(self, input_ids, token_type_ids, attention_mask):
|
||||
inputs = {
|
||||
"input_ids": input_ids,
|
||||
"token_type_ids": token_type_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
res = self.model(**inputs, output_hidden_states=True)
|
||||
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()
|
||||
return res
|
||||
|
||||
|
||||
model = ORTDeberta(model_name)
|
||||
inputs = AutoTokenizer.from_pretrained(model_name)(
|
||||
"今日はいい天気ですね", return_tensors="pt"
|
||||
)
|
||||
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(inputs["input_ids"], inputs["token_type_ids"], inputs["attention_mask"]),
|
||||
"../models/deberta.onnx",
|
||||
input_names=["input_ids", "token_type_ids", "attention_mask"],
|
||||
output_names=["output"],
|
||||
verbose=True,
|
||||
dynamic_axes={"input_ids": {1: "batch_size"}, "attention_mask": {1: "batch_size"}},
|
||||
)
|
||||
Reference in New Issue
Block a user