-
Notifications
You must be signed in to change notification settings - Fork 1
/
ssr.lua
31 lines (26 loc) · 918 Bytes
/
ssr.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
require 'math'
local SignedSquareRoot, parent = torch.class('nn.SignedSquareRoot', 'nn.Module')
function SignedSquareRoot:__init(args)
parent.__init(self)
self.module = nn.Sequential()
:add(nn.Abs())
:add(nn.Sqrt())
end
function SignedSquareRoot:updateOutput(input)
self.output = self.output or input.new()
self.output_ = self.module:forward(input)
-- get sign for each input element
self.sign = self.sign or input.new()
self.sign:resizeAs(input)
torch.sign(self.sign, input)
self.output:cmul(self.output_, self.sign)
return self.output
end
function SignedSquareRoot:updateGradInput(input, gradOutput)
self.gradInput = self.gradInput or inout.new()
self.gradInput:cdiv(gradOutput,
self.output_ * 2)
-- filtering out nan, avoid 1/0 caused number explosion
self.gradInput[self.output_:eq(0)] = 0
return self.gradInput
end