fix: build script

This commit is contained in:
Googlefan
2025-02-22 08:26:59 +00:00
parent a11e57d175
commit 5b364a3c10
2 changed files with 13 additions and 11 deletions

View File

@@ -5,6 +5,7 @@ from transformers import AutoModelForMaskedLM, AutoTokenizer
import torch
from torch import nn
from argparse import ArgumentParser
import os
parser = ArgumentParser()
parser.add_argument("--model", default="ku-nlp/deberta-v2-large-japanese-char-wwm")
@@ -15,7 +16,7 @@ bert_models.load_tokenizer(Languages.JP, model_name)
tokenizer = bert_models.load_tokenizer(Languages.JP)
converter = BertConverter(tokenizer)
tokenizer = converter.converted()
tokenizer.save("../models/tokenizer.json")
tokenizer.save("../../models/tokenizer.json")
class ORTDeberta(nn.Module):
@@ -42,9 +43,10 @@ inputs = AutoTokenizer.from_pretrained(model_name)(
torch.onnx.export(
model,
(inputs["input_ids"], inputs["token_type_ids"], inputs["attention_mask"]),
"../models/deberta.onnx",
"../../models/deberta.onnx",
input_names=["input_ids", "token_type_ids", "attention_mask"],
output_names=["output"],
verbose=True,
dynamic_axes={"input_ids": {1: "batch_size"}, "attention_mask": {1: "batch_size"}},
)
os.system("onnxsim ../../models/deberta.onnx ../../models/deberta.onnx")