-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrainer.py
158 lines (125 loc) · 6.01 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import argparse
import time
import flax
from torch.utils.data import DataLoader, TensorDataset
from flax.training import train_state
import optax
from tqdm import tqdm
from utils import *
from models import *
from data import *
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=1111, help='random seed')
parser.add_argument('--prediction_length', type=int, default=8)
parser.add_argument('--estimation_length', type=int, default=128)
parser.add_argument('--num_samples', type=int, default=200000)
parser.add_argument('--down_rate', type=int, default=8)
parser.add_argument('--hidden_size', type=int, default=32, help='hidden state space size')
parser.add_argument('--output_size', type=int, default=4)
parser.add_argument('--num_epochs', type=int, default=300)
parser.add_argument('--layer_num', type=int, default=2, help='number of hidden layers of f_0 and f_u')
parser.add_argument('--phase', type=float, default=0.314, help='phase range of eigenvalues')
parser.add_argument('--r_max', type=float, default=1.0)
parser.add_argument('--r_min', type=float, default=0.9)
parser.add_argument('--scan', action="store_true", help='parallel or serial')
args = parser.parse_args()
return args
def trainer(arguments):
train, valid, test = loadData((arguments.prediction_length,
arguments.estimation_length),
arguments.num_samples,
arguments.down_rate)
print("train shape: ", train.shape)
print("valid shape: ", valid.shape)
print("test shape: ", test.shape)
train = torch.tensor(train, dtype=torch.float32)
batch_size = 1024
dataset = TensorDataset(
train[:, 0:arguments.prediction_length, 10:].reshape(train.shape[0], -1),
train[:, arguments.prediction_length:, 0:10],
train[:, arguments.prediction_length:, 10:]
)
train_data = DataLoader(dataset, batch_size=batch_size, shuffle=True)
x_valid, y_valid = (valid[:, 0:arguments.prediction_length, 10:].reshape(valid.shape[0], -1),
valid[:, arguments.prediction_length:, 0:10]), valid[:, arguments.prediction_length:, 10:]
x_test, y_test = (test[:, 0:arguments.prediction_length, 10:].reshape(test.shape[0], -1),
test[:, arguments.prediction_length:, 0:10]), test[:, arguments.prediction_length:, 10:]
model = complexNDM(hidden_size=arguments.hidden_size,
output_size=arguments.output_size,
layer_num=arguments.layer_num,
sigma_min=arguments.r_min,
sigma_max=arguments.r_max,
scan=arguments.scan,
phase=arguments.phase)
rng = jax.random.PRNGKey(arguments.seed)
dummy_input = (jnp.ones((1, arguments.prediction_length * arguments.output_size)), jnp.ones((1, 128, 10)))
params = model.init(rng, dummy_input)
print(model.tabulate(rng, dummy_input))
schedule = optax.schedules.warmup_cosine_decay_schedule(
init_value=1e-7,
peak_value=2e-4,
warmup_steps=0.1 * arguments.num_epochs * (len(dataset) // batch_size),
decay_steps=arguments.num_epochs * (len(dataset) // batch_size),
end_value=1e-7
)
state = train_state.TrainState.create(
apply_fn=model.apply,
params=params,
tx=optax.adamw(schedule)
)
def loss_fn(params, x, y):
y_pred, hidden_states = model.apply(params, x)
loss_total = smoothl1loss(y, y_pred)
loss_total = jnp.mean(loss_total)
diff = jnp.abs(hidden_states[0:-1] - hidden_states[1:])
loss_smth = smoothl1loss(diff, jnp.zeros_like(diff))
loss_smth = jnp.mean(loss_smth)
ratio = jnp.divide(loss_smth, loss_total)
ratio = jax.lax.stop_gradient(ratio)
loss_total = loss_total + loss_smth / (10 * ratio)
return loss_total
@jax.jit
def train_step(state, x, y):
def loss(params):
return loss_fn(params, x, y)
loss, grads = jax.value_and_grad(loss)(state.params)
state = state.apply_gradients(grads=grads)
return state, loss
for epoch in range(arguments.num_epochs):
epoch_loss_avg = 0.0
print(f"Epoch {epoch + 1}/{arguments.num_epochs}\n---------------")
with tqdm(train_data, desc="Training", unit="batch") as tqdm_bar:
start_time = time.time()
for i, (x1, x2, y) in enumerate(tqdm_bar):
x1 = jnp.array(x1.numpy())
x2 = jnp.array(x2.numpy())
y = jnp.array(y.numpy())
state, current_loss = train_step(state, (x1, x2), y)
epoch_loss_avg += current_loss
current_lr = schedule(state.step)
tqdm_bar.set_postfix(loss=current_loss, lr=current_lr)
end_time = time.time()
epoch_loss_avg /= (i + 1)
batch_avg_time = (end_time - start_time) / (i + 1)
print("Epoch Avg Loss: {:.5f}".format(epoch_loss_avg))
print("Batch Avg Time: {:.3f} s".format(batch_avg_time))
validations, _ = model.apply(state.params, x_valid)
valid_loss = jnp.sqrt(jnp.mean(jnp.square(100 * validations - 100 * y_valid)))
print("Valid Loss RMSE: {:.4f}".format(valid_loss) + '\n')
predictions, _ = model.apply(state.params, x_test)
test_loss = jnp.mean(jnp.square(100 * predictions - 100 * y_test))
l_max = jnp.max(jnp.abs(100 * predictions - 100 * y_test))
print("Test Loss MSE: {:.4f}".format(test_loss))
print("Test Loss RMSE: {:.4f}".format(jnp.sqrt(test_loss)))
print("Test Loss l_max: {:.4f}".format(l_max))
bytes_output = flax.serialization.to_bytes(state.params)
with open('/root/autodl-tmp/Jax_ComplexNDM/checkpoints/best_model.flax', 'wb') as f:
f.write(bytes_output)
def main():
args = parse_arguments()
seed_random(args.seed)
print(f'Scan: {args.scan}')
trainer(args)
if __name__ == '__main__':
main()