Skip to content

Commit

Permalink
Github workflow: All loss calculations on CPU
Browse files Browse the repository at this point in the history
  • Loading branch information
torzdf committed May 22, 2024
1 parent 381ca7d commit c7f494c
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions tests/lib/model/losses_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@

from keras import device, losses as k_losses, Variable


from lib.model import losses
from lib.utils import get_backend

_PARAMS = [(losses.GeneralizedLoss(), (2, 16, 16)),
(losses.GradientLoss(), (2, 16, 16)),
# TODO Make sure these output dimensions are correct
(losses.GMSDLoss(), (2, 1, 1)),
# TODO Make sure these output dimensions are correct
(losses.LInfNorm(), (2, 1, 1))]
with device("cpu"):
_PARAMS = [(losses.GeneralizedLoss(), (2, 16, 16)),
(losses.GradientLoss(), (2, 16, 16)),
# TODO Make sure these output dimensions are correct
(losses.GMSDLoss(), (2, 1, 1)),
# TODO Make sure these output dimensions are correct
(losses.LInfNorm(), (2, 1, 1))]
_IDS = ["GeneralizedLoss", "GradientLoss", "GMSDLoss", "LInfNorm"]
_IDS = [f"{loss}[{get_backend().upper()}]" for loss in _IDS]

Expand Down

0 comments on commit c7f494c

Please sign in to comment.