diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index 375e3834dfc..6e056589a8c 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -892,7 +892,7 @@ def alpha_loss(self, tensordict: TensorDictBase) -> Tensor: @property def _alpha(self): - if self.min_log_alpha is not None: + if self.min_log_alpha is not None or self.max_log_alpha is not None: self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha) alpha = self.log_alpha.data.exp() return alpha diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py index 801180901a7..22e84673641 100644 --- a/torchrl/objectives/crossq.py +++ b/torchrl/objectives/crossq.py @@ -677,7 +677,7 @@ def alpha_loss(self, log_prob: Tensor) -> Tensor: @property def _alpha(self): - if self.min_log_alpha is not None: + if self.min_log_alpha is not None or self.max_log_alpha is not None: self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha) with torch.no_grad(): alpha = self.log_alpha.exp() diff --git a/torchrl/objectives/decision_transformer.py b/torchrl/objectives/decision_transformer.py index 013e28713bf..a0d193acbfc 100644 --- a/torchrl/objectives/decision_transformer.py +++ b/torchrl/objectives/decision_transformer.py @@ -171,7 +171,7 @@ def _forward_value_estimator_keys(self, **kwargs): @property def alpha(self): - if self.min_log_alpha is not None: + if self.min_log_alpha is not None or self.max_log_alpha is not None: self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha) with torch.no_grad(): alpha = self.log_alpha.exp() diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index dafff17011e..eae6b7feb34 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -846,7 +846,7 @@ def _alpha_loss(self, log_prob: Tensor) -> Tensor: @property def _alpha(self): - if self.min_log_alpha is not None: + if self.min_log_alpha is not None or self.max_log_alpha is not None: self.log_alpha.data.clamp_(self.min_log_alpha, self.max_log_alpha) with torch.no_grad(): alpha = self.log_alpha.exp() @@ -1374,7 +1374,7 @@ def _alpha_loss(self, log_prob: Tensor) -> Tensor: @property def _alpha(self): - if self.min_log_alpha is not None: + if self.min_log_alpha is not None or self.max_log_alpha is not None: self.log_alpha.data = self.log_alpha.data.clamp( self.min_log_alpha, self.max_log_alpha )