-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathonnx_export.py
74 lines (70 loc) · 2.59 KB
/
onnx_export.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import argparse
import time
import numpy as np
import onnx
from onnxsim import simplify
import onnxruntime as ort
import onnxoptimizer
import torch
from model_onnx import SynthesizerTrn
import utils
from hubert import hubert_model_onnx
import torch_directml
dml = torch_directml.device()
def main(HubertExport,NetExport):
path = "NyaruTaffy"
if(HubertExport):
device = dml
hubert_soft = hubert_model_onnx.hubert_soft("hubert/model.pt")
test_input = torch.rand(1, 1, 16000)
input_names = ["source"]
output_names = ["embed"]
torch.onnx.export(hubert_soft.to(device),
test_input.to(device),
"hubert3.0.onnx",
dynamic_axes={
"source": {
2: "sample_length"
}
},
verbose=False,
opset_version=13,
input_names=input_names,
output_names=output_names)
if(NetExport):
device = dml
hps = utils.get_hparams_from_file(f"checkpoints/{path}/config.json")
SVCVITS = SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
**hps.model)
_ = utils.load_checkpoint(f"checkpoints/{path}/model.pth", SVCVITS, None)
_ = SVCVITS.eval().to(device)
for i in SVCVITS.parameters():
i.requires_grad = False
test_hidden_unit = torch.rand(1, 50, 256)
test_lengths = torch.LongTensor([50])
test_pitch = torch.rand(1, 50)
test_sid = torch.LongTensor([0])
input_names = ["hidden_unit", "lengths", "pitch", "sid"]
output_names = ["audio", ]
SVCVITS.eval()
torch.onnx.export(SVCVITS,
(
test_hidden_unit.to(device),
test_lengths.to(device),
test_pitch.to(device),
test_sid.to(device)
),
f"checkpoints/{path}/model.onnx",
dynamic_axes={
"hidden_unit": [0, 1],
"pitch": [1]
},
do_constant_folding=False,
opset_version=16,
verbose=False,
input_names=input_names,
output_names=output_names)
if __name__ == '__main__':
main(False,True)