Skip to content

Commit

Permalink
Merge pull request #18 from mhpi/cpu-seed-fixer
Browse files Browse the repository at this point in the history
seed fix 1
  • Loading branch information
taddyb authored Nov 29, 2023
2 parents 50b7f2f + 73cc5ab commit b03a83b
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 33 deletions.
6 changes: 3 additions & 3 deletions example/StreamflowExample-DI.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@
# 0: train base model without DI
# 1: train DI model
# 0,1: do both base and DI model
# 2: test trained models
Action = [2]
gpuid = 6
# 2: test trained modelsRAPID_output_202311
Action = [0]
gpuid = -1
torch.cuda.set_device(gpuid)

# Set hyperparameters
Expand Down
7 changes: 5 additions & 2 deletions hydroDL/model/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@


def createMask(x, dr, seed):
generator = torch.Generator(device="cuda")
generator.manual_seed(seed)
if torch.cuda.is_available():
generator = torch.Generator(device="cuda")
generator.manual_seed(seed)
else:
torch.manual_seed(seed)
mask = x.new().resize_as_(x).bernoulli_(1 - dr).div_(1 - dr).detach_()
# print('droprate='+str(dr))
return mask
Expand Down
15 changes: 8 additions & 7 deletions hydroDL/model/rnn/LSTMcell_tied.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ def __init__(
mode="train",
dr=0.5,
drMethod="drX+drW+drC",
gpu=1
gpu=1,
seed=42
):
super(LSTMcell_tied, self).__init__()

Expand All @@ -26,7 +27,7 @@ def __init__(
self.dr = dr
self.name = "LSTMcell_tied"
self.is_legacy = True

self.seed=seed
self.w_ih = Parameter(torch.Tensor(hiddenSize * 4, inputSize))
self.w_hh = Parameter(torch.Tensor(hiddenSize * 4, hiddenSize))
self.b_ih = Parameter(torch.Tensor(hiddenSize * 4))
Expand Down Expand Up @@ -55,11 +56,11 @@ def reset_parameters(self):
weight.data.uniform_(-stdv, stdv)

def reset_mask(self, x, h, c):
self.maskX = createMask(x, self.dr)
self.maskH = createMask(h, self.dr)
self.maskC = createMask(c, self.dr)
self.maskW_ih = createMask(self.w_ih, self.dr)
self.maskW_hh = createMask(self.w_hh, self.dr)
self.maskX = createMask(x, self.dr, self.seed)
self.maskH = createMask(h, self.dr, self.seed)
self.maskC = createMask(c, self.dr, self.seed)
self.maskW_ih = createMask(self.w_ih, self.dr, self.seed)
self.maskW_hh = createMask(self.w_hh, self.dr, self.seed)

def forward(self, x, hidden, *, resetMask=True, doDropMC=False):
if self.dr > 0 and (doDropMC is True or self.training is True):
Expand Down
43 changes: 22 additions & 21 deletions hydroDL/model/rnn/LSTMcell_untied.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@

class LSTMcell_untied(torch.nn.Module):
def __init__(
self, *, inputSize, hiddenSize, train=True, dr=0.5, drMethod="gal+sem", gpu=0
self, *, inputSize, hiddenSize, train=True, dr=0.5, drMethod="gal+sem", gpu=0, seed=42
):
super(LSTMcell_untied, self).__init__()
self.inputSize = inputSize
self.hiddenSize = inputSize
self.dr = dr
self.seed = seed
self.name = "LSTMcell_untied"
self.is_legacy = True

Expand Down Expand Up @@ -50,26 +51,26 @@ def reset_parameters(self):
w.data.uniform_(-std, std)

def init_mask(self, x, h, c):
self.maskX_i = createMask(x, self.dr)
self.maskX_f = createMask(x, self.dr)
self.maskX_c = createMask(x, self.dr)
self.maskX_o = createMask(x, self.dr)

self.maskH_i = createMask(h, self.dr)
self.maskH_f = createMask(h, self.dr)
self.maskH_c = createMask(h, self.dr)
self.maskH_o = createMask(h, self.dr)

self.maskC = createMask(c, self.dr)

self.maskW_xi = createMask(self.w_xi, self.dr)
self.maskW_xf = createMask(self.w_xf, self.dr)
self.maskW_xc = createMask(self.w_xc, self.dr)
self.maskW_xo = createMask(self.w_xo, self.dr)
self.maskW_hi = createMask(self.w_hi, self.dr)
self.maskW_hf = createMask(self.w_hf, self.dr)
self.maskW_hc = createMask(self.w_hc, self.dr)
self.maskW_ho = createMask(self.w_ho, self.dr)
self.maskX_i = createMask(x, self.dr, self.seed)
self.maskX_f = createMask(x, self.dr, self.seed)
self.maskX_c = createMask(x, self.dr, self.seed)
self.maskX_o = createMask(x, self.dr, self.seed)

self.maskH_i = createMask(h, self.dr, self.seed)
self.maskH_f = createMask(h, self.dr, self.seed)
self.maskH_c = createMask(h, self.dr, self.seed)
self.maskH_o = createMask(h, self.dr, self.seed)

self.maskC = createMask(c, self.dr, self.seed)

self.maskW_xi = createMask(self.w_xi, self.dr, self.seed)
self.maskW_xf = createMask(self.w_xf, self.dr, self.seed)
self.maskW_xc = createMask(self.w_xc, self.dr, self.seed)
self.maskW_xo = createMask(self.w_xo, self.dr, self.seed)
self.maskW_hi = createMask(self.w_hi, self.dr, self.seed)
self.maskW_hf = createMask(self.w_hf, self.dr, self.seed)
self.maskW_hc = createMask(self.w_hc, self.dr, self.seed)
self.maskW_ho = createMask(self.w_ho, self.dr, self.seed)

def forward(self, x, hidden):
h0, c0 = hidden
Expand Down

0 comments on commit b03a83b

Please sign in to comment.