-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils_pruning.py
106 lines (73 loc) · 2.68 KB
/
utils_pruning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import copy
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
def generate_mask(model):
for name,m in model.named_modules():
if isinstance(m, nn.Conv2d):
#generate_mask
mask_weight = (m.weight != 0).float()
prune.CustomFromMask.apply(m, 'weight', mask=mask_weight)
def pruning_model(model, px):
print('start unstructured pruning')
parameters_to_prune =[]
for name,m in model.named_modules():
if isinstance(m, nn.Conv2d):
parameters_to_prune.append((m,'weight'))
parameters_to_prune = tuple(parameters_to_prune)
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=px,
)
def pruning_model_random(model, px):
parameters_to_prune =[]
for name,m in model.named_modules():
if isinstance(m, nn.Conv2d):
parameters_to_prune.append((m,'weight'))
parameters_to_prune = tuple(parameters_to_prune)
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.RandomUnstructured,
amount=px,
)
def prune_model_custom(model, mask_dict):
print('start unstructured pruning with custom mask')
for name,m in model.named_modules():
if isinstance(m, nn.Conv2d):
prune.CustomFromMask.apply(m, 'weight', mask=mask_dict[name+'.weight_mask'])
def remove_prune(model):
print('remove pruning')
for name,m in model.named_modules():
if isinstance(m, nn.Conv2d):
prune.remove(m,'weight')
def extract_mask(model_dict):
new_dict = {}
for key in model_dict.keys():
if 'mask' in key:
new_dict[key] = copy.deepcopy(model_dict[key])
return new_dict
def check_sparsity(model):
sum_list = 0
zero_sum = 0
for name,m in model.named_modules():
if isinstance(m, nn.Conv2d):
sum_list = sum_list+float(m.weight.nelement())
zero_sum = zero_sum+float(torch.sum(m.weight == 0))
print('* remain weight = ', 100*(1-zero_sum/sum_list),'%')
return 100*(1-zero_sum/sum_list)
def check_sparsity_mask(mask_dict):
sum_list = 0
zero_sum = 0
for key in mask_dict.keys():
sum_list = sum_list+float(mask_dict[key].nelement())
zero_sum = zero_sum+float(torch.sum(mask_dict[key] == 0))
print('* remain weight = ', 100*(1-zero_sum/sum_list),'%')
return 100*(1-zero_sum/sum_list)
def pruning_with_rewind(model, px, init):
pruning_model(model, px)
current_mask = extract_mask(model.state_dict())
remove_prune(model)
model.load_state_dict(init)
prune_model_custom(model, current_mask)
check_sparsity(model)