-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest.py
95 lines (79 loc) · 2.34 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#-*- coding : utf-8 -*-
# coding: utf-8
import time
import numpy as np
import torch
import h5py
import functools
import math
from torch import nn, optim
from torch.autograd import Variable
import torch.nn.functional as F
from torch.optim import lr_scheduler
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os
from datetime import datetime
import parameters
from scipy import io
import sys
sys.path.append("..")
print('current time:',datetime.now())
## GPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)
## import net
from source_transformation import wavenet2d as myNet
net = myNet()
net.load_state_dict(torch.load(parameters.result_path+str(parameters.test_checkpoint_epoch)+'.pkl'))
net = net.to(device)
### test
time_span = parameters.timespan_input
time_span2 = parameters.timespan
receiver = parameters.receiver
trace = parameters.trace
testnum = parameters.testnum
## load testdata
f = h5py.File("./data/marmousi_testdata_shot77.h5", "r")
X = f['X'][:]
Y = f['Y'][:]
f.close()
Xinput = np.zeros([testnum,1,time_span,trace])
for k in range(1):
for i in range(testnum):
Xinput[k*testnum+i,0,:time_span2,:] = X[:,trace*i:trace*i+trace].reshape((1,1,time_span2,trace))
with torch.no_grad():
Xt = Variable(torch.from_numpy(Xinput))
Xt = Xt.to(device).type(torch.cuda.FloatTensor)
Youtput = net(Xt).data.cpu().numpy()
Y_hat = np.zeros([time_span2,receiver])
for k in range(1):
for i in range(testnum):
Y_hat[:,i*trace:i*trace+trace] = Youtput[i][0]
## plot
import matplotlib.pyplot as plt
extent = [0, 1, 1, 0]
plt.figure(figsize=(14,5))
temp = parameters.sample_id_test
cmax = np.max(dat)
cmin = -cmax
colour = 'gray'
plt.subplot(1,3,1)
plt.title('X, L2_loss=%.4f'%(L2_loss(X[temp],Y[temp])),fontsize=18)
plt.imshow(X[temp],vmax=cmax,vmin=cmin,extent=extent,cmap=colour)
plt.yticks(size=15)
plt.xticks(size=15)
plt.subplot(1,3,2)
plt.title('y_hat, L2_loss=%.4f'%(L2_loss(Y_hat[temp],Y[temp])),fontsize=18)
plt.imshow(Y_hat[temp],vmax=cmax,vmin=cmin,extent=extent,cmap=colour)
plt.yticks(size=15)
plt.xticks(size=15)
plt.subplot(1,3,3)
plt.title('Y',fontsize=18)
plt.imshow(Y[temp],vmax=cmax,vmin=cmin,extent=extent,cmap=colour)
plt.yticks(size=15)
plt.xticks(size=15)
plt.tight_layout()
plt.savefig('')
plt.show()