Skip to content

Commit

Permalink
Started working on Bagging
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerdo committed May 17, 2022
1 parent 6cb54f7 commit 1917482
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 8 deletions.
5 changes: 3 additions & 2 deletions aucmedi/ensemble/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
| Technique | Description |
| ------------------------------------------ | ---------------------------------------------------------------------------------------------------- |
| [Augmenting][aucmedi.ensemble.augmenting] | Inference Augmenting (test-time augmentation) function for augmenting unknown images for prediction. |
| Bagging | Coming soon. |
| [Augmenting][aucmedi.ensemble.augmenting] | Inference Augmenting (test-time augmentation) function for augmenting unknown images for prediction. |
| [Bagging][aucmedi.ensemble.bagging] | Cross-Validation based Bagging for equal models trained with different sampling. |
| Stacking | Coming soon. |
???+ info
Expand All @@ -45,3 +45,4 @@
# Library imports #
#-----------------------------------------------------#
from aucmedi.ensemble.augmenting import predict_augmenting
from aucmedi.ensemble.bagging import Bagging
95 changes: 95 additions & 0 deletions aucmedi/ensemble/bagging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
#==============================================================================#
# Author: Dominik Müller #
# Copyright: 2022 IT-Infrastructure for Translational Medical Research, #
# University of Augsburg #
# #
# This program is free software: you can redistribute it and/or modify #
# it under the terms of the GNU General Public License as published by #
# the Free Software Foundation, either version 3 of the License, or #
# (at your option) any later version. #
# #
# This program is distributed in the hope that it will be useful, #
# but WITHOUT ANY WARRANTY; without even the implied warranty of #
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
# GNU General Public License for more details. #
# #
# You should have received a copy of the GNU General Public License #
# along with this program. If not, see <http://www.gnu.org/licenses/>. #
#==============================================================================#
#-----------------------------------------------------#
# Library imports #
#-----------------------------------------------------#
# External libraries
import numpy as np
# Internal libraries
from aucmedi import DataGenerator
from aucmedi.ensemble.aggregate import aggregate_dict

#-----------------------------------------------------#
# Ensemble Learning: Bagging #
#-----------------------------------------------------#
class Bagging:
""" A Bagging class providing functionality for cross-validation based ensemble learning.
"""
def __init__(self, model, k_fold=3):
""" Initialization function for creating a Bagging object.
"""
# Cache class variables
self.model_template = model
self.k_fold = k_fold
self.model_list = []

# Create k models based on template
for i in range(k_fold):



def train(self, training_generator, epochs=20,
iterations=None, callbacks=[], class_weights=None,
transfer_learning=False):
# apply cross-validaton
pass
# for loop
# model.training
# return combined history object


def predict(self, prediction_generator, aggregate=""):
pass
# for loop
# model.predict
# aggregate
# return


# Dump model to file
def dump(self, file_path):
""" Store model to disk.
Recommended to utilize the file format ".hdf5".
Args:
file_path (str): Path to store the model on disk.
"""
self.model.save(file_path)


# Load model from file
def load(self, file_path, custom_objects={}):
""" Load neural network model and its weights from a file.
After loading, the model will be compiled.
If loading a model in ".hdf5" format, it is not necessary to define any custom_objects.
Args:
file_path (str): Input path, from which the model will be loaded.
custom_objects (dict): Dictionary of custom objects for compiling
(e.g. non-TensorFlow based loss functions or architectures).
"""
# Create model input path
self.model = load_model(file_path, custom_objects, compile=False)
# Compile model
self.model.compile(optimizer=Adam(learning_rate=self.learning_rate),
loss=self.loss, metrics=self.metrics)
39 changes: 33 additions & 6 deletions tests/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ def setUpClass(self):
path_sampleGRAY = os.path.join(self.tmp_data.name, index)
np.save(path_sampleGRAY, img_gray)
self.sampleList3D.append(index)
# Create classification labels
self.labels_ohe = np.zeros((3, 4), dtype=np.uint8)
for i in range(0, 3):
class_index = np.random.randint(0, 4)
self.labels_ohe[i][class_index] = 1
# Initialize model
self.model2D = Neural_Network(n_labels=4, channels=3,
architecture="2D.Vanilla",
Expand All @@ -75,17 +80,15 @@ def test_Augmenting_2D_functionality(self):
# Test functionality with batch_size 10 and n_cycles = 1
datagen = DataGenerator(self.sampleList2D, self.tmp_data.name,
batch_size=10, resize=None, data_aug=None,
grayscale=False, two_dim=False, subfunctions=[],
standardize_mode="tf")
grayscale=False, subfunctions=[], standardize_mode="tf")
preds = predict_augmenting(self.model2D, datagen,
n_cycles=1, aggregate="mean")
self.assertTrue(np.array_equal(preds.shape, (3, 4)))

# Test functionality with batch_size 10 and n_cycles = 5
datagen = DataGenerator(self.sampleList2D, self.tmp_data.name,
batch_size=10, resize=None, data_aug=None,
grayscale=False, two_dim=False, subfunctions=[],
standardize_mode="tf")
grayscale=False, subfunctions=[], standardize_mode="tf")
preds = predict_augmenting(self.model2D, datagen,
n_cycles=5, aggregate="mean")
self.assertTrue(np.array_equal(preds.shape, (3, 4)))
Expand All @@ -95,8 +98,7 @@ def test_Augmenting_2D_customAug(self):
my_aug = Image_Augmentation()
datagen = DataGenerator(self.sampleList2D, self.tmp_data.name,
batch_size=10, resize=None, data_aug=my_aug,
grayscale=False, two_dim=False, subfunctions=[],
standardize_mode="tf")
grayscale=False, subfunctions=[], standardize_mode="tf")
preds = predict_augmenting(self.model2D, datagen,
n_cycles=1, aggregate="mean")
self.assertTrue(np.array_equal(preds.shape, (3, 4)))
Expand Down Expand Up @@ -130,3 +132,28 @@ def test_Augmenting_3D_customAug(self):
preds = predict_augmenting(self.model3D, datagen,
n_cycles=1, aggregate="mean")
self.assertTrue(np.array_equal(preds.shape, (3, 4)))

#-------------------------------------------------#
# Bagging #
#-------------------------------------------------#
def test_Bagging_create(self):
# Initialize Bagging object
el = Bagging(model=self.model2D, k_fold=5)
self.assertIsInstance(el, Bagging)

# def test_Bagging_initialize(self):
# # Initialize training DataGenerator
# datagen = DataGenerator(self.sampleList2D, self.tmp_data.name,
# labels=self.labels_ohe, batch_size=3, resize=None,
# data_aug=None, grayscale=False, subfunctions=[],
# standardize_mode="tf")
# Bagging
# pass
# # Test functionality with batch_size 10 and n_cycles = 1
# datagen = DataGenerator(self.sampleList2D, self.tmp_data.name,
# batch_size=10, resize=None, data_aug=None,
# grayscale=False, two_dim=False, subfunctions=[],
# standardize_mode="tf")
# preds = predict_augmenting(self.model2D, datagen,
# n_cycles=1, aggregate="mean")
# self.assertTrue(np.array_equal(preds.shape, (3, 4)))

0 comments on commit 1917482

Please sign in to comment.