-
Notifications
You must be signed in to change notification settings - Fork 32
/
Copy pathnnutils.lua
48 lines (41 loc) · 1.38 KB
/
nnutils.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
require "torch"
function recursive_map(module, field, func)
local str = ""
if module[field] or module.modules then
str = str .. torch.typename(module) .. ": "
end
if module[field] then
str = str .. func(module[field])
end
if module.modules then
str = str .. "["
for i, submodule in ipairs(module.modules) do
local submodule_str = recursive_map(submodule, field, func)
str = str .. submodule_str
if i < #module.modules and string.len(submodule_str) > 0 then
str = str .. " "
end
end
str = str .. "]"
end
return str
end
function abs_mean(w)
return torch.mean(torch.abs(w:clone():float()))
end
function abs_max(w)
return torch.abs(w:clone():float()):max()
end
-- Build a string of average absolute weight values for the modules in the
-- given network.
function get_weight_norms(module)
return "Weight norms:\n" .. recursive_map(module, "weight", abs_mean) ..
"\nWeight max:\n" .. recursive_map(module, "weight", abs_max)
end
-- Build a string of average absolute weight gradient values for the modules
-- in the given network.
function get_grad_norms(module)
return "Weight grad norms:\n" ..
recursive_map(module, "gradWeight", abs_mean) ..
"\nWeight grad max:\n" .. recursive_map(module, "gradWeight", abs_max)
end