mirror of
https://github.com/neodyland/sbv2-api.git
synced 2025-12-22 23:49:58 +00:00
fix: build script
This commit is contained in:
@@ -5,6 +5,7 @@ from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
import torch
|
||||
from torch import nn
|
||||
from argparse import ArgumentParser
|
||||
import os
|
||||
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--model", default="ku-nlp/deberta-v2-large-japanese-char-wwm")
|
||||
@@ -15,7 +16,7 @@ bert_models.load_tokenizer(Languages.JP, model_name)
|
||||
tokenizer = bert_models.load_tokenizer(Languages.JP)
|
||||
converter = BertConverter(tokenizer)
|
||||
tokenizer = converter.converted()
|
||||
tokenizer.save("../models/tokenizer.json")
|
||||
tokenizer.save("../../models/tokenizer.json")
|
||||
|
||||
|
||||
class ORTDeberta(nn.Module):
|
||||
@@ -42,9 +43,10 @@ inputs = AutoTokenizer.from_pretrained(model_name)(
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(inputs["input_ids"], inputs["token_type_ids"], inputs["attention_mask"]),
|
||||
"../models/deberta.onnx",
|
||||
"../../models/deberta.onnx",
|
||||
input_names=["input_ids", "token_type_ids", "attention_mask"],
|
||||
output_names=["output"],
|
||||
verbose=True,
|
||||
dynamic_axes={"input_ids": {1: "batch_size"}, "attention_mask": {1: "batch_size"}},
|
||||
)
|
||||
os.system("onnxsim ../../models/deberta.onnx ../../models/deberta.onnx")
|
||||
@@ -36,7 +36,7 @@ data = array.tolist()
|
||||
hyper_parameters = HyperParameters.load_from_json(config_file)
|
||||
out_name = hyper_parameters.model_name
|
||||
|
||||
with open(f"../models/style_vectors_{out_name}.json", "w") as f:
|
||||
with open(f"../../models/style_vectors_{out_name}.json", "w") as f:
|
||||
json.dump(
|
||||
{
|
||||
"data": data,
|
||||
@@ -127,7 +127,7 @@ torch.onnx.export(
|
||||
torch.tensor(0.6777),
|
||||
torch.tensor(0.8),
|
||||
),
|
||||
f"../models/model_{out_name}.onnx",
|
||||
f"../../models/model_{out_name}.onnx",
|
||||
verbose=True,
|
||||
dynamic_axes={
|
||||
"x_tst": {0: "batch_size", 1: "x_tst_max_length"},
|
||||
@@ -153,11 +153,11 @@ torch.onnx.export(
|
||||
],
|
||||
output_names=["output"],
|
||||
)
|
||||
os.system(f"onnxsim ../models/model_{out_name}.onnx ../models/model_{out_name}.onnx")
|
||||
onnxfile = open(f"../models/model_{out_name}.onnx", "rb").read()
|
||||
stylefile = open(f"../models/style_vectors_{out_name}.json", "rb").read()
|
||||
os.system(f"onnxsim ../../models/model_{out_name}.onnx ../../models/model_{out_name}.onnx")
|
||||
onnxfile = open(f"../../models/model_{out_name}.onnx", "rb").read()
|
||||
stylefile = open(f"../../models/style_vectors_{out_name}.json", "rb").read()
|
||||
version = bytes("1", "utf8")
|
||||
with taropen(f"../models/tmp_{out_name}.sbv2tar", "w") as w:
|
||||
with taropen(f"../../models/tmp_{out_name}.sbv2tar", "w") as w:
|
||||
|
||||
def add_tar(f, b):
|
||||
t = TarInfo(f)
|
||||
@@ -167,9 +167,9 @@ with taropen(f"../models/tmp_{out_name}.sbv2tar", "w") as w:
|
||||
add_tar("version.txt", version)
|
||||
add_tar("model.onnx", onnxfile)
|
||||
add_tar("style_vectors.json", stylefile)
|
||||
open(f"../models/{out_name}.sbv2", "wb").write(
|
||||
open(f"../../models/{out_name}.sbv2", "wb").write(
|
||||
ZstdCompressor(threads=-1, level=22).compress(
|
||||
open(f"../models/tmp_{out_name}.sbv2tar", "rb").read()
|
||||
open(f"../../models/tmp_{out_name}.sbv2tar", "rb").read()
|
||||
)
|
||||
)
|
||||
os.unlink(f"../models/tmp_{out_name}.sbv2tar")
|
||||
os.unlink(f"../../models/tmp_{out_name}.sbv2tar")
|
||||
|
||||
Reference in New Issue
Block a user