-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpostprocessing.py
65 lines (45 loc) · 2.4 KB
/
postprocessing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
"""
Name: postprocessing
plot function
Requirement:
interpolationroutine, numpy, matplotlib
Inputs:
predicted results
true results
Output:
error, predicted, true plots
"""
from interpolationroutine import westernAustraliaLocal
import numpy as np
import matplotlib.pyplot as plt
def postprocessing(y_test, y_pred, ds_local, var_local, T, depth, index, path):
ypred = y_pred.reshape(640,480)
Latlocal_1, Lonlocal_1, Lat_np, Lon_np, ds_sstloc_mean_np = westernAustraliaLocal(ds_local, var_local, T, depth)
idx0 = np.argwhere(np.isnan(ds_sstloc_mean_np))
idx0 = np.asarray(idx0)
ypred[idx0[:,0],idx0[:,1]] = 0
ypred[ypred == 0] = 'nan'
ds_local['Predicted'] = (('eta_rho', 'xi_rho'), ypred)
ytest = y_test.reshape(640,480)
ytest[ytest==0] = 'nan'
ds_local['ROMS'] = (('eta_rho', 'xi_rho'), ytest)
Diff = (ytest - ypred)/ytest
ds_local['Error'] = (('eta_rho', 'xi_rho'), Diff)
section = ds_local.Predicted
section.plot(x="lon_rho", y="lat_rho", figsize=(15, 7), clim=(25, 35))
plt.savefig(str(path) + "predicted_" + str(var_local) + str(index) + ".png", format="png", dpi=300)
section = ds_local.ROMS
section.plot(x="lon_rho", y="lat_rho", figsize=(15, 7), clim=(25, 35))
plt.savefig(str(path) + "ROMS_" + str(var_local) + str(index) + ".png", format="png", dpi=300)
section = ds_local['Predicted'].isel(eta_rho=slice(380, 500) , xi_rho=slice(350, 460))
section.plot(x="lon_rho", y="lat_rho", figsize=(7, 3), clim=(25, 35), vmin=24, vmax=30)
plt.savefig(str(path) + "predicted_sharkbay_" + str(var_local) + str(index) + ".png", format="png", dpi=300)
section = ds_local['ROMS'].isel(eta_rho=slice(380, 500) , xi_rho=slice(350, 460))
section.plot(x="lon_rho", y="lat_rho", figsize=(7, 3), clim=(25, 35), vmin=24, vmax=30)
plt.savefig(str(path) + "ROMS_sharkbay_" + str(var_local) + str(index) + ".png", format="png", dpi=300)
section = ds_local.Error
section.plot(x="lon_rho", y="lat_rho", figsize=(15, 7), clim=(25, 35))
plt.savefig(str(path) + "Error_" + str(var_local) + str(index) + ".png", format="png", dpi=300)
section = ds_local['Error'].isel(eta_rho=slice(380, 500) , xi_rho=slice(350, 460))
section.plot(x="lon_rho", y="lat_rho", figsize=(15, 7), clim=(25, 35))
plt.savefig(str(path) + "Error_sharkbay_" + str(var_local) + str(index) + ".png", format="png", dpi=300)