Skip to content

Commit

Permalink
Merge pull request #327 from hughperkins/add-opencl
Browse files Browse the repository at this point in the history
Add opencl
  • Loading branch information
nicholas-leonard authored Sep 4, 2016
2 parents 1cb5b3c + 8d0c555 commit 4f8401d
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 2 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
build/
6 changes: 5 additions & 1 deletion MaskZero.lua
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,11 @@ function MaskZero:updateOutput(input)
local vectorDim = rmi:dim()
self._zeroMask = self._zeroMask or rmi.new()
self._zeroMask:norm(rmi, 2, vectorDim)
self.zeroMask = self.zeroMask or ((torch.type(rmi) == 'torch.CudaTensor') and torch.CudaByteTensor() or torch.ByteTensor())
self.zeroMask = self.zeroMask or (
(torch.type(rmi) == 'torch.CudaTensor') and torch.CudaByteTensor()
or (torch.type(rmi) == 'torch.ClTensor') and torch.ClTensor()
or torch.ByteTensor()
)
self._zeroMask.eq(self.zeroMask, self._zeroMask, 0)

-- forward through decorated module
Expand Down
3 changes: 2 additions & 1 deletion MaskZeroCriterion.lua
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ function MaskZeroCriterion:updateOutput(input, target)
self._zeroMask = self._zeroMask or rmi.new()
self._zeroMask:norm(rmi, 2, vectorDim)
local zeroMask = self._zeroMask
if torch.isTypeOf(zeroMask, 'torch.CudaTensor') then
if torch.isTypeOf(zeroMask, 'torch.CudaTensor') or
torch.isTypeOf(zeroMask, 'torch.ClTensor') then
self.__zeroMask = self.__zeroMask or torch.FloatTensor()
self.__zeroMask:resize(self._zeroMask:size()):copy(self._zeroMask)
zeroMask = self._zeroMask
Expand Down

0 comments on commit 4f8401d

Please sign in to comment.