From 5e4688d508f72ead579f58366fbba2c8e37219cf Mon Sep 17 00:00:00 2001 From: Udbhav Date: Tue, 3 Aug 2021 09:17:51 +0530 Subject: [PATCH 01/10] improved flop budget and add double skip resnet18 --- models/base_model.py | 75 ++++++++++--- models/ds_resnet.py | 259 +++++++++++++++++++++++++++++++++++++++++++ models/models.py | 3 + models/resnet.py | 75 ++++++------- 4 files changed, 353 insertions(+), 59 deletions(-) create mode 100755 models/ds_resnet.py diff --git a/models/base_model.py b/models/base_model.py index fcafb03..4a4e695 100755 --- a/models/base_model.py +++ b/models/base_model.py @@ -9,7 +9,6 @@ def __init__(self): super(BaseModel, self).__init__() self.prunable_modules = [] self.prev_module = defaultdict() -# self.next_module = defaultdict() pass def set_threshold(self, threshold): @@ -48,9 +47,10 @@ def calculate_prune_threshold(self, Vc, budget_type = 'channel_ratio'): def smoothRound(self, x, steepness=20.): return 1./(1.+torch.exp(-1*steepness*(x-0.5))) - def n_remaining(self, m, steepness=20.): - return (m.pruned_zeta if m.is_pruned else self.smoothRound(m.get_zeta_t(), steepness)).sum() - + def n_remaining(self, m, steepness=20., do_sum=True): + rem = (m.pruned_zeta if m.is_pruned else self.smoothRound(m.get_zeta_t(), steepness)) + return rem.sum() if do_sum else rem + def is_all_pruned(self, m): return self.n_remaining(m) == 0 @@ -72,13 +72,57 @@ def get_remaining(self, steepness=20., budget_type = 'channel_ratio'): n_rem += self.n_remaining(l_block, steepness)*prev_remaining*k*k n_total += l_block.num_gates*prev_total*k*k elif budget_type == 'flops_ratio': - k = l_block._conv_module.kernel_size[0] - output_area = l_block._conv_module.output_area - prev_total = 3 if self.prev_module[l_block] is None else self.prev_module[l_block].num_gates - prev_remaining = 3 if self.prev_module[l_block] is None else self.n_remaining(self.prev_module[l_block], steepness) + k1 = l_block._conv_module.kernel_size[0] + k2 = l_block._conv_module.kernel_size[1] + active_elements_count = l_block._conv_module.output_area + if self.prev_module[l_block] is None: + prev_total = 3 + prev_remaining = 3 + elif isinstance(self.prev_module[l_block], nn.BatchNorm2d): + prev_total = self.prev_module[l_block].num_gates + prev_remaining = self.n_remaining(self.prev_module[l_block], steepness) + else: + prev_total = self.prev_module[l_block][-1].num_gates + def cal_max(prev): + if isinstance(prev[0], nn.BatchNorm2d): + prev1 = self.n_remaining(prev[0], steepness, do_sum=False) + prev2 = self.n_remaining(prev[1], steepness, do_sum=False) + return (torch.maximum(prev1, prev2) + torch.maximum(prev2, prev1))/2 + prev2 = self.n_remaining(prev[-1], steepness, do_sum=False) + list_ = cal_max(prev[0]) + return (torch.maximum(list_, prev2) + torch.maximum(prev2, list_))/2 + + prev_remaining = cal_max(self.prev_module[l_block]).sum() + curr_remaining = self.n_remaining(l_block, steepness) - n_rem += curr_remaining*prev_remaining*k*k*output_area + curr_remaining*output_area - n_total += l_block.num_gates*prev_total*k*k*output_area + l_block.num_gates*output_area + + ## Prunned + # conv + conv_per_position_flops = k1 * k2 * prev_remaining * curr_remaining + n_rem += conv_per_position_flops * active_elements_count + if l_block._conv_module.bias is not None: + n_rem += curr_remaining * active_elements_count + + # bn + batch_flops = curr_remaining * active_elements_count + n_rem += batch_flops ## ReLU flops + if l_block.affine: + batch_flops *= 2 + n_rem += batch_flops + + ## normal + # conv + conv_per_position_flops = k1 * k2 * prev_total * l_block.num_gates + n_total += conv_per_position_flops * active_elements_count + if l_block._conv_module.bias is not None: + n_total += l_block.num_gates * active_elements_count + + # bn + batch_flops = l_block.num_gates * active_elements_count + n_total += batch_flops ## ReLU flops + if l_block.affine: + batch_flops *= 2 + n_total += batch_flops return n_rem/n_total def give_zetas(self): @@ -128,7 +172,7 @@ def prune(self, Vc, budget_type = 'channel_ratio', finetuning=False, threshold=N high = mid-1 else: low = mid+1 - elif budget_type == 'flops_ratio': + elif budget_type == 'flops_ratio' and threshold==None: zetas = sorted(self.give_zetas()) high = len(zetas)-1 low = 0 @@ -138,12 +182,11 @@ def prune(self, Vc, budget_type = 'channel_ratio', finetuning=False, threshold=N for l_block in self.prunable_modules: l_block.prune(threshold) self.remove_orphans() - if self.flops()1 else False + for l in self.layers_size: + for i in range(l): + if do_downsample: + downsample_n = a[current_loc] + ans+=current_max*a[current_loc] + current_loc+=1 + + ans+=current_max*a[current_loc]*9 + ans+=a[current_loc]*a[current_loc+1]*9 + if do_downsample: + current_max = max(downsample_n, a[current_loc+1]) + else: + current_max = max(current_max, a[current_loc+1]) + do_downsample = False + current_loc+=2 + do_downsample = True + return ans + a[-1]*self.num_classes + 2*np.sum(a) + + def params(self): + a = [3] + b = [3] + for i in self.prunable_modules: + a.append(int(i.pruned_zeta.sum())) + b.append(len(i.pruned_zeta)) + return self.__calc_params(a)/self.__calc_params(b) + +def make_resnet18(num_classes, insize): + model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, insize=insize) + return model + +# def make_resnet50(num_classes, insize): +# model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, insize=insize) +# return model + +# def make_resnet101(num_classes, insize): +# model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, insize=insize) +# return model + +# def make_resnet152(num_classes, insize): +# model = ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes, insize=insize) +# return model + +def get_ds_resnet_model(model, method, num_classes, insize): + """Returns the requested model, ready for training/pruning with the specified method. + + :param model: str, either wrn or r50 + :param method: full or prune + :param num_classes: int, num classes in the dataset + :return: A prunable ResNet model + """ + ModuleInjection.pruning_method = method + ModuleInjection.prunable_modules = [] + + if model == 'dsr18': + net = make_resnet18(num_classes, insize) + # elif model == 'ds50': + # net = make_resnet50(num_classes, insize) + # elif model == 'dsr101': + # net = make_resnet101(num_classes, insize) + # elif model == 'dsr152': + # net = make_resnet152(num_classes, insize) + net.prunable_modules = ModuleInjection.prunable_modules + return net \ No newline at end of file diff --git a/models/models.py b/models/models.py index 311273d..8be1263 100755 --- a/models/models.py +++ b/models/models.py @@ -1,6 +1,7 @@ from .resnet import get_resnet_model from .network_slimming_resnet import get_network_slimming_model from .mobilenet import get_mobilenet +from .ds_resnet import get_ds_resnet_model def get_model(model, method, num_classes, insize): """Returns the requested model, ready for training/pruning with the specified method. @@ -12,6 +13,8 @@ def get_model(model, method, num_classes, insize): if model in ['wrn', 'r50', 'r101','r110', 'r152', 'r32', 'r18', 'r56', 'r20']: net = get_resnet_model(model, method, num_classes, insize) + elif model in ["dsr18"]: + net = get_ds_resnet_model(model, method, num_classes, insize) elif model in ['r164']: net = get_network_slimming_model(method, num_classes) elif model in ['mobilenetv2']: diff --git a/models/resnet.py b/models/resnet.py index 06d1c3d..b870a4e 100755 --- a/models/resnet.py +++ b/models/resnet.py @@ -113,8 +113,9 @@ def __init__(self, block, layers, width=1, num_classes=1000, insize=32): self.prev_module[b.bn2] = b.bn1 if b.downsample is not None: self.prev_module[b.downsample[1]] = prev - prev = b.bn2 - + prev = (b.downsample[1], b.bn2) + else: + prev = (prev, b.bn2) def _make_layer(self, block, planes, blocks, stride=1): downsample = None @@ -193,32 +194,6 @@ def __calc_params(self, a): do_downsample = True return ans + a[-1]*self.num_classes + 2*np.sum(a) - def __calc_flops(self, a): - ans=a[0]*a[1]*9*self.insize**2 + a[1]*self.insize**2 - current_loc = 2 - current_max = a[1] - downsample_n = a[2] - size = self.insize*2 - do_downsample = True if self.width>1 else False - for l in self.layers_size: - for i in range(l): - if do_downsample: - downsample_n = a[current_loc] - size = size//2 - ans+=(current_max+1)*a[current_loc]*size**2 - current_loc+=1 - - ans+=current_max*a[current_loc]*9*size**2 + a[current_loc]*size**2 - ans+=a[current_loc]*a[current_loc+1]*9*size**2 + a[current_loc+1]*size**2 - if do_downsample: - current_max = max(downsample_n, a[current_loc+1]) - else: - current_max = max(current_max, a[current_loc+1]) - do_downsample = False - current_loc+=2 - do_downsample = True - return 2*ans + 2*(current_max-1)*100 - def params(self): a = [3] b = [3] @@ -226,22 +201,16 @@ def params(self): a.append(int(i.pruned_zeta.sum())) b.append(len(i.pruned_zeta)) return self.__calc_params(a)/self.__calc_params(b) - - def flops(self): - a = [3] - b = [3] - for i in self.prunable_modules: - a.append(int(i.pruned_zeta.sum())) - b.append(len(i.pruned_zeta)) - return self.__calc_flops(a)/self.__calc_flops(b) - class ResNet(BaseModel): def __init__(self, block, layers, width=1, num_classes=1000, produce_vectors=False, init_weights=True, insize=32): super(ResNet, self).__init__() + self.insize = insize + self.width = width self.produce_vectors = produce_vectors self.block_type = block.__class__.__name__ self.inplanes = 64 + self.layers_size = layers if insize<128: self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) else: @@ -257,11 +226,32 @@ def __init__(self, block, layers, width=1, num_classes=1000, produce_vectors=Fal self.avgpool = nn.AdaptiveAvgPool2d(output_size=1) # Global Avg Pool self.fc = nn.Linear(512 * block.expansion * width, num_classes) + self.prev_module[self.bn1] = None self.init_weights() - - for l in [self.layer1, self.layer2, self.layer3, self.layer4]: - for b in l.children(): - downs = next(b.downsample.children()) if b.downsample is not None else None + if self.block_type =="BasicBlock": + prev = self.bn1 + for l_block in [self.layer1 , self.layer2 , self.layer3 , self.layer4]: + for b in l_block: + self.prev_module[b.bn1] = prev + self.prev_module[b.bn2] = b.bn1 + if b.downsample is not None: + self.prev_module[b.downsample[1]] = prev + prev = (b.downsample[1], b.bn2) + else: + prev = (prev, b.bn2) + + else: + prev = self.bn1 + for l_block in [self.layer1, self.layer2, self.layer3, self.layer4]: + for b in l_block: + self.prev_module[b.bn1] = prev + self.prev_module[b.bn2] = b.bn1 + self.prev_module[b.bn3] = b.bn2 + if b.downsample is not None: + self.prev_module[b.downsample[1]] = prev + prev = (b.downsample[1], b.bn3) + else: + prev = (prev, b.bn3) def _make_layer(self, block, planes, blocks, stride=1): downsample = None @@ -334,7 +324,6 @@ def remove_orphans(self): m2.pruned_zeta.data.copy_(torch.zeros_like(m2.pruned_zeta)) return num_removed - def make_wide_resnet(num_classes, insize): model = ResNetCifar(BasicBlock, [4, 4, 4], width=12, num_classes=num_classes, insize=insize) return model @@ -400,4 +389,4 @@ def get_resnet_model(model, method, num_classes, insize): elif model == 'r152': net = make_resnet152(num_classes, insize) net.prunable_modules = ModuleInjection.prunable_modules - return net + return net \ No newline at end of file From a57e72d7463798ef003dc9d86fcca3c589cb6fe9 Mon Sep 17 00:00:00 2001 From: Udbhav Bamba Date: Wed, 4 Aug 2021 16:54:45 +0530 Subject: [PATCH 02/10] fix block block_type --- models/ds_resnet.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/models/ds_resnet.py b/models/ds_resnet.py index f90885d..557457e 100755 --- a/models/ds_resnet.py +++ b/models/ds_resnet.py @@ -50,7 +50,6 @@ def forward(self, x): # class Bottleneck(nn.Module): # expansion = 4 -# name = "Bottleneck" # def __init__(self, inplanes, planes, stride=1, downsample=None): # super(Bottleneck, self).__init__() @@ -117,7 +116,7 @@ def __init__(self, block, layers, width=1, num_classes=1000, produce_vectors=Fal self.init_weights() - if block.name =="BasicBlock": + if self.block_type =="BasicBlock": prev = self.bn1 for l_block in [self.layer1 , self.layer2 , self.layer3 , self.layer4]: for b in l_block: @@ -256,4 +255,4 @@ def get_ds_resnet_model(model, method, num_classes, insize): # elif model == 'dsr152': # net = make_resnet152(num_classes, insize) net.prunable_modules = ModuleInjection.prunable_modules - return net \ No newline at end of file + return net From 9ca109eb57bd6bdf19dbd20b81f752a608e02c8e Mon Sep 17 00:00:00 2001 From: gktejus Date: Wed, 4 Aug 2021 18:09:22 +0530 Subject: [PATCH 03/10] fixed block name --- models/ds_resnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/ds_resnet.py b/models/ds_resnet.py index 557457e..ddae8e2 100755 --- a/models/ds_resnet.py +++ b/models/ds_resnet.py @@ -95,7 +95,7 @@ def __init__(self, block, layers, width=1, num_classes=1000, produce_vectors=Fal self.insize=insize self.width = width self.produce_vectors = produce_vectors - self.block_type = block.__class__.__name__ + self.block_type = block.__name__ self.inplanes = 64 self.layers_size = layers if insize<128: From 269651d3a4972adc03e9afb7c34cead148186bcb Mon Sep 17 00:00:00 2001 From: Udbhav Bamba Date: Sat, 7 Aug 2021 18:16:38 +0530 Subject: [PATCH 04/10] correct prune function for flops --- models/base_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/models/base_model.py b/models/base_model.py index 4a4e695..f2c5f9f 100755 --- a/models/base_model.py +++ b/models/base_model.py @@ -182,7 +182,7 @@ def prune(self, Vc, budget_type = 'channel_ratio', finetuning=False, threshold=N for l_block in self.prunable_modules: l_block.prune(threshold) self.remove_orphans() - if self.get_remaining(steepness=20., budget_type='flops_ratio') Date: Sat, 7 Aug 2021 18:33:48 +0530 Subject: [PATCH 05/10] restore old --- models/base_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/base_model.py b/models/base_model.py index f2c5f9f..364284a 100755 --- a/models/base_model.py +++ b/models/base_model.py @@ -182,7 +182,7 @@ def prune(self, Vc, budget_type = 'channel_ratio', finetuning=False, threshold=N for l_block in self.prunable_modules: l_block.prune(threshold) self.remove_orphans() - if self.get_remaining(steepness=20., budget_type='flops_ratio') Date: Sat, 7 Aug 2021 18:50:38 +0530 Subject: [PATCH 06/10] added label smoothing --- finetuning.py | 22 ++++++++++++++++++---- utils.py | 16 ++++++++++++++++ 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/finetuning.py b/finetuning.py index f93ca67..83f0b22 100644 --- a/finetuning.py +++ b/finetuning.py @@ -30,6 +30,8 @@ ap.add_argument('--test_only', '-t', type=bool, default=False, help='test the best model') ap.add_argument('--workers', default=0, type=int, help='number of workers') ap.add_argument('--cuda_id', '-id', type=str, default='0', help='gpu number') +ap.add_argument('--label_smoothing', '-ls', type=float, default=0, help='set label smoothing') + args = ap.parse_args() valid_size=args.valid_size @@ -58,10 +60,22 @@ state = torch.load(model_path)['state_dict'] model.load_state_dict(state, strict=False) CE = nn.CrossEntropyLoss() -def criterion(model, y_pred, y_true): +def criterion_test(model, y_pred, y_true): ce_loss = CE(y_pred, y_true) return ce_loss +if args.label_smoothing>0: + CE_smooth = CrossEntropyLabelSmooth(data_object.num_classes , args.label_smoothing) + def criterion_train(model, y_pred, y_true): + ce_loss = CE_smooth(y_pred, y_true) + return ce_loss +else: + def criterion_train(model, y_pred, y_true): + ce_loss = CE(y_pred, y_true) + return ce_loss + + + optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.decay) device = torch.device(f"cuda:{str(args.cuda_id)}") model.to(device) @@ -123,8 +137,8 @@ def test(model, loss_fn, optimizer, phase): for epoch in range(num_epochs): adjust_learning_rate(optimizer, epoch, args) print('Starting epoch %d / %d' % (epoch + 1, num_epochs)) - train_loss = train(model, criterion, optimizer) - accuracy, valid_loss = test(model, criterion, optimizer, "val") + train_loss = train(model, criterion_train, optimizer) + accuracy, valid_loss = test(model, criterion_test, optimizer, "val") remaining = model.get_remaining(20.,args.budget_type).item() if accuracy>best_accuracy: @@ -146,5 +160,5 @@ def test(model, loss_fn, optimizer, phase): state = torch.load(f"checkpoints/{args.name}_{args.dataset}_finetuned.pth") model.load_state_dict(state['state_dict'],strict=True) -acc, v_loss = test(model, criterion, optimizer, "test") +acc, v_loss = test(model, criterion_test, optimizer, "test") print(f"Test Accuracy: {acc} | Valid Accuracy: {state['acc']}") \ No newline at end of file diff --git a/utils.py b/utils.py index 1e2583f..8fa5d89 100755 --- a/utils.py +++ b/utils.py @@ -14,6 +14,22 @@ def seed_everything(seed): torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True + +class CrossEntropyLabelSmooth(nn.Module): + + def __init__(self, num_classes, epsilon): + super(CrossEntropyLabelSmooth, self).__init__() + self.num_classes = num_classes + self.epsilon = epsilon + self.logsoftmax = nn.LogSoftmax(dim=1) + + def forward(self, inputs, targets): + log_probs = self.logsoftmax(inputs) + targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1) + targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes + loss = (-targets * log_probs).mean(0).sum() + return loss + def get_mask_dict(own_state, state_dict): for name, param in state_dict.items(): if name not in own_state: From 21cc4b8a646e85650ca7cfe13ee78e7f27d13fd0 Mon Sep 17 00:00:00 2001 From: gktejus Date: Wed, 11 Aug 2021 19:49:04 +0530 Subject: [PATCH 07/10] fixed naming --- finetuning.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/finetuning.py b/finetuning.py index 83f0b22..ace4924 100644 --- a/finetuning.py +++ b/finetuning.py @@ -133,6 +133,9 @@ def test(model, loss_fn, optimizer, phase): train_losses = [] valid_losses = [] valid_accuracy = [] +name = f'{args.name}_{args.dataset}_finetuned' +if args.label_smoothing>0: + name += '_label_smoothing' if args.test_only == False: for epoch in range(num_epochs): adjust_learning_rate(optimizer, epoch, args) @@ -149,16 +152,16 @@ def test(model, loss_fn, optimizer, phase): "state_dict" : model.state_dict(), "acc" : best_accuracy, "rem" : remaining, - }, f"checkpoints/{args.name}_{args.dataset}_finetuned.pth") + }, f"checkpoints/{name}.pth") train_losses.append(train_loss) valid_losses.append(valid_loss) valid_accuracy.append(accuracy) df_data=np.array([train_losses, valid_losses, valid_accuracy]).T df = pd.DataFrame(df_data,columns = ['train_losses','valid_losses','valid_accuracy']) - df.to_csv(f"logs/{args.name}_{args.dataset}_finetuned.csv") + df.to_csv(f"logs/{name}.csv") -state = torch.load(f"checkpoints/{args.name}_{args.dataset}_finetuned.pth") +state = torch.load(f"checkpoints/{name}.pth") model.load_state_dict(state['state_dict'],strict=True) acc, v_loss = test(model, criterion_test, optimizer, "test") print(f"Test Accuracy: {acc} | Valid Accuracy: {state['acc']}") \ No newline at end of file From 34599cc3529f35209d9790a65adcc14f37e0fe44 Mon Sep 17 00:00:00 2001 From: Udbhav Bamba Date: Wed, 8 Sep 2021 16:26:58 +0530 Subject: [PATCH 08/10] change ds arch --- models/ds_resnet.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/models/ds_resnet.py b/models/ds_resnet.py index ddae8e2..56dd642 100755 --- a/models/ds_resnet.py +++ b/models/ds_resnet.py @@ -105,13 +105,14 @@ def __init__(self, block, layers, width=1, num_classes=1000, produce_vectors=Fal self.bn1 = nn.BatchNorm2d(64) self.conv1, self.bn1 = ModuleInjection.make_prunable(self.conv1, self.bn1) self.prev_module[self.bn1]=None - self.activ = nn.ReLU(inplace=True) - self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + # self.activ = nn.ReLU(inplace=True) + # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64 * width, layers[0]) self.layer2 = self._make_layer(block, 128 * width, layers[1], stride=2) self.layer3 = self._make_layer(block, 256 * width, layers[2], stride=2) self.layer4 = self._make_layer(block, 512 * width, layers[3], stride=2) - self.avgpool = nn.AdaptiveAvgPool2d(output_size=1) # Global Avg Pool + self.avgpool = nn.AdaptiveAvgPool2d(output_size=1) # Global Avg Pool4 + self.bn2 = nn.BatchNorm1d(512 * block.expansion * width) self.fc = nn.Linear(512 * block.expansion * width, num_classes) self.init_weights() @@ -128,7 +129,7 @@ def __init__(self, block, layers, width=1, num_classes=1000, produce_vectors=Fal prev = (prev ,b.bn1) self.prev_module[b.bn2] = prev prev = (prev , b.bn2) - + else: raise ValueError("Only BasicBlock supported") # prev = self.bn1 @@ -164,8 +165,8 @@ def _make_layer(self, block, planes, blocks, stride=1): def forward(self, x): x = self.conv1(x) x = self.bn1(x) - x = self.activ(x) - x = self.maxpool(x) + # x = self.activ(x) + # x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) @@ -174,7 +175,8 @@ def forward(self, x): x = self.avgpool(x) feature_vectors = x.view(x.size(0), -1) - x = self.fc(feature_vectors) + x = self.bn2(feature_vectors) + x = self.fc(x) if self.produce_vectors: return x, feature_vectors From 045495bec6caa5026c2a5f4a49cd0f79843d83ee Mon Sep 17 00:00:00 2001 From: Udbhav Bamba Date: Wed, 23 Mar 2022 21:16:36 +0530 Subject: [PATCH 09/10] Update resnet.py --- models/resnet.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/models/resnet.py b/models/resnet.py index b870a4e..27e8571 100755 --- a/models/resnet.py +++ b/models/resnet.py @@ -12,7 +12,7 @@ def conv3x3(in_planes, out_planes, stride=1): class BasicBlock(nn.Module): expansion = 1 - + block_type = "BasicBlock" def __init__(self, inplanes, planes, stride=1, downsample=None): super(BasicBlock, self).__init__() self.conv1 = conv3x3(inplanes, planes, stride) @@ -46,6 +46,7 @@ def forward(self, x): class Bottleneck(nn.Module): expansion = 4 + block_type = "Bottleneck" def __init__(self, inplanes, planes, stride=1, downsample=None): super(Bottleneck, self).__init__() @@ -389,4 +390,4 @@ def get_resnet_model(model, method, num_classes, insize): elif model == 'r152': net = make_resnet152(num_classes, insize) net.prunable_modules = ModuleInjection.prunable_modules - return net \ No newline at end of file + return net From 7a053266c88a71c4f1c82f6fe3c501ff1606df9a Mon Sep 17 00:00:00 2001 From: Udbhav Bamba Date: Wed, 23 Mar 2022 21:18:19 +0530 Subject: [PATCH 10/10] Update resnet.py --- models/resnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/resnet.py b/models/resnet.py index 27e8571..677f07c 100755 --- a/models/resnet.py +++ b/models/resnet.py @@ -209,7 +209,7 @@ def __init__(self, block, layers, width=1, num_classes=1000, produce_vectors=Fal self.insize = insize self.width = width self.produce_vectors = produce_vectors - self.block_type = block.__class__.__name__ + self.block_type = block.block_type self.inplanes = 64 self.layers_size = layers if insize<128: