diff --git a/get_miou.py b/get_miou.py index 05dff4b..d9ed0f0 100644 --- a/get_miou.py +++ b/get_miou.py @@ -4,7 +4,7 @@ from tqdm import tqdm from deeplab import DeeplabV3 -from utils.utils_metrics import compute_mIoU +from utils.utils_metrics import compute_mIoU, show_results ''' 进行指标评估需要注意以下几点: @@ -34,9 +34,10 @@ #-------------------------------------------------------# VOCdevkit_path = 'VOCdevkit' - image_ids = open(os.path.join(VOCdevkit_path, "VOC2007/ImageSets/Segmentation/val.txt"),'r').read().splitlines() - gt_dir = os.path.join(VOCdevkit_path, "VOC2007/SegmentationClass/") - pred_dir = "miou_out" + image_ids = open(os.path.join(VOCdevkit_path, "VOC2007/ImageSets/Segmentation/val.txt"),'r').read().splitlines() + gt_dir = os.path.join(VOCdevkit_path, "VOC2007/SegmentationClass/") + miou_out_path = "miou_out" + pred_dir = os.path.join(miou_out_path, 'detection-results') if miou_mode == 0 or miou_mode == 1: if not os.path.exists(pred_dir): @@ -56,5 +57,6 @@ if miou_mode == 0 or miou_mode == 2: print("Get miou.") - compute_mIoU(gt_dir, pred_dir, image_ids, num_classes, name_classes) # 执行计算mIoU的函数 + hist, IoUs, PA_Recall, Precision = compute_mIoU(gt_dir, pred_dir, image_ids, num_classes, name_classes) # 执行计算mIoU的函数 print("Get miou done.") + show_results(miou_out_path, hist, IoUs, PA_Recall, Precision, name_classes) \ No newline at end of file diff --git a/utils/utils_metrics.py b/utils/utils_metrics.py index eff6254..3ef5aa6 100644 --- a/utils/utils_metrics.py +++ b/utils/utils_metrics.py @@ -1,5 +1,8 @@ +import csv +import os from os.path import join +import matplotlib.pyplot as plt import numpy as np import torch import torch.nn.functional as F @@ -42,9 +45,15 @@ def fast_hist(a, b, n): def per_class_iu(hist): return np.diag(hist) / np.maximum((hist.sum(1) + hist.sum(0) - np.diag(hist)), 1) -def per_class_PA(hist): +def per_class_PA_Recall(hist): return np.diag(hist) / np.maximum(hist.sum(1), 1) +def per_class_Precision(hist): + return np.diag(hist) / np.maximum(hist.sum(0), 1) + +def per_Accuracy(hist): + return np.sum(np.diag(hist)) / np.maximum(np.sum(hist), 1) + def compute_mIoU(gt_dir, pred_dir, png_name_list, num_classes, name_classes): print('Num classes', num_classes) #-----------------------------------------# @@ -86,22 +95,87 @@ def compute_mIoU(gt_dir, pred_dir, png_name_list, num_classes, name_classes): hist += fast_hist(label.flatten(), pred.flatten(),num_classes) # 每计算10张就输出一下目前已计算的图片中所有类别平均的mIoU值 if ind > 0 and ind % 10 == 0: - print('{:d} / {:d}: mIou-{:0.2f}; mPA-{:0.2f}'.format(ind, len(gt_imgs), - 100 * np.nanmean(per_class_iu(hist)), - 100 * np.nanmean(per_class_PA(hist)))) + print('{:d} / {:d}: mIou-{:0.2f}%; mPA-{:0.2f}%; Accuracy-{:0.2f}%'.format( + ind, + len(gt_imgs), + 100 * np.nanmean(per_class_iu(hist)), + 100 * np.nanmean(per_class_PA_Recall(hist)), + 100 * per_Accuracy(hist) + ) + ) #------------------------------------------------# # 计算所有验证集图片的逐类别mIoU值 #------------------------------------------------# - mIoUs = per_class_iu(hist) - mPA = per_class_PA(hist) + IoUs = per_class_iu(hist) + PA_Recall = per_class_PA_Recall(hist) + Precision = per_class_Precision(hist) #------------------------------------------------# # 逐类别输出一下mIoU值 #------------------------------------------------# for ind_class in range(num_classes): - print('===>' + name_classes[ind_class] + ':\tmIou-' + str(round(mIoUs[ind_class] * 100, 2)) + '; mPA-' + str(round(mPA[ind_class] * 100, 2))) + print('===>' + name_classes[ind_class] + ':\tIou-' + str(round(IoUs[ind_class] * 100, 2)) \ + + '; Recall (equal to the PA)-' + str(round(PA_Recall[ind_class] * 100, 2))+ '; Precision-' + str(round(Precision[ind_class] * 100, 2))) #-----------------------------------------------------------------# # 在所有验证集图像上求所有类别平均的mIoU值,计算时忽略NaN值 #-----------------------------------------------------------------# - print('===> mIoU: ' + str(round(np.nanmean(mIoUs) * 100, 2)) + '; mPA: ' + str(round(np.nanmean(mPA) * 100, 2))) - return mIoUs + print('===> mIoU: ' + str(round(np.nanmean(IoUs) * 100, 2)) + '; mPA: ' + str(round(np.nanmean(PA_Recall) * 100, 2)) + '; Accuracy: ' + str(round(per_Accuracy(hist) * 100, 2))) + return np.array(hist, np.int), IoUs, PA_Recall, Precision + +def adjust_axes(r, t, fig, axes): + bb = t.get_window_extent(renderer=r) + text_width_inches = bb.width / fig.dpi + current_fig_width = fig.get_figwidth() + new_fig_width = current_fig_width + text_width_inches + propotion = new_fig_width / current_fig_width + x_lim = axes.get_xlim() + axes.set_xlim([x_lim[0], x_lim[1] * propotion]) + +def draw_plot_func(values, name_classes, plot_title, x_label, output_path, tick_font_size = 12, plt_show = True): + fig = plt.gcf() + axes = plt.gca() + plt.barh(range(len(values)), values, color='royalblue') + plt.title(plot_title, fontsize=tick_font_size + 2) + plt.xlabel(x_label, fontsize=tick_font_size) + plt.yticks(range(len(values)), name_classes, fontsize=tick_font_size) + r = fig.canvas.get_renderer() + for i, val in enumerate(values): + str_val = " " + str(val) + if val < 1.0: + str_val = " {0:.2f}".format(val) + t = plt.text(val, i, str_val, color='royalblue', va='center', fontweight='bold') + if i == (len(values)-1): + adjust_axes(r, t, fig, axes) + + fig.tight_layout() + fig.savefig(output_path) + if plt_show: + plt.show() + plt.close() + +def show_results(miou_out_path, hist, IoUs, PA_Recall, Precision, name_classes, tick_font_size = 12): + draw_plot_func(IoUs, name_classes, "mIoU = {0:.2f}%".format(np.nanmean(IoUs)*100), "Intersection over Union", \ + os.path.join(miou_out_path, "mIoU.png"), tick_font_size = tick_font_size, plt_show = True) + print("Save mIoU out to " + os.path.join(miou_out_path, "mIoU.png")) + + draw_plot_func(PA_Recall, name_classes, "mPA = {0:.2f}%".format(np.nanmean(PA_Recall)*100), "Pixel Accuracy", \ + os.path.join(miou_out_path, "mPA.png"), tick_font_size = tick_font_size, plt_show = False) + print("Save mPA out to " + os.path.join(miou_out_path, "mPA.png")) + + draw_plot_func(PA_Recall, name_classes, "mRecall = {0:.2f}%".format(np.nanmean(PA_Recall)*100), "Recall", \ + os.path.join(miou_out_path, "Recall.png"), tick_font_size = tick_font_size, plt_show = False) + print("Save Recall out to " + os.path.join(miou_out_path, "Recall.png")) + + draw_plot_func(Precision, name_classes, "mPrecision = {0:.2f}%".format(np.nanmean(Precision)*100), "Precision", \ + os.path.join(miou_out_path, "Precision.png"), tick_font_size = tick_font_size, plt_show = False) + print("Save Precision out to " + os.path.join(miou_out_path, "Precision.png")) + + with open(os.path.join(miou_out_path, "confusion_matrix.csv"), 'w', newline='') as f: + writer = csv.writer(f) + writer_list = [] + writer_list.append([' '] + [str(c) for c in name_classes]) + for i in range(len(hist)): + writer_list.append([name_classes[i]] + [str(x) for x in hist[i]]) + writer.writerows(writer_list) + print("Save confusion_matrix out to " + os.path.join(miou_out_path, "confusion_matrix.csv")) + \ No newline at end of file