-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathParams_GFLOP_Counting.py
27 lines (19 loc) · 1.19 KB
/
Params_GFLOP_Counting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import os
import torch
from ptflops import get_model_complexity_info
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager
from nnunetv2.utilities.get_network_from_plans1 import get_network_from_plans
from batchgenerators.utilities.file_and_folder_operations import load_json
#os.environ["CUDA_VISIBLE_DEVICES"] = "0"
plan_path = 'path_to_network_plan_file'#"/home/hln0895/Usformer/dataset_model_config/plan.json"
plans = PlansManager(plan_path)
configuration = '3d_fullres'
configuration_manager = plans.get_configuration(configuration)
dataset_json_path = 'path_to_dataset_json_file'#"/home/hln0895/Usformer/dataset_model_config/dataset.json"
dataset_json = load_json(dataset_json_path)
model = get_network_from_plans(plans, dataset_json, configuration_manager,
num_input_channels=1)
macs, params = get_model_complexity_info(model, (1,32, 320,320), as_strings=True,
print_per_layer_stat=False, verbose=True)#input size(1,32, 320,320)
print('{:<30} {:<8}'.format('Computational complexity: ', macs))
print('{:<30} {:<8}'.format('Number of parameters: ', params))