Skip to content

Commit

Permalink
wip: Update tasks, parts 1-3
Browse files Browse the repository at this point in the history
  • Loading branch information
adjavon committed Jul 25, 2024
1 parent 354005a commit 14d8e72
Showing 1 changed file with 77 additions and 72 deletions.
149 changes: 77 additions & 72 deletions solution.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,28 @@
# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[]
# # Exercise 8: Knowledge Extraction from a Convolutional Neural Network
#
# In the following exercise we will train a convolutional neural network to classify electron microscopy images of Drosophila synapses, based on which neurotransmitter they contain. We will then train a CycleGAN and use a method called Discriminative Attribution from Counterfactuals (DAC) to understand how the network performs its classification, effectively going back from prediction to image data.
# The goal of this exercise is to learn how to probe what a pre-trained classifier has learned about the data it was trained on.

# We will be working with a simple example which is a fun derivation on the MNIST dataset that you will have seen in previous exercises in this course.
# Unlike regular MNIST, our dataset is classified not by number, but by color!
#
# We will:
# 1. Load a pre-trained classifier and try applying conventional attribution methods
# 2. Train a GAN to create counterfactual images - translating images from one class to another
# 3. Evaluate the GAN - see how good it is at fooling the classifier
# 4. Create attributions from the counterfactual, and learn the differences between the classes.
#
# If time permits, we will try to apply this all over again as a bonus exercise to a much more complex and more biologically relevant problem.
# ### Acknowledgments
#
# This notebook was written by Jan Funke and modified by Tri Nguyen and Diane Adjavon, using code from Nils Eckstein and a modified version of the [CycleGAN](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) implementation.
# This notebook was written by Diane Adjavon, from a previous version written by Jan Funke and modified by Tri Nguyen, using code from Nils Eckstein.
#
# %% [markdown]
# <div class="alert alert-danger">
# Set your python kernel to <code>08_knowledge_extraction</code>
# </div>

# %% [markdown]
# <div class="alert alert-block alert-success"><h1>Start here (AKA checkpoint 0)</h1>
#
# </div>

# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[]
## %% [markdown] editable=true slideshow={"slide_type": ""} tags=[]
# # Part 1: Setup
#
# In this part of the notebook, we will load the same dataset as in the previous exercise.
Expand All @@ -34,9 +39,7 @@
# - There are four classes named after the matplotlib colormaps from which we sample the data: spring, summer, autumn, and winter.
# Let's plot a few examples.
# %%
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np

# Show some examples
fig, axs = plt.subplots(4, 4, figsize=(8, 8))
Expand All @@ -47,57 +50,45 @@
ax.set_title(f"Class {y}")
ax.axis("off")


# TODO move this to the "classification" exercise
# TODO modify so that we can show examples as well at different places in the range
def plot_color_gradients(cmap_list):
gradient = np.linspace(0, 1, 256)
gradient = np.vstack((gradient, gradient))

# Create figure and adjust figure height to number of colormaps
nrows = len(cmap_list)
figh = 0.35 + 0.15 + (nrows + (nrows - 1) * 0.1) * 0.22
fig, axs = plt.subplots(nrows=nrows + 1, figsize=(6.4, figh))
fig.subplots_adjust(top=1 - 0.35 / figh, bottom=0.15 / figh, left=0.2, right=0.99)

for ax, name in zip(axs, cmap_list):
ax.imshow(gradient, aspect="auto", cmap=mpl.colormaps[name])
ax.text(
-0.01,
0.5,
name,
va="center",
ha="right",
fontsize=10,
transform=ax.transAxes,
)

# Turn off *all* ticks & spines, not just the ones with colormaps.
for ax in axs:
ax.set_axis_off()


plot_color_gradients(["spring", "summer", "winter", "autumn"])
# %% [markdown]
# In the Failure Modes exercise, we trained a classifier on this dataset. Let's load that classifier now!
# %% [markdown]
# <div class="alert alert-block alert-info"><h3>Task 1.1: Load the classifier</h3>
# We have written a slightly more general version of the `DenseModel` that you used in the previous exercise. Ours requires two inputs:
# - `input_shape`: the shape of the input images, as a tuple
# - `num_classes`: the number of classes in the dataset
#
# TODO add a task
# Create a dense model with the right inputs and load the weights from the checkpoint.
# <div>
# %%
import torch
from classifier.model import DenseModel

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# TODO Load the model with the correct input shape
model = DenseModel(input_shape=(...), num_classes=4)

# TODO modify this with the location of your classifier checkpoint
checkpoint = torch.load("extras/checkpoints/model.pth")
checkpoint = torch.load(...)
model.load_state_dict(checkpoint)
model = model.to(device)
# %% tags=["solution"]
import torch
from classifier.model import DenseModel

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Load the model
model = DenseModel(input_shape=(3, 28, 28), num_classes=4)
# Load the checkpoint
checkpoint = torch.load("extras/checkpoints/model.pth")
model.load_state_dict(checkpoint)
model = model.to(device)

# %% [markdown]
# # Part 2: Masking the relevant part of the image
# # Part 2: Using Integrated Gradients to find what the classifier knows
#
# In this section we will make a first attempt at highlight differences between the "real" and "fake" images that are most important to change the decision of the classifier.
#
Expand Down Expand Up @@ -159,6 +150,7 @@ def plot_color_gradients(cmap_list):

# %% editable=true slideshow={"slide_type": ""} tags=[]
from captum.attr import visualization as viz
import numpy as np


def visualize_attribution(attribution, original_image):
Expand Down Expand Up @@ -315,7 +307,6 @@ def visualize_color_attribution(attribution, original_image):
# <div class="alert alert-block alert-info"><h2>BONUS Task: Using different attributions.</h2>
#
#
#
# [`captum`](https://captum.ai/tutorials/Resnet_TorchVision_Interpret) has access to various different attribution algorithms.
#
# Replace `IntegratedGradients` with different attribution methods. Are they consistent with each other?
Expand All @@ -325,12 +316,10 @@ def visualize_color_attribution(attribution, original_image):
# <div class="alert alert-block alert-success"><h2>Checkpoint 2</h2>
# Let us know on the exercise chat when you've reached this point!
#
# TODO change this!!
#
# At this point we have:
#
# - Trained a classifier that can predict neurotransmitters from EM-slices of synapses.</li>
# - Found a way to mask the parts of the image that seem to be relevant for the classification, using integrated gradients.</li>
# - Loaded a classifier that classifies MNIST-like images by color, but we don't know how!
# - Tried applying Integrated Gradients to find out what the classifier is looking at - with little success.
# - Discovered the effect of changing the baseline on the output of integrated gradients.
#
# Coming up in the next section, we will learn how to create counterfactual images.
Expand All @@ -339,11 +328,11 @@ def visualize_color_attribution(attribution, original_image):
# </div>


# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[]
# %% [markdown]
# # Part 3: Train a GAN to Translate Images
#
# To gain insight into how the trained network classify images, we will use [Discriminative Attribution from Counterfactuals](https://arxiv.org/abs/2109.13412), a feature attribution with counterfactual explanations methodology.
# This method employs a CycleGAN to translate images from one class to another to make counterfactual explanations
# This method employs a StarGAN to translate images from one class to another to make counterfactual explanations.
#
# **What is a counterfactual?**
#
Expand All @@ -358,28 +347,20 @@ def visualize_color_attribution(attribution, original_image):
#
# **Counterfactual synapses**
#
# In this example, we will train a CycleGAN network that translates GABAergic synapses to acetylcholine synapses (you can also train other pairs too by changing the classes below).

# In this example, we will train a StarGAN network that is able to take any of our special MNIST images and change its class.
# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[]
# ### The model
# TODO Change this!!
# ![cycle.png](assets/cyclegan.png)
#
# In the following, we create a [CycleGAN model](https://arxiv.org/pdf/1703.10593.pdf). It is a Generative Adversarial model that is trained to turn one class of images X (for us, GABA) into a different class of images Y (for us, Acetylcholine).
# In the following, we create a [StarGAN model](https://arxiv.org/abs/1711.09020).
# It is a Generative Adversarial model that is trained to turn one class of images X into a different class of images Y.
#
# It has two generators:
# - Generator G takes a GABA image and tries to turn it into an image of an Acetylcholine synapse. When given an image that is already showing an Acetylcholine synapse, G should just re-create the same image: these are the `identities`.
# - Generator F takes a Acetylcholine image and tries to turn it into an image of an GABA synapse. When given an image that is already showing a GABA synapse, F should just re-create the same image: these are the `identities`.
# The model is made up of three networks:
# - The generator - this will be the bulk of the model, and will be responsible for transforming the images: we're going to use a `UNet`
# - The discriminator - this will be responsible for telling the difference between real and fake images: we're going to use a `DenseModel`
# - The style mapping - this will be responsible for encoding the style of the image: we're going to use a `DenseModel`
#
#
# When in training mode, the CycleGAN will also create a `reconstruction`. These are images that are passed through both generators.
# For example, a GABA image will first be transformed by G to Acetylcholine, then F will turn it back into GABA.
# This is achieved by training the network with a cycle-consistency loss. In our example, this is an L2 loss between the `real` GABA image and the `reconstruction` GABA image.
#
# But how do we force the generators to change the class of the input image? We use a discriminator for each.
# - DX tries to recognize fake GABA images: F will need to create images realistic and GABAergic enough to trick it.
# - DY tries to recognize fake Acetylcholine images: G will need to create images realistic and cholinergic enough to trick it.

# Let's start by creating these!
# %%
from dlmbl_unet import UNet
from torch import nn
Expand All @@ -405,18 +386,42 @@ def forward(self, x, y):
x = torch.cat([x, style], dim=1)
return self.generator(x)

# %% [markdown]
# <div class="alert alert-block alert-info"><h3>Task 3.1: Create the models</h3>
#
# We are going to create the models for the generator, discriminator, and style mapping.
#
# Given the Generator structure above, fill in the missing parts for the unet and the style mapping.
# %%
style_mapping = DenseModel(
input_shape=..., num_classes=... # How big is the style space?
)
unet = UNet(depth=..., in_channels=..., out_channels=..., final_activation=nn.Sigmoid())

# TODO make them figure out how many channels in the input and output, make them choose UNet depth
unet = UNet(depth=2, in_channels=6, out_channels=3, final_activation=nn.Sigmoid())
discriminator = DenseModel(input_shape=(3, 28, 28), num_classes=4)
generator = Generator(unet, style_mapping=style_mapping)
# %% tags = ["solution"]
# Here is an example of a working exercise
style_mapping = DenseModel(input_shape=(3, 28, 28), num_classes=3)
unet = UNet(depth=2, in_channels=6, out_channels=3, final_activation=nn.Sigmoid())
generator = Generator(unet, style_mapping=style_mapping)

# all models on the GPU
# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[]
# <div class="alert alert-block alert-info"><h3>Task 3.2: Create the discriminator</h3>
#
# We want the discriminator to be like a classifier, so it is able to look at an image and tell not only whether it is real, but also which class it came from.
# The discriminator will take as input either a real image or a fake image.
# Fill in the following code to create a discriminator that can classify the images into the correct number of classes.
# </div>
# %% tags=[]
discriminator = DenseModel(input_shape=..., num_classes=...)
# %% tags=["solution"]
discriminator = DenseModel(input_shape=(3, 28, 28), num_classes=4)
# %% [markdown]
# Let's move all models onto the GPU
# %%
generator = generator.to(device)
discriminator = discriminator.to(device)


# %% [markdown] editable=true slideshow={"slide_type": ""} tags=[]
# ## Training a GAN
#
Expand Down

0 comments on commit 14d8e72

Please sign in to comment.