diff --git a/example/StreamflowExample-DI.py b/example/StreamflowExample-DI.py index 7bcd07d..624aacc 100644 --- a/example/StreamflowExample-DI.py +++ b/example/StreamflowExample-DI.py @@ -45,6 +45,7 @@ Action = [0] gpuid = -1 torch.cuda.set_device(gpuid) +device = torch.cuda.current_device() if torch.cuda.is_available() else torch.device("cpu" ) # Set hyperparameters EPOCH = 300 @@ -207,7 +208,7 @@ if flow_regime==0: lossFun = RmseLoss() elif flow_regime==1: - lossFun = NSELossBatch(np.nanstd(yTrain, axis=1)) + lossFun = NSELossBatch(np.nanstd(yTrain, axis=1),device = device) # the loaded loss should be consistent with the 'name' in optLoss Dict above for logging purpose # update and write the dictionary variable to out folder for logging and future testing masterDict = wrapMaster(out, optData, optModel, optLoss, optTrain)