From f701b32b4161fc61cf7b53f5fad77e03523e4f3f Mon Sep 17 00:00:00 2001 From: Yalan-Song <83627309+Yalan-Song@users.noreply.github.com> Date: Thu, 30 Nov 2023 20:57:55 -0500 Subject: [PATCH] Update StreamflowExample-DI.py --- example/StreamflowExample-DI.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/example/StreamflowExample-DI.py b/example/StreamflowExample-DI.py index 7bcd07d..624aacc 100644 --- a/example/StreamflowExample-DI.py +++ b/example/StreamflowExample-DI.py @@ -45,6 +45,7 @@ Action = [0] gpuid = -1 torch.cuda.set_device(gpuid) +device = torch.cuda.current_device() if torch.cuda.is_available() else torch.device("cpu" ) # Set hyperparameters EPOCH = 300 @@ -207,7 +208,7 @@ if flow_regime==0: lossFun = RmseLoss() elif flow_regime==1: - lossFun = NSELossBatch(np.nanstd(yTrain, axis=1)) + lossFun = NSELossBatch(np.nanstd(yTrain, axis=1),device = device) # the loaded loss should be consistent with the 'name' in optLoss Dict above for logging purpose # update and write the dictionary variable to out folder for logging and future testing masterDict = wrapMaster(out, optData, optModel, optLoss, optTrain)